Transformer-XL


Transformer 的语言模型训练是在数百字符的独立定长片段上进行的,上下文长度固定,片段间不存在任何信息流动,模型无法捕捉超出预设上下文范围的长期依赖关系。 Transformer-XL 架构能够在保持时序连贯性的同时突破固定长度限制来学习依赖关系。

The Fixed-Context Problem

标准 Transformer LM 训练时会把长文本切成长度为 \(L\) 的片段:

segment 1: x_1, ..., x_L
segment 2: x_{L+1}, ..., x_{2L}
segment 3: x_{2L+1}, ...

如果每个 segment 独立训练,那么 segment 2 的第一个 token 无法看到 segment 1 的内容。这样会产生两个问题:

  1. context fragmentation: 跨 segment 的长期依赖被切断;
  2. inefficient evaluation: 推理时如果滑动窗口每次重算整段,重复计算很多。

Transformer-XL 的核心是 segment-level recurrence:上一段的 hidden states 被缓存,作为下一段的额外 memory。

NoteDefinition: Context Fragmentation

Context fragmentation is the loss of dependencies across independently processed fixed-length segments. It occurs when a language model cannot condition on tokens before the current segment boundary.

如果 segment 长度是 \(L\),标准固定窗口训练实际优化的是

\[ \prod_{\tau} \prod_{i=1}^{L} p_\theta(x_{\tau,i}\mid x_{\tau,<i}), \]

而不是完整文本的

\[ \prod_{t} p_\theta(x_t\mid x_{<t}). \]

这里的差别不是小实现细节。每个 segment 开头的 token 被迫在缺失前文的情况下预测,模型学到的是被截断的条件分布。

Segment-Level Recurrence

令第 \(\tau\) 个 segment 的输入为

\[ X_\tau=[x_{\tau,1},\ldots,x_{\tau,L}]. \]

\(n\) 层 hidden states 记为

\[ H_\tau^{(n)}\in\mathbb{R}^{L\times d}. \]

Transformer-XL 在计算当前 segment 第 \(n\) 层时,把上一段同层 hidden states 缓存下来:

\[ \tilde{H}_\tau^{(n-1)} = [\operatorname{SG}(H_{\tau-1}^{(n-1)}); H_\tau^{(n-1)}], \]

其中 \(\operatorname{SG}\) 表示 stop-gradient。当前 segment 的 queries 来自 \(H_\tau^{(n-1)}\),keys/values 来自拼接后的 memory:

\[ Q=H_\tau^{(n-1)}W^Q, \qquad K=\tilde{H}_\tau^{(n-1)}W^K, \qquad V=\tilde{H}_\tau^{(n-1)}W^V. \]

这像 RNN 的 hidden state,但粒度更高:缓存的不是单个向量,而是一段 token-level hidden states。

若 memory 长度为 \(M\),当前 segment 长度为 \(L\),第 \(n\) 层 attention 的 score shape 是

\[ A_\tau^{(n)} \in \mathbb{R}^{L\times(M+L)}. \]

当前 segment 的每个 query 可以看见前一段 memory 和当前 segment 中不违反 causal mask 的位置。训练时常只缓存前一段或固定数量的历史 hidden states;推理时可以滚动保留最近 \(M\) 个 states。

更完整地,Transformer-XL 的每层 memory 可以写成一个 list:

\[ \operatorname{mems}_\tau = \left[ M_\tau^{(0)},M_\tau^{(1)},\ldots,M_\tau^{(N)} \right], \]

其中 \(N\) 是层数,\(M_\tau^{(n)}\in\mathbb{R}^{M\times d}\)。第 \(n\) 层计算后更新:

\[ M_{\tau+1}^{(n)} = \operatorname{Tail}_M \left( [M_\tau^{(n)};H_\tau^{(n)}] \right). \]

这里 \(\operatorname{Tail}_M\) 表示只保留最后 \(M\) 个 hidden states。注意 memory 不是一个全局单向量,而是每层都有自己的历史 hidden states;这也是它比“把上一段最后一个 token 当 RNN state”更强的地方。

