4.6 Mathematical Foundations of Next-Token Prediction


Next-token prediction 看起来只是“预测下一个 token”的分类任务,但它能成为 LLM 预训练核心,是因为它同时连接了概率链式法则、最大似然、cross entropy、KL divergence、teacher forcing、perplexity 和 instruction tuning。

这一节把它当成数学对象,而不是 API 里的 labels=input_ids

Chain Rule of Probability

NoteDefinition: Chain Rule of Probability

For a discrete sequence \(x_{1:T}\), \[ p(x_{1:T}) = \prod_{t=1}^{T}p(x_t\mid x_{<t}). \] This identity holds for any joint distribution.

自回归语言模型的做法是用 neural network 参数化每个条件分布:

\[ p_\theta(x_t\mid x_{<t}) = \operatorname{softmax}(f_\theta(x_{<t}))_{x_t}. \]

Transformer decoder 的 causal mask 保证 \(f_\theta\) 在位置 \(t\) 只能依赖 prefix,而不能看未来。

Causal Mask as a Conditional Independence Constraint

在实现中,causal mask 通常是加到 attention logits 上的上三角矩阵:

\[ M_{ij} = \begin{cases} 0, & j\leq i,\\ -\infty, & j>i. \end{cases} \]

于是第 \(i\) 个位置的 attention 权重为:

\[ \alpha_{ij} = \frac{\exp((q_i^\top k_j)/\sqrt{d}+M_{ij})} {\sum_{\ell}\exp((q_i^\top k_\ell)/\sqrt{d}+M_{i\ell})}. \]

\(j>i\),则 \(\exp(-\infty)=0\),未来 token 对当前位置 hidden state 没有贡献。这不是一个“为了训练方便”的 mask,而是在结构上把模型限制为:

\[ h_i=f_\theta(x_{\leq i}), \qquad p_\theta(x_{i+1}\mid x_{\leq i}) = \operatorname{softmax}(W h_i)_{x_{i+1}}. \]

如果 mask 泄漏未来 token,训练 loss 会异常低,但模型学到的是一个不可能在生成时使用的条件分布。

Maximum Likelihood

给定训练集 \(\mathcal{D}=\{x^{(n)}_{1:T_n}\}_{n=1}^{N}\),最大似然估计是:

\[ \theta^\star = \arg\max_\theta \sum_{n=1}^{N} \log p_\theta(x^{(n)}_{1:T_n}). \]

代入 chain rule:

\[ \theta^\star = \arg\max_\theta \sum_{n=1}^{N} \sum_{t=1}^{T_n} \log p_\theta(x_t^{(n)}\mid x_{<t}^{(n)}). \]

因此 next-token prediction 不是启发式,而是序列分布的 maximum likelihood。

Empirical Distribution View

训练集定义了一个 empirical distribution:

\[ \hat{p}_{\mathcal{D}}(x_{1:T}) = \frac{1}{N}\sum_{n=1}^N \mathbf{1}[x_{1:T}=x^{(n)}_{1:T_n}]. \]

最大似然等价于最小化 empirical negative log-likelihood:

\[ \hat{\mathcal{L}}(\theta) = \mathbb{E}_{x\sim \hat{p}_{\mathcal{D}}} \left[ -\log p_\theta(x) \right]. \]

代入 chain rule 后,训练样本不再只是 \(N\) 条 sequence,而是大量 prefix-target pairs:

\[ (x_{<1},x_1),\ (x_{<2},x_2),\ldots,\ (x_{<T},x_T). \]

这解释了为什么 next-token prediction 数据效率高:一段长度为 \(T\) 的文本提供 \(T\) 个监督位置,而不是一个 sequence-level label。

ImportantTheorem: Token Cross Entropy Equals Negative Log-Likelihood

For one-hot labels, summing token-level cross entropy over all positions equals the negative log-likelihood of an autoregressive sequence model.

对位置 \(t\),真实 token 是 \(y=x_t\)。模型分布为

\[ q_\theta(v\mid x_{<t}) = \operatorname{softmax}(z_t)_v. \]

one-hot empirical distribution:

\[ \hat{p}(v\mid x_{<t})=\mathbf{1}[v=y]. \]

cross entropy:

\[ H(\hat{p},q_\theta) = -\sum_{v\in\mathcal{V}} \hat{p}(v\mid x_{<t}) \log q_\theta(v\mid x_{<t}) = -\log q_\theta(y\mid x_{<t}). \]

