TIDE: Token-Informed Depth Execution for Per-Token Early Exit in LLM Inference

作者: Jaber Jaber, Osama Jaber 年份: 2026 会议: arXiv 分类: 高效推理与部署

论文笔记:TIDE: Token-Informed Depth Execution for Per-Token Early Exit in LLM Inference

元信息

项目内容
机构RightNow AI
日期March 2026
项目主页
对比基线DeeBERT, CALM, LayerSkip, EE-LLM, SkipDecode, MoD
链接arXiv / Code

一句话总结

提出一种无需重训练的 per-token 提前退出系统,通过在检查点层附加轻量 router MLP 实现 LLM 推理加速

核心贡献

Post-training 提前退出: 无需修改原始模型权重,仅通过离线校准训练轻量 router 即可实现 per-token early exit

Universal Model Adapter: 通过 17 条属性路径自动探测 Transformer 架构,支持 LLaMA/Mistral/Qwen/GPT-2 等主流模型

融合 CUDA 内核: 将 RMSNorm + router 评估融合为单次 kernel launch,原生支持 fp16/bf16,覆盖 8 种模板特化

问题背景

要解决的问题

LLM 推理中,不同 token 的”难度”不同——简单 token(如冠词”the”)无需经过所有 Transformer 层即可收敛,但标准推理仍让每个 token 走完全部层,造成计算浪费

现有方法的局限

DeeBERT 等 encoder 方法不处理 KV Cache 和自回归生成

CALM 针对 encoder-decoder(T5),需修改训练流程

LayerSkipEE-LLM 需要从头训练或修改预训练 pipeline

SkipDecode 不支持 per-token 决策,所有 token 在同一层退出

Mixture-of-Depths 需在训练阶段引入 routing,无法 post-hoc 使用

本文的动机

如果能在不修改模型权重的前提下,通过轻量的 learned router 判断每个 token 的隐状态是否已收敛,就能在保持精度的同时减少冗余计算。关键洞察:token 在中间层的 余弦相似度 与最终层很高时,后续层对该 token 几乎没有贡献

方法详解

系统架构

TIDE 采用 两阶段架构(离线校准 + 在线推理):

  • 离线校准: 在冻结模型上收集隐状态 → 计算 余弦相似度 → 训练二分类 router MLP
  • 在线推理: 完整前向传播(保持 KV Cache 完整性)→ router post-hoc 评估 → 选择最早收敛层计算 logits
  • Router 参数量: 每个 checkpoint 仅 d×128+128d \times 128 + 128 个参数(对 d=4096d=4096,约 524K 参数)
  • 总 checkpoint 大小: ~4 MB

核心模块

模块1: 离线校准(Calibration)

设计动机: 利用 余弦相似度 度量隐状态收敛程度,避免启发式 confidence 阈值

