4.3 Masks and Positional Encoding


Transformer 本身对输入顺序没有天然感知。顺序来自 positional encoding;可见性来自 attention mask。不同 mask 和 position 设计,定义了模型能学习什么条件分布。

Running Example: Why Causal Mask Exists

训练句子:

I love machine learning

如果目标是 next-token prediction,那么在位置 machine 预测下一个 token learning 时,模型可以看 I love machine,但不能看 learning 本身。否则训练会退化成作弊。

对长度 \(4\) 的序列,causal mask 对应可见矩阵:

\[ \begin{bmatrix} 1&0&0&0\\ 1&1&0&0\\ 1&1&1&0\\ 1&1&1&1 \end{bmatrix}. \]

\(i\) 行表示第 \(i\) 个 token 可以看哪些位置。这个矩阵就是 GPT 和 BERT 训练目标差异的核心之一:GPT 用下三角 mask,BERT 用双向可见再随机 mask token。

Masks

padding mask 用于忽略无效 token。causal mask 用于禁止模型看未来:

\[ M_{ij} = \begin{cases} 0, & j\leq i,\\ -\infty, & j>i. \end{cases} \]

attention logits 变为

\[ A = \frac{QK^\top}{\sqrt{d_k}}+M. \]

NoteDefinition: Causal Language Modeling

Causal language modeling trains \[ p_\theta(x_{1:T}) = \prod_{t=1}^{T}p_\theta(x_t\mid x_{<t}), \] implemented by a causal attention mask that prevents position \(t\) from attending to positions \(>t\).

mask 的理论意义是把 attention computation 限制到某个 \(\sigma\)-field,也就是“当前位置允许知道的信息”。如果第 \(t\) 个位置的 hidden state 可以依赖 \(x_{>t}\),那么它建模的就不是

\[ p_\theta(x_t\mid x_{<t}), \]

而更接近一个泄漏未来的条件分布

\[ p_\theta(x_t\mid x_{<t},x_{>t}). \]

这会让 next-token training 的 likelihood 分解失效。也就是说,causal mask 不是一个实现小技巧,而是 autoregressive factorization 的计算图约束。

ImportantTheorem: Future Visibility Invalidates AR Maximum Likelihood

If the representation used to predict \(x_t\) depends on any future token \(x_s\) with \(s>t\), then the model is not optimizing the autoregressive factorization \(\prod_t p_\theta(x_t\mid x_{<t})\) for that computation graph.

AR maximum likelihood 要求第 \(t\) 项 logits 是某个函数

\[ z_t=f_\theta(x_{<t}) \]

或在常见 label-shift 实现中等价地由当前位置之前的可见 token 决定。若 attention 允许 \(x_s,s>t\) 进入 hidden state,则 logits 变为

\[ z_t=f_\theta(x_{<t},x_s,\ldots). \]

此时 loss 里的第 \(t\) 项虽然形式上仍是 cross entropy,

\[ -\log p_\theta(x_t\mid z_t), \]

\(z_t\) 已经包含未来 token 信息。因此它对应的条件变量集合不是 \(x_{<t}\),不能再解释为 AR factorization 的第 \(t\) 项。

不同 mask 对应不同 factorization:

visibility factorization / objective typical use
causal \(\prod_t p(x_t\mid x_{<t})\) GPT-style LM
bidirectional denoising / classification style objective BERT encoder
prefix-LM \(p(y\mid x_{\text{prefix}})\) with causal target instruction/conditional generation
block diagonal independent examples inside one packed sequence efficient pretraining batches

所以调 mask 时要先问:我到底想让模型学习哪个条件分布?如果只是为了“shape 能跑”而改 mask,很容易把训练目标悄悄改掉。

Additive Mask Implementation

实践里 mask 通常不是在 softmax 后把概率乘 0,而是在 softmax 前把 logits 加上一个大负数:

\[ \alpha_{ij} = \operatorname{softmax}_j(s_{ij}+m_{ij}). \]

如果 \(m_{ij}=-\infty\),则

\[ \exp(s_{ij}+m_{ij})=0, \]

所以该位置概率严格为 0。代码里常见 shape:

Mask Shape Broadcast target
causal \([1,1,T,T]\) \([B,H,T,T]\)
padding \([B,1,1,T]\) \([B,H,T,T]\)
prefix/block \([B,1,T,T]\) \([B,H,T,T]\)

组合 mask 时直接相加:

\[ M=M_{\text{causal}}+M_{\text{pad}}+M_{\text{task}}. \]

WarningPitfall: Mask Convention Must Match the Kernel

Some APIs use boolean masks where True means visible; others use True means masked. Some use additive masks. Always check the exact convention before passing masks to fused attention kernels.

更具体地,常见实现会遇到三种 mask convention:

convention dtype meaning of allowed position typical operation
additive float value is 0 scores = scores + mask
boolean-visible bool True means can attend scores.masked_fill(~mask, -inf)
boolean-blocked bool False means can attend or API-specific depends on kernel

安全写法是先构造一个“可见性矩阵” \(V\in\{0,1\}^{T_q\times T_k}\),再转换成目标 kernel 要求的格式:

def visibility_to_additive(visible, dtype):
    # visible: True means attention is allowed
    neg = torch.finfo(dtype).min
    return torch.where(
        visible,
        torch.zeros((), dtype=dtype, device=visible.device),
        torch.full((), neg, dtype=dtype, device=visible.device),
    )

注意这里用 torch.finfo(dtype).min 或足够大的负数,而不是在 FP16/BF16 里手写 -1e30。过大的负数可能溢出成 -inf,而某些 fused kernel 对全 masked 行的 -inf softmax 会产生 NaN。更稳的策略是保证每一行至少有一个可见位置,或者在 kernel 允许时使用内置 causal/padding mask。

PyTorch scaled_dot_product_attention 还会有一个常见分叉:

out = F.scaled_dot_product_attention(
    q, k, v,
    attn_mask=mask,
    is_causal=False,
)

如果已经传了自定义 attn_mask,再把 is_causal=True 打开,可能导致重复 masking 或 API 不允许二者同时使用。工程上最好把 mask 构造路径写成明确的二选一:

  1. 纯 decoder causal self-attention:用 is_causal=True,不传自定义 causal mask;
  2. prefix/block/packing attention:显式传 attn_mask,并关闭 is_causal
  3. padding-only attention:使用 padding mask 或把 padding 合并进 additive mask。
TipDebug Trick: Print One Attention Row

For a tiny sequence, print the visible indices for one query row before calling the fused kernel. It is easier to catch mask convention bugs before softmax.

Visibility Algebra and Kernel Invariants

把 mask 写成一堆 if-else 很容易失控。更稳的方式是先定义一个 visibility relation:

\[ V_{b,i,j}=1 \quad\Longleftrightarrow\quad \text{query }i\text{ in sample }b\text{ is allowed to read key }j. \]

这里 \(i\) 遍历 query positions,\(j\) 遍历 key positions。prefill self-attention 中通常有 \(T_q=T_k=T\);cached decode 中则常见 \(T_q=1\)\(T_k=T_{\text{past}}+1\)。所以工程里不要默认 mask 一定是方阵。更一般地,若 query 的逻辑位置为 \(p_i\),key 的逻辑位置为 \(r_j\),causal visibility 是

\[ V_{i,j}^{\text{causal}} = \mathbf{1}[r_j\le p_i]. \]

在普通 prefill 中 \(p_i=i,r_j=j\),这就退化成下三角矩阵;在 decode 中新 query 的 \(p_0=T_{\text{past}}\),所有缓存 key 的 \(r_j\le T_{\text{past}}\) 都可见。

NoteDefinition: Visibility Contract

A safe attention call should make the visibility relation explicit before converting it to a boolean or additive mask. The contract is about semantic positions, not only tensor columns.

左 padding、packed sequence、prefix LM 都是在这个基础 relation 上继续取交集。设 key validity 为 \(K_{b,j}\in\{0,1\}\),task visibility 为 \(G_{b,i,j}\),则

\[ V_{b,i,j} = K_{b,j}\cdot G_{b,i,j}. \]