对所有 token 求和:

\[ \sum_t H(\hat{p}_t,q_{\theta,t}) = -\sum_t\log p_\theta(x_t\mid x_{<t}) = -\log p_\theta(x_{1:T}). \]

Cross Entropy and KL

population objective:

\[ \mathcal{L}(\theta) = \mathbb{E}_{x\sim p_{\text{data}}} \left[ -\sum_t\log p_\theta(x_t\mid x_{<t}) \right]. \]

写成条件分布的 cross entropy:

\[ \mathcal{L}(\theta) = \sum_t \mathbb{E}_{x_{<t}\sim p_{\text{data}}} H\left( p_{\text{data}}(\cdot\mid x_{<t}), p_\theta(\cdot\mid x_{<t}) \right). \]

由于

\[ H(p,q)=H(p)+\operatorname{KL}(p\|q), \]

最小化 cross entropy 等价于最小化真实条件分布和模型条件分布之间的 KL divergence。

KL divergence:

\[ \operatorname{KL}(p\|q) = \sum_x p(x)\log\frac{p(x)}{q(x)} = \sum_x p(x)\log p(x)-\sum_xp(x)\log q(x). \]

\[ H(p)=-\sum_xp(x)\log p(x), \qquad H(p,q)=-\sum_xp(x)\log q(x). \]

所以

\[ H(p,q)=H(p)+\operatorname{KL}(p\|q). \]

若数据无限、模型容量无限、优化达到全局最优,则:

\[ p_\theta(\cdot\mid x_{<t}) = p_{\text{data}}(\cdot\mid x_{<t}) \]

在数据分布支持集上成立。

KL Direction Matters

next-token MLE 最小化的是:

\[ \operatorname{KL} \left( p_{\text{data}}(\cdot\mid x_{<t}) \middle\| p_\theta(\cdot\mid x_{<t}) \right), \]

也就是 forward KL。它对“真实分布中出现但模型给低概率”的事件惩罚很重,因为 \(-\log p_\theta(y)\) 会爆炸。反过来的 reverse KL,

\[ \operatorname{KL}(p_\theta\|p_{\text{data}}), \]

更偏向 mode-seeking,会强烈惩罚模型把概率放到数据分布认为不可能的地方。预训练常用 forward KL 形式,是因为数据只给我们样本 \(y\sim p_{\text{data}}\),而不是完整的 \(p_{\text{data}}\) 密度表。

所以“next-token prediction 只是在模仿数据”这句话要精确一点:它在 empirical prefix distribution 上,用 teacher-forced samples 估计 forward KL 的梯度。

Softmax Gradient

\[ p_i=\frac{e^{z_i}}{\sum_j e^{z_j}}, \qquad \mathcal{L}=-\log p_y. \]

则:

\[ \mathcal{L} = -z_y+\log\sum_j e^{z_j}. \]

\(z_i\) 求导:

\[ \frac{\partial\mathcal{L}}{\partial z_i} = p_i-\mathbf{1}[i=y]. \]

这解释了分类式 LM loss 的梯度含义:真实 token 的 logit 被往上推,模型过度赋概率的错误 token 被往下压。

若 LM head 和 input embedding tied,logits 可以写成:

\[ z_t=E h_t, \qquad E\in\mathbb{R}^{|\mathcal{V}|\times d}. \]

\(e_y\) 是 one-hot label,则:

\[ \frac{\partial\mathcal{L}_t}{\partial z_t} = p_t-e_y, \]

并且

\[ \frac{\partial\mathcal{L}_t}{\partial h_t} = E^\top(p_t-e_y). \]

这说明 token-level CE 不只是更新最后一层分类器。错误分布通过 \(E^\top\) 回传到 hidden state,再穿过 Transformer block,最后影响 attention、MLP 和 embedding。若某个错误 token 和真实 token embedding 很接近,梯度方向也会带有这种几何结构。

NoteDefinition: Tied LM Head

A tied language-model head reuses the input embedding matrix as the output projection, usually computing logits with \(z_t=E h_t\) or \(z_t=h_t E^\top\) depending on tensor convention.

工程上这也解释了为什么新增 tokenizer token 后必须 resize embedding:LM head 的类别数就是 vocabulary size。

Tiny Gradient Example

用一个三词表例子看 CE 在做什么。设词表是:

["cat", "dog", "runs"]

