Diet Your LLM: Dimension-wise Global Pruning of LLMs via Merging Task-specific Importance Score

作者: Jimyung Hong, Jaehyung Kim 年份: 2026 会议: arXiv 分类: 剪枝与稀疏化

论文笔记:Diet Your LLM: Dimension-wise Global Pruning of LLMs via Merging Task-specific Importance Score

元信息

项目内容
机构未明确标注
日期March 2026
项目主页
对比基线SliceGPT, PuDDing, Magnitude Pruning
链接arXiv / Code

一句话总结

提出 DieT,一种无需训练的维度级全局剪枝方法,通过跨任务激活投票构建统一剪枝 mask,在 20% 稀疏度下比现有方法提升约 10% 准确率。

核心贡献

维度级全局剪枝框架: 产生单一可部署 mask,统一应用于 embedding、attention、MLP 和 LM head 的所有线性层

基于激活的任务感知 profiling: 每任务仅需 100 个样本,通过多数投票聚合跨任务重要性分数,完全免训练

方差校正硬剪枝: 通过 α=d/d\alpha = \sqrt{d/d'} 缩放因子补偿维度缩减带来的方差变化,实现真实加速而不损失精度

问题背景

要解决的问题

LLM 部署面临计算和存储瓶颈,结构化剪枝可以直接移除完整维度以获得实际加速,但如何在不训练的前提下选择应当移除的维度是核心难题

现有方法的局限

任务无关方法(如 SliceGPT): 基于 PCA 旋转后切片,不考虑下游任务特征,剪枝后性能骤降

任务感知方法(如 PuDDing): 需要训练路由器,增加额外成本

LLM-Pruner: 依赖梯度信息 + LoRA 微调恢复,流程复杂

本文的动机

不同任务对相同维度的”重要性”存在共识——如果一个维度被多数任务认为不重要,则可安全移除;通过多数投票机制聚合这一共识,无需训练即可获得全局 mask

方法详解

整体框架

DieT 采用 三步流水线 架构:

  • Step 1 — 任务级激活 profiling: 对每个任务 tt,用 100 个样本前向传播,收集 MLP 模块 输出的激活幅度
  • Step 2 — 全局 mask 构建: 每个任务独立产生二值选择器 z(t)z^{(t)},通过 多数投票 聚合为全局 mask uu
  • Step 3 — mask 应用: 将 uu 统一应用于所有残差连接的线性层,将隐藏维度从 dd 缩减至 dd'

核心模块

模块1: 维度级 Masking

设计动机: 利用 结构化剪枝 的维度粒度,在所有层共享同一 mask,避免逐层优化

输入维度 masking:

  • 对线性层 y=Wx+by = Wx + b,通过对角 mask Pz=diag(z)P_z = \text{diag}(z) 零化 WW 的列

输出维度 masking:

  • 通过 P~z=diag(z)\tilde{P}_z = \text{diag}(z) 零化 WW 的行和对应偏置

全局 mask z{0,1}dz \in \{0,1\}^d 使有效维度从 dd 降至 d=jzjd' = \sum_j z_j

模块2: 任务感知激活 Profiling

设计动机: 激活幅度直接反映维度在特定任务上的信息承载量

具体实现:

  • 对每个任务 tt,收集所有 MLP 层输出激活的绝对值均值
  • 跨层聚合得到每个维度 kk 的重要性分数 ak(t)a_k^{(t)}
  • 重要性最低的维度被标记为剪枝候选,产生二值选择器 z(t)z^{(t)}

模块3: 多数投票聚合

设计动机: 利用跨任务共识过滤掉单任务偏见,提升剪枝鲁棒性

具体实现:

  • 将所有任务的选择器堆叠为矩阵 Z{0,1}d×TZ \in \{0,1\}^{d \times T}
  • 对每个维度统计投票数 ck=tTzk(t)c_k = \sum_{t \in T} z_k^{(t)}
  • 选取投票最多的 sd\lfloor s \cdot d \rfloor 个维度作为待移除集合

模块4: 方差校正硬剪枝

设计动机: 物理移除维度后输入维度减少,导致激活方差按 d/dd'/d 比例缩小,需补偿

具体实现:

  • 计算缩放因子 α=d/d\alpha = \sqrt{d/d'}
  • 对切片后的权重矩阵统一缩放 Wscaled=αWslicedW_{\text{scaled}} = \alpha \cdot W_{\text{sliced}}

关键公式

公式1: 输入维度 Masking

y=WPzx+b,Pz=diag(z)y = W P_z x + b, \quad P_z = \text{diag}(z)

含义: 通过对角 mask 矩阵选择性零化权重矩阵的列,等价于移除输入中对应的维度

符号说明:

  • WRm×nW \in \mathbb{R}^{m \times n}: 线性层权重矩阵
  • Pz=diag(z)P_z = \text{diag}(z): 对角 mask 矩阵
  • z{0,1}nz \in \{0,1\}^n: 二值 mask 向量

公式2: 输出维度 Masking

y~=P~z(Wx+b),P~z=diag(z)\tilde{y} = \tilde{P}_z (Wx + b), \quad \tilde{P}_z = \text{diag}(z)

含义: 零化输出维度中不重要的分量,等价于移除对应的行

符号说明:

  • P~z\tilde{P}_z: 输出端对角 mask
  • z{0,1}mz \in \{0,1\}^m: 输出维度的二值 mask

公式3: 跨 batch 激活均值

hˉi(t)[j,k]=1BbBhi,b(t)[j,k]\bar{h}_i^{(t)}[j,k] = \frac{1}{B} \sum_{b \in B} |h_{i,b}^{(t)}[j,k]|

含义: 对每个任务 tt,将 MLP 层 ii 在位置 jj、维度 kk 的激活绝对值在 batch 维度上取平均

符号说明:

  • hi,b(t)RLi×dh_{i,b}^{(t)} \in \mathbb{R}^{L_i \times d}: 任务 tt 在第 ii 层、第 bb 个 batch 的 MLP 输出
  • BB: batch 数
  • LiL_i: 第 ii 层的序列长度

公式4: 维度重要性分数

ak(t)=i=1Nj=1Lihˉi(t)[j,k]i=1NLia_k^{(t)} = \frac{\sum_{i=1}^{N} \sum_{j=1}^{L_i} \bar{h}_i^{(t)}[j,k]}{\sum_{i=1}^{N} L_i}

含义: 维度 kk 在任务 tt 上的重要性等于其跨层、跨位置的平均激活幅度

符号说明:

  • NN: MLP 层总数
  • ak(t)a_k^{(t)}: 维度 kk 对任务 tt 的重要性分数

公式5: 多数投票选择待移除维度

O=argtopsd{ck}k=1d,ck=tTzk(t)O = \arg\text{top}_{\lfloor s \cdot d \rfloor} \{c_k\}_{k=1}^d, \quad c_k = \sum_{t \in T} z_k^{(t)}

含义: 选取被最多任务投票为”不重要”的前 sd\lfloor s \cdot d \rfloor 个维度作为移除集合

符号说明:

  • ss: 目标稀疏率
  • ckc_k: 维度 kk 收到的总投票数
  • OO: 待移除维度索引集合

公式6: 方差校正缩放

α=dd,Wscaled=αWsliced\alpha = \sqrt{\frac{d}{d'}}, \quad W_{\text{scaled}} = \alpha \cdot W_{\text{sliced}}

含义: 物理移除维度后,输入维度从 dd 变为 dd',激活方差按 d/dd'/d 缩小;乘以 α\alpha 补偿方差,保持数值稳定

符号说明:

  • dd: 原始隐藏维度
  • d=jujd' = \sum_j u_j: 保留的维度数
  • WslicedW_{\text{sliced}}: 移除对应行/列后的权重子矩阵

关键图表

Figure 1: DieT 整体流程

Figure 1: DieT Pipeline{:width 600}

说明: DieT 的三步流程。Step 1: 在密集 LLM 上对每个任务进行激活 profiling(隐藏宽度 dd)。Step 2: 通过跨任务多数投票构建全局剪枝 mask——每个任务 tt 的二值选择器标记低重要性维度,投票聚合生成全局 mask。Step 3: mask 统一应用于所有层,将残差宽度从 dd 缩减至 dd'

Figure 2: 跨任务投票共识直方图

Figure 2: Vote Histogram{:width 600}

说明: Gemma-2 2B 上 7 个任务的维度投票共识分布。投票数 7-6 的维度约占 10%(~234/2304),对应 10% 稀疏度时高共识移除;投票数 7-4 约覆盖 20%,支撑 20% 稀疏度目标。

Figure 3: Profiling 样本量 vs 准确率

Figure 3: Sample Size{:width 600}

说明: Gemma-2 2B 上 7 个 zero-shot 任务的准确率随 profiling 样本量变化。N=10 时 44.2%,N=100 时 45.0%,与使用全量数据(46.1%)仅差 1.1%,表明 100 个样本即可近似最优。带 * 的 benchmark 使用 acc_norm 指标。

Table 1: Gemma-2 2B 主要结果

MethodSparsityBoolQRTEHellaSwag*WinoGrande*ARC-E*ARC-C*OBQA*Avg
Dense0%77.763.573.071.277.952.642.265.4
Magnitude-Dim10%58.753.140.152.647.529.431.044.6
SliceGPT10%37.847.326.051.526.424.431.435.0
PuDDing10%38.547.325.550.028.322.738.835.9
DieT10%70.651.662.762.859.539.247.656.3
Magnitude-Dim20%41.253.427.451.529.623.731.836.9
SliceGPT20%38.547.325.949.226.724.133.235.0
PuDDing20%49.847.325.248.326.324.729.235.8
DieT20%62.247.344.356.044.529.331.645.0
Magnitude-Dim30%38.547.325.448.925.522.430.834.1
SliceGPT30%53.447.325.448.125.022.430.836.1
PuDDing30%53.551.625.550.426.823.633.437.8
DieT30%59.447.330.550.732.625.331.639.6

说明: DieT 在所有稀疏度下均显著优于 baseline。10% 稀疏度时平均准确率 56.3%,比 SliceGPT 高 21.3%,比 PuDDing 高 20.4%。

Table 2: Gemma-2 9B 主要结果

MethodSparsityBoolQRTEHellaSwag*WinoGrande*ARC-E*ARC-C*OBQA*Avg
Dense0%85.074.482.577.385.864.847.273.9
SliceGPT10%37.853.426.149.226.622.430.835.2
PuDDing10%47.247.347.659.366.244.435.849.7
DieT10%81.653.874.567.672.251.151.664.6
SliceGPT20%46.547.326.149.326.222.429.235.3
PuDDing20%55.147.349.359.766.740.323.848.9
DieT20%70.447.355.356.453.336.142.251.6

说明: 9B 模型上优势更加明显。10% 稀疏度时 DieT 比 SliceGPT 高 29.4%,比 PuDDing 高 14.9%。

Table 3: 硬剪枝效率(Gemma-2 2B, 20% sparsity)

MetricDenseDieT-Hard
Avg Latency (ms)41.5239.58
FLOPs (×10¹²)3.953.16

说明: 硬剪枝在 FLOPs 上减少约 20%,延迟有一定降低,GPU 显存峰值也一致性下降。

Table 4: Zero-masking vs 硬剪枝

VariantAvg Accuracy
DieT (zero-masked)45.0%
DieT-Hard (with α\alpha rescaling)45.4%

说明: 方差校正后的硬剪枝不仅保持精度,甚至略有提升(+0.4%),验证了缩放因子 α\alpha 的有效性。

Table 5: 全局 mask vs 单任务 mask(Gemma-2 2B, 20%)

Mask SourceAvg Accuracy
Global (DieT)45.0%
ARC-Easy only44.0%
ARC-Challenge only43.6%

说明: 多任务聚合的全局 mask 优于任何单一任务的 mask,验证了跨任务投票的增益。

Table 6: 连续分数 vs 二值投票

Merging StrategyAvg Accuracy
Continuous score44.6%
Binary voting (DieT)45.0%

说明: 二值投票略优于连续分数加权,表明投票机制对分数量化粒度具有鲁棒性。

Table 7: 跨架构泛化(10% / 20% sparsity)

Model10% Avg20% Avg
Qwen2.5-7B53.733.5
Phi-4-mini-reasoning50.221.4

说明: 10% 稀疏度下跨架构泛化良好;20% 时非 Gemma 架构退化明显,暗示需要自适应稀疏度分配或逐层策略。

实验

数据集

数据集类型特点指标
BoolQ阅读理解Yes/No 问题acc
RTE文本蕴含前提-假设对acc
HellaSwag常识推理句子补全acc_norm
WinoGrande常识推理代词消歧acc_norm
ARC-Easy科学问答简单题目acc_norm
ARC-Challenge科学问答困难题目acc_norm
OpenBookQA多跳推理开放知识acc_norm

实现细节

模型: Gemma-2 2B / 9B,以及 Qwen2.5-7B、Phi-4-mini-reasoning

Profiling 样本: 每任务 100 个随机样本(来自训练集)

随机种子: 42

稀疏度: 10%, 20%, 30%

评估工具: lm-evaluation-harness,zero-shot 模式

硬件: 未明确标注

可视化结果

投票共识: 高共识维度(7 票满票)集中于一小部分,表明各任务对”无用维度”有强共识

样本效率: 100 样本已足够,继续增加边际收益递减

批判性思考

优点

极简且实用: 不需要训练、不需要梯度、每任务仅需 100 样本,流程非常轻量

统一 mask 设计: 所有层共享同一 mask,部署极为简单,无需逐层管理

方差校正硬剪枝: α\alpha 缩放是一个优雅的理论修正,使硬剪枝无精度损失

局限性

评估范围偏窄: 仅 7 个英语 zero-shot benchmark,缺少生成任务(如 perplexity、翻译、摘要)评估

高稀疏度泛化差: 非 Gemma 架构在 20% 稀疏度时显著退化(Qwen2.5: 33.5%,Phi-4-mini: 21.4%),普适性存疑

所有层共享 mask: 虽然简单,但忽略了不同层对维度重要性可能有差异的事实;逐层自适应 mask 可能进一步提升性能

baseline 选择: PuDDing 主要做 depth pruning(块级),与 DieT 的维度级剪枝不完全对等;缺少与 LLM-Pruner、Wanda 等同类维度级方法的直接对比

潜在改进方向

逐层自适应稀疏率: 不同层使用不同稀疏度,可能更好地适配非 Gemma 架构

任务权重投票: 引入任务难度/分布感知的加权投票(而非简单的多数投票)

结合 perplexity 校准: 在投票后用少量 calibration 数据微调 mask 边界

与恢复训练结合: 在硬剪枝后加入轻量 LoRA 微调,有望在高稀疏度下大幅提升

可复现性评估

关联笔记

基于

SliceGPT: PCA 旋转+切片的维度剪枝基线

LLM-Pruner: 依赖图+梯度重要性的结构化剪枝

对比

SliceGPT: 同为维度级剪枝但不使用任务信息,DieT 全面超越

PuDDing: 需训练路由器的 prompt 感知 depth pruning,DieT 免训练且效果更优

Magnitude Pruning: 基于权重幅度的经典方法,作为维度级剪枝的朴素基线

方法相关

结构化剪枝: 核心方法类别

width pruning: DieT 本质是维度级宽度剪枝

Gemma: 主要实验模型架构

硬件/数据相关

BoolQ: 阅读理解 benchmark

HellaSwag: 常识推理 benchmark

WinoGrande: 代词消歧 benchmark

速查卡片

Diet Your LLM (DieT)

  • 核心: 免训练维度级全局剪枝,跨任务激活投票构建统一 mask
  • 方法: 每任务 100 样本 profiling MLP 激活 → 二值重要性选择器 → 多数投票聚合 → 方差校正硬剪枝
  • 结果: Gemma-2 2B 10% 稀疏度 56.3%(比 SliceGPT +21.3%);20% 稀疏度 45.0%(+10.0%)
  • 代码: https://github.com/Jimmy145123/DIET

笔记创建时间: 2026-03-27