L04: Language Models and Recurrent Neural Networks

Week 2 · Thu Jan 15 2026 08:00:00 GMT+0800 (中国标准时间)

进度: 0/22 (0%)
下载 PDF
/ 0
100%
正在加载 PDF...

L04: Language Models and Recurrent Neural Networks

Slides

中英交替版(推荐)

L04 双语 (PDF)

英文原版

L04 EN (PDF)

中文翻译版

L04 ZH (PDF)

核心知识点

1. 语言模型(Language Modeling)— 本课最重要的概念

Slide 1 Slide 2 Slide 3 Slide 4 Slide 5 Slide 6 Slide 7 Slide 8
  • 定义:预测下一个词的任务
    • 给定词序列 x(1),x(2),,x(t)x^{(1)}, x^{(2)}, \ldots, x^{(t)},计算 P(x(t+1)x(t),,x(1))P(x^{(t+1)} | x^{(t)}, \ldots, x^{(1)})
  • 等价地:为一段文本赋予概率
    • P(x(1),,x(T))=t=1TP(x(t)x(t1),,x(1))P(x^{(1)}, \ldots, x^{(T)}) = \prod_{t=1}^{T} P(x^{(t)} | x^{(t-1)}, \ldots, x^{(1)})
  • 无处不在的应用:预测输入、搜索补全、语音识别、拼写纠错、机器翻译、对话、摘要…
  • ChatGPT 本质就是一个语言模型!
  • 下一词预测能解决的任务:常识(trivia)、句法(syntax)、共指(coreference)、情感(sentiment)、推理(reasoning)等

📐 语言模型概率的链式分解

联合概率 → 条件概率的乘积(链式法则 Chain Rule of Probability):

P(x(1),x(2),,x(T))=P(x(1))P(x(2)x(1))P(x(3)x(1),x(2))P(x^{(1)}, x^{(2)}, \ldots, x^{(T)}) = P(x^{(1)}) \cdot P(x^{(2)} | x^{(1)}) \cdot P(x^{(3)} | x^{(1)}, x^{(2)}) \cdots

推导

P(A,B,C)=P(A)P(BA)P(CA,B)P(A, B, C) = P(A) \cdot P(B|A) \cdot P(C|A,B)

这是概率论的基本恒等式(条件概率定义:P(BA)=P(A,B)/P(A)P(B|A) = P(A,B)/P(A)),无任何假设。

扩展到序列:

P(x(1),,x(T))=t=1TP(x(t)x(t1),,x(1))P(x^{(1)}, \ldots, x^{(T)}) = \prod_{t=1}^{T} P(x^{(t)} | x^{(t-1)}, \ldots, x^{(1)})

语言模型的任务:对每个 tt,学习 P(x(t)x(t1),,x(1))P(x^{(t)} | x^{(t-1)}, \ldots, x^{(1)}),即给定历史预测下一个词。

困惑度(Perplexity):测量语言模型”有多惊讶”于给定文本:

PPL=exp(1Tt=1TlogP(x(t+1)x(t)))=exp(J)\text{PPL} = \exp\left(-\frac{1}{T}\sum_{t=1}^T \log P(x^{(t+1)} | x^{(\le t)})\right) = \exp(J)

PPL=k\text{PPL} = k 意味着模型在每个位置的平均不确定性相当于在 kk 个词中均匀随机猜测。PPL 越低,模型越好。

📚 已收录至 拓展阅读知识库

🔢 困惑度计算示例

设定:测试文本 “the cat sat”(3 个词,加上结束符共 3 步预测),模型给出的条件概率:

时间步预测目标条件概率 P(x(t+1)x(t))P(x^{(t+1)} \mid x^{(\le t)})
t=0t=0“the”0.10(高频词)
t=1t=1“cat”0.05
t=2t=2”sat”0.20

计算:

  1. 平均对数概率 = 13[log(0.10)+log(0.05)+log(0.20)]\frac{1}{3}[\log(0.10) + \log(0.05) + \log(0.20)]
  2. =13[2.3032.9961.609]=6.9083=2.303= \frac{1}{3}[-2.303 - 2.996 - 1.609] = \frac{-6.908}{3} = -2.303
  3. PPL=exp(2.303)10.0\text{PPL} = \exp(2.303) \approx \mathbf{10.0}

