Self-Attention 完整推导

分类: 注意力与Transformer · 难度: 中级 · 关联讲座: L05

Self-Attention 完整推导

本文从 Seq2Seq 注意力出发,完整推导 Self-Attention 的三步计算流程(打分→归一化→加权求和),再推进到 Transformer 中 Query-Key-Value 框架下的缩放点积注意力公式 Attention(Q,K,V)\text{Attention}(Q,K,V),包含缩放因子 1/dk1/\sqrt{d_k} 的方差分析、完整的形状追踪和计算复杂度。


1. 三步注意力计算

📐 三步注意力计算:完整推导

变量定义

  • stRds_t \in \mathbb{R}^d = decoder 在时间步 tt 的隐状态(Query)
  • hiRdh_i \in \mathbb{R}^d = encoder 第 ii 个位置的隐状态(Key/Value)
  • H=[h1,,hn]Rd×nH = [h_1, \ldots, h_n] \in \mathbb{R}^{d \times n} = 所有 encoder 隐状态
  • eiRe_i \in \mathbb{R} = 第 ii 个位置的注意力分数(标量)
  • αRn\alpha \in \mathbb{R}^n = 注意力权重向量(概率分布)
  • atRda_t \in \mathbb{R}^d = 注意力输出(加权上下文向量)

推导过程

第 1 步:计算注意力分数(三种等价变体)

点积注意力(Luong et al., 2015):

ei=stThie_i = s_t^T h_i

缩放点积注意力(Vaswani et al., 2017):

ei=stThide_i = \frac{s_t^T h_i}{\sqrt{d}}

缩放原因:stThis_t^T h_i 的期望方差为 dd(若 st,his_t, h_i 各分量 i.i.d. 均值0方差1),d\sqrt{d} 缩放将方差归一化为1。

加性注意力(Bahdanau et al., 2015):

ei=vTtanh(W1hi+W2st),vRda,W1,W2Rda×de_i = v^T \tanh(W_1 h_i + W_2 s_t), \quad v \in \mathbb{R}^{d_a}, W_1, W_2 \in \mathbb{R}^{d_a \times d}

参数 v,W1,W2v, W_1, W_2 通过反向传播学习,表达能力更强但参数更多。

第 2 步:Softmax 归一化

nn 个原始分数归一化为概率分布:

αi=exp(ei)j=1nexp(ej),αRn,i=1nαi=1,αi0\alpha_i = \frac{\exp(e_i)}{\sum_{j=1}^{n} \exp(e_j)}, \quad \alpha \in \mathbb{R}^n, \quad \sum_{i=1}^{n} \alpha_i = 1, \quad \alpha_i \geq 0

矩阵形式:α=softmax(e)\alpha = \text{softmax}(e)

第 3 步:加权求和得上下文向量

at=i=1nαihi=HαRda_t = \sum_{i=1}^{n} \alpha_i h_i = H \alpha \in \mathbb{R}^d

最终将 ata_tsts_t 拼接,送入输出层:s~t=tanh(Wc[at;st])\tilde{s}_t = \tanh(W_c [a_t; s_t]),用于预测 yty_t


2. Attention(Q, K, V) 公式

📐 Attention(Q, K, V):完整推导

变量定义

  • XRn×dmodelX \in \mathbb{R}^{n \times d_\text{model}} = 输入矩阵(序列长度 nn,每个位置 dmodeld_\text{model} 维)
  • WQRdmodel×dkW^Q \in \mathbb{R}^{d_\text{model} \times d_k} = Query 投影矩阵
  • WKRdmodel×dkW^K \in \mathbb{R}^{d_\text{model} \times d_k} = Key 投影矩阵
  • WVRdmodel×dvW^V \in \mathbb{R}^{d_\text{model} \times d_v} = Value 投影矩阵
  • Q=XWQRn×dkQ = X W^Q \in \mathbb{R}^{n \times d_k}K=XWKRn×dkK = X W^K \in \mathbb{R}^{n \times d_k}V=XWVRn×dvV = X W^V \in \mathbb{R}^{n \times d_v}

推导过程

第 1 步:线性投影

每个输入向量 xiRdmodelx_i \in \mathbb{R}^{d_\text{model}} 通过三个独立线性变换得到 Q/K/V:

qi=xiWQ,ki=xiWK,vi=xiWVq_i = x_i W^Q, \quad k_i = x_i W^K, \quad v_i = x_i W^V

矩阵形式同时处理所有位置:Q=XWQQ = XW^QK=XWKK = XW^KV=XWVV = XW^V

第 2 步:计算相似度矩阵

计算所有 Query 与所有 Key 的点积,得到 n×nn \times n 注意力分数矩阵:

E=QKTRn×n,Eij=qiTkjE = QK^T \in \mathbb{R}^{n \times n}, \quad E_{ij} = q_i^T k_j

EijE_{ij} 表示位置 ii 对位置 jj 的原始注意力分数。

第 3 步:缩放(关键步骤)

为什么需要除以 dk\sqrt{d_k}

qi,kjq_i, k_j 的各分量 i.i.d.,均值 0,方差 1,则:

Var(qiTkj)=Var ⁣(l=1dkqilkjl)=l=1dkVar(qil)Var(kjl)=dk\text{Var}(q_i^T k_j) = \text{Var}\!\left(\sum_{l=1}^{d_k} q_{il} k_{jl}\right) = \sum_{l=1}^{d_k} \text{Var}(q_{il}) \cdot \text{Var}(k_{jl}) = d_k

所以 qiTkjq_i^T k_j 的标准差为 dk\sqrt{d_k}。不缩放时,dkd_k 较大(如 64)的分数会使 softmax 饱和在接近独热分布的区域,梯度近乎为零。缩放后方差归一化为 1:

E~=QKTdk\tilde{E} = \frac{QK^T}{\sqrt{d_k}}

第 4 步:Softmax 按行归一化

A=softmax ⁣(QKTdk)Rn×nA = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) \in \mathbb{R}^{n \times n}

按行归一化:Aij=exp(E~ij)l=1nexp(E~il)A_{ij} = \dfrac{\exp(\tilde{E}_{ij})}{\sum_{l=1}^{n} \exp(\tilde{E}_{il})},每行之和为 1。

AijA_{ij} 是位置 ii 分配给位置 jj 的注意力权重。

第 5 步:加权求和 Value 得输出

O=AVRn×dv,Oi=j=1nAijvjO = AV \in \mathbb{R}^{n \times d_v}, \quad O_i = \sum_{j=1}^{n} A_{ij} v_j

位置 ii 的输出是所有位置 Value 向量按注意力权重的加权平均。

完整公式

Attention(Q,K,V)=softmax ⁣(QKTdk)V\boxed{\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) V}

形状追踪(务必熟练):

矩阵形状说明
QQ(n×dk)(n \times d_k)
KTK^T(dk×n)(d_k \times n)
QKTQK^T(n×n)(n \times n)注意力矩阵,与序列长度平方成正比
A=softmax()A = \text{softmax}(\cdot)(n×n)(n \times n)行归一化概率矩阵
VV(n×dv)(n \times d_v)
O=AVO = AV(n×dv)(n \times d_v)最终输出

计算复杂度

  • QKTQK^T(n×dk)(dk×n)O(n2dk)(n \times d_k) \cdot (d_k \times n) \Rightarrow O(n^2 d_k)
  • AVAV(n×n)(n×dv)O(n2dv)(n \times n) \cdot (n \times d_v) \Rightarrow O(n^2 d_v)
  • 总:O(n2d)O(n^2 d)n2n^2 是瓶颈)