4.11 LLM Inference: Decoding, KV Cache, and Serving


训练时,LLM 的核心问题是如何用大量 token 更新参数;推理时,参数已经固定,核心问题变成:如何在有限显存和延迟预算下,反复执行条件分布

\[ p_\theta(x_{t+1}\mid x_{\le t}) \]

并把它变成用户能感知的流式输出。

如果只说 “KV cache 可以加速推理”,这还不够。真正要理解的是:

  1. prefill 和 decode 是两种完全不同的计算形态;
  2. KV cache 把重复计算换成显存占用和内存带宽压力;
  3. serving runtime 的主要工作不是“调用模型”,而是管理很多请求的 cache、position、sampling state 和 stop condition;
  4. speculative decoding、continuous batching、PagedAttention 等技术都在优化同一个瓶颈:自回归生成的串行性与动态内存。

Inference as a State Machine

NoteDefinition: Autoregressive Inference State

An autoregressive inference state contains the generated prefix, per-layer KV cache, position state, sampling state, and stopping state needed to sample the next token from \(p_\theta(x_{t+1}\mid x_{\le t})\).

一个请求从进入服务端到完成,大致经历:

tokenize -> prefill -> decode loop -> detokenize/stream -> finish

更细一点,decode loop 每步都在做:

last token + cache + position
  -> model forward
  -> logits
  -> logits processors
  -> sampling / selection
  -> append token
  -> update KV cache
  -> check stop condition

这和训练 loop 很像:训练维护的是 gradient state 和 optimizer state;推理维护的是 cache state 和 sampling state。

WarningPitfall: Inference Is Not Just model.generate()

model.generate() hides tokenization, cache layout, masking, sampling, batching, and stopping. Serving systems must manage these states explicitly for many concurrent requests.

Prefill and Decode

NoteDefinition: Prefill

Prefill is the inference stage that processes the entire prompt in parallel and writes its key/value tensors into the KV cache.

NoteDefinition: Decode

Decode is the autoregressive stage that consumes the most recent generated token, reads the previous KV cache, writes the new token’s KV tensors, and emits the next token.

假设 prompt 长度是 \(T_p\),已经生成长度是 \(T_g\)。Prefill 的输入 shape 通常是

\[ X_{\text{prefill}}\in\mathbb{N}^{B\times T_p}, \]

模型一次性计算所有 prompt positions 的 hidden states,并为每层写入

\[ K,V\in\mathbb{R}^{B\times H_{kv}\times T_p\times d_h}. \]

Decode 第 \(s\) 步只输入最新 token:

\[ X_{\text{decode}}\in\mathbb{N}^{B\times 1}, \]

但 attention 要读完整历史:

\[ K,V\in\mathbb{R}^{B\times H_{kv}\times (T_p+s)\times d_h}. \]

于是 prefill 更像大矩阵乘和长序列 attention,通常 compute-bound;decode 每步只有一个 query,却要读越来越长的 KV cache,常常 memory-bandwidth-bound。

Stage Input length Parallelism Main pressure User metric
prefill prompt length \(T_p\) high over tokens compute and KV write time to first token
decode 1 per step low over time KV read bandwidth inter-token latency
NoteDefinition: TTFT and TPOT

Time to first token (TTFT) measures prompt processing plus the first decode step. Time per output token (TPOT) measures the average latency of subsequent decode steps.

TTFT 差,用户会觉得“半天不开始回答”;TPOT 差,用户会觉得“回答吐字很慢”。这两个指标对应的优化手段不同。

Why KV Cache Works

在第 \(\ell\) 层 attention 中,位置 \(t\) 的 key/value 是

\[ k_t^{(\ell)}=h_t^{(\ell)}W_K^{(\ell)}, \qquad v_t^{(\ell)}=h_t^{(\ell)}W_V^{(\ell)}. \]

对于 causal decoder,已经生成的 token 不会再改变。权重固定,prefix 固定,位置编码规则固定,则过去 token 的 \(k_t,v_t\) 在后续 decode steps 中可以复用。

无 cache 的朴素推理,在生成第 \(s\) 个 token 时,会重新处理长度 \(T_p+s\) 的整个序列。总计算近似随

\[ \sum_{s=1}^{T_g}(T_p+s) \]

甚至 attention 部分随

\[ \sum_{s=1}^{T_g}(T_p+s)^2 \]

增长。KV cache 后,每步只计算新 token 的 projection,但 attention 仍需读历史 keys/values:

\[ q_{\text{new}}K_{\le t}^{\top} \in \mathbb{R}^{B\times H\times 1\times t}. \]

所以 decode 从“重复计算所有历史 token”变成“为新 token 查询历史 cache”。

ImportantTheorem: Cached Decode Preserves Causal Attention Outputs

For a deterministic causal Transformer with fixed weights and deterministic position encoding, cached decoding produces the same logits as recomputing the full prefix at every step, up to numerical differences from kernel implementation and precision.