如果 query row 本身是 padding,输出通常会在后续 loss 或 hidden-state selection 中被丢弃。但 attention kernel 仍然会对这一行做 softmax。于是出现一个很实际的不变量:

\[ \sum_j V_{b,i,j}\ge 1 \quad \text{for every query row sent to softmax}. \]

若某一整行全是不可见位置,softmax 分母为 \(0\),某些实现会返回 NaN,另一些 fused kernel 可能返回未定义值。对 padding query row,常见做法是让它至少看见某个 dummy pad key,并保证该 row 的输出不参与 loss;对真实 query row,则必须保证至少能看见自己或历史 token。

WarningPitfall: All-Masked Rows Are Numerical Bugs

Even if padded query outputs are later ignored, an all-masked row can poison the forward pass with NaN. Mask correctness includes row non-emptiness.

一个可复用的构造函数可以从逻辑位置出发:

def assert_real_rows_nonempty(visible, query_valid=None):
    ok = visible.any(dim=-1)
    if query_valid is not None:
        ok = ok | ~query_valid
    if not ok.all():
        raise ValueError("attention mask contains an all-masked real query row")

def rectangular_causal_visible(q_pos, k_pos, key_valid, query_valid=None):
    # q_pos: [B, Tq], logical position of each query
    # k_pos: [B, Tk], logical position of each key
    # key_valid: [B, Tk], True for real keys
    visible = k_pos[:, None, :] <= q_pos[:, :, None]
    visible = visible & key_valid[:, None, :]
    assert_real_rows_nonempty(visible, query_valid)
    return visible

cached decode with a chunk of \(T_q\) new tokens is just the rectangular version:

\[ p_i=T_{\text{past}}+i,\qquad r_j=j,\qquad V_{i,j}=\mathbf{1}[j\le T_{\text{past}}+i]. \]

This is why cache offset belongs to the model input. A decode mask that restarts query positions at zero is a valid tensor and an invalid computation graph.

ImportantProposition: Rectangular Causality Preserves the AR Graph

For cached decoding, if each new query at logical position \(p_i=T_{\text{past}}+i\) can attend only to keys with logical positions \(r_j\le p_i\), then the hidden state of that query depends only on tokens in the autoregressive prefix.

The cache stores hidden states or projected keys/values computed from previous tokens. By induction, cached key/value at logical position \(r_j<T_{\text{past}}\) depends only on tokens \(x_{\le r_j}\). The new query at \(p_i\) is allowed to read only keys with \(r_j\le p_i\), including earlier cached tokens and previous tokens in the same decode chunk. Therefore every read path into the query state is contained in \(x_{\le p_i}\), which is exactly the autoregressive prefix for predicting the next token after \(p_i\).

Padding Side and Loss Mask

对 decoder-only LLM,padding 有两个相互纠缠的问题:

  1. attention 能不能看见 pad token;
  2. loss 会不会在 pad target 上计算。

右 padding:

[A, B, C, EOS]
[D, E, EOS, PAD]

左 padding:

[A, B, C, EOS]
[PAD, D, E, EOS]

如果使用 absolute position embedding,左 padding 会改变真实 token 的 position ids,除非手动重新编号。对 RoPE,position ids 仍然影响旋转角度,所以也要保证真实 token 的位置语义正确。

训练时 labels 常用 ignore index:

\[ \ell_{b,t} = \begin{cases} -\log p_\theta(y_{b,t}\mid x_{b,\le t}),& y_{b,t}\ne -100,\\ 0,& y_{b,t}=-100. \end{cases} \]

masked loss:

\[ L = \frac{\sum_{b,t}m_{b,t}\ell_{b,t}} {\sum_{b,t}m_{b,t}}. \]

WarningPitfall: Attention Mask and Loss Mask Are Different

The attention mask controls what tokens can be read. The loss mask controls which targets contribute to training. Padding usually needs both.

position ids 是第三个独立对象。对一个 batch,常常同时有:

tensor controls example
attention_mask 哪些 input token 可被读取 pad token 不可见
labels == -100 哪些 target token 算 loss prompt/pad 不算
position_ids 每个真实 token 的位置编号 左 padding 后真实 token 从 0 开始

