SAM
分类: 训练优化
SAM (Sharpness-Aware Minimization)
定义
SAM 是一种优化方法,通过同时最小化损失值和损失曲面的锐度(sharpness)来寻找更平坦的极小值,从而提升模型的泛化能力
数学形式
内层最大化:在参数 的 -邻域内找到使损失最大的扰动
外层最小化:优化参数使得”最坏情况下的损失”最小
近似求解:,然后在 处计算梯度更新
核心要点
平坦极小值(flat minima)比尖锐极小值(sharp minima)泛化更好,这是 SAM 的理论基础
每步需要两次前向-反向传播:第一次计算扰动方向,第二次在扰动点计算实际梯度
训练开销约为标准 SGD 的 2 倍
后续改进:ASAM(自适应 SAM)、LookSAM(周期性 SAM)、GSAM(梯度引导 SAM)降低计算开销
代表工作
Foret et al., 2021 — 原始 SAM 论文(ICLR 2021)
Kwon et al., 2021 — ASAM,参数自适应的 SAM
Long-Tailed Loss Landscape — 分组 SAM (GSA) 用于长尾学习
相关概念
Hessian — 锐度与 Hessian 特征值密切相关
AdamW — 标准优化器,SAM 可叠加在其上使用