对任意已生成 prefix \(x_{\le t}\),full recomputation 在每层为每个位置 \(i\le t\) 计算同一个 hidden state \(h_i^{(\ell)}\)。由于 causal mask 禁止 \(i\) 看到未来位置,\(h_i^{(\ell)}\) 只依赖 \(x_{\le i}\),不依赖后续即将生成的 token。

cached decoding 在 token \(i\) 第一次出现时计算并保存

\[ k_i^{(\ell)},v_i^{(\ell)}. \]

后续步骤读取这些缓存值。因为权重、输入 prefix 和 position state 都相同,缓存中的 \(k_i^{(\ell)},v_i^{(\ell)}\) 与 full recomputation 得到的值相同。新位置 \(t\) 的 query 对同一组 \(K_{\le t},V_{\le t}\) 做同一 causal attention,因此输出 logits 相同。实际代码中可能由于 fused kernel、量化、并行 reduction 顺序带来微小数值差异。

KV Cache Memory Formula

设:

  • batch/request 数为 \(B\)
  • 总上下文长度为 \(T\)
  • 层数为 \(L\)
  • query heads 数为 \(H_q\)
  • KV heads 数为 \(H_{kv}\)
  • 每个 head 维度为 \(d_h\)
  • cache dtype bytes 为 \(b\),例如 FP16/BF16 为 2。

每层 cache 有 K 和 V 两份:

\[ \operatorname{KVBytes} = 2\cdot B\cdot T\cdot L\cdot H_{kv}\cdot d_h\cdot b. \]

因为 hidden size

\[ d=H_qd_h, \]

若是标准 MHA,\(H_{kv}=H_q\),公式可写成

\[ \operatorname{KVBytes}_{\text{MHA}} = 2BTLdb. \]

若是 GQA/MQA,\(H_{kv}<H_q\)

\[ \operatorname{KVBytes}_{\text{GQA}} = \operatorname{KVBytes}_{\text{MHA}} \cdot \frac{H_{kv}}{H_q}. \]

NoteDefinition: Grouped-Query Attention

Grouped-query attention uses more query heads than key/value heads. Several query heads share one key/value head, reducing KV cache memory and bandwidth.

例子:\(L=32\)\(H_q=32\)\(H_{kv}=8\)\(d_h=128\)\(T=8192\)\(B=8\),BF16 cache。则

\[ \operatorname{KVBytes} = 2\cdot 8\cdot8192\cdot32\cdot8\cdot128\cdot2 \approx 8.59\text{ GB}. \]

如果同样模型用 MHA,即 \(H_{kv}=32\),cache 约为

\[ 34.36\text{ GB}. \]

这就是为什么现代 LLM 常用 GQA:它不只是节省参数,而是显著降低 decode 阶段的 cache bandwidth 和 serving 显存。

KV Cache Layout in Real Runtimes

公式里的 KV cache 看起来像一个连续张量:

\[ K,V\in\mathbb{R}^{B\times L\times H_{kv}\times T\times d_h}. \]

真实 runtime 往往不会按这个逻辑形状直接存,因为请求长度不同、请求会动态进入/退出、cache 需要分页。更常见的物理布局是按 block 存:

k_cache: [num_blocks, block_size, num_kv_heads, head_dim]
v_cache: [num_blocks, block_size, num_kv_heads, head_dim]

每条 sequence 有一个 block table:

seq A logical block 0 -> physical block 17
seq A logical block 1 -> physical block 02
seq A logical block 2 -> physical block 91

decode 第 \(t\) 个位置时,runtime 先算 logical block 和 block offset:

\[ b_{\text{logical}}=\left\lfloor\frac{t}{P}\right\rfloor, \qquad o=t\bmod P, \]

其中 \(P\) 是 block size。然后通过 block table 找到 physical block,把新 token 的 K/V 写到

k_cache[physical_block, offset, :, :]
v_cache[physical_block, offset, :, :]

这种布局让 cache manager 可以在请求结束时把 physical blocks 还回 free list,也让 prefix sharing 通过 refcount 实现。

WarningPitfall: Logical Position and Physical Address Are Different

The model’s position id is a logical sequence position. The KV cache address is a physical memory location. Mixing them up causes silent cache corruption or wrong RoPE phases.

KV Cache Quantization

KV cache 有时比权重更快把显存吃完。除了 GQA/MQA,还可以把 cache 从 BF16/FP16 量化到 FP8/INT8。一个简化的 per-block quantization 是:

\[ \hat{K}_{b} = \operatorname{round}\left(\frac{K_b}{s_b}\right), \qquad s_b=\frac{\max |K_b|}{q_{\max}}, \]

读取时近似恢复:

\[ K_b\approx s_b\hat{K}_b. \]

它把 cache bytes 从 \(2\) bytes/element 降到 \(1\) byte/element,但 attention logits 直接依赖 \(q^\top k\),所以量化误差会影响注意力分布:

\[ q^\top(K+\Delta K) = q^\top K+q^\top\Delta K. \]

\(q^\top\Delta K\) 在长上下文中累积到足够大,softmax 的 top positions 可能改变。工程上通常要记录:

Choice Trade-off
BF16/FP16 KV safest quality, highest memory
FP8 KV lower memory/bandwidth, needs scale management
INT8 KV stronger compression, more quality risk
per-tensor scale cheap, worse outlier handling
per-head/per-block scale better range, more metadata
WarningPitfall: Weight Quantization and KV Quantization Are Different

Quantizing model weights changes static parameters. Quantizing KV cache changes dynamic activations that grow with context and are read every decode step. They have different error patterns and should be evaluated separately.

Position State and RoPE

KV cache 不只是 K/V tensor,还隐含 position state。对 absolute position embedding,位置 \(t\) 的 embedding 是 \(P_t\)。对 RoPE,query/key 会乘上位置相关旋转:

\[ \operatorname{RoPE}(q_t,t), \qquad \operatorname{RoPE}(k_t,t). \]

所以 cached decoding 必须保证:

  1. 新 token 使用正确 position id;
  2. left padding/right padding 不改变真实 token 的位置语义;
  3. packed prompts 的不同样本边界不会串位;
  4. long-context scaling 的 RoPE 参数和训练/部署配置一致。
WarningPitfall: Wrong Position IDs Can Make Cache Correct but Semantics Wrong

The KV cache may have the right shape while storing keys computed with wrong positions. This often appears as degraded long-context behavior rather than a shape error.

这也是为什么推理服务里 position id、attention mask、sequence length metadata 必须和 cache manager 放在一起管理。它们共同定义“这个 cache block 属于哪个请求的哪个位置”。

A Minimal Cached Decode Loop

概念上,单请求 decode 可以写成:

tokens = tokenizer(prompt)
cache = None

logits, cache = model(tokens, past_key_values=cache, use_cache=True)
next_id = sample(logits[:, -1])
tokens.append(next_id)

while not stop(tokens):
    logits, cache = model([next_id], past_key_values=cache, use_cache=True)
    next_id = sample(logits[:, -1])
    tokens.append(next_id)

实际 serving 代码会复杂很多,因为 cache 不是 Python tuple 这么简单,而是一个由 runtime 管理的显存对象。它需要支持:

  1. append new KV;
  2. read old KV;
  3. evict finished requests;
  4. share prefix blocks;
  5. handle beam/parallel sampling;
  6. compact or page memory;
  7. map logical sequence positions to physical memory blocks。
NoteDefinition: Logical and Physical KV Blocks

A logical KV block is a block index in a sequence’s cache address space. A physical KV block is an actual GPU-memory block storing K/V tensors. A block table maps logical blocks to physical blocks.

Cache Manager Implementation Sketch

一个 serving runtime 至少需要三个表:

State Example Role
free block list [3, 8, 19, ...] physical KV blocks not in use
sequence table request_id -> SequenceState per-request logical metadata
block table seq_id -> [physical_block_ids] logical-to-physical mapping

可以把每个请求的状态写成:

class SequenceState:
    def __init__(self, request_id, sampling_cfg):
        self.request_id = request_id
        self.token_ids = []
        self.block_table = []
        self.length = 0
        self.sampling_cfg = sampling_cfg
        self.finished = False

append 一个 token 时,cache manager 先确认最后一个 block 是否还有位置:

def ensure_slot(seq, free_blocks, block_size):
    if seq.length % block_size == 0:
        if not free_blocks:
            raise RuntimeError("KV cache exhausted")
        seq.block_table.append(free_blocks.pop())
    logical = seq.length // block_size
    offset = seq.length % block_size
    return seq.block_table[logical], offset

请求结束时释放 blocks:

def release_sequence(seq, free_blocks, refcount):
    for block in seq.block_table:
        refcount[block] -= 1
        if refcount[block] == 0:
            free_blocks.append(block)
    seq.block_table.clear()

prefix sharing 或 beam search 会让一个 physical block 的 refcount > 1。一旦某个分支要写入共享 block,就必须 copy-on-write:

def cow_last_block(seq, free_blocks, refcount):
    block = seq.block_table[-1]
    if refcount[block] == 1:
        return block
    new_block = free_blocks.pop()
    copy_kv_block(dst=new_block, src=block)
    refcount[block] -= 1
    refcount[new_block] = 1
    seq.block_table[-1] = new_block
    return new_block
WarningPitfall: Cache Exhaustion Is a Scheduling Event

When the free KV block pool is empty, the runtime should reject, queue, preempt, or swap requests deliberately. Silently truncating context or overwriting blocks corrupts generation.

Logits Processors and Sampling

模型输出 logits

\[ z\in\mathbb{R}^{|\mathcal{V}|}. \]

Sampling 不是直接 argmax 的同义词。Serving runtime 通常会按顺序应用一组 logits processors:

raw logits
  -> repetition / presence penalty
  -> bad words or grammar mask
  -> temperature
  -> top-k / top-p / min-p
  -> sample or argmax

Temperature:

\[ p_i(\tau) = \frac{\exp(z_i/\tau)} {\sum_j\exp(z_j/\tau)}. \]

\(\tau<1\) 让分布更尖,\(\tau>1\) 让分布更平。

