L04: Language Models and Recurrent Neural Networks

Week 2 · Thu Jan 15 2026 08:00:00 GMT+0800 (中国标准时间)

进度: 0/22 (0%)
下载 PDF
/ 0
100%
正在加载 PDF...

L04: Language Models and Recurrent Neural Networks

Slides

中英交替版(推荐)

L04 双语 (PDF)

英文原版

L04 EN (PDF)

中文翻译版

L04 ZH (PDF)

核心知识点

1. 语言模型(Language Modeling)— 本课最重要的概念

Slide 1 Slide 2 Slide 3 Slide 4 Slide 5 Slide 6 Slide 7 Slide 8
  • 定义:预测下一个词的任务
    • 给定词序列 x(1),x(2),,x(t)x^{(1)}, x^{(2)}, \ldots, x^{(t)},计算 P(x(t+1)x(t),,x(1))P(x^{(t+1)} | x^{(t)}, \ldots, x^{(1)})
  • 等价地:为一段文本赋予概率
    • P(x(1),,x(T))=t=1TP(x(t)x(t1),,x(1))P(x^{(1)}, \ldots, x^{(T)}) = \prod_{t=1}^{T} P(x^{(t)} | x^{(t-1)}, \ldots, x^{(1)})
  • 无处不在的应用:预测输入、搜索补全、语音识别、拼写纠错、机器翻译、对话、摘要…
  • ChatGPT 本质就是一个语言模型!
  • 下一词预测能解决的任务:常识(trivia)、句法(syntax)、共指(coreference)、情感(sentiment)、推理(reasoning)等

📐 语言模型概率的链式分解

联合概率 → 条件概率的乘积(链式法则 Chain Rule of Probability):

P(x(1),x(2),,x(T))=P(x(1))P(x(2)x(1))P(x(3)x(1),x(2))P(x^{(1)}, x^{(2)}, \ldots, x^{(T)}) = P(x^{(1)}) \cdot P(x^{(2)} | x^{(1)}) \cdot P(x^{(3)} | x^{(1)}, x^{(2)}) \cdots

推导

P(A,B,C)=P(A)P(BA)P(CA,B)P(A, B, C) = P(A) \cdot P(B|A) \cdot P(C|A,B)

这是概率论的基本恒等式(条件概率定义:P(BA)=P(A,B)/P(A)P(B|A) = P(A,B)/P(A)),无任何假设。

扩展到序列:

P(x(1),,x(T))=t=1TP(x(t)x(t1),,x(1))P(x^{(1)}, \ldots, x^{(T)}) = \prod_{t=1}^{T} P(x^{(t)} | x^{(t-1)}, \ldots, x^{(1)})

语言模型的任务:对每个 tt,学习 P(x(t)x(t1),,x(1))P(x^{(t)} | x^{(t-1)}, \ldots, x^{(1)}),即给定历史预测下一个词。

困惑度(Perplexity):测量语言模型”有多惊讶”于给定文本:

PPL=exp(1Tt=1TlogP(x(t+1)x(t)))=exp(J)\text{PPL} = \exp\left(-\frac{1}{T}\sum_{t=1}^T \log P(x^{(t+1)} | x^{(\le t)})\right) = \exp(J)

PPL=k\text{PPL} = k 意味着模型在每个位置的平均不确定性相当于在 kk 个词中均匀随机猜测。PPL 越低,模型越好。

📚 已收录至 拓展阅读知识库

🔢 困惑度计算示例

设定:测试文本 “the cat sat”(3 个词,加上结束符共 3 步预测),模型给出的条件概率:

时间步预测目标条件概率 P(x(t+1)x(t))P(x^{(t+1)} \mid x^{(\le t)})
t=0t=0“the”0.10(高频词)
t=1t=1“cat”0.05
t=2t=2”sat”0.20

计算:

  1. 平均对数概率 = 13[log(0.10)+log(0.05)+log(0.20)]\frac{1}{3}[\log(0.10) + \log(0.05) + \log(0.20)]
  2. =13[2.3032.9961.609]=6.9083=2.303= \frac{1}{3}[-2.303 - 2.996 - 1.609] = \frac{-6.908}{3} = -2.303
  3. PPL=exp(2.303)10.0\text{PPL} = \exp(2.303) \approx \mathbf{10.0}