当前 prefix 的真实下一个 token 是 "dog",模型 logits 为:

\[ z=[2.0,1.0,0.0]. \]

softmax 约为:

\[ p\approx[0.665,0.245,0.090]. \]

label one-hot 是:

\[ e_{\text{dog}}=[0,1,0]. \]

所以 logits 梯度是:

\[ \nabla_z\mathcal{L} = p-e_{\text{dog}} \approx [0.665,-0.755,0.090]. \]

梯度下降更新 logits 的方向是 \(-\nabla_z\mathcal{L}\):降低 "cat""runs",提高 "dog"。但注意它不是只惩罚 top-1 错误 "cat";所有 token 都按当前概率分到梯度。模型越自信地把概率放错,错 token 的负向更新越大。

这就是 next-token loss 比 accuracy 更细的地方。两个模型都预测错 top-1,但一个给正确 token \(0.40\),另一个给 \(10^{-6}\),CE 会把它们区分开:

\[ -\log 0.40\approx0.916, \qquad -\log 10^{-6}\approx13.82. \]

Hessian and Local Geometry

softmax CE 的 Hessian 是:

\[ \frac{\partial^2\mathcal{L}} {\partial z\,\partial z^\top} = \operatorname{diag}(p)-pp^\top. \]

对任意向量 \(u\)

\[ u^\top(\operatorname{diag}(p)-pp^\top)u = \mathbb{E}_{i\sim p}[u_i^2] -\left(\mathbb{E}_{i\sim p}[u_i]\right)^2 = \operatorname{Var}_{i\sim p}(u_i) \geq0. \]

所以对 logits 而言,单个 token 的 CE 是 convex 的。深度网络整体非凸,是因为 logits \(z_\theta(x)\) 是参数 \(\theta\) 的非线性函数;但最后一层分类损失本身的几何很干净。

NoteDefinition: Logit-Space Convexity

Cross entropy with softmax is convex in logits, but a Transformer language model is not convex in its weights because logits are nonlinear functions of the weights.

这也解释了为什么很多优化诊断会看 logits、entropy、margin 和 calibration:它们位于损失最直接作用的空间,比直接看所有权重更可解释。

Fisher View of the Same Hessian

对 categorical distribution,log-likelihood 的 score function 是:

\[ \nabla_z\log p_y = e_y-p. \]

Fisher information 是 score 的二阶矩:

\[ F_z = \mathbb{E}_{y\sim p} \left[ (e_y-p)(e_y-p)^\top \right]. \]

展开第 \(i,j\) 项:

\[ \mathbb{E}[e_{y,i}e_{y,j}] -p_i\mathbb{E}[e_{y,j}] -p_j\mathbb{E}[e_{y,i}] +p_ip_j. \]

因为 \(\mathbb{E}[e_y]=p\),且 \(\mathbb{E}[e_{y,i}e_{y,j}]=p_i\mathbf{1}[i=j]\),所以:

\[ F_z = \operatorname{diag}(p)-pp^\top. \]

这和 softmax CE 对 logits 的 Hessian 相同。于是 CE 的局部二阶几何可以理解成:模型越不确定的 token 子空间,曲率越大;所有 logits 同时加同一个常数的方向曲率为零,因为 softmax 对常数平移不变:

\[ \operatorname{softmax}(z+c\mathbf{1}) = \operatorname{softmax}(z). \]

NoteDefinition: Fisher Information

Fisher information is the expected outer product of the score function. For categorical logits, it equals the softmax covariance matrix \(\operatorname{diag}(p)-pp^\top\).

这也是 natural gradient 的直觉来源:普通梯度在参数空间里走,natural gradient 用 Fisher 修正,使一步更新更接近“在分布空间里移动固定距离”。完整 LLM 的 Fisher 太大,但 K-FAC、Shampoo、Adam 的二阶/预条件影子都可以从这里读出来。

Binary Logistic Regression as One-Token LM

把 vocabulary 缩到两个 token,softmax CE 就退化为 logistic loss。设真实 label \(y\in\{0,1\}\),logit difference 为 \(a=z_1-z_0\),则:

\[ p_\theta(y=1\mid x)=\sigma(a)=\frac{1}{1+e^{-a}}. \]

负对数似然为:

\[ \ell(a,y) = -y\log\sigma(a) -(1-y)\log(1-\sigma(a)). \]

若把 \(s=2y-1\in\{-1,+1\}\),它也可以写成:

\[ \ell(a,s)=\log(1+\exp(-sa)). \]

导数是:

\[ \frac{\partial \ell}{\partial a} = \sigma(a)-y, \qquad \frac{\partial^2 \ell}{\partial a^2} = \sigma(a)(1-\sigma(a)). \]

这个最小例子说明 next-token CE 并不神秘:它就是多类别 logistic regression 在每个 prefix 上重复一次。Transformer 复杂在如何把 prefix 映射成 logits;loss 对 logits 的数学结构非常清楚。

Label Shift and Masked Loss

训练代码里常写 labels=input_ids,但数学上模型在位置 \(t\) 预测的是下一个 token。对 token ids

x0 x1 x2 x3

监督关系是:

hidden at x0 -> label x1
hidden at x1 -> label x2
hidden at x2 -> label x3

第一种实现显式 shift:

inputs = input_ids[:, :-1]
labels = input_ids[:, 1:]
logits = model(inputs).logits
loss = cross_entropy(logits.reshape(-1, vocab), labels.reshape(-1))

第二种实现把完整 input_idslabels=input_ids 传给模型,由 model 内部做 shift。两者数学目标相同,但调试时必须知道 shift 在哪一层发生。

带 padding、prompt mask、assistant mask 的一般形式是:

\[ \mathcal{L} = -\frac{1}{\sum_{n,t}m_t^{(n)}} \sum_{n,t} m_t^{(n)} \log p_\theta \left( x_{t+1}^{(n)} \mid x_{\leq t}^{(n)} \right), \]

其中 \(m_t^{(n)}\in\{0,1\}\) 表示该位置是否计入 loss。padding、prompt-only token、system token、tool observation token 都可能被 mask 掉。

WarningPitfall: Off-by-One Label Bugs Are Silent

If labels are shifted twice, the model learns to predict \(x_{t+2}\) from \(x_{\leq t}\). If labels are not shifted at all, the task may leak the current token. Both bugs can produce plausible tensor shapes.

最小检查方法是拿一条短序列手工打印:

input:  [BOS, A, B, C]
target: [A,   B, C, EOS]
mask:   [1,   1, 1, 1]

如果 batch 中有 padding:

input:  [BOS, A, B, PAD]
target: [A,   B, EOS, PAD]
mask:   [1,   1, 1,   0]

PAD 位置必须既不被 attention 当成真实内容,也不被 loss 当成预测目标。

Loss Denominator Is Part of the Objective

同一批 logits 和 labels,只要 denominator 不同,优化目标就不同。下面三种写法表面都叫 “mean loss”,但梯度尺度不同:

Reduction Formula When acceptable
token mean \(\sum m_t\ell_t / \sum m_t\) most LM/SFT training
sequence mean \(\frac{1}{B}\sum_n \frac{\sum_t m_{nt}\ell_{nt}}{\sum_t m_{nt}}\) each example should have equal weight
batch-length mean \(\sum m_t\ell_t /(BT)\) only when padding ratio is fixed and intentional

若做 gradient accumulation,正确的 token mean 是 accumulation window 上的全局 denominator:

\[ \mathcal{L} = \frac{ \sum_{k=1}^{K} \sum_{n,t} m_{nt}^{(k)}\ell_{nt}^{(k)} }{ \sum_{k=1}^{K} \sum_{n,t} m_{nt}^{(k)} }. \]

如果每个 micro-batch 先各自除以有效 token 数,再除以 \(K\),短样本 micro-batch 会被赋予更大权重。这个 bug 不一定让 loss 爆炸,但会悄悄改变 instruction 数据、长文档数据和 padding-heavy 数据的训练比例。

loss_sum = 0.0
token_sum = 0

for batch in accumulation_window:
    logits = model(batch["input_ids"]).logits
    losses = cross_entropy_per_token(logits, batch["labels"])
    mask = batch["labels"].ne(-100)
    loss_sum = loss_sum + (losses * mask).sum()
    token_sum = token_sum + mask.sum()

loss = loss_sum / token_sum.clamp_min(1)
loss.backward()

实践中为了省显存,很多框架仍采用 micro-batch 内部 backward;那就至少要确认不同 micro-batch 的有效 token 数差异不大,或使用框架提供的 token-normalized accumulation。

Worked Denominator Example

考虑两个 micro-batches。第一批只有一个有效 assistant token,loss 为 \(4.0\);第二批有九个有效 token,每个 loss 为 \(1.0\)