左 padding 的 generation batch 中,position_ids 通常应由 cumulative sum 得到:

attention_mask = input_ids.ne(pad_id).long()
position_ids = attention_mask.cumsum(dim=-1) - 1
position_ids = position_ids.masked_fill(attention_mask.eq(0), 0)

例如:

input:        [PAD, PAD, A, B, C]
attn_mask:   [0,   0,   1, 1, 1]
position_id: [0,   0,   0, 1, 2]

这样真实 token A,B,C 的位置仍然是 \(0,1,2\)。如果直接用 [0,1,2,3,4],那么 A 会被当成位置 \(2\),绝对位置 embedding 和 RoPE 角度都会偏移。

sequence packing 还会引入 block boundary。把两条样本拼成:

[A, B, EOS, D, E, EOS]

若希望它们是两个独立样本,那么第二条样本的 token 不应该 attend 第一条样本:

\[ V= \begin{bmatrix} 1&0&0&0&0&0\\ 1&1&0&0&0&0\\ 1&1&1&0&0&0\\ 0&0&0&1&0&0\\ 0&0&0&1&1&0\\ 0&0&0&1&1&1 \end{bmatrix}. \]

这不是普通 causal mask,而是 block-diagonal causal mask。若不加 block boundary,模型会把前一篇文档当作后一篇文档的 prefix,训练目标就被污染。

WarningPitfall: Packed Sequences Need Boundary-Aware Masks

Packing improves token utilization, but without block-diagonal causal masks and reset-aware position ids, examples leak context into each other.

更一般地,packing 需要同时维护两个语义坐标:

tensor meaning example for [A,B,EOS,D,E,EOS]
doc_ids token 属于哪条原始样本 [0,0,0,1,1,1]
pos_in_doc token 在原始样本内部的位置 [0,1,2,0,1,2]

block causal visibility 不是按 packed column index 判断,而是按 (doc_id, pos_in_doc) 判断:

def packed_causal_visible(doc_ids, pos_in_doc, key_valid, query_valid=None):
    # doc_ids, pos_in_doc, key_valid: [B, T]
    same_doc = doc_ids[:, :, None] == doc_ids[:, None, :]
    causal = pos_in_doc[:, None, :] <= pos_in_doc[:, :, None]
    visible = same_doc & causal & key_valid[:, None, :]
    assert_real_rows_nonempty(visible, query_valid)
    return visible

这里最容易写错的是 causal 的两个轴。query 维度在中间,key 维度在最后,所以判断是

\[ \text{pos\_in\_doc}_{b,j}\le \text{pos\_in\_doc}_{b,i}. \]

同时,position_ids 应该使用 pos_in_doc,而不是 packed offset:

packed token: [A, B, EOS, D, E, EOS]
packed index: [0, 1, 2,   3, 4, 5]
position_id: [0, 1, 2,   0, 1, 2]

如果第二篇文档的 D,E,EOS 使用位置 3,4,5,RoPE/absolute embedding 会把它当作长 prefix 后的延续;如果 mask 又没有 block boundary,模型不仅能读到前一篇文档,还会在位置几何上把两篇文档连接起来。

ImportantProposition: Packed Likelihood Factorizes by Documents Only with Boundary Masks

For a packed sequence containing independent documents, the packed training objective equals the sum of per-document causal language-model objectives only if attention visibility is block diagonal by document and positions reset within each document.

设 packed 序列由文档 \(x^{(1)},\ldots,x^{(m)}\) 拼接而成。独立文档的目标是

\[ \sum_{u=1}^m \sum_t \log p_\theta(x_t^{(u)}\mid x_{<t}^{(u)}). \]

若第 \(u\) 篇文档的 token 可以 attend 第 \(v\ne u\) 篇文档的 token,则其 logits 是

\[ z_t^{(u)} = f_\theta(x_{<t}^{(u)}, x^{(v)}, \ldots), \]

对应的条件变量集合已经不是 \(x_{<t}^{(u)}\)。若 position ids 不在文档边界处重置,则即使 visibility 不泄漏,模型看到的位置几何也不同于单独训练该文档时的位置几何。因此两者都需要满足,packed loss 才能解释为 per-document causal losses 的高效批处理实现。

