Transformer 计算复杂度分析
分类: 注意力与Transformer · 难度: 中级 · 关联讲座: L05
本文逐步拆解 Transformer 各组件(QKV 投影、注意力矩阵、FFN)的计算复杂度,分析 O(n2d) 与 O(nd2) 两项在不同序列长度下的主导关系,并与 RNN、CNN 进行横向对比。
📐 计算复杂度分析:完整推导
变量定义:
- n = 序列长度
- d = 模型维度(dmodel)
- L = Transformer 层数
自注意力复杂度逐步分析:
步骤 1:计算 Q,K,V 投影(三次矩阵乘法,每次 (n×d)⋅(d×dk)):
复杂度=3×O(n⋅d⋅dk)=O(nd2)(通常 dk≈d)
步骤 2:计算注意力矩阵 QKT((n×dk)⋅(dk×n)):
复杂度=O(n2dk)=O(n2d)
步骤 3:注意力加权 AV((n×n)⋅(n×dv)):
复杂度=O(n2dv)=O(n2d)
步骤 4:FFN(两次线性层,dff=4d):
复杂度=O(n⋅d⋅4d)+O(n⋅4d⋅d)=O(nd2)
单层总复杂度:
投影+FFNO(nd2)+注意力矩阵O(n2d)
当 n≪d 时(短序列),O(nd2) 主导;当 n≫d 时(长序列),O(n2d) 主导。
内存复杂度:需要存储注意力矩阵 A∈Rn×n,内存 O(n2)。
与其他序列模型对比:
| 属性 | Self-Attention | RNN | CNN(核大小 k) |
|---|
| 单层复杂度 | O(n2d) | O(nd2) | O(knd2) |
| 最大路径长度 | O(1) | O(n) | O(n/k) |
| 并行化 | 完全并行 | 顺序(无法并行) | 完全并行 |
| 长距离依赖 | 直接(1步) | 困难(n 步) | 有限(取决于 k) |
数值示例
BERT-base 配置:n=512,d=768,L=12,h=12
注意力矩阵计算量(每层):
n2×d=5122×768=262144×768≈201 M 次乘加
不同序列长度下的注意力计算量对比:
| 序列长度 n | 注意力计算量(n2d,d=768) | 相对 n=512 |
|---|
| 128 | 1282×768≈12.6 M | 0.06× |
| 512 | 5122×768≈201 M | 1×(基准) |
| 2048 | 20482×768≈3.22 B | 16× |
| 8192 | 81922×768≈51.5 B | 256× |
| 100000 | 1010×768≈7.68 T | 38400× |
结论:n 翻 4 倍,计算量翻 16 倍(n2 的代价)。这就是为什么处理长文档(书、代码库)的 Transformer 需要专门的高效注意力方案。
FFN vs 注意力的计算量比较(n=512,d=768,dff=3072):
- FFN:2×n×d×dff=2×512×768×3072≈2.42 B
- 注意力:≈201 M
- FFN 计算量约为注意力的 12 倍(所以 FFN 才是实际计算瓶颈,注意力是内存瓶颈)