NoteDefinition: Segment-Level Recurrence

Segment-level recurrence reuses hidden states from previous segments as attention memory for the current segment while stopping gradients through those cached states.

Batched Tensor Contract

实际实现通常带 batch 维。令 batch size 为 \(B\),segment length 为 \(L\),memory length 为 \(M\),hidden size 为 \(d\),layer 数为 \(N\)。第 \(n\) 层 memory 和当前 hidden 的 shape 是:

\[ M_\tau^{(n)}\in\mathbb{R}^{B\times M\times d}, \qquad H_\tau^{(n)}\in\mathbb{R}^{B\times L\times d}. \]

进入第 \(n\) 层 attention 前,把 memory 和当前 hidden 在 sequence 维拼接:

\[ \tilde{H}_\tau^{(n-1)} = \left[ \operatorname{SG}(M_\tau^{(n-1)}), H_\tau^{(n-1)} \right] \in \mathbb{R}^{B\times(M+L)\times d}. \]

于是 projection 后:

\[ Q\in\mathbb{R}^{B\times H\times L\times d_h}, \qquad K,V\in\mathbb{R}^{B\times H\times(M+L)\times d_h}. \]

注意 query 只来自当前 segment,key/value 来自 memory + current。若误把 memory 也放进 query,就会让旧 token 再次产生输出和 loss,训练目标变成另一种长序列 Transformer。

ImportantContract: Transformer-XL Queries Are Current-Segment Only

Transformer-XL reuses previous hidden states as keys and values, but it computes queries and loss only for the current segment.

Memory 更新也有明确 contract。每层保留最新 \(M\) 个 hidden states:

def update_mems(mems, hids, mem_len):
    # mems[n]: [B, M, D], hids[n]: [B, L, D]
    if mem_len == 0:
        return [h[:, :0].detach() for h in hids]

    new_mems = []
    for old, cur in zip(mems, hids):
        cat = torch.cat([old, cur], dim=1)
        new_mems.append(cat[:, -mem_len:].detach())
    return new_mems

这里 detach 应在写入下一步 memory 前完成。若忘记 detach,训练可能短时间能跑,但计算图会跨 segment 增长;若忘记裁剪,attention 成本会随训练步数线性增长。

Causal Mask With Memory

当前 segment 长度为 \(L\),memory 长度为 \(M\)。对 query position \(i\in\{0,\ldots,L-1\}\),key positions 包括:

memory:  -M, ..., -1
current:  0, ..., L-1

causal mask 允许当前 token attend 到全部 memory 和当前 segment 中不超过自己的位置:

\[ \operatorname{allow}(i,j) = \begin{cases} 1, & j<0,\\ 1, & 0\leq j\leq i,\\ 0, & j>i. \end{cases} \]

所以 attention logits 的 mask shape 是 \(L\times(M+L)\)。这和普通 causal mask 的 \(L\times L\) 不同,前面多出来的 \(M\) 列通常全可见。

WarningPitfall: Transformer-XL Mask Is Not a Plain Triangular Matrix

With memory, the attention mask is rectangular: current queries attend to all memory keys plus causal current keys. Using a plain \(L\times L\) triangular mask silently drops the recurrence benefit.

一个最小 visible mask 可以这样构造:

def make_txl_mask(q_len, mem_len, device):
    # True means visible. Shape: [q_len, mem_len + q_len]
    mem_visible = torch.ones(q_len, mem_len, dtype=torch.bool, device=device)
    cur_visible = torch.ones(q_len, q_len, dtype=torch.bool, device=device).tril()
    return torch.cat([mem_visible, cur_visible], dim=-1)

例如 \(M=3,L=4\)

query 0: [m m m | x . . .]
query 1: [m m m | x x . .]
query 2: [m m m | x x x .]
query 3: [m m m | x x x x]

其中 m 表示 memory visible,x 表示当前 segment 中可见 token,. 表示未来 token 被 mask。这个矩形可见性正是 Transformer-XL 和普通 causal Transformer 的关键差别。