Top-k 保留概率最高的 \(k\) 个 token;top-p nucleus sampling 保留最小集合 \(S\),使得

\[ \sum_{i\in S}p_i\ge p_{\text{nucleus}}. \]

Repetition penalty 则根据已生成 token 调整 logits。一个简化版本可以写成:

\[ z_i' = \begin{cases} z_i/\alpha,& i\in \text{prefix and } z_i>0,\\ z_i\alpha,& i\in \text{prefix and } z_i<0,\\ z_i,& \text{otherwise}. \end{cases} \]

WarningPitfall: Sampling Parameters Change the Runtime Contract

The same model with different temperature, top-p, repetition penalty, or stop strings is a different decoding policy. Evaluations and demos should log these settings.

Processor Order Is Part of the Policy

Logits processors 不是一组可以随便交换的函数。设 raw logits 为 \(z\),一个解码策略可以形式化写成:

\[ \tilde{z} = \mathcal{F}_m\circ\cdots\circ\mathcal{F}_2\circ\mathcal{F}_1(z;\mathcal{S}_t), \qquad p=\operatorname{softmax}(\tilde{z}), \]

其中 \(\mathcal{S}_t\) 是当前 request state:历史 token、grammar state、ban list、temperature、top-p、RNG state 等。若两个 processor 不可交换:

\[ \mathcal{F}_a(\mathcal{F}_b(z)) \neq \mathcal{F}_b(\mathcal{F}_a(z)), \]

那么它们的顺序就定义了不同的 decoding policy。

NoteDefinition: Logits Processor

A logits processor is a deterministic transformation from raw model logits and decoding state to processed logits before sampling or argmax selection.

常见顺序背后的逻辑是:

Stage Input Output Why order matters
penalties raw logits + history adjusted logits repetition/presence uses generated prefix
hard masks adjusted logits + parser/ban state some logits set to \(-\infty\) invalid tokens must never be sampled
temperature valid logits rescaled logits controls entropy of allowed set
truncation probabilities or logits smaller support top-p depends on post-temperature probabilities
sampling final distribution + RNG next token consumes RNG state

Top-k by rank is invariant to positive temperature scaling, but top-p is not invariant because cumulative probability changes after softmax temperature:

\[ S_p(\tau) = \min \left\{ S: \sum_{i\in S} \frac{\exp(z_i/\tau)} {\sum_j\exp(z_j/\tau)} \geq p \right\}. \]

所以“temperature before top-p”和“top-p before temperature”会产生不同支持集。服务端、评测脚本和复现实验必须记录 processor order,而不只是记录 top_p=0.9

WarningPitfall: Hard Masks Must Precede Sampling

Bad-word masks, grammar masks, and tool-call boundary masks should set invalid logits to \(-\infty\) before normalization. Filtering after sampling is rejection sampling, which changes latency and may fail if most mass is invalid.

Implementing Common Logits Processors

一个工程上稳的约定是:processor 接收并返回 logits,不直接返回 probability。这样所有 hard mask 都可以通过 \(-\infty\) 表示,最后只做一次 stable softmax。

import torch


def apply_repetition_penalty(logits, generated_ids, penalty):
    if penalty == 1.0 or not generated_ids:
        return logits

    out = logits.clone()
    seen = torch.as_tensor(sorted(set(generated_ids)), device=logits.device)
    values = out[..., seen]
    values = torch.where(values > 0, values / penalty, values * penalty)
    out[..., seen] = values
    return out


def apply_allowed_token_mask(logits, allowed_ids):
    mask = torch.full_like(logits, -torch.inf)
    ids = torch.as_tensor(allowed_ids, device=logits.device)
    mask[..., ids] = logits[..., ids]
    return mask

上面的 repetition penalty 是 logits-level 的经验规则,不是概率模型的定理。它的特点是保留 logits 符号:正 logits 被除以 \(\alpha\),负 logits 乘以 \(\alpha\),都让已出现 token 变得更不可能。

No-repeat n-gram 也是常见 processor。若已经生成:

the cat sat on the cat

并设置 no-repeat 3-gram,那么 prefix the cat 后面曾经接过 sat,当前 suffix 若也是 the cat,就要 ban sat

from collections import defaultdict


def banned_ngram_next_tokens(tokens, n):
    if n <= 0 or len(tokens) + 1 < n:
        return set()

    table = defaultdict(set)
    for i in range(len(tokens) - n + 1):
        prefix = tuple(tokens[i : i + n - 1])
        nxt = tokens[i + n - 1]
        table[prefix].add(nxt)

    current = tuple(tokens[-(n - 1) :])
    return table[current]


def apply_ban_ids(logits, banned_ids):
    if not banned_ids:
        return logits
    out = logits.clone()
    ids = torch.as_tensor(sorted(banned_ids), device=logits.device)
    out[..., ids] = -torch.inf
    return out

Top-k/top-p filtering 要在最终 logits 上产生一个截断支持集。一个 batch 内每个 request 的 \(k,p\) 可能不同;生产系统通常按 sampling config 分组,或者写支持 per-row 参数的 kernel。

def top_k_top_p_filter(logits, top_k=None, top_p=None):
    out = logits.clone()

    if top_k is not None and top_k > 0:
        kth = torch.topk(out, k=top_k, dim=-1).values[..., -1, None]
        out = out.masked_fill(out < kth, -torch.inf)

    if top_p is not None and top_p < 1.0:
        sorted_logits, sorted_idx = torch.sort(out, descending=True, dim=-1)
        probs = torch.softmax(sorted_logits, dim=-1)
        cum = probs.cumsum(dim=-1)
        remove = cum > top_p
        remove[..., 1:] = remove[..., :-1].clone()
        remove[..., 0] = False
        sorted_logits = sorted_logits.masked_fill(remove, -torch.inf)
        out = torch.full_like(out, -torch.inf)
        out.scatter_(dim=-1, index=sorted_idx, src=sorted_logits)

    return out

这里 remove[..., 1:] = remove[..., :-1] 的作用是保留第一个让 cumulative probability 超过 top-p 的 token。否则极端情况下 top-p 太小可能把所有 token 都删掉。

NoteDefinition: Sampling Support

The sampling support is the set of vocabulary ids with finite processed logits after all hard masks, penalties, temperature scaling, and truncation processors have been applied.

一个必要 invariant 是:

\[ |\operatorname{supp}(\tilde{z})|\geq1. \]

如果 grammar mask、bad-word mask、top-k/top-p 叠加后支持集为空,runtime 应该抛出结构化错误或回退到明确策略,而不是让 softmax 产生 NaN。

def assert_nonempty_support(logits):
    ok = torch.isfinite(logits).any(dim=-1)
    if not torch.all(ok):
        bad_rows = torch.nonzero(~ok, as_tuple=False).flatten().tolist()
        raise RuntimeError(f"empty sampling support rows: {bad_rows}")

Sampling Trace and RNG State

为了复现一次 generation,只保存最终文本远远不够。每个 decode step 至少要能回答:

  1. raw logits 对应哪个 model revision 和 adapter;
  2. 哪些 processors 被应用、顺序是什么;
  3. 最终 token 是 argmax 还是随机采样;
  4. 随机采样消耗了哪个 RNG state;
  5. 记录的 logprob 是 raw policy 还是 processed policy。
NoteDefinition: Processed-Policy Log Probability

The processed-policy log probability of a sampled token is the log probability under the distribution after all logits processors have been applied.

对于 serving trace,常见有两种 logprob:

\[ \log p_{\text{raw}}(y) = \log\operatorname{softmax}(z)_y, \qquad \log p_{\text{proc}}(y) = \log\operatorname{softmax}(\tilde{z})_y. \]

raw logprob 更接近模型本身;processed logprob 才是实际 decoding policy。若 top-p 把某个 token mask 掉,它在 processed policy 下概率就是 \(0\),logprob 是 \(-\infty\)

一个最小单步采样函数应该显式返回 trace:

from dataclasses import dataclass


@dataclass(frozen=True)
class SampleTrace:
    token_id: int
    raw_logprob: float
    processed_logprob: float
    support_size: int
    temperature: float
    top_k: int | None
    top_p: float | None
    rng_state_hash: str


def sample_one(raw_logits, processed_logits, cfg, generator):
    assert_nonempty_support(processed_logits)

    raw_lp = torch.log_softmax(raw_logits.float(), dim=-1)
    proc_lp = torch.log_softmax(processed_logits.float(), dim=-1)
    probs = proc_lp.exp()

    rng_state = generator.get_state()
    token = torch.multinomial(probs, num_samples=1, generator=generator)
    tid = int(token.item())

    trace = SampleTrace(
        token_id=tid,
        raw_logprob=float(raw_lp[..., tid].item()),
        processed_logprob=float(proc_lp[..., tid].item()),
        support_size=int(torch.isfinite(processed_logits).sum().item()),
        temperature=cfg.temperature,
        top_k=cfg.top_k,
        top_p=cfg.top_p,
        rng_state_hash=str(hash(bytes(rng_state.cpu().tolist()))),
    )
    return tid, trace

生产代码不会用 Python hash() 做跨进程稳定审计,因为它可能受解释器随机化影响;这里表达的是 contract:采样前的 RNG state 必须能被追踪或复现。更稳的做法是保存 seed、request-local counter、step index,或保存框架提供的 generator state digest。

ImportantTheorem: Deterministic Replay Requires State Before Sampling

For stochastic decoding, replaying the same model inputs and sampling parameters reproduces the same token only if the RNG state immediately before the sampling draw is also reproduced.

随机采样可写成:

\[ y_t = F(p_t, u_t), \]

其中 \(p_t\) 是 processed distribution,\(u_t\) 是伪随机数生成器在第 \(t\) 步产生的随机变量。即使 \(p_t\) 完全相同,只要 \(u_t\) 不同,\(F\) 的输出 token 也可能不同。

伪随机数生成器的 \(u_t\) 由初始 seed 和之前消耗随机数的次数决定。continuous batching、speculative decoding、rejection sampling、grammar fallback 都可能改变随机数消耗顺序。因此要 replay 单个 request,必须保存 request-local RNG state 或保存足够恢复它的 seed/counter,而不能依赖进程级全局 RNG。

这也是 preemption/swap 的一部分:暂停请求时,不能只保存 tokens 和 KV blocks,还要保存 sampling RNG state。否则恢复后虽然 cache 正确,后续随机 token 仍可能不同。

Grammar Masks and Structured Decoding

Grammar constrained decoding 可以看成每一步由 parser state 给出 allowed token set:

\[ A_t = \operatorname{Allowed}(\text{parser_state}_t) \subseteq \mathcal{V}. \]

然后 hard mask:

\[ \tilde{z}_i = \begin{cases} z_i,& i\in A_t,\\ -\infty,& i\notin A_t. \end{cases} \]

采样 token \(y_t\) 后,parser state 也要更新:

\[ \text{parser_state}_{t+1} = \operatorname{Step}(\text{parser_state}_t,y_t). \]

所以 structured decoding 的 request state 至少包含:

State Why
parser stack / DFA state decides allowed next tokens
tokenizer boundary state tokens may be partial strings
emitted text buffer validates stop and final parse
fallback policy what to do if support becomes empty
WarningPitfall: Token-Level Grammar Is Not Character-Level Grammar

A token can contain multiple characters, partial UTF-8 bytes, or text that crosses grammar boundaries. Grammar masks must be built against tokenizer pieces, not only against character-level parser transitions.

一个简单但清晰的 structured decoding loop 是:

def structured_decode_step(model, req, logits):
    allowed = req.parser.allowed_token_ids(req.tokenizer)
    logits = apply_allowed_token_mask(logits, allowed)
    logits = top_k_top_p_filter(logits, req.top_k, req.top_p)
    token_id, trace = sample_one(req.raw_logits, logits, req.cfg, req.rng)
    piece = req.tokenizer.decode([token_id])
    req.parser.step(piece)
    req.generated_ids.append(token_id)
    req.trace.append(trace)
    return token_id

工程上要注意 allowed_token_ids 可能很大,不能每步在 Python 里扫描整个 vocabulary。常见优化是:

  1. 为 grammar state 预计算 token mask;
  2. 把 mask 存成 bitset;
  3. 在 GPU 上应用 mask;
  4. 对高频 grammar state 做 cache;
  5. 对 JSON/tool-call 这类常见格式写专门 parser。

Stop Conditions

停止生成看起来简单,其实有多层:

Stop condition Example
token stop EOS token
string stop "\n\nUser:"
length stop max new tokens
grammar stop JSON object closed
tool stop tool-call boundary reached
safety stop policy filter interrupts

String stop 不能只看最后一个 token,因为 stop sequence 可能跨 token:

token_1 = "\n"
token_2 = "\nUser"
token_3 = ":"

所以 runtime 需要维护 detokenized suffix 或 token-level automaton。对结构化输出,grammar constrained decoding 会在每步根据 parser state mask 掉非法 token。

Stop Strings as Automata

字符串级 stop condition 的安全做法,是维护一个 suffix buffer 或 automaton,而不是只检查新 token 的文本。设 stop string 是

"\n\nUser:"

tokenizer 可能把它切成多种形式。每次 decode 出新 token 后,runtime 追加其文本到 suffix buffer,只保留最长 stop string 长度的后缀:

class StopMatcher:
    def __init__(self, stops):
        self.stops = stops
        self.max_len = max(len(s) for s in stops)
        self.suffix = ""

    def update(self, piece):
        self.suffix = (self.suffix + piece)[-self.max_len:]
        return any(self.suffix.endswith(s) for s in self.stops)

这个版本简单但有两个坑:

  1. byte-level tokenizer 的 decode piece 可能包含不完整 UTF-8 片段;
  2. streaming 时 stop string 之前的部分可能已经发给用户。

更严谨的实现会延迟发送最近 \(M\) 个字符,直到确认它们不可能成为 stop string 的一部分。若做 JSON/grammar constrained decoding,则 stop condition 还要和 parser state 绑定:不是看到 } 就停,而是 parser 确认整个 JSON object 闭合、栈为空、后续 token 非法或无需继续。

