Self-Distillation for Multi-Token Prediction

作者: Guoliang Zhao, Ruobing Xie, An Wang, Shuaipeng Li, Huaibing Xie, Xingwu Sun 年份: 2026 会议: arXiv 分类: 高效推理与部署

论文笔记:Self-Distillation for Multi-Token Prediction

元信息

项目内容
机构Tencent, Large Language Model Department
日期March 2026
项目主页
对比基线DeepSeek-V3 MTP
链接arXiv

一句话总结

通过梯度分离的自蒸馏策略提升 MTP head 的接受率(+7.5%),并用循环扩展将 MTP head 经济地扩展到 16 个,实现 LLM 推理 3.2 倍加速。

核心贡献

自蒸馏训练策略 MTP-D: 梯度分离 + TopNN logits 选择的自蒸馏,在不影响主头性能的前提下大幅提升 MTP head 接受率

循环扩展策略 (Looped Extension): 利用 MTP head 的结构一致性,通过权重复制 + 继续预训练将 head 数从 4 经济地扩展到 8/16

系统性消融与分析: 对梯度分离、TopN 选择、KL 变体、集成策略等进行全面消融,为 MTP 训练提供实践指导

问题背景

要解决的问题

LLM 的自回归逐 token 生成存在固有的延迟瓶颈,Multi-Token Prediction (MTP) 通过并行预测多个未来 token 加速推理

但 MTP head 的接受率有限,且联合训练困难(“跷跷板效应”使主头性能受损)

现有方法的局限

DeepSeek-V3 的级联 MTP 架构虽然保持了完整因果链,但 MTP head 的损失随 head 索引增加而上升

累积接受率 (CAR) 指数下降:DeepSeek MTP 在 1-to-8 循环设置下,第 3 个 head 的 CAR 降至 0.6%

实际部署中大多数 LLM 仅使用 1-4 个 MTP head,限制了加速潜力

本文的动机

主头(main head)已经学到了丰富的语义知识和分布信息,可以通过自蒸馏将这些知识传递给 MTP head

DeepSeek MTP 架构的结构一致性为循环扩展提供了天然基础

方法详解

模型架构

MTP-D 基于 DeepSeek-V3 的级联 MTP 架构:

  • 输入: token 序列 t1,t2,,tTt_1, t_2, \ldots, t_T
  • Backbone: Dense Transformer(2B)或 MoE Transformer(N10B A1B)
  • 核心模块: KK 个级联 MTP head,每个 head 由单层 Dense layer 构成,与主模型共享 embedding 和 output head
  • 输出: 每个 MTP head kk 预测第 k+1k+1 个未来 token 的概率分布 P^k\hat{P}^k
  • 注意力: GQA,RoPE 位置编码,RMSNorm 归一化

核心模块

模块1: 梯度分离的自蒸馏 (Gradient-Detached Self-Distillation)

设计动机: 利用主头的输出分布作为 teacher 信号,通过 KL 散度 蒸馏来提升 MTP head 的预测质量,同时用 Stop-Gradient 保护主头不被蒸馏梯度干扰

