SFT / CLM / Instruction FT 完整详解

分类: 预训练与微调 · 难度: 中级 · 关联讲座: L08

SFT / CLM / Instruction FT 完整详解

本文系统讲解三个常被混淆的概念——CLM(因果语言建模)、SFT(监督微调)、Instruction Fine-Tuning(指令微调)——的精确定义、数学形式、关键区别,以及 token 级损失计算的完整推导。


1. 三个概念的层次关系

概念层次是否需要标注数据典型代表
CLM目标函数否(自监督)GPT 预训练
SFT训练方法是(人工示范)InstructGPT Stage 1
Instruction FT训练策略/范式是(跨任务)Flan-T5、InstructGPT

包含关系:Instruction FT SFT 方法,SFT 类 CLM 的 MLE 目标(但限定 token 范围)。


2. CLM:因果语言建模

2.1 目标函数

给定序列 x=[x1,x2,,xT]\mathbf{x} = [x_1, x_2, \ldots, x_T],CLM 最大化序列的对数似然:

JCLM(θ)=t=1TlogPθ(xtx1,,xt1)=t=1TlogPθ(xtx<t)J_{CLM}(\theta) = \sum_{t=1}^{T} \log P_\theta(x_t \mid x_1, \ldots, x_{t-1}) = \sum_{t=1}^{T} \log P_\theta(x_t \mid x_{<t})

训练目标是最小化负对数似然(NLL Loss):

LCLM(θ)=1Tt=1TlogPθ(xtx<t)\mathcal{L}_{CLM}(\theta) = -\frac{1}{T} \sum_{t=1}^{T} \log P_\theta(x_t \mid x_{<t})

2.2 关键特性

  • 无监督:数据无需标注,用文本自身做监督信号(xtx_t 既是上下文又是标签)
  • 全 token 计损失:序列中每个位置都参与梯度更新
  • 自回归生成:推断时从左到右逐 token 采样
  • 适用场景:预训练,学习通用文本分布

3. SFT:监督微调

3.1 目标函数

给定 (指令, 回复) 数据集 DSFT={(xi,yi)}i=1N\mathcal{D}_{SFT} = \{(x_i, y_i)\}_{i=1}^{N},SFT 最小化:

LSFT(θ)=1Ni=1N1yit=1yilogPθ ⁣(yi(t)yi(<t),xi)\mathcal{L}_{SFT}(\theta) = -\frac{1}{N} \sum_{i=1}^{N} \frac{1}{|y_i|} \sum_{t=1}^{|y_i|} \log P_\theta\!\left(y_i^{(t)} \mid y_i^{(<t)}, x_i\right)

