Multi-Head Attention 与 Transformer 核心组件

分类: 注意力与Transformer · 难度: 中级 · 关联讲座: L05

Multi-Head Attention 与 Transformer 核心组件

本文完整推导 Multi-Head Attention 的数学形式(多头投影→独立注意力→拼接融合),包含参数量和计算量分析,证明多头与单头的总计算量相同。随后推导 Transformer 的三大核心组件:正弦位置编码(含相对位置的旋转矩阵证明)、前馈网络 FFN、残差连接与 Layer Normalization。


1. Multi-Head Attention

📐 Multi-Head Attention:完整推导

变量定义

  • hh = 注意力头数(Transformer base 中 h=8h=8
  • dmodeld_\text{model} = 模型维度(Transformer base 中 dmodel=512d_\text{model}=512
  • dk=dv=dmodel/hd_k = d_v = d_\text{model} / h = 每个头的维度(base 中 dk=64d_k=64
  • WiQRdmodel×dkW_i^Q \in \mathbb{R}^{d_\text{model} \times d_k}WiKRdmodel×dkW_i^K \in \mathbb{R}^{d_\text{model} \times d_k}WiVRdmodel×dvW_i^V \in \mathbb{R}^{d_\text{model} \times d_v} = 第 ii 个头的投影矩阵
  • WORhdv×dmodelW^O \in \mathbb{R}^{h d_v \times d_\text{model}} = 输出投影矩阵

推导过程

第 1 步:每个头独立计算注意力

ii 个头将输入投影到低维子空间后计算注意力:

headi=Attention(QWiQ,  KWiK,  VWiV)Rn×dv\text{head}_i = \text{Attention}(Q W_i^Q,\; K W_i^K,\; V W_i^V) \in \mathbb{R}^{n \times d_v}

每个头有自己独立的 WiQ,WiK,WiVW_i^Q, W_i^K, W_i^V——这是关键,让每个头可以”专注”不同方面。

第 2 步:拼接所有头的输出

沿最后维度拼接 hh 个头的输出:

Concat(head1,,headh)Rn×(hdv)=Rn×dmodel\text{Concat}(\text{head}_1, \ldots, \text{head}_h) \in \mathbb{R}^{n \times (h \cdot d_v)} = \mathbb{R}^{n \times d_\text{model}}

(因为 hdv=h(dmodel/h)=dmodelh \cdot d_v = h \cdot (d_\text{model}/h) = d_\text{model}

第 3 步:输出投影

WOW^O 将拼接结果投影回 dmodeld_\text{model} 维,融合来自不同头的信息:

MultiHead(Q,K,V)=Concat(head1,,headh)WORn×dmodel\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O \in \mathbb{R}^{n \times d_\text{model}}

参数量分析(每个 MultiHead 层):

  • hh 个头的 WiQW_i^Qh×dmodel×dk=dmodel2h \times d_\text{model} \times d_k = d_\text{model}^2(因为 hdk=dmodelh \cdot d_k = d_\text{model}
  • 同理 WiKW_i^KWiVW_i^V:各 dmodel2d_\text{model}^2
  • WOW^Ohdv×dmodel=dmodel2h d_v \times d_\text{model} = d_\text{model}^2
  • 总参数量4dmodel24 d_\text{model}^2(与单头完全相同!)

计算量分析hh 个头,每个头维度 dk=dmodel/hd_k = d_\text{model}/h,计算量 h×O(n2dk)=O(n2dmodel)\approx h \times O(n^2 d_k) = O(n^2 d_\text{model})(与单头相同量级)。


2. Transformer 核心组件

📐 Transformer 核心组件:完整推导

一、正弦位置编码

变量定义

  • pospos = 序列中位置(0,1,,n10, 1, \ldots, n-1
  • ii = 编码向量的维度索引(0,1,,dmodel/210, 1, \ldots, d_\text{model}/2 - 1
  • dmodeld_\text{model} = 模型维度

公式

PE(pos,2i)=sin ⁣(pos100002i/dmodel),PE(pos,2i+1)=cos ⁣(pos100002i/dmodel)PE_{(pos,\, 2i)} = \sin\!\left(\frac{pos}{10000^{2i/d_\text{model}}}\right), \quad PE_{(pos,\, 2i+1)} = \cos\!\left(\frac{pos}{10000^{2i/d_\text{model}}}\right)

为什么用正弦/余弦?三个关键性质

性质 1:每个位置有唯一编码——不同 pospos 的编码向量两两不同。

性质 2:可以用线性变换表示相对位置偏移。用三角恒等式展开:

sin(A+B)=sinAcosB+cosAsinB\sin(A + B) = \sin A \cos B + \cos A \sin B

cos(A+B)=cosAcosBsinAsinB\cos(A + B) = \cos A \cos B - \sin A \sin B

ωi=1/100002i/dmodel\omega_i = 1/10000^{2i/d_\text{model}},则 PEpos+kPE_{pos+k} 可由 PEposPE_{pos} 线性表示:

[PE(pos+k,2i)PE(pos+k,2i+1)]=[cos(kωi)sin(kωi)sin(kωi)cos(kωi)][PE(pos,2i)PE(pos,2i+1)]\begin{bmatrix} PE_{(pos+k, 2i)} \\ PE_{(pos+k, 2i+1)} \end{bmatrix} = \begin{bmatrix} \cos(k\omega_i) & \sin(k\omega_i) \\ -\sin(k\omega_i) & \cos(k\omega_i) \end{bmatrix} \begin{bmatrix} PE_{(pos, 2i)} \\ PE_{(pos, 2i+1)} \end{bmatrix}

旋转矩阵只依赖偏移 kk,不依赖绝对位置 pospos——这意味着模型可以学习”相对偏移为 kk 的词之间的关系”,而不受绝对位置影响。

性质 3:低频分量(大 ii,小 ωi\omega_i)变化慢,捕捉长距离结构;高频分量(小 ii,大 ωi\omega_i)变化快,捕捉局部精细位置。

二、前馈网络(FFN)

FFN(x)=W2ReLU(W1x+b1)+b2\text{FFN}(x) = W_2 \cdot \text{ReLU}(W_1 x + b_1) + b_2

  • W1Rdff×dmodelW_1 \in \mathbb{R}^{d_{ff} \times d_\text{model}}W2Rdmodel×dffW_2 \in \mathbb{R}^{d_\text{model} \times d_{ff}}dff=4dmodeld_{ff} = 4 d_\text{model}(先扩展 4 倍再压缩)
  • 作用:在注意力聚合信息后,对每个位置独立做非线性变换
  • 参数量:2×dmodel×dff=8dmodel22 \times d_\text{model} \times d_{ff} = 8 d_\text{model}^2(约占 Transformer 总参数 2/3)

三、残差连接 + Layer Normalization(Post-LN)

output=LayerNorm(x+SubLayer(x))\text{output} = \text{LayerNorm}(x + \text{SubLayer}(x))

残差连接的梯度传播:设 y=x+f(x)y = x + f(x),则:

Lx=Lyyx=Ly(1+fx)\frac{\partial \mathcal{L}}{\partial x} = \frac{\partial \mathcal{L}}{\partial y} \cdot \frac{\partial y}{\partial x} = \frac{\partial \mathcal{L}}{\partial y}\left(1 + \frac{\partial f}{\partial x}\right)

梯度中有 “+1” 项,即使 f/x\partial f/\partial x 很小(子层几乎不学习),梯度仍然能通过 “+1” 路径直接传到更早的层,避免梯度消失。

Layer Normalization

LayerNorm(x)=xμσ+ϵγ+β\text{LayerNorm}(x) = \frac{x - \mu}{\sigma + \epsilon} \cdot \gamma + \beta

其中 μ,σ\mu, \sigma 是在特征维度上计算的(不同于 BatchNorm 在 batch 维度上计算)。

Pre-LN vs Post-LN

  • Post-LN(原始 Transformer):LayerNorm(x+SubLayer(x))\text{LayerNorm}(x + \text{SubLayer}(x))
  • Pre-LN(现代 LLM 标准):x+SubLayer(LayerNorm(x))x + \text{SubLayer}(\text{LayerNorm}(x))
  • Pre-LN 训练更稳定(梯度直接通过残差路径流动,不经过 LayerNorm),现代 LLM(GPT-2 之后)几乎全部采用 Pre-LN。