梯度连乘与条件语言模型推导
分类: 神经网络基础 · 难度: 中级 · 关联讲座: L05
梯度连乘与条件语言模型推导
本文整理 RNN 梯度消失/爆炸的完整数学推导(链式法则→Jacobian 连乘→奇异值分析),以及 Seq2Seq 条件语言模型的概率分解与训练目标推导。这两个推导是理解”为什么需要注意力机制和 Transformer”的数学基础。
1. 梯度连乘
📐 梯度连乘:完整推导
变量定义:
- h(t) = 时间步 t 的隐状态
- Wh = 隐状态到隐状态的权重矩阵
- σ = 激活函数(如 tanh)
- z(j)=Whh(j−1)+Wxx(j) = 前激活值
推导过程:
第 1 步:从损失 J(t) 到 h(k) 的梯度,需要经过链式法则逐步回传:
∂h(k)∂J(t)=∂h(t)∂J(t)⋅∂h(k)∂h(t)
第 2 步:将 ∂h(k)∂h(t) 展开为连乘积(每一步用链式法则):
∂h(k)∂h(t)=∏j=k+1t∂h(j−1)∂h(j)
第 3 步:计算单步 Jacobian,由 h(j)=σ(Whh(j−1)+Wxx(j)) 得:
∂h(j−1)∂h(j)=WhT⋅diag(σ′(z(j)))
第 4 步:代入连乘积,得完整梯度公式:
∂h(k)∂h(t)=∏j=k+1tWhT⋅diag(σ′(z(j)))
第 5 步:奇异值分析,设 Wh 的最大奇异值为 λ1,σ′ 在 [0,1] 有界:
- 若 λ1<1:乘积随 t−k 指数趋零 → 梯度消失
- 若 λ1>1:乘积随 t−k 指数爆炸 → 梯度爆炸
Gradient Clipping 算法:
g^←⎩⎨⎧∥g^∥thresholdg^g^若 ∥g^∥>threshold否则
直觉:将梯度向量缩放到固定长度以内,保持方向不变,只压缩幅度。
2. 条件语言模型
📐 条件语言模型:完整推导
变量定义:
- x=(x1,…,xm) = 源句子(长度 m)
- y=(y1,…,yT) = 目标句子(长度 T)
- c = 上下文向量(encoder 最终隐状态)
- s(t) = decoder 在时间步 t 的隐状态
推导过程:
第 1 步:目标是对 P(y∣x) 建模,用概率链式法则分解联合概率:
P(y∣x)=P(y1,y2,…,yT∣x)=∏t=1TP(yt∣y1,…,yt−1,x)
第 2 步:Encoder 将 x 压缩为固定向量 c(Encoder RNN 最后一步隐状态):
c=henc(m)=fenc(x1,…,xm)
第 3 步:Decoder 在每一步利用 c 和之前生成的词:
s(t)=fdec(s(t−1),yt−1,c)
P(yt∣y<t,x)=softmax(Wos(t))[yt-index]
第 4 步:训练目标——最大化对数似然(对训练集上每个 (x,y) 对求和):
L=∑(x,y)∈DlogP(y∣x)=∑(x,y)∈D∑t=1TlogP(yt∣y<t,x)
第 5 步:推断时用贪心解码或 Beam Search 寻找最高概率序列:
y^=argmaxyP(y∣x)