SAM

分类: 训练优化

SAM (Sharpness-Aware Minimization)

定义

SAM 是一种优化方法,通过同时最小化损失值和损失曲面的锐度(sharpness)来寻找更平坦的极小值,从而提升模型的泛化能力

数学形式

minwmaxϵρL(w+ϵ)\min_{\mathbf{w}} \max_{\|\boldsymbol{\epsilon}\| \leq \rho} L(\mathbf{w} + \boldsymbol{\epsilon})

内层最大化:在参数 w\mathbf{w}ρ\rho-邻域内找到使损失最大的扰动 ϵ\boldsymbol{\epsilon}

外层最小化:优化参数使得”最坏情况下的损失”最小

近似求解:ϵ^=ρwL(w)wL(w)\hat{\boldsymbol{\epsilon}} = \rho \frac{\nabla_{\mathbf{w}} L(\mathbf{w})}{\|\nabla_{\mathbf{w}} L(\mathbf{w})\|},然后在 w+ϵ^\mathbf{w} + \hat{\boldsymbol{\epsilon}} 处计算梯度更新

核心要点

平坦极小值(flat minima)比尖锐极小值(sharp minima)泛化更好,这是 SAM 的理论基础

每步需要两次前向-反向传播:第一次计算扰动方向,第二次在扰动点计算实际梯度

训练开销约为标准 SGD 的 2 倍

后续改进:ASAM(自适应 SAM)、LookSAM(周期性 SAM)、GSAM(梯度引导 SAM)降低计算开销

代表工作

Foret et al., 2021 — 原始 SAM 论文(ICLR 2021)

Kwon et al., 2021 — ASAM,参数自适应的 SAM

Long-Tailed Loss Landscape — 分组 SAM (GSA) 用于长尾学习

相关概念

Hessian — 锐度与 Hessian 特征值密切相关

AdamW — 标准优化器,SAM 可叠加在其上使用