解读:平均每步相当于在 10 个词里猜一个。

对比:随机猜测(词汇表 50,000 词)的 PPL = 50,000;GPT-4 在 PTB 数据集 PPL ≈ 2050;人类理解文本约 2040。

⚠️ 常见误区

  1. 误区:语言模型 = 生成式模型 → 正确:语言模型也可以是判别模型(如 masked LM,BERT)。经典语言模型是自回归的(给定前文预测下一词),但并非唯一形式。
  2. 误区:PPL 越低一定越好 → 正确:PPL 依赖测试集。在训练集上极低 PPL = 过拟合。此外,PPL 低不代表生成文本的质量好(可能重复、语义空洞)。

2. N-gram 语言模型

Slide 9 Slide 10 Slide 11 Slide 12 Slide 13 Slide 14 Slide 15 Slide 16 Slide 17 Slide 18 Slide 19 Slide 20 Slide 21
  • N-gram:nn 个连续词的片段
    • unigram / bigram / trigram / 4-gram …
  • 马尔可夫假设x(t+1)x^{(t+1)} 只依赖前 n1n-1 个词
    • P(x(t+1)x(t),,x(1))P(x(t+1)x(t),,x(tn+2))P(x^{(t+1)} | x^{(t)}, \ldots, x^{(1)}) \approx P(x^{(t+1)} | x^{(t)}, \ldots, x^{(t-n+2)})
  • 概率估计:计数法
    • P(wcontext)count(context,w)count(context)P(w | \text{context}) \approx \frac{\text{count}(\text{context}, w)}{\text{count}(\text{context})}
  • 示例(4-gram):“students opened their” 出现 1000 次
    • “students opened their books” 400 次 \to P(books)=0.4P(\text{books}) = 0.4
    • “students opened their exams” 100 次 \to P(exams)=0.1P(\text{exams}) = 0.1
  • 稀疏性问题
    • 问题 1:nn-gram 未在语料中出现 \to 概率为 0 \to 平滑(smoothing):给每个词加小 δ\delta
    • 问题 2:(n1)(n-1)-gram 未出现 \to 无法计算 \to 回退(backoff):用更短的 nn-gram
  • 存储问题:需要存储所有见过的 nn-gram 的计数
  • 生成文本:惊人地符合语法,但语义不连贯(上下文窗口太小)

📐 N-gram 概率估计与平滑

马尔可夫假设(N-gram 的核心):

P(x(t+1)x(t),,x(1))P(x(t+1)x(t),,x(tn+2))P(x^{(t+1)} | x^{(t)}, \ldots, x^{(1)}) \approx P(x^{(t+1)} | x^{(t)}, \ldots, x^{(t-n+2)})

只保留最近 n1n-1 个词的历史,使得条件历史有限,可以通过计数估计。

最大似然估计(MLE):

P(wwtn+2,,wt)C(wtn+2,,wt,w)C(wtn+2,,wt)P(w | w_{t-n+2}, \ldots, w_t) \approx \frac{C(w_{t-n+2}, \ldots, w_t, w)}{C(w_{t-n+2}, \ldots, w_t)}

其中 C()C(\cdot) 是语料库中的计数。

稀疏性问题:如果分子为 0(nn-gram 从未出现),概率为 0,且导致任何包含此 nn-gram 的句子概率为 0。

解决方法 1 — Laplace 平滑(加一平滑)

Psmooth(wh)=C(h,w)+δw[C(h,w)+δ]=C(h,w)+δC(h)+δVP_{smooth}(w | h) = \frac{C(h, w) + \delta}{\sum_{w'} [C(h, w') + \delta]} = \frac{C(h, w) + \delta}{C(h) + \delta |V|}

通过给每个词加小量 δ=1\delta = 1,保证所有概率 > 0。

解决方法 2 — Kneser-Ney 回退(Backoff): 当 nn-gram 未出现时,回退到 (n1)(n-1)-gram(递归直到 unigram):