Encoder, Decoder, and Prefix Masks

Mask type Visible context Typical model
bidirectional all non-padding tokens BERT encoder
causal previous tokens only GPT decoder
encoder-decoder target attends source + past target translation
prefix LM prefix bidirectional, suffix causal conditional generation

mask 是训练目标的一部分。BERT 和 GPT 不只是 architecture 不同,而是可见性结构和 likelihood factorization 不同。

Prefix LM Mask

Prefix LM 允许 prefix 内部双向可见,suffix 使用 causal mask。设 prefix 长度为 \(P\),总长度为 \(T\)。可见性为

\[ V_{ij} = \begin{cases} 1,& i<P\ \text{and}\ j<P,\\ 1,& i\ge P\ \text{and}\ j\le i,\\ 0,& \text{otherwise}. \end{cases} \]

这适合条件生成:prefix 是条件,suffix 是要生成的目标。比如图文生成、文档续写或 instruction + answer 格式,都可以看作某种 prefix-conditioned generation。

对应 additive mask:

\[ M_{ij} = \begin{cases} 0,& V_{ij}=1,\\ -\infty,& V_{ij}=0. \end{cases} \]

mask 是概率分解的工程实现。你改了 mask,就改了模型能访问的信息,也就改了训练目标。

代码里可以先构造 visibility,再转 additive mask:

def prefix_lm_visibility(seq_len, prefix_len, device):
    idx = torch.arange(seq_len, device=device)
    q = idx[:, None]
    k = idx[None, :]

    prefix_to_prefix = (q < prefix_len) & (k < prefix_len)
    suffix_causal = (q >= prefix_len) & (k <= q)
    return prefix_to_prefix | suffix_causal

如果 batch 内每条样本的 prefix_len 不同,visibility 就需要 shape [B, T, T],再 broadcast 到 [B, H, T, T]。这类 task-specific mask 往往无法只靠 is_causal=True 表达,需要显式传给 attention kernel。

Absolute Position Embeddings

最直接的做法是学习位置表:

\[ h_t^{(0)}=E[x_t]+P[t]. \]

缺点是外推到更长 context 困难,因为训练中未见过的位置没有可靠表示。

如果最大长度是 \(T_{\max}\),位置表参数量是

\[ T_{\max}d. \]

这部分参数不大,但外推弱:训练中没有更新过的 \(P[t]\) 不知道应该表示什么。GPT-2 使用 learned absolute position embeddings,所以把 context 从 1024 直接改到更长不是单纯改 config 就能可靠工作。

Sinusoidal Position Encoding

原始 Transformer 使用

\[ PE_{(pos,2i)}=\sin\left(\frac{pos}{10000^{2i/d}}\right), \]

\[ PE_{(pos,2i+1)}=\cos\left(\frac{pos}{10000^{2i/d}}\right). \]

它的好处是相对位移可以由线性变换表达,且不需要学习无限长的位置表。

更具体地,频率为 \(\omega_i=10000^{-2i/d}\)。某个二维 pair 是

\[ [\sin(\omega_i pos),\cos(\omega_i pos)]. \]

位置 \(pos+\Delta\) 可以由旋转矩阵作用在位置 \(pos\) 上得到:

\[ \begin{bmatrix} \sin(\omega(pos+\Delta))\\ \cos(\omega(pos+\Delta)) \end{bmatrix} = \begin{bmatrix} \cos(\omega\Delta)&\sin(\omega\Delta)\\ -\sin(\omega\Delta)&\cos(\omega\Delta) \end{bmatrix} \begin{bmatrix} \sin(\omega pos)\\ \cos(\omega pos) \end{bmatrix}. \]

这说明 sinusoidal encoding 中相对位移可以通过线性变换表达。

RoPE

Rotary Position Embedding 把位置信息注入到 \(Q,K\) 的旋转中。对二维 pair,可写作