具体实现:

  • 在 2,000 条 WikiText-103 文本上运行冻结模型,收集所有 checkpoint 层的隐状态 hk(i)Rd\mathbf{h}_k^{(i)} \in \mathbb{R}^d
  • 计算每个 token ii 在 checkpoint 层 kk 与最终层 LL余弦相似度
  • 当相似度超过阈值 τ=0.98\tau = 0.98 时标记为”已收敛”(yk(i)=1y_k^{(i)} = 1
  • Binary Cross-Entropy 损失训练二层 MLP router,Adam 优化器(lr=103\text{lr} = 10^{-3}),100 epochs
  • DeepSeek R1 Distill 8B 在 A100 上校准仅需 170 秒(339,853 tokens)

模块2: Universal Model Adapter

设计动机: 不同 HuggingFace 模型的内部命名规范各异,需要自动适配

具体实现:

  • 探测 5 类组件的 17 条属性路径:
    • Layers: model.layers, transformer.h, transformer.layers, gpt_neox.layers, model.decoder.layers + largest-ModuleList fallback
    • Final Norm: 5 条路径 + sibling-of-layers 启发式
    • LM Head: 2 条路径 + vocab-size shape matching
    • Embedding: 5 条路径 + vocab-size shape matching
    • Hidden Dim: model.config.hidden_size(HuggingFace 通用)
  • 支持架构:LLaMA, Mistral, Qwen, GPT-2, GPT-NeoX, Phi, Falcon, OPT, Gemma
  • 用户可通过 register_adapter() 注册自定义适配器

模块3: Post-Hoc Generation

设计动机: 保持 KV Cache 完整性是自回归生成的刚性约束,直接跳层会破坏缓存一致性

具体实现:

  • 完整前向传播,收集所有 checkpoint 层隐状态 H=[h0,h1,,hL]H = [\mathbf{h}_0, \mathbf{h}_1, \ldots, \mathbf{h}_L]
  • 按层序遍历 router:若 ϕk(H[k+1])>θ\phi_k(H[k+1]) > \theta(对 batch 中所有 token),则用 LMHead(RMSNorm(H[k+1]))\text{LMHead}(\text{RMSNorm}(H[k+1])) 计算 logits
  • 无 early exit 触发时回退到标准 logits
  • 优势: 所有层都执行,KV Cache 始终完整填充;兼容任意 transformers 库版本

模块4: 融合 CUDA 内核

设计动机: router 评估引入的额外开销需通过 kernel 融合最小化

具体实现:

  • Fused LayerNorm + Route: 单 kernel 完成 RMSNorm → down-projection → SiLU → up-projection → Sigmoid,使用 warp-level reduction(__shfl_xor_sync
  • Batch Compact: 分离 continuing/exiting tokens,小 batch 用 warp balloting(__ballot_sync),大 batch 用 prefix-sum scatter
  • Exit Scatter: 将 exited tokens 的隐状态复制回原始输出 buffer 位置
  • Exit Projection: 融合 RMSNorm + scatter
  • 所有 kernel 模板化支持 float32 / fp16 / bf16,累积运算始终为 float32

关键公式

公式1: 隐状态收敛度量

sk(i)=hk(i)hL(i)hk(i)hL(i)s_k^{(i)} = \frac{\mathbf{h}_k^{(i)} \cdot \mathbf{h}_L^{(i)}}{\|\mathbf{h}_k^{(i)}\| \, \|\mathbf{h}_L^{(i)}\|}

含义: 衡量 token ii 在 checkpoint 层 kk 的隐状态与最终层 LL 的相似程度,值越接近 1 说明该 token 越早收敛

符号说明:

  • hk(i)Rd\mathbf{h}_k^{(i)} \in \mathbb{R}^d: token ii 在第 kk 层的隐状态向量
  • hL(i)Rd\mathbf{h}_L^{(i)} \in \mathbb{R}^d: token ii 在最终层 LL 的隐状态向量
  • sk(i)[0,1]s_k^{(i)} \in [0, 1]: 收敛度分数

公式2: 收敛标签生成

yk(i)=1[sk(i)>τ]y_k^{(i)} = \mathbf{1}[s_k^{(i)} > \tau]

含义: 当余弦相似度超过阈值 τ\tau 时,标记该 token 在该层已收敛

符号说明:

  • τ=0.98\tau = 0.98: 默认收敛阈值(较严格)
  • yk(i){0,1}y_k^{(i)} \in \{0, 1\}: 二分类标签

公式3: Router MLP

ϕk(h)=σ(WupSiLU(WdownRMSNorm(h)))\phi_k(\mathbf{h}) = \sigma\left(\mathbf{W}_{\text{up}} \, \text{SiLU}\left(\mathbf{W}_{\text{down}} \, \text{RMSNorm}(\mathbf{h})\right)\right)

含义: 每个 checkpoint 层的二层 MLP router,输出退出概率

符号说明:

  • WdownRb×d\mathbf{W}_{\text{down}} \in \mathbb{R}^{b \times d}: 降维矩阵,b=128b=128 为瓶颈维度
  • WupR1×b\mathbf{W}_{\text{up}} \in \mathbb{R}^{1 \times b}: 升维到标量
  • SiLU(x)=xσ(x)\text{SiLU}(x) = x \cdot \sigma(x): SiLU 激活函数
  • σ\sigma: Sigmoid 函数
  • RMSNorm: 输入归一化

公式4: Post-Hoc 退出决策

k=min{kC:ϕk(H[k+1])>θ,  kkmin}k^* = \min \{ k \in \mathcal{C} : \phi_k(H[k+1]) > \theta, \; k \geq k_{\min} \} logits=LMHead(RMSNorm(H[k+1]))\text{logits} = \text{LMHead}(\text{RMSNorm}(H[k^*+1]))

含义: 选择满足阈值条件的最早 checkpoint 层,用该层隐状态计算 logits

符号说明:

  • C\mathcal{C}: checkpoint 层集合(间隔 c=4c=4
  • θ\theta: 推理时退出阈值(可调,不同于训练阈值 τ\tau
  • kmink_{\min}: 最小退出层(防止过早退出)
  • H[k+1]H[k+1]: 第 k+1k+1 层的隐状态

关键图表

Figure 1: TIDE System Overview / 系统概览

Figure 1: TIDE System Overview{:width 600}

说明: TIDE 系统概览。左侧为一次性离线校准流程:在冻结模型上收集隐状态 → 计算 per-token 与最终层的 余弦相似度 → 在每个 checkpoint 训练 binary router MLP。右侧为推理流程:完整前向传播(保持 KV Cache)→ router post-hoc 评估 → 选择最早收敛层计算 logits。编号步骤显示执行顺序。

Figure 2: Per-Token Early Exit Visualization / 逐 Token 提前退出可视化

说明: 在 32 层模型中的 per-token early exit 示意。Token “the” 在第 8 层收敛,跳过后续 24 层;Token “cat” 在第 16 层退出;只有 “sat” 需要完整深度。三个 token 总计 8+16+32=568+16+32=56 次 layer 操作(vs. baseline 的 3×32=963 \times 32 = 96),减少 42% 计算。TIDE 在完整前向传播后 post-hoc 评估 router ϕk\phi_k

Table 1: Early Exit 系统对比

SystemDecoderPost-trainingPer-tokenCUDAUniversal
DeeBERT
CALMEnc-Dec
LayerSkip
EE-LLM
SkipDecode
MoD
TIDE

说明: TIDE 是唯一同时满足 decoder 支持、post-training、per-token 决策、CUDA 优化和通用架构支持的系统

Table 2: CUDA Kernel 模板特化

目标模型Hidden DimBottleneck
Phi-2 / TinyLlama204864
Phi-2 / TinyLlama2048128
Phi-3-mini3072128
LLaMA-8B / Mistral-7B4096128
LLaMA-8B / Mistral-7B4096256
LLaMA-13B5120128
LLaMA-70B / Qwen-72B8192128
LLaMA-70B / Qwen-72B8192256

说明: 8 种编译时特化覆盖主流模型维度,未命中时 fallback 到通用实现

Table 3: Prefill Exit Rates / 预填充退出率

ModelLayersθTokensExit%Distribution
DeepSeek R1 Distill 8B320.85322100%L11:16, L31:306
DeepSeek R1 Distill 8B320.50322100%L11:16, L31:306
Qwen3 8B360.85155100%L35:155
Qwen3 8B360.50155100%L11:11, L23:5, L35:139

说明: 所有 token 均能实现提前退出。DeepSeek R1 中约 5% 的 token 在 L11 退出,其余集中在倒数第二个 checkpoint(L31)。Qwen3 在低阈值下退出分布更分散

Table 4: Prefill Latency and Throughput / 预填充延迟与吞吐量

ModelMetricBaselineTIDEΔ
DeepSeek R1 8BLatency (θ=0.85)39.08 ms36.94 ms−5.5%
DeepSeek R1 8BLatency (θ=0.50)39.08 ms36.26 ms−7.2%
DeepSeek R1 8BThroughput BS=1973 tok/s1,037 tok/s+6.6%
DeepSeek R1 8BThroughput BS=88,668 tok/s7,252 tok/s−16.3%
Qwen3 8BLatency (θ=0.85)46.82 ms44.14 ms−5.7%
Qwen3 8BThroughput BS=1258 tok/s271 tok/s+5.0%
Qwen3 8BThroughput BS=81,781 tok/s1,926 tok/s+8.1%

说明: BS=1 时 TIDE 一致提升吞吐量(5-6.6%)和降低延迟(5.5-7.2%)。但 DeepSeek R1 在 BS=8 时吞吐反降 16.3%,说明 hidden state 收集在大 batch 时成为瓶颈。Qwen3 在 BS=8 时仍有 8.1% 增益

Table 5: Decode Exit Rates and Output Quality / 解码退出率与输出质量

θExit RateUnique TokensCorrectExit Layer
1.0 (off)0%99
0.8598.4%95L31
0.7099.2%95L31
0.5099.6%95L31

说明: 在数学推理任务上,即使 99.6% 的 token 提前退出,仍能得到正确答案。退出 token 全部集中在 L31(倒数第二 checkpoint),说明 router 非常保守

实验

数据集

数据集规模特点用途
WikiText-1032,000 samples (~340K tokens)英文 Wikipedia 长文本校准
自定义 prompts16 条(8 推理/数学 + 8 通用)多样性评测评估

实现细节

硬件: NVIDIA A100-SXM4-40GB (sm_80, 80 GB HBM2e)

软件: CUDA 12.4, PyTorch 2.10, transformers 5.3

精度: bfloat16

校准配置: checkpoint 间隔 c=4c=4,收敛阈值 τ=0.98\tau=0.98,bottleneck b=128b=128

Router 训练: Adam (lr=1e-3), 100 epochs, Binary Cross-Entropy loss

Latency 测量: 20 次平均(3 次 warmup)

代码规模: 3,097 行(Python 1,308 + CUDA/C++ 1,081 + 测试 708),74 个通过测试

收敛分析

DeepSeek R1 Distill 8B: 339,853 tokens,100% 在 L31 收敛

Qwen3 8B: 314,530 tokens,100% 在 L35 收敛

GPT-2 (124M, 12 layers): 78,843 tokens,100% 在 L11 收敛

严格阈值 τ=0.98\tau=0.98 导致退出集中在倒数第二 checkpoint

可视化结果

3-token 示例(“the”/“cat”/“sat”)展示了理想情况下 42% 的计算减少

实际实验中,由于保守阈值,大部分 token 退出在最后 checkpoint 附近,真实加速较温和(5-8%)

批判性思考

优点

零训练成本: 完全 post-training,不修改原始权重,校准仅需 ~3 分钟,降低了使用门槛

工程完整度高: 自带融合 CUDA kernel、自动架构探测、pip install 即用,74 个测试覆盖,工程质量远超多数学术 early exit 工作

KV Cache 安全: post-hoc 策略保证 KV Cache 完整性,这是实际部署的刚需

通用性好: 一套代码支持 LLaMA/Mistral/Qwen/GPT-2 等 9 种架构系列

局限性

加速幅度有限: 实际加速仅 5-8%,远低于 Figure 2 中 42% 理论减少。根本原因在于 post-hoc 模式仍需执行所有层,只是选择更早的 logit 输出——这与”跳过层”有本质区别

大 batch 退化: DeepSeek R1 在 BS=8 时吞吐反降 16.3%,因为 output_hidden_states=True 收集所有层隐状态的开销在大 batch 下不可忽略

退出分布极保守: 几乎所有 token 集中在倒数第二 checkpoint 退出,说明 τ=0.98\tau=0.98 过于严格,真正”跳过大量层”的 token 极少

评估规模偏小: 仅在 8B 模型上测试,缺少 13B/70B 级别验证;评估仅用 16 条 prompt,统计显著性不足

无标准 benchmark: 未在 MMLU/GSM8K/HumanEval 等公认 benchmark 上评测,难以与其他方法直接比较

潜在改进方向

真正的 layer-skipping: 当前 post-hoc 模式的根本瓶颈在于仍执行所有层。若能设计 KV Cache 感知的 speculative skip(如结合 LayerSkip 的自投机解码),可能获得真正的层跳过加速

自适应阈值: 用可学习或基于 calibration 统计的 per-layer 阈值替代全局 τ\tau,让浅层 checkpoint 也能有效退出

更大规模验证: 在 70B+ 模型上测试,这类模型层数更多,early exit 的加速空间理论上更大

Speculative Decoding 结合: 用 early exit 层作为 draft model,可能同时获得 early exit 和投机解码的双重加速

可复现性评估

  • 代码开源(Apache 2.0, GitHub + PyPI)
  • 预训练模型(router checkpoint 未公开发布)
  • 训练细节完整(校准配置全部公开)
  • 数据集可获取(WikiText-103 公开)

关联笔记

基于

CALM: confidence-based early exit 的开创性工作,TIDE 的 learned router 替代了 confidence heuristic

DeeBERT: BERT early exit 先驱,TIDE 将思路拓展到 decoder-only 自回归模型

Adaptive Computation Time: Graves 提出的自适应计算时间,TIDE 的 per-token depth 决策是同一思想的 post-training 实现

对比

LayerSkip: 需要从头训练 + layer dropout schedule,TIDE 无需重训

SkipDecode: 同为 post-training,但不支持 per-token 决策

Mixture-of-Depths: 训练时 routing tokens 跳层,TIDE 为推理时 post-hoc

SpecEE: 结合 speculative decoding + early exit,与 TIDE 互补

ADEPT: 用 draft model 决定 token depth,TIDE 用 learned router

方法相关

early exit: 核心方法论

KV Cache: post-hoc 策略的关键约束

RMSNorm: router 内置归一化 + CUDA kernel 融合

余弦相似度: 收敛度量的核心指标

adaptive computation: 广义的自适应计算范式

Mixture-of-Experts: 同属条件计算(conditional computation)大类

硬件/数据相关

WikiText-103: 校准数据集

A100 GPU: 实验硬件

速查卡片

TIDE: Token-Informed Depth Execution for Per-Token Early Exit in LLM Inference

  • 核心: 无需重训练的 per-token 提前退出,通过 learned router post-hoc 选择最早收敛层
  • 方法: 离线校准训练轻量 MLP router(~4MB)+ 推理时 post-hoc 评估 + 融合 CUDA kernel
  • 结果: DeepSeek R1 8B 上 7.2% prefill 延迟降低 / 6.6% 吞吐提升(BS=1),98-99% token 提前退出且精度无损
  • 代码: GitHub / pip install tide-inference

笔记创建时间: 2026-03-25