Continuous Batching

NoteDefinition: Continuous Batching

Continuous batching is a serving strategy that dynamically adds, removes, and schedules requests between decode iterations instead of waiting for a fixed batch to finish.

普通 static batching 的问题是:一个 batch 内请求长度不同,短请求完成后,长请求还在跑,GPU batch 逐渐变小,利用率下降。Continuous batching 每个 decode tick 重新组织 active requests:

tick 0: A B C D enter prefill/decode
tick 1: A B C D decode
tick 2: B finishes, E enters, A C D E decode
tick 3: A finishes, F enters, C D E F decode

这要求 runtime 每步维护:

  1. active sequence table;
  2. per-sequence cache block table;
  3. per-sequence position;
  4. per-sequence sampling configuration;
  5. prefill jobs and decode jobs 的调度优先级。
WarningPitfall: Bigger Batch Can Hurt Latency

Large batches improve throughput but can increase queueing delay and TTFT. Serving is a latency-throughput trade-off, not just a tokens/sec maximization problem.

Scheduler Queues and Chunked Prefill

真实服务端通常有至少两类工作:

Work type Shape Pressure
prefill many prompt tokens at once compute, attention over prompt
decode one token per active request KV bandwidth, low arithmetic intensity

如果每个 tick 都让长 prompt prefill 独占 GPU,已有用户的 streaming 会卡住;如果只服务 decode,新请求 TTFT 会很差。常见做法是设置 token budget:

per_tick_budget = max_num_batched_tokens
reserve_decode_slots(active_sequences)
use_remaining_budget_for_prefill_chunks

长 prompt 可以被切成 chunks:

prompt length 12000, chunk size 2048
prefill chunks: 2048 + 2048 + 2048 + 2048 + 2048 + 1760

每个 chunk 写入同一条 sequence 的 KV cache,并维护 position offset。这样调度器可以在 chunks 之间插入 decode ticks,降低其他请求的 TPOT 抖动。

一个简化 scheduler:

def schedule_tick(waiting_prefills, active_decodes, token_budget):
    batch = []

    # Decode has low latency tolerance, so reserve one token per active request.
    for seq in active_decodes:
        if token_budget <= 0:
            break
        batch.append(("decode", seq, 1))
        token_budget -= 1

    # Fill remaining budget with prefill chunks.
    while token_budget > 0 and waiting_prefills:
        seq = waiting_prefills[0]
        n = min(seq.remaining_prompt_tokens, token_budget, seq.max_chunk)
        batch.append(("prefill", seq, n))
        token_budget -= n
        seq.remaining_prompt_tokens -= n
        if seq.remaining_prompt_tokens == 0:
            waiting_prefills.pop(0)

    return batch
NoteDefinition: Decode-Prefill Interleaving

