Transformer 计算复杂度分析

分类: 注意力与Transformer · 难度: 中级 · 关联讲座: L05

Transformer 计算复杂度分析

本文逐步拆解 Transformer 各组件(QKV 投影、注意力矩阵、FFN)的计算复杂度,分析 O(n2d)O(n^2 d)O(nd2)O(nd^2) 两项在不同序列长度下的主导关系,并与 RNN、CNN 进行横向对比。


📐 计算复杂度分析:完整推导

变量定义

  • nn = 序列长度
  • dd = 模型维度(dmodeld_\text{model}
  • LL = Transformer 层数

自注意力复杂度逐步分析

步骤 1:计算 Q,K,VQ, K, V 投影(三次矩阵乘法,每次 (n×d)(d×dk)(n \times d) \cdot (d \times d_k)):

复杂度=3×O(nddk)=O(nd2)(通常 dkd\text{复杂度} = 3 \times O(n \cdot d \cdot d_k) = O(n d^2) \quad \text{(通常 } d_k \approx d\text{)}

步骤 2:计算注意力矩阵 QKTQK^T(n×dk)(dk×n)(n \times d_k) \cdot (d_k \times n)):

复杂度=O(n2dk)=O(n2d)\text{复杂度} = O(n^2 d_k) = O(n^2 d)

步骤 3:注意力加权 AVAV(n×n)(n×dv)(n \times n) \cdot (n \times d_v)):

复杂度=O(n2dv)=O(n2d)\text{复杂度} = O(n^2 d_v) = O(n^2 d)

步骤 4:FFN(两次线性层,dff=4dd_{ff} = 4d):

复杂度=O(nd4d)+O(n4dd)=O(nd2)\text{复杂度} = O(n \cdot d \cdot 4d) + O(n \cdot 4d \cdot d) = O(n d^2)

单层总复杂度

O(nd2)投影+FFN+O(n2d)注意力矩阵\underbrace{O(n d^2)}_{\text{投影+FFN}} + \underbrace{O(n^2 d)}_{\text{注意力矩阵}}

ndn \ll d 时(短序列),O(nd2)O(nd^2) 主导;当 ndn \gg d 时(长序列),O(n2d)O(n^2 d) 主导。

内存复杂度:需要存储注意力矩阵 ARn×nA \in \mathbb{R}^{n \times n},内存 O(n2)O(n^2)

与其他序列模型对比

属性Self-AttentionRNNCNN(核大小 kk
单层复杂度O(n2d)O(n^2 d)O(nd2)O(n d^2)O(knd2)O(k n d^2)
最大路径长度O(1)O(1)O(n)O(n)O(n/k)O(n/k)
并行化完全并行顺序(无法并行)完全并行
长距离依赖直接(1步)困难(nn 步)有限(取决于 kk

数值示例

BERT-base 配置n=512n=512d=768d=768L=12L=12h=12h=12

注意力矩阵计算量(每层):

n2×d=5122×768=262144×768201 M 次乘加n^2 \times d = 512^2 \times 768 = 262144 \times 768 \approx 201 \text{ M 次乘加}

不同序列长度下的注意力计算量对比

序列长度 nn注意力计算量(n2dn^2 dd=768d=768相对 n=512n=512
1281282×76812.6128^2 \times 768 \approx 12.6 M0.06×0.06\times
5125122×768201512^2 \times 768 \approx 201 M1×1\times(基准)
204820482×7683.222048^2 \times 768 \approx 3.22 B16×16\times
819281922×76851.58192^2 \times 768 \approx 51.5 B256×256\times
1000001010×7687.6810^{10} \times 768 \approx 7.68 T38400×38400\times

结论nn 翻 4 倍,计算量翻 16 倍n2n^2 的代价)。这就是为什么处理长文档(书、代码库)的 Transformer 需要专门的高效注意力方案。

FFN vs 注意力的计算量比较n=512n=512d=768d=768dff=3072d_{ff}=3072):

  • FFN:2×n×d×dff=2×512×768×30722.422 \times n \times d \times d_{ff} = 2 \times 512 \times 768 \times 3072 \approx 2.42 B
  • 注意力:201\approx 201 M
  • FFN 计算量约为注意力的 12 倍(所以 FFN 才是实际计算瓶颈,注意力是内存瓶颈)