全局 token mean 是:

\[ \frac{4.0+9\times1.0}{1+9} = 1.3. \]

若先对每个 micro-batch 求 mean,再平均:

\[ \frac{1}{2} \left( 4.0 + \frac{9\times1.0}{9} \right) = 2.5. \]

同样的 logits、同样的 labels,只因为 reduction 顺序不同,短样本的权重被放大了接近 \(2\) 倍。若短样本大多来自某一类 instruction,例如格式遵循或闲聊,训练目标就会在不知不觉中偏向这类样本。

WarningPitfall: Accumulation Changes the Empirical Distribution

Gradient accumulation with per-microbatch mean loss gives each microbatch equal weight, not each token equal weight. This is a different empirical risk unless token counts are constant.

Teacher Forcing and Exposure Bias

训练时,条件 prefix 来自真实数据:

\[ p_\theta(x_t\mid x_{<t}^{\text{gold}}). \]

生成时,prefix 来自模型自己:

\[ p_\theta(\hat{x}_t\mid \hat{x}_{<t}). \]

这产生 exposure bias:早期采样错误会把后续条件推到训练数据较少覆盖的区域。它不是 next-token objective 的数学错误,而是 teacher forcing 训练和 autoregressive sampling 推理之间的分布偏移。

teacher forcing 的工程价值是并行性。给定完整序列 \(x_{1:T}\),Transformer 可以一次 forward 同时计算所有位置的 hidden states:

\[ (h_1,\ldots,h_T)=F_\theta(x_{1:T};M_{\text{causal}}), \]

然后一次性得到所有 next-token logits。训练复杂度主要是一个长度 \(T\) 的并行 forward;若像生成一样逐 token rollout,训练会慢很多。

WarningPitfall: Lower Token Loss Is Not Always Better Generation

Token-level likelihood improves local conditional modeling. Long-form generation also depends on sampling, stopping, calibration, prompt distribution, and whether the model can recover from its own earlier tokens.

Finite Context and Packed Sequences

真实模型有 context length \(L\),所以它学的不是无限历史条件分布,而是:

\[ p_\theta(x_t\mid x_{<t}) \approx p_\theta(x_t\mid x_{\max(1,t-L):t-1}). \]

当训练文本超过 \(L\),常见做法是 chunking 或 packing:

doc1 tokens ... EOS doc2 tokens ... EOS doc3 tokens ...

packing 提高吞吐,但必须处理文档边界。如果没有 EOS 或 attention boundary,模型会被训练成“doc2 的开头可以依赖 doc1 的结尾”。这在纯网页预训练里有时可以接受,在 SFT/chat 数据里通常会污染样本。

WarningPitfall: Packed SFT Needs Boundary Masks

Packed instruction examples should not allow one conversation to attend to another conversation unless the packing format intentionally inserts boundary tokens and masks.

Perplexity

Perplexity 是平均 NLL 的指数:

\[ \operatorname{PPL} = \exp\left( \frac{1}{T} \sum_{t=1}^{T} -\log p_\theta(x_t\mid x_{<t}) \right). \]

如果模型每一步平均像在 \(K\) 个等概率选项中猜,PPL 约等于 \(K\)。但不同 tokenizer 的 token 粒度不同,所以不同 tokenizer 下的 PPL 不能直接比较。

另一个等价指标是 bits per token:

\[ \operatorname{BPT} = \frac{1}{T} \sum_t -\log_2 p_\theta(x_t\mid x_{<t}) = \frac{\log \operatorname{PPL}}{\log 2}. \]

若要跨 tokenizer 或跨语言比较,更稳的是 bits per byte/character:

\[ \operatorname{BPB} = \frac{\sum_t-\log_2 p_\theta(x_t\mid x_{<t})} {\#\text{bytes in original text}}. \]

这把 tokenizer fertility 的影响部分归一化掉。否则一个 tokenizer 把中文切得更碎,token-level PPL 可能看起来更低或更高,但并不代表原始文本建模更好。

Sequence Log-Probability and Length Bias

next-token objective 是 token factorization,但很多下游系统需要比较整个 completion:

prompt: "The capital of France is"
A: " Paris."
B: " the city of Paris."

模型给 completion \(y_{1:M}\) 的条件概率是:

\[ p_\theta(y_{1:M}\mid x) = \prod_{m=1}^{M} p_\theta(y_m\mid x,y_{<m}). \]

所以 log-prob 是 token log-prob 的和:

\[ \log p_\theta(y_{1:M}\mid x) = \sum_{m=1}^{M} \log p_\theta(y_m\mid x,y_{<m}). \]

这个公式直接解释了 length bias:只要每个 token 概率小于 \(1\),多生成一个 token 通常会让总 log-prob 更小。于是按 raw sequence log-prob 排序会偏好短答案;按平均 log-prob 排序又可能偏好长而局部流畅的废话。

Score Formula Bias
sum log-prob \(\sum_m \log p(y_m\mid x,y_{<m})\) prefers shorter completions
average log-prob \(\frac{1}{M}\sum_m \log p(y_m\mid x,y_{<m})\) ignores answer length cost
length penalty \(\frac{\sum_m\log p_m}{M^\alpha}\) tunable compromise
WarningPitfall: Completion Log-Probs Need the Same Tokenizer

Sequence scores are sums over tokenizer tokens. A response split into fewer tokens can have a different length penalty even when the raw text is semantically similar.

在 preference optimization、reranking、self-consistency 和 verifier systems 中,必须明确使用哪一种 normalization。DPO 常用的是整段 response 的 log-prob sum:

\[ \log \pi_\theta(y\mid x) = \sum_{m=1}^{M_y} \log \pi_\theta(y_m\mid x,y_{<m}), \]

因为 preference label 比较的是完整 response。但如果数据中 chosen answer 系统性更长,raw sum 会混入长度偏置;这时要么做数据审计,要么显式加入 length-normalized variant,并在报告里说明目标已经改变。

Log-Probability Accounting in Code

给定 logits[:, :-1]labels[:, 1:],sequence log-prob 的核心是按 label gather:

import torch.nn.functional as F


def token_logprobs(logits, labels):
    logp = F.log_softmax(logits[:, :-1], dim=-1)
    target = labels[:, 1:].clone()
    mask = target.ne(-100)
    target = target.masked_fill(~mask, 0)
    picked = logp.gather(-1, target.unsqueeze(-1)).squeeze(-1)
    return picked, mask


picked, mask = token_logprobs(logits, labels)
seq_logp = (picked * mask).sum(dim=-1)
seq_len = mask.sum(dim=-1).clamp_min(1)
avg_logp = seq_logp / seq_len

这里 labels[:, 1:]logits[:, :-1] 对齐;如果模型 forward 已经内部 shift,就不能再这样手动 shift。mask 必须来自 shifted labels,否则最后一个 prompt token 或第一个 assistant token 很容易错位。

Temperature, Entropy, and Evaluation Logits

temperature sampling 把 logits 改成 \(z/T\)

\[ p_T(v\mid x) = \frac{\exp(z_v/T)} {\sum_u\exp(z_u/T)}. \]

\(T>1\),分布更平,entropy 通常上升;当 \(T<1\),分布更尖,entropy 下降。对两个 token \(a,b\),temperature 不改变它们的排序:

\[ \log\frac{p_T(a)}{p_T(b)} = \frac{z_a-z_b}{T}. \]

它改变的是 odds 的强度,而不是 argmax。若 \(z_a-z_b=4\)\(T=2\) 会把 log-odds 降到 \(2\)\(T=0.5\) 会把 log-odds 放大到 \(8\)。这就是 temperature 能让采样更发散或更确定的原因。

WarningPitfall: Perplexity Uses Raw Model Distribution

Perplexity and validation NLL should be computed from the model distribution at training temperature, usually raw logits with \(T=1\). Applying inference temperature during evaluation measures a different distribution.

top-\(p\) 和 top-\(k\) sampling 更进一步:它们把部分 token 的概率截断为 \(0\),再重新归一化。这样的分布适合生成,但不是训练时最大似然学到的完整条件分布。于是报告模型能力时,要分开写:

Quantity Uses raw logits? Meaning
validation NLL / PPL yes probability-model fit
greedy accuracy yes top-1 local prediction
sampled answer quality no model plus decoding policy
pass@k no model plus sampling budget

这也是为什么同一个 checkpoint 可以在低 temperature 下很稳定,在高 temperature 下有创造性,但 validation loss 完全不变。

Log Loss Is Proper

ImportantTheorem: Log Loss Is a Proper Scoring Rule

For true distribution \(p\), the expected log loss \[ \mathbb{E}_{y\sim p}[-\log q(y)] \] is minimized at \(q=p\).

期望 log loss 是 cross entropy:

\[ H(p,q)=H(p)+\operatorname{KL}(p\|q). \]

\(H(p)\)\(q\) 无关,而 \(\operatorname{KL}(p\|q)\geq0\),等号当且仅当 \(p=q\)。所以最优预测分布就是真实分布。

这说明 next-token prediction 在概率意义上是合理的:如果真的能最小化 population log loss,模型应该输出真实条件分布,而不是只输出 mode。

Calibration, Entropy, and Why Mode Accuracy Is Not Enough

proper scoring rule 的含义还可以从 calibration 看。假设某类 prefix 下真实条件分布是:

\[ p^\star(\text{yes}\mid x)=0.7, \qquad p^\star(\text{no}\mid x)=0.3. \]

一个只输出 mode 的模型给:

\[ q_1=[1,0], \]

另一个校准模型给:

\[ q_2=[0.7,0.3]. \]

从 accuracy 看,\(q_1\) 更“坚定”,但 expected log loss 是:

\[ \mathbb{E}_{y\sim p^\star}[-\log q_1(y)] = \infty \]

因为真实数据中有 \(30\%\)"no",而 \(q_1\) 给了零概率。\(q_2\) 的 loss 是有限的熵 \(H(p^\star)\)。这解释了为什么 next-token prediction 训练的是完整分布,而不是只训练 argmax。

在 LLM 中,这个现象对应到多种合理续写:同一个 prefix 后可能有多个语义正确但风格不同的 token。CE 不要求模型永远选一个模板答案;它要求模型把概率质量放到数据条件分布支持的区域。

Instruction Tuning as Conditional LM

把 instruction example 序列化:

<user> Explain convexity.
<assistant> Convexity means ...

causal LM 仍然建模整个序列,但 SFT 常只对 assistant tokens 计 loss:

\[ \mathcal{L}_{\text{SFT}} = -\sum_{t\in\mathcal{A}} \log p_\theta(x_t\mid x_{<t}), \]

其中 \(\mathcal{A}\) 是 assistant token 位置集合。这不是换模型结构,而是换 label mask。GPT-2 的 WebText pretraining、instruction tuning、chat fine-tuning,本质上都围绕同一个条件概率核心展开。

更精确地说,prompt tokens 仍然进入条件上下文,但不贡献 loss:

\[ \mathcal{L}_{\text{SFT}} = -\sum_t m_t^{\text{assistant}} \log p_\theta(x_t\mid x_{<t}). \]

所以 SFT 不是在训练无条件回答分布,而是在训练:

\[ p_\theta(\text{assistant answer}\mid \text{system},\text{user},\text{history}). \]

如果误把 user prompt 也计入 loss,模型会学会复述用户输入;如果误把 assistant answer mask 掉,训练几乎只是在做 prompt language modeling。

Soft Targets and Distillation

one-hot label 是最常见情况,但 next-token objective 也可以用 teacher distribution:

\[ \mathcal{L}_{\text{KD}} = \tau^2 \sum_v p_T(v\mid x_{<t}) \log \frac{p_T(v\mid x_{<t})}{p_\theta^\tau(v\mid x_{<t})}. \]

这里 \(p_T\) 是 teacher,\(p_\theta^\tau\) 是 student 在 temperature \(\tau\) 下的分布。与 one-hot CE 相比,distillation 会告诉 student:哪些错误 token 是“接近正确”的,哪些是完全不应该的。这也是很多 small LLM 从 large LLM 学 reasoning/style 的数学接口。

Preference Losses Reuse Sequence Log-Prob

RLHF、DPO、IPO、KTO 这些 post-training 方法看起来已经离开 next-token prediction,但底层仍然反复调用同一个量:

\[ \log \pi_\theta(y\mid x) = \sum_{m=1}^{M} \log \pi_\theta(y_m\mid x,y_{<m}). \]

DPO 的核心 log-ratio 可以写成:

\[ r_\theta(x,y) = \beta \left[ \log \pi_\theta(y\mid x) - \log \pi_{\text{ref}}(y\mid x) \right]. \]

给定 chosen response \(y^+\) 和 rejected response \(y^-\),loss 是:

\[ \mathcal{L}_{\text{DPO}} = -\log\sigma \left( r_\theta(x,y^+)-r_\theta(x,y^-) \right). \]

这说明 preference optimization 不是换掉语言模型数学接口,而是把 sequence log-prob 组合成 pairwise objective。所有 label-shift、mask、tokenizer、length-bias、denominator 的问题都会继承过来。

如果 chosen 比 rejected 长很多,raw log-prob sum 可能自然更低;如果 response mask 错位,DPO 更新会把 prompt token 也算进偏好;如果 reference model tokenizer/template 不一致,log-ratio 就不再比较同一个事件。

NoteDefinition: Policy Log-Ratio

In preference optimization, a policy log-ratio compares how much more likely the trainable policy makes a response relative to a reference policy under the same prompt, tokenizer, template, and response mask.

一个最小 DPO log-prob accounting 应该显式返回每条 response 的 token 数:

def sequence_logp(logits, labels):
    picked, mask = token_logprobs(logits, labels)
    return {
        "logp": (picked * mask).sum(dim=-1),
        "tokens": mask.sum(dim=-1),
    }


chosen = sequence_logp(chosen_logits, chosen_labels)
rejected = sequence_logp(rejected_logits, rejected_labels)

delta_policy = chosen["logp"] - rejected["logp"]
delta_ref = chosen_ref["logp"] - rejected_ref["logp"]
loss = -F.logsigmoid(beta * (delta_policy - delta_ref)).mean()

这里的 tokens 不一定进入公式,但必须被记录。没有长度审计,就很难判断 preference loss 到底在学习质量偏好,还是在学习答案长度偏好。

Label Smoothing as Entropy Regularization

label smoothing 把 one-hot label 改成:

\[ q_\epsilon(v) = (1-\epsilon)\mathbf{1}[v=y] +\epsilon u(v), \]

其中 \(u(v)=1/|\mathcal{V}|\) 是 uniform distribution。损失为:

\[ H(q_\epsilon,p_\theta) = (1-\epsilon)(-\log p_\theta(y)) +\epsilon H(u,p_\theta). \]

第一项仍然推高真实 token;第二项要求模型不要把非真实 token 的概率压到过分接近 \(0\)。从梯度看:

\[ \frac{\partial \mathcal{L}}{\partial z_i} = p_\theta(i)-q_\epsilon(i). \]

真实 token 的目标概率从 \(1\) 变成 \(1-\epsilon+\epsilon/|\mathcal{V}|\),其他 token 的目标概率从 \(0\) 变成 \(\epsilon/|\mathcal{V}|\)。这会改善 calibration,但也可能伤害需要极低 entropy 的任务,例如严格格式输出、代码 token、数学符号补全。

WarningPitfall: Label Smoothing Changes the Target Distribution

Label smoothing is not a harmless regularizer. It replaces the empirical one-hot distribution with a higher-entropy target, so validation NLL against one-hot labels may worsen even if calibration improves.

Decoding Is Not Training

训练学到的是条件分布的参数化;生成时还要定义如何从这个分布取样。前面讲的 temperature 改变 odds 强度,top-\(p\) sampling 则选择最小集合 \(S_p\)

\[ \sum_{v\in S_p}p(v)\geq p. \]

top-\(k\) sampling 选择概率最高的 \(k\) 个 token,repetition penalty 会按历史 token 修改 logits,stop sequence 又会在文本层面截断输出。所以“同一个模型”可能因为 decoding policy 不同而表现得保守、发散、有创造力或不稳定。训练目标定义概率模型;decoding 定义如何使用概率模型。

Implementation Checklist

写或检查 causal LM training loop 时,至少逐项验证:

  1. causal mask 是否严格禁止未来 token;
  2. label shift 是在 collator、training step 还是 model forward 内部发生;
  3. PAD、prompt、system/tool tokens 是否按预期设置 labels=-100
  4. packed sequence 是否有 EOS 或 block-diagonal attention boundary;
  5. loss denominator 是有效 label 数,而不是 batch size 或 padded length;
  6. tokenizer 和 chat template 是否与训练 checkpoint 一致;
  7. perplexity 是否只在可比 tokenizer/数据切分上比较;
  8. generation 评测是否固定 decoding policy,而不是只报告 training loss。

next-token prediction 的核心非常简单:

\[ \text{prefix}\rightarrow \text{distribution over next token}. \]

但训练系统里的每个细节都在定义“哪些 prefix、哪些 target、哪些位置、哪些条件格式”真正进入这个概率目标。