Self-Distillation for Multi-Token Prediction
论文笔记:Self-Distillation for Multi-Token Prediction
元信息
| 项目 | 内容 |
|---|---|
| 机构 | Tencent, Large Language Model Department |
| 日期 | March 2026 |
| 项目主页 | 无 |
| 对比基线 | DeepSeek-V3 MTP |
| 链接 | arXiv |
一句话总结
通过梯度分离的自蒸馏策略提升 MTP head 的接受率(+7.5%),并用循环扩展将 MTP head 经济地扩展到 16 个,实现 LLM 推理 3.2 倍加速。
核心贡献
自蒸馏训练策略 MTP-D: 梯度分离 + TopNN logits 选择的自蒸馏,在不影响主头性能的前提下大幅提升 MTP head 接受率
循环扩展策略 (Looped Extension): 利用 MTP head 的结构一致性,通过权重复制 + 继续预训练将 head 数从 4 经济地扩展到 8/16
系统性消融与分析: 对梯度分离、TopN 选择、KL 变体、集成策略等进行全面消融,为 MTP 训练提供实践指导
问题背景
要解决的问题
LLM 的自回归逐 token 生成存在固有的延迟瓶颈,Multi-Token Prediction (MTP) 通过并行预测多个未来 token 加速推理
但 MTP head 的接受率有限,且联合训练困难(“跷跷板效应”使主头性能受损)
现有方法的局限
DeepSeek-V3 的级联 MTP 架构虽然保持了完整因果链,但 MTP head 的损失随 head 索引增加而上升
累积接受率 (CAR) 指数下降:DeepSeek MTP 在 1-to-8 循环设置下,第 3 个 head 的 CAR 降至 0.6%
实际部署中大多数 LLM 仅使用 1-4 个 MTP head,限制了加速潜力
本文的动机
主头(main head)已经学到了丰富的语义知识和分布信息,可以通过自蒸馏将这些知识传递给 MTP head
DeepSeek MTP 架构的结构一致性为循环扩展提供了天然基础
方法详解
模型架构
MTP-D 基于 DeepSeek-V3 的级联 MTP 架构:
- 输入: token 序列
- Backbone: Dense Transformer(2B)或 MoE Transformer(N10B A1B)
- 核心模块: 个级联 MTP head,每个 head 由单层 Dense layer 构成,与主模型共享 embedding 和 output head
- 输出: 每个 MTP head 预测第 个未来 token 的概率分布
- 注意力: GQA,RoPE 位置编码,RMSNorm 归一化
核心模块
模块1: 梯度分离的自蒸馏 (Gradient-Detached Self-Distillation)
设计动机: 利用主头的输出分布作为 teacher 信号,通过 KL 散度 蒸馏来提升 MTP head 的预测质量,同时用 Stop-Gradient 保护主头不被蒸馏梯度干扰
具体实现:
- 对主头 logits 应用 stop-gradient ,阻断梯度回传
- 仅选择主头输出中 TopN 个最高概率 token 的 logits 进行蒸馏(默认 )
- 对选定 logits 分别计算 teacher 端 softmax 和 student 端 log-softmax,用 forward KL 进行蒸馏
模块2: TopNN Logits 选择 (TopNN-Selected Logits)
设计动机: 词表规模巨大(如 122,880),全词表蒸馏带来冗余计算、内存开销和数值不稳定性
具体实现:
- 主头 logits 经 softmax 后呈 长尾分布:Top10,000 覆盖 99.52% 累积概率
- 选择主头 Top-N logits 对应索引
- 对 teacher 和 student 均只提取这些索引对应的 logits,再分别做 softmax/log-softmax
- 避免了低概率 token 带来的弱监督信号和对数运算数值不稳定
模块3: 循环扩展策略 (Looped MTP Head Extension)
设计动机: DeepSeek 级联 MTP 架构具有强结构一致性和输入输出相似性,天然支持通过循环复制扩展 head 数
具体实现:
- 已训练的 个 MTP head 的权重复制初始化第 到 个 head
- 冻结主模型和已训练的前 个 head,仅训练新 head
- 继续预训练使用相同的自蒸馏策略,70B tokens 即可达到良好效果
- 可选集成策略:用主头 + 已训练 MTP head 的集成 logits 作为 teacher
关键公式
公式1: MTP 交叉熵损失
含义: 标准的多 token 预测交叉熵损失,将每个 MTP head 的预测与 ground-truth token 对齐
符号说明:
- : MTP head 总数
- : 第 个 MTP head 的权重系数
- : 第 个 MTP head 对位置 到 的预测分布
- : ground-truth token 序列
公式2: TopN Logits 索引选择
含义: 从主头在位置 的 logits 中选择概率最高的 个 token 的索引
符号说明:
- : 主头在位置 的 logits 向量
- : 选取的 top logits 数量(默认 10,000)
公式3: Teacher 端 softmax 归一化
含义: 对主头选定索引的 logits 做 softmax,得到 teacher 概率分布
符号说明:
- : softmax 函数
- : 仅保留 TopN 索引对应的 logits 分量
公式4: Student 端 log-softmax
含义: 对 MTP head 选定索引的 logits 做 log-softmax,用于 KL 散度计算
符号说明:
- : 第 个 MTP head 的 logits
公式5: KL 自蒸馏损失
含义: 梯度分离的 KL 蒸馏损失,让每个 MTP head 的输出分布对齐主头的分布
符号说明:
- : 第 个 MTP head 的 KL 损失权重系数(默认 ,K=4 时用 0.5)
- : stop-gradient 操作,阻断梯度回传到主头
- : forward KL 散度
公式6: MTP-D 总损失
含义: CE 确保与 ground-truth 的基础对齐,KL 让 MTP head 从主头获取语义知识并约束分布一致性
符号说明:
- : 交叉熵损失
- : KL 自蒸馏损失
公式7: 接受率
含义: 第 个 MTP head 的 draft token 被主头验证接受的比率
符号说明:
- : 总样本数
- : 第 个样本中第 个 head 被接受的 token 数
- : 第 个样本中第 个 head 被验证的 token 总数
公式8: 累积接受率
含义: 相对于总生成步数的接受率,反映实际加速效果
符号说明:
- : 第 个样本中每个 head 生成的总 token 数
关键图表
Figure 1: Overview / 自蒸馏方法总览
{:width 600}
说明: MTP-D 的梯度分离、TopNN-logits-selected 自蒸馏方法概览。主头(main head)输出经 stop-gradient 后,选取 Top-N logits 作为 teacher 信号,通过 KL 散度蒸馏到各个 MTP head,同时保持 CE 损失与 ground-truth 的对齐。
Figure 2: Looped Extension / 循环扩展训练策略
{:width 600}
说明: 循环扩展策略示意图。灰色块为冻结的主模型和已训练的 MTP head 1m,其权重被复制初始化 head m+12m(橙色块),仅训练新 head,大幅降低训练成本。
Figure 3: Training-Free Looped Extension Results / 免训练循环扩展结果
{:width 600}
{:width 600}
说明: AGIEval-en 上免训练循环扩展的结果。(a) 累积接受率:MTP-D 在循环连接点虽有下降但保持可接受水平(26.70%),而 DeepSeek MTP 在第 3 个 head 就降至 0.6%。(b) 逐 head 接受率:循环连接点有明显下降但逐渐恢复。
Figure 4: Speedup Ratios / 加速比
{:width 600}
说明: 2B Dense 模型在不同 MTP 方法和 K 设置下的加速比。MTP-D 在 K=4 时实现 22.9% 加速(vs MTP),K=1 时也有约 14% 提升。
Figure 5: Looped Extension Strategies Comparison / 循环扩展策略对比
{:width 600}
{:width 600}
说明: 不同循环扩展策略在所有 benchmark 上的平均 CAR 和加速比。MTP-D 4-to-8 和 4-to-16 设置表现最佳,4-to-16 达到平均 3.204 倍加速。
Figure 6: TopN Probability Distribution / TopN 概率分布分析
{:width 400}
{:width 400}
{:width 400}
说明: 主头 logits 的概率分布在不同 TopN 设置下的可视化。全词表呈明显长尾分布,Top10,000 覆盖 99.52% 累积概率,Top1,000 覆盖 83.41%。
Figure 7: Training Loss Curves / 训练损失曲线
{:width 600}
说明: 2B Dense 模型 4 head 配置下的训练损失曲线。MTP-D 的 MTP head 损失始终低于 DeepSeek MTP,且主头损失保持一致。
Figure 8: MoE Speedup Ratios / MoE 模型加速比
{:width 600}
说明: A1B MoE 模型的加速比。趋势与 Dense 模型一致,MTP-D 在 MoE 架构上同样有效。
Figure 9: Loss Comparison / 损失对比
{:width 400}
{:width 400}
说明: (a) 主头损失对比:MTP-D 和 DeepSeek MTP 的主头损失几乎重合,证明自蒸馏不影响主头。(b) MTP head 损失对比:MTP-D 的 MTP head 损失显著更低。
Figure 10: Full Vocabulary Training Loss / 全词表训练损失
{:width 600}
说明: 使用全词表(不做 TopN 选择)时的训练损失曲线,出现明显的数值不稳定(loss spike),验证了 TopN 选择的必要性。
Figure 11: Training-Free Looped CAR per Benchmark / 免训练循环 CAR
{:width 400}
{:width 400}
{:width 400}
{:width 400}
{:width 400}
{:width 400}
{:width 400}
说明: 7 个 benchmark 上免训练循环扩展的 CAR 详细结果。MTP-D 在所有 benchmark 上均显著优于 DeepSeek MTP。
Figure 12: Training-Free Looped AR per Benchmark / 免训练循环 AR
{:width 400}
{:width 400}
{:width 400}
{:width 400}
{:width 400}
{:width 400}
{:width 400}
说明: 7 个 benchmark 上免训练循环扩展的逐 head 接受率。循环连接点的 AR 下降在 MTP-D 中更快恢复。
Figure 13: Continued Pre-Training Looped CAR (up to 8) / 继续预训练循环 CAR
{:width 400}
{:width 400}
{:width 400}
{:width 400}
{:width 400}
{:width 400}
{:width 400}
说明: 继续预训练后循环扩展到 8 head 的 CAR。经 70B tokens 继续预训练后 CAR 显著提升。
Figure 14: Continued Pre-Training Looped CAR (up to 16) / 继续预训练循环 CAR (16 head)
{:width 400}
{:width 400}
{:width 400}
{:width 400}
{:width 400}
{:width 400}
{:width 400}
说明: 扩展到 16 head 的 CAR。第 16 个 head 仍维持 5-10% CAR,高接受率 benchmark 从 16 head 中获益最多。
Table 1: 主实验结果(K=4, 2B Dense)
| Benchmark | Head | MTP CAR | MTP AR | MTP-D CAR | MTP-D AR |
|---|---|---|---|---|---|
| AGIEval-en | 1 | 85.86 | 85.86 | 85.86 | 85.86 |
| AGIEval-en | 4 | 45.47 | 85.75 | 52.96 | 88.98 |
| GSM8K | 4 | 65.22 | 91.61 | 72.76 | 92.89 |
| MATH | 4 | 37.37 | 78.39 | 46.42 | 81.81 |
| NaturalQuestions | 4 | 63.63 | 93.91 | 71.22 | 95.34 |
| SimpleQA | 4 | 39.15 | 76.82 | 46.27 | 80.38 |
| TriviaQA | 4 | 41.45 | 80.83 | 48.52 | 82.85 |
| SuperGPQA | 4 | 47.59 | 84.82 | 54.98 | 87.94 |
表格说明: MTP-D 在所有 benchmark 的第 4 个 head 上 CAR 平均提升约 7.5%,AR 平均提升约 3.9%。
Table 1 (续): A1B MoE 结果
| Benchmark | Head | MTP CAR | MTP AR | MTP-D CAR | MTP-D AR |
|---|---|---|---|---|---|
| AGIEval-en | 4 | 48.82 | 85.41 | 52.47 | 90.98 |
| GSM8K | 4 | 61.93 | 89.38 | 71.47 | 92.86 |
| MATH | 4 | 30.38 | 73.57 | 37.89 | 76.16 |
| NaturalQuestions | 4 | 60.35 | 90.80 | 68.05 | 92.92 |
| SimpleQA | 4 | 42.38 | 78.75 | 49.18 | 81.48 |
| TriviaQA | 4 | 37.03 | 78.12 | 44.51 | 80.91 |
| SuperGPQA | 4 | 47.31 | 83.37 | 54.28 | 86.44 |
表格说明: MoE 架构上 MTP-D 的提升更为显著,第 4 个 head 的 CAR 平均提升约 4.7%。
Table 2: 消融实验(K=1, 2B Dense)
| 配置 | Main Acc (↑) | MTP Acc (↑) | 说明 |
|---|---|---|---|
| MTP-D (default) | 90.06 | 11.68 | 基线配置 |
| w/o Detach | 88.57 (-1.49) | 12.05 | 主头性能受损 |
| TopN=1,000 | 90.47 | 11.47 | 覆盖率不足 |
| Union (10K×2) | 90.51 | 11.27 | 额外收益有限 |
| β_k=0.1 | 87.66 | 10.99 | KL 权重不足 |
| β_k=0.3 | 88.83 | 11.31 | - |
| β_k=0.5 | 89.72 | 11.47 | - |
| β_k=1.5 | 90.80 | 11.64 | 主头损失上升 |
| Reverse KL | 90.22 | 11.62 | Forward KL 更优 |
| Hybrid KL | 88.92 | 10.99 | 性能下降 |
关键发现: (1) 梯度分离对保护主头至关重要(去除后主头 Acc -1.49);(2) β_k=1.0 是最佳平衡点;(3) Forward KL 优于 Reverse/Hybrid KL。
Table 5: 主头性能对比
| 配置 | NTP Loss | Main Acc |
|---|---|---|
| NTP baseline | 2.5457 | 10.96 |
| MTP K=1 | 2.5458 | 11.28 |
| MTP-D K=1 | 2.5482 | 11.68 |
| MTP K=4 | 2.5430 | 11.06 |
| MTP-D K=4 | 2.5495 | 10.96 |
表格说明: MTP-D 的 NTP loss 与 MTP 基线几乎一致(差异 <0.007),证明自蒸馏对主头无损。
Table 6-7: K=4 集成策略消融
| 策略 | Main Acc | 4th Head CAR |
|---|---|---|
| MTP-D (β=0.5) | 10.96 | baseline |
| β=0.3 | 11.13 | lower |
| Ensemble mean | 11.22 | improved |
| Ensemble split | 11.20 | improved |
关键发现: 集成策略同时提升主头性能和 MTP 接受率,多元监督信号互补有益。
Table 8: 循环扩展加速结果汇总
| 配置 | 方法 | 平均加速比 |
|---|---|---|
| K=1 | MTP-D | 1.142× |
| K=4 | MTP-D | 2.372× |
| 1-to-8 (免训练) | MTP | 1.769× |
| 1-to-8 (免训练) | MTP-D* | 1.769× |
| 1-to-8 (继续预训练) | MTP | 2.644× |
| 1-to-8 (继续预训练) | MTP-D | 2.846× |
| 4-to-8 | MTP-D | 3.052× |
| 1-to-16 (免训练) | MTP-D* | 1.657× |
| 4-to-16 | MTP-D | 3.204× |
关键发现: (1) 4-to-16 MTP-D 达到最高 3.204× 加速;(2) 分组循环(4-to-8, 4-to-16)优于单 head 循环(1-to-8, 1-to-16);(3) 70B tokens 继续预训练即可获得显著提升。
实验
数据集
| 数据集 | 规模 | 特点 | 用途 |
|---|---|---|---|
| FineWeb-Edu-350BT | 350B tokens | 教育领域高质量网页文本 | 预训练 |
评估基准
| Benchmark | 特点 | 评估能力 |
|---|---|---|
| AGIEval-en | 人类标准化考试 | 综合推理 |
| GSM8K | 小学数学 | 数学推理 |
| MATH | 竞赛数学 | 高级数学 |
| NaturalQuestions | 知识密集 QA | 知识检索 |
| SimpleQA | 短事实查询 | 事实性 |
| SuperGPQA | 研究生级知识 | 深度推理 |
| TriviaQA | 阅读理解 | 阅读理解 |
实现细节
模型: 2B Dense (32 layers, hidden 2048, 16 heads) + N10B A1B MoE (22 layers, 128 experts, top-8)
优化器: AdamW,,
学习率: max ,Cosine Decay 到 0
Batch Size: 8192
序列长度: 4096
训练量: 预训练 350B tokens,循环扩展 70B tokens
硬件: 256 × NVIDIA H20 GPU(预训练约 30 天)
Warmup: 预训练 1000 steps,循环扩展 0 steps
正则化: weight decay 0.1, gradient clipping norm 1.0
推理: 贪心解码,单 batch,最大生成 100 tokens,无 KV cache
可视化结果
MTP-D 在所有 7 个 benchmark 上的 CAR 和 AR 均一致性提升
循环扩展到 16 head 后加速比饱和,受级联架构 CAR 递减限制
训练损失曲线显示 MTP-D 的 MTP head 损失始终更低,主头损失几乎不受影响
批判性思考
优点
设计简洁高效: 仅增加 KL 损失项即可显著提升 MTP 接受率,无需修改模型架构
理论动机清晰: 梯度分离保护主头 + TopN 选择解决效率/稳定性问题,每个设计选择都有充分的消融验证
循环扩展策略实用: 利用架构结构一致性,70B tokens 即可扩展 head 数,训练成本极低
消融实验全面: 对所有关键设计(detach、TopN、β_k、KL 变体、集成策略)都进行了系统消融
局限性
仅验证了预训练阶段: 未探索 SFT/RLHF 等 post-training 阶段的适用性
模型规模有限: 仅在 2B 和 10B 上验证,未在 70B+ 超大模型上测试
推理设置受限: 单 batch、无 KV cache、最大 100 tokens,不完全反映实际部署场景
α_k、β_k 与 K 的理论关系不充分: 最优超参数依赖经验调参,缺乏理论指导
CAR 递减问题未根本解决: 级联架构的固有限制使 16+ head 的边际收益递减
潜在改进方向
将 MTP-D 扩展到 post-training 阶段(SFT 微调、RLHF 对齐)
探索非级联架构(如 Meta 的并行 MTP)+ 自蒸馏的组合
引入自适应 β_k 调度策略(如与训练进度或 head 索引挂钩)
在实际部署场景(长文本生成、大 batch、KV cache)中评估加速效果
结合 EAGLE 等 speculative decoding 方法进一步提升验证效率
可复现性评估
- 代码开源(未提供)
- 预训练模型(未提供)
- 训练细节完整(超参数、数据集、硬件均详细记录)
- 数据集可获取(FineWeb-Edu 公开可用)
关联笔记
基于
DeepSeek-V3: 级联 MTP 架构的基础
Hinton KD: 知识蒸馏的经典框架(Hinton et al., 2015)
对比
Medusa: 另一种 MTP 推理加速框架
EAGLE: Speculative Sampling,重新思考特征不确定性
方法相关
Multi-Token Prediction: 核心预测范式
自蒸馏: 核心训练策略
Speculative Decoding: 推理加速的验证框架
KL Divergence: 蒸馏损失函数
Stop-Gradient: 保护主头的梯度分离技术
硬件/数据相关
FineWeb-Edu: 预训练数据集
NVIDIA H20: 训练硬件
速查卡片
Self-Distillation for Multi-Token Prediction
- 核心: 梯度分离的自蒸馏提升 MTP head 接受率,循环扩展经济地增加 head 数
- 方法: 主头 → stop-gradient → TopN logits → KL 蒸馏 → MTP head;循环权重复制 + 70B tokens 继续预训练
- 结果: K=4 接受率 +7.5%(22.9% 加速),循环扩展到 16 head 达 3.2× 加速
- 代码: 未公开
笔记创建时间: 2026-03-27