FlashAttention

分类: 高效推理与部署

FlashAttention

定义

通过 IO-aware 的分块(tiling)计算,将注意力的中间矩阵保留在 SRAM 而非 HBM,从而降低内存带宽瓶颈,实现近线性内存复杂度的精确注意力计算。

数学形式

Attention(Q,K,V)=softmax ⁣(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)V

FlashAttention 不改变公式语义,改变的是计算顺序:将矩阵分成 tile,逐块更新 softmax 分母(online softmax),避免写回完整 N×NN \times N 矩阵到 HBM。

核心要点

FlashAttention-1(2022):分块计算 + online softmax,减少 HBM 读写次数

FlashAttention-2(2023):进一步优化并行度和 GEMM 利用率

FlashAttention-3(2024):面向 H100 的 Warp specialization 和 pingpong pipeline

精确计算(非近似),无精度损失

与 BinaryAttention 的区别:FA 优化内存带宽,BinaryAttention 降低计算精度

代表工作

Dao et al. (2022), FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

Dao (2023), FlashAttention-2

相关概念

MLP 模块

Softmax

BinaryAttention