LSTM 完整推导与梯度消失分析

分类: 神经网络基础 · 难度: 进阶 · 关联讲座: L04

LSTM 完整推导与梯度消失分析

本文从 RNN 的前向传播与反向传播(BPTT)出发,推导梯度消失的数学根因,再引出 LSTM 的门控架构如何通过加法更新缓解这一问题。内容覆盖 RNN 的完整计算流程、梯度链的谱半径分析,以及 LSTM 四个门的完整推导与梯度流对比。


1. RNN 前向传播与 BPTT

📐 RNN 前向传播与 BPTT

前向传播(Forward Pass)

给定词嵌入序列 x(1),,x(T)x^{(1)}, \ldots, x^{(T)}(每个 x(t)Rdx^{(t)} \in \mathbb{R}^d):

步骤 1:初始化 h(0)=0h^{(0)} = \mathbf{0}(或随机初始化)

步骤 2:逐步更新隐状态: h(t)=σ(Whh(t1)+Wxx(t)+bh)h^{(t)} = \sigma(W_h h^{(t-1)} + W_x x^{(t)} + b_h)

其中 WhRn×nW_h \in \mathbb{R}^{n \times n}(隐藏到隐藏),WxRn×dW_x \in \mathbb{R}^{n \times d}(输入到隐藏)

步骤 3:每步输出分布(softmax 预测下一词): y^(t)=softmax(Wyh(t)+by)\hat{y}^{(t)} = \text{softmax}(W_y h^{(t)} + b_y)

损失J(t)=logP(x(t+1)x(t))=logy^x(t+1)(t)J^{(t)} = -\log P(x^{(t+1)} | x^{(\le t)}) = -\log \hat{y}^{(t)}_{x^{(t+1)}}(正确词的对数概率)

总损失J=1Tt=1TJ(t)J = \frac{1}{T}\sum_{t=1}^{T} J^{(t)}

反向传播(BPTT — Backpropagation Through Time)

WhW_h 求梯度时,需要跨时间步回传: J(t)Wh=k=1tJ(t)h(t)(j=k+1th(j)h(j1))h(k)Wh\frac{\partial J^{(t)}}{\partial W_h} = \sum_{k=1}^{t} \frac{\partial J^{(t)}}{\partial h^{(t)}} \cdot \left(\prod_{j=k+1}^{t} \frac{\partial h^{(j)}}{\partial h^{(j-1)}}\right) \cdot \frac{\partial h^{(k)}}{\partial W_h}

