梯度连乘与条件语言模型推导

分类: 神经网络基础 · 难度: 中级 · 关联讲座: L05

梯度连乘与条件语言模型推导

本文整理 RNN 梯度消失/爆炸的完整数学推导(链式法则→Jacobian 连乘→奇异值分析),以及 Seq2Seq 条件语言模型的概率分解与训练目标推导。这两个推导是理解”为什么需要注意力机制和 Transformer”的数学基础。


1. 梯度连乘

📐 梯度连乘:完整推导

变量定义

  • h(t)h^{(t)} = 时间步 tt 的隐状态
  • WhW_h = 隐状态到隐状态的权重矩阵
  • σ\sigma = 激活函数(如 tanh)
  • z(j)=Whh(j1)+Wxx(j)z^{(j)} = W_h h^{(j-1)} + W_x x^{(j)} = 前激活值

推导过程

第 1 步:从损失 J(t)J^{(t)}h(k)h^{(k)} 的梯度,需要经过链式法则逐步回传:

J(t)h(k)=J(t)h(t)h(t)h(k)\frac{\partial J^{(t)}}{\partial h^{(k)}} = \frac{\partial J^{(t)}}{\partial h^{(t)}} \cdot \frac{\partial h^{(t)}}{\partial h^{(k)}}

第 2 步:将 h(t)h(k)\frac{\partial h^{(t)}}{\partial h^{(k)}} 展开为连乘积(每一步用链式法则):

h(t)h(k)=j=k+1th(j)h(j1)\frac{\partial h^{(t)}}{\partial h^{(k)}} = \prod_{j=k+1}^{t} \frac{\partial h^{(j)}}{\partial h^{(j-1)}}

第 3 步:计算单步 Jacobian,由 h(j)=σ(Whh(j1)+Wxx(j))h^{(j)} = \sigma(W_h h^{(j-1)} + W_x x^{(j)}) 得:

h(j)h(j1)=WhTdiag ⁣(σ(z(j)))\frac{\partial h^{(j)}}{\partial h^{(j-1)}} = W_h^T \cdot \text{diag}\!\left(\sigma'(z^{(j)})\right)

第 4 步:代入连乘积,得完整梯度公式:

h(t)h(k)=j=k+1tWhTdiag ⁣(σ(z(j)))\frac{\partial h^{(t)}}{\partial h^{(k)}} = \prod_{j=k+1}^{t} W_h^T \cdot \text{diag}\!\left(\sigma'(z^{(j)})\right)

第 5 步:奇异值分析,设 WhW_h 的最大奇异值为 λ1\lambda_1σ\sigma'[0,1][0,1] 有界:

  • λ1<1\lambda_1 < 1:乘积随 tkt - k 指数趋零 → 梯度消失
  • λ1>1\lambda_1 > 1:乘积随 tkt - k 指数爆炸 → 梯度爆炸

Gradient Clipping 算法

g^{thresholdg^g^若 g^>thresholdg^否则\hat{g} \leftarrow \begin{cases} \dfrac{\text{threshold}}{\|\hat{g}\|} \hat{g} & \text{若 } \|\hat{g}\| > \text{threshold} \\ \hat{g} & \text{否则} \end{cases}

直觉:将梯度向量缩放到固定长度以内,保持方向不变,只压缩幅度。


2. 条件语言模型

📐 条件语言模型:完整推导

变量定义

  • x=(x1,,xm)x = (x_1, \ldots, x_m) = 源句子(长度 mm
  • y=(y1,,yT)y = (y_1, \ldots, y_T) = 目标句子(长度 TT
  • cc = 上下文向量(encoder 最终隐状态)
  • s(t)s^{(t)} = decoder 在时间步 tt 的隐状态

推导过程

第 1 步:目标是对 P(yx)P(y|x) 建模,用概率链式法则分解联合概率:

P(yx)=P(y1,y2,,yTx)=t=1TP(yty1,,yt1,x)P(y|x) = P(y_1, y_2, \ldots, y_T | x) = \prod_{t=1}^{T} P(y_t | y_1, \ldots, y_{t-1}, x)

第 2 步:Encoder 将 xx 压缩为固定向量 cc(Encoder RNN 最后一步隐状态):

c=henc(m)=fenc(x1,,xm)c = h^{(m)}_\text{enc} = f_\text{enc}(x_1, \ldots, x_m)

第 3 步:Decoder 在每一步利用 cc 和之前生成的词:

s(t)=fdec(s(t1),yt1,c)s^{(t)} = f_\text{dec}(s^{(t-1)}, y_{t-1}, c)

P(yty<t,x)=softmax(Wos(t))[yt-index]P(y_t | y_{<t}, x) = \text{softmax}(W_o s^{(t)}) \big[y_t\text{-index}\big]

第 4 步:训练目标——最大化对数似然(对训练集上每个 (x,y)(x, y) 对求和):

L=(x,y)DlogP(yx)=(x,y)Dt=1TlogP(yty<t,x)\mathcal{L} = \sum_{(x,y) \in \mathcal{D}} \log P(y|x) = \sum_{(x,y) \in \mathcal{D}} \sum_{t=1}^{T} \log P(y_t | y_{<t}, x)

第 5 步:推断时用贪心解码或 Beam Search 寻找最高概率序列:

y^=argmaxyP(yx)\hat{y} = \arg\max_{y} P(y|x)