如果 batch 中某些 row 刚进入新文档,则对应 row 的 memory 应该清空或 mask 掉。否则即使矩形 mask 正确,也会把上一篇文档当作当前文档的前文。

Why Stop Gradient

如果不 stop-gradient,反向传播会跨越无限多 segment,显存不可控。Transformer-XL 只把旧 segment 当作 memory 读,不把梯度传回旧 segment:

\[ \frac{\partial H_\tau}{\partial H_{\tau-1}} \text{ is not backpropagated through memory.} \]

这样训练成本仍然接近固定窗口 Transformer,但模型能在前向计算中利用更长上下文。

更准确地说,Transformer-XL 把前向上下文长度和反向传播长度解耦:

Quantity Meaning
segment length \(L\) tokens receiving loss and gradients in this step
memory length \(M\) previous hidden states visible to attention
BPTT length approximately \(L\), because memory is stop-gradient

如果不 stop-gradient,计算图会跨 segment 链接:

segment 1 -> segment 2 -> segment 3 -> ...

显存随历史长度增长,训练退化成长序列 BPTT。stop-gradient 则变成:

old hidden states --read only--> current segment
WarningPitfall: Memory Is Context, Not Trainable Past

Transformer-XL memory improves the forward conditioning context, but gradients do not update old segment activations. It is not the same as backpropagating through the entire document.

Truncated BPTT Analogy

Transformer-XL 很像 RNN 的 truncated BPTT。RNN 会把 hidden state 传到下一段,但 detach hidden state:

h = h.detach()

Transformer-XL 做的是 token-level hidden states 的 detach:

mems = [m.detach() for m in new_mems]

不同点是,RNN memory 通常是每层一个向量;Transformer-XL memory 是每层一段 token states。于是当前 token 可以通过 attention 有选择地读取历史多个位置,而不是只读压缩后的最后状态。

ImportantTheorem: Stop-Gradient Bounds Backward Graph Length

If cached memory states are detached at every segment boundary, gradients from segment \(\tau\) do not propagate into computations that produced segment \(\tau-1\) memory states.

设当前 segment loss 为 \(\mathcal{L}_\tau\),memory 输入为

\[ \tilde{M}_{\tau}^{(n)} = \operatorname{SG}(M_{\tau}^{(n)}). \]

stop-gradient 的定义是:

\[ \frac{\partial \operatorname{SG}(z)}{\partial z}=0. \]

因此链式法则中,从 \(\mathcal{L}_\tau\)\(M_\tau^{(n)}\) 的路径包含该零 Jacobian:

\[ \frac{\partial \mathcal{L}_\tau}{\partial M_\tau^{(n)}} = \frac{\partial \mathcal{L}_\tau}{\partial \operatorname{SG}(M_\tau^{(n)})} \frac{\partial \operatorname{SG}(M_\tau^{(n)})}{\partial M_\tau^{(n)}} =0. \]

所以当前 segment 的梯度不会回到产生上一段 memory 的计算图。

Relative Positional Encoding

直接复用 absolute position 会出问题:同一个 token 在不同 segment 中的 absolute index 改变,memory 的位置意义会混乱。Transformer-XL 使用 relative positional encoding,让 attention 依赖相对距离而不是绝对编号。

一个简化的相对 attention score 可以写作

\[ A_{ij} = q_i^\top k_j + q_i^\top r_{i-j}, \]

其中 \(r_{i-j}\) 表示 query position \(i\) 与 key position \(j\) 的相对位移 embedding。

更完整的 Transformer-XL 分解把 content-content、content-position 和 global bias 分开,使模型知道:

  1. 当前 token 内容和 memory token 内容是否相关;
  2. 两者距离多远;
  3. 不同 head 对内容/位置的全局偏好。

Transformer-XL 的 relative attention score 可写成四项:

\[ A_{ij} = \underbrace{q_i^\top k_j}_{\text{content-content}} + \underbrace{q_i^\top r_{i-j}}_{\text{content-position}} + \underbrace{u^\top k_j}_{\text{global content bias}} + \underbrace{v^\top r_{i-j}}_{\text{global position bias}}. \]

