特征蒸馏

分类: 知识蒸馏

特征蒸馏(Feature Distillation)

定义

知识蒸馏的变体,在中间特征层(而非最终 logit)进行知识迁移:要求学生模型的中间特征表示与教师模型对齐,通常以 MSE 或余弦相似度为损失。

数学形式

Lfeat=MSE ⁣(WprojTfs(x),  ft(x))\mathcal{L}_{\text{feat}} = \ell_{\text{MSE}}\!\left(W_{\text{proj}}^T f_s(x),\; f_t(x)\right)

其中 WprojW_{\text{proj}} 为可学习投影矩阵(当学生与教师特征维度不同时使用),fsf_sftf_t 分别为学生和教师的中间特征。

核心要点

相比 logit-level 蒸馏,特征蒸馏传递更丰富的”中间表示”信息,通常在小规模数据上更有效

需要处理维度不匹配问题:常用可学习投影层(线性层或 1×1 卷积)

与 logit 蒸馏配合使用通常效果更佳(GRACE Table 4 证明两者互补)

FitNets 是最早系统使用特征蒸馏的工作

代表工作

GRACE: 在 Compress 阶段用 Lfeat=MSE(WprojTϕst,[ϕmerge,ϕprov])\mathcal{L}_{\text{feat}} = \ell_{\text{MSE}}(W_{\text{proj}}^T \phi_{\text{st}}, [\phi_{\text{merge}}, \phi_{\text{prov}}]) 对齐学生与教师的拼接特征

FitNets: 引导学生网络中间层匹配教师中间层

相关概念

知识蒸馏

类增量学习