2.4 Recurrent Neural Networks


RNN 是把 neural network 放到序列上的最早一批通用框架。它的核心不是某个特定 cell,而是一个 factorization:序列概率可以按时间展开,模型在每个时刻用 hidden state 总结过去。

Running Example: Character Language Model

先不要急着看公式。假设我们要训练一个很小的 character-level language model,语料只有一句:

深 度 学 习 很 有 趣

训练样本可以构造成:

input prefix target
深 度
深 度 学
深 度 学 习

如果用 one-hot 表示 token,那么每个 \(x_t\) 是词表大小 \(V\) 的离散向量。RNN 做的事情是逐步读入 token:

\[ h_1=f(x_1,h_0),\quad h_2=f(x_2,h_1),\quad h_3=f(x_3,h_2). \]

当模型在第 \(t\) 步预测 \(x_{t+1}\) 时,\(h_t\) 就是它对前缀 \(x_{\leq t}\) 的压缩记忆。比如看到“深 度 学”以后,\(h_3\) 应该包含“这可能是在讲深度学习”这个上下文,从而提高下一个 token 是“习”的概率。

NotePedagogical Key

RNN is not mysterious recurrence for its own sake. It is a way to turn a variable-length prefix into a fixed-size state so that a classifier can predict the next discrete token.

Sequence Factorization

NoteDefinition: Autoregressive Factorization

For a sequence \(x_{1:T}\), autoregressive modeling writes \[ p(x_{1:T}) = \prod_{t=1}^{T}p(x_t\mid x_{<t}). \] The modeling problem is to build a state representation of \(x_{<t}\) that is useful for predicting \(x_t\).

RNN 的答案是 hidden state:

\[ h_t = \phi(W_{xh}x_t+W_{hh}h_{t-1}+b_h), \]

\[ \hat{y}_t = W_{hy}h_t+b_y. \]

这里 \(h_t\) 是一个递归 summary。它把任意长度的历史压进固定维度向量,这既是 RNN 的优雅之处,也是它的瓶颈。

Hidden State as Compression

从概率建模角度看,RNN 在近似一个理想状态:

\[ p(x_t\mid x_{<t}) \approx p_\theta(x_t\mid h_{t-1}), \qquad h_{t-1}=F_\theta(x_{<t}). \]

如果 \(h_{t-1}\) 足够表达历史中对预测未来有用的信息,那么它就是一个 learned sufficient statistic。问题是这个 statistic 的维度固定,所以 RNN 必须在每一步做压缩:哪些信息留下,哪些信息丢掉。

NoteDefinition: Recurrent State

A recurrent state is a fixed-dimensional summary \(h_t\) updated by a shared transition map \[ h_t=f_\theta(x_t,h_{t-1}). \] It induces a conditional model by replacing the full prefix \(x_{\le t}\) with \(h_t\).

这个定义很朴素,但它直接解释了 RNN 的两个核心困难:

  1. optimization difficulty: 梯度必须穿过很多次同一个 transition;
  2. representation bottleneck: 所有历史信息都被压到 \(h_t\) 中。

Transformer 后来把第一点改成并行 attention,把第二点改成显式保留所有 token 的 states;但 next-token factorization 本身没有变。

Linear RNN as a Lens

为了看清 recurrence 的本质,先去掉非线性:

\[ h_t=Ah_{t-1}+Bx_t. \]

展开可得

\[ h_t = A^t h_0 + \sum_{i=1}^{t}A^{t-i}Bx_i. \]

这条公式非常重要:过去输入 \(x_i\) 对当前 state 的影响由 \(A^{t-i}\) 控制。若 \(A\) 的特征值模长小于 \(1\),远处输入会衰减;若大于 \(1\),远处输入会爆炸。

递推展开:

\[ h_t=A(Ah_{t-2}+Bx_{t-1})+Bx_t =A^2h_{t-2}+ABx_{t-1}+Bx_t. \]

继续展开到 \(h_0\)

\[ h_t=A^th_0+\sum_{i=1}^{t}A^{t-i}Bx_i. \]

\(A=Q\Lambda Q^{-1}\) 可对角化,则

\[ A^k=Q\Lambda^kQ^{-1}. \]

\(\Lambda^k\) 中每个方向按 \(\lambda_j^k\) 缩放。于是 RNN 的“记忆长度”本质上和 recurrent dynamics 的谱有关。非线性 RNN 虽然没有这么简单,但局部线性化后的 Jacobian product 仍然遵循同一类逻辑。