其中 \(u,v\) 是每个 head 学到的全局 bias。四项分别回答:

  1. 当前内容和历史内容是否相关;
  2. 当前内容是否偏好某种相对距离;
  3. 某些历史内容是否总体更容易被读;
  4. 某些相对距离是否总体更重要。

Dimension Walkthrough

对单个 attention head:

\[ Q\in\mathbb{R}^{L\times d_h}, \qquad K\in\mathbb{R}^{(M+L)\times d_h}. \]

relative position embedding 覆盖从最远 memory 到当前 segment 的距离:

\[ R\in\mathbb{R}^{(M+L)\times d_h}. \]

content-content term:

\[ QK^\top\in\mathbb{R}^{L\times(M+L)}. \]

content-position term:

\[ QR^\top\in\mathbb{R}^{L\times(M+L)}. \]

global content bias 可以写成:

\[ \mathbf{1}_L (u^\top K^\top) \in \mathbb{R}^{L\times(M+L)}. \]

global position bias:

\[ \mathbf{1}_L (v^\top R^\top) \in \mathbb{R}^{L\times(M+L)}. \]

四项 shape 相同,才能相加后进入 softmax:

\[ \operatorname{Attn}(Q,K,V) = \operatorname{softmax} \left( \frac{A+M_{\text{causal}}}{\sqrt{d_h}} \right)V. \]

ImportantTheorem: Relative Positions Are Stable Across Segments

If attention scores depend on \(i-j\) rather than absolute indices, shifting both current and memory positions by the same segment offset leaves the positional part of the score unchanged.

设 segment offset 为 \(c\)。原来的 query/key positions 是 \(i,j\),平移后是 \(i'=i+c,j'=j+c\)。相对距离为

\[ i'-j' = (i+c)-(j+c) = i-j. \]

因此任何只依赖 \(i-j\) 的 positional term 都不随 segment offset 改变。对于复用 memory 的模型,这正是需要的性质:同一段历史被放在不同 absolute index 下,仍有一致的相对距离语义。

Relative Shift Trick

朴素计算 \(q_i^\top r_{i-j}\) 需要构造每个 query-key pair 的相对位置 embedding。工程实现中会先计算

\[ Q R^\top \]

得到一个包含多种相对位移的矩阵,再通过 reshape/slice 做 relative shift,让第 \(i\) 行第 \(j\) 列对齐到 \(r_{i-j}\)

概念上:

raw relative logits -> pad -> reshape -> slice -> aligned relative logits

这个 trick 不改变数学,只是避免显式为每个 pair gather 相对 embedding。它是 Transformer-XL 比普通“讲公式”更工程化的地方:长上下文不仅要数学可行,也要张量布局可行。

一个简化的 relative shift 可以想成:

import torch


def relative_shift(x):
    # x: [batch, heads, q_len, k_len]
    b, h, q_len, k_len = x.shape
    zero = x.new_zeros(b, h, q_len, 1)
    x = torch.cat([zero, x], dim=-1)
    x = x.view(b, h, k_len + 1, q_len)
    x = x[:, :, 1:].view(b, h, q_len, k_len)
    return x

真实实现会因 layout 不同而略有差异,但核心都是通过 reshape 把“按 relative distance 排列的 logits”对齐成“按 key position 排列的 logits”。

WarningPitfall: Relative Shift Bugs Are Shape-Correct

Relative-shift mistakes often preserve tensor shapes but assign the wrong distance embedding to each query-key pair. The model still trains, but position information is corrupted.

Relative Shift Golden Test

relative shift 最容易写成“shape 正确但语义错”。一个小测试是用整数标记 relative-distance 列,检查 shift 后第 \((i,j)\) 个位置是否对应目标距离。

假设 key positions 是

memory:  -M, ..., -1
current:  0, ..., L-1

query position 是 \(i\),key position 是 \(j\),目标相对距离是 \(i-j\)。可以构造一个 reference:

def reference_rel_ids(q_len, mem_len):
    key_pos = torch.arange(-mem_len, q_len)
    query_pos = torch.arange(q_len)
    return query_pos[:, None] - key_pos[None, :]

然后让 relative-shift 路径输出的每个位置携带同样的 distance id。若两者不一致,说明 reshape/slice 方向错了。这个测试比只看 loss 更早暴露 bug,因为位置错配的模型仍可能靠内容项训练下降。

WarningPitfall: Relative Position Tests Need Semantic Labels

Testing only tensor shapes cannot validate relative shift. Use a tiny distance-labeled example to verify that each query-key pair receives the intended relative position.

Training and Inference

训练时,每个 batch 处理一个 segment,并缓存前一 segment 的 hidden states。推理时,memory 可以滚动更新:

read current segment
attend to memory + current tokens
produce hidden states
append current hidden states into memory
drop oldest states if memory too long
NoteDefinition: Effective Context Length

Transformer-XL’s effective context length is approximately current segment length plus memory length: \[ L_{\text{eff}} = L_{\text{segment}} + L_{\text{mem}}. \]

Complexity

普通 segment Transformer 每层 attention 成本:

\[ O(L^2d). \]

Transformer-XL 每层 attention 成本:

\[ O(L(M+L)d). \]

\(M=L\),大约是普通 segment 的 \(2\times\) attention score 成本,但 effective context 变为 \(2L\)。如果用普通滑动窗口每次处理 \(2L\) 长度,则成本是:

\[ O((2L)^2d)=O(4L^2d), \]

而且相邻窗口会重复计算大量历史 hidden states。Transformer-XL 的 memory 复用让历史 hidden states 不必每次重算。

Memory storage per layer per sequence 约为:

\[ M\cdot d\cdot s, \]

其中 \(s\) 是每个元素字节数。总共 \(N\) 层:

\[ N M d s. \]

这比 KV cache 形式不同,但同样会随 layer、memory length 和 hidden size 线性增长。

Training Loop Sketch

一个最小训练 loop 可以想成:

mems = None

for segment, target, reset in stream:
    if reset:
        mems = None
    logits, new_mems = model(segment, mems=mems)
    loss = cross_entropy(logits, target)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    mems = [m.detach() for m in new_mems]

关键点是:

  1. mems 是每层 hidden states;
  2. 当前 loss 只对当前 segment 计算;
  3. detach() 实现 stop-gradient;
  4. memory 长度要裁剪,否则显存和 attention 成本都会增长。
WarningPitfall: Reusing Memory Across Documents Can Leak Context

When batching multiple documents, memory must be reset at document boundaries unless cross-document conditioning is intended. Otherwise the model can condition on unrelated previous text.

Batched Streams

batch 中每个 row 可能来自不同文档流:

batch row 0: doc A segment 1 -> doc A segment 2 -> ...
batch row 1: doc B segment 1 -> doc B segment 2 -> ...

因此 reset flag 应该是 per-row,而不是整个 batch 一个布尔值。若 row 0 进入新文档,只应清空 row 0 的 memory;row 1 若仍在同一文档,应保留 memory。

概念上:

for layer, mem in enumerate(mems):
    mem[reset_rows] = 0

实际实现还要处理不同文档长度、packed stream、distributed sampler 和 worker 边界。Transformer-XL 的 recurrence 强依赖数据管线:如果 batch 顺序被 shuffle 到不连续,memory 就不再代表真实前文。

WarningPitfall: Shuffling Can Destroy Recurrence

Segment-level recurrence assumes consecutive segments from the same stream arrive in order. Ordinary example-level shuffling breaks this assumption unless the dataloader preserves stream state.

Stream Sampler Contract

Transformer-XL 的 dataloader 更像“多条连续文本流的状态机”,而不是普通独立样本 shuffle。每个 batch row 应维护:

State Meaning
doc_id 当前 row 属于哪篇文档或哪条 token stream
offset 当前 segment 在 stream 中的起点
reset 当前 segment 是否没有合法前文
mems[row] 与该 row 前文对应的 per-layer memory

