online softmax

分类: 深度学习基础

online softmax

定义

一种增量计算 softmax 归一化的方法,通过维护 running max 和 log-sum-exp 统计量,支持逐步合并多段 softmax 结果而无需一次性访问所有 logits

数学形式

给定两段统计量 (o(1),m(1),(1))(o^{(1)}, m^{(1)}, \ell^{(1)})(o(2),m(2),(2))(o^{(2)}, m^{(2)}, \ell^{(2)})

m=max(m(1),m(2)),h=em(1)mo(1)+em(2)mo(2)em(1)m(1)+em(2)m(2)m = \max(m^{(1)}, m^{(2)}), \quad \boldsymbol{h} = \frac{e^{m^{(1)}-m} \boldsymbol{o}^{(1)} + e^{m^{(2)}-m} \boldsymbol{o}^{(2)}}{e^{m^{(1)}-m} \ell^{(1)} + e^{m^{(2)}-m} \ell^{(2)}}

核心要点

FlashAttention 的数学基础

AttnRes 的 two-phase 推理中用于合并 inter-block 和 intra-block attention 结果

支持 element-wise kernel fusion,减少 I/O

代表工作

Milakov & Gimelshein 2018: Online normalizer calculation for softmax

FlashAttention: 利用 online softmax 实现 IO-aware exact attention

AttnRes: two-phase computation 中合并 Phase 1 和 Phase 2 的 attention 输出

相关概念

Softmax

FlashAttention