ImportantTheorem: Linear Memory Decays or Explodes Exponentially

For a linear recurrence \(h_t=Ah_{t-1}+Bx_t\), if the spectral radius \(\rho(A)<1\), then the influence of \(h_0\) decays exponentially; if \(\rho(A)>1\), some directions can grow exponentially.

对任意 \(\epsilon>0\),存在矩阵范数使得

\[ \|A\|\le \rho(A)+\epsilon. \]

\(\rho(A)<1\),取足够小的 \(\epsilon\) 使 \(\rho(A)+\epsilon<1\),则

\[ \|A^t h_0\| \le \|A\|^t\|h_0\| \le (\rho(A)+\epsilon)^t\|h_0\|. \]

这说明初始状态影响指数衰减。若存在特征值 \(|\lambda|>1\),沿对应特征方向有 \(A^t v=\lambda^t v\),范数指数增长。

Training by Backpropagation Through Time

把 RNN 沿时间展开后,它就是一个参数共享的深层网络。loss 为

\[ \mathcal{L} = \sum_{t=1}^{T}\ell_t(\hat{y}_t,y_t). \]

对上面的字符模型,teacher forcing 的训练过程是:

  1. 输入真实前缀 深 度 学
  2. 模型分别输出预测 的 logits;
  3. 每个位置都和真实下一个 token 做 cross entropy;
  4. 把三个位置的 loss 加起来反传。

这和生成时不一样。生成时模型预测出一个 token 后,要把自己的预测喂回去;训练时则把真实 token 喂回去。这种 mismatch 称为 exposure bias。

WarningPitfall: Teacher Forcing Hides Generation Errors

During teacher forcing, the model always conditions on correct previous tokens. During generation, it conditions on its own sampled tokens. A small early mistake can shift the hidden state into an unfamiliar region.

由于 \(h_t\) 依赖 \(h_{t-1}\),梯度包含时间方向的 Jacobian product:

\[ \frac{\partial \mathcal{L}}{\partial h_k} = \sum_{t\geq k} \frac{\partial \ell_t}{\partial h_t} \prod_{i=k+1}^{t} \frac{\partial h_i}{\partial h_{i-1}}. \]

WarningPitfall: Long Context Is Not Free

RNN theoretically has access to all past tokens through \(h_t\), but in practice long-range information must survive repeated nonlinear transformations. This creates vanishing/exploding gradient problems and memory bottlenecks.

BPTT as a Recursive Error Signal

仅仅写 Jacobian product 还不够。真正实现 BPTT 时,可以把每个时刻的 hidden gradient 看成一个从未来回流的 error signal。

\[ a_t=W_{xh}x_t+W_{hh}h_{t-1}+b_h, \qquad h_t=\phi(a_t), \qquad o_t=W_{hy}h_t+b_y. \]

假设每个时刻都有 loss \(\ell_t(o_t,y_t)\)。先定义输出侧梯度:

\[ g_t^o=\frac{\partial \ell_t}{\partial o_t}. \]

对 hidden state 的总梯度包含两部分:

  1. 当前时刻输出 loss 对 \(h_t\) 的梯度;
  2. 未来时刻通过 \(h_{t+1}\) 传回来的梯度。

因此有递推:

\[ \delta_t^h = W_{hy}^{\top}g_t^o + \left(W_{hh}^{\top}\delta_{t+1}^a\right), \]

其中