具体实现:

  • 对主头 logits Q^\hat{Q} 应用 stop-gradient sg()\operatorname{sg}(\cdot),阻断梯度回传
  • 仅选择主头输出中 TopN 个最高概率 token 的 logits 进行蒸馏(默认 N=10,000N=10{,}000
  • 对选定 logits 分别计算 teacher 端 softmax 和 student 端 log-softmax,用 forward KL 进行蒸馏

模块2: TopNN Logits 选择 (TopNN-Selected Logits)

设计动机: 词表规模巨大(如 122,880),全词表蒸馏带来冗余计算、内存开销和数值不稳定性

具体实现:

  • 主头 logits 经 softmax 后呈 长尾分布:Top10,000 覆盖 99.52% 累积概率
  • 选择主头 Top-N logits 对应索引 ItN=TopK(Q^t,N)\mathcal{I}_t^N = \text{TopK}(\hat{Q}_t, N)
  • 对 teacher 和 student 均只提取这些索引对应的 logits,再分别做 softmax/log-softmax
  • 避免了低概率 token 带来的弱监督信号和对数运算数值不稳定

模块3: 循环扩展策略 (Looped MTP Head Extension)

设计动机: DeepSeek 级联 MTP 架构具有强结构一致性和输入输出相似性,天然支持通过循环复制扩展 head 数

具体实现:

  • 已训练的 mm 个 MTP head 的权重复制初始化第 m+1m+12m2m 个 head
  • 冻结主模型和已训练的前 mm 个 head,仅训练新 head
  • 继续预训练使用相同的自蒸馏策略,70B tokens 即可达到良好效果
  • 可选集成策略:用主头 + 已训练 MTP head 的集成 logits 作为 teacher

关键公式

公式1: MTP 交叉熵损失

LmtpCE=k=1KαkCE(P^k+1:T+1k,  tk+1:T+1)\mathcal{L}_{\text{mtp}}^{\text{CE}} = \sum_{k=1}^{K} \alpha_k \cdot \text{CE}(\hat{P}_{k+1:T+1}^k, \; t_{k+1:T+1})

含义: 标准的多 token 预测交叉熵损失,将每个 MTP head 的预测与 ground-truth token 对齐

符号说明:

  • KK: MTP head 总数
  • αk\alpha_k: 第 kk 个 MTP head 的权重系数
  • P^k+1:T+1k\hat{P}^k_{k+1:T+1}: 第 kk 个 MTP head 对位置 k+1k+1T+1T+1 的预测分布
  • tk+1:T+1t_{k+1:T+1}: ground-truth token 序列

公式2: TopN Logits 索引选择

ItN=TopK(Q^t,N)\mathcal{I}_t^N = \text{TopK}(\hat{Q}_t, N)

含义: 从主头在位置 tt 的 logits 中选择概率最高的 NN 个 token 的索引

符号说明:

  • Q^t\hat{Q}_t: 主头在位置 tt 的 logits 向量
  • NN: 选取的 top logits 数量(默认 10,000)

公式3: Teacher 端 softmax 归一化

Q~k+1:T+1=σ(Q^k+1:T+1[,ItN])\tilde{Q}_{k+1:T+1} = \sigma\bigl(\hat{Q}_{k+1:T+1}[\ldots, \mathcal{I}_t^N]\bigr)

含义: 对主头选定索引的 logits 做 softmax,得到 teacher 概率分布

符号说明:

  • σ()\sigma(\cdot): softmax 函数
  • Q^[,ItN]\hat{Q}[\ldots, \mathcal{I}_t^N]: 仅保留 TopN 索引对应的 logits 分量

公式4: Student 端 log-softmax

P~k+1:T+1k=log(σ(P^k+1:T+1k[,ItN]))\tilde{P}_{k+1:T+1}^k = \log\bigl(\sigma(\hat{P}_{k+1:T+1}^k[\ldots, \mathcal{I}_t^N])\bigr)

含义: 对 MTP head 选定索引的 logits 做 log-softmax,用于 KL 散度计算

符号说明:

  • P^k\hat{P}^k: 第 kk 个 MTP head 的 logits

公式5: KL 自蒸馏损失

LmtpKL=k=1KβkKL(P~k+1:T+1k,  sg(Q~)k+1:T+1)\mathcal{L}_{\text{mtp}}^{\text{KL}} = \sum_{k=1}^{K} \beta_k \cdot \text{KL}\bigl(\tilde{P}_{k+1:T+1}^k, \; \operatorname{sg}(\tilde{Q})_{k+1:T+1}\bigr)

含义: 梯度分离的 KL 蒸馏损失,让每个 MTP head 的输出分布对齐主头的分布

符号说明:

  • βk\beta_k: 第 kk 个 MTP head 的 KL 损失权重系数(默认 βk=1.0\beta_k = 1.0,K=4 时用 0.5)
  • sg()\operatorname{sg}(\cdot): stop-gradient 操作,阻断梯度回传到主头
  • KL(,)\text{KL}(\cdot, \cdot): forward KL 散度

公式6: MTP-D 总损失

Lmtp=LmtpCE+LmtpKL\mathcal{L}_{\text{mtp}} = \mathcal{L}_{\text{mtp}}^{\text{CE}} + \mathcal{L}_{\text{mtp}}^{\text{KL}}

含义: CE 确保与 ground-truth 的基础对齐,KL 让 MTP head 从主头获取语义知识并约束分布一致性

符号说明:

  • LmtpCE\mathcal{L}_{\text{mtp}}^{\text{CE}}: 交叉熵损失
  • LmtpKL\mathcal{L}_{\text{mtp}}^{\text{KL}}: KL 自蒸馏损失

公式7: 接受率

ARj=s=1SAj(s)s=1SCjcmp(s)\text{AR}_j = \frac{\sum_{s=1}^{S} A_j^{(s)}}{\sum_{s=1}^{S} C_j^{\text{cmp}(s)}}

含义: 第 jj 个 MTP head 的 draft token 被主头验证接受的比率

符号说明:

  • SS: 总样本数
  • Aj(s)A_j^{(s)}: 第 ss 个样本中第 jj 个 head 被接受的 token 数
  • Cjcmp(s)C_j^{\text{cmp}(s)}: 第 ss 个样本中第 jj 个 head 被验证的 token 总数

公式8: 累积接受率

CARj=s=1SAj(s)s=1SCstep(s)\text{CAR}_j = \frac{\sum_{s=1}^{S} A_j^{(s)}}{\sum_{s=1}^{S} C_{\text{step}}^{(s)}}

含义: 相对于总生成步数的接受率,反映实际加速效果

符号说明:

  • Cstep(s)C_{\text{step}}^{(s)}: 第 ss 个样本中每个 head 生成的总 token 数

关键图表

Figure 1: Overview / 自蒸馏方法总览

Figure 1: Overview{:width 600}

说明: MTP-D 的梯度分离、TopNN-logits-selected 自蒸馏方法概览。主头(main head)输出经 stop-gradient 后,选取 Top-N logits 作为 teacher 信号,通过 KL 散度蒸馏到各个 MTP head,同时保持 CE 损失与 ground-truth 的对齐。

Figure 2: Looped Extension / 循环扩展训练策略

Figure 2: Looped Extension{:width 600}

说明: 循环扩展策略示意图。灰色块为冻结的主模型和已训练的 MTP head 1m,其权重被复制初始化 head m+12m(橙色块),仅训练新 head,大幅降低训练成本。

Figure 3: Training-Free Looped Extension Results / 免训练循环扩展结果

Figure 3a: Cumulative Acceptance Rates{:width 600}

Figure 3b: Acceptance Rates{:width 600}

说明: AGIEval-en 上免训练循环扩展的结果。(a) 累积接受率:MTP-D 在循环连接点虽有下降但保持可接受水平(26.70%),而 DeepSeek MTP 在第 3 个 head 就降至 0.6%。(b) 逐 head 接受率:循环连接点有明显下降但逐渐恢复。

Figure 4: Speedup Ratios / 加速比

Figure 4: Speedup{:width 600}

说明: 2B Dense 模型在不同 MTP 方法和 K 设置下的加速比。MTP-D 在 K=4 时实现 22.9% 加速(vs MTP),K=1 时也有约 14% 提升。

Figure 5: Looped Extension Strategies Comparison / 循环扩展策略对比

Figure 5a: Average CAR{:width 600}

Figure 5b: Average Speedup{:width 600}

说明: 不同循环扩展策略在所有 benchmark 上的平均 CAR 和加速比。MTP-D 4-to-8 和 4-to-16 设置表现最佳,4-to-16 达到平均 3.204 倍加速。

Figure 6: TopN Probability Distribution / TopN 概率分布分析

Figure 6: TopN Distribution{:width 400} Figure 6: Top10000{:width 400} Figure 6: Top1000{:width 400}

说明: 主头 logits 的概率分布在不同 TopN 设置下的可视化。全词表呈明显长尾分布,Top10,000 覆盖 99.52% 累积概率,Top1,000 覆盖 83.41%。

Figure 7: Training Loss Curves / 训练损失曲线

Figure 7: Loss Curves{:width 600}

说明: 2B Dense 模型 4 head 配置下的训练损失曲线。MTP-D 的 MTP head 损失始终低于 DeepSeek MTP,且主头损失保持一致。

Figure 8: MoE Speedup Ratios / MoE 模型加速比

Figure 8: MoE Speedup{:width 600}

说明: A1B MoE 模型的加速比。趋势与 Dense 模型一致,MTP-D 在 MoE 架构上同样有效。

Figure 9: Loss Comparison / 损失对比

Figure 9a: Main Head Loss{:width 400} Figure 9b: MTP Head Loss{:width 400}

说明: (a) 主头损失对比:MTP-D 和 DeepSeek MTP 的主头损失几乎重合,证明自蒸馏不影响主头。(b) MTP head 损失对比:MTP-D 的 MTP head 损失显著更低。

Figure 10: Full Vocabulary Training Loss / 全词表训练损失

Figure 10: Full Vocab Loss{:width 600}

说明: 使用全词表(不做 TopN 选择)时的训练损失曲线,出现明显的数值不稳定(loss spike),验证了 TopN 选择的必要性。

Figure 11: Training-Free Looped CAR per Benchmark / 免训练循环 CAR

Figure 11: CAR AGIEval{:width 400} Figure 11: CAR GSM8K{:width 400} Figure 11: CAR MATH{:width 400} Figure 11: CAR NQ{:width 400} Figure 11: CAR SimpleQA{:width 400} Figure 11: CAR SuperGPQA{:width 400} Figure 11: CAR TriviaQA{:width 400}

说明: 7 个 benchmark 上免训练循环扩展的 CAR 详细结果。MTP-D 在所有 benchmark 上均显著优于 DeepSeek MTP。

Figure 12: Training-Free Looped AR per Benchmark / 免训练循环 AR

Figure 12: AR AGIEval{:width 400} Figure 12: AR GSM8K{:width 400} Figure 12: AR MATH{:width 400} Figure 12: AR NQ{:width 400} Figure 12: AR SimpleQA{:width 400} Figure 12: AR SuperGPQA{:width 400} Figure 12: AR TriviaQA{:width 400}

说明: 7 个 benchmark 上免训练循环扩展的逐 head 接受率。循环连接点的 AR 下降在 MTP-D 中更快恢复。

Figure 13: Continued Pre-Training Looped CAR (up to 8) / 继续预训练循环 CAR

Figure 13: CPT CAR AGIEval{:width 400} Figure 13: CPT CAR GSM8K{:width 400} Figure 13: CPT CAR MATH{:width 400} Figure 13: CPT CAR NQ{:width 400} Figure 13: CPT CAR SimpleQA{:width 400} Figure 13: CPT CAR SuperGPQA{:width 400} Figure 13: CPT CAR TriviaQA{:width 400}

说明: 继续预训练后循环扩展到 8 head 的 CAR。经 70B tokens 继续预训练后 CAR 显著提升。

Figure 14: Continued Pre-Training Looped CAR (up to 16) / 继续预训练循环 CAR (16 head)

Figure 14: CPT16 CAR AGIEval{:width 400} Figure 14: CPT16 CAR GSM8K{:width 400} Figure 14: CPT16 CAR MATH{:width 400} Figure 14: CPT16 CAR NQ{:width 400} Figure 14: CPT16 CAR SimpleQA{:width 400} Figure 14: CPT16 CAR SuperGPQA{:width 400} Figure 14: CPT16 CAR TriviaQA{:width 400}

说明: 扩展到 16 head 的 CAR。第 16 个 head 仍维持 5-10% CAR,高接受率 benchmark 从 16 head 中获益最多。

Table 1: 主实验结果(K=4, 2B Dense)

BenchmarkHeadMTP CARMTP ARMTP-D CARMTP-D AR
AGIEval-en185.8685.8685.8685.86
AGIEval-en445.4785.7552.9688.98
GSM8K465.2291.6172.7692.89
MATH437.3778.3946.4281.81
NaturalQuestions463.6393.9171.2295.34
SimpleQA439.1576.8246.2780.38
TriviaQA441.4580.8348.5282.85
SuperGPQA447.5984.8254.9887.94

表格说明: MTP-D 在所有 benchmark 的第 4 个 head 上 CAR 平均提升约 7.5%,AR 平均提升约 3.9%。

Table 1 (续): A1B MoE 结果

BenchmarkHeadMTP CARMTP ARMTP-D CARMTP-D AR
AGIEval-en448.8285.4152.4790.98
GSM8K461.9389.3871.4792.86
MATH430.3873.5737.8976.16
NaturalQuestions460.3590.8068.0592.92
SimpleQA442.3878.7549.1881.48
TriviaQA437.0378.1244.5180.91
SuperGPQA447.3183.3754.2886.44

表格说明: MoE 架构上 MTP-D 的提升更为显著,第 4 个 head 的 CAR 平均提升约 4.7%。

Table 2: 消融实验(K=1, 2B Dense)

配置Main Acc (↑)MTP Acc (↑)说明
MTP-D (default)90.0611.68基线配置
w/o Detach88.57 (-1.49)12.05主头性能受损
TopN=1,00090.4711.47覆盖率不足
Union (10K×2)90.5111.27额外收益有限
β_k=0.187.6610.99KL 权重不足
β_k=0.388.8311.31-
β_k=0.589.7211.47-
β_k=1.590.8011.64主头损失上升
Reverse KL90.2211.62Forward KL 更优
Hybrid KL88.9210.99性能下降

关键发现: (1) 梯度分离对保护主头至关重要(去除后主头 Acc -1.49);(2) β_k=1.0 是最佳平衡点;(3) Forward KL 优于 Reverse/Hybrid KL。

Table 5: 主头性能对比

配置NTP LossMain Acc
NTP baseline2.545710.96
MTP K=12.545811.28
MTP-D K=12.548211.68
MTP K=42.543011.06
MTP-D K=42.549510.96

表格说明: MTP-D 的 NTP loss 与 MTP 基线几乎一致(差异 <0.007),证明自蒸馏对主头无损。

Table 6-7: K=4 集成策略消融

策略Main Acc4th Head CAR
MTP-D (β=0.5)10.96baseline
β=0.311.13lower
Ensemble mean11.22improved
Ensemble split11.20improved

关键发现: 集成策略同时提升主头性能和 MTP 接受率,多元监督信号互补有益。

Table 8: 循环扩展加速结果汇总

配置方法平均加速比
K=1MTP-D1.142×
K=4MTP-D2.372×
1-to-8 (免训练)MTP1.769×
1-to-8 (免训练)MTP-D*1.769×
1-to-8 (继续预训练)MTP2.644×
1-to-8 (继续预训练)MTP-D2.846×
4-to-8MTP-D3.052×
1-to-16 (免训练)MTP-D*1.657×
4-to-16MTP-D3.204×

关键发现: (1) 4-to-16 MTP-D 达到最高 3.204× 加速;(2) 分组循环(4-to-8, 4-to-16)优于单 head 循环(1-to-8, 1-to-16);(3) 70B tokens 继续预训练即可获得显著提升。

实验

数据集

数据集规模特点用途
FineWeb-Edu-350BT350B tokens教育领域高质量网页文本预训练

评估基准

Benchmark特点评估能力
AGIEval-en人类标准化考试综合推理
GSM8K小学数学数学推理
MATH竞赛数学高级数学
NaturalQuestions知识密集 QA知识检索
SimpleQA短事实查询事实性
SuperGPQA研究生级知识深度推理
TriviaQA阅读理解阅读理解

实现细节

模型: 2B Dense (32 layers, hidden 2048, 16 heads) + N10B A1B MoE (22 layers, 128 experts, top-8)

优化器: AdamW(β1,β2)=(0.9,0.95)(\beta_1, \beta_2) = (0.9, 0.95)ϵ=1×108\epsilon = 1 \times 10^{-8}

学习率: max 3×1043 \times 10^{-4}Cosine Decay 到 0

Batch Size: 8192

序列长度: 4096

训练量: 预训练 350B tokens,循环扩展 70B tokens

硬件: 256 × NVIDIA H20 GPU(预训练约 30 天)

Warmup: 预训练 1000 steps,循环扩展 0 steps

正则化: weight decay 0.1, gradient clipping norm 1.0

推理: 贪心解码,单 batch,最大生成 100 tokens,无 KV cache

可视化结果

MTP-D 在所有 7 个 benchmark 上的 CAR 和 AR 均一致性提升

循环扩展到 16 head 后加速比饱和,受级联架构 CAR 递减限制

训练损失曲线显示 MTP-D 的 MTP head 损失始终更低,主头损失几乎不受影响

批判性思考

优点

设计简洁高效: 仅增加 KL 损失项即可显著提升 MTP 接受率,无需修改模型架构

理论动机清晰: 梯度分离保护主头 + TopN 选择解决效率/稳定性问题,每个设计选择都有充分的消融验证

循环扩展策略实用: 利用架构结构一致性,70B tokens 即可扩展 head 数,训练成本极低

消融实验全面: 对所有关键设计(detach、TopN、β_k、KL 变体、集成策略)都进行了系统消融

局限性

仅验证了预训练阶段: 未探索 SFT/RLHF 等 post-training 阶段的适用性

模型规模有限: 仅在 2B 和 10B 上验证,未在 70B+ 超大模型上测试

推理设置受限: 单 batch、无 KV cache、最大 100 tokens,不完全反映实际部署场景

α_k、β_k 与 K 的理论关系不充分: 最优超参数依赖经验调参,缺乏理论指导

CAR 递减问题未根本解决: 级联架构的固有限制使 16+ head 的边际收益递减

潜在改进方向

将 MTP-D 扩展到 post-training 阶段(SFT 微调、RLHF 对齐)

探索非级联架构(如 Meta 的并行 MTP)+ 自蒸馏的组合

引入自适应 β_k 调度策略(如与训练进度或 head 索引挂钩)

在实际部署场景(长文本生成、大 batch、KV cache)中评估加速效果

结合 EAGLE 等 speculative decoding 方法进一步提升验证效率

可复现性评估

  • 代码开源(未提供)
  • 预训练模型(未提供)
  • 训练细节完整(超参数、数据集、硬件均详细记录)
  • 数据集可获取(FineWeb-Edu 公开可用)

关联笔记

基于

DeepSeek-V3: 级联 MTP 架构的基础

Hinton KD: 知识蒸馏的经典框架(Hinton et al., 2015)

对比

Medusa: 另一种 MTP 推理加速框架

EAGLE: Speculative Sampling,重新思考特征不确定性

方法相关

Multi-Token Prediction: 核心预测范式

自蒸馏: 核心训练策略

Speculative Decoding: 推理加速的验证框架

KL Divergence: 蒸馏损失函数

Stop-Gradient: 保护主头的梯度分离技术

硬件/数据相关

FineWeb-Edu: 预训练数据集

NVIDIA H20: 训练硬件

速查卡片

Self-Distillation for Multi-Token Prediction

  • 核心: 梯度分离的自蒸馏提升 MTP head 接受率,循环扩展经济地增加 head 数
  • 方法: 主头 → stop-gradient → TopN logits → KL 蒸馏 → MTP head;循环权重复制 + 70B tokens 继续预训练
  • 结果: K=4 接受率 +7.5%(22.9% 加速),循环扩展到 16 head 达 3.2× 加速
  • 代码: 未公开

笔记创建时间: 2026-03-27