Diet Your LLM: Dimension-wise Global Pruning of LLMs via Merging Task-specific Importance Score
论文笔记:Diet Your LLM: Dimension-wise Global Pruning of LLMs via Merging Task-specific Importance Score
元信息
| 项目 | 内容 |
|---|---|
| 机构 | 未明确标注 |
| 日期 | March 2026 |
| 项目主页 | — |
| 对比基线 | SliceGPT, PuDDing, Magnitude Pruning |
| 链接 | arXiv / Code |
一句话总结
提出 DieT,一种无需训练的维度级全局剪枝方法,通过跨任务激活投票构建统一剪枝 mask,在 20% 稀疏度下比现有方法提升约 10% 准确率。
核心贡献
维度级全局剪枝框架: 产生单一可部署 mask,统一应用于 embedding、attention、MLP 和 LM head 的所有线性层
基于激活的任务感知 profiling: 每任务仅需 100 个样本,通过多数投票聚合跨任务重要性分数,完全免训练
方差校正硬剪枝: 通过 缩放因子补偿维度缩减带来的方差变化,实现真实加速而不损失精度
问题背景
要解决的问题
LLM 部署面临计算和存储瓶颈,结构化剪枝可以直接移除完整维度以获得实际加速,但如何在不训练的前提下选择应当移除的维度是核心难题
现有方法的局限
任务无关方法(如 SliceGPT): 基于 PCA 旋转后切片,不考虑下游任务特征,剪枝后性能骤降
任务感知方法(如 PuDDing): 需要训练路由器,增加额外成本
LLM-Pruner: 依赖梯度信息 + LoRA 微调恢复,流程复杂
本文的动机
不同任务对相同维度的”重要性”存在共识——如果一个维度被多数任务认为不重要,则可安全移除;通过多数投票机制聚合这一共识,无需训练即可获得全局 mask
方法详解
整体框架
DieT 采用 三步流水线 架构:
- Step 1 — 任务级激活 profiling: 对每个任务 ,用 100 个样本前向传播,收集 MLP 模块 输出的激活幅度
- Step 2 — 全局 mask 构建: 每个任务独立产生二值选择器 ,通过 多数投票 聚合为全局 mask
- Step 3 — mask 应用: 将 统一应用于所有残差连接的线性层,将隐藏维度从 缩减至
核心模块
模块1: 维度级 Masking
设计动机: 利用 结构化剪枝 的维度粒度,在所有层共享同一 mask,避免逐层优化
输入维度 masking:
- 对线性层 ,通过对角 mask 零化 的列
输出维度 masking:
- 通过 零化 的行和对应偏置
全局 mask 使有效维度从 降至
模块2: 任务感知激活 Profiling
设计动机: 激活幅度直接反映维度在特定任务上的信息承载量
具体实现:
- 对每个任务 ,收集所有 MLP 层输出激活的绝对值均值
- 跨层聚合得到每个维度 的重要性分数
- 重要性最低的维度被标记为剪枝候选,产生二值选择器
模块3: 多数投票聚合
设计动机: 利用跨任务共识过滤掉单任务偏见,提升剪枝鲁棒性
具体实现:
- 将所有任务的选择器堆叠为矩阵
- 对每个维度统计投票数
- 选取投票最多的 个维度作为待移除集合
模块4: 方差校正硬剪枝
设计动机: 物理移除维度后输入维度减少,导致激活方差按 比例缩小,需补偿
具体实现:
- 计算缩放因子
- 对切片后的权重矩阵统一缩放
关键公式
公式1: 输入维度 Masking
含义: 通过对角 mask 矩阵选择性零化权重矩阵的列,等价于移除输入中对应的维度
符号说明:
- : 线性层权重矩阵
- : 对角 mask 矩阵
- : 二值 mask 向量
公式2: 输出维度 Masking
含义: 零化输出维度中不重要的分量,等价于移除对应的行
符号说明:
- : 输出端对角 mask
- : 输出维度的二值 mask
公式3: 跨 batch 激活均值
含义: 对每个任务 ,将 MLP 层 在位置 、维度 的激活绝对值在 batch 维度上取平均
符号说明:
- : 任务 在第 层、第 个 batch 的 MLP 输出
- : batch 数
- : 第 层的序列长度
公式4: 维度重要性分数
含义: 维度 在任务 上的重要性等于其跨层、跨位置的平均激活幅度
符号说明:
- : MLP 层总数
- : 维度 对任务 的重要性分数
公式5: 多数投票选择待移除维度
含义: 选取被最多任务投票为”不重要”的前 个维度作为移除集合
符号说明:
- : 目标稀疏率
- : 维度 收到的总投票数
- : 待移除维度索引集合
公式6: 方差校正缩放
含义: 物理移除维度后,输入维度从 变为 ,激活方差按 缩小;乘以 补偿方差,保持数值稳定
符号说明:
- : 原始隐藏维度
- : 保留的维度数
- : 移除对应行/列后的权重子矩阵
关键图表
Figure 1: DieT 整体流程
{:width 600}
说明: DieT 的三步流程。Step 1: 在密集 LLM 上对每个任务进行激活 profiling(隐藏宽度 )。Step 2: 通过跨任务多数投票构建全局剪枝 mask——每个任务 的二值选择器标记低重要性维度,投票聚合生成全局 mask。Step 3: mask 统一应用于所有层,将残差宽度从 缩减至 。
Figure 2: 跨任务投票共识直方图
{:width 600}
说明: Gemma-2 2B 上 7 个任务的维度投票共识分布。投票数 7-6 的维度约占 10%(~234/2304),对应 10% 稀疏度时高共识移除;投票数 7-4 约覆盖 20%,支撑 20% 稀疏度目标。
Figure 3: Profiling 样本量 vs 准确率
{:width 600}
说明: Gemma-2 2B 上 7 个 zero-shot 任务的准确率随 profiling 样本量变化。N=10 时 44.2%,N=100 时 45.0%,与使用全量数据(46.1%)仅差 1.1%,表明 100 个样本即可近似最优。带 * 的 benchmark 使用 acc_norm 指标。
Table 1: Gemma-2 2B 主要结果
| Method | Sparsity | BoolQ | RTE | HellaSwag* | WinoGrande* | ARC-E* | ARC-C* | OBQA* | Avg |
|---|---|---|---|---|---|---|---|---|---|
| Dense | 0% | 77.7 | 63.5 | 73.0 | 71.2 | 77.9 | 52.6 | 42.2 | 65.4 |
| Magnitude-Dim | 10% | 58.7 | 53.1 | 40.1 | 52.6 | 47.5 | 29.4 | 31.0 | 44.6 |
| SliceGPT | 10% | 37.8 | 47.3 | 26.0 | 51.5 | 26.4 | 24.4 | 31.4 | 35.0 |
| PuDDing | 10% | 38.5 | 47.3 | 25.5 | 50.0 | 28.3 | 22.7 | 38.8 | 35.9 |
| DieT | 10% | 70.6 | 51.6 | 62.7 | 62.8 | 59.5 | 39.2 | 47.6 | 56.3 |
| Magnitude-Dim | 20% | 41.2 | 53.4 | 27.4 | 51.5 | 29.6 | 23.7 | 31.8 | 36.9 |
| SliceGPT | 20% | 38.5 | 47.3 | 25.9 | 49.2 | 26.7 | 24.1 | 33.2 | 35.0 |
| PuDDing | 20% | 49.8 | 47.3 | 25.2 | 48.3 | 26.3 | 24.7 | 29.2 | 35.8 |
| DieT | 20% | 62.2 | 47.3 | 44.3 | 56.0 | 44.5 | 29.3 | 31.6 | 45.0 |
| Magnitude-Dim | 30% | 38.5 | 47.3 | 25.4 | 48.9 | 25.5 | 22.4 | 30.8 | 34.1 |
| SliceGPT | 30% | 53.4 | 47.3 | 25.4 | 48.1 | 25.0 | 22.4 | 30.8 | 36.1 |
| PuDDing | 30% | 53.5 | 51.6 | 25.5 | 50.4 | 26.8 | 23.6 | 33.4 | 37.8 |
| DieT | 30% | 59.4 | 47.3 | 30.5 | 50.7 | 32.6 | 25.3 | 31.6 | 39.6 |
说明: DieT 在所有稀疏度下均显著优于 baseline。10% 稀疏度时平均准确率 56.3%,比 SliceGPT 高 21.3%,比 PuDDing 高 20.4%。
Table 2: Gemma-2 9B 主要结果
| Method | Sparsity | BoolQ | RTE | HellaSwag* | WinoGrande* | ARC-E* | ARC-C* | OBQA* | Avg |
|---|---|---|---|---|---|---|---|---|---|
| Dense | 0% | 85.0 | 74.4 | 82.5 | 77.3 | 85.8 | 64.8 | 47.2 | 73.9 |
| SliceGPT | 10% | 37.8 | 53.4 | 26.1 | 49.2 | 26.6 | 22.4 | 30.8 | 35.2 |
| PuDDing | 10% | 47.2 | 47.3 | 47.6 | 59.3 | 66.2 | 44.4 | 35.8 | 49.7 |
| DieT | 10% | 81.6 | 53.8 | 74.5 | 67.6 | 72.2 | 51.1 | 51.6 | 64.6 |
| SliceGPT | 20% | 46.5 | 47.3 | 26.1 | 49.3 | 26.2 | 22.4 | 29.2 | 35.3 |
| PuDDing | 20% | 55.1 | 47.3 | 49.3 | 59.7 | 66.7 | 40.3 | 23.8 | 48.9 |
| DieT | 20% | 70.4 | 47.3 | 55.3 | 56.4 | 53.3 | 36.1 | 42.2 | 51.6 |
说明: 9B 模型上优势更加明显。10% 稀疏度时 DieT 比 SliceGPT 高 29.4%,比 PuDDing 高 14.9%。
Table 3: 硬剪枝效率(Gemma-2 2B, 20% sparsity)
| Metric | Dense | DieT-Hard |
|---|---|---|
| Avg Latency (ms) | 41.52 | 39.58 |
| FLOPs (×10¹²) | 3.95 | 3.16 |
说明: 硬剪枝在 FLOPs 上减少约 20%,延迟有一定降低,GPU 显存峰值也一致性下降。
Table 4: Zero-masking vs 硬剪枝
| Variant | Avg Accuracy |
|---|---|
| DieT (zero-masked) | 45.0% |
| DieT-Hard (with rescaling) | 45.4% |
说明: 方差校正后的硬剪枝不仅保持精度,甚至略有提升(+0.4%),验证了缩放因子 的有效性。
Table 5: 全局 mask vs 单任务 mask(Gemma-2 2B, 20%)
| Mask Source | Avg Accuracy |
|---|---|
| Global (DieT) | 45.0% |
| ARC-Easy only | 44.0% |
| ARC-Challenge only | 43.6% |
说明: 多任务聚合的全局 mask 优于任何单一任务的 mask,验证了跨任务投票的增益。
Table 6: 连续分数 vs 二值投票
| Merging Strategy | Avg Accuracy |
|---|---|
| Continuous score | 44.6% |
| Binary voting (DieT) | 45.0% |
说明: 二值投票略优于连续分数加权,表明投票机制对分数量化粒度具有鲁棒性。
Table 7: 跨架构泛化(10% / 20% sparsity)
| Model | 10% Avg | 20% Avg |
|---|---|---|
| Qwen2.5-7B | 53.7 | 33.5 |
| Phi-4-mini-reasoning | 50.2 | 21.4 |
说明: 10% 稀疏度下跨架构泛化良好;20% 时非 Gemma 架构退化明显,暗示需要自适应稀疏度分配或逐层策略。
实验
数据集
| 数据集 | 类型 | 特点 | 指标 |
|---|---|---|---|
| BoolQ | 阅读理解 | Yes/No 问题 | acc |
| RTE | 文本蕴含 | 前提-假设对 | acc |
| HellaSwag | 常识推理 | 句子补全 | acc_norm |
| WinoGrande | 常识推理 | 代词消歧 | acc_norm |
| ARC-Easy | 科学问答 | 简单题目 | acc_norm |
| ARC-Challenge | 科学问答 | 困难题目 | acc_norm |
| OpenBookQA | 多跳推理 | 开放知识 | acc_norm |
实现细节
模型: Gemma-2 2B / 9B,以及 Qwen2.5-7B、Phi-4-mini-reasoning
Profiling 样本: 每任务 100 个随机样本(来自训练集)
随机种子: 42
稀疏度: 10%, 20%, 30%
评估工具: lm-evaluation-harness,zero-shot 模式
硬件: 未明确标注
可视化结果
投票共识: 高共识维度(7 票满票)集中于一小部分,表明各任务对”无用维度”有强共识
样本效率: 100 样本已足够,继续增加边际收益递减
批判性思考
优点
极简且实用: 不需要训练、不需要梯度、每任务仅需 100 样本,流程非常轻量
统一 mask 设计: 所有层共享同一 mask,部署极为简单,无需逐层管理
方差校正硬剪枝: 缩放是一个优雅的理论修正,使硬剪枝无精度损失
局限性
评估范围偏窄: 仅 7 个英语 zero-shot benchmark,缺少生成任务(如 perplexity、翻译、摘要)评估
高稀疏度泛化差: 非 Gemma 架构在 20% 稀疏度时显著退化(Qwen2.5: 33.5%,Phi-4-mini: 21.4%),普适性存疑
所有层共享 mask: 虽然简单,但忽略了不同层对维度重要性可能有差异的事实;逐层自适应 mask 可能进一步提升性能
baseline 选择: PuDDing 主要做 depth pruning(块级),与 DieT 的维度级剪枝不完全对等;缺少与 LLM-Pruner、Wanda 等同类维度级方法的直接对比
潜在改进方向
逐层自适应稀疏率: 不同层使用不同稀疏度,可能更好地适配非 Gemma 架构
任务权重投票: 引入任务难度/分布感知的加权投票(而非简单的多数投票)
结合 perplexity 校准: 在投票后用少量 calibration 数据微调 mask 边界
与恢复训练结合: 在硬剪枝后加入轻量 LoRA 微调,有望在高稀疏度下大幅提升
可复现性评估
- 代码开源(https://github.com/Jimmy145123/DIET)
- 预训练模型
- 训练细节完整(无需训练,profiling 细节清晰)
- 数据集可获取(所有 benchmark 均公开)
关联笔记
基于
SliceGPT: PCA 旋转+切片的维度剪枝基线
LLM-Pruner: 依赖图+梯度重要性的结构化剪枝
对比
SliceGPT: 同为维度级剪枝但不使用任务信息,DieT 全面超越
PuDDing: 需训练路由器的 prompt 感知 depth pruning,DieT 免训练且效果更优
Magnitude Pruning: 基于权重幅度的经典方法,作为维度级剪枝的朴素基线
方法相关
结构化剪枝: 核心方法类别
width pruning: DieT 本质是维度级宽度剪枝
Gemma: 主要实验模型架构
硬件/数据相关
BoolQ: 阅读理解 benchmark
HellaSwag: 常识推理 benchmark
WinoGrande: 代词消歧 benchmark
速查卡片
Diet Your LLM (DieT)
- 核心: 免训练维度级全局剪枝,跨任务激活投票构建统一 mask
- 方法: 每任务 100 样本 profiling MLP 激活 → 二值重要性选择器 → 多数投票聚合 → 方差校正硬剪枝
- 结果: Gemma-2 2B 10% 稀疏度 56.3%(比 SliceGPT +21.3%);20% 稀疏度 45.0%(+10.0%)
- 代码: https://github.com/Jimmy145123/DIET
笔记创建时间: 2026-03-27