RLHF 完整数学推导

分类: 预训练与微调 · 难度: 进阶 · 关联讲座: L08

RLHF 完整数学推导

本文整理从人类反馈中强化学习(RLHF)的完整数学框架,涵盖预训练 LM 与助手模型的目标差异、SFT 目标函数、Bradley-Terry 偏好建模、奖励模型训练,以及 PPO 优化目标。同时收录 InstructGPT 三阶段训练流水线的形式化描述。


1. 预训练 LM 与助手模型的目标差异

📐 预训练 LM 与助手模型的目标差异

预训练语言模型的目标:在给定任意前缀的情况下,预测最可能的下一个 token:

PLM(xt+1xt;θ)P_{LM}(x_{t+1} \mid x_{\leq t}; \theta)

这个目标对输入内容完全无偏——给定 “The sky is” 和给定 “How do I make a bomb? Answer:” 都一视同仁,只预测下一个最可能的 token。

助手模型的目标:给定用户请求 xx 和系统角色描述 rr,生成有帮助的回复:

Passistant(yx,r;θ)其中 y 最大化用户满意度P_{assistant}(y \mid x, r; \theta) \quad \text{其中 } y \text{ 最大化用户满意度}

两者的根本差异:LM 目标是描述性的(世界上的文本是什么样的),助手目标是规范性的(回复应该是什么样的)。这个差距——对齐问题(Alignment Problem)——是后训练(post-training)存在的根本原因。

2. SFT 目标函数

📐 SFT 目标函数

监督微调(Supervised Fine-Tuning) 与标准语言模型训练的区别仅在于损失计算的范围

JSFT(θ)=(xi,yi)DSFTt=1yilogPθ ⁣(yi(t)yi(<t),xi)J_{SFT}(\theta) = -\sum_{(x_i, y_i) \in D_{SFT}} \sum_{t=1}^{|y_i|} \log P_\theta\!\left(y_i^{(t)} \mid y_i^{(<t)}, x_i\right)

  • (xi,yi)(x_i, y_i):(指令 / 用户请求,示范回复)对
  • 只在回复 yiy_i 上计算损失,指令 xix_i 的 token 损失被 mask 掉(不参与梯度更新)
  • 这与标准 CLM 预训练的区别:SFT 的模型需要学的是”给定指令,生成好的回复”,而不是”给定任意前缀,续写任意文本”

数学上,SFT 等价于在条件分布 P(yx)P(y \mid x) 上做最大似然估计(MLE),以人类示范回复作为正样本。

3. RLHF 完整流程

📐 RLHF 完整流程的数学推导

Step 1:训练奖励模型(Reward Model)

收集人类偏好数据:对同一输入 xx,标注者标记哪个回复更好,形成偏好对 (yw,yl)(y_w, y_l)(winner, loser)。

基于 Bradley-Terry 偏好模型(人类选择 ywy_w 优于 yly_l 的概率):

P(ywylx)=σ(r(x,yw)r(x,yl))P(y_w \succ y_l \mid x) = \sigma(r(x, y_w) - r(x, y_l))

训练 RM 最大化偏好数据的对数似然(最小化负对数似然):

JRM=E(x,yw,yl)Dpref[logσ ⁣(rϕ(x,yw)rϕ(x,yl))]J_{RM} = -\mathbb{E}_{(x, y_w, y_l) \sim D_{pref}} \left[\log \sigma\!\left(r_\phi(x, y_w) - r_\phi(x, y_l)\right)\right]

Step 2:PPO 微调策略

最大化奖励,同时通过 KL 惩罚约束策略不偏离参考模型过远:

JPPO(θ)=ExD,yπθ(x) ⁣[rϕ(x,y)βlogπθ(yx)πref(yx)]J_{PPO}(\theta) = \mathbb{E}_{x \sim D,\, y \sim \pi_\theta(\cdot|x)}\!\left[r_\phi(x, y) - \beta \cdot \log\frac{\pi_\theta(y \mid x)}{\pi_{ref}(y \mid x)}\right]

其中 β[0.01,0.1]\beta \in [0.01, 0.1] 是 KL 惩罚系数,πref\pi_{ref} 是 SFT 模型(固定不更新)。

完整训练所需的模型数量:Policy(更新) + Reference(固定) + Reward(固定) + Value(更新)= 4 个模型同时在显存中,这是 PPO 内存开销大的根本原因。

4. InstructGPT 三阶段训练流水线

📐 InstructGPT 三阶段训练流水线

阶段 1 — SFT:在人工示范数据上监督微调,得到 πSFT\pi_{SFT}θSFT=argmaxθ(x,y)DSFTlogPθ(yx)\theta_{SFT} = \arg\max_\theta \sum_{(x,y)\in D_{SFT}} \log P_\theta(y \mid x)

阶段 2 — RM 训练:学习人类偏好排序,得到奖励函数 rϕ(x,y)r_\phi(x, y)ϕ=argminϕE(x,yw,yl)[logσ(rϕ(x,yw)rϕ(x,yl))]\phi^* = \arg\min_\phi -\mathbb{E}_{(x,y_w,y_l)}\left[\log\sigma(r_\phi(x,y_w) - r_\phi(x,y_l))\right]

