Grow, Don't Overwrite: Fine-tuning Without Forgetting

作者: Dyah Adila 年份: 2026 会议: arXiv 分类: 模型增长

Grow, Don’t Overwrite: Fine-tuning Without Forgetting

一句话总结

通过 function-preserving expansion 在 Transformer 子模块中复制预训练参数并施加缩放校正,使扩展后模型在初始化时与原模型数学等价,从而在微调新任务时完全消除灾难性遗忘

核心问题

标准微调覆盖预训练参数 → 内部表征漂移 → 灾难性遗忘。正则化方法(如 EWC)在固定容量内做零和博弈——分配给记忆的资源无法用于学习。

形式化:给定预训练模型 M0M_0(参数 θ0\theta_0),目标是:

minθTLT(θT)s.t.LPT(θT)LPT(θ0)\min_{\theta_T} \mathcal{L}_T(\theta_T) \quad \text{s.t.} \quad \mathcal{L}_{PT}(\theta_T) \leq \mathcal{L}_{PT}(\theta_0)

关键公式

MLP 扩展的 Function-Preserving 变换

原始 MLP 层:

MLPn(X)=ReLU(XWn(1)+Bn(1))Wn(2)+Bn(2)\text{MLP}_n(\mathbf{X}) = \text{ReLU}(\mathbf{X}\mathbf{W}_n^{(1)} + \mathbf{B}_n^{(1)}) \cdot \mathbf{W}_n^{(2)} + \mathbf{B}_n^{(2)}

其中 Wn(1)Rh×p\mathbf{W}_n^{(1)} \in \mathbb{R}^{h \times p}Wn(2)Rp×h\mathbf{W}_n^{(2)} \in \mathbb{R}^{p \times h}

Step 1: 复制 Up-Projection(隐藏维度 p2pp \to 2p):

Wn(1)W^n(1):=[Wn(1)Wn(1)]Rh×2p\mathbf{W}_n^{(1)} \mapsto \hat{\mathbf{W}}_n^{(1)} := [\mathbf{W}_n^{(1)} \mid \mathbf{W}_n^{(1)}] \in \mathbb{R}^{h \times 2p}

Step 2: 缩放 Down-Projection(除以复制因子 k=2k=2):

Wn(2)W^n(2):=[12Wn(2)12Wn(2)]R2p×h\mathbf{W}_n^{(2)} \mapsto \hat{\mathbf{W}}_n^{(2)} := \begin{bmatrix} \frac{1}{2}\mathbf{W}_n^{(2)} \\ \frac{1}{2}\mathbf{W}_n^{(2)} \end{bmatrix} \in \mathbb{R}^{2p \times h}

Function-Preserving 证明

Y=ReLU(XWn(1))\mathbf{Y} = \text{ReLU}(\mathbf{X}\mathbf{W}_n^{(1)}),扩展后输出:

[YY][12Wn(2)12Wn(2)]=12YWn(2)+12YWn(2)=YWn(2)[\mathbf{Y} \mid \mathbf{Y}] \begin{bmatrix} \frac{1}{2}\mathbf{W}_n^{(2)} \\ \frac{1}{2}\mathbf{W}_n^{(2)} \end{bmatrix} = \frac{1}{2}\mathbf{Y}\mathbf{W}_n^{(2)} + \frac{1}{2}\mathbf{Y}\mathbf{W}_n^{(2)} = \mathbf{Y}\mathbf{W}_n^{(2)}

任意扩展因子 kk:Up-projection 拼接 kk 次,Down-projection 每份缩放 1/k1/k,function-preserving 对任意 k2k \geq 2 成立。

方法详解

微调策略

策略训练参数适用场景
G-Freeze仅新增权重(复制的 up-proj 和 down-proj)默认方案,简单任务
G-Train整个扩展后的 up-projection复杂任务(如 MathQA)

关键洞察:事实性知识定位在 down-projection 层,因此 G-Train 只解锁 up-projection。