解读:平均每步相当于在 10 个词里猜一个。

对比:随机猜测(词汇表 50,000 词)的 PPL = 50,000;GPT-4 在 PTB 数据集 PPL ≈ 2050;人类理解文本约 2040。

⚠️ 常见误区

  1. 误区:语言模型 = 生成式模型 → 正确:语言模型也可以是判别模型(如 masked LM,BERT)。经典语言模型是自回归的(给定前文预测下一词),但并非唯一形式。
  2. 误区:PPL 越低一定越好 → 正确:PPL 依赖测试集。在训练集上极低 PPL = 过拟合。此外,PPL 低不代表生成文本的质量好(可能重复、语义空洞)。

2. N-gram 语言模型

Slide 9 Slide 10 Slide 11 Slide 12 Slide 13 Slide 14 Slide 15 Slide 16 Slide 17 Slide 18 Slide 19 Slide 20
  • N-gram:nn 个连续词的片段
    • unigram / bigram / trigram / 4-gram …
  • 马尔可夫假设x(t+1)x^{(t+1)} 只依赖前 n1n-1 个词
    • P(x(t+1)x(t),,x(1))P(x(t+1)x(t),,x(tn+2))P(x^{(t+1)} | x^{(t)}, \ldots, x^{(1)}) \approx P(x^{(t+1)} | x^{(t)}, \ldots, x^{(t-n+2)})
  • 概率估计:计数法
    • P(wcontext)count(context,w)count(context)P(w | \text{context}) \approx \frac{\text{count}(\text{context}, w)}{\text{count}(\text{context})}
  • 示例(4-gram):“students opened their” 出现 1000 次
    • “students opened their books” 400 次 \to P(books)=0.4P(\text{books}) = 0.4
    • “students opened their exams” 100 次 \to P(exams)=0.1P(\text{exams}) = 0.1
  • 稀疏性问题
    • 问题 1:nn-gram 未在语料中出现 \to 概率为 0 \to 平滑(smoothing):给每个词加小 δ\delta
    • 问题 2:(n1)(n-1)-gram 未出现 \to 无法计算 \to 回退(backoff):用更短的 nn-gram
  • 存储问题:需要存储所有见过的 nn-gram 的计数
  • 生成文本:惊人地符合语法,但语义不连贯(上下文窗口太小)

📐 N-gram 概率估计与平滑

马尔可夫假设(N-gram 的核心):

P(x(t+1)x(t),,x(1))P(x(t+1)x(t),,x(tn+2))P(x^{(t+1)} | x^{(t)}, \ldots, x^{(1)}) \approx P(x^{(t+1)} | x^{(t)}, \ldots, x^{(t-n+2)})

只保留最近 n1n-1 个词的历史,使得条件历史有限,可以通过计数估计。

最大似然估计(MLE):

P(wwtn+2,,wt)C(wtn+2,,wt,w)C(wtn+2,,wt)P(w | w_{t-n+2}, \ldots, w_t) \approx \frac{C(w_{t-n+2}, \ldots, w_t, w)}{C(w_{t-n+2}, \ldots, w_t)}

其中 C()C(\cdot) 是语料库中的计数。

稀疏性问题:如果分子为 0(nn-gram 从未出现),概率为 0,且导致任何包含此 nn-gram 的句子概率为 0。

解决方法 1 — Laplace 平滑(加一平滑)

Psmooth(wh)=C(h,w)+δw[C(h,w)+δ]=C(h,w)+δC(h)+δVP_{smooth}(w | h) = \frac{C(h, w) + \delta}{\sum_{w'} [C(h, w') + \delta]} = \frac{C(h, w) + \delta}{C(h) + \delta |V|}

通过给每个词加小量 δ=1\delta = 1,保证所有概率 > 0。

解决方法 2 — Kneser-Ney 回退(Backoff): 当 nn-gram 未出现时,回退到 (n1)(n-1)-gram(递归直到 unigram):