阶段 3 — PPO:在 RM 的奖励下优化策略,约束不偏离 πSFT\pi_{SFT}θ=argmaxθEyπθ[rϕ(x,y)βKL(πθπSFT)]\theta^* = \arg\max_\theta \mathbb{E}_{y \sim \pi_\theta}\left[r_\phi(x,y) - \beta \cdot KL(\pi_\theta \| \pi_{SFT})\right]

关键结论(InstructGPT 论文原文数字)

在人类评测中,1.3B 参数 InstructGPT 优于 175B 参数 GPT-3(85% 的评测者偏好 InstructGPT 输出)。参数量相差 134 倍,但对齐使小模型胜出——说明行为对齐的收益远超单纯的模型规模

5. PPO Clip 算法机制与 Token 级实现

📐 PPO Clip 目标完整推导

问题:朴素 REINFORCE 的高方差

REINFORCE 梯度:θJ=Eyπθ ⁣[θlogπθ(yx)A(x,y)]\nabla_\theta J = \mathbb{E}_{y \sim \pi_\theta}\!\left[\nabla_\theta \log \pi_\theta(y|x) \cdot A(x,y)\right]

当一步更新过大时,重要性采样权重 πθπold\frac{\pi_\theta}{\pi_{old}} 方差爆炸,训练崩溃。

PPO 的核心:Clip 重要性比

定义 token tt 的重要性比(Importance Ratio):

rt(θ)=πθ(ytx,y<t)πold(ytx,y<t)r_t(\theta) = \frac{\pi_\theta(y_t \mid x, y_{<t})}{\pi_{old}(y_t \mid x, y_{<t})}

PPO Clip 损失(ϵ=0.2\epsilon = 0.2):

LCLIP(θ)=Et ⁣[min ⁣(rt(θ)At,    clip(rt(θ),1ϵ,1+ϵ)At)]L^{CLIP}(\theta) = \mathbb{E}_t\!\left[\min\!\left(r_t(\theta)\cdot A_t,\;\; \text{clip}(r_t(\theta),\, 1-\epsilon,\, 1+\epsilon)\cdot A_t\right)\right]

min 函数确保目标是悲观下界:只允许对策略有利的情况按实际计,对策略不利时取保守截断值。

Clip 行为分情况

优势符号rtr_t 越界时clip 作用
At>0A_t > 0(好动作),rt>1.2r_t > 1.2已过度增加概率截断,禁止继续提高
At<0A_t < 0(坏动作),rt<0.8r_t < 0.8已过度降低概率截断,禁止继续打压

LLM 后训练中的 Token 级 MDP 映射

RL 概念LLM 中的对应
状态 sts_t已生成序列 (x,y1,,yt1)(x, y_1, \ldots, y_{t-1})
动作 ata_t下一个 token yty_t($
终末奖励RM 对完整回复评分 rϕ(x,y1:T)r_\phi(x, y_{1:T})
中间惩罚Per-token KL:βlogπθ(yt)πref(yt)-\beta \log\frac{\pi_\theta(y_t)}{\pi_{ref}(y_t)}

完整的 token 级奖励信号:

r~t={rϕ(x,y1:T)βlogπθ(yTsT)πref(yTsT)t=Tβlogπθ(ytst)πref(ytst)t<T\tilde{r}_t = \begin{cases} r_\phi(x, y_{1:T}) - \beta \log\dfrac{\pi_\theta(y_T|s_T)}{\pi_{ref}(y_T|s_T)} & t = T \\ -\beta \log\dfrac{\pi_\theta(y_t|s_t)}{\pi_{ref}(y_t|s_t)} & t < T \end{cases}

优势函数用 GAE 计算:AtGAE=k0(γλ)kδt+kA_t^{GAE} = \sum_{k \geq 0}(\gamma\lambda)^k \delta_{t+k},其中 δt=r~t+γVϕ(st+1)Vϕ(st)\delta_t = \tilde{r}_t + \gamma V_\phi(s_{t+1}) - V_\phi(s_t)

4 模型架构

模型是否更新作用
Policy πθ\pi_\theta生成回复
Value Network VϕV_\phi估计每 token 期望回报(与 Policy 等大)
Reward Model rϕr_\phi对完整序列打分
Reference Policy πref\pi_{ref}KL 惩罚基准

Value Network 与 Policy 参数量相当且需要梯度,是 PPO 显存开销是 DPO 两倍的根本原因。

6. RLHF / PPO / DPO / GRPO 层次关系

RLHF(框架): 用人类偏好数据训练 RM,再用 RL 最大化 RM 奖励

├── PPO(标准 RL 实现)
│   需要:Policy + Value + RM + Reference(4 模型)
│   在线采样 → 适合探索,代价高

├── GRPO(PPO 的轻量变体)
│   去掉 Value Network,用组内统计估计优势
│   需要:Policy + RM + Reference(3 模型)
│   适合:稀疏自动评分任务(数学/代码)

└── DPO(RLHF 目标的数学重参数化)
    绕过 RL 和 RM,直接在偏好对上优化
    需要:Policy + Reference(2 模型)
    离线,简单稳定,但无在线探索能力

三者共享底层假设:Bradley-Terry 偏好模型 P(ywylx)=σ(rwrl)P(y_w \succ y_l \mid x) = \sigma(r_w - r_l)