Linear Attention

分类: 深度学习基础

Linear Attention

定义

将 softmax attention 中的指数核替换为线性核,使得注意力计算可以利用矩阵乘法结合律从 O(N2)O(N^2) 降为 O(N)O(N)

数学形式

原始形式(primal form):

LinearAttn(q,K,V)=qKV=i=1N(qki)vi\text{LinearAttn}(\mathbf{q}, \mathbf{K}, \mathbf{V}) = \mathbf{q}\mathbf{K}^\top\mathbf{V} = \sum_{i=1}^{N} (\mathbf{q}\mathbf{k}_i^\top) \mathbf{v}_i

对偶形式(dual form):

LinearAttn(q,K,V)=qWN,WN=i=1Nkivi\text{LinearAttn}(\mathbf{q}, \mathbf{K}, \mathbf{V}) = \mathbf{q}\mathbf{W}_N, \quad \mathbf{W}_N = \sum_{i=1}^{N} \mathbf{k}_i^\top \mathbf{v}_i

核心要点

对偶形式将注意力重写为 query 与隐式权重矩阵的线性变换

权重矩阵 WN\mathbf{W}_N 是所有 token 的 rank-1 外积之和

这一视角启发了 IWP 将 token pruning 理解为隐式权重剪枝

代表模型:Katharopoulos et al. (2020), RWKV, DeltaNet, Kimi Linear

代表工作

IWP: 利用对偶形式将 softmax attention 也重写为隐式权重矩阵

RWKV: 线性注意力的 RNN 形式

DeltaNet: 带 delta rule 的线性注意力

相关概念

Softmax Attention

核方法

RKHS