TIDE: Token-Informed Depth Execution for Per-Token Early Exit in LLM Inference
论文笔记:TIDE: Token-Informed Depth Execution for Per-Token Early Exit in LLM Inference
元信息
| 项目 | 内容 |
|---|---|
| 机构 | RightNow AI |
| 日期 | March 2026 |
| 项目主页 | — |
| 对比基线 | DeeBERT, CALM, LayerSkip, EE-LLM, SkipDecode, MoD |
| 链接 | arXiv / Code |
一句话总结
提出一种无需重训练的 per-token 提前退出系统,通过在检查点层附加轻量 router MLP 实现 LLM 推理加速
核心贡献
Post-training 提前退出: 无需修改原始模型权重,仅通过离线校准训练轻量 router 即可实现 per-token early exit
Universal Model Adapter: 通过 17 条属性路径自动探测 Transformer 架构,支持 LLaMA/Mistral/Qwen/GPT-2 等主流模型
融合 CUDA 内核: 将 RMSNorm + router 评估融合为单次 kernel launch,原生支持 fp16/bf16,覆盖 8 种模板特化
问题背景
要解决的问题
LLM 推理中,不同 token 的”难度”不同——简单 token(如冠词”the”)无需经过所有 Transformer 层即可收敛,但标准推理仍让每个 token 走完全部层,造成计算浪费
现有方法的局限
DeeBERT 等 encoder 方法不处理 KV Cache 和自回归生成
CALM 针对 encoder-decoder(T5),需修改训练流程
LayerSkip 和 EE-LLM 需要从头训练或修改预训练 pipeline
SkipDecode 不支持 per-token 决策,所有 token 在同一层退出
Mixture-of-Depths 需在训练阶段引入 routing,无法 post-hoc 使用
本文的动机
如果能在不修改模型权重的前提下,通过轻量的 learned router 判断每个 token 的隐状态是否已收敛,就能在保持精度的同时减少冗余计算。关键洞察:token 在中间层的 余弦相似度 与最终层很高时,后续层对该 token 几乎没有贡献
方法详解
系统架构
TIDE 采用 两阶段架构(离线校准 + 在线推理):
- 离线校准: 在冻结模型上收集隐状态 → 计算 余弦相似度 → 训练二分类 router MLP
- 在线推理: 完整前向传播(保持 KV Cache 完整性)→ router post-hoc 评估 → 选择最早收敛层计算 logits
- Router 参数量: 每个 checkpoint 仅 个参数(对 ,约 524K 参数)
- 总 checkpoint 大小: ~4 MB
核心模块
模块1: 离线校准(Calibration)
设计动机: 利用 余弦相似度 度量隐状态收敛程度,避免启发式 confidence 阈值
具体实现:
- 在 2,000 条 WikiText-103 文本上运行冻结模型,收集所有 checkpoint 层的隐状态
- 计算每个 token 在 checkpoint 层 与最终层 的 余弦相似度
- 当相似度超过阈值 时标记为”已收敛”()
- 用 Binary Cross-Entropy 损失训练二层 MLP router,Adam 优化器(),100 epochs
- DeepSeek R1 Distill 8B 在 A100 上校准仅需 170 秒(339,853 tokens)
模块2: Universal Model Adapter
设计动机: 不同 HuggingFace 模型的内部命名规范各异,需要自动适配
具体实现:
- 探测 5 类组件的 17 条属性路径:
- Layers:
model.layers,transformer.h,transformer.layers,gpt_neox.layers,model.decoder.layers+ largest-ModuleList fallback - Final Norm: 5 条路径 + sibling-of-layers 启发式
- LM Head: 2 条路径 + vocab-size shape matching
- Embedding: 5 条路径 + vocab-size shape matching
- Hidden Dim:
model.config.hidden_size(HuggingFace 通用)
- Layers:
- 支持架构:LLaMA, Mistral, Qwen, GPT-2, GPT-NeoX, Phi, Falcon, OPT, Gemma
- 用户可通过
register_adapter()注册自定义适配器
模块3: Post-Hoc Generation
设计动机: 保持 KV Cache 完整性是自回归生成的刚性约束,直接跳层会破坏缓存一致性
具体实现:
- 完整前向传播,收集所有 checkpoint 层隐状态
- 按层序遍历 router:若 (对 batch 中所有 token),则用 计算 logits
- 无 early exit 触发时回退到标准 logits
- 优势: 所有层都执行,KV Cache 始终完整填充;兼容任意 transformers 库版本
模块4: 融合 CUDA 内核
设计动机: router 评估引入的额外开销需通过 kernel 融合最小化
具体实现:
- Fused LayerNorm + Route: 单 kernel 完成 RMSNorm → down-projection → SiLU → up-projection → Sigmoid,使用 warp-level reduction(
__shfl_xor_sync) - Batch Compact: 分离 continuing/exiting tokens,小 batch 用 warp balloting(
__ballot_sync),大 batch 用 prefix-sum scatter - Exit Scatter: 将 exited tokens 的隐状态复制回原始输出 buffer 位置
- Exit Projection: 融合 RMSNorm + scatter
- 所有 kernel 模板化支持 float32 / fp16 / bf16,累积运算始终为 float32
关键公式
公式1: 隐状态收敛度量
含义: 衡量 token 在 checkpoint 层 的隐状态与最终层 的相似程度,值越接近 1 说明该 token 越早收敛
符号说明:
- : token 在第 层的隐状态向量
- : token 在最终层 的隐状态向量
- : 收敛度分数
公式2: 收敛标签生成
含义: 当余弦相似度超过阈值 时,标记该 token 在该层已收敛
符号说明:
- : 默认收敛阈值(较严格)
- : 二分类标签
公式3: Router MLP
含义: 每个 checkpoint 层的二层 MLP router,输出退出概率
符号说明:
公式4: Post-Hoc 退出决策
含义: 选择满足阈值条件的最早 checkpoint 层,用该层隐状态计算 logits
符号说明:
- : checkpoint 层集合(间隔 )
- : 推理时退出阈值(可调,不同于训练阈值 )
- : 最小退出层(防止过早退出)
- : 第 层的隐状态
关键图表
Figure 1: TIDE System Overview / 系统概览
{:width 600}
说明: TIDE 系统概览。左侧为一次性离线校准流程:在冻结模型上收集隐状态 → 计算 per-token 与最终层的 余弦相似度 → 在每个 checkpoint 训练 binary router MLP。右侧为推理流程:完整前向传播(保持 KV Cache)→ router post-hoc 评估 → 选择最早收敛层计算 logits。编号步骤显示执行顺序。
Figure 2: Per-Token Early Exit Visualization / 逐 Token 提前退出可视化
说明: 在 32 层模型中的 per-token early exit 示意。Token “the” 在第 8 层收敛,跳过后续 24 层;Token “cat” 在第 16 层退出;只有 “sat” 需要完整深度。三个 token 总计 次 layer 操作(vs. baseline 的 ),减少 42% 计算。TIDE 在完整前向传播后 post-hoc 评估 router 。
Table 1: Early Exit 系统对比
| System | Decoder | Post-training | Per-token | CUDA | Universal |
|---|---|---|---|---|---|
| DeeBERT | ✗ | ✓ | ✓ | ✗ | ✗ |
| CALM | Enc-Dec | ✗ | ✓ | ✗ | ✗ |
| LayerSkip | ✓ | ✗ | ✓ | ✗ | ✗ |
| EE-LLM | ✓ | ✗ | ✓ | ✗ | ✗ |
| SkipDecode | ✓ | ✓ | ✗ | ✗ | ✗ |
| MoD | ✓ | ✗ | ✓ | ✗ | ✗ |
| TIDE | ✓ | ✓ | ✓ | ✓ | ✓ |
说明: TIDE 是唯一同时满足 decoder 支持、post-training、per-token 决策、CUDA 优化和通用架构支持的系统
Table 2: CUDA Kernel 模板特化
| 目标模型 | Hidden Dim | Bottleneck |
|---|---|---|
| Phi-2 / TinyLlama | 2048 | 64 |
| Phi-2 / TinyLlama | 2048 | 128 |
| Phi-3-mini | 3072 | 128 |
| LLaMA-8B / Mistral-7B | 4096 | 128 |
| LLaMA-8B / Mistral-7B | 4096 | 256 |
| LLaMA-13B | 5120 | 128 |
| LLaMA-70B / Qwen-72B | 8192 | 128 |
| LLaMA-70B / Qwen-72B | 8192 | 256 |
说明: 8 种编译时特化覆盖主流模型维度,未命中时 fallback 到通用实现
Table 3: Prefill Exit Rates / 预填充退出率
| Model | Layers | θ | Tokens | Exit% | Distribution |
|---|---|---|---|---|---|
| DeepSeek R1 Distill 8B | 32 | 0.85 | 322 | 100% | L11:16, L31:306 |
| DeepSeek R1 Distill 8B | 32 | 0.50 | 322 | 100% | L11:16, L31:306 |
| Qwen3 8B | 36 | 0.85 | 155 | 100% | L35:155 |
| Qwen3 8B | 36 | 0.50 | 155 | 100% | L11:11, L23:5, L35:139 |
说明: 所有 token 均能实现提前退出。DeepSeek R1 中约 5% 的 token 在 L11 退出,其余集中在倒数第二个 checkpoint(L31)。Qwen3 在低阈值下退出分布更分散
Table 4: Prefill Latency and Throughput / 预填充延迟与吞吐量
| Model | Metric | Baseline | TIDE | Δ |
|---|---|---|---|---|
| DeepSeek R1 8B | Latency (θ=0.85) | 39.08 ms | 36.94 ms | −5.5% |
| DeepSeek R1 8B | Latency (θ=0.50) | 39.08 ms | 36.26 ms | −7.2% |
| DeepSeek R1 8B | Throughput BS=1 | 973 tok/s | 1,037 tok/s | +6.6% |
| DeepSeek R1 8B | Throughput BS=8 | 8,668 tok/s | 7,252 tok/s | −16.3% |
| Qwen3 8B | Latency (θ=0.85) | 46.82 ms | 44.14 ms | −5.7% |
| Qwen3 8B | Throughput BS=1 | 258 tok/s | 271 tok/s | +5.0% |
| Qwen3 8B | Throughput BS=8 | 1,781 tok/s | 1,926 tok/s | +8.1% |
说明: BS=1 时 TIDE 一致提升吞吐量(5-6.6%)和降低延迟(5.5-7.2%)。但 DeepSeek R1 在 BS=8 时吞吐反降 16.3%,说明 hidden state 收集在大 batch 时成为瓶颈。Qwen3 在 BS=8 时仍有 8.1% 增益
Table 5: Decode Exit Rates and Output Quality / 解码退出率与输出质量
| θ | Exit Rate | Unique Tokens | Correct | Exit Layer |
|---|---|---|---|---|
| 1.0 (off) | 0% | 99 | ✓ | — |
| 0.85 | 98.4% | 95 | ✓ | L31 |
| 0.70 | 99.2% | 95 | ✓ | L31 |
| 0.50 | 99.6% | 95 | ✓ | L31 |
说明: 在数学推理任务上,即使 99.6% 的 token 提前退出,仍能得到正确答案。退出 token 全部集中在 L31(倒数第二 checkpoint),说明 router 非常保守
实验
数据集
| 数据集 | 规模 | 特点 | 用途 |
|---|---|---|---|
| WikiText-103 | 2,000 samples (~340K tokens) | 英文 Wikipedia 长文本 | 校准 |
| 自定义 prompts | 16 条(8 推理/数学 + 8 通用) | 多样性评测 | 评估 |
实现细节
硬件: NVIDIA A100-SXM4-40GB (sm_80, 80 GB HBM2e)
软件: CUDA 12.4, PyTorch 2.10, transformers 5.3
精度: bfloat16
校准配置: checkpoint 间隔 ,收敛阈值 ,bottleneck
Router 训练: Adam (lr=1e-3), 100 epochs, Binary Cross-Entropy loss
Latency 测量: 20 次平均(3 次 warmup)
代码规模: 3,097 行(Python 1,308 + CUDA/C++ 1,081 + 测试 708),74 个通过测试
收敛分析
DeepSeek R1 Distill 8B: 339,853 tokens,100% 在 L31 收敛
Qwen3 8B: 314,530 tokens,100% 在 L35 收敛
GPT-2 (124M, 12 layers): 78,843 tokens,100% 在 L11 收敛
严格阈值 导致退出集中在倒数第二 checkpoint
可视化结果
3-token 示例(“the”/“cat”/“sat”)展示了理想情况下 42% 的计算减少
实际实验中,由于保守阈值,大部分 token 退出在最后 checkpoint 附近,真实加速较温和(5-8%)
批判性思考
优点
零训练成本: 完全 post-training,不修改原始权重,校准仅需 ~3 分钟,降低了使用门槛
工程完整度高: 自带融合 CUDA kernel、自动架构探测、pip install 即用,74 个测试覆盖,工程质量远超多数学术 early exit 工作
KV Cache 安全: post-hoc 策略保证 KV Cache 完整性,这是实际部署的刚需
通用性好: 一套代码支持 LLaMA/Mistral/Qwen/GPT-2 等 9 种架构系列
局限性
加速幅度有限: 实际加速仅 5-8%,远低于 Figure 2 中 42% 理论减少。根本原因在于 post-hoc 模式仍需执行所有层,只是选择更早的 logit 输出——这与”跳过层”有本质区别
大 batch 退化: DeepSeek R1 在 BS=8 时吞吐反降 16.3%,因为 output_hidden_states=True 收集所有层隐状态的开销在大 batch 下不可忽略
退出分布极保守: 几乎所有 token 集中在倒数第二 checkpoint 退出,说明 过于严格,真正”跳过大量层”的 token 极少
评估规模偏小: 仅在 8B 模型上测试,缺少 13B/70B 级别验证;评估仅用 16 条 prompt,统计显著性不足
无标准 benchmark: 未在 MMLU/GSM8K/HumanEval 等公认 benchmark 上评测,难以与其他方法直接比较
潜在改进方向
真正的 layer-skipping: 当前 post-hoc 模式的根本瓶颈在于仍执行所有层。若能设计 KV Cache 感知的 speculative skip(如结合 LayerSkip 的自投机解码),可能获得真正的层跳过加速
自适应阈值: 用可学习或基于 calibration 统计的 per-layer 阈值替代全局 ,让浅层 checkpoint 也能有效退出
更大规模验证: 在 70B+ 模型上测试,这类模型层数更多,early exit 的加速空间理论上更大
与 Speculative Decoding 结合: 用 early exit 层作为 draft model,可能同时获得 early exit 和投机解码的双重加速
可复现性评估
- 代码开源(Apache 2.0, GitHub + PyPI)
- 预训练模型(router checkpoint 未公开发布)
- 训练细节完整(校准配置全部公开)
- 数据集可获取(WikiText-103 公开)
关联笔记
基于
CALM: confidence-based early exit 的开创性工作,TIDE 的 learned router 替代了 confidence heuristic
DeeBERT: BERT early exit 先驱,TIDE 将思路拓展到 decoder-only 自回归模型
Adaptive Computation Time: Graves 提出的自适应计算时间,TIDE 的 per-token depth 决策是同一思想的 post-training 实现
对比
LayerSkip: 需要从头训练 + layer dropout schedule,TIDE 无需重训
SkipDecode: 同为 post-training,但不支持 per-token 决策
Mixture-of-Depths: 训练时 routing tokens 跳层,TIDE 为推理时 post-hoc
SpecEE: 结合 speculative decoding + early exit,与 TIDE 互补
ADEPT: 用 draft model 决定 token depth,TIDE 用 learned router
方法相关
early exit: 核心方法论
KV Cache: post-hoc 策略的关键约束
RMSNorm: router 内置归一化 + CUDA kernel 融合
余弦相似度: 收敛度量的核心指标
adaptive computation: 广义的自适应计算范式
Mixture-of-Experts: 同属条件计算(conditional computation)大类
硬件/数据相关
WikiText-103: 校准数据集
A100 GPU: 实验硬件
速查卡片
TIDE: Token-Informed Depth Execution for Per-Token Early Exit in LLM Inference
- 核心: 无需重训练的 per-token 提前退出,通过 learned router post-hoc 选择最早收敛层
- 方法: 离线校准训练轻量 MLP router(~4MB)+ 推理时 post-hoc 评估 + 融合 CUDA kernel
- 结果: DeepSeek R1 8B 上 7.2% prefill 延迟降低 / 6.6% 吞吐提升(BS=1),98-99% token 提前退出且精度无损
- 代码: GitHub /
pip install tide-inference
笔记创建时间: 2026-03-25