\[ \begin{bmatrix} q_{2i}'\\ q_{2i+1}' \end{bmatrix} = \begin{bmatrix} \cos \theta_{pos,i} & -\sin \theta_{pos,i}\\ \sin \theta_{pos,i} & \cos \theta_{pos,i} \end{bmatrix} \begin{bmatrix} q_{2i}\\ q_{2i+1} \end{bmatrix}. \]

这样 \(q_m^\top k_n\) 会自然包含相对位置 \(m-n\) 的信息。现代 LLM 大量使用 RoPE 及其 context extension 变体。

ImportantTheorem: RoPE Dot Product Depends on Relative Position

For each two-dimensional rotary pair, if \(R_m\) and \(R_n\) are rotation matrices for positions \(m\) and \(n\), then \[ (R_m q)^\top(R_n k)=q^\top R_{n-m}k. \] Thus the attention score contains relative position \(n-m\).

Rotation matrices satisfy

\[ R_m^\top R_n=R_{n-m}. \]

Therefore

\[ (R_m q)^\top(R_n k) = q^\top R_m^\top R_n k = q^\top R_{n-m}k. \]

RoPE applies this independently to many two-dimensional pairs with different frequencies, so each pair contributes a relative-position-dependent term to the dot product.

工程实现通常不显式构造旋转矩阵,而是预先缓存 cossin

def rotate_half(x):
    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    return torch.stack((-x2, x1), dim=-1).flatten(-2)

def apply_rope(x, cos, sin):
    # x: [B, H, T, D], cos/sin: [1, 1, T, D]
    return x * cos + rotate_half(x) * sin

如果 head dimension 为 \(D\),RoPE 通常对偶数/奇数维成对旋转。频率为

\[ \omega_i=\theta^{-2i/D}, \qquad i=0,\ldots,D/2-1, \]

位置 \(p\) 的角度是

\[ \phi_{p,i}=p\omega_i. \]

实现中会把每个频率复制到 pair 的两个维度:

inv_freq = base ** (-torch.arange(0, head_dim, 2) / head_dim)
freqs = torch.outer(position_ids.float(), inv_freq)
emb = torch.repeat_interleave(freqs, repeats=2, dim=-1)
cos = emb.cos()[None, None, :, :]
sin = emb.sin()[None, None, :, :]

RoPE 的几个实现细节:

  1. 只旋转 \(Q,K\),不旋转 \(V\)
  2. decode 时新 token 的 position_id 必须包含 cache offset;
  3. padding token 的 position id 不应该影响真实 token;
  4. 若使用 partial rotary,只旋转 head dimension 的前一部分;
  5. Q 和 K 必须使用同一套 position convention,否则相对位置结构会断掉。
WarningPitfall: RoPE Bugs Are Often Numerically Plausible

Wrong position offsets or mismatched Q/K rotations usually do not crash. They produce valid-shaped logits with degraded long-context behavior.

RoPE Cache Contract

RoPE 的工程契约可以写得更精确:cos/sin 不是简单按当前 tensor column 取,而应按 position_ids gather。这样同一套实现才能同时支持左 padding、packing 和 cached decode。

def build_rope_cache(max_pos, rotary_dim, base, device):
    idx = torch.arange(0, rotary_dim, 2, device=device)
    inv_freq = base ** (-idx.float() / rotary_dim)
    pos = torch.arange(max_pos, device=device).float()
    freqs = torch.outer(pos, inv_freq)
    emb = torch.repeat_interleave(freqs, repeats=2, dim=-1)
    return emb.cos(), emb.sin()

def gather_rope(cos_cache, sin_cache, position_ids):
    # position_ids: [B, T]
    cos = cos_cache[position_ids]  # [B, T, Drot]
    sin = sin_cache[position_ids]
    return cos[:, None, :, :], sin[:, None, :, :]

partial rotary 还要求 rotary_dim 是偶数,并且不能超过 head_dim

def apply_partial_rope(x, cos, sin, rotary_dim):
    if rotary_dim % 2 != 0:
        raise ValueError("rotary_dim must be even")
    if rotary_dim > x.size(-1):
        raise ValueError("rotary_dim cannot exceed head_dim")

    x_rot = x[..., :rotary_dim]
    x_pass = x[..., rotary_dim:]
    x_rot = x_rot * cos + rotate_half(x_rot) * sin
    return torch.cat([x_rot, x_pass], dim=-1)

