4.6 Mathematical Foundations of Next-Token Prediction
Next-token prediction 看起来只是“预测下一个 token”的分类任务,但它能成为 LLM 预训练核心,是因为它同时连接了概率链式法则、最大似然、cross entropy、KL divergence、teacher forcing、perplexity 和 instruction tuning。
这一节把它当成数学对象,而不是 API 里的 labels=input_ids。
Chain Rule of Probability
For a discrete sequence \(x_{1:T}\), \[ p(x_{1:T}) = \prod_{t=1}^{T}p(x_t\mid x_{<t}). \] This identity holds for any joint distribution.
自回归语言模型的做法是用 neural network 参数化每个条件分布:
\[ p_\theta(x_t\mid x_{<t}) = \operatorname{softmax}(f_\theta(x_{<t}))_{x_t}. \]
Transformer decoder 的 causal mask 保证 \(f_\theta\) 在位置 \(t\) 只能依赖 prefix,而不能看未来。
Causal Mask as a Conditional Independence Constraint
在实现中,causal mask 通常是加到 attention logits 上的上三角矩阵:
\[ M_{ij} = \begin{cases} 0, & j\leq i,\\ -\infty, & j>i. \end{cases} \]
于是第 \(i\) 个位置的 attention 权重为:
\[ \alpha_{ij} = \frac{\exp((q_i^\top k_j)/\sqrt{d}+M_{ij})} {\sum_{\ell}\exp((q_i^\top k_\ell)/\sqrt{d}+M_{i\ell})}. \]
若 \(j>i\),则 \(\exp(-\infty)=0\),未来 token 对当前位置 hidden state 没有贡献。这不是一个“为了训练方便”的 mask,而是在结构上把模型限制为:
\[ h_i=f_\theta(x_{\leq i}), \qquad p_\theta(x_{i+1}\mid x_{\leq i}) = \operatorname{softmax}(W h_i)_{x_{i+1}}. \]
如果 mask 泄漏未来 token,训练 loss 会异常低,但模型学到的是一个不可能在生成时使用的条件分布。
Maximum Likelihood
给定训练集 \(\mathcal{D}=\{x^{(n)}_{1:T_n}\}_{n=1}^{N}\),最大似然估计是:
\[ \theta^\star = \arg\max_\theta \sum_{n=1}^{N} \log p_\theta(x^{(n)}_{1:T_n}). \]
代入 chain rule:
\[ \theta^\star = \arg\max_\theta \sum_{n=1}^{N} \sum_{t=1}^{T_n} \log p_\theta(x_t^{(n)}\mid x_{<t}^{(n)}). \]
因此 next-token prediction 不是启发式,而是序列分布的 maximum likelihood。
Empirical Distribution View
训练集定义了一个 empirical distribution:
\[ \hat{p}_{\mathcal{D}}(x_{1:T}) = \frac{1}{N}\sum_{n=1}^N \mathbf{1}[x_{1:T}=x^{(n)}_{1:T_n}]. \]
最大似然等价于最小化 empirical negative log-likelihood:
\[ \hat{\mathcal{L}}(\theta) = \mathbb{E}_{x\sim \hat{p}_{\mathcal{D}}} \left[ -\log p_\theta(x) \right]. \]
代入 chain rule 后,训练样本不再只是 \(N\) 条 sequence,而是大量 prefix-target pairs:
\[ (x_{<1},x_1),\ (x_{<2},x_2),\ldots,\ (x_{<T},x_T). \]
这解释了为什么 next-token prediction 数据效率高:一段长度为 \(T\) 的文本提供 \(T\) 个监督位置,而不是一个 sequence-level label。
For one-hot labels, summing token-level cross entropy over all positions equals the negative log-likelihood of an autoregressive sequence model.
对位置 \(t\),真实 token 是 \(y=x_t\)。模型分布为
\[ q_\theta(v\mid x_{<t}) = \operatorname{softmax}(z_t)_v. \]
one-hot empirical distribution:
\[ \hat{p}(v\mid x_{<t})=\mathbf{1}[v=y]. \]
cross entropy:
\[ H(\hat{p},q_\theta) = -\sum_{v\in\mathcal{V}} \hat{p}(v\mid x_{<t}) \log q_\theta(v\mid x_{<t}) = -\log q_\theta(y\mid x_{<t}). \]
对所有 token 求和:
\[ \sum_t H(\hat{p}_t,q_{\theta,t}) = -\sum_t\log p_\theta(x_t\mid x_{<t}) = -\log p_\theta(x_{1:T}). \]
Cross Entropy and KL
population objective:
\[ \mathcal{L}(\theta) = \mathbb{E}_{x\sim p_{\text{data}}} \left[ -\sum_t\log p_\theta(x_t\mid x_{<t}) \right]. \]
写成条件分布的 cross entropy:
\[ \mathcal{L}(\theta) = \sum_t \mathbb{E}_{x_{<t}\sim p_{\text{data}}} H\left( p_{\text{data}}(\cdot\mid x_{<t}), p_\theta(\cdot\mid x_{<t}) \right). \]
由于
\[ H(p,q)=H(p)+\operatorname{KL}(p\|q), \]
最小化 cross entropy 等价于最小化真实条件分布和模型条件分布之间的 KL divergence。
KL divergence:
\[ \operatorname{KL}(p\|q) = \sum_x p(x)\log\frac{p(x)}{q(x)} = \sum_x p(x)\log p(x)-\sum_xp(x)\log q(x). \]
而
\[ H(p)=-\sum_xp(x)\log p(x), \qquad H(p,q)=-\sum_xp(x)\log q(x). \]
所以
\[ H(p,q)=H(p)+\operatorname{KL}(p\|q). \]
若数据无限、模型容量无限、优化达到全局最优,则:
\[ p_\theta(\cdot\mid x_{<t}) = p_{\text{data}}(\cdot\mid x_{<t}) \]
在数据分布支持集上成立。
KL Direction Matters
next-token MLE 最小化的是:
\[ \operatorname{KL} \left( p_{\text{data}}(\cdot\mid x_{<t}) \middle\| p_\theta(\cdot\mid x_{<t}) \right), \]
也就是 forward KL。它对“真实分布中出现但模型给低概率”的事件惩罚很重,因为 \(-\log p_\theta(y)\) 会爆炸。反过来的 reverse KL,
\[ \operatorname{KL}(p_\theta\|p_{\text{data}}), \]
更偏向 mode-seeking,会强烈惩罚模型把概率放到数据分布认为不可能的地方。预训练常用 forward KL 形式,是因为数据只给我们样本 \(y\sim p_{\text{data}}\),而不是完整的 \(p_{\text{data}}\) 密度表。
所以“next-token prediction 只是在模仿数据”这句话要精确一点:它在 empirical prefix distribution 上,用 teacher-forced samples 估计 forward KL 的梯度。
Softmax Gradient
令
\[ p_i=\frac{e^{z_i}}{\sum_j e^{z_j}}, \qquad \mathcal{L}=-\log p_y. \]
则:
\[ \mathcal{L} = -z_y+\log\sum_j e^{z_j}. \]
对 \(z_i\) 求导:
\[ \frac{\partial\mathcal{L}}{\partial z_i} = p_i-\mathbf{1}[i=y]. \]
这解释了分类式 LM loss 的梯度含义:真实 token 的 logit 被往上推,模型过度赋概率的错误 token 被往下压。
若 LM head 和 input embedding tied,logits 可以写成:
\[ z_t=E h_t, \qquad E\in\mathbb{R}^{|\mathcal{V}|\times d}. \]
令 \(e_y\) 是 one-hot label,则:
\[ \frac{\partial\mathcal{L}_t}{\partial z_t} = p_t-e_y, \]
并且
\[ \frac{\partial\mathcal{L}_t}{\partial h_t} = E^\top(p_t-e_y). \]
这说明 token-level CE 不只是更新最后一层分类器。错误分布通过 \(E^\top\) 回传到 hidden state,再穿过 Transformer block,最后影响 attention、MLP 和 embedding。若某个错误 token 和真实 token embedding 很接近,梯度方向也会带有这种几何结构。
A tied language-model head reuses the input embedding matrix as the output projection, usually computing logits with \(z_t=E h_t\) or \(z_t=h_t E^\top\) depending on tensor convention.
工程上这也解释了为什么新增 tokenizer token 后必须 resize embedding:LM head 的类别数就是 vocabulary size。
Tiny Gradient Example
用一个三词表例子看 CE 在做什么。设词表是:
["cat", "dog", "runs"]
当前 prefix 的真实下一个 token 是 "dog",模型 logits 为:
\[ z=[2.0,1.0,0.0]. \]
softmax 约为:
\[ p\approx[0.665,0.245,0.090]. \]
label one-hot 是:
\[ e_{\text{dog}}=[0,1,0]. \]
所以 logits 梯度是:
\[ \nabla_z\mathcal{L} = p-e_{\text{dog}} \approx [0.665,-0.755,0.090]. \]
梯度下降更新 logits 的方向是 \(-\nabla_z\mathcal{L}\):降低 "cat" 和 "runs",提高 "dog"。但注意它不是只惩罚 top-1 错误 "cat";所有 token 都按当前概率分到梯度。模型越自信地把概率放错,错 token 的负向更新越大。
这就是 next-token loss 比 accuracy 更细的地方。两个模型都预测错 top-1,但一个给正确 token \(0.40\),另一个给 \(10^{-6}\),CE 会把它们区分开:
\[ -\log 0.40\approx0.916, \qquad -\log 10^{-6}\approx13.82. \]
Hessian and Local Geometry
softmax CE 的 Hessian 是:
\[ \frac{\partial^2\mathcal{L}} {\partial z\,\partial z^\top} = \operatorname{diag}(p)-pp^\top. \]
对任意向量 \(u\):
\[ u^\top(\operatorname{diag}(p)-pp^\top)u = \mathbb{E}_{i\sim p}[u_i^2] -\left(\mathbb{E}_{i\sim p}[u_i]\right)^2 = \operatorname{Var}_{i\sim p}(u_i) \geq0. \]
所以对 logits 而言,单个 token 的 CE 是 convex 的。深度网络整体非凸,是因为 logits \(z_\theta(x)\) 是参数 \(\theta\) 的非线性函数;但最后一层分类损失本身的几何很干净。
Cross entropy with softmax is convex in logits, but a Transformer language model is not convex in its weights because logits are nonlinear functions of the weights.
这也解释了为什么很多优化诊断会看 logits、entropy、margin 和 calibration:它们位于损失最直接作用的空间,比直接看所有权重更可解释。
Fisher View of the Same Hessian
对 categorical distribution,log-likelihood 的 score function 是:
\[ \nabla_z\log p_y = e_y-p. \]
Fisher information 是 score 的二阶矩:
\[ F_z = \mathbb{E}_{y\sim p} \left[ (e_y-p)(e_y-p)^\top \right]. \]
展开第 \(i,j\) 项:
\[ \mathbb{E}[e_{y,i}e_{y,j}] -p_i\mathbb{E}[e_{y,j}] -p_j\mathbb{E}[e_{y,i}] +p_ip_j. \]
因为 \(\mathbb{E}[e_y]=p\),且 \(\mathbb{E}[e_{y,i}e_{y,j}]=p_i\mathbf{1}[i=j]\),所以:
\[ F_z = \operatorname{diag}(p)-pp^\top. \]
这和 softmax CE 对 logits 的 Hessian 相同。于是 CE 的局部二阶几何可以理解成:模型越不确定的 token 子空间,曲率越大;所有 logits 同时加同一个常数的方向曲率为零,因为 softmax 对常数平移不变:
\[ \operatorname{softmax}(z+c\mathbf{1}) = \operatorname{softmax}(z). \]
Fisher information is the expected outer product of the score function. For categorical logits, it equals the softmax covariance matrix \(\operatorname{diag}(p)-pp^\top\).
这也是 natural gradient 的直觉来源:普通梯度在参数空间里走,natural gradient 用 Fisher 修正,使一步更新更接近“在分布空间里移动固定距离”。完整 LLM 的 Fisher 太大,但 K-FAC、Shampoo、Adam 的二阶/预条件影子都可以从这里读出来。
Binary Logistic Regression as One-Token LM
把 vocabulary 缩到两个 token,softmax CE 就退化为 logistic loss。设真实 label \(y\in\{0,1\}\),logit difference 为 \(a=z_1-z_0\),则:
\[ p_\theta(y=1\mid x)=\sigma(a)=\frac{1}{1+e^{-a}}. \]
负对数似然为:
\[ \ell(a,y) = -y\log\sigma(a) -(1-y)\log(1-\sigma(a)). \]
若把 \(s=2y-1\in\{-1,+1\}\),它也可以写成:
\[ \ell(a,s)=\log(1+\exp(-sa)). \]
导数是:
\[ \frac{\partial \ell}{\partial a} = \sigma(a)-y, \qquad \frac{\partial^2 \ell}{\partial a^2} = \sigma(a)(1-\sigma(a)). \]
这个最小例子说明 next-token CE 并不神秘:它就是多类别 logistic regression 在每个 prefix 上重复一次。Transformer 复杂在如何把 prefix 映射成 logits;loss 对 logits 的数学结构非常清楚。
Label Shift and Masked Loss
训练代码里常写 labels=input_ids,但数学上模型在位置 \(t\) 预测的是下一个 token。对 token ids
x0 x1 x2 x3
监督关系是:
hidden at x0 -> label x1
hidden at x1 -> label x2
hidden at x2 -> label x3
第一种实现显式 shift:
inputs = input_ids[:, :-1]
labels = input_ids[:, 1:]
logits = model(inputs).logits
loss = cross_entropy(logits.reshape(-1, vocab), labels.reshape(-1))第二种实现把完整 input_ids 和 labels=input_ids 传给模型,由 model 内部做 shift。两者数学目标相同,但调试时必须知道 shift 在哪一层发生。
带 padding、prompt mask、assistant mask 的一般形式是:
\[ \mathcal{L} = -\frac{1}{\sum_{n,t}m_t^{(n)}} \sum_{n,t} m_t^{(n)} \log p_\theta \left( x_{t+1}^{(n)} \mid x_{\leq t}^{(n)} \right), \]
其中 \(m_t^{(n)}\in\{0,1\}\) 表示该位置是否计入 loss。padding、prompt-only token、system token、tool observation token 都可能被 mask 掉。
If labels are shifted twice, the model learns to predict \(x_{t+2}\) from \(x_{\leq t}\). If labels are not shifted at all, the task may leak the current token. Both bugs can produce plausible tensor shapes.
最小检查方法是拿一条短序列手工打印:
input: [BOS, A, B, C]
target: [A, B, C, EOS]
mask: [1, 1, 1, 1]
如果 batch 中有 padding:
input: [BOS, A, B, PAD]
target: [A, B, EOS, PAD]
mask: [1, 1, 1, 0]
PAD 位置必须既不被 attention 当成真实内容,也不被 loss 当成预测目标。
Loss Denominator Is Part of the Objective
同一批 logits 和 labels,只要 denominator 不同,优化目标就不同。下面三种写法表面都叫 “mean loss”,但梯度尺度不同:
| Reduction | Formula | When acceptable |
|---|---|---|
| token mean | \(\sum m_t\ell_t / \sum m_t\) | most LM/SFT training |
| sequence mean | \(\frac{1}{B}\sum_n \frac{\sum_t m_{nt}\ell_{nt}}{\sum_t m_{nt}}\) | each example should have equal weight |
| batch-length mean | \(\sum m_t\ell_t /(BT)\) | only when padding ratio is fixed and intentional |
若做 gradient accumulation,正确的 token mean 是 accumulation window 上的全局 denominator:
\[ \mathcal{L} = \frac{ \sum_{k=1}^{K} \sum_{n,t} m_{nt}^{(k)}\ell_{nt}^{(k)} }{ \sum_{k=1}^{K} \sum_{n,t} m_{nt}^{(k)} }. \]
如果每个 micro-batch 先各自除以有效 token 数,再除以 \(K\),短样本 micro-batch 会被赋予更大权重。这个 bug 不一定让 loss 爆炸,但会悄悄改变 instruction 数据、长文档数据和 padding-heavy 数据的训练比例。
loss_sum = 0.0
token_sum = 0
for batch in accumulation_window:
logits = model(batch["input_ids"]).logits
losses = cross_entropy_per_token(logits, batch["labels"])
mask = batch["labels"].ne(-100)
loss_sum = loss_sum + (losses * mask).sum()
token_sum = token_sum + mask.sum()
loss = loss_sum / token_sum.clamp_min(1)
loss.backward()实践中为了省显存,很多框架仍采用 micro-batch 内部 backward;那就至少要确认不同 micro-batch 的有效 token 数差异不大,或使用框架提供的 token-normalized accumulation。
Worked Denominator Example
考虑两个 micro-batches。第一批只有一个有效 assistant token,loss 为 \(4.0\);第二批有九个有效 token,每个 loss 为 \(1.0\)。
全局 token mean 是:
\[ \frac{4.0+9\times1.0}{1+9} = 1.3. \]
若先对每个 micro-batch 求 mean,再平均:
\[ \frac{1}{2} \left( 4.0 + \frac{9\times1.0}{9} \right) = 2.5. \]
同样的 logits、同样的 labels,只因为 reduction 顺序不同,短样本的权重被放大了接近 \(2\) 倍。若短样本大多来自某一类 instruction,例如格式遵循或闲聊,训练目标就会在不知不觉中偏向这类样本。
Gradient accumulation with per-microbatch mean loss gives each microbatch equal weight, not each token equal weight. This is a different empirical risk unless token counts are constant.
Teacher Forcing and Exposure Bias
训练时,条件 prefix 来自真实数据:
\[ p_\theta(x_t\mid x_{<t}^{\text{gold}}). \]
生成时,prefix 来自模型自己:
\[ p_\theta(\hat{x}_t\mid \hat{x}_{<t}). \]
这产生 exposure bias:早期采样错误会把后续条件推到训练数据较少覆盖的区域。它不是 next-token objective 的数学错误,而是 teacher forcing 训练和 autoregressive sampling 推理之间的分布偏移。
teacher forcing 的工程价值是并行性。给定完整序列 \(x_{1:T}\),Transformer 可以一次 forward 同时计算所有位置的 hidden states:
\[ (h_1,\ldots,h_T)=F_\theta(x_{1:T};M_{\text{causal}}), \]
然后一次性得到所有 next-token logits。训练复杂度主要是一个长度 \(T\) 的并行 forward;若像生成一样逐 token rollout,训练会慢很多。
Token-level likelihood improves local conditional modeling. Long-form generation also depends on sampling, stopping, calibration, prompt distribution, and whether the model can recover from its own earlier tokens.
Finite Context and Packed Sequences
真实模型有 context length \(L\),所以它学的不是无限历史条件分布,而是:
\[ p_\theta(x_t\mid x_{<t}) \approx p_\theta(x_t\mid x_{\max(1,t-L):t-1}). \]
当训练文本超过 \(L\),常见做法是 chunking 或 packing:
doc1 tokens ... EOS doc2 tokens ... EOS doc3 tokens ...
packing 提高吞吐,但必须处理文档边界。如果没有 EOS 或 attention boundary,模型会被训练成“doc2 的开头可以依赖 doc1 的结尾”。这在纯网页预训练里有时可以接受,在 SFT/chat 数据里通常会污染样本。
Packed instruction examples should not allow one conversation to attend to another conversation unless the packing format intentionally inserts boundary tokens and masks.
Perplexity
Perplexity 是平均 NLL 的指数:
\[ \operatorname{PPL} = \exp\left( \frac{1}{T} \sum_{t=1}^{T} -\log p_\theta(x_t\mid x_{<t}) \right). \]
如果模型每一步平均像在 \(K\) 个等概率选项中猜,PPL 约等于 \(K\)。但不同 tokenizer 的 token 粒度不同,所以不同 tokenizer 下的 PPL 不能直接比较。
另一个等价指标是 bits per token:
\[ \operatorname{BPT} = \frac{1}{T} \sum_t -\log_2 p_\theta(x_t\mid x_{<t}) = \frac{\log \operatorname{PPL}}{\log 2}. \]
若要跨 tokenizer 或跨语言比较,更稳的是 bits per byte/character:
\[ \operatorname{BPB} = \frac{\sum_t-\log_2 p_\theta(x_t\mid x_{<t})} {\#\text{bytes in original text}}. \]
这把 tokenizer fertility 的影响部分归一化掉。否则一个 tokenizer 把中文切得更碎,token-level PPL 可能看起来更低或更高,但并不代表原始文本建模更好。
Sequence Log-Probability and Length Bias
next-token objective 是 token factorization,但很多下游系统需要比较整个 completion:
prompt: "The capital of France is"
A: " Paris."
B: " the city of Paris."
模型给 completion \(y_{1:M}\) 的条件概率是:
\[ p_\theta(y_{1:M}\mid x) = \prod_{m=1}^{M} p_\theta(y_m\mid x,y_{<m}). \]
所以 log-prob 是 token log-prob 的和:
\[ \log p_\theta(y_{1:M}\mid x) = \sum_{m=1}^{M} \log p_\theta(y_m\mid x,y_{<m}). \]
这个公式直接解释了 length bias:只要每个 token 概率小于 \(1\),多生成一个 token 通常会让总 log-prob 更小。于是按 raw sequence log-prob 排序会偏好短答案;按平均 log-prob 排序又可能偏好长而局部流畅的废话。
| Score | Formula | Bias |
|---|---|---|
| sum log-prob | \(\sum_m \log p(y_m\mid x,y_{<m})\) | prefers shorter completions |
| average log-prob | \(\frac{1}{M}\sum_m \log p(y_m\mid x,y_{<m})\) | ignores answer length cost |
| length penalty | \(\frac{\sum_m\log p_m}{M^\alpha}\) | tunable compromise |
Sequence scores are sums over tokenizer tokens. A response split into fewer tokens can have a different length penalty even when the raw text is semantically similar.
在 preference optimization、reranking、self-consistency 和 verifier systems 中,必须明确使用哪一种 normalization。DPO 常用的是整段 response 的 log-prob sum:
\[ \log \pi_\theta(y\mid x) = \sum_{m=1}^{M_y} \log \pi_\theta(y_m\mid x,y_{<m}), \]
因为 preference label 比较的是完整 response。但如果数据中 chosen answer 系统性更长,raw sum 会混入长度偏置;这时要么做数据审计,要么显式加入 length-normalized variant,并在报告里说明目标已经改变。
Log-Probability Accounting in Code
给定 logits[:, :-1] 和 labels[:, 1:],sequence log-prob 的核心是按 label gather:
import torch.nn.functional as F
def token_logprobs(logits, labels):
logp = F.log_softmax(logits[:, :-1], dim=-1)
target = labels[:, 1:].clone()
mask = target.ne(-100)
target = target.masked_fill(~mask, 0)
picked = logp.gather(-1, target.unsqueeze(-1)).squeeze(-1)
return picked, mask
picked, mask = token_logprobs(logits, labels)
seq_logp = (picked * mask).sum(dim=-1)
seq_len = mask.sum(dim=-1).clamp_min(1)
avg_logp = seq_logp / seq_len这里 labels[:, 1:] 与 logits[:, :-1] 对齐;如果模型 forward 已经内部 shift,就不能再这样手动 shift。mask 必须来自 shifted labels,否则最后一个 prompt token 或第一个 assistant token 很容易错位。
Temperature, Entropy, and Evaluation Logits
temperature sampling 把 logits 改成 \(z/T\):
\[ p_T(v\mid x) = \frac{\exp(z_v/T)} {\sum_u\exp(z_u/T)}. \]
当 \(T>1\),分布更平,entropy 通常上升;当 \(T<1\),分布更尖,entropy 下降。对两个 token \(a,b\),temperature 不改变它们的排序:
\[ \log\frac{p_T(a)}{p_T(b)} = \frac{z_a-z_b}{T}. \]
它改变的是 odds 的强度,而不是 argmax。若 \(z_a-z_b=4\),\(T=2\) 会把 log-odds 降到 \(2\);\(T=0.5\) 会把 log-odds 放大到 \(8\)。这就是 temperature 能让采样更发散或更确定的原因。
Perplexity and validation NLL should be computed from the model distribution at training temperature, usually raw logits with \(T=1\). Applying inference temperature during evaluation measures a different distribution.
top-\(p\) 和 top-\(k\) sampling 更进一步:它们把部分 token 的概率截断为 \(0\),再重新归一化。这样的分布适合生成,但不是训练时最大似然学到的完整条件分布。于是报告模型能力时,要分开写:
| Quantity | Uses raw logits? | Meaning |
|---|---|---|
| validation NLL / PPL | yes | probability-model fit |
| greedy accuracy | yes | top-1 local prediction |
| sampled answer quality | no | model plus decoding policy |
| pass@k | no | model plus sampling budget |
这也是为什么同一个 checkpoint 可以在低 temperature 下很稳定,在高 temperature 下有创造性,但 validation loss 完全不变。
Log Loss Is Proper
For true distribution \(p\), the expected log loss \[ \mathbb{E}_{y\sim p}[-\log q(y)] \] is minimized at \(q=p\).
期望 log loss 是 cross entropy:
\[ H(p,q)=H(p)+\operatorname{KL}(p\|q). \]
\(H(p)\) 与 \(q\) 无关,而 \(\operatorname{KL}(p\|q)\geq0\),等号当且仅当 \(p=q\)。所以最优预测分布就是真实分布。
这说明 next-token prediction 在概率意义上是合理的:如果真的能最小化 population log loss,模型应该输出真实条件分布,而不是只输出 mode。
Calibration, Entropy, and Why Mode Accuracy Is Not Enough
proper scoring rule 的含义还可以从 calibration 看。假设某类 prefix 下真实条件分布是:
\[ p^\star(\text{yes}\mid x)=0.7, \qquad p^\star(\text{no}\mid x)=0.3. \]
一个只输出 mode 的模型给:
\[ q_1=[1,0], \]
另一个校准模型给:
\[ q_2=[0.7,0.3]. \]
从 accuracy 看,\(q_1\) 更“坚定”,但 expected log loss 是:
\[ \mathbb{E}_{y\sim p^\star}[-\log q_1(y)] = \infty \]
因为真实数据中有 \(30\%\) 的 "no",而 \(q_1\) 给了零概率。\(q_2\) 的 loss 是有限的熵 \(H(p^\star)\)。这解释了为什么 next-token prediction 训练的是完整分布,而不是只训练 argmax。
在 LLM 中,这个现象对应到多种合理续写:同一个 prefix 后可能有多个语义正确但风格不同的 token。CE 不要求模型永远选一个模板答案;它要求模型把概率质量放到数据条件分布支持的区域。
Instruction Tuning as Conditional LM
把 instruction example 序列化:
<user> Explain convexity.
<assistant> Convexity means ...
causal LM 仍然建模整个序列,但 SFT 常只对 assistant tokens 计 loss:
\[ \mathcal{L}_{\text{SFT}} = -\sum_{t\in\mathcal{A}} \log p_\theta(x_t\mid x_{<t}), \]
其中 \(\mathcal{A}\) 是 assistant token 位置集合。这不是换模型结构,而是换 label mask。GPT-2 的 WebText pretraining、instruction tuning、chat fine-tuning,本质上都围绕同一个条件概率核心展开。
更精确地说,prompt tokens 仍然进入条件上下文,但不贡献 loss:
\[ \mathcal{L}_{\text{SFT}} = -\sum_t m_t^{\text{assistant}} \log p_\theta(x_t\mid x_{<t}). \]
所以 SFT 不是在训练无条件回答分布,而是在训练:
\[ p_\theta(\text{assistant answer}\mid \text{system},\text{user},\text{history}). \]
如果误把 user prompt 也计入 loss,模型会学会复述用户输入;如果误把 assistant answer mask 掉,训练几乎只是在做 prompt language modeling。
Soft Targets and Distillation
one-hot label 是最常见情况,但 next-token objective 也可以用 teacher distribution:
\[ \mathcal{L}_{\text{KD}} = \tau^2 \sum_v p_T(v\mid x_{<t}) \log \frac{p_T(v\mid x_{<t})}{p_\theta^\tau(v\mid x_{<t})}. \]
这里 \(p_T\) 是 teacher,\(p_\theta^\tau\) 是 student 在 temperature \(\tau\) 下的分布。与 one-hot CE 相比,distillation 会告诉 student:哪些错误 token 是“接近正确”的,哪些是完全不应该的。这也是很多 small LLM 从 large LLM 学 reasoning/style 的数学接口。
Preference Losses Reuse Sequence Log-Prob
RLHF、DPO、IPO、KTO 这些 post-training 方法看起来已经离开 next-token prediction,但底层仍然反复调用同一个量:
\[ \log \pi_\theta(y\mid x) = \sum_{m=1}^{M} \log \pi_\theta(y_m\mid x,y_{<m}). \]
DPO 的核心 log-ratio 可以写成:
\[ r_\theta(x,y) = \beta \left[ \log \pi_\theta(y\mid x) - \log \pi_{\text{ref}}(y\mid x) \right]. \]
给定 chosen response \(y^+\) 和 rejected response \(y^-\),loss 是:
\[ \mathcal{L}_{\text{DPO}} = -\log\sigma \left( r_\theta(x,y^+)-r_\theta(x,y^-) \right). \]
这说明 preference optimization 不是换掉语言模型数学接口,而是把 sequence log-prob 组合成 pairwise objective。所有 label-shift、mask、tokenizer、length-bias、denominator 的问题都会继承过来。
如果 chosen 比 rejected 长很多,raw log-prob sum 可能自然更低;如果 response mask 错位,DPO 更新会把 prompt token 也算进偏好;如果 reference model tokenizer/template 不一致,log-ratio 就不再比较同一个事件。
In preference optimization, a policy log-ratio compares how much more likely the trainable policy makes a response relative to a reference policy under the same prompt, tokenizer, template, and response mask.
一个最小 DPO log-prob accounting 应该显式返回每条 response 的 token 数:
def sequence_logp(logits, labels):
picked, mask = token_logprobs(logits, labels)
return {
"logp": (picked * mask).sum(dim=-1),
"tokens": mask.sum(dim=-1),
}
chosen = sequence_logp(chosen_logits, chosen_labels)
rejected = sequence_logp(rejected_logits, rejected_labels)
delta_policy = chosen["logp"] - rejected["logp"]
delta_ref = chosen_ref["logp"] - rejected_ref["logp"]
loss = -F.logsigmoid(beta * (delta_policy - delta_ref)).mean()这里的 tokens 不一定进入公式,但必须被记录。没有长度审计,就很难判断 preference loss 到底在学习质量偏好,还是在学习答案长度偏好。
Label Smoothing as Entropy Regularization
label smoothing 把 one-hot label 改成:
\[ q_\epsilon(v) = (1-\epsilon)\mathbf{1}[v=y] +\epsilon u(v), \]
其中 \(u(v)=1/|\mathcal{V}|\) 是 uniform distribution。损失为:
\[ H(q_\epsilon,p_\theta) = (1-\epsilon)(-\log p_\theta(y)) +\epsilon H(u,p_\theta). \]
第一项仍然推高真实 token;第二项要求模型不要把非真实 token 的概率压到过分接近 \(0\)。从梯度看:
\[ \frac{\partial \mathcal{L}}{\partial z_i} = p_\theta(i)-q_\epsilon(i). \]
真实 token 的目标概率从 \(1\) 变成 \(1-\epsilon+\epsilon/|\mathcal{V}|\),其他 token 的目标概率从 \(0\) 变成 \(\epsilon/|\mathcal{V}|\)。这会改善 calibration,但也可能伤害需要极低 entropy 的任务,例如严格格式输出、代码 token、数学符号补全。
Label smoothing is not a harmless regularizer. It replaces the empirical one-hot distribution with a higher-entropy target, so validation NLL against one-hot labels may worsen even if calibration improves.
Decoding Is Not Training
训练学到的是条件分布的参数化;生成时还要定义如何从这个分布取样。前面讲的 temperature 改变 odds 强度,top-\(p\) sampling 则选择最小集合 \(S_p\):
\[ \sum_{v\in S_p}p(v)\geq p. \]
top-\(k\) sampling 选择概率最高的 \(k\) 个 token,repetition penalty 会按历史 token 修改 logits,stop sequence 又会在文本层面截断输出。所以“同一个模型”可能因为 decoding policy 不同而表现得保守、发散、有创造力或不稳定。训练目标定义概率模型;decoding 定义如何使用概率模型。
Implementation Checklist
写或检查 causal LM training loop 时,至少逐项验证:
- causal mask 是否严格禁止未来 token;
- label shift 是在 collator、training step 还是 model forward 内部发生;
PAD、prompt、system/tool tokens 是否按预期设置labels=-100;- packed sequence 是否有 EOS 或 block-diagonal attention boundary;
- loss denominator 是有效 label 数,而不是 batch size 或 padded length;
- tokenizer 和 chat template 是否与训练 checkpoint 一致;
- perplexity 是否只在可比 tokenizer/数据切分上比较;
- generation 评测是否固定 decoding policy,而不是只报告 training loss。
next-token prediction 的核心非常简单:
\[ \text{prefix}\rightarrow \text{distribution over next token}. \]
但训练系统里的每个细节都在定义“哪些 prefix、哪些 target、哪些位置、哪些条件格式”真正进入这个概率目标。