预训练目标函数与架构对比

分类: 预训练与微调 · 难度: 中级 · 关联讲座: L07

预训练目标函数与架构对比

本文系统梳理预训练的核心数学框架:从预训练的样本效率优势出发,对比 MLM 与 CLM 两大目标函数,形式化三种预训练架构(Encoder / Decoder / Encoder-Decoder)的差异,并讨论数据混合比例的优化框架与 Chinchilla 最优缩放律。


1. 预训练 vs 随机初始化的样本效率

📐 预训练 vs 随机初始化的样本效率

变量定义:设下游任务需要 kk 个标注样本使预训练模型达到性能 PP^*,随机初始化模型达到同等性能 PP^* 所需样本数为 kk'

推导过程:预训练模型已经从大规模无标注数据 DpreD_{pre}(数百亿 token)中习得语言的统计规律,其参数初始化点 θ0pre\theta_0^{pre} 在损失景观中已处于接近下游任务最优解的低谷附近。而随机初始化点 θ0rand\theta_0^{rand} 则处于高损失区域,需要更多梯度步才能到达同等低谷。

形式化地,fine-tuning 的优化路径为: θ=θ0preηi=1kθL(xi,yi;θ)\theta^* = \theta_0^{pre} - \eta \sum_{i=1}^{k} \nabla_\theta \mathcal{L}(x_i, y_i; \theta)

由于起点更优,同样的 kk 步优化可以达到更低的损失,等效为:

kk×C(C1,实验中 C100)k' \approx k \times C \quad (C \gg 1, \text{实验中 } C \approx 100)

结论:预训练模型在低资源场景下的样本效率比随机初始化高约 100 倍。

2. MLM 与 CLM 目标函数对比

📐 MLM 与 CLM 目标函数对比

Masked Language Modeling(BERT 类)

随机选择 15% 的 token 位置集合 MM,要求模型从未被遮盖的上下文中预测这些位置的原始 token:

JMLM=ExDiMlogP(xixM;θ)J_{MLM} = -\mathbb{E}_{x \sim D} \sum_{i \in M} \log P(x_i \mid x_{\setminus M}; \theta)

其中 xMx_{\setminus M} 表示去掉 mask 位置后的序列,注意力是双向的(每个位置可看到所有非 mask 位置)。

Causal Language Modeling(GPT 类)

给定前缀预测下一个 token,注意力通过 causal mask 强制单向

JCLM=t=1TlogP(xtx<t;θ)J_{CLM} = -\sum_{t=1}^{T} \log P(x_t \mid x_{<t}; \theta)

关键区别

  • MLM:每个样本可以提供 15% × T 个训练信号,但需要特殊的 [MASK] token(预训练-微调不一致)
  • CLM:每个样本只提供 T 个训练信号,但可以直接用于自回归生成,无训练-推理差异

3. 三种架构的形式化对比

📐 三种架构的形式化对比

Encoder-only(BERT):双向注意力,输出每个位置的上下文表示:

H=TransformerEncoder(x1,,xn)Rn×dH = \text{TransformerEncoder}(x_1, \ldots, x_n) \in \mathbb{R}^{n \times d}

注意力矩阵 ARn×nA \in \mathbb{R}^{n \times n},无 mask(每个位置可看全局)。

Decoder-only(GPT):单向注意力(causal mask),输出自回归概率:

P(xtx<t)=softmax(htWV)其中 Aij=0 if j>iP(x_t \mid x_{<t}) = \text{softmax}(h_t W_V) \quad \text{其中 } A_{ij} = 0 \text{ if } j > i

Encoder-Decoder(T5):encoder 双向处理输入,decoder 通过 cross-attention 条件生成:

P(yty<t,x)=softmax(hdec,tWV),hdec,t=CrossAttn(hdec,t,Henc)P(y_t \mid y_{<t}, x) = \text{softmax}(h_{dec,t} W_V), \quad h_{dec,t} = \text{CrossAttn}(h_{dec,t}, H_{enc})

参数规模对比(以 hidden=768, L=12, H=12 为基准):

模型参数量架构预训练目标
BERT-Base110MEncoder-onlyMLM + NSP
GPT-2-Base117MDecoder-onlyCLM
T5-Base250MEncoder-DecoderSpan Corruption

T5 参数量约为同规格 BERT 的 2.3 倍,因为 encoder 和 decoder 各自都有完整的 Transformer 层。

4. 数据混合比例的优化框架

📐 数据混合比例的优化框架

问题形式化:设有 KK 个数据源 {D1,,DK}\{D_1, \ldots, D_K\},混合权重 α=(α1,,αK)\alpha = (\alpha_1, \ldots, \alpha_K),满足 kαk=1,αk0\sum_k \alpha_k = 1, \alpha_k \geq 0

训练损失为各数据源损失的加权和: Ltrain(θ;α)=k=1KαkLk(θ)\mathcal{L}_{train}(\theta; \alpha) = \sum_{k=1}^{K} \alpha_k \cdot \mathcal{L}_k(\theta)

目标是最小化在目标评估集(如多个下游任务)上的损失: α=argminαLeval(θ(α))\alpha^* = \arg\min_{\alpha} \mathcal{L}_{eval}(\theta^*(\alpha))

重复数据的危害:若数据 DkD_k 重复 nn 次,等效权重从 αk\alpha_k 变为 nαk/(1+(n1)αk)n \cdot \alpha_k / (1 + (n-1)\alpha_k)(归一化后),模型可能记忆训练样本(memorization),在成员推断攻击(membership inference)中表现脆弱。

Chinchilla 最优比例(Hoffmann et al., 2022):在固定算力预算 CC(FLOPs)下: NC0.5,DC0.5N^* \propto C^{0.5}, \quad D^* \propto C^{0.5}

模型参数量和训练 token 数应等比例增长,最优比约为 20 tokens per parameter。