PKN(wh)={max(C(h,w)d,0)/C(h)+λ(h)PKN(wh1)if C(h,w)>0PKN(wh1)otherwiseP_{KN}(w | h) = \begin{cases} \max(C(h,w) - d, 0) / C(h) + \lambda(h) \cdot P_{KN}(w | h_{-1}) & \text{if } C(h,w) > 0 \\ P_{KN}(w | h_{-1}) & \text{otherwise} \end{cases}

📚 已收录至 拓展阅读知识库

🔢 Bigram 概率计算示例

语料(玩具语料):

  • “I like cats” (1 次)
  • “I like dogs” (2 次)
  • “cats like fish” (1 次)

Bigram 计数:(I, like)=3, (like, cats)=1, (like, dogs)=2, (cats, like)=1

计算 P(dogslike)P(\text{dogs} | \text{like})

P(dogslike)=C(like, dogs)C(like)=230.67P(\text{dogs} | \text{like}) = \frac{C(\text{like, dogs})}{C(\text{like})} = \frac{2}{3} \approx 0.67

计算 P(fishlike)P(\text{fish} | \text{like})

P(fishlike)=C(like, fish)C(like)=03=0P(\text{fish} | \text{like}) = \frac{C(\text{like, fish})}{C(\text{like})} = \frac{0}{3} = 0

加 Laplace 平滑(词汇表 V=5|V|=5δ=1\delta=1):

Psmooth(fishlike)=0+13+5=18=0.125P_{smooth}(\text{fish} | \text{like}) = \frac{0 + 1}{3 + 5} = \frac{1}{8} = 0.125

⚠️ 常见误区

  1. 误区nn 越大越好 → 正确nn 越大,数据稀疏性越严重(多数 5-gram 从未在语料中出现)。实践中 n=5n=5 已经是上限,且需要大量平滑。
  2. 误区:N-gram 已经过时 → 正确:N-gram 在工业界的特定场景(输入法、拼写检查)中依然使用,因为极其高效且可解释。理解 N-gram 有助于理解神经语言模型为什么更好。

3. 固定窗口神经语言模型

Slide 21
Slide 21
Slide 22
Slide 22
Slide 23
Slide 23
Slide 24
Slide 24
  • 改进 nn-gram:用神经网络替代计数
  • Bengio et al. (2003):首个神经语言模型
    • 输入:固定窗口内词向量的拼接
    • 隐藏层 + softmax 输出
  • 优点:无稀疏性问题、无需存储所有 nn-gram
  • 问题:窗口大小固定,无法利用更远的上下文

💡 为什么从 N-gram → 神经语言模型是进步?

N-gram 的本质问题:用计数(频率)估计概率。问题是”the students opened their ___“和”the professors opened their ___“在语言学上极其相似,但 N-gram 把它们当成完全不同的 4-gram,无法共享任何信息。

神经语言模型的洞察:用词向量表示历史!

“students” 和 “professors” 的词向量很相似(都是人),所以网络自然会对这两个 context 给出相似的预测分布。

这就是”泛化”——神经网络通过连续表示,把相似的输入映射到相似的输出,而 N-gram 做不到。

局限:固定窗口(如 5 词)仍然无法捕捉远距离依赖,且不同位置使用不同的权重矩阵(W1,W2,W_1, W_2, \ldots),无法共享参数。

⚠️ 常见误区

误区:固定窗口神经 LM 和 N-gram 一样 → 正确:固定窗口神经 LM 的”有效上下文”是词向量而非 one-hot,因此可以泛化到未见过的 nn-gram。但两者都受固定窗口大小限制,这是 RNN 要解决的问题。

4. 循环神经网络(RNN)

Slide 25 Slide 26 Slide 27 Slide 28 Slide 29 Slide 30 Slide 31 Slide 32 Slide 33 Slide 34 Slide 35 Slide 36 Slide 37 Slide 38 Slide 39
  • 核心思想:处理任意长度的序列,在每个时间步共享权重
  • 隐状态更新:
    • h(t)=σ(Whh(t1)+Wxx(t)+b)h^{(t)} = \sigma(W_h h^{(t-1)} + W_x x^{(t)} + b)
  • 输出分布:y^(t)=softmax(Uh(t)+b2)\hat{y}^{(t)} = \text{softmax}(U h^{(t)} + b_2)
  • 训练:对每个时间步计算交叉熵损失,总损失取平均
    • J(θ)=1Tt=1TJ(t)=1Tt=1TlogP(x(t+1)x(t),,x(1))J(\theta) = \frac{1}{T} \sum_{t=1}^{T} J^{(t)} = -\frac{1}{T} \sum_{t=1}^{T} \log P(x^{(t+1)} | x^{(t)}, \ldots, x^{(1)})
  • 评估指标:困惑度(Perplexity)
    • PPL=exp(J(θ))=exp(1Tt=1TlogP(x(t+1)))\text{PPL} = \exp(J(\theta)) = \exp\left(-\frac{1}{T}\sum_{t=1}^T \log P(x^{(t+1)} | \ldots)\right)
    • PPL 越低越好;等价于交叉熵损失的指数形式
  • 优点:可处理任意长序列、理论上可利用远距离信息、模型大小不随输入增长
  • 缺点:循环计算导致训练慢、实际中难以学习长距离依赖

📐 RNN 前向传播与 BPTT

前向传播(Forward Pass)

给定词嵌入序列 x(1),,x(T)x^{(1)}, \ldots, x^{(T)}(每个 x(t)Rdx^{(t)} \in \mathbb{R}^d):

步骤 1:初始化 h(0)=0h^{(0)} = \mathbf{0}(或随机初始化)

步骤 2:逐步更新隐状态:

h(t)=σ(Whh(t1)+Wxx(t)+bh)h^{(t)} = \sigma(W_h h^{(t-1)} + W_x x^{(t)} + b_h)

其中 WhRn×nW_h \in \mathbb{R}^{n \times n}(隐藏到隐藏),WxRn×dW_x \in \mathbb{R}^{n \times d}(输入到隐藏)

步骤 3:每步输出分布(softmax 预测下一词):

y^(t)=softmax(Wyh(t)+by)\hat{y}^{(t)} = \text{softmax}(W_y h^{(t)} + b_y)

损失J(t)=logP(x(t+1)x(t))=logy^x(t+1)(t)J^{(t)} = -\log P(x^{(t+1)} | x^{(\le t)}) = -\log \hat{y}^{(t)}_{x^{(t+1)}}(正确词的对数概率)

总损失J=1Tt=1TJ(t)J = \frac{1}{T}\sum_{t=1}^{T} J^{(t)}

反向传播(BPTT — Backpropagation Through Time)

WhW_h 求梯度时,需要跨时间步回传:

J(t)Wh=k=1tJ(t)h(t)(j=k+1th(j)h(j1))h(k)Wh\frac{\partial J^{(t)}}{\partial W_h} = \sum_{k=1}^{t} \frac{\partial J^{(t)}}{\partial h^{(t)}} \cdot \left(\prod_{j=k+1}^{t} \frac{\partial h^{(j)}}{\partial h^{(j-1)}}\right) \cdot \frac{\partial h^{(k)}}{\partial W_h}

每一步的局部雅可比:h(j)h(j1)=WhTdiag(σ(z(j)))\frac{\partial h^{(j)}}{\partial h^{(j-1)}} = W_h^T \text{diag}(\sigma'(z^{(j)}))(其中 z(j)=Whh(j1)+z^{(j)} = W_h h^{(j-1)} + \ldots

📚 已收录至 拓展阅读知识库

🔢 小型 RNN 前向传播示例

设定d=2d=2(词向量维度),n=2n=2(隐藏层维度),词汇表 V=3|V|=3

参数(小数值方便计算):

  • Wh=[0.5000.5]W_h = \begin{bmatrix}0.5 & 0 \\ 0 & 0.5\end{bmatrix}Wx=[1001]W_x = \begin{bmatrix}1 & 0 \\ 0 & 1\end{bmatrix}bh=0b_h = \mathbf{0}
  • 输入序列:x(1)=[0.8,0.2]Tx^{(1)} = [0.8, 0.2]^Tx(2)=[0.3,0.7]Tx^{(2)} = [0.3, 0.7]^T

计算

  1. t=1t=1h(1)=tanh(Wh0+Wxx(1))=tanh([0.8,0.2]T)=[0.664,0.197]Th^{(1)} = \tanh(W_h \cdot \mathbf{0} + W_x \cdot x^{(1)}) = \tanh([0.8, 0.2]^T) = [0.664, 0.197]^T
  2. t=2t=2z(2)=Whh(1)+Wxx(2)=[0.332+0.3,0.099+0.7]T=[0.632,0.799]Tz^{(2)} = W_h h^{(1)} + W_x x^{(2)} = [0.332+0.3, 0.099+0.7]^T = [0.632, 0.799]^T h(2)=tanh([0.632,0.799]T)=[0.559,0.665]Th^{(2)} = \tanh([0.632, 0.799]^T) = [0.559, 0.665]^T

观察h(2)h^{(2)} 同时包含了 x(1)x^{(1)}x(2)x^{(2)} 的信息(通过 h(1)h^{(1)} 间接传入)。

⚠️ 常见误区

  1. 误区:RNN 的权重 Wh,WxW_h, W_x 在每个时间步都不同 → 正确:RNN 在所有时间步共享同一套权重(这是 RNN 的核心特性!)。共享权重使模型大小不随序列长度增长。
  2. 误区:BPTT 会计算所有历史时间步 → 正确:实践中常用”截断 BPTT(Truncated BPTT)“,只回传固定步数(如 100 步),平衡计算量和梯度信息。

5. 梯度消失与梯度爆炸

Slide 40 Slide 41 Slide 42 Slide 43 Slide 44 Slide 45 Slide 46 Slide 47 Slide 48 Slide 49
  • 梯度消失
    • J(t)h(1)=h(2)h(1)×h(3)h(2)××J(t)h(t)\frac{\partial J^{(t)}}{\partial h^{(1)}} = \frac{\partial h^{(2)}}{\partial h^{(1)}} \times \frac{\partial h^{(3)}}{\partial h^{(2)}} \times \cdots \times \frac{\partial J^{(t)}}{\partial h^{(t)}}
    • 每一步乘以 WhW_h 的转置,如果特征值 < 1,连乘后趋近 0
    • 结果:远距离信号丢失,模型只能学到近距离依赖
  • 梯度爆炸
    • 特征值 > 1 时,梯度指数增长
    • 解决方案:梯度裁剪(Gradient Clipping)
      • 如果 g^threshold\|\hat{g}\| \ge \text{threshold},则 g^thresholdg^g^\hat{g} \leftarrow \frac{\text{threshold}}{\|\hat{g}\|} \hat{g}
  • 对 RNN-LM 的影响:无法学习长距离依赖(如 “tickets…tickets” 的跨句指代)

📐 梯度消失的数学原因

梯度链的展开(损失 J(t)J^{(t)} 对早期隐状态 h(k)h^{(k)} 的梯度):

J(t)h(k)=J(t)h(t)j=k+1th(j)h(j1)\frac{\partial J^{(t)}}{\partial h^{(k)}} = \frac{\partial J^{(t)}}{\partial h^{(t)}} \cdot \prod_{j=k+1}^{t} \frac{\partial h^{(j)}}{\partial h^{(j-1)}}

每一步的 Jacobian:

h(j)h(j1)=WhTdiag(σ(z(j)))\frac{\partial h^{(j)}}{\partial h^{(j-1)}} = W_h^T \cdot \text{diag}(\sigma'(z^{(j)}))

谱半径决定命运:设 WhTW_h^T 的最大奇异值为 σ1\sigma_1,激活函数的最大导数为 γ\gamma

  • 如果 σ1γ<1\sigma_1 \gamma < 1tkt - k 步后梯度 (σ1γ)tk0\sim (\sigma_1 \gamma)^{t-k} \to 0梯度消失
  • 如果 σ1γ>1\sigma_1 \gamma > 1:梯度 (σ1γ)tk\sim (\sigma_1 \gamma)^{t-k} \to \infty梯度爆炸
  • 只有精确 =1= 1 时梯度才稳定(实践中极难维持)

为什么激活函数导数也是关键

  • tanh(z)=1tanh2(z)(0,1]\tanh'(z) = 1 - \tanh^2(z) \in (0, 1],在饱和区趋向 0
  • sigmoid 的导数最大值仅 0.25,梯度消失更严重
  • ReLU 在 z>0z > 0 区域导数 = 1,但负值区域梯度 = 0

Gradient Clipping 只治”爆炸”

if g^threshold:g^thresholdg^g^\text{if } \|\hat{g}\| \ge \text{threshold}: \quad \hat{g} \leftarrow \frac{\text{threshold}}{\|\hat{g}\|} \hat{g}

这等比例缩小梯度向量,保持方向但限制大小。不能增大消失的梯度。

📚 已收录至 拓展阅读知识库

🔢 梯度消失的数值示意

标量简化:设 Wh=0.9W_h = 0.9tanh(z(j))=0.8\tanh'(z^{(j)}) = 0.8(常数,简化):

每步缩减因子=0.9×0.8=0.72\text{每步缩减因子} = 0.9 \times 0.8 = 0.72

距离(时间步数)梯度大小
tk=1t - k = 10.721=0.7200.72^1 = 0.720
tk=5t - k = 50.725=0.1930.72^5 = 0.193
tk=10t - k = 100.7210=0.0370.72^{10} = 0.037
tk=20t - k = 200.7220=0.0010.72^{20} = 0.001

20 步之前的信号梯度已衰减到 0.1%,实际上无法更新相关参数。

对比爆炸(Wh=1.1W_h = 1.1σ=1.0\sigma' = 1.0):

  • tk=10t - k = 101.110=2.591.1^{10} = 2.59
  • tk=20t - k = 201.120=6.731.1^{20} = 6.73
  • tk=50t - k = 501.150=1171.1^{50} = 117 → NaN

⚠️ 常见误区

  1. 误区:梯度消失 = 梯度为 0 → 正确:梯度消失是指数衰减,不是突然归零。距离 20 步时可能只有 0.001 的大小,小到无法驱动参数更新,但并不完全是 0。
  2. 误区:梯度裁剪可以防止梯度消失 → 正确:裁剪只防止爆炸(缩小过大的梯度)。防止消失需要架构改变(LSTM 的加法更新、残差连接)。

6. 解决梯度消失:LSTM

Slide 50
Slide 50
  • LSTM(Long Short-Term Memory):引入独立的记忆单元 c(t)c^{(t)}
  • 门控机制:
    • 遗忘门(Forget gate):f(t)=σ(Wfh(t1)+Ufx(t)+bf)f^{(t)} = \sigma(W_f h^{(t-1)} + U_f x^{(t)} + b_f)
    • 输入门(Input gate):i(t)=σ(Wih(t1)+Uix(t)+bi)i^{(t)} = \sigma(W_i h^{(t-1)} + U_i x^{(t)} + b_i)
    • 输出门(Output gate):o(t)=σ(Woh(t1)+Uox(t)+bo)o^{(t)} = \sigma(W_o h^{(t-1)} + U_o x^{(t)} + b_o)
  • 记忆更新:c(t)=f(t)c(t1)+i(t)c~(t)c^{(t)} = f^{(t)} \odot c^{(t-1)} + i^{(t)} \odot \tilde{c}^{(t)}
  • 隐状态:h(t)=o(t)tanh(c(t))h^{(t)} = o^{(t)} \odot \tanh(c^{(t)})
  • 关键:记忆单元的加法更新(而非乘法)使梯度更容易流过长序列
  • 其他方案:GRU(更简化的门控)、残差连接、注意力机制

📐 LSTM 完整推导——为什么门控能缓解梯度消失?

LSTM 的核心区别:引入记忆单元(cell state) c(t)c^{(t)},通过加法更新而非乘法更新传递信息。

所有门的计算(共 4 个):

[ifoc~](t)=[σσσtanh](W[h(t1)x(t)]+b)\begin{bmatrix} i \\ f \\ o \\ \tilde{c} \end{bmatrix}^{(t)} = \begin{bmatrix} \sigma \\ \sigma \\ \sigma \\ \tanh \end{bmatrix}\left(W \begin{bmatrix} h^{(t-1)} \\ x^{(t)} \end{bmatrix} + b\right)

展开为:

  • 遗忘门:f(t)=σ(Wf[h(t1),x(t)]+bf)(0,1)nf^{(t)} = \sigma(W_f [h^{(t-1)}, x^{(t)}] + b_f) \in (0, 1)^n
  • 输入门:i(t)=σ(Wi[h(t1),x(t)]+bi)(0,1)ni^{(t)} = \sigma(W_i [h^{(t-1)}, x^{(t)}] + b_i) \in (0, 1)^n
  • 候选记忆:c~(t)=tanh(Wc[h(t1),x(t)]+bc)(1,1)n\tilde{c}^{(t)} = \tanh(W_c [h^{(t-1)}, x^{(t)}] + b_c) \in (-1, 1)^n
  • 输出门:o(t)=σ(Wo[h(t1),x(t)]+bo)(0,1)no^{(t)} = \sigma(W_o [h^{(t-1)}, x^{(t)}] + b_o) \in (0, 1)^n

记忆单元(加法更新——关键!)

c(t)=f(t)c(t1)+i(t)c~(t)c^{(t)} = f^{(t)} \odot c^{(t-1)} + i^{(t)} \odot \tilde{c}^{(t)}

隐状态(输出)

h(t)=o(t)tanh(c(t))h^{(t)} = o^{(t)} \odot \tanh(c^{(t)})

为什么加法更新能缓解梯度消失?

c(t)c^{(t)}c(t1)c^{(t-1)} 的梯度:

c(t)c(t1)=f(t)\frac{\partial c^{(t)}}{\partial c^{(t-1)}} = f^{(t)}

(忽略高阶项)。这是逐元素乘以遗忘门,而不是乘以大矩阵!

对比标准 RNN 的 h(t)h(t1)=WhTdiag(σ)\frac{\partial h^{(t)}}{\partial h^{(t-1)}} = W_h^T \text{diag}(\sigma')(矩阵乘法,连乘后快速衰减)。

如果遗忘门 f(t)1f^{(t)} \approx 1(不遗忘),梯度 ≈ 1,不消失也不爆炸。这类似于 ResNet 的残差连接!

📚 已收录至 拓展阅读知识库

🔢 LSTM 一步更新数值示例

设定n=2n=2(隐藏维度),当前状态:

  • h(t1)=[0.5,0.3]Th^{(t-1)} = [0.5, -0.3]^Tc(t1)=[0.8,0.2]Tc^{(t-1)} = [0.8, 0.2]^T
  • 输入 x(t)=[1.0,0.0]Tx^{(t)} = [1.0, 0.0]^T

假设(为简化,直接给出门的激活值):

  • f(t)=[0.9,0.8]Tf^{(t)} = [0.9, 0.8]^T(遗忘大部分记忆)
  • i(t)=[0.7,0.4]Ti^{(t)} = [0.7, 0.4]^T(写入部分新信息)
  • c~(t)=[0.6,0.5]T\tilde{c}^{(t)} = [0.6, -0.5]^T(候选新记忆)
  • o(t)=[0.6,0.7]To^{(t)} = [0.6, 0.7]^T(输出多少到隐状态)

计算

  1. 新记忆c(t)=fc(t1)+ic~c^{(t)} = f \odot c^{(t-1)} + i \odot \tilde{c} =[0.9×0.8,0.8×0.2]+[0.7×0.6,0.4×(0.5)]= [0.9 \times 0.8, 0.8 \times 0.2] + [0.7 \times 0.6, 0.4 \times (-0.5)] =[0.72,0.16]+[0.42,0.20]=[1.14,0.04]= [0.72, 0.16] + [0.42, -0.20] = [1.14, -0.04]

  2. 新隐状态h(t)=otanh(c(t))h^{(t)} = o \odot \tanh(c^{(t)}) tanh([1.14,0.04])=[0.817,0.040]\tanh([1.14, -0.04]) = [0.817, -0.040] h(t)=[0.6×0.817,0.7×(0.040)]=[0.490,0.028]h^{(t)} = [0.6 \times 0.817, 0.7 \times (-0.040)] = [0.490, -0.028]

观察:遗忘门 f=0.9f = 0.9(几乎保留过去记忆),输入门写入新信息,输出门控制最终输出。

⚠️ 常见误区

  1. 误区:LSTM 完全解决了梯度消失 → 正确:LSTM 大幅缓解梯度消失(通过记忆单元的加法路径),但并不完全消除。在极长序列(数千步)上仍然困难。Transformer 通过注意力机制进一步解决这个问题。
  2. 误区:遗忘门为 0 = 清除所有记忆 → 正确f=0f = 0c(t)=ic~(t)c^{(t)} = i \odot \tilde{c}^{(t)},完全丢弃过去记忆,完全用新信息。f=1f = 1 时完全保留记忆(加法更新不改变 cc)。中间值 = 部分遗忘。
  3. 误区:LSTM 有 4 组权重矩阵,参数是 RNN 的 4 倍 → 正确:可以将 4 个门的权重矩阵拼为一个大矩阵 WR4n×(n+d)W \in \mathbb{R}^{4n \times (n+d)},一次矩阵乘法完成所有门的计算(更高效)。

7. 机器翻译简介

Slide 51 Slide 52 Slide 53 Slide 54 Slide 55 Slide 56 Slide 57 Slide 58 Slide 59
  • 从源语言 xx 到目标语言 yy 的翻译
  • NMT 的成功:2014 年 seq2seq 首篇论文 \to 2016 年 Google 全面替换 SMT
  • Encoder-Decoder 架构:
    • Encoder RNN 编码源句 \to 上下文向量
    • Decoder RNN 以此为条件生成目标句
  • 条件语言模型:P(yx)=P(y1x)P(y2y1,x)P(yTy1,,yT1,x)P(y|x) = P(y_1|x) P(y_2|y_1, x) \cdots P(y_T|y_1, \ldots, y_{T-1}, x)

📐 条件语言模型的训练目标

目标:最大化给定源句 xx 时,目标句 yy 的对数似然:

L=(x,y)datalogP(yx;θ)=(x,y)t=1ylogP(yty<t,x;θ)\mathcal{L} = \sum_{(x,y) \in \text{data}} \log P(y|x; \theta) = \sum_{(x,y)} \sum_{t=1}^{|y|} \log P(y_t | y_{<t}, x; \theta)

Encoder-Decoder 分解

  • Encoderhenc=RNN(x1,,xm)h_{\text{enc}} = \text{RNN}(x_1, \ldots, x_m),最终状态 henc(m)h_{\text{enc}}^{(m)} 作为初始 decoder 状态
  • Decoderhdec(t)=RNN(yt1,hdec(t1))h_{\text{dec}}^{(t)} = \text{RNN}(y_{t-1}, h_{\text{dec}}^{(t-1)})P(yt)=softmax(Whdec(t))P(y_t | \ldots) = \text{softmax}(W h_{\text{dec}}^{(t)})

瓶颈问题的数学表现

  • 整个源句 x1,,xmx_1, \ldots, x_m 被压缩为一个向量 henc(m)Rnh_{\text{enc}}^{(m)} \in \mathbb{R}^n(固定维度)
  • 无论源句多长(m=5m = 5 还是 m=100m = 100),都要压缩到同样大小的向量
  • 信息论:固定容量的”瓶颈”对长句信息损失严重

解决方案(引出注意力):让 decoder 在每一步直接访问所有 encoder 隐状态 {henc(1),,henc(m)}\{h_{\text{enc}}^{(1)}, \ldots, h_{\text{enc}}^{(m)}\},而非只用最后一个状态。

📚 已收录至 拓展阅读知识库

🔗 知识关联

  • → L05 注意力机制:本节的瓶颈问题直接引发了注意力机制的发明(Bahdanau et al., 2015)
  • → L05 Transformer:完全放弃 RNN,改用自注意力解决长距离依赖问题
  • ← L03 反向传播:BPTT 是 L03 中讲的计算图反向传播在序列上的应用
  • ← L02 Word2Vec:RNN-LM 的输入是词嵌入,与 L02 中的词向量直接对接

⚠️ 常见误区

  1. 误区:Seq2Seq 中 encoder 和 decoder 不共享权重 → 正确:这是设计选择,有些模型共享,有些不共享。通常 encoder 和 decoder 是不同的 RNN(不同参数)。
  2. 误区:Teacher forcing = 作弊 → 正确:Teacher forcing(训练时用真实目标词作为 decoder 输入,而非上一步预测词)是训练的标准方法,虽然会造成”暴露偏差”(训练 vs 推理时的分布差异),但实践中效果好。

推荐阅读

关联概念

作业提醒

  • A2 截止:Jan 22 (Thu)

个人笔记