Decode-prefill interleaving schedules latency-sensitive decode steps together with throughput-oriented prefill chunks so that TTFT and TPOT can be controlled independently.

调度器还必须检查兼容性:不同 tokenizer、adapter、LoRA、grammar mask、tensor-parallel group、KV dtype、speculative draft model 都可能让请求不能放在同一个 batch。Continuous batching 的难点不是“把 list concat 起来”,而是维护一组兼容的动态状态。

Preemption and Swapping

当 KV cache 不够时,runtime 可以:

Policy Meaning Risk
reject return overload error simple but hurts availability
queue wait until blocks free increases TTFT
preempt pause low-priority sequence needs resumable state
swap move KV blocks CPU/GPU PCIe bandwidth and latency
recompute discard KV and recompute prefix later compute overhead

Preemption 的正确性条件是:恢复时 tokens、position ids、adapter state、sampling RNG state、KV blocks 或可重算 prefix 都一致。否则同一个请求暂停再恢复会生成不同分布,甚至直接错位。

PagedAttention

PagedAttention 的核心动机是:KV cache 大、动态增长、请求长度不同,如果为每个请求预留连续最大长度 cache,就会严重浪费显存。PagedAttention 借鉴操作系统分页,把每条序列的 KV cache 切成固定大小 blocks。

逻辑上,一条序列的 cache 是连续的:

logical blocks: 0 1 2 3 ...

物理上,这些 blocks 可以分散在显存池:

logical 0 -> physical 17
logical 1 -> physical 02
logical 2 -> physical 91
logical 3 -> physical 06

attention kernel 通过 block table 读取分散的 K/V blocks。这样 request 生成到哪里,就按需分配到哪里;request 完成后,blocks 归还给池。

ImportantTheorem: Paging Reduces KV Internal Fragmentation

If KV memory is allocated in fixed-size blocks and requests receive blocks on demand, unused reserved KV space per active request is bounded by one block rather than by the difference between maximum context length and actual sequence length.

传统连续预分配若为每个请求保留 \(T_{\max}\) tokens,但实际长度是 \(T_i\),浪费为

\[ T_{\max}-T_i. \]

这个浪费可以接近 \(T_{\max}\)。分页后,设 block size 为 \(P\),请求实际需要 \(T_i\) tokens,则分配

\[ \left\lceil\frac{T_i}{P}\right\rceil \]

个 blocks。除最后一个 block 外都被填满,最后一个 block 最多浪费 \(P-1\) 个 token slots。因此每个请求的内部碎片上界从 \(O(T_{\max})\) 降到 \(O(P)\)

PagedAttention 不是改变 Transformer 数学,而是改变 KV cache 的内存布局和 attention kernel 读 cache 的方式。代价是 kernel 和 runtime 更复杂;收益是显存利用率提高,continuous batching 更自然。

Speculative Decoding

NoteDefinition: Speculative Decoding

Speculative decoding uses a fast draft model to propose several future tokens and a target model to verify them in parallel, reducing the number of serial target-model decode steps while preserving the target distribution under the exact sampling algorithm.