\[ \delta_t^a = \delta_t^h\odot\phi'(a_t). \]

参数梯度为

\[ \frac{\partial \mathcal{L}}{\partial W_{hy}} = \sum_t h_t^{\top}g_t^o, \qquad \frac{\partial \mathcal{L}}{\partial b_y} = \sum_t g_t^o, \]

\[ \frac{\partial \mathcal{L}}{\partial W_{xh}} = \sum_t x_t^{\top}\delta_t^a, \qquad \frac{\partial \mathcal{L}}{\partial W_{hh}} = \sum_t h_{t-1}^{\top}\delta_t^a. \]

从计算图看,\(h_t\) 有两条出边:一条到当前输出 \(o_t\),一条到下一时刻 pre-activation \(a_{t+1}\)。所以

\[ \frac{\partial \mathcal{L}}{\partial h_t} = \frac{\partial \ell_t}{\partial h_t} + \frac{\partial \mathcal{L}_{>t}}{\partial h_t}. \]

第一项由 \(o_t=W_{hy}h_t+b_y\) 给出

\[ \frac{\partial \ell_t}{\partial h_t} = W_{hy}^{\top}g_t^o. \]

第二项通过

\[ a_{t+1}=W_{xh}x_{t+1}+W_{hh}h_t+b_h \]

传回,因此为 \(W_{hh}^{\top}\delta_{t+1}^a\)。再乘以 activation derivative 得到 \(\delta_t^a\)。参数梯度则由每个时刻共享参数的局部 Jacobian 求和得到。

这个递推很重要:RNN 不是有 \(T\) 套参数,而是一套参数被时间展开重复使用,所以所有时刻对 \(W_{xh},W_{hh},W_{hy}\) 的梯度都要相加。

Output Loss and Label Shifting

语言模型里,RNN 通常输入 \(x_{1:T}\),预测 \(x_{2:T+1}\)。若 logits 为

\[ o_t=W_{hy}h_t+b_y\in\mathbb{R}^{V}, \]

softmax 概率为

\[ p_t(j)=\frac{\exp(o_{t,j})}{\sum_{k=1}^{V}\exp(o_{t,k})}. \]

对目标 token \(y_t\),cross entropy 是

\[ \ell_t=-\log p_t(y_t). \]

它对 logits 的梯度有一个非常常用的形式:

\[ \frac{\partial \ell_t}{\partial o_{t,j}} = p_t(j)-\mathbf{1}[j=y_t]. \]

把 loss 写成

\[ \ell_t = -o_{t,y_t} + \log\sum_k \exp(o_{t,k}). \]

\(o_{t,j}\) 求导:

\[ \frac{\partial \ell_t}{\partial o_{t,j}} = -\mathbf{1}[j=y_t] + \frac{\exp(o_{t,j})}{\sum_k\exp(o_{t,k})} = p_t(j)-\mathbf{1}[j=y_t]. \]

这就是为什么上面 BPTT 里可以把 \(g_t^o\) 看成当前位置的“概率误差”。实现时最常见的错误是 label shift 错一格:

# idx: [B, T]
inp = idx[:, :-1]
tgt = idx[:, 1:]
logits, _ = model(inp)           # [B, T-1, V]
loss = F.cross_entropy(
    logits.reshape(-1, logits.size(-1)),
    tgt.reshape(-1),
)

如果把 idx[:, :-1]idx[:, :-1] 自己对齐,模型会学 identity copy;如果把 targets 错移两格,loss 也许还能下降,但生成会明显不对。

Vanishing and Exploding Gradients

若 recurrent Jacobian 的谱半径长期小于 \(1\),梯度消失;若大于 \(1\),梯度爆炸。对于 vanilla RNN,

\[ \frac{\partial h_t}{\partial h_{t-1}} = \operatorname{diag}(\phi'(a_t))W_{hh}. \]

tanh/sigmoid 的 \(\phi'\) 通常小于 \(1\),因此长程依赖很难训练。这是 LSTM/GRU 出现的直接原因。

更形式化地,若存在常数 \(\gamma<1\),使得每一步

\[ \left\| \frac{\partial h_i}{\partial h_{i-1}} \right\|_2 \le \gamma, \]

则从 \(t\) 传回 \(k\) 的梯度范数满足

\[ \left\| \frac{\partial h_t}{\partial h_k} \right\|_2 \le \prod_{i=k+1}^{t} \left\| \frac{\partial h_i}{\partial h_{i-1}} \right\|_2 \le \gamma^{t-k}. \]

距离越远,梯度指数衰减。反过来若平均范数大于 \(1\),梯度可能指数爆炸。RNN 训练里的 gradient clipping 不是装饰,而是为了防止这种时间方向的乘积突然失控。

把这个结论和 linear RNN 对照,可以看到 forward memory 和 backward gradient 是同一枚硬币的两面。forward 中 \(A^{t-i}\) 决定过去输入保留多久;backward 中 Jacobian product 决定未来 loss 能否给过去状态分配 credit。

WarningPitfall: Orthogonal Initialization Helps but Does Not Solve Everything

Orthogonal recurrent weights can keep \(\|W_{hh}\|_2\) near one at initialization, but nonlinear gates, inputs, layer normalization, optimizer updates, and loss curvature still change the effective Jacobian during training.

一个简单诊断是记录每个 batch 的 global grad norm 和 hidden norm:

signal symptom likely issue
grad norm grows with sequence length exploding BPTT path clip, shorter TBPTT, smaller LR
grad norm near zero for long tasks vanishing credit gated cell, residual/attention, easier curriculum
hidden norm diverges unstable recurrence recurrent init, normalization, clipping
hidden norm collapses to constant state ignored too much regularization, bad gate bias

LSTM

NoteDefinition: LSTM Cell

An LSTM maintains a hidden state \(h_t\) and a cell state \(c_t\): \[ i_t=\sigma(W_i[x_t,h_{t-1}]+b_i), \] \[ f_t=\sigma(W_f[x_t,h_{t-1}]+b_f), \] \[ o_t=\sigma(W_o[x_t,h_{t-1}]+b_o), \] \[ \tilde{c}_t=\tanh(W_c[x_t,h_{t-1}]+b_c), \] \[ c_t=f_t\odot c_{t-1}+i_t\odot\tilde{c}_t, \qquad h_t=o_t\odot\tanh(c_t). \]

LSTM 的关键是 additive memory path:

\[ c_t=f_t\odot c_{t-1}+i_t\odot\tilde{c}_t. \]

相比 vanilla RNN 的反复矩阵乘法,cell state 让信息可以沿时间近似线性传播。forget gate 决定保留多少旧记忆,input gate 决定写入多少新信息。

实际实现通常不会给四个 gate 分别做四次矩阵乘法,而是一次性算出四倍 hidden size 的向量:

gates = x_t @ W_ih.T + h_prev @ W_hh.T + b
i, f, g, o = gates.chunk(4, dim=-1)
i = torch.sigmoid(i)
f = torch.sigmoid(f)
g = torch.tanh(g)
o = torch.sigmoid(o)
c = f * c_prev + i * g
h = o * torch.tanh(c)

这背后有两个工程细节:

  1. gate fusion 减少 kernel launch 和内存读写;
  2. forget gate bias 常初始化为正数,让训练初期更倾向保留记忆。
TipImplementation Tip: Initialize Forget Bias Positive

For many LSTM tasks, initializing the forget-gate bias to a positive value such as \(1\) makes \(f_t=\sigma(b_f)\) closer to \(0.73\) at the start, so the cell does not forget everything before learning useful gates.

Why LSTM Helps

看 cell state 的导数:

\[ c_t=f_t\odot c_{t-1}+i_t\odot\tilde{c}_t. \]

忽略 gates 对 \(c_{t-1}\) 的间接依赖时,有

\[ \frac{\partial c_t}{\partial c_{t-1}} \approx f_t. \]

于是从 \(c_t\) 传回 \(c_k\) 的梯度近似为

\[ \frac{\partial c_t}{\partial c_k} \approx \prod_{i=k+1}^{t}f_i. \]

如果 forget gate 学到 \(f_i\approx1\),信息可以沿 cell state 长距离保留;如果学到 \(f_i\approx0\),模型主动遗忘。这比 vanilla RNN 里固定由 \(W_{hh}\) 和 activation derivative 决定的梯度路径灵活得多。

forget gate 还可以解释 memory timescale。若某一维长期有常数 \(f\),旧记忆经过 \(k\) 步后剩下

\[ f^k. \]

\(f^k=e^{-1}\),得到有效时间常数

\[ \tau=-\frac{1}{\log f}. \]

例如 \(f=0.9\)\(\tau\approx9.5\)\(f=0.99\)\(\tau\approx99.5\)。这说明 LSTM 不是“自动记住一切”,而是通过 gate 学每个维度的遗忘时间尺度。

ImportantTheorem: Additive Memory Path Reduces Gradient Decay

An LSTM cell can learn an approximate identity path through \(c_t\) by setting forget gates near one and input gates near zero, allowing gradients to propagate across many steps without repeated multiplication by a dense recurrent matrix.

\(f_i=\mathbf{1}\)\(i_i=\mathbf{0}\),则

\[ c_i=c_{i-1}. \]

所以

\[ \frac{\partial c_t}{\partial c_k}=I. \]

实际训练中 gates 不会总是精确取 0 或 1,但 sigmoid gates 允许模型在不同维度上选择保留、覆盖或遗忘。这样梯度可以沿 additive state path 传播,而不必每一步都乘以同一个 dense recurrent matrix。

GRU

GRU 把 LSTM 的 cell state 和 hidden state 合并。常见写法之一是:

\[ z_t=\sigma(W_z[x_t,h_{t-1}]), \qquad r_t=\sigma(W_r[x_t,h_{t-1}]), \]

\[ \tilde{h}_t=\tanh(W_h[x_t,r_t\odot h_{t-1}]), \]

\[ h_t=(1-z_t)\odot h_{t-1}+z_t\odot \tilde{h}_t. \]

这里 \(z_t\) 越大,越偏向写入 candidate;越小,越保留旧状态。也有文献和框架使用相反约定,把 \(z_t\) 解释成保留旧状态的比例,例如 PyTorch 文档里常见形式是

\[ h_t=(1-z_t)\odot n_t+z_t\odot h_{t-1}. \]

两者本质等价,只是 \(z\) 的语义相反。读论文或复现代码时一定要看清楚 gate convention。

GRU 参数更少,常在数据规模较小或延迟敏感场景中使用。它没有单独的 \(c_t\),所以状态接口更简单;但在需要非常长的记忆时,LSTM 的显式 cell path 往往更容易解释和调试。

cell state gates strengths common use
vanilla RNN \(h_t\) none simplest, fast toy models, short sequences
GRU \(h_t\) update/reset fewer params than LSTM moderate sequence tasks
LSTM \((h_t,c_t)\) input/forget/output explicit memory path long dependencies, older seq2seq

Manual Scan Implementation

nn.GRUnn.LSTM 隐藏了时间循环,但理解底层 scan 很有用。一个手写 RNN cell 通常长这样:

class RNNCell(nn.Module):
    def __init__(self, dim: int, hidden: int):
        super().__init__()
        self.xh = nn.Linear(dim, hidden, bias=True)
        self.hh = nn.Linear(hidden, hidden, bias=False)

    def forward(self, x_t, h):
        return torch.tanh(self.xh(x_t) + self.hh(h))


def scan(cell, x, h0):
    # x: [B, T, D]
    h = h0
    outs = []
    for t in range(x.size(1)):
        h = cell(x[:, t], h)
        outs.append(h)
    return torch.stack(outs, dim=1), h

这个 loop 有几个 shape invariant:

tensor shape meaning
x [B, T, D] embedded input sequence
h [B, H] current hidden state
outs [B, T, H] hidden states for all timesteps
logits [B, T, V] token logits per timestep

RNN 的顺序依赖意味着无法像 self-attention 那样完全并行化时间维。训练可以 batch over B,但时间循环仍然是 sequential dependency。这也是 Transformer 在硬件上更容易扩展的关键原因之一。

Truncated BPTT

完整 BPTT 需要把整个长度 \(T\) 的计算图都留到 backward。长序列时这会带来两个问题:

  1. activation memory 随 \(T\) 增长;
  2. 梯度路径太长,数值不稳定。

Truncated BPTT 每隔 \(K\) 步截断一次 hidden state:

h = None

for chunk in chunks(sequence, length=K):
    logits, h = rnn(chunk, h)
    loss = criterion(logits, targets)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

    h = h.detach()

h.detach() 的含义是:保留 hidden value 作为下一段上下文,但不让梯度继续穿过上一段图。于是模型 forward 仍有跨 chunk 的状态,backward 只回传 \(K\) 步。

WarningPitfall: Detaching Too Early Destroys Temporal Credit Assignment

If the hidden state is detached every single token, the model cannot assign credit across time. If it is never detached on long streams, memory grows and gradients may become unstable.

LSTM 的 hidden state 是 tuple,所以要同时 detach:

def detach_state(state):
    if isinstance(state, tuple):
        return tuple(s.detach() for s in state)
    return state.detach()

在语言模型 streaming 训练里,还要区分两件事:

  1. state carry: 下一段 forward 是否继承上一段 hidden;
  2. gradient carry: backward 是否穿过上一段计算图。

TBPTT 的常见选择是 carry state but not gradient。若每个 chunk 都重置 state,模型退化成只能看 chunk 内上下文;若完全不截断,则显存随流长度增长。

strategy forward context backward horizon use case
reset every chunk local only \(K\) independent samples
carry + detach cross-chunk state \(K\) streaming LM default
carry no detach full stream stream length very short streams only

还有一个常见细节:如果一个 epoch 里把长语料切成连续 chunks,shuffle 的单位不能随便打乱 token 顺序,否则 carry state 会接到不相关文本上。可以 shuffle streams,但每条 stream 内部保持时间顺序。

Stateful Batching and Reset Masks

TBPTT 的 h.detach() 只解决“梯度图有多长”的问题,不自动解决“这个 hidden state 属于谁”的问题。真实 streaming dataloader 常把多个长序列切成 chunks,再把 chunks 放进 batch slots:

step 0: slot 0 -> doc A chunk 0, slot 1 -> doc B chunk 0
step 1: slot 0 -> doc A chunk 1, slot 1 -> doc B chunk 1
step 2: slot 0 -> doc C chunk 0, slot 1 -> doc B chunk 2

在 step 2,slot 0 的 hidden state 必须重置;否则 doc C 的开头会继承 doc A 的上下文。这种 bug 不会产生 shape error,却会污染训练目标。

NoteDefinition: Stateful Batch Slot

A stateful batch slot is a batch row whose recurrent hidden state is carried across chunks only while the same logical stream remains assigned to that row.

\(r_{b,t}\in\{0,1\}\) 表示 batch row \(b\) 在时间 \(t\) 是否开启新 stream。带 reset mask 的 recurrence 可以写成:

\[ \tilde{h}_{b,t-1} = (1-r_{b,t})h_{b,t-1} +r_{b,t}h_{\text{init}}, \]

\[ h_{b,t} = f_\theta(x_{b,t},\tilde{h}_{b,t-1}). \]

\(r_{b,t}=1\),hidden state 在 forward 上被重置;若 \(r_{b,t}=0\),正常 carry。

ImportantTheorem: Reset Masks Prevent Cross-Stream Dependence

If \(r_{b,t}=1\), then \(h_{b,t}\) is independent of the previous stream’s hidden state \(h_{b,t-1}\), assuming \(h_{\text{init}}\) is fixed or depends only on the new stream.

\(r_{b,t}=1\) 时:

\[ \tilde{h}_{b,t-1} = (1-1)h_{b,t-1}+1\cdot h_{\text{init}} = h_{\text{init}}. \]

因此:

\[ h_{b,t} = f_\theta(x_{b,t},h_{\text{init}}). \]

右侧不含旧的 \(h_{b,t-1}\),所以当前 state 不依赖上一条 stream 的 hidden value。若 \(h_{\text{init}}\) 是可学习初始状态,也必须按新 stream 广播,而不能从旧 stream 读取。

PyTorch 里 GRU 的 state 形状常是 [num_layers, B, H],LSTM 是 (h, c) 两个同形状 tensor。reset mask 要广播到 layer 和 hidden 维:

def apply_reset_mask(state, reset_mask, init_state=None, batch_dim=1):
    # reset_mask: [B], True means this row starts a new stream.
    # PyTorch GRU/LSTM state is usually [num_layers, B, H], so batch_dim=1.
    shape = [1] * state.ndim
    shape[batch_dim] = reset_mask.numel()
    keep = (~reset_mask).reshape(shape)
    keep = keep.to(device=state.device, dtype=state.dtype)

    if init_state is None:
        init_state = torch.zeros_like(state)
    return state * keep + init_state * (1.0 - keep)


def reset_recurrent_state(state, reset_mask, init_state=None, batch_dim=1):
    if state is None:
        return None
    if isinstance(state, tuple):
        if init_state is None:
            init_state = (None, None)
        return tuple(
            apply_reset_mask(s, reset_mask, init, batch_dim=batch_dim)
            for s, init in zip(state, init_state)
        )
    return apply_reset_mask(state, reset_mask, init_state, batch_dim=batch_dim)

一个 stateful TBPTT step 的顺序通常是:

state = reset_recurrent_state(state, batch["reset_mask"])
logits, state = model(batch["input_ids"], state)
loss = masked_lm_loss(logits, batch["labels"], batch["loss_mask"])

opt.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()

state = detach_state(state)

顺序很重要:先 reset 再 forward;backward/step 后再 detach。若先 detach 再 reset,数值上通常等价;但工程上把 reset 放在 forward 前可以明确表达“当前 batch 的哪些行是新 stream”。

Hidden-State Routing by Stream Id

如果 dataloader 不能保证同一个 stream 一直占同一个 batch slot,就需要按 stream id 路由 hidden state。一个简单 state bank 可以写成:

class StateBank:
    def __init__(self):
        self.store = {}

    def get(self, stream_ids, template_state):
        rows = []
        for sid in stream_ids:
            key = int(sid)
            if key in self.store:
                rows.append(self.store[key])
            else:
                rows.append(torch.zeros_like(template_state[:, :1]))
        return torch.cat(rows, dim=1)

    def put(self, stream_ids, state):
        state = state.detach()
        for row, sid in enumerate(stream_ids):
            self.store[int(sid)] = state[:, row : row + 1].contiguous()

    def drop(self, finished_stream_ids):
        for sid in finished_stream_ids:
            self.store.pop(int(sid), None)

LSTM 版本要同时存 (h, c),并且所有 state tensor 都应该在同一 device/dtype。这个设计适合少量长 streams;如果 stream 很多,state bank 会变成内存问题,需要限制活跃 stream 数或把 state 移回 CPU。

WarningPitfall: Batch Row Is Not Stream Identity

In stateful RNN training, the batch row index is only a temporary slot. The logical identity is the stream id. Reusing a slot for a new stream without resetting or rerouting hidden state leaks context across examples.

一个最小 smoke test 是:两个 stream 用相同 token chunk,但不同 initial state;重置后输出应与旧 state 无关。

def assert_reset_blocks_state_leak(model, chunk, old_state):
    reset = torch.ones(chunk.size(0), dtype=torch.bool, device=chunk.device)
    zero_state = reset_recurrent_state(old_state, reset)

    out_a, _ = model(chunk, zero_state)
    out_b, _ = model(chunk, None)

    assert torch.allclose(out_a, out_b, atol=1e-5, rtol=1e-4)

若这个测试失败,说明 reset path 没有真正清掉旧 state,或者模型在 forward 里还有其他跨样本缓存。

Padding and Packed Sequences

RNN batch 训练通常要把不同长度序列 pad 到同一长度:

seq1: A B C D
seq2: E F PAD PAD

如果直接把 padding 送进 RNN,模型会在 PAD 上继续更新 hidden state。对分类任务,这可能污染最后 hidden;对语言建模,可能在 PAD target 上产生无意义 loss。

有三种常见处理:

  1. loss mask 忽略 padded positions;
  2. 根据真实长度取最后有效 hidden;
  3. 使用 pack_padded_sequence 让 RNN kernel 跳过 padding。

概念上,mask 后的序列 loss 是

\[ \mathcal{L} = \frac{ \sum_{b,t}m_{b,t}\ell_{b,t} }{ \sum_{b,t}m_{b,t} }. \]

WarningPitfall: Last Hidden Is Not Always the Last Column

For padded batches, h[:, -1] may be the hidden state after padding. Use sequence lengths or packed sequences to select the last real timestep.

用 loss mask 写语言模型 loss 时,推荐显式控制 denominator:

loss_per_token = F.cross_entropy(
    logits.reshape(-1, vocab),
    targets.reshape(-1),
    reduction="none",
).view_as(targets)

mask = targets.ne(pad_id)
loss = (loss_per_token * mask).sum() / mask.sum().clamp_min(1)

这比直接 mean() 更稳,因为不同 batch 的 padding 比例可能不同。若用 ignore_index=pad_id,也要确认框架的 reduction 是否按非 ignore token 平均。

pack_padded_sequence 的典型写法:

lengths = attention_mask.sum(dim=1).cpu()
x = embedding(input_ids)
packed = nn.utils.rnn.pack_padded_sequence(
    x,
    lengths,
    batch_first=True,
    enforce_sorted=False,
)
packed_out, h_n = rnn(packed)
out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)

packed sequence 能跳过 padding 的 recurrent computation,但它也带来一些麻烦:长度必须在 CPU 上,输出顺序/长度要对齐,和自定义 attention 或 residual block 组合时更容易出错。因此很多现代代码会选择普通 padded tensor + mask,牺牲一点计算换取实现清晰。

Sequence-to-Sequence and Attention

RNN 不只做 language model。经典 seq2seq 模型用 encoder RNN 把输入序列压成一个 context,再用 decoder RNN 逐步生成输出:

\[ h_t^{\text{enc}}=f_{\text{enc}}(x_t,h_{t-1}^{\text{enc}}), \qquad s_u=f_{\text{dec}}(y_{u-1},s_{u-1},c_u). \]

最早的 encoder-decoder 常用最后一个 encoder state 作为 context:

\[ c=h_T^{\text{enc}}. \]

这对长句很吃力,因为所有输入信息都挤在一个向量里。attention 的做法是让 decoder 每一步从所有 encoder states 中读:

\[ e_{u,t}=v^\top\tanh(W_s s_{u-1}+W_h h_t^{\text{enc}}), \]

\[ \alpha_{u,t} = \frac{\exp(e_{u,t})}{\sum_j\exp(e_{u,j})}, \qquad c_u=\sum_t \alpha_{u,t}h_t^{\text{enc}}. \]

这已经很接近 Transformer 的核心思想:不要把序列全部压进一个固定向量,而是在需要时按权重读取历史 states。

NoteDefinition: Additive Attention

Additive attention scores each encoder state with a learned compatibility function, normalizes scores with softmax, and forms a context vector as a weighted sum of encoder states.

训练 decoder 时通常仍然 teacher forcing:

decoder input:  <bos> 我 喜欢 深度 学习
target:         我    喜欢 深度 学习 <eos>

生成时则 autoregressive decode。这个训练/推理差异和 RNN language model 一样,是 exposure bias 的来源之一。

RNN as a Training Paradigm

RNN 训练引入了几个后来仍然重要的训练范式:

Paradigm RNN form Later echo
teacher forcing feed ground-truth previous token LM cross entropy
truncated BPTT backprop only for \(K\) steps memory-efficient training
sequence-to-sequence encoder state conditions decoder Transformer encoder-decoder
autoregressive generation sample one token at a time GPT-style decoding
scheduled sampling mix ground-truth and model tokens exposure bias mitigation

Transformer 替代了 RNN 的 sequential computation,但没有替代 autoregressive factorization。今天 GPT 的训练目标仍然是 RNN 时代就清楚的 next-token prediction。

Minimal PyTorch Sketch

import torch
import torch.nn.functional as F
from torch import nn


class TinyRNNLM(nn.Module):
    def __init__(self, vocab: int, dim: int, hidden: int):
        super().__init__()
        self.emb = nn.Embedding(vocab, dim)
        self.rnn = nn.GRU(dim, hidden, batch_first=True)
        self.head = nn.Linear(hidden, vocab)

    def forward(self, idx, state=None):
        x = self.emb(idx)
        h, state = self.rnn(x, state)
        return self.head(h), state

这个最小模型已经包含语言模型的主要接口:输入 token ids,输出每个位置的 next-token logits。

训练一步可以写成:

def lm_step(model, batch, pad_id):
    # batch: [B, T]
    inp = batch[:, :-1]
    tgt = batch[:, 1:]

    logits, _ = model(inp)
    loss = F.cross_entropy(
        logits.reshape(-1, logits.size(-1)),
        tgt.reshape(-1),
        ignore_index=pad_id,
    )
    return loss

若做 TBPTT,训练 loop 要显式管理 state:

state = None
for chunk in stream:
    logits, state = model(chunk[:, :-1], state)
    loss = F.cross_entropy(
        logits.reshape(-1, logits.size(-1)),
        chunk[:, 1:].reshape(-1),
        ignore_index=pad_id,
    )
    opt.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    opt.step()
    state = detach_state(state)

Implementation Checklist

实现或调试 RNN 时,可以按下面顺序查:

  1. 输入/target 是否正确 shift 一位;
  2. logits shape 是否是 [B, T, V],target shape 是否是 [B, T]
  3. loss 是否忽略 PAD,并按非 PAD token 归一化;
  4. 分类任务是否取最后有效 timestep,而不是 padded column;
  5. TBPTT 是否 carry state but detach graph;
  6. LSTM 是否同时 detach (h, c)
  7. hidden state 是否跨不同样本错误复用;
  8. stateful batch slots 是否在新 stream 开始时 reset;
  9. 若 batch row 会复用,是否用 stream id 路由 hidden state;
  10. recurrent grad norm 是否需要 clipping;
  11. generation 时是否把模型自己的输出喂回去;
  12. teacher forcing ratio 或 scheduled sampling 是否只在训练阶段使用;
  13. bidirectional RNN 是否被误用于因果语言建模;
  14. packed sequence 输出是否和原 batch 顺序对齐。

两个 smoke tests:

# 1. label shifting changes target length by one
batch = torch.randint(0, vocab, (2, 8))
inp, tgt = batch[:, :-1], batch[:, 1:]
logits, _ = model(inp)
assert logits.shape[:2] == tgt.shape

# 2. TBPTT detach breaks graph across chunks
_, state = model(inp)
state = detach_state(state)
if isinstance(state, tuple):
    assert all(not s.requires_grad for s in state)
else:
    assert not state.requires_grad

# 3. reset mask blocks hidden-state leakage across streams
old_state = torch.randn_like(state) if not isinstance(state, tuple) else tuple(torch.randn_like(s) for s in state)
assert_reset_blocks_state_leak(model, inp, old_state)

这类测试不证明模型会学好,但能抓住最常见的 RNN 工程错误:错位、PAD 污染、hidden state 泄漏,以及长序列图没有被正确截断。