其中:

  • xi=[xi(1),,xi(m)]x_i = [x_i^{(1)}, \ldots, x_i^{(m)}]:指令 token 序列(不计损失,但参与 attention)
  • yi=[yi(1),,yi(n)]y_i = [y_i^{(1)}, \ldots, y_i^{(n)}]:回复 token 序列(只在此处计损失

3.2 Loss Mask 的形式化

将拼接序列记为 Si=[xi;yi]S_i = [x_i; y_i],引入 loss mask m{0,1}m+n\mathbf{m} \in \{0,1\}^{m+n}

mt={01tm(指令部分)1m<tm+n(回复部分)m_t = \begin{cases} 0 & 1 \leq t \leq m \quad (\text{指令部分}) \\ 1 & m < t \leq m+n \quad (\text{回复部分}) \end{cases}

则 SFT 损失可写成与 CLM 平行的形式:

LSFT(θ)=1tmtt=1m+nmtlogPθ(sts<t)\mathcal{L}_{SFT}(\theta) = -\frac{1}{\sum_t m_t} \sum_{t=1}^{m+n} m_t \cdot \log P_\theta(s_t \mid s_{<t})

梯度只来自 response token

LSFTθ=1nt=m+1m+nlogPθ(sts<t)θ\frac{\partial \mathcal{L}_{SFT}}{\partial \theta} = -\frac{1}{n} \sum_{t=m+1}^{m+n} \frac{\partial \log P_\theta(s_t \mid s_{<t})}{\partial \theta}

注意:虽然指令 token 不产生梯度,但它们仍参与前向传播——在 causal attention 中作为 context 向量可见,是生成正确回复的必要条件。

3.3 CLM vs SFT 的对比

维度CLMSFT
数据格式纯文本 x\mathbf{x}(指令 xx, 回复 yy) 对
损失范围所有 token仅回复 token
是否需要标注
学习目标P(x)P(\mathbf{x}) 联合分布P(yx)P(y \mid x) 条件分布
梯度信噪比低(指令 token 引入噪声)高(仅学习”如何回复”)
典型应用预训练指令微调、对话模型

4. Token 级损失计算:完整数值示例

设定

  • 指令 x=[[Q],翻译:,]x = [\texttt{[Q]}, \texttt{翻译:}, \texttt{猫}]m=3m = 3
  • 回复 y=[[A],cat,.,[EOS]]y = [\texttt{[A]}, \texttt{cat}, \texttt{.}, \texttt{[EOS]}]n=4n = 4
  • 完整序列 SST=7T = 7

各位置损失贡献

位置 ttTokenlogPθ(sts<t)-\log P_\theta(s_t \mid s_{<t})CLM 计入?SFT 计入?
1[Q]2.12
2翻译:1.71
32.53
4[A]0.80
5cat0.48
6.0.21
7[EOS]0.31

LCLM=2.12+1.71+2.53+0.80+0.48+0.21+0.3171.17\mathcal{L}_{CLM} = \frac{2.12+1.71+2.53+0.80+0.48+0.21+0.31}{7} \approx 1.17

LSFT=0.80+0.48+0.21+0.314=0.45\mathcal{L}_{SFT} = \frac{0.80+0.48+0.21+0.31}{4} = 0.45

关键观察

  1. 梯度信号:CLM 中约 3743%\frac{3}{7} \approx 43\% 的梯度来自”生成指令”的方向,与任务无关;SFT 100% 来自”生成正确回复”。

  2. 长 prompt 放大效应:若 few-shot prompt 包含 50 个示例 token,只有 4 个回复 token,CLM 中 505493%\frac{50}{54} \approx 93\% 的梯度浪费在 prompt 上。SFT 不受此影响。

  3. 数值大小差异:SFT 损失数值更小不代表模型”更好学”,而是排除了概率低的指令 token(1=2.12,2=1.71\ell_1 = 2.12, \ell_2 = 1.71)。


5. Instruction Fine-Tuning:策略层面

5.1 什么是 Instruction FT?

Instruction Fine-Tuning(指令微调)不是一个具体算法,而是一种训练策略

在包含多种任务多种指令格式的数据集上,用 SFT 方法微调预训练模型,使其具备泛化的指令遵循能力。

5.2 与普通 SFT 的区别

普通 SFT 可能只针对单一任务(如”翻译”),而 Instruction FT 要求:

  • 任务多样性:摘要、翻译、问答、代码生成……覆盖广泛
  • 指令多样性:同一任务用不同措辞表达
  • 零样本泛化:训练后能响应从未见过的任务指令

5.3 代表工作

模型预训练基座任务数数据规模关键发现
FLAN (Wei et al., 2021)LaMDA-PT62-Instruction FT 使 LM 能零样本泛化
T0 (Sanh et al., 2021)T5171-多任务提示格式的重要性
Super-Natural InstructionsGPT-Neo1,6163M+任务数量是关键因素
Flan-T5 (Chung et al., 2022)T5-XXL1,836-更大模型 → 更大的 IFT 增益
InstructGPT (Ouyang et al., 2022)GPT-3-13K SFTSFT + RLHF 流水线
LIMA (Zhou et al., 2023)LLaMA-1K1000 条高质量 > 52K 自动生成

5.4 数据质量 vs 数量

LIMA 的核心发现(“Less Is More for Alignment”):

1K 高质量示范52K Alpaca 自动生成(人类偏好评测)\text{1K 高质量示范} \geq \text{52K Alpaca 自动生成} \quad (\text{人类偏好评测})

这表明 SFT 的作用是激发预训练中已存在的能力,而非注入新知识。模型的能力边界由预训练决定,SFT 只是”调整格式”和”解锁行为模式”。


6. 为什么不直接用 CLM 做指令微调?

理论层面:CLM 学习联合分布 P(x,y)P(x, y),SFT 学习条件分布 P(yx)P(y \mid x)。对指令遵循任务,我们需要的是后者——已知指令 xx,生成最佳回复 yy,不需要”如何生成指令”的知识。

实践层面的三个问题:

  1. 梯度污染:instruction token 的梯度方向与任务目标无关,甚至相反(学习”续写 prompt”而非”回答 prompt”)。

  2. 信噪比低:实际指令往往比回复长(few-shot 场景),CLM 大量梯度浪费在 prompt 上。

  3. 行为模式错误:用 CLM 微调的模型,面对 “Explain X to me” 的 prompt,可能会生成 “Explain Y to me” 的续写,而非真正的解释——因为这是训练数据中最可能的续写。

实验验证(Wang et al., 2022):Response-only SFT 在指令遵循任务上持续优于 full-sequence CLM SFT,差距在长指令场景下更显著。


7. 实现细节:如何在代码中实现 Loss Mask

以 HuggingFace Transformers 为例(PyTorch):

# 构造 loss mask:指令部分为 -100(忽略),回复部分为真实 token id
def build_sft_labels(input_ids: list[int], instruction_len: int) -> list[int]:
    """
    input_ids: 完整序列 [instruction_tokens + response_tokens]
    instruction_len: 指令部分的 token 数量
    返回: labels,-100 表示不计损失
    """
    labels = [-100] * instruction_len + input_ids[instruction_len:]
    return labels

# 使用示例
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3-8B")

instruction = "[INST] 将'猫追老鼠'翻译成英文 [/INST]"
response = "The cat chases the mouse."

inst_ids = tokenizer.encode(instruction, add_special_tokens=True)
resp_ids = tokenizer.encode(response, add_special_tokens=False)
resp_ids.append(tokenizer.eos_token_id)

input_ids = inst_ids + resp_ids
labels = build_sft_labels(input_ids, len(inst_ids))
# labels: [-100, -100, ..., -100, token_id_for_"The", ..., eos_id]

# CrossEntropyLoss 自动忽略 label=-100 的位置
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)

ignore_index=-100 是 PyTorch CrossEntropyLoss 的标准用法,HuggingFace 的 CausalLM 模型默认用此约定处理 loss mask。


相关阅读