GQA

分类: 网络架构

GQA

定义

Grouped Query Attention (GQA) 是 Ainslie et al. (2023) 提出的注意力机制变体,将查询头分组共享 KV 头,在 MHA(全头注意力)和 MQA(单 KV 头)之间取得平衡。

数学形式

GQA(Q,K,V)=Concat(head1,,headh)WO\text{GQA}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O

其中每 gg 个查询头共享一组 (Kj,Vj)(K_j, V_j),KV 头数 =h/g= h / g。MHA 对应 g=1g=1,MQA 对应 g=hg=h

核心要点

减少 KV cache 大小,显著降低推理时的内存占用和延迟

LLaMA 2 70B、Mistral 等主流 LLM 均采用 GQA

可通过 mean pooling 现有 MHA checkpoint 的 KV 头来初始化 GQA,实现低成本迁移

在长序列场景下,KV cache 节省效果尤为显著

代表工作

Ainslie et al. “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints” (EMNLP 2023)

相关概念

ViT

FlashAttention

PagedAttention

SnapKV