GKD

分类: 知识蒸馏

GKD

定义

Google 提出的广义知识蒸馏框架,统一了 on-policy 和 off-policy LLM 蒸馏,通过在学生生成的序列上计算 divergence 来缓解 train-test distribution mismatch

数学形式

LGKD=ExDEypθ[Df(pT(x,y<t)pθ(x,y<t))]\mathcal{L}_{\text{GKD}} = \mathbb{E}_{x \sim \mathcal{D}} \mathbb{E}_{y \sim p_\theta} [D_f(p_T(\cdot|x, y_{<t}) \| p_\theta(\cdot|x, y_{<t}))]

其中 DfD_f 可以是任意 f-divergence(KL, reverse KL, JSD 等)

核心要点

将 on-policy 数据生成与多种 divergence 度量解耦

支持 f-divergence 家族的统一框架

学生在自身采样序列上训练,避免 exposure bias

在 summarization 和 translation 任务上优于标准 KD

代表工作

Agarwal et al., “On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes” (ICLR 2024)

相关概念

MiniLLM

DistiLLM

FitNet