LSTM 完整推导与梯度消失分析
分类: 神经网络基础 · 难度: 进阶 · 关联讲座: L04
LSTM 完整推导与梯度消失分析
本文从 RNN 的前向传播与反向传播(BPTT)出发,推导梯度消失的数学根因,再引出 LSTM 的门控架构如何通过加法更新缓解这一问题。内容覆盖 RNN 的完整计算流程、梯度链的谱半径分析,以及 LSTM 四个门的完整推导与梯度流对比。
1. RNN 前向传播与 BPTT
📐 RNN 前向传播与 BPTT
前向传播(Forward Pass):
给定词嵌入序列 x(1),…,x(T)(每个 x(t)∈Rd):
步骤 1:初始化 h(0)=0(或随机初始化)
步骤 2:逐步更新隐状态:
h(t)=σ(Whh(t−1)+Wxx(t)+bh)
其中 Wh∈Rn×n(隐藏到隐藏),Wx∈Rn×d(输入到隐藏)
步骤 3:每步输出分布(softmax 预测下一词):
y^(t)=softmax(Wyh(t)+by)
损失:J(t)=−logP(x(t+1)∣x(≤t))=−logy^x(t+1)(t)(正确词的对数概率)
总损失:J=T1∑t=1TJ(t)
反向传播(BPTT — Backpropagation Through Time):
对 Wh 求梯度时,需要跨时间步回传:
∂Wh∂J(t)=∑k=1t∂h(t)∂J(t)⋅(∏j=k+1t∂h(j−1)∂h(j))⋅∂Wh∂h(k)
每一步的局部雅可比:∂h(j−1)∂h(j)=WhTdiag(σ′(z(j)))(其中 z(j)=Whh(j−1)+…)
🔢 小型 RNN 前向传播示例
设定:d=2(词向量维度),n=2(隐藏层维度),词汇表 ∣V∣=3
参数(小数值方便计算):
- Wh=[0.5000.5],Wx=[1001],bh=0
- 输入序列:x(1)=[0.8,0.2]T,x(2)=[0.3,0.7]T
计算:
- t=1:h(1)=tanh(Wh⋅0+Wx⋅x(1))=tanh([0.8,0.2]T)=[0.664,0.197]T
- t=2:z(2)=Whh(1)+Wxx(2)=[0.332+0.3,0.099+0.7]T=[0.632,0.799]T
h(2)=tanh([0.632,0.799]T)=[0.559,0.665]T
观察:h(2) 同时包含了 x(1) 和 x(2) 的信息(通过 h(1) 间接传入)。
⚠️ 常见误区
- 误区:RNN 的权重 Wh,Wx 在每个时间步都不同 → 正确:RNN 在所有时间步共享同一套权重(这是 RNN 的核心特性!)。共享权重使模型大小不随序列长度增长。
- 误区:BPTT 会计算所有历史时间步 → 正确:实践中常用”截断 BPTT(Truncated BPTT)“,只回传固定步数(如 100 步),平衡计算量和梯度信息。
2. 梯度消失的数学原因
📐 梯度消失的数学原因
梯度链的展开(损失 J(t) 对早期隐状态 h(k) 的梯度):
∂h(k)∂J(t)=∂h(t)∂J(t)⋅∏j=k+1t∂h(j−1)∂h(j)
每一步的 Jacobian:
∂h(j−1)∂h(j)=WhT⋅diag(σ′(z(j)))
谱半径决定命运:设 WhT 的最大奇异值为 σ1,激活函数的最大导数为 γ:
- 如果 σ1γ<1:t−k 步后梯度 ∼(σ1γ)t−k→0(梯度消失)
- 如果 σ1γ>1:梯度 ∼(σ1γ)t−k→∞(梯度爆炸)
- 只有精确 =1 时梯度才稳定(实践中极难维持)
为什么激活函数导数也是关键:
- tanh′(z)=1−tanh2(z)∈(0,1],在饱和区趋向 0
- sigmoid 的导数最大值仅 0.25,梯度消失更严重
- ReLU 在 z>0 区域导数 = 1,但负值区域梯度 = 0
Gradient Clipping 只治”爆炸”:
if ∥g^∥≥threshold:g^←∥g^∥thresholdg^
这等比例缩小梯度向量,保持方向但限制大小。不能增大消失的梯度。
🔢 梯度消失的数值示意
标量简化:设 Wh=0.9,tanh′(z(j))=0.8(常数,简化):
每步缩减因子=0.9×0.8=0.72
| 距离(时间步数) | 梯度大小 |
|---|
| t−k=1 | 0.721=0.720 |
| t−k=5 | 0.725=0.193 |
| t−k=10 | 0.7210=0.037 |
| t−k=20 | 0.7220=0.001 |
20 步之前的信号梯度已衰减到 0.1%,实际上无法更新相关参数。
对比爆炸(Wh=1.1,σ′=1.0):
- t−k=10:1.110=2.59
- t−k=20:1.120=6.73
- t−k=50:1.150=117 → NaN
⚠️ 常见误区
- 误区:梯度消失 = 梯度为 0 → 正确:梯度消失是指数衰减,不是突然归零。距离 20 步时可能只有 0.001 的大小,小到无法驱动参数更新,但并不完全是 0。
- 误区:梯度裁剪可以防止梯度消失 → 正确:裁剪只防止爆炸(缩小过大的梯度)。防止消失需要架构改变(LSTM 的加法更新、残差连接)。
3. LSTM 完整推导——为什么门控能缓解梯度消失?
📐 LSTM 完整推导——为什么门控能缓解梯度消失?
LSTM 的核心区别:引入记忆单元(cell state) c(t),通过加法更新而非乘法更新传递信息。
所有门的计算(共 4 个):
ifoc~(t)=σσσtanh(W[h(t−1)x(t)]+b)
展开为:
- 遗忘门:f(t)=σ(Wf[h(t−1),x(t)]+bf)∈(0,1)n
- 输入门:i(t)=σ(Wi[h(t−1),x(t)]+bi)∈(0,1)n
- 候选记忆:c~(t)=tanh(Wc[h(t−1),x(t)]+bc)∈(−1,1)n
- 输出门:o(t)=σ(Wo[h(t−1),x(t)]+bo)∈(0,1)n
记忆单元(加法更新——关键!):
c(t)=f(t)⊙c(t−1)+i(t)⊙c~(t)
隐状态(输出):
h(t)=o(t)⊙tanh(c(t))
为什么加法更新能缓解梯度消失?
c(t) 对 c(t−1) 的梯度:
∂c(t−1)∂c(t)=f(t)
(忽略高阶项)。这是逐元素乘以遗忘门,而不是乘以大矩阵!
对比标准 RNN 的 ∂h(t−1)∂h(t)=WhTdiag(σ′)(矩阵乘法,连乘后快速衰减)。
如果遗忘门 f(t)≈1(不遗忘),梯度 ≈ 1,不消失也不爆炸。这类似于 ResNet 的残差连接!
🔢 LSTM 一步更新数值示例
设定:n=2(隐藏维度),当前状态:
- h(t−1)=[0.5,−0.3]T,c(t−1)=[0.8,0.2]T
- 输入 x(t)=[1.0,0.0]T
假设(为简化,直接给出门的激活值):
- f(t)=[0.9,0.8]T(遗忘大部分记忆)
- i(t)=[0.7,0.4]T(写入部分新信息)
- c~(t)=[0.6,−0.5]T(候选新记忆)
- o(t)=[0.6,0.7]T(输出多少到隐状态)
计算:
-
新记忆:c(t)=f⊙c(t−1)+i⊙c~
=[0.9×0.8,0.8×0.2]+[0.7×0.6,0.4×(−0.5)]
=[0.72,0.16]+[0.42,−0.20]=[1.14,−0.04]
-
新隐状态:h(t)=o⊙tanh(c(t))
tanh([1.14,−0.04])=[0.817,−0.040]
h(t)=[0.6×0.817,0.7×(−0.040)]=[0.490,−0.028]
观察:遗忘门 f=0.9(几乎保留过去记忆),输入门写入新信息,输出门控制最终输出。
⚠️ 常见误区
- 误区:LSTM 完全解决了梯度消失 → 正确:LSTM 大幅缓解梯度消失(通过记忆单元的加法路径),但并不完全消除。在极长序列(数千步)上仍然困难。Transformer 通过注意力机制进一步解决这个问题。
- 误区:遗忘门为 0 = 清除所有记忆 → 正确:f=0 时 c(t)=i⊙c~(t),完全丢弃过去记忆,完全用新信息。f=1 时完全保留记忆(加法更新不改变 c)。中间值 = 部分遗忘。
- 误区:LSTM 有 4 组权重矩阵,参数是 RNN 的 4 倍 → 正确:可以将 4 个门的权重矩阵拼为一个大矩阵 W∈R4n×(n+d),一次矩阵乘法完成所有门的计算(更高效)。
知识关联
🔗 知识关联
- → L05 注意力机制:LSTM 虽然缓解了梯度消失,但仍受限于顺序计算。注意力机制(Bahdanau et al., 2015)让 decoder 直接访问所有 encoder 隐状态,进一步解决长距离依赖。
- → L05 Transformer:完全放弃 RNN 的循环结构,改用自注意力实现并行计算和直接的长距离连接。
- ← L03 反向传播:BPTT 是 L03 中讲的计算图反向传播在时间序列上的展开应用。
- ← L02 Word2Vec:RNN-LM 的输入是词嵌入向量,与 L02 中的词向量训练直接对接。
- ↔ ResNet 残差连接:LSTM 的 cell state 加法更新与 ResNet 的 skip connection 原理相同——都通过加法路径保持梯度流通。