一个简化的 stream sampler:

class StreamState:
    def __init__(self, docs, seg_len):
        self.docs = docs
        self.seg_len = seg_len
        self.doc_id = 0
        self.offset = 0
        self.reset = True

    def next_segment(self):
        doc = self.docs[self.doc_id]
        start = self.offset
        end = min(start + self.seg_len, len(doc))
        segment = doc[start:end]
        reset = self.reset

        self.offset = end
        self.reset = False
        if self.offset >= len(doc):
            self.doc_id = (self.doc_id + 1) % len(self.docs)
            self.offset = 0
            self.reset = True

        return segment, reset

真实训练会加 distributed sharding、padding、drop-last、随机起点和 epoch 边界,但核心 invariant 不变:row 的 memory 必须对应 row 的真实前文。如果 dataloader 为了打乱数据把 row 的下一个 segment 换成另一篇文档,就必须同时给 reset=True 并清空该 row memory。

ImportantContract: Memory Belongs to a Stream Row

In Transformer-XL training, memory is not a global cache. It is per layer and per batch row, and it is valid only while that row continues the same text stream.

Relation to KV Cache

Transformer-XL memory 和现代 decoder-only LLM 的 KV cache 很像,但语境不同:

Mechanism What is cached Purpose
Transformer-XL memory previous segment hidden states long-range training/evaluation
KV cache per-layer keys and values fast autoregressive decoding

两者都在避免重复计算,也都承认一个事实:长上下文不是只改一个 max_length,而是 memory representation、position encoding、训练方式和推理系统共同决定的。

更细的差别:

Aspect Transformer-XL memory Decoder KV cache
stored object hidden states before projection projected keys/values
training use yes, across segments usually no, mainly inference
gradient stopped through memory inference has no gradient
position scheme relative attention needed RoPE/ALiBi/absolute variants
update unit segment token or chunk

Transformer-XL memory 要重新经过当前层的 \(W^K,W^V\) 投影;KV cache 已经保存投影后的 keys/values。前者更像“复用历史 hidden states”,后者更像“复用 attention computation products”。

Limitations

Transformer-XL 缓解了 fixed segment 的断裂,但它仍然有注意力成本。当前 segment 的每个 query 要 attend 到 memory + current tokens:

\[ O(L_{\text{segment}}(L_{\text{segment}}+L_{\text{mem}})d). \]

因此它不是无限上下文方案,而是用缓存和相对位置编码在固定成本附近扩展可用上下文。

Transformer-XL vs Modern Long-Context LLMs

Transformer-XL 的思想在今天仍然有教学价值,但现代 decoder-only LLM 通常采用不同组合:

Method Main idea Training/inference role
Transformer-XL cache previous hidden states across segments long-range LM with recurrence
KV cache cache per-layer keys/values fast AR decoding
RoPE scaling stretch relative position geometry extend context length
sliding-window attention restrict attention to recent window bound compute and memory
recurrence/SSM hybrid compress long history into state reduce cache growth

Transformer-XL 的核心贡献不是“现在最强长上下文方案”,而是它清楚地提出了一个问题:固定长度 attention 会切断文本连续性,必须让模型在 segment 边界之外保留某种 state。

Implementation Checklist

实现或阅读 Transformer-XL 时,重点检查:

  1. 每层 memory shape 是否是 [batch, mem_len, hidden]
  2. memory 是否在 segment boundary detach;
  3. memory 是否按 mem_len 裁剪;
  4. attention mask 是否是 [q_len, mem_len + q_len]
  5. relative position logits 是否通过 shift 对齐;
  6. 文档边界是否 reset memory;
  7. dataloader 是否保留连续 segment 顺序;
  8. loss 是否只对当前 segment tokens 计算;
  9. evaluation 是否复用 memory,而不是每段从空上下文开始。

如果这几项有一项错,模型仍可能正常跑、loss 也下降,但学到的就不是 Transformer-XL 论文里的 recurrence 机制。

References