Multi-Head Attention 与 Transformer 核心组件
分类: 注意力与Transformer · 难度: 中级 · 关联讲座: L05
本文完整推导 Multi-Head Attention 的数学形式(多头投影→独立注意力→拼接融合),包含参数量和计算量分析,证明多头与单头的总计算量相同。随后推导 Transformer 的三大核心组件:正弦位置编码(含相对位置的旋转矩阵证明)、前馈网络 FFN、残差连接与 Layer Normalization。
1. Multi-Head Attention
📐 Multi-Head Attention:完整推导
变量定义:
- h = 注意力头数(Transformer base 中 h=8)
- dmodel = 模型维度(Transformer base 中 dmodel=512)
- dk=dv=dmodel/h = 每个头的维度(base 中 dk=64)
- WiQ∈Rdmodel×dk,WiK∈Rdmodel×dk,WiV∈Rdmodel×dv = 第 i 个头的投影矩阵
- WO∈Rhdv×dmodel = 输出投影矩阵
推导过程:
第 1 步:每个头独立计算注意力
第 i 个头将输入投影到低维子空间后计算注意力:
headi=Attention(QWiQ,KWiK,VWiV)∈Rn×dv
每个头有自己独立的 WiQ,WiK,WiV——这是关键,让每个头可以”专注”不同方面。
第 2 步:拼接所有头的输出
沿最后维度拼接 h 个头的输出:
Concat(head1,…,headh)∈Rn×(h⋅dv)=Rn×dmodel
(因为 h⋅dv=h⋅(dmodel/h)=dmodel)
第 3 步:输出投影
用 WO 将拼接结果投影回 dmodel 维,融合来自不同头的信息:
MultiHead(Q,K,V)=Concat(head1,…,headh)WO∈Rn×dmodel
参数量分析(每个 MultiHead 层):
- h 个头的 WiQ:h×dmodel×dk=dmodel2(因为 h⋅dk=dmodel)
- 同理 WiK,WiV:各 dmodel2
- WO:hdv×dmodel=dmodel2
- 总参数量:4dmodel2(与单头完全相同!)
计算量分析:h 个头,每个头维度 dk=dmodel/h,计算量 ≈h×O(n2dk)=O(n2dmodel)(与单头相同量级)。
一、正弦位置编码
变量定义:
- pos = 序列中位置(0,1,…,n−1)
- i = 编码向量的维度索引(0,1,…,dmodel/2−1)
- dmodel = 模型维度
公式:
PE(pos,2i)=sin(100002i/dmodelpos),PE(pos,2i+1)=cos(100002i/dmodelpos)
为什么用正弦/余弦?三个关键性质:
性质 1:每个位置有唯一编码——不同 pos 的编码向量两两不同。
性质 2:可以用线性变换表示相对位置偏移。用三角恒等式展开:
sin(A+B)=sinAcosB+cosAsinB
cos(A+B)=cosAcosB−sinAsinB
设 ωi=1/100002i/dmodel,则 PEpos+k 可由 PEpos 线性表示:
[PE(pos+k,2i)PE(pos+k,2i+1)]=[cos(kωi)−sin(kωi)sin(kωi)cos(kωi)][PE(pos,2i)PE(pos,2i+1)]
旋转矩阵只依赖偏移 k,不依赖绝对位置 pos——这意味着模型可以学习”相对偏移为 k 的词之间的关系”,而不受绝对位置影响。
性质 3:低频分量(大 i,小 ωi)变化慢,捕捉长距离结构;高频分量(小 i,大 ωi)变化快,捕捉局部精细位置。
二、前馈网络(FFN)
FFN(x)=W2⋅ReLU(W1x+b1)+b2
- W1∈Rdff×dmodel,W2∈Rdmodel×dff,dff=4dmodel(先扩展 4 倍再压缩)
- 作用:在注意力聚合信息后,对每个位置独立做非线性变换
- 参数量:2×dmodel×dff=8dmodel2(约占 Transformer 总参数 2/3)
三、残差连接 + Layer Normalization(Post-LN)
output=LayerNorm(x+SubLayer(x))
残差连接的梯度传播:设 y=x+f(x),则:
∂x∂L=∂y∂L⋅∂x∂y=∂y∂L(1+∂x∂f)
梯度中有 “+1” 项,即使 ∂f/∂x 很小(子层几乎不学习),梯度仍然能通过 “+1” 路径直接传到更早的层,避免梯度消失。
Layer Normalization:
LayerNorm(x)=σ+ϵx−μ⋅γ+β
其中 μ,σ 是在特征维度上计算的(不同于 BatchNorm 在 batch 维度上计算)。
Pre-LN vs Post-LN:
- Post-LN(原始 Transformer):LayerNorm(x+SubLayer(x))
- Pre-LN(现代 LLM 标准):x+SubLayer(LayerNorm(x))
- Pre-LN 训练更稳定(梯度直接通过残差路径流动,不经过 LayerNorm),现代 LLM(GPT-2 之后)几乎全部采用 Pre-LN。