Multi-Token Prediction

分类: 高效推理与部署

Multi-Token Prediction

定义

一种 LLM 推理加速技术,让模型在单次前向传播中同时预测多个未来 token,而非传统的逐 token 自回归生成

数学形式

P(xt+1,xt+2,,xt+kxt)i=1kPheadi(xt+ixt)P(x_{t+1}, x_{t+2}, \ldots, x_{t+k} \mid x_{\leq t}) \approx \prod_{i=1}^{k} P_{\text{head}_i}(x_{t+i} \mid x_{\leq t}) 其中 kk 为预测头数量,每个 head 独立预测对应位置的 token

核心要点

标准自回归模型每步只预测 next token,MTP 通过多个预测头并行预测 kk 个 token

推理时需配合 speculative decoding 策略:先用 MTP heads 草拟 kk 个 token,再由主模型验证

acceptance rate(接受率)是关键指标——MTP heads 预测被主模型接受的比例直接决定加速效果

DeepSeek V3 是 MTP 的代表性工业实现,将 MTP 作为训练辅助任务提升模型质量

训练时各 head 独立学习可能导致表示不一致,MTP-D 等方法用 自蒸馏 对齐各 head

代表工作

Meta MTP (2024): 最早系统性提出 multi-token prediction 作为语言模型训练目标

DeepSeek V3: 工业级 MTP 实现,用于训练和推理加速

MTP-D: 用自蒸馏策略对齐多个预测头,提升 acceptance rate

相关概念

自蒸馏 — MTP-D 用自蒸馏对齐多个预测头

early exit — 同属推理加速范畴,但策略不同

知识蒸馏 — MTP 训练中 teacher 信号的上位概念