残差连接

分类: 网络架构

残差连接

定义

在深度网络中让每层的输入直接加到输出上,形成 identity shortcut,使得梯度可以无损地反向传播到浅层

数学形式

hl=hl1+fl1(hl1)\boldsymbol{h}_l = \boldsymbol{h}_{l-1} + f_{l-1}(\boldsymbol{h}_{l-1})

展开后:hl=h1+i=1l1fi(hi)\boldsymbol{h}_l = \boldsymbol{h}_1 + \sum_{i=1}^{l-1} f_i(\boldsymbol{h}_i)

核心要点

解决深度网络的梯度消失/爆炸问题,使千层网络训练成为可能

提供 identity mapping 作为 gradient highway

固定权重 1 的累加等价于 depth-wise linear attention(全 1 下三角混合矩阵)

与 PreNorm 结合时导致隐状态幅度以 O(L)O(L) 增长(PreNorm dilution)

代表工作

ResNet: He et al. 2015,首次提出残差学习框架

AttnRes: 用 depth-wise softmax attention 替代固定权重累加

Highway Network: element-wise gating 的残差变体

DenseFormer: 对所有前序层的 learned scalar 加权

相关概念

PreNorm

Highway Network

identity mapping