Stop-Gradient

分类: 深度学习基础

Stop-Gradient

定义

在计算图中阻断梯度回传的操作,使得被标记的张量在反向传播时不产生梯度

数学形式

sg(x)=x,sg(x)x=0\operatorname{sg}(x) = x, \quad \frac{\partial \operatorname{sg}(x)}{\partial x} = 0

核心要点

前向传播中值不变,反向传播中梯度为零

常用于自蒸馏/对比学习中保护 teacher/momentum encoder

PyTorch 中通过 .detach()torch.no_grad() 实现

代表工作

MTP-D: 对主头 logits 应用 stop-gradient 保护主头性能

SimSiam, BYOL: 对比学习中的动量编码器

相关概念

自蒸馏

KL Divergence