每一步的局部雅可比:h(j)h(j1)=WhTdiag(σ(z(j)))\frac{\partial h^{(j)}}{\partial h^{(j-1)}} = W_h^T \text{diag}(\sigma'(z^{(j)}))(其中 z(j)=Whh(j1)+z^{(j)} = W_h h^{(j-1)} + \ldots

🔢 小型 RNN 前向传播示例

设定d=2d=2(词向量维度),n=2n=2(隐藏层维度),词汇表 V=3|V|=3

参数(小数值方便计算):

  • Wh=[0.5000.5]W_h = \begin{bmatrix}0.5 & 0 \\ 0 & 0.5\end{bmatrix}Wx=[1001]W_x = \begin{bmatrix}1 & 0 \\ 0 & 1\end{bmatrix}bh=0b_h = \mathbf{0}
  • 输入序列:x(1)=[0.8,0.2]Tx^{(1)} = [0.8, 0.2]^Tx(2)=[0.3,0.7]Tx^{(2)} = [0.3, 0.7]^T

计算

  1. t=1t=1h(1)=tanh(Wh0+Wxx(1))=tanh([0.8,0.2]T)=[0.664,0.197]Th^{(1)} = \tanh(W_h \cdot \mathbf{0} + W_x \cdot x^{(1)}) = \tanh([0.8, 0.2]^T) = [0.664, 0.197]^T
  2. t=2t=2z(2)=Whh(1)+Wxx(2)=[0.332+0.3,0.099+0.7]T=[0.632,0.799]Tz^{(2)} = W_h h^{(1)} + W_x x^{(2)} = [0.332+0.3, 0.099+0.7]^T = [0.632, 0.799]^T h(2)=tanh([0.632,0.799]T)=[0.559,0.665]Th^{(2)} = \tanh([0.632, 0.799]^T) = [0.559, 0.665]^T

观察h(2)h^{(2)} 同时包含了 x(1)x^{(1)}x(2)x^{(2)} 的信息(通过 h(1)h^{(1)} 间接传入)。

⚠️ 常见误区

  1. 误区:RNN 的权重 Wh,WxW_h, W_x 在每个时间步都不同 → 正确:RNN 在所有时间步共享同一套权重(这是 RNN 的核心特性!)。共享权重使模型大小不随序列长度增长。
  2. 误区:BPTT 会计算所有历史时间步 → 正确:实践中常用”截断 BPTT(Truncated BPTT)“,只回传固定步数(如 100 步),平衡计算量和梯度信息。

2. 梯度消失的数学原因

📐 梯度消失的数学原因

梯度链的展开(损失 J(t)J^{(t)} 对早期隐状态 h(k)h^{(k)} 的梯度):

J(t)h(k)=J(t)h(t)j=k+1th(j)h(j1)\frac{\partial J^{(t)}}{\partial h^{(k)}} = \frac{\partial J^{(t)}}{\partial h^{(t)}} \cdot \prod_{j=k+1}^{t} \frac{\partial h^{(j)}}{\partial h^{(j-1)}}

每一步的 Jacobian: h(j)h(j1)=WhTdiag(σ(z(j)))\frac{\partial h^{(j)}}{\partial h^{(j-1)}} = W_h^T \cdot \text{diag}(\sigma'(z^{(j)}))

谱半径决定命运:设 WhTW_h^T 的最大奇异值为 σ1\sigma_1,激活函数的最大导数为 γ\gamma

  • 如果 σ1γ<1\sigma_1 \gamma < 1tkt - k 步后梯度 (σ1γ)tk0\sim (\sigma_1 \gamma)^{t-k} \to 0梯度消失
  • 如果 σ1γ>1\sigma_1 \gamma > 1:梯度 (σ1γ)tk\sim (\sigma_1 \gamma)^{t-k} \to \infty梯度爆炸
  • 只有精确 =1= 1 时梯度才稳定(实践中极难维持)

为什么激活函数导数也是关键

  • tanh(z)=1tanh2(z)(0,1]\tanh'(z) = 1 - \tanh^2(z) \in (0, 1],在饱和区趋向 0
  • sigmoid 的导数最大值仅 0.25,梯度消失更严重
  • ReLU 在 z>0z > 0 区域导数 = 1,但负值区域梯度 = 0

Gradient Clipping 只治”爆炸”if g^threshold:g^thresholdg^g^\text{if } \|\hat{g}\| \ge \text{threshold}: \quad \hat{g} \leftarrow \frac{\text{threshold}}{\|\hat{g}\|} \hat{g}

这等比例缩小梯度向量,保持方向但限制大小。不能增大消失的梯度。

🔢 梯度消失的数值示意

标量简化:设 Wh=0.9W_h = 0.9tanh(z(j))=0.8\tanh'(z^{(j)}) = 0.8(常数,简化):

每步缩减因子=0.9×0.8=0.72\text{每步缩减因子} = 0.9 \times 0.8 = 0.72

距离(时间步数)梯度大小
tk=1t - k = 10.721=0.7200.72^1 = 0.720
tk=5t - k = 50.725=0.1930.72^5 = 0.193
tk=10t - k = 100.7210=0.0370.72^{10} = 0.037
tk=20t - k = 200.7220=0.0010.72^{20} = 0.001

20 步之前的信号梯度已衰减到 0.1%,实际上无法更新相关参数。

对比爆炸(Wh=1.1W_h = 1.1σ=1.0\sigma' = 1.0):

  • tk=10t - k = 101.110=2.591.1^{10} = 2.59
  • tk=20t - k = 201.120=6.731.1^{20} = 6.73
  • tk=50t - k = 501.150=1171.1^{50} = 117 → NaN

⚠️ 常见误区

  1. 误区:梯度消失 = 梯度为 0 → 正确:梯度消失是指数衰减,不是突然归零。距离 20 步时可能只有 0.001 的大小,小到无法驱动参数更新,但并不完全是 0。
  2. 误区:梯度裁剪可以防止梯度消失 → 正确:裁剪只防止爆炸(缩小过大的梯度)。防止消失需要架构改变(LSTM 的加法更新、残差连接)。

3. LSTM 完整推导——为什么门控能缓解梯度消失?

📐 LSTM 完整推导——为什么门控能缓解梯度消失?

LSTM 的核心区别:引入记忆单元(cell state) c(t)c^{(t)},通过加法更新而非乘法更新传递信息。

所有门的计算(共 4 个):

[ifoc~](t)=[σσσtanh](W[h(t1)x(t)]+b)\begin{bmatrix} i \\ f \\ o \\ \tilde{c} \end{bmatrix}^{(t)} = \begin{bmatrix} \sigma \\ \sigma \\ \sigma \\ \tanh \end{bmatrix}\left(W \begin{bmatrix} h^{(t-1)} \\ x^{(t)} \end{bmatrix} + b\right)

展开为:

  • 遗忘门:f(t)=σ(Wf[h(t1),x(t)]+bf)(0,1)nf^{(t)} = \sigma(W_f [h^{(t-1)}, x^{(t)}] + b_f) \in (0, 1)^n
  • 输入门:i(t)=σ(Wi[h(t1),x(t)]+bi)(0,1)ni^{(t)} = \sigma(W_i [h^{(t-1)}, x^{(t)}] + b_i) \in (0, 1)^n
  • 候选记忆:c~(t)=tanh(Wc[h(t1),x(t)]+bc)(1,1)n\tilde{c}^{(t)} = \tanh(W_c [h^{(t-1)}, x^{(t)}] + b_c) \in (-1, 1)^n
  • 输出门:o(t)=σ(Wo[h(t1),x(t)]+bo)(0,1)no^{(t)} = \sigma(W_o [h^{(t-1)}, x^{(t)}] + b_o) \in (0, 1)^n

记忆单元(加法更新——关键!)c(t)=f(t)c(t1)+i(t)c~(t)c^{(t)} = f^{(t)} \odot c^{(t-1)} + i^{(t)} \odot \tilde{c}^{(t)}

隐状态(输出)h(t)=o(t)tanh(c(t))h^{(t)} = o^{(t)} \odot \tanh(c^{(t)})

为什么加法更新能缓解梯度消失?

c(t)c^{(t)}c(t1)c^{(t-1)} 的梯度: c(t)c(t1)=f(t)\frac{\partial c^{(t)}}{\partial c^{(t-1)}} = f^{(t)}

(忽略高阶项)。这是逐元素乘以遗忘门,而不是乘以大矩阵!

对比标准 RNN 的 h(t)h(t1)=WhTdiag(σ)\frac{\partial h^{(t)}}{\partial h^{(t-1)}} = W_h^T \text{diag}(\sigma')(矩阵乘法,连乘后快速衰减)。

如果遗忘门 f(t)1f^{(t)} \approx 1(不遗忘),梯度 ≈ 1,不消失也不爆炸。这类似于 ResNet 的残差连接!

🔢 LSTM 一步更新数值示例

设定n=2n=2(隐藏维度),当前状态:

  • h(t1)=[0.5,0.3]Th^{(t-1)} = [0.5, -0.3]^Tc(t1)=[0.8,0.2]Tc^{(t-1)} = [0.8, 0.2]^T
  • 输入 x(t)=[1.0,0.0]Tx^{(t)} = [1.0, 0.0]^T

假设(为简化,直接给出门的激活值):

  • f(t)=[0.9,0.8]Tf^{(t)} = [0.9, 0.8]^T(遗忘大部分记忆)
  • i(t)=[0.7,0.4]Ti^{(t)} = [0.7, 0.4]^T(写入部分新信息)
  • c~(t)=[0.6,0.5]T\tilde{c}^{(t)} = [0.6, -0.5]^T(候选新记忆)
  • o(t)=[0.6,0.7]To^{(t)} = [0.6, 0.7]^T(输出多少到隐状态)

计算

  1. 新记忆c(t)=fc(t1)+ic~c^{(t)} = f \odot c^{(t-1)} + i \odot \tilde{c} =[0.9×0.8,0.8×0.2]+[0.7×0.6,0.4×(0.5)]= [0.9 \times 0.8, 0.8 \times 0.2] + [0.7 \times 0.6, 0.4 \times (-0.5)] =[0.72,0.16]+[0.42,0.20]=[1.14,0.04]= [0.72, 0.16] + [0.42, -0.20] = [1.14, -0.04]

  2. 新隐状态h(t)=otanh(c(t))h^{(t)} = o \odot \tanh(c^{(t)}) tanh([1.14,0.04])=[0.817,0.040]\tanh([1.14, -0.04]) = [0.817, -0.040] h(t)=[0.6×0.817,0.7×(0.040)]=[0.490,0.028]h^{(t)} = [0.6 \times 0.817, 0.7 \times (-0.040)] = [0.490, -0.028]

观察:遗忘门 f=0.9f = 0.9(几乎保留过去记忆),输入门写入新信息,输出门控制最终输出。

⚠️ 常见误区

  1. 误区:LSTM 完全解决了梯度消失 → 正确:LSTM 大幅缓解梯度消失(通过记忆单元的加法路径),但并不完全消除。在极长序列(数千步)上仍然困难。Transformer 通过注意力机制进一步解决这个问题。
  2. 误区:遗忘门为 0 = 清除所有记忆 → 正确f=0f = 0c(t)=ic~(t)c^{(t)} = i \odot \tilde{c}^{(t)},完全丢弃过去记忆,完全用新信息。f=1f = 1 时完全保留记忆(加法更新不改变 cc)。中间值 = 部分遗忘。
  3. 误区:LSTM 有 4 组权重矩阵,参数是 RNN 的 4 倍 → 正确:可以将 4 个门的权重矩阵拼为一个大矩阵 WR4n×(n+d)W \in \mathbb{R}^{4n \times (n+d)},一次矩阵乘法完成所有门的计算(更高效)。

知识关联

🔗 知识关联

  • → L05 注意力机制:LSTM 虽然缓解了梯度消失,但仍受限于顺序计算。注意力机制(Bahdanau et al., 2015)让 decoder 直接访问所有 encoder 隐状态,进一步解决长距离依赖。
  • → L05 Transformer:完全放弃 RNN 的循环结构,改用自注意力实现并行计算和直接的长距离连接。
  • ← L03 反向传播:BPTT 是 L03 中讲的计算图反向传播在时间序列上的展开应用。
  • ← L02 Word2Vec:RNN-LM 的输入是词嵌入向量,与 L02 中的词向量训练直接对接。
  • ↔ ResNet 残差连接:LSTM 的 cell state 加法更新与 ResNet 的 skip connection 原理相同——都通过加法路径保持梯度流通。