2.4 Recurrent Neural Networks
RNN 是把 neural network 放到序列上的最早一批通用框架。它的核心不是某个特定 cell,而是一个 factorization:序列概率可以按时间展开,模型在每个时刻用 hidden state 总结过去。
Running Example: Character Language Model
先不要急着看公式。假设我们要训练一个很小的 character-level language model,语料只有一句:
深 度 学 习 很 有 趣
训练样本可以构造成:
| input prefix | target |
|---|---|
| 深 | 度 |
| 深 度 | 学 |
| 深 度 学 | 习 |
| 深 度 学 习 | 很 |
如果用 one-hot 表示 token,那么每个 \(x_t\) 是词表大小 \(V\) 的离散向量。RNN 做的事情是逐步读入 token:
\[ h_1=f(x_1,h_0),\quad h_2=f(x_2,h_1),\quad h_3=f(x_3,h_2). \]
当模型在第 \(t\) 步预测 \(x_{t+1}\) 时,\(h_t\) 就是它对前缀 \(x_{\leq t}\) 的压缩记忆。比如看到“深 度 学”以后,\(h_3\) 应该包含“这可能是在讲深度学习”这个上下文,从而提高下一个 token 是“习”的概率。
RNN is not mysterious recurrence for its own sake. It is a way to turn a variable-length prefix into a fixed-size state so that a classifier can predict the next discrete token.
Sequence Factorization
For a sequence \(x_{1:T}\), autoregressive modeling writes \[ p(x_{1:T}) = \prod_{t=1}^{T}p(x_t\mid x_{<t}). \] The modeling problem is to build a state representation of \(x_{<t}\) that is useful for predicting \(x_t\).
RNN 的答案是 hidden state:
\[ h_t = \phi(W_{xh}x_t+W_{hh}h_{t-1}+b_h), \]
\[ \hat{y}_t = W_{hy}h_t+b_y. \]
这里 \(h_t\) 是一个递归 summary。它把任意长度的历史压进固定维度向量,这既是 RNN 的优雅之处,也是它的瓶颈。
Linear RNN as a Lens
为了看清 recurrence 的本质,先去掉非线性:
\[ h_t=Ah_{t-1}+Bx_t. \]
展开可得
\[ h_t = A^t h_0 + \sum_{i=1}^{t}A^{t-i}Bx_i. \]
这条公式非常重要:过去输入 \(x_i\) 对当前 state 的影响由 \(A^{t-i}\) 控制。若 \(A\) 的特征值模长小于 \(1\),远处输入会衰减;若大于 \(1\),远处输入会爆炸。
递推展开:
\[ h_t=A(Ah_{t-2}+Bx_{t-1})+Bx_t =A^2h_{t-2}+ABx_{t-1}+Bx_t. \]
继续展开到 \(h_0\):
\[ h_t=A^th_0+\sum_{i=1}^{t}A^{t-i}Bx_i. \]
若 \(A=Q\Lambda Q^{-1}\) 可对角化,则
\[ A^k=Q\Lambda^kQ^{-1}. \]
\(\Lambda^k\) 中每个方向按 \(\lambda_j^k\) 缩放。于是 RNN 的“记忆长度”本质上和 recurrent dynamics 的谱有关。非线性 RNN 虽然没有这么简单,但局部线性化后的 Jacobian product 仍然遵循同一类逻辑。
For a linear recurrence \(h_t=Ah_{t-1}+Bx_t\), if the spectral radius \(\rho(A)<1\), then the influence of \(h_0\) decays exponentially; if \(\rho(A)>1\), some directions can grow exponentially.
对任意 \(\epsilon>0\),存在矩阵范数使得
\[ \|A\|\le \rho(A)+\epsilon. \]
若 \(\rho(A)<1\),取足够小的 \(\epsilon\) 使 \(\rho(A)+\epsilon<1\),则
\[ \|A^t h_0\| \le \|A\|^t\|h_0\| \le (\rho(A)+\epsilon)^t\|h_0\|. \]
这说明初始状态影响指数衰减。若存在特征值 \(|\lambda|>1\),沿对应特征方向有 \(A^t v=\lambda^t v\),范数指数增长。
Training by Backpropagation Through Time
把 RNN 沿时间展开后,它就是一个参数共享的深层网络。loss 为
\[ \mathcal{L} = \sum_{t=1}^{T}\ell_t(\hat{y}_t,y_t). \]
对上面的字符模型,teacher forcing 的训练过程是:
- 输入真实前缀
深 度 学; - 模型分别输出预测
度、学、习的 logits; - 每个位置都和真实下一个 token 做 cross entropy;
- 把三个位置的 loss 加起来反传。
这和生成时不一样。生成时模型预测出一个 token 后,要把自己的预测喂回去;训练时则把真实 token 喂回去。这种 mismatch 称为 exposure bias。
During teacher forcing, the model always conditions on correct previous tokens. During generation, it conditions on its own sampled tokens. A small early mistake can shift the hidden state into an unfamiliar region.
由于 \(h_t\) 依赖 \(h_{t-1}\),梯度包含时间方向的 Jacobian product:
\[ \frac{\partial \mathcal{L}}{\partial h_k} = \sum_{t\geq k} \frac{\partial \ell_t}{\partial h_t} \prod_{i=k+1}^{t} \frac{\partial h_i}{\partial h_{i-1}}. \]
RNN theoretically has access to all past tokens through \(h_t\), but in practice long-range information must survive repeated nonlinear transformations. This creates vanishing/exploding gradient problems and memory bottlenecks.
BPTT as a Recursive Error Signal
仅仅写 Jacobian product 还不够。真正实现 BPTT 时,可以把每个时刻的 hidden gradient 看成一个从未来回流的 error signal。
令
\[ a_t=W_{xh}x_t+W_{hh}h_{t-1}+b_h, \qquad h_t=\phi(a_t), \qquad o_t=W_{hy}h_t+b_y. \]
假设每个时刻都有 loss \(\ell_t(o_t,y_t)\)。先定义输出侧梯度:
\[ g_t^o=\frac{\partial \ell_t}{\partial o_t}. \]
对 hidden state 的总梯度包含两部分:
- 当前时刻输出 loss 对 \(h_t\) 的梯度;
- 未来时刻通过 \(h_{t+1}\) 传回来的梯度。
因此有递推:
\[ \delta_t^h = W_{hy}^{\top}g_t^o + \left(W_{hh}^{\top}\delta_{t+1}^a\right), \]
其中
\[ \delta_t^a = \delta_t^h\odot\phi'(a_t). \]
参数梯度为
\[ \frac{\partial \mathcal{L}}{\partial W_{hy}} = \sum_t h_t^{\top}g_t^o, \qquad \frac{\partial \mathcal{L}}{\partial b_y} = \sum_t g_t^o, \]
\[ \frac{\partial \mathcal{L}}{\partial W_{xh}} = \sum_t x_t^{\top}\delta_t^a, \qquad \frac{\partial \mathcal{L}}{\partial W_{hh}} = \sum_t h_{t-1}^{\top}\delta_t^a. \]
从计算图看,\(h_t\) 有两条出边:一条到当前输出 \(o_t\),一条到下一时刻 pre-activation \(a_{t+1}\)。所以
\[ \frac{\partial \mathcal{L}}{\partial h_t} = \frac{\partial \ell_t}{\partial h_t} + \frac{\partial \mathcal{L}_{>t}}{\partial h_t}. \]
第一项由 \(o_t=W_{hy}h_t+b_y\) 给出
\[ \frac{\partial \ell_t}{\partial h_t} = W_{hy}^{\top}g_t^o. \]
第二项通过
\[ a_{t+1}=W_{xh}x_{t+1}+W_{hh}h_t+b_h \]
传回,因此为 \(W_{hh}^{\top}\delta_{t+1}^a\)。再乘以 activation derivative 得到 \(\delta_t^a\)。参数梯度则由每个时刻共享参数的局部 Jacobian 求和得到。
这个递推很重要:RNN 不是有 \(T\) 套参数,而是一套参数被时间展开重复使用,所以所有时刻对 \(W_{xh},W_{hh},W_{hy}\) 的梯度都要相加。
Output Loss and Label Shifting
语言模型里,RNN 通常输入 \(x_{1:T}\),预测 \(x_{2:T+1}\)。若 logits 为
\[ o_t=W_{hy}h_t+b_y\in\mathbb{R}^{V}, \]
softmax 概率为
\[ p_t(j)=\frac{\exp(o_{t,j})}{\sum_{k=1}^{V}\exp(o_{t,k})}. \]
对目标 token \(y_t\),cross entropy 是
\[ \ell_t=-\log p_t(y_t). \]
它对 logits 的梯度有一个非常常用的形式:
\[ \frac{\partial \ell_t}{\partial o_{t,j}} = p_t(j)-\mathbf{1}[j=y_t]. \]
把 loss 写成
\[ \ell_t = -o_{t,y_t} + \log\sum_k \exp(o_{t,k}). \]
对 \(o_{t,j}\) 求导:
\[ \frac{\partial \ell_t}{\partial o_{t,j}} = -\mathbf{1}[j=y_t] + \frac{\exp(o_{t,j})}{\sum_k\exp(o_{t,k})} = p_t(j)-\mathbf{1}[j=y_t]. \]
这就是为什么上面 BPTT 里可以把 \(g_t^o\) 看成当前位置的“概率误差”。实现时最常见的错误是 label shift 错一格:
# idx: [B, T]
inp = idx[:, :-1]
tgt = idx[:, 1:]
logits, _ = model(inp) # [B, T-1, V]
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
tgt.reshape(-1),
)如果把 idx[:, :-1] 和 idx[:, :-1] 自己对齐,模型会学 identity copy;如果把 targets 错移两格,loss 也许还能下降,但生成会明显不对。
Vanishing and Exploding Gradients
若 recurrent Jacobian 的谱半径长期小于 \(1\),梯度消失;若大于 \(1\),梯度爆炸。对于 vanilla RNN,
\[ \frac{\partial h_t}{\partial h_{t-1}} = \operatorname{diag}(\phi'(a_t))W_{hh}. \]
tanh/sigmoid 的 \(\phi'\) 通常小于 \(1\),因此长程依赖很难训练。这是 LSTM/GRU 出现的直接原因。
更形式化地,若存在常数 \(\gamma<1\),使得每一步
\[ \left\| \frac{\partial h_i}{\partial h_{i-1}} \right\|_2 \le \gamma, \]
则从 \(t\) 传回 \(k\) 的梯度范数满足
\[ \left\| \frac{\partial h_t}{\partial h_k} \right\|_2 \le \prod_{i=k+1}^{t} \left\| \frac{\partial h_i}{\partial h_{i-1}} \right\|_2 \le \gamma^{t-k}. \]
距离越远,梯度指数衰减。反过来若平均范数大于 \(1\),梯度可能指数爆炸。RNN 训练里的 gradient clipping 不是装饰,而是为了防止这种时间方向的乘积突然失控。
把这个结论和 linear RNN 对照,可以看到 forward memory 和 backward gradient 是同一枚硬币的两面。forward 中 \(A^{t-i}\) 决定过去输入保留多久;backward 中 Jacobian product 决定未来 loss 能否给过去状态分配 credit。
Orthogonal recurrent weights can keep \(\|W_{hh}\|_2\) near one at initialization, but nonlinear gates, inputs, layer normalization, optimizer updates, and loss curvature still change the effective Jacobian during training.
一个简单诊断是记录每个 batch 的 global grad norm 和 hidden norm:
| signal | symptom | likely issue |
|---|---|---|
| grad norm grows with sequence length | exploding BPTT path | clip, shorter TBPTT, smaller LR |
| grad norm near zero for long tasks | vanishing credit | gated cell, residual/attention, easier curriculum |
| hidden norm diverges | unstable recurrence | recurrent init, normalization, clipping |
| hidden norm collapses to constant | state ignored | too much regularization, bad gate bias |
LSTM
An LSTM maintains a hidden state \(h_t\) and a cell state \(c_t\): \[ i_t=\sigma(W_i[x_t,h_{t-1}]+b_i), \] \[ f_t=\sigma(W_f[x_t,h_{t-1}]+b_f), \] \[ o_t=\sigma(W_o[x_t,h_{t-1}]+b_o), \] \[ \tilde{c}_t=\tanh(W_c[x_t,h_{t-1}]+b_c), \] \[ c_t=f_t\odot c_{t-1}+i_t\odot\tilde{c}_t, \qquad h_t=o_t\odot\tanh(c_t). \]
LSTM 的关键是 additive memory path:
\[ c_t=f_t\odot c_{t-1}+i_t\odot\tilde{c}_t. \]
相比 vanilla RNN 的反复矩阵乘法,cell state 让信息可以沿时间近似线性传播。forget gate 决定保留多少旧记忆,input gate 决定写入多少新信息。
实际实现通常不会给四个 gate 分别做四次矩阵乘法,而是一次性算出四倍 hidden size 的向量:
gates = x_t @ W_ih.T + h_prev @ W_hh.T + b
i, f, g, o = gates.chunk(4, dim=-1)
i = torch.sigmoid(i)
f = torch.sigmoid(f)
g = torch.tanh(g)
o = torch.sigmoid(o)
c = f * c_prev + i * g
h = o * torch.tanh(c)这背后有两个工程细节:
- gate fusion 减少 kernel launch 和内存读写;
forget gate bias常初始化为正数,让训练初期更倾向保留记忆。
For many LSTM tasks, initializing the forget-gate bias to a positive value such as \(1\) makes \(f_t=\sigma(b_f)\) closer to \(0.73\) at the start, so the cell does not forget everything before learning useful gates.
Why LSTM Helps
看 cell state 的导数:
\[ c_t=f_t\odot c_{t-1}+i_t\odot\tilde{c}_t. \]
忽略 gates 对 \(c_{t-1}\) 的间接依赖时,有
\[ \frac{\partial c_t}{\partial c_{t-1}} \approx f_t. \]
于是从 \(c_t\) 传回 \(c_k\) 的梯度近似为
\[ \frac{\partial c_t}{\partial c_k} \approx \prod_{i=k+1}^{t}f_i. \]
如果 forget gate 学到 \(f_i\approx1\),信息可以沿 cell state 长距离保留;如果学到 \(f_i\approx0\),模型主动遗忘。这比 vanilla RNN 里固定由 \(W_{hh}\) 和 activation derivative 决定的梯度路径灵活得多。
forget gate 还可以解释 memory timescale。若某一维长期有常数 \(f\),旧记忆经过 \(k\) 步后剩下
\[ f^k. \]
令 \(f^k=e^{-1}\),得到有效时间常数
\[ \tau=-\frac{1}{\log f}. \]
例如 \(f=0.9\) 时 \(\tau\approx9.5\),\(f=0.99\) 时 \(\tau\approx99.5\)。这说明 LSTM 不是“自动记住一切”,而是通过 gate 学每个维度的遗忘时间尺度。
An LSTM cell can learn an approximate identity path through \(c_t\) by setting forget gates near one and input gates near zero, allowing gradients to propagate across many steps without repeated multiplication by a dense recurrent matrix.
若 \(f_i=\mathbf{1}\) 且 \(i_i=\mathbf{0}\),则
\[ c_i=c_{i-1}. \]
所以
\[ \frac{\partial c_t}{\partial c_k}=I. \]
实际训练中 gates 不会总是精确取 0 或 1,但 sigmoid gates 允许模型在不同维度上选择保留、覆盖或遗忘。这样梯度可以沿 additive state path 传播,而不必每一步都乘以同一个 dense recurrent matrix。
GRU
GRU 把 LSTM 的 cell state 和 hidden state 合并。常见写法之一是:
\[ z_t=\sigma(W_z[x_t,h_{t-1}]), \qquad r_t=\sigma(W_r[x_t,h_{t-1}]), \]
\[ \tilde{h}_t=\tanh(W_h[x_t,r_t\odot h_{t-1}]), \]
\[ h_t=(1-z_t)\odot h_{t-1}+z_t\odot \tilde{h}_t. \]
这里 \(z_t\) 越大,越偏向写入 candidate;越小,越保留旧状态。也有文献和框架使用相反约定,把 \(z_t\) 解释成保留旧状态的比例,例如 PyTorch 文档里常见形式是
\[ h_t=(1-z_t)\odot n_t+z_t\odot h_{t-1}. \]
两者本质等价,只是 \(z\) 的语义相反。读论文或复现代码时一定要看清楚 gate convention。
GRU 参数更少,常在数据规模较小或延迟敏感场景中使用。它没有单独的 \(c_t\),所以状态接口更简单;但在需要非常长的记忆时,LSTM 的显式 cell path 往往更容易解释和调试。
| cell | state | gates | strengths | common use |
|---|---|---|---|---|
| vanilla RNN | \(h_t\) | none | simplest, fast | toy models, short sequences |
| GRU | \(h_t\) | update/reset | fewer params than LSTM | moderate sequence tasks |
| LSTM | \((h_t,c_t)\) | input/forget/output | explicit memory path | long dependencies, older seq2seq |
Manual Scan Implementation
nn.GRU 和 nn.LSTM 隐藏了时间循环,但理解底层 scan 很有用。一个手写 RNN cell 通常长这样:
class RNNCell(nn.Module):
def __init__(self, dim: int, hidden: int):
super().__init__()
self.xh = nn.Linear(dim, hidden, bias=True)
self.hh = nn.Linear(hidden, hidden, bias=False)
def forward(self, x_t, h):
return torch.tanh(self.xh(x_t) + self.hh(h))
def scan(cell, x, h0):
# x: [B, T, D]
h = h0
outs = []
for t in range(x.size(1)):
h = cell(x[:, t], h)
outs.append(h)
return torch.stack(outs, dim=1), h这个 loop 有几个 shape invariant:
| tensor | shape | meaning |
|---|---|---|
x |
[B, T, D] |
embedded input sequence |
h |
[B, H] |
current hidden state |
outs |
[B, T, H] |
hidden states for all timesteps |
logits |
[B, T, V] |
token logits per timestep |
RNN 的顺序依赖意味着无法像 self-attention 那样完全并行化时间维。训练可以 batch over B,但时间循环仍然是 sequential dependency。这也是 Transformer 在硬件上更容易扩展的关键原因之一。
Truncated BPTT
完整 BPTT 需要把整个长度 \(T\) 的计算图都留到 backward。长序列时这会带来两个问题:
- activation memory 随 \(T\) 增长;
- 梯度路径太长,数值不稳定。
Truncated BPTT 每隔 \(K\) 步截断一次 hidden state:
h = None
for chunk in chunks(sequence, length=K):
logits, h = rnn(chunk, h)
loss = criterion(logits, targets)
optimizer.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
h = h.detach()h.detach() 的含义是:保留 hidden value 作为下一段上下文,但不让梯度继续穿过上一段图。于是模型 forward 仍有跨 chunk 的状态,backward 只回传 \(K\) 步。
If the hidden state is detached every single token, the model cannot assign credit across time. If it is never detached on long streams, memory grows and gradients may become unstable.
LSTM 的 hidden state 是 tuple,所以要同时 detach:
def detach_state(state):
if isinstance(state, tuple):
return tuple(s.detach() for s in state)
return state.detach()在语言模型 streaming 训练里,还要区分两件事:
- state carry: 下一段 forward 是否继承上一段 hidden;
- gradient carry: backward 是否穿过上一段计算图。
TBPTT 的常见选择是 carry state but not gradient。若每个 chunk 都重置 state,模型退化成只能看 chunk 内上下文;若完全不截断,则显存随流长度增长。
| strategy | forward context | backward horizon | use case |
|---|---|---|---|
| reset every chunk | local only | \(K\) | independent samples |
| carry + detach | cross-chunk state | \(K\) | streaming LM default |
| carry no detach | full stream | stream length | very short streams only |
还有一个常见细节:如果一个 epoch 里把长语料切成连续 chunks,shuffle 的单位不能随便打乱 token 顺序,否则 carry state 会接到不相关文本上。可以 shuffle streams,但每条 stream 内部保持时间顺序。
Stateful Batching and Reset Masks
TBPTT 的 h.detach() 只解决“梯度图有多长”的问题,不自动解决“这个 hidden state 属于谁”的问题。真实 streaming dataloader 常把多个长序列切成 chunks,再把 chunks 放进 batch slots:
step 0: slot 0 -> doc A chunk 0, slot 1 -> doc B chunk 0
step 1: slot 0 -> doc A chunk 1, slot 1 -> doc B chunk 1
step 2: slot 0 -> doc C chunk 0, slot 1 -> doc B chunk 2
在 step 2,slot 0 的 hidden state 必须重置;否则 doc C 的开头会继承 doc A 的上下文。这种 bug 不会产生 shape error,却会污染训练目标。
A stateful batch slot is a batch row whose recurrent hidden state is carried across chunks only while the same logical stream remains assigned to that row.
设 \(r_{b,t}\in\{0,1\}\) 表示 batch row \(b\) 在时间 \(t\) 是否开启新 stream。带 reset mask 的 recurrence 可以写成:
\[ \tilde{h}_{b,t-1} = (1-r_{b,t})h_{b,t-1} +r_{b,t}h_{\text{init}}, \]
\[ h_{b,t} = f_\theta(x_{b,t},\tilde{h}_{b,t-1}). \]
若 \(r_{b,t}=1\),hidden state 在 forward 上被重置;若 \(r_{b,t}=0\),正常 carry。
If \(r_{b,t}=1\), then \(h_{b,t}\) is independent of the previous stream’s hidden state \(h_{b,t-1}\), assuming \(h_{\text{init}}\) is fixed or depends only on the new stream.
当 \(r_{b,t}=1\) 时:
\[ \tilde{h}_{b,t-1} = (1-1)h_{b,t-1}+1\cdot h_{\text{init}} = h_{\text{init}}. \]
因此:
\[ h_{b,t} = f_\theta(x_{b,t},h_{\text{init}}). \]
右侧不含旧的 \(h_{b,t-1}\),所以当前 state 不依赖上一条 stream 的 hidden value。若 \(h_{\text{init}}\) 是可学习初始状态,也必须按新 stream 广播,而不能从旧 stream 读取。
PyTorch 里 GRU 的 state 形状常是 [num_layers, B, H],LSTM 是 (h, c) 两个同形状 tensor。reset mask 要广播到 layer 和 hidden 维:
def apply_reset_mask(state, reset_mask, init_state=None, batch_dim=1):
# reset_mask: [B], True means this row starts a new stream.
# PyTorch GRU/LSTM state is usually [num_layers, B, H], so batch_dim=1.
shape = [1] * state.ndim
shape[batch_dim] = reset_mask.numel()
keep = (~reset_mask).reshape(shape)
keep = keep.to(device=state.device, dtype=state.dtype)
if init_state is None:
init_state = torch.zeros_like(state)
return state * keep + init_state * (1.0 - keep)
def reset_recurrent_state(state, reset_mask, init_state=None, batch_dim=1):
if state is None:
return None
if isinstance(state, tuple):
if init_state is None:
init_state = (None, None)
return tuple(
apply_reset_mask(s, reset_mask, init, batch_dim=batch_dim)
for s, init in zip(state, init_state)
)
return apply_reset_mask(state, reset_mask, init_state, batch_dim=batch_dim)一个 stateful TBPTT step 的顺序通常是:
state = reset_recurrent_state(state, batch["reset_mask"])
logits, state = model(batch["input_ids"], state)
loss = masked_lm_loss(logits, batch["labels"], batch["loss_mask"])
opt.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
state = detach_state(state)顺序很重要:先 reset 再 forward;backward/step 后再 detach。若先 detach 再 reset,数值上通常等价;但工程上把 reset 放在 forward 前可以明确表达“当前 batch 的哪些行是新 stream”。
Padding and Packed Sequences
RNN batch 训练通常要把不同长度序列 pad 到同一长度:
seq1: A B C D
seq2: E F PAD PAD
如果直接把 padding 送进 RNN,模型会在 PAD 上继续更新 hidden state。对分类任务,这可能污染最后 hidden;对语言建模,可能在 PAD target 上产生无意义 loss。
有三种常见处理:
- loss mask 忽略 padded positions;
- 根据真实长度取最后有效 hidden;
- 使用
pack_padded_sequence让 RNN kernel 跳过 padding。
概念上,mask 后的序列 loss 是
\[ \mathcal{L} = \frac{ \sum_{b,t}m_{b,t}\ell_{b,t} }{ \sum_{b,t}m_{b,t} }. \]
For padded batches, h[:, -1] may be the hidden state after padding. Use sequence lengths or packed sequences to select the last real timestep.
用 loss mask 写语言模型 loss 时,推荐显式控制 denominator:
loss_per_token = F.cross_entropy(
logits.reshape(-1, vocab),
targets.reshape(-1),
reduction="none",
).view_as(targets)
mask = targets.ne(pad_id)
loss = (loss_per_token * mask).sum() / mask.sum().clamp_min(1)这比直接 mean() 更稳,因为不同 batch 的 padding 比例可能不同。若用 ignore_index=pad_id,也要确认框架的 reduction 是否按非 ignore token 平均。
pack_padded_sequence 的典型写法:
lengths = attention_mask.sum(dim=1).cpu()
x = embedding(input_ids)
packed = nn.utils.rnn.pack_padded_sequence(
x,
lengths,
batch_first=True,
enforce_sorted=False,
)
packed_out, h_n = rnn(packed)
out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)packed sequence 能跳过 padding 的 recurrent computation,但它也带来一些麻烦:长度必须在 CPU 上,输出顺序/长度要对齐,和自定义 attention 或 residual block 组合时更容易出错。因此很多现代代码会选择普通 padded tensor + mask,牺牲一点计算换取实现清晰。
Sequence-to-Sequence and Attention
RNN 不只做 language model。经典 seq2seq 模型用 encoder RNN 把输入序列压成一个 context,再用 decoder RNN 逐步生成输出:
\[ h_t^{\text{enc}}=f_{\text{enc}}(x_t,h_{t-1}^{\text{enc}}), \qquad s_u=f_{\text{dec}}(y_{u-1},s_{u-1},c_u). \]
最早的 encoder-decoder 常用最后一个 encoder state 作为 context:
\[ c=h_T^{\text{enc}}. \]
这对长句很吃力,因为所有输入信息都挤在一个向量里。attention 的做法是让 decoder 每一步从所有 encoder states 中读:
\[ e_{u,t}=v^\top\tanh(W_s s_{u-1}+W_h h_t^{\text{enc}}), \]
\[ \alpha_{u,t} = \frac{\exp(e_{u,t})}{\sum_j\exp(e_{u,j})}, \qquad c_u=\sum_t \alpha_{u,t}h_t^{\text{enc}}. \]
这已经很接近 Transformer 的核心思想:不要把序列全部压进一个固定向量,而是在需要时按权重读取历史 states。
Additive attention scores each encoder state with a learned compatibility function, normalizes scores with softmax, and forms a context vector as a weighted sum of encoder states.
训练 decoder 时通常仍然 teacher forcing:
decoder input: <bos> 我 喜欢 深度 学习
target: 我 喜欢 深度 学习 <eos>
生成时则 autoregressive decode。这个训练/推理差异和 RNN language model 一样,是 exposure bias 的来源之一。
RNN as a Training Paradigm
RNN 训练引入了几个后来仍然重要的训练范式:
| Paradigm | RNN form | Later echo |
|---|---|---|
| teacher forcing | feed ground-truth previous token | LM cross entropy |
| truncated BPTT | backprop only for \(K\) steps | memory-efficient training |
| sequence-to-sequence | encoder state conditions decoder | Transformer encoder-decoder |
| autoregressive generation | sample one token at a time | GPT-style decoding |
| scheduled sampling | mix ground-truth and model tokens | exposure bias mitigation |
Transformer 替代了 RNN 的 sequential computation,但没有替代 autoregressive factorization。今天 GPT 的训练目标仍然是 RNN 时代就清楚的 next-token prediction。
Minimal PyTorch Sketch
import torch
import torch.nn.functional as F
from torch import nn
class TinyRNNLM(nn.Module):
def __init__(self, vocab: int, dim: int, hidden: int):
super().__init__()
self.emb = nn.Embedding(vocab, dim)
self.rnn = nn.GRU(dim, hidden, batch_first=True)
self.head = nn.Linear(hidden, vocab)
def forward(self, idx, state=None):
x = self.emb(idx)
h, state = self.rnn(x, state)
return self.head(h), state这个最小模型已经包含语言模型的主要接口:输入 token ids,输出每个位置的 next-token logits。
训练一步可以写成:
def lm_step(model, batch, pad_id):
# batch: [B, T]
inp = batch[:, :-1]
tgt = batch[:, 1:]
logits, _ = model(inp)
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
tgt.reshape(-1),
ignore_index=pad_id,
)
return loss若做 TBPTT,训练 loop 要显式管理 state:
state = None
for chunk in stream:
logits, state = model(chunk[:, :-1], state)
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
chunk[:, 1:].reshape(-1),
ignore_index=pad_id,
)
opt.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
opt.step()
state = detach_state(state)Implementation Checklist
实现或调试 RNN 时,可以按下面顺序查:
- 输入/target 是否正确 shift 一位;
- logits shape 是否是
[B, T, V],target shape 是否是[B, T]; - loss 是否忽略 PAD,并按非 PAD token 归一化;
- 分类任务是否取最后有效 timestep,而不是 padded column;
- TBPTT 是否 carry state but detach graph;
- LSTM 是否同时 detach
(h, c); - hidden state 是否跨不同样本错误复用;
- stateful batch slots 是否在新 stream 开始时 reset;
- 若 batch row 会复用,是否用 stream id 路由 hidden state;
- recurrent grad norm 是否需要 clipping;
- generation 时是否把模型自己的输出喂回去;
- teacher forcing ratio 或 scheduled sampling 是否只在训练阶段使用;
- bidirectional RNN 是否被误用于因果语言建模;
- packed sequence 输出是否和原 batch 顺序对齐。
两个 smoke tests:
# 1. label shifting changes target length by one
batch = torch.randint(0, vocab, (2, 8))
inp, tgt = batch[:, :-1], batch[:, 1:]
logits, _ = model(inp)
assert logits.shape[:2] == tgt.shape
# 2. TBPTT detach breaks graph across chunks
_, state = model(inp)
state = detach_state(state)
if isinstance(state, tuple):
assert all(not s.requires_grad for s in state)
else:
assert not state.requires_grad
# 3. reset mask blocks hidden-state leakage across streams
old_state = torch.randn_like(state) if not isinstance(state, tuple) else tuple(torch.randn_like(s) for s in state)
assert_reset_blocks_state_leak(model, inp, old_state)这类测试不证明模型会学好,但能抓住最常见的 RNN 工程错误:错位、PAD 污染、hidden state 泄漏,以及长序列图没有被正确截断。