自回归 decode 慢,是因为目标模型每次只能吐一个 token:

target forward -> token 1
target forward -> token 2
target forward -> token 3
...

Speculative decoding 改成:

draft model proposes y1, y2, ..., yK
target model scores all K positions in one forward
accept a prefix of the draft
fallback sample at first rejection

设 draft distribution 为 \(q\),target distribution 为 \(p\)。draft 先采样候选 token \(y\)。验证时接受概率为

\[ a(y)=\min\left(1,\frac{p(y)}{q(y)}\right). \]

若拒绝,则从修正分布采样:

\[ r(y) = \frac{[p(y)-q(y)]_+} {\sum_v[p(v)-q(v)]_+}. \]

直觉是:draft 已经覆盖了 \(q\) 那部分概率;如果某个 token 被 draft 过度提出,就按比例拒绝;拒绝后从 target 比 draft 更“缺”的正残差里补回来。

对单步 token,最终输出 \(y\) 的概率由两部分组成:

  1. draft 提出 \(y\) 且被接受;
  2. draft 提出某个 token 后被拒绝,再从 residual distribution 中采到 \(y\)

第一部分概率为

\[ q(y)\min\left(1,\frac{p(y)}{q(y)}\right) = \min(q(y),p(y)). \]

所有拒绝事件的总概率为

\[ 1-\sum_v\min(q(v),p(v)) = \sum_v[p(v)-q(v)]_+. \]

拒绝后从

\[ r(y)=\frac{[p(y)-q(y)]_+}{\sum_v[p(v)-q(v)]_+} \]

采样,因此第二部分给出 \([p(y)-q(y)]_+\)。总和为

\[ \min(q(y),p(y))+[p(y)-q(y)]_+=p(y). \]

多 token speculative decoding 对 accepted prefix 逐位应用同样思想,并让 target model 并行计算这些位置的条件分布。

Speculative decoding 的收益取决于:

  1. draft model 是否足够快;
  2. draft distribution 是否接近 target;
  3. target model 一次验证多个 token 的并行效率;
  4. batch 中请求长度和采样参数是否适合一起调度。

如果 draft 太差,大量 token 被拒绝;如果 draft 太慢,省下的 target steps 抵不过 draft 成本。

Serving Metrics

训练常看 loss、tokens/sec、GPU utilization。Serving 至少要分开看:

Metric Meaning
TTFT time to first token
TPOT time per output token
ITL inter-token latency distribution
request throughput completed requests per second
token throughput generated tokens per second
prefill throughput prompt tokens per second
cache utilization fraction of KV blocks used
queue time time waiting before prefill

一个系统可以 tokens/sec 很高但 TTFT 很差,因为它把请求攒成大 batch;也可以 TTFT 很低但吞吐差,因为每个请求单独跑。真实服务需要根据场景取舍:

Scenario Priority
chat UI low TTFT and stable TPOT
offline generation high token throughput
evaluation deterministic settings and reproducible prompts
agent/tool use stop condition and structured decoding correctness
batch labeling cost per million tokens

Fragmented GPUs

在碎片 GPU 场景里,不要幻想调度器能把很多慢互联单卡变成一台大机器。Serving 通常比训练更适合碎片资源,因为很多请求天然可以分散:

GPU 0: worker for model shard or full small model
GPU 1: another worker
GPU 2: embedding / rerank / draft model
queue: routes requests by model, adapter, length, priority

但如果一个模型本身放不进单卡,就需要 tensor parallel 或 weight quantization;这会引入跨卡通信和部署复杂度。碎片资源最稳妥的使用方式通常是:

  1. 小模型多副本;
  2. quantized model 多副本;
  3. batch/offline inference;
  4. draft model or verifier separation;
  5. 避免对互联要求很高的 TP job。
WarningPitfall: Serving Parallelism Is Constrained by Memory Layout

You can route requests across many workers, but one request’s decode step still needs coherent access to its model weights and KV cache. Poor placement can erase batching gains.

Implementation Checklist

写或调 serving 代码时,先检查:

  1. prefill and decode paths use the same tokenizer, chat template, and special tokens;
  2. position_ids match padding, packing, and cache offsets;
  3. KV cache dtype and shape match model attention type;
  4. GQA/MQA head expansion is handled correctly in attention kernels;
  5. stop strings are checked across token boundaries;
  6. processor order, sampling parameters, and logprob type are logged with outputs;
  7. continuous batching does not mix incompatible adapter or grammar states;
  8. finished requests release KV blocks;
  9. prefix sharing only happens for identical prefix state;
  10. RNG state is request-local and saved for preemption/replay;
  11. TTFT and TPOT are measured separately。

推理系统的本质是把一个数学上简单的自回归条件分布变成大量动态请求上的低延迟执行问题。KV cache 解决重复计算,PagedAttention 解决动态内存,continuous batching 解决 GPU 利用率,speculative decoding 解决串行步数。它们不是互相独立的小技巧,而是同一条服务端链路上的不同瓶颈。

References