选择性层扩展

不需要扩展所有层——先做一次 SFT 预实验,按权重更新幅度排序各层,选 top-N 层扩展:

选择 9-10 层即可匹配全量扩展性能

可训练参数从 ~60%(全量)降至 ~30%(选择性)

任务复杂度与扩展需求

通过分析权重更新矩阵 ΔW(1)=Wt(1)Wt1(1)\Delta\mathbf{W}^{(1)} = \mathbf{W}_t^{(1)} - \mathbf{W}_{t-1}^{(1)} 的 effective rank:

简单任务(entailment):高秩更新集中在少数层

复杂推理(MathQA):高秩更新分散在几乎所有层 → 需要更广泛的扩展

关键图表

Figure 1: Growing Approach 概览

复制预训练参数 + 缩放校正 → 初始化时数学等价 → 微调时新增容量吸收新知识。

Figure 2: 遗忘消除实验

展示四个下游任务(翻译、蕴涵、科学QA、数学推理)上的结果。SFT baseline 在大领域迁移时原始性能近乎归零;G-Freeze/G-Train 完全保持原始性能,同时匹配或超越 SFT 的新任务表现。

Figure 3: 选择性扩展

9-10 层选择性扩展 ≈ 全量扩展性能,可训练参数减半。

实验结果

主实验:Gemma3-1B

任务SFT 新任务SFT 原始域G-Freeze 新任务G-Freeze 原始域
法语翻译近零✓(匹配)保持
蕴涵推理严重退化✓(匹配)保持
科学QA轻度退化✓(匹配)保持
数学推理中度退化✓(G-Train超越)保持

Function Vector 保持分析

数据集方法交叉因果头FV 余弦相似度
蕴涵SFT2/100.28
蕴涵Ours5/100.95
翻译SFT3/100.58
翻译Ours5/100.76

方法保持了预训练模型的内部神经回路(Function Vectors),直接阻止了表征漂移。

消融实验

初始化策略对比

零初始化新层:新任务学习失败

零初始化 + 冻结原始参数:灾难性遗忘

Function-preserving(本文):两者兼顾

扩展哪个组件?

扩展目标效果
MLP 隐藏维度最优
Attention head 维度效果差
Attention head 数量效果差
MLP + Attention无额外收益

结论:MLP 扩展是最参数高效且性能最好的策略。

与已有工作的关系

方法思路与本文区别
Net2NetFunction-preserving wider/deeper用于训练加速,非抗遗忘
LoRA低秩增量微调不保证 function-preserving
GROWN%3A Grow Only When Necessary for Continual Learning持续学习中的增长关注何时增长,本文关注如何增长
EWC / L2-SP正则化约束固定容量零和博弈
Progressive Neural Networks任务增加新列不共享参数,不 function-preserving

对我们工作的启示

  1. 核心方向直接相关:这是 function-preserving expansion 从训练加速向抗遗忘微调的扩展,验证了 function-preserving 约束的通用价值
  2. 缩放校正的具体实现1/k1/k 缩放是最简洁的 function-preserving 方案,可直接复用
  3. 选择性扩展:不需要扩展所有层,按权重更新幅度选层 → 可以迁移到模型增长的”在哪里增长”问题
  4. MLP > Attention:MLP 扩展效果最好,与 AMP 剪枝论文的发现一致(MLP 是 ViT 的参数大户)
  5. Effective rank 分析:任务复杂度与需要的扩展广度正相关 → 可以作为自适应增长策略的理论基础

局限性

只在语言模型(Gemma3-1B/4B)上验证,缺少视觉模型实验

选择性扩展的层选择依赖 SFT 预实验,不是端到端自动化的

未与 LoRA + regularization 的组合方案对比

扩展因子 k=2k=2 是经验选择,缺少理论指导

相关概念

function-preserving

Net2Net

LoRA

持续学习