PKN(wh)={max(C(h,w)d,0)/C(h)+λ(h)PKN(wh1)if C(h,w)>0PKN(wh1)otherwiseP_{KN}(w | h) = \begin{cases} \max(C(h,w) - d, 0) / C(h) + \lambda(h) \cdot P_{KN}(w | h_{-1}) & \text{if } C(h,w) > 0 \\ P_{KN}(w | h_{-1}) & \text{otherwise} \end{cases}

📚 已收录至 拓展阅读知识库

🔢 Bigram 概率计算示例

语料(玩具语料):

  • “I like cats” (1 次)
  • “I like dogs” (2 次)
  • “cats like fish” (1 次)

Bigram 计数:(I, like)=3, (like, cats)=1, (like, dogs)=2, (cats, like)=1

计算 P(dogslike)P(\text{dogs} | \text{like})

P(dogslike)=C(like, dogs)C(like)=230.67P(\text{dogs} | \text{like}) = \frac{C(\text{like, dogs})}{C(\text{like})} = \frac{2}{3} \approx 0.67

计算 P(fishlike)P(\text{fish} | \text{like})

P(fishlike)=C(like, fish)C(like)=03=0P(\text{fish} | \text{like}) = \frac{C(\text{like, fish})}{C(\text{like})} = \frac{0}{3} = 0

加 Laplace 平滑(词汇表 V=5|V|=5δ=1\delta=1):

Psmooth(fishlike)=0+13+5=18=0.125P_{smooth}(\text{fish} | \text{like}) = \frac{0 + 1}{3 + 5} = \frac{1}{8} = 0.125

⚠️ 常见误区

  1. 误区nn 越大越好 → 正确nn 越大,数据稀疏性越严重(多数 5-gram 从未在语料中出现)。实践中 n=5n=5 已经是上限,且需要大量平滑。
  2. 误区:N-gram 已经过时 → 正确:N-gram 在工业界的特定场景(输入法、拼写检查)中依然使用,因为极其高效且可解释。理解 N-gram 有助于理解神经语言模型为什么更好。

3. 固定窗口神经语言模型

Slide 22
Slide 22
Slide 23
Slide 23
Slide 24
Slide 24
  • 改进 nn-gram:用神经网络替代计数
  • Bengio et al. (2003):首个神经语言模型
    • 输入:固定窗口内词向量的拼接
    • 隐藏层 + softmax 输出
  • 优点:无稀疏性问题、无需存储所有 nn-gram
  • 问题:窗口大小固定,无法利用更远的上下文

💡 为什么从 N-gram → 神经语言模型是进步?

N-gram 的本质问题:用计数(频率)估计概率。问题是”the students opened their ___“和”the professors opened their ___“在语言学上极其相似,但 N-gram 把它们当成完全不同的 4-gram,无法共享任何信息。

神经语言模型的洞察:用词向量表示历史!

“students” 和 “professors” 的词向量很相似(都是人),所以网络自然会对这两个 context 给出相似的预测分布。

这就是”泛化”——神经网络通过连续表示,把相似的输入映射到相似的输出,而 N-gram 做不到。

局限:固定窗口(如 5 词)仍然无法捕捉远距离依赖,且不同位置使用不同的权重矩阵(W1,W2,W_1, W_2, \ldots),无法共享参数。

⚠️ 常见误区

误区:固定窗口神经 LM 和 N-gram 一样 → 正确:固定窗口神经 LM 的”有效上下文”是词向量而非 one-hot,因此可以泛化到未见过的 nn-gram。但两者都受固定窗口大小限制,这是 RNN 要解决的问题。

4. 循环神经网络(RNN)

Slide 25 Slide 26 Slide 27 Slide 28 Slide 29 Slide 30 Slide 31 Slide 32 Slide 33 Slide 34
  • 核心思想:处理任意长度的序列,在每个时间步共享权重
  • 隐状态更新:
    • h(t)=σ(Whh(t1)+Wxx(t)+b)h^{(t)} = \sigma(W_h h^{(t-1)} + W_x x^{(t)} + b)
  • 输出分布:y^(t)=softmax(Uh(t)+b2)\hat{y}^{(t)} = \text{softmax}(U h^{(t)} + b_2)
  • 训练:对每个时间步计算交叉熵损失,总损失取平均
    • J(θ)=1Tt=1TJ(t)=1Tt=1TlogP(x(t+1)x(t),,x(1))J(\theta) = \frac{1}{T} \sum_{t=1}^{T} J^{(t)} = -\frac{1}{T} \sum_{t=1}^{T} \log P(x^{(t+1)} | x^{(t)}, \ldots, x^{(1)})
  • 评估指标:困惑度(Perplexity)
    • PPL=exp(J(θ))=exp(1Tt=1TlogP(x(t+1)))\text{PPL} = \exp(J(\theta)) = \exp\left(-\frac{1}{T}\sum_{t=1}^T \log P(x^{(t+1)} | \ldots)\right)
    • PPL 越低越好;等价于交叉熵损失的指数形式
  • 优点:可处理任意长序列、理论上可利用远距离信息、模型大小不随输入增长
  • 缺点:循环计算导致训练慢、实际中难以学习长距离依赖

📐 RNN 前向传播与 BPTT

前向传播(Forward Pass)

给定词嵌入序列 x(1),,x(T)x^{(1)}, \ldots, x^{(T)}(每个 x(t)Rdx^{(t)} \in \mathbb{R}^d):

步骤 1:初始化 h(0)=0h^{(0)} = \mathbf{0}(或随机初始化)

步骤 2:逐步更新隐状态:

h(t)=σ(Whh(t1)+Wxx(t)+bh)h^{(t)} = \sigma(W_h h^{(t-1)} + W_x x^{(t)} + b_h)

其中 WhRn×nW_h \in \mathbb{R}^{n \times n}(隐藏到隐藏),WxRn×dW_x \in \mathbb{R}^{n \times d}(输入到隐藏)

步骤 3:每步输出分布(softmax 预测下一词):

y^(t)=softmax(Wyh(t)+by)\hat{y}^{(t)} = \text{softmax}(W_y h^{(t)} + b_y)

损失J(t)=logP(x(t+1)x(t))=logy^x(t+1)(t)J^{(t)} = -\log P(x^{(t+1)} | x^{(\le t)}) = -\log \hat{y}^{(t)}_{x^{(t+1)}}(正确词的对数概率)

总损失J=1Tt=1TJ(t)J = \frac{1}{T}\sum_{t=1}^{T} J^{(t)}

反向传播(BPTT — Backpropagation Through Time)

WhW_h 求梯度时,需要跨时间步回传:

J(t)Wh=k=1tJ(t)h(t)(j=k+1th(j)h(j1))h(k)Wh\frac{\partial J^{(t)}}{\partial W_h} = \sum_{k=1}^{t} \frac{\partial J^{(t)}}{\partial h^{(t)}} \cdot \left(\prod_{j=k+1}^{t} \frac{\partial h^{(j)}}{\partial h^{(j-1)}}\right) \cdot \frac{\partial h^{(k)}}{\partial W_h}

每一步的局部雅可比:h(j)h(j1)=WhTdiag(σ(z(j)))\frac{\partial h^{(j)}}{\partial h^{(j-1)}} = W_h^T \text{diag}(\sigma'(z^{(j)}))(其中 z(j)=Whh(j1)+z^{(j)} = W_h h^{(j-1)} + \ldots

📚 已收录至 拓展阅读知识库

🔢 小型 RNN 前向传播示例

设定d=2d=2(词向量维度),n=2n=2(隐藏层维度),词汇表 V=3|V|=3

参数(小数值方便计算):

  • Wh=[0.5000.5]W_h = \begin{bmatrix}0.5 & 0 \\ 0 & 0.5\end{bmatrix}Wx=[1001]W_x = \begin{bmatrix}1 & 0 \\ 0 & 1\end{bmatrix}bh=0b_h = \mathbf{0}
  • 输入序列:x(1)=[0.8,0.2]Tx^{(1)} = [0.8, 0.2]^Tx(2)=[0.3,0.7]Tx^{(2)} = [0.3, 0.7]^T

计算

  1. t=1t=1h(1)=tanh(Wh0+Wxx(1))=tanh([0.8,0.2]T)=[0.664,0.197]Th^{(1)} = \tanh(W_h \cdot \mathbf{0} + W_x \cdot x^{(1)}) = \tanh([0.8, 0.2]^T) = [0.664, 0.197]^T
  2. t=2t=2z(2)=Whh(1)+Wxx(2)=[0.332+0.3,0.099+0.7]T=[0.632,0.799]Tz^{(2)} = W_h h^{(1)} + W_x x^{(2)} = [0.332+0.3, 0.099+0.7]^T = [0.632, 0.799]^T h(2)=tanh([0.632,0.799]T)=[0.559,0.665]Th^{(2)} = \tanh([0.632, 0.799]^T) = [0.559, 0.665]^T

观察h(2)h^{(2)} 同时包含了 x(1)x^{(1)}x(2)x^{(2)} 的信息(通过 h(1)h^{(1)} 间接传入)。

⚠️ 常见误区

  1. 误区:RNN 的权重 Wh,WxW_h, W_x 在每个时间步都不同 → 正确:RNN 在所有时间步共享同一套权重(这是 RNN 的核心特性!)。共享权重使模型大小不随序列长度增长。
  2. 误区:BPTT 会计算所有历史时间步 → 正确:实践中常用”截断 BPTT(Truncated BPTT)“,只回传固定步数(如 100 步),平衡计算量和梯度信息。

5. 梯度消失与梯度爆炸

Slide 35 Slide 36 Slide 37 Slide 38 Slide 39 Slide 40 Slide 41
  • 梯度消失
    • J(t)h(1)=h(2)h(1)×h(3)h(2)××J(t)h(t)\frac{\partial J^{(t)}}{\partial h^{(1)}} = \frac{\partial h^{(2)}}{\partial h^{(1)}} \times \frac{\partial h^{(3)}}{\partial h^{(2)}} \times \cdots \times \frac{\partial J^{(t)}}{\partial h^{(t)}}
    • 每一步乘以 WhW_h 的转置,如果特征值 < 1,连乘后趋近 0
    • 结果:远距离信号丢失,模型只能学到近距离依赖
  • 梯度爆炸
    • 特征值 > 1 时,梯度指数增长
    • 解决方案:梯度裁剪(Gradient Clipping)
      • 如果 g^threshold\|\hat{g}\| \ge \text{threshold},则 g^thresholdg^g^\hat{g} \leftarrow \frac{\text{threshold}}{\|\hat{g}\|} \hat{g}
  • 对 RNN-LM 的影响:无法学习长距离依赖(如 “tickets…tickets” 的跨句指代)

📐 梯度消失的数学原因

梯度链的展开(损失 J(t)J^{(t)} 对早期隐状态 h(k)h^{(k)} 的梯度):

J(t)h(k)=J(t)h(t)j=k+1th(j)h(j1)\frac{\partial J^{(t)}}{\partial h^{(k)}} = \frac{\partial J^{(t)}}{\partial h^{(t)}} \cdot \prod_{j=k+1}^{t} \frac{\partial h^{(j)}}{\partial h^{(j-1)}}

每一步的 Jacobian:

h(j)h(j1)=WhTdiag(σ(z(j)))\frac{\partial h^{(j)}}{\partial h^{(j-1)}} = W_h^T \cdot \text{diag}(\sigma'(z^{(j)}))

谱半径决定命运:设 WhTW_h^T 的最大奇异值为 σ1\sigma_1,激活函数的最大导数为 γ\gamma

  • 如果 σ1γ<1\sigma_1 \gamma < 1tkt - k 步后梯度 (σ1γ)tk0\sim (\sigma_1 \gamma)^{t-k} \to 0梯度消失
  • 如果 σ1γ>1\sigma_1 \gamma > 1:梯度 (σ1γ)tk\sim (\sigma_1 \gamma)^{t-k} \to \infty梯度爆炸
  • 只有精确 =1= 1 时梯度才稳定(实践中极难维持)

为什么激活函数导数也是关键

  • tanh(z)=1tanh2(z)(0,1]\tanh'(z) = 1 - \tanh^2(z) \in (0, 1],在饱和区趋向 0
  • sigmoid 的导数最大值仅 0.25,梯度消失更严重
  • ReLU 在 z>0z > 0 区域导数 = 1,但负值区域梯度 = 0

Gradient Clipping 只治”爆炸”

if g^threshold:g^thresholdg^g^\text{if } \|\hat{g}\| \ge \text{threshold}: \quad \hat{g} \leftarrow \frac{\text{threshold}}{\|\hat{g}\|} \hat{g}

这等比例缩小梯度向量,保持方向但限制大小。不能增大消失的梯度。

📚 已收录至 拓展阅读知识库

🔢 梯度消失的数值示意

标量简化:设 Wh=0.9W_h = 0.9tanh(z(j))=0.8\tanh'(z^{(j)}) = 0.8(常数,简化):

每步缩减因子=0.9×0.8=0.72\text{每步缩减因子} = 0.9 \times 0.8 = 0.72

距离(时间步数)梯度大小
tk=1t - k = 10.721=0.7200.72^1 = 0.720
tk=5t - k = 50.725=0.1930.72^5 = 0.193
tk=10t - k = 100.7210=0.0370.72^{10} = 0.037
tk=20t - k = 200.7220=0.0010.72^{20} = 0.001

20 步之前的信号梯度已衰减到 0.1%,实际上无法更新相关参数。

对比爆炸(Wh=1.1W_h = 1.1σ=1.0\sigma' = 1.0):

  • tk=10t - k = 101.110=2.591.1^{10} = 2.59
  • tk=20t - k = 201.120=6.731.1^{20} = 6.73
  • tk=50t - k = 501.150=1171.1^{50} = 117 → NaN

⚠️ 常见误区

  1. 误区:梯度消失 = 梯度为 0 → 正确:梯度消失是指数衰减,不是突然归零。距离 20 步时可能只有 0.001 的大小,小到无法驱动参数更新,但并不完全是 0。
  2. 误区:梯度裁剪可以防止梯度消失 → 正确:裁剪只防止爆炸(缩小过大的梯度)。防止消失需要架构改变(LSTM 的加法更新、残差连接)。

6. 解决梯度消失:LSTM

Slide 42 Slide 43 Slide 44 Slide 45 Slide 46 Slide 47 Slide 48 Slide 49 Slide 50
  • LSTM(Long Short-Term Memory):引入独立的记忆单元 c(t)c^{(t)}
  • 门控机制:
    • 遗忘门(Forget gate):f(t)=σ(Wfh(t1)+Ufx(t)+bf)f^{(t)} = \sigma(W_f h^{(t-1)} + U_f x^{(t)} + b_f)
    • 输入门(Input gate):i(t)=σ(Wih(t1)+Uix(t)+bi)i^{(t)} = \sigma(W_i h^{(t-1)} + U_i x^{(t)} + b_i)
    • 输出门(Output gate):o(t)=σ(Woh(t1)+Uox(t)+bo)o^{(t)} = \sigma(W_o h^{(t-1)} + U_o x^{(t)} + b_o)
  • 记忆更新:c(t)=f(t)c(t1)+i(t)c~(t)c^{(t)} = f^{(t)} \odot c^{(t-1)} + i^{(t)} \odot \tilde{c}^{(t)}
  • 隐状态:h(t)=o(t)tanh(c(t))h^{(t)} = o^{(t)} \odot \tanh(c^{(t)})
  • 关键:记忆单元的加法更新(而非乘法)使梯度更容易流过长序列
  • 其他方案:GRU(更简化的门控)、残差连接、注意力机制

📐 LSTM 完整推导——为什么门控能缓解梯度消失?

LSTM 的核心区别:引入记忆单元(cell state) c(t)c^{(t)},通过加法更新而非乘法更新传递信息。

所有门的计算(共 4 个):

[ifoc~](t)=[σσσtanh](W[h(t1)x(t)]+b)\begin{bmatrix} i \\ f \\ o \\ \tilde{c} \end{bmatrix}^{(t)} = \begin{bmatrix} \sigma \\ \sigma \\ \sigma \\ \tanh \end{bmatrix}\left(W \begin{bmatrix} h^{(t-1)} \\ x^{(t)} \end{bmatrix} + b\right)

展开为:

  • 遗忘门:f(t)=σ(Wf[h(t1),x(t)]+bf)(0,1)nf^{(t)} = \sigma(W_f [h^{(t-1)}, x^{(t)}] + b_f) \in (0, 1)^n
  • 输入门:i(t)=σ(Wi[h(t1),x(t)]+bi)(0,1)ni^{(t)} = \sigma(W_i [h^{(t-1)}, x^{(t)}] + b_i) \in (0, 1)^n
  • 候选记忆:c~(t)=tanh(Wc[h(t1),x(t)]+bc)(1,1)n\tilde{c}^{(t)} = \tanh(W_c [h^{(t-1)}, x^{(t)}] + b_c) \in (-1, 1)^n
  • 输出门:o(t)=σ(Wo[h(t1),x(t)]+bo)(0,1)no^{(t)} = \sigma(W_o [h^{(t-1)}, x^{(t)}] + b_o) \in (0, 1)^n

记忆单元(加法更新——关键!)

c(t)=f(t)c(t1)+i(t)c~(t)c^{(t)} = f^{(t)} \odot c^{(t-1)} + i^{(t)} \odot \tilde{c}^{(t)}

隐状态(输出)

h(t)=o(t)tanh(c(t))h^{(t)} = o^{(t)} \odot \tanh(c^{(t)})

为什么加法更新能缓解梯度消失?

c(t)c^{(t)}c(t1)c^{(t-1)} 的梯度:

c(t)c(t1)=f(t)\frac{\partial c^{(t)}}{\partial c^{(t-1)}} = f^{(t)}

(忽略高阶项)。这是逐元素乘以遗忘门,而不是乘以大矩阵!

对比标准 RNN 的 h(t)h(t1)=WhTdiag(σ)\frac{\partial h^{(t)}}{\partial h^{(t-1)}} = W_h^T \text{diag}(\sigma')(矩阵乘法,连乘后快速衰减)。

如果遗忘门 f(t)1f^{(t)} \approx 1(不遗忘),梯度 ≈ 1,不消失也不爆炸。这类似于 ResNet 的残差连接!

📚 已收录至 拓展阅读知识库

🔢 LSTM 一步更新数值示例

设定n=2n=2(隐藏维度),当前状态:

  • h(t1)=[0.5,0.3]Th^{(t-1)} = [0.5, -0.3]^Tc(t1)=[0.8,0.2]Tc^{(t-1)} = [0.8, 0.2]^T
  • 输入 x(t)=[1.0,0.0]Tx^{(t)} = [1.0, 0.0]^T

假设(为简化,直接给出门的激活值):

  • f(t)=[0.9,0.8]Tf^{(t)} = [0.9, 0.8]^T(遗忘大部分记忆)
  • i(t)=[0.7,0.4]Ti^{(t)} = [0.7, 0.4]^T(写入部分新信息)
  • c~(t)=[0.6,0.5]T\tilde{c}^{(t)} = [0.6, -0.5]^T(候选新记忆)
  • o(t)=[0.6,0.7]To^{(t)} = [0.6, 0.7]^T(输出多少到隐状态)

计算

  1. 新记忆c(t)=fc(t1)+ic~c^{(t)} = f \odot c^{(t-1)} + i \odot \tilde{c} =[0.9×0.8,0.8×0.2]+[0.7×0.6,0.4×(0.5)]= [0.9 \times 0.8, 0.8 \times 0.2] + [0.7 \times 0.6, 0.4 \times (-0.5)] =[0.72,0.16]+[0.42,0.20]=[1.14,0.04]= [0.72, 0.16] + [0.42, -0.20] = [1.14, -0.04]

  2. 新隐状态h(t)=otanh(c(t))h^{(t)} = o \odot \tanh(c^{(t)}) tanh([1.14,0.04])=[0.817,0.040]\tanh([1.14, -0.04]) = [0.817, -0.040] h(t)=[0.6×0.817,0.7×(0.040)]=[0.490,0.028]h^{(t)} = [0.6 \times 0.817, 0.7 \times (-0.040)] = [0.490, -0.028]

观察:遗忘门 f=0.9f = 0.9(几乎保留过去记忆),输入门写入新信息,输出门控制最终输出。

⚠️ 常见误区

  1. 误区:LSTM 完全解决了梯度消失 → 正确:LSTM 大幅缓解梯度消失(通过记忆单元的加法路径),但并不完全消除。在极长序列(数千步)上仍然困难。Transformer 通过注意力机制进一步解决这个问题。
  2. 误区:遗忘门为 0 = 清除所有记忆 → 正确f=0f = 0c(t)=ic~(t)c^{(t)} = i \odot \tilde{c}^{(t)},完全丢弃过去记忆,完全用新信息。f=1f = 1 时完全保留记忆(加法更新不改变 cc)。中间值 = 部分遗忘。
  3. 误区:LSTM 有 4 组权重矩阵,参数是 RNN 的 4 倍 → 正确:可以将 4 个门的权重矩阵拼为一个大矩阵 WR4n×(n+d)W \in \mathbb{R}^{4n \times (n+d)},一次矩阵乘法完成所有门的计算(更高效)。

7. 机器翻译简介

Slide 51 Slide 52 Slide 53 Slide 54 Slide 55 Slide 56 Slide 57 Slide 58 Slide 59
  • 从源语言 xx 到目标语言 yy 的翻译
  • NMT 的成功:2014 年 seq2seq 首篇论文 \to 2016 年 Google 全面替换 SMT
  • Encoder-Decoder 架构:
    • Encoder RNN 编码源句 \to 上下文向量
    • Decoder RNN 以此为条件生成目标句
  • 条件语言模型:P(yx)=P(y1x)P(y2y1,x)P(yTy1,,yT1,x)P(y|x) = P(y_1|x) P(y_2|y_1, x) \cdots P(y_T|y_1, \ldots, y_{T-1}, x)

📐 条件语言模型的训练目标

目标:最大化给定源句 xx 时,目标句 yy 的对数似然:

L=(x,y)datalogP(yx;θ)=(x,y)t=1ylogP(yty<t,x;θ)\mathcal{L} = \sum_{(x,y) \in \text{data}} \log P(y|x; \theta) = \sum_{(x,y)} \sum_{t=1}^{|y|} \log P(y_t | y_{<t}, x; \theta)

Encoder-Decoder 分解

  • Encoderhenc=RNN(x1,,xm)h_{\text{enc}} = \text{RNN}(x_1, \ldots, x_m),最终状态 henc(m)h_{\text{enc}}^{(m)} 作为初始 decoder 状态
  • Decoderhdec(t)=RNN(yt1,hdec(t1))h_{\text{dec}}^{(t)} = \text{RNN}(y_{t-1}, h_{\text{dec}}^{(t-1)})P(yt)=softmax(Whdec(t))P(y_t | \ldots) = \text{softmax}(W h_{\text{dec}}^{(t)})

瓶颈问题的数学表现

  • 整个源句 x1,,xmx_1, \ldots, x_m 被压缩为一个向量 henc(m)Rnh_{\text{enc}}^{(m)} \in \mathbb{R}^n(固定维度)
  • 无论源句多长(m=5m = 5 还是 m=100m = 100),都要压缩到同样大小的向量
  • 信息论:固定容量的”瓶颈”对长句信息损失严重

解决方案(引出注意力):让 decoder 在每一步直接访问所有 encoder 隐状态 {henc(1),,henc(m)}\{h_{\text{enc}}^{(1)}, \ldots, h_{\text{enc}}^{(m)}\},而非只用最后一个状态。

📚 已收录至 拓展阅读知识库

🔗 知识关联

  • → L05 注意力机制:本节的瓶颈问题直接引发了注意力机制的发明(Bahdanau et al., 2015)
  • → L05 Transformer:完全放弃 RNN,改用自注意力解决长距离依赖问题
  • ← L03 反向传播:BPTT 是 L03 中讲的计算图反向传播在序列上的应用
  • ← L02 Word2Vec:RNN-LM 的输入是词嵌入,与 L02 中的词向量直接对接

⚠️ 常见误区

  1. 误区:Seq2Seq 中 encoder 和 decoder 不共享权重 → 正确:这是设计选择,有些模型共享,有些不共享。通常 encoder 和 decoder 是不同的 RNN(不同参数)。
  2. 误区:Teacher forcing = 作弊 → 正确:Teacher forcing(训练时用真实目标词作为 decoder 输入,而非上一步预测词)是训练的标准方法,虽然会造成”暴露偏差”(训练 vs 推理时的分布差异),但实践中效果好。

推荐阅读

关联概念

作业提醒

  • A2 截止:Jan 22 (Thu)

个人笔记