classification head

分类: 网络架构

Classification Head

定义

语言模型最后一层,将隐藏向量 hRd\mathbf{h} \in \mathbb{R}^d 映射到词表大小的 logit 向量 zRv\mathbf{z} \in \mathbb{R}^v,用于 next token 预测

数学形式

z=Eh\mathbf{z} = \mathbf{E} \cdot \mathbf{h}

核心要点

通常与 token embedding 矩阵共享权重(weight tying),参数量为 v×dv \times d

在小型语言模型中可占模型总参数的 60% 和推理计算的 50%

随着词表规模增长(如 Llama-3 的 128K),分类头成为推理瓶颈

计算复杂度 O(vd)O(vd),是单纯的矩阵-向量乘法

代表工作

FlashHead: 用两阶段球面聚类检索替代稠密分类头,参数减少 87.4%

SVD-Softmax: 用低秩分解加速分类头

Vocabulary Trimming: 直接裁剪词表缩小分类头

相关概念

Softmax

Greedy Decoding

Temperature Sampling