GQA/MQA 下,\(Q\)\(K\) 的 head 数不同:

q: [B, Hq,  T, D]
k: [B, Hkv, T, D]

但 position convention 仍然必须相同。RoPE 不是对 head index 编码,而是对 token position 编码;所以 cos/sin 可以 broadcast 到不同 head 数:

cos, sin = gather_rope(cos_cache, sin_cache, position_ids)
q = apply_partial_rope(q, cos, sin, rotary_dim)
k = apply_partial_rope(k, cos, sin, rotary_dim)

若实现先 repeat KV heads 再旋转,或先旋转再 repeat KV heads,数学上通常等价,因为 repeat 只复制 head 维度;但不能让 Q 使用左 padding 修正后的 position_ids,K 使用 raw column index。这会破坏

\[ (R_m q)^\top(R_n k)=q^\top R_{n-m}k \]

中的相对位移 \(n-m\)

TipSmoke Test: RoPE Relative Score

Pick two token pairs with the same relative distance and compare their rotary dot products. With identical unrotated vectors and correct position ids, scores should match up to numerical tolerance.

RoPE Scaling and Context Extension

把 RoPE 模型从训练长度 \(L_{\text{train}}\) 扩到更长 \(L_{\text{test}}\),核心问题是角度

\[ \phi_{p,i}=p\omega_i \]

在未训练过的 \(p\) 上会进入新的相位区域。不同 scaling 方法本质上是在重新定义位置 \(p\) 或频率 \(\omega_i\)

最简单的 linear position interpolation 是:

\[ p' = p\cdot \frac{L_{\text{train}}}{L_{\text{test}}}, \]

然后用 \(p'\) 计算 RoPE 角度。它让测试时最大位置被压回训练范围,但会压缩短距离分辨率。另一类做法调整 RoPE base 或按频率分段缩放,使低频维度负责长程外推,高频维度尽量保留局部分辨率。

method family changes trade-off
position interpolation replace \(p\) by scaled \(p'\) stable but compresses local distances
NTK/base scaling changes effective \(\theta\) / frequencies preserves more structure but needs tuning
YaRN-like mixed scaling different treatment for low/high frequencies better long context, more hyperparameters
continued long-context training adapts parameters to new length expensive but most reliable

context extension 不是“把 max_position_embeddings 改大”。位置编码、attention memory、训练数据长度、评测任务和 serving cache offset 必须一致。

ALiBi and Relative Bias

另一类做法是不修改 \(Q,K\),而是在 attention logits 里加相对位置 bias。ALiBi 的形式可以概括为

\[ s_{ij}' = s_{ij} - m_h(i-j), \qquad j\le i, \]

其中 \(m_h\) 是 head-specific slope。距离越远,bias 越负,模型天然偏向近邻。

ALiBi 的效果可以直接从 odds ratio 看出来。若同一 query \(i\) 比较两个 key \(j_1,j_2\),忽略 value 后 attention 权重比为

\[ \frac{\alpha_{ij_1}}{\alpha_{ij_2}} = \exp\left( s_{ij_1}-s_{ij_2} -m_h[(i-j_1)-(i-j_2)] \right). \]

\(j_1\) 更远时,\((i-j_1)\) 更大,bias 会指数级压低远距离 key 的 odds。不同 head 使用不同 slope,相当于让某些 head 更局部,某些 head 更愿意看远。

和 RoPE 不同,ALiBi 不改变 \(Q,K\) 的几何,只在 score 上加一个确定性相对距离先验。因此它外推长上下文时不需要新位置 embedding 表,但也会把“距离越远越不该看”这个 inductive bias 强加给所有样本。

Method Injected into Extrapolation intuition
learned absolute input embeddings weak beyond trained table
sinusoidal input embeddings analytic positions
RoPE Q/K rotation relative dot-product structure
ALiBi attention logits linear distance bias

不同 position 方法不是可随意替换的小模块。它们影响 attention score 的几何结构,进而影响长上下文行为。

Position State During Generation

prefill 阶段一次处理 prompt:

positions: 0 1 2 ... T-1

decode 阶段每生成一个 token,position id 必须递增:

new token at position T
new token at position T+1
...

KV cache 中已经存了过去 positions 的 rotated keys 或 position-conditioned states。若 decode 时 position id 从 0 重新开始,shape 不会报错,但 attention score 的相对位置全部错乱。

WarningPitfall: Cache Offset Is Part of the Model Input

During cached decoding, the next token’s position is not 0; it is the current sequence length. Cache offset bugs often degrade generation quality without throwing shape errors.

一个 cached decode step 可以这样理解:

past_len = past_key_values.get_seq_length()
position_ids = torch.arange(
    past_len,
    past_len + input_ids.size(1),
    device=input_ids.device,
)[None, :]

out = model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    position_ids=position_ids,
    past_key_values=past_key_values,
    use_cache=True,
)

如果 batch 内不同样本已有 cache 长度不同,runtime 还要为每条样本维护自己的 offset。PagedAttention/continuous batching 之所以复杂,一部分原因就是每个 request 的 logical position、KV block location 和 finished status 都不同。

WarningPitfall: Context Extension Is Not Just Changing a Number

Increasing max context length requires position encoding, attention memory, training distribution, and evaluation to agree. RoPE scaling can extend usable length, but models may still fail on long-range retrieval if not trained or adapted for it.

Implementation Checklist

调试 masks and positions 时,可以按下面顺序自检:

  1. causal mask 的第 \(i\) 行是否只允许 \(j\le i\)
  2. padding mask 是否同时作用于 attention 和 loss;
  3. left padding generation 是否重新计算真实 token 的 position_ids
  4. sequence packing 是否使用 block-diagonal causal mask;
  5. packed sequence 的 position_ids 是否按文档重置;
  6. prefix-LM mask 是否允许 prefix 双向、suffix causal;
  7. bool mask 的 True/False 语义是否匹配目标 API;
  8. additive mask 的 dtype 是否和 attention logits 兼容;
  9. 每个送入 softmax 的 query row 是否至少有一个可见 key;
  10. fused attention 中是否重复使用 attn_maskis_causal=True
  11. RoPE 是否只作用在 \(Q,K\),且 Q/K 使用同一 position convention;
  12. RoPE cache 是否按 position_ids gather,而不是按 raw column index;
  13. cached decode 的新 token position 是否从 past_len 开始;
  14. long-context scaling 是否同步更新训练、推理和评测配置;
  15. 小样本 attention row 可见位置是否和预期矩阵一致。

几个最小测试很有效:

# 1. future-token leakage test
logits1 = model(input_ids).logits[:, :-1]
changed = input_ids.clone()
changed[:, -1] = other_token_id
logits2 = model(changed).logits[:, :-1]
assert torch.allclose(logits1, logits2, atol=1e-5)

# 2. left-padding position test
ids_a = torch.tensor([[A, B, C]])
ids_b = torch.tensor([[PAD, PAD, A, B, C]])
# With corrected position_ids and masks, logits on A/B/C should match closely.

# 3. packed-document boundary test
ids = torch.tensor([[A, B, EOS, D, E, EOS]])
changed = ids.clone()
changed[:, 1] = other_token_id
# Logits inside document 2 should not change when document 1 changes.

# 4. prefill-vs-cached-decode test
full = model(prompt).logits[:, -1]
prefill = model(prompt[:, :-1], use_cache=True)
step = model(
    prompt[:, -1:],
    past_key_values=prefill.past_key_values,
    position_ids=torch.tensor([[prompt.size(1) - 1]], device=prompt.device),
).logits[:, -1]
assert torch.allclose(full, step, atol=1e-4)

# 5. mask row-nonempty test
visible = make_visible(...)
assert visible.any(dim=-1).all()

第一个测试检查 causal mask 是否阻止最后一个 token 影响之前位置;第二个测试检查左 padding 是否只是 batch 对齐,而没有改变真实 token 的位置语义;第三个测试检查 packing 是否真的隔离文档;第四个测试检查 cache offset 和 RoPE position ids;第五个测试检查 fused attention 前最基本的数值安全。