Multi-Head Attention

分类: 深度学习基础

Multi-Head Attention

定义

Transformer 的核心组件,将输入投影到多个独立的注意力子空间(head),各 head 并行计算 Softmax 注意力后拼接输出

数学形式

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O headi=Attention(QWiQ,KWiK,VWiV)=softmax(QWiQ(KWiK)dk)VWiV\text{head}_i = \text{Attention}(Q W_i^Q, K W_i^K, V W_i^V) = \text{softmax}\left(\frac{Q W_i^Q (K W_i^K)^\top}{\sqrt{d_k}}\right) V W_i^V

核心要点

多子空间: 每个 head 学习不同的注意力模式(如语法关系、语义关联、位置模式等)

投影矩阵: 每个 head 有独立的 WiQ,WiK,WiVRd×dkW^Q_i, W^K_i, W^V_i \in \mathbb{R}^{d \times d_k},输出投影 WORhdk×dW^O \in \mathbb{R}^{hd_k \times d}

缩放因子: 1/dk1/\sqrt{d_k} 防止点积过大导致 softmax 梯度消失

头数与维度: 通常 dk=d/hd_k = d / h,总计算量与单头注意力相同

变体: GQA(Grouped Query Attention)、MQA(Multi-Query Attention)减少 KV head 数量以降低推理开销

代表工作

Vaswani et al. 2017: “Attention Is All You Need”(原始 Transformer)

Growth Transformer Training: 分析了 9 头注意力的各投影矩阵(q/k/v/o_proj)的可预测性差异

Michel et al. 2019: 证明多数注意力头可在推理时安全移除

相关概念

Transformer

Softmax

GQA

FlashAttention

RoPE