LayerScale

分类: 网络架构

LayerScale

定义

用可学习的 element-wise 对角缩放矩阵调节每层输出的幅度,初始化为小值以稳定深层 ViT 训练

数学形式

hl=hl1+diag(λl)fl1(hl1)\boldsymbol{h}_l = \boldsymbol{h}_{l-1} + \text{diag}(\boldsymbol{\lambda}_l) \cdot f_{l-1}(\boldsymbol{h}_{l-1})

其中 λl\boldsymbol{\lambda}_l 初始化为小常数(如 10410^{-4}

核心要点

比 ReZero 更细粒度(per-channel vs per-layer)

权重为 static,训练后固定

仅访问前一层 hl1\boldsymbol{h}_{l-1},无 cross-layer access

首次在 CaiT (Going Deeper with Image Transformers) 中提出

代表工作

Touvron et al. 2021: Going Deeper with Image Transformers (CaiT)

相关概念

残差连接

ReZero

PreNorm