4.10 LLM Training Systems: Memory, Parallelism, and Checkpointing


LLM 训练不是把 model.train() 跑起来这么简单。真正的瓶颈通常是显存、通信、activation、optimizer states 和数据吞吐。理解训练系统,首先要会做一张 memory bill。

Training Memory Bill

训练中主要显存项:

State Scales with Needed when
parameters \(N\) forward/backward
gradients \(N\) after backward
optimizer states \(N\) optimizer step
activations \(B,T,L,d\) backward
temporary buffers kernels/communication runtime

对 AdamW,若参数 BF16、梯度 BF16、Adam states FP32:

\[ \text{param}=2N,\quad \text{grad}=2N,\quad m=4N,\quad v=4N. \]

总计约:

\[ 12N\text{ bytes}. \]

若还有 FP32 master weights,则变成:

\[ 16N\text{ bytes}. \]

一个 7B 模型,仅 optimizer 相关状态就可能是:

\[ 12\times7\text{B}=84\text{ GB}. \]

这解释了为什么 7B full fine-tuning 往往不能简单放进单张 24GB GPU。

更一般地,若参数量为 \(N\),每个参数的训练状态字节数为 \(s_{\text{train}}\),每卡可用显存为 \(M\),则只看参数相关状态的下界是:

\[ M_{\text{state}} = Ns_{\text{train}}. \]

这还没有算 activations、KV-like temporary buffers、CUDA workspace、communication buckets、fragmentation 和 dataloader pinned memory。实践中不能让 \(M_{\text{state}}\) 接近显存上限,否则 forward/backward 一开始就会 OOM。

NoteDefinition: Memory Bill

A memory bill is an explicit accounting of parameter states, gradients, optimizer states, activations, temporary buffers, communication buffers, and fragmentation margin for a training run.

Worked Memory Bill

以 7B dense model full fine-tuning 为例,假设 BF16 params、BF16 grads、FP32 Adam states、无 FP32 master weights:

Item Bytes per parameter Total for 7B
params 2 14 GB
gradients 2 14 GB
Adam \(m\) 4 28 GB
Adam \(v\) 4 28 GB
subtotal 12 84 GB

若使用 AdamW + FP32 master weights:

Extra item Bytes per parameter Total for 7B
master params 4 28 GB

subtotal 变成 112 GB。再考虑 activation,单卡 24GB 不可能 full fine-tune。可选路线只有:

  1. 减少 trainable parameters:LoRA/QLoRA;
  2. 减少状态精度:8-bit optimizer;
  3. shard 状态:ZeRO/FSDP;
  4. 降低 activation:checkpointing、短序列、小 microbatch;
  5. 组合以上方法。
WarningPitfall: Parameter Count Is Not Training Memory

Saying “the model is 7B and BF16 weights are 14 GB” only describes inference-like weight memory. Training needs gradients, optimizer states, activations, temporary buffers, and safety margin.

Activation Memory

Transformer activation 粗略随 batch、sequence、layer、hidden size 增长:

\[ O(BTLd). \]

attention 还可能引入 \(O(BHT^2)\) 中间项,虽然 FlashAttention 可以显著减少显式 attention matrix 存储。

长上下文训练里,activation 往往比参数更麻烦。因为参数只和模型大小有关,而 activation 同时和 batch size、sequence length 有关。

更细一点,若每层保存 hidden states、attention input、MLP intermediate 等,activation 可以写成:

\[ M_{\text{act}} \approx c_{\text{act}}\cdot B_{\text{micro}}\cdot T\cdot L\cdot d\cdot s, \]

其中 \(s\) 是每个元素字节数,\(c_{\text{act}}\) 是由实现决定的常数。这个常数不是小事:是否保存 attention matrix、是否用 FlashAttention、是否 checkpoint、是否用 fused kernels,都会改变 \(c_{\text{act}}\)

attention matrix 如果显式保存,额外项近似为:

\[ M_{\text{attn}} \approx B_{\text{micro}}\cdot H\cdot T^2\cdot s. \]

这就是长上下文训练对普通 attention 很不友好的原因。\(T\) 翻倍,hidden activation 近似翻倍,但 attention matrix 近似变成四倍。

NoteDefinition: Microbatch

A microbatch is the per-forward batch processed before gradient accumulation. Global batch size may be much larger than microbatch size.

Activation Checkpointing

NoteDefinition: Activation Checkpointing

Activation checkpointing saves only selected activations during forward and recomputes omitted activations during backward, trading extra compute for lower memory.

普通反传保存每层 activation:

forward:  save h_1, h_2, ..., h_L
backward: reuse saved activations

checkpointing:

forward:  save only block boundaries
backward: recompute inside each block, then backprop

如果把 \(L\) 层分成 \(K\) 个 segments,只保存 segment boundaries,activation memory 可显著降低,但 backward 要多做部分 forward。

WarningPitfall: Checkpointing Is Not Free

Activation checkpointing reduces memory but increases compute. If the job is already compute-bound and memory is sufficient, it can slow training unnecessarily.

PyTorch pattern:

from torch.utils.checkpoint import checkpoint

def block_forward(hidden_states):
    return block(hidden_states, attention_mask=mask)

hidden_states = checkpoint(block_forward, hidden_states, use_reentrant=False)

现代 Transformer 训练通常按 block checkpoint,而不是对每个小 op checkpoint。

Checkpointing Cost Model

若不 checkpoint,保存每层 activation:

\[ M_{\text{act}}=O(L), \qquad C_{\text{train}}\approx C_{\text{fwd}}+C_{\text{bwd}}. \]

若每个 Transformer block checkpoint,只保存 block 输入,backward 时重算 block forward:

\[ M_{\text{act}}\downarrow, \qquad C_{\text{train}}\approx C_{\text{fwd}}+C_{\text{bwd}}+C_{\text{recompute}}. \]

通常 \(C_{\text{bwd}}\) 已经约为 forward 的 2 倍,checkpointing 可能让总计算增加约 20% 到 40%,具体取决于 checkpoint 粒度和 kernel。

极端理论上,若把 \(L\) 层分成 \(\sqrt{L}\) 个段并递归 checkpoint,可把 activation memory 从 \(O(L)\) 降到近似 \(O(\sqrt{L})\),代价是更多 recompute。但实际 LLM 训练更多选择 block-level checkpoint,因为它简单、稳定、容易和 FSDP/TP 组合。

WarningPitfall: RNG State Must Be Preserved

Checkpointing layers with dropout or stochastic operations must preserve RNG state during recomputation, otherwise backward sees a different computation graph.

Data Parallelism

Data parallelism 复制完整模型到每张 GPU,每张卡处理不同 mini-batch,反传后 all-reduce gradients:

\[ g=\frac{1}{R}\sum_{r=1}^{R}g^{(r)}. \]

优点:简单、吞吐好。缺点:每张卡都保存完整参数、梯度、optimizer states,无法解决单卡装不下模型的问题。

NoteDefinition: Data Parallelism

Data parallelism replicates the model on each worker and splits data across workers. Gradients are synchronized across workers before the optimizer step.

All-Reduce Semantics

假设每张卡计算局部平均梯度:

\[ g^{(r)} = \frac{1}{B_{\text{local}}} \sum_{i\in\mathcal{B}_r}\nabla_\theta \ell_i. \]

DDP all-reduce 后得到:

\[ g = \frac{1}{R}\sum_{r=1}^{R}g^{(r)}. \]

若每张卡 batch size 相同,这等于 global batch 的平均梯度。若不同 rank 的有效 token 数不同,例如 packed sequence 或 mask 后 token 数不同,简单按 rank 平均可能偏离 token-level average。更严格的做法是按有效 token 数归一化:

\[ g = \frac{\sum_r\sum_{i,t}m_{i,t}^{(r)}\nabla_\theta \ell_{i,t}^{(r)}} {\sum_r\sum_{i,t}m_{i,t}^{(r)}}. \]

这在 SFT 和 packed language modeling 里很常见:loss denominator 应该是有效 label 数,不是 padded token 数,也不是 rank 数。

Gradient Accumulation with DDP

若每个 microbatch 都 all-reduce,通信会浪费。典型做法是 accumulation 中间用 no_sync(),最后一个 microbatch 再同步:

from contextlib import nullcontext

for step, batch in enumerate(loader):
    sync = (step + 1) % grad_accum == 0
    context = model.no_sync() if not sync else nullcontext()
    with context:
        loss = model(batch).loss / grad_accum
        loss.backward()
    if sync:
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

这保持优化语义接近大 batch,同时减少通信次数。需要注意的是,gradient clipping 应该在同步后的 accumulated gradient 上做。

Token-Normalized Training Step Contract

上面的代码在每个 microbatch loss 都已经是同一种平均语义时可用。但 LLM 训练经常有 variable-length packing、prompt mask、padding mask、SFT answer-only mask。此时每个 rank、每个 microbatch 的有效 label token 数都不同。正确目标通常是:

\[ \mathcal{L} = \frac{ \sum_{r=1}^{R} \sum_{k=1}^{K} \sum_{i,t} m_{r,k,i,t}\ell_{r,k,i,t} }{ \sum_{r=1}^{R} \sum_{k=1}^{K} \sum_{i,t} m_{r,k,i,t} }, \]

其中 \(R\) 是 data-parallel world size,\(K\) 是 accumulation steps,\(m_{r,k,i,t}\in\{0,1\}\) 表示这个 token 是否参与 loss。

NoteDefinition: Token-Normalized Update

A token-normalized update scales the accumulated gradient by the total number of valid loss tokens across all ranks and accumulation microbatches, rather than by the number of sequences, ranks, or microbatches.

这个细节在 DDP 里尤其容易错。PyTorch DDP 的同步语义通常是 all-reduce 后再除以 world size。若 rank \(r\) 对本地 loss numerator 反传得到

\[ G_r = \nabla_\theta \sum_{k,i,t} m_{r,k,i,t}\ell_{r,k,i,t}, \]

我们想得到:

\[ G_{\text{target}} = \frac{\sum_r G_r}{N_{\text{tok}}}, \qquad N_{\text{tok}} = \sum_{r,k,i,t}m_{r,k,i,t}. \]

但 DDP 同步后给的是 rank 平均:

\[ G_{\text{ddp}} = \frac{1}{R}\sum_r \tilde{G}_r. \]

因此每个 rank 本地反传前应把 local numerator 乘上:

\[ s = \frac{R}{N_{\text{tok}}}. \]

这样:

\[ G_{\text{ddp}} = \frac{1}{R} \sum_r \frac{R}{N_{\text{tok}}}G_r = \frac{\sum_rG_r}{N_{\text{tok}}}. \]

设每个 rank 反传的 scaled local objective 是

\[ \tilde{\mathcal{L}}_r = s\sum_{k,i,t}m_{r,k,i,t}\ell_{r,k,i,t}. \]

则本地梯度为 \(\tilde{G}_r=sG_r\)。DDP 同步后:

\[ G_{\text{ddp}} = \frac1R\sum_r\tilde{G}_r = \frac1R\sum_r sG_r. \]

\(s=R/N_{\text{tok}}\),就得到目标 token average gradient:

\[ G_{\text{ddp}} = \frac1R \sum_r \frac{R}{N_{\text{tok}}}G_r = \frac{\sum_rG_r}{N_{\text{tok}}}. \]

工程上最好让 batch sampler 一次吐出一个 accumulation window,这样在任何 forward 之前就能统计整个 window 的有效 token 数:

from contextlib import nullcontext

import torch
import torch.distributed as dist
import torch.nn.functional as F


def valid_label_count(batch):
    return (batch["labels"] != -100).sum()


def all_reduce_sum_scalar(x, device):
    t = torch.as_tensor(float(x), device=device)
    dist.all_reduce(t, op=dist.ReduceOp.SUM)
    return t


def lm_loss_sum(logits, labels):
    # logits: [B, T, V], labels: [B, T], labels already shifted or aligned.
    return F.cross_entropy(
        logits.reshape(-1, logits.size(-1)).float(),
        labels.reshape(-1),
        ignore_index=-100,
        reduction="sum",
    )

一个最小但语义完整的 accumulation step:

def train_update(model, optimizer, scaler, scheduler, window, max_grad_norm):
    device = next(model.parameters()).device
    world = dist.get_world_size()

    local_valid = sum(valid_label_count(batch) for batch in window)
    global_valid = all_reduce_sum_scalar(local_valid, device)
    if global_valid.item() == 0:
        optimizer.zero_grad(set_to_none=True)
        return {"skipped": True, "reason": "empty valid-token window"}

    loss_scale = world / global_valid
    optimizer.zero_grad(set_to_none=True)

    for micro_idx, batch in enumerate(window):
        sync = micro_idx == len(window) - 1
        context = nullcontext() if sync else model.no_sync()
        with context:
            with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                out = model(**batch)
                loss = lm_loss_sum(out.logits, batch["labels"]) * loss_scale
            scaler.scale(loss).backward()

    scaler.unscale_(optimizer)
    grad_norm = torch.nn.utils.clip_grad_norm_(
        model.parameters(),
        max_grad_norm,
    )

    old_scale = scaler.get_scale()
    scaler.step(optimizer)
    scaler.update()
    new_scale = scaler.get_scale()
    stepped = new_scale >= old_scale
    if stepped:
        scheduler.step()

    return {
        "skipped": not stepped,
        "valid_tokens": int(global_valid.item()),
        "grad_norm": float(grad_norm),
        "loss_scale": float(loss_scale.item()),
    }

这里有几个 contract:

Step Contract
count valid tokens before forward/backward for the whole accumulation window
scale loss numerator use world_size / global_valid_tokens under DDP averaging
no_sync() only disable sync for non-final microbatches
unscale_ before gradient clipping under AMP
clip gradients after all microbatches and after DDP sync
scheduler.step() only after a real optimizer step
WarningPitfall: Scheduler Has Two Clocks

There is a data-consumption clock and an optimizer-update clock. If AMP overflow skips optimizer.step(), the run has consumed data but has not changed parameters. Decide explicitly whether LR schedules follow successful optimizer updates, consumed valid tokens, or wall-clock steps, and log both counters.

对于 token-based schedule,建议至少记录两条计数:

\[ T_{\text{seen}} \leftarrow T_{\text{seen}}+N_{\text{tok}}, \]

\[ T_{\text{updated}} \leftarrow \begin{cases} T_{\text{updated}}+N_{\text{tok}},&\text{optimizer stepped},\\ T_{\text{updated}},&\text{AMP overflow skipped step}. \end{cases} \]

前者用于数据审计和吞吐,后者用于优化进度和 LR schedule。若训练系统在 overflow 时重放同一个 batch,则两者可以重新合并;若直接丢弃该 window,就必须把差异写进日志,否则 resume 后很难解释 loss/LR 对不上。

Accumulation Smoke Tests

token-normalized accumulation 的最小测试,是比较“单进程大 batch”和“多 microbatch accumulation”的梯度是否一致。设同一批数据被分成 \(K\) 个 microbatches:

def flatten_grads(model):
    return torch.cat([
        p.grad.detach().float().flatten()
        for p in model.parameters()
        if p.grad is not None
    ])


def assert_accum_matches_full_batch(model_a, model_b, full_batch, windows):
    out = model_a(**full_batch)
    loss = lm_loss_sum(out.logits, full_batch["labels"])
    denom = valid_label_count(full_batch).clamp_min(1)
    (loss / denom).backward()
    g_full = flatten_grads(model_a)

    denom = sum(valid_label_count(batch) for batch in windows).clamp_min(1)
    for batch in windows:
        out = model_b(**batch)
        loss = lm_loss_sum(out.logits, batch["labels"]) / denom
        loss.backward()
    g_accum = flatten_grads(model_b)

    assert torch.allclose(g_full, g_accum, atol=1e-5, rtol=1e-4)

分布式版本再额外检查:

  1. 每个 rank 的 global_valid_tokens 相同;
  2. final microbatch 才触发 DDP sync;
  3. pre-clip grad norm 在 rank 间一致;
  4. AMP skipped step 不推进 scheduler;
  5. resume 后 T_seenT_updated、sampler position 同时恢复。

这些测试不华丽,但能抓住大多数“loss 看起来差不多,但 scaling 其实错了”的训练 bug。

ZeRO and FSDP

ZeRO 的思想是:data parallel 中很多状态在每张卡上重复保存,可以 shard。

Stage Sharded states
ZeRO-1 optimizer states
ZeRO-2 optimizer states + gradients
ZeRO-3 optimizer states + gradients + parameters

FSDP 与 ZeRO-3 思想相近:参数按 rank 分片,计算某层时 all-gather 该层参数,计算后释放完整参数,只保留 shard。

NoteDefinition: Fully Sharded Data Parallel

FSDP shards model parameters, gradients, and optimizer states across data-parallel workers, all-gathering parameters only when needed for computation.

如果 world size 是 \(R\),理想情况下参数/梯度/optimizer states 可从每卡 \(O(N)\) 降到 \(O(N/R)\),但通信增加。

如果 AdamW 每参数 12 bytes,ZeRO-3/FSDP 理想状态下每卡参数相关状态约为:

\[ M_{\text{state, shard}} \approx \frac{12N}{R}. \]

但 forward/backward 中某个 block 的 full parameters 会被 all-gather 到每张卡,所以 peak memory 不是纯 \(12N/R\)

\[ M_{\text{peak}} \approx \frac{12N}{R} +M_{\text{full block}} +M_{\text{act}} +M_{\text{comm}} +M_{\text{temp}}. \]

这就是 wrapping granularity 重要的原因:block 太大,\(M_{\text{full block}}\) 高;block 太小,通信调用太碎。

FSDP Forward/Backward Timeline

对一个 wrapped block:

forward:
  all-gather full parameters for block
  compute block forward
  optionally discard full parameters

backward:
  all-gather full parameters again if needed
  compute gradients
  reduce-scatter gradients back to shards

这解释了 FSDP 的 trade-off:

  1. 显存下降;
  2. 通信增加;
  3. wrapping granularity 很重要;
  4. checkpoint save/load 也变复杂。
WarningPitfall: Too Fine FSDP Wrapping Can Kill Throughput

Wrapping very small modules can cause excessive all-gather/reduce-scatter overhead. Wrapping too coarsely can keep too many full parameters in memory.

FSDP Prefetch and Overlap

高效 FSDP 不只是“能放下”。它还要把通信藏在计算后面:

compute block k
  while computing: prefetch full params for block k+1
backward block k
  while computing: reduce-scatter grads for block k+1

如果通信无法 overlap,GPU 会出现周期性空转。日志里常见症状是 GPU utilization 低、tokens/sec 不随 GPU 数线性增长、NCCL 时间占比高。

State Dicts and Resume

FSDP checkpoint 可能有几种形态:

Checkpoint style What is stored Use case
full state dict gathered full weights export / inference
sharded state dict per-rank shards training resume
local state dict rank-local layout framework-specific recovery

只保存 full weights 不能恢复 Adam states、scheduler、random state 和 dataloader progress。它能拿来推理,但不能精确继续训练。对大模型训练,checkpoint 设计应该一开始就决定,而不是训练崩了以后再补。

FSDP / ZeRO Configuration Surface

FSDP/ZeRO 的难点不只是“开一个 flag”。真正影响显存和吞吐的是一组配置:

Knob Meaning Trade-off
wrapping policy which modules become sharding units peak full-param memory vs communication calls
mixed precision dtype for params, reductions, buffers memory/bandwidth vs numerical risk
backward prefetch gather next shard before it is needed overlap communication vs peak memory
limit all-gathers throttle outstanding all-gathers lower memory spikes vs less overlap
CPU offload move params/optimizer states to CPU lower GPU memory vs PCIe bottleneck
activation checkpointing recompute block activations lower activation memory vs extra compute
sharded state dict checkpoint shards per rank faster resume vs export complexity

一个 mental model 是:FSDP 把每个 wrapped module 的参数状态分成三种时刻:

resting state: only local shard lives on this rank
compute state: full module parameters are all-gathered
cleanup state: full parameters are freed, gradients reduce-scattered

所以 peak memory 取决于“同时有多少 full modules 被 gather”。如果 prefetch 太 aggressive,可能出现:

current block full params
+ next block prefetched full params
+ activations
+ communication bucket

导致理论 sharding 够用,但实际 peak OOM。

WarningPitfall: FSDP Peak Memory Is a Timeline Property

Average sharded memory can look safe while peak memory fails because all-gather, prefetch, activation, and communication buffers overlap in time.

一个简化的 PyTorch FSDP 配置可以写成:

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, BackwardPrefetch
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

mp = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16,
)

model = FSDP(
    model,
    auto_wrap_policy=transformer_auto_wrap_policy({DecoderBlock}),
    mixed_precision=mp,
    backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
    limit_all_gathers=True,
    use_orig_params=True,
)

这里 auto_wrap_policy 决定 sharding granularity;limit_all_gathers 控制内存尖峰;use_orig_params 影响 optimizer 和参数访问语义。不同 PyTorch 版本细节会变,但排查时永远要问:wrapper 边界在哪里?full params 什么时候出现?gradients 什么时候 reduce-scatter?

Offload and Hierarchical Memory

当 GPU 显存不够时,可以把某些状态 offload 到 CPU 或 NVMe:

Offloaded state Saves GPU memory Cost
optimizer states large for Adam optimizer step waits for CPU/NVMe transfer
gradients moderate backward/step synchronization complexity
parameters large forward/backward needs frequent transfers
activations can help long context recompute/transfer overhead

Offload 的核心约束是带宽。若每步要搬运 \(C\) bytes,而链路带宽是 \(W\) bytes/s,则最小额外时间为:

\[ t_{\text{offload}} \geq \frac{C}{W}. \]

PCIe 带宽远低于 HBM,NVMe 又远低于 HBM。于是 offload 通常是“让 job 能跑起来”的手段,不一定提高吞吐。它适合:

  1. 小集群上勉强 full fine-tune;
  2. LoRA/QLoRA 外再压显存;
  3. batch size 很小、通信/计算不是满载;
  4. 对吞吐不敏感但必须完成的实验。
WarningPitfall: Offload Can Hide OOM by Creating a New Bottleneck

Moving optimizer states or parameters to CPU/NVMe may avoid GPU OOM, but the training step can become dominated by host-device transfer or storage I/O.

调试 offload 时应同时记录:

Metric Why
H2D/D2H transfer time confirms PCIe bottleneck
optimizer step time offloaded states often stall here
GPU idle time shows whether compute waits for transfer
CPU memory usage prevents host OOM
page fault / NVMe throughput catches storage-backed stalls

Tensor Parallelism

Tensor parallelism 把单个矩阵乘法切到多张 GPU 上。例如线性层:

\[ Y=XW. \]

Column parallel 把 \(W\) 按输出维切分:

\[ W=[W_1,W_2,\ldots,W_R], \]

每张卡算一部分输出:

\[ Y_r=XW_r. \]

Row parallel 把 \(W\) 按输入维切分,需要对 partial outputs 做 all-reduce:

\[ Y=\sum_r X_rW_r. \]

Tensor parallel 适合单层矩阵太大或需要降低单卡参数/激活压力的场景,但要求高速互联。

MLP Tensor Parallel Example

Transformer MLP 常见形式:

\[ \operatorname{MLP}(x)=W_2\sigma(W_1x). \]

Column parallel 切 \(W_1\)

\[ W_1=[W_{1,1},\ldots,W_{1,R}], \qquad u_r=W_{1,r}x. \]

每张卡只算一部分 intermediate hidden。然后 \(W_2\) 可以 row parallel:

\[ y=\sum_{r=1}^R W_{2,r}\sigma(u_r). \]

最后需要 all-reduce 得到完整 \(y\)。所以 TP 的通信通常在 row-parallel output 或 attention output 处发生。

Attention Tensor Parallel

多头注意力天然适合按 heads 切分:

\[ H = H_1\cup H_2\cup\cdots\cup H_R. \]

每张卡计算一部分 heads 的 \(Q,K,V\) 和 attention output,再在 output projection 后合并。GQA/MQA 会影响 KV heads 的切分:如果 KV heads 很少,TP rank 太多时可能出现 KV head 不够分或通信/复制设计变复杂。

WarningPitfall: Tensor Parallelism Wants Fast Interconnect

Tensor parallelism communicates inside almost every Transformer layer. It is much more sensitive to NVLink/InfiniBand bandwidth than ordinary data parallelism.

Pipeline Parallelism

Pipeline parallelism 把 layers 分到不同 GPU:

GPU0: layers 0-7
GPU1: layers 8-15
GPU2: layers 16-23
GPU3: layers 24-31

为了减少 pipeline bubble,会把 batch 切成 micro-batches:

microbatch 1 -> stage 0 -> stage 1 -> ...
microbatch 2 -> stage 0 -> stage 1 -> ...

缺点是实现复杂,activation 需要跨 stage 传递,pipeline schedule 会影响吞吐和显存。

Pipeline Bubble

若 pipeline 有 \(S\) 个 stages,microbatch 数为 \(M\),简单 GPipe schedule 的有效利用率近似:

\[ \eta_{\text{pipe}} \approx \frac{M}{M+S-1}. \]

\(M\) 很小时,pipeline bubble 很大。例如 \(S=8,M=4\)

\[ \eta_{\text{pipe}}\approx \frac{4}{11}\approx36\%. \]

所以 pipeline parallelism 需要足够多 microbatches 来填满流水线。但 microbatch 太多又会影响 activation memory、optimizer step latency 和 batch semantics。

1F1B Schedule

1F1B 表示 warmup 后每个 stage 交替做 one-forward-one-backward:

warmup: F1 F2 F3 ...
steady: B1 F4, B2 F5, B3 F6 ...
cooldown: remaining backward

相比先跑完所有 forward 再 backward,1F1B 可以降低 activation 存活时间,因为早期 microbatch 的 backward 更早发生。

WarningPitfall: Pipeline Changes Layer Placement Bugs

When layers live on different devices, tied embeddings, final LM head, loss computation, and activation checkpointing must agree on device placement.

Parallelism Map

Parallelism Splits Solves Cost
data parallel batch throughput replicated states
ZeRO/FSDP states/params memory communication
tensor parallel matrix ops per-layer size high-bandwidth communication
pipeline parallel layers model depth bubbles/scheduling
sequence parallel sequence activations long context activation layout complexity

大规模训练通常组合这些策略,例如 DP+FSDP+TP。小规模微调则更常用 LoRA/QLoRA、gradient accumulation、activation checkpointing。

Sequence Parallelism

Tensor parallel 会让某些 activation 在每张 TP rank 上重复保存。Sequence parallel 把 sequence dimension 也切开,让某些 per-token operations 分摊:

\[ X\in\mathbb{R}^{B\times T\times d} \quad\rightarrow\quad X_r\in\mathbb{R}^{B\times (T/R)\times d}. \]

LayerNorm、dropout、residual 等逐 token 操作可以在 shard 上做。遇到需要全序列的信息时再 all-gather。它的意义主要是长上下文 activation memory,而不是参数 memory。

Choosing a 3D Parallel Strategy

可以把总 GPU 数分解为:

\[ G=R_{\text{dp}}\cdot R_{\text{tp}}\cdot R_{\text{pp}}. \]

粗略原则:

  1. 单层矩阵太大或单卡 matmul 太慢:增大 \(R_{\text{tp}}\)
  2. 层数太多、模型深度装不下:增大 \(R_{\text{pp}}\)
  3. 模型能装下但想提高吞吐:增大 \(R_{\text{dp}}\)
  4. 参数状态装不下:用 FSDP/ZeRO shard DP states;
  5. 长上下文 activation 装不下:activation checkpointing + sequence parallel。

这不是纯数学分解,因为通信拓扑很重要。TP rank 最好放在高速互联组内;DP/FSDP 可以跨节点;PP 跨节点时 stage 边界 activation 传输会变贵。

Gradient Accumulation and Global Batch

有效 batch:

\[ B_{\text{global}} = B_{\text{micro}} \times K_{\text{accum}} \times R_{\text{data}}. \]

如果增大 data parallel world size 但不调整 micro-batch 或 accumulation,global batch 会变大。LR schedule、warmup steps 和 optimizer dynamics 都会变。

WarningPitfall: Scaling GPUs Changes the Optimization Problem

Adding GPUs changes global batch size unless compensated. A run that is numerically stable on 1 GPU can diverge or generalize differently on 64 GPUs if the effective batch and LR schedule change.

Token-Based Batch Size

LLM 训练更应该关心 tokens per update:

\[ T_{\text{update}} = B_{\text{micro}}\cdot T_{\text{seq}}\cdot K_{\text{accum}}\cdot R_{\text{data}}. \]

若使用 variable-length packing,实际有效 label tokens 是:

\[ T_{\text{valid}} = \sum_{r,k,i,t}m_{r,k,i,t}. \]

训练日志里只记录 batch_size 不够,应该记录 tokens/update 和 tokens/sec。否则两个 run 可能 batch size 相同,但一个充满 padding,另一个 packed 得很满,优化和吞吐完全不同。

Learning-Rate Scaling

常见 linear scaling rule 是:

\[ \eta_{\text{new}} \approx \eta_{\text{old}} \cdot \frac{B_{\text{new}}}{B_{\text{old}}}. \]

但它不是定理。大 batch 会降低梯度噪声,可能需要更长 warmup 或不同 weight decay。LLM pretraining 常以 tokens 计 schedule:

\[ \text{warmup ratio} = \frac{\text{warmup tokens}}{\text{total training tokens}}. \]

Checkpointing

训练 checkpoint 至少包含:

  1. model weights;
  2. optimizer states;
  3. scheduler state;
  4. random number generator states;
  5. dataloader/progress state;
  6. sharding metadata if FSDP/ZeRO。

只保存 model.state_dict() 可以用于推理,但通常不能精确 resume training。对长训练,resume correctness 是工程底线。

What Exact Resume Means

NoteDefinition: Exact Training Resume

Exact resume means that continuing from a checkpoint produces the same subsequent parameter trajectory as an uninterrupted run, up to nondeterminism explicitly accepted by the system.

要接近 exact resume,checkpoint 还需要:

State Why
model weights current parameters
optimizer states Adam moments / momentum
scheduler state current LR and warmup/decay position
gradient scaler AMP overflow/loss-scale behavior
RNG states dropout, sampling, data augmentation
dataloader sampler state same next samples
consumed tokens/steps schedule and logging alignment
sharding metadata FSDP/ZeRO layout
tokenizer/config hash data interpretation consistency

一个实用测试是:训练 \(K\) steps 保存 checkpoint,继续 \(M\) steps;另一个 run 从 checkpoint resume 后跑 \(M\) steps。比较 loss、LR、grad norm、参数 hash 是否在可接受误差内一致。

Checkpoint Frequency Trade-off

checkpoint 太频繁会浪费 I/O;太稀疏会让崩溃损失大量训练。若一次 checkpoint 写入 \(C\) GB,文件系统带宽 \(W\) GB/s,保存耗时约:

\[ t_{\text{ckpt}} \approx \frac{C}{W}. \]

如果每 \(S\) 秒保存一次,I/O overhead 约为 \(t_{\text{ckpt}}/S\)。大规模训练会用 async checkpoint、分片保存、对象存储或低优先级后台上传来降低停顿。

Checkpoint Manifests and Fault Tolerance

大训练不要只把 checkpoint 当成一个目录。它应该有 manifest,记录哪些 shard、哪个 step、什么配置、哪些 hash 是一致的:

{
  "step": 120000,
  "consumed_tokens": 503316480000,
  "world_size": 64,
  "parallelism": {"dp": 8, "tp": 4, "pp": 2},
  "model_config_hash": "sha256:...",
  "tokenizer_hash": "sha256:...",
  "data_manifest_hash": "sha256:...",
  "shards": [
    {"rank": 0, "path": "rank_000.pt", "sha256": "..."},
    {"rank": 1, "path": "rank_001.pt", "sha256": "..."}
  ]
}

这个 manifest 的作用是把“能 load”升级成“知道自己 load 了什么”。没有它,常见问题包括:

  1. 某个 rank 的 shard 来自旧 step;
  2. tokenizer/config 变了但 checkpoint 还能加载;
  3. dataloader 继续位置错了,训练样本重复或跳过;
  4. TP/PP/FSDP layout 改了,state dict silently remap;
  5. async checkpoint 尚未完全写完就被当成可恢复点。
NoteDefinition: Checkpoint Atomicity

Checkpoint atomicity means a checkpoint is either fully visible as a consistent recovery point or not visible at all.

一种简单做法是先写临时目录:

ckpt_120000.tmp/
  rank_000.pt
  rank_001.pt
  ...
  manifest.json

所有 shard 和 manifest 校验完成后,再 atomic rename:

ckpt_120000.tmp -> ckpt_120000

分布式文件系统不一定保证跨目录 rename 完全符合直觉,所以工程上还会用 COMMITTED 标记文件或对象存储事务语义。恢复时只扫描带 committed marker 的 checkpoint。

Failure Modes to Rehearse

训练系统应该主动演练故障,而不是等真崩溃:

Failure Expected recovery behavior
one rank killed all ranks exit or elastic restart from last committed checkpoint
checkpoint interrupted incomplete checkpoint ignored
dataloader worker crash job fails loudly or restarts without sample duplication
NCCL timeout diagnostics include rank, collective, tensor size
disk full checkpoint fails before deleting last good checkpoint
resume with wrong tokenizer hash check aborts

一个最小 fault-tolerance smoke test:

run 100 steps
save checkpoint
kill one worker during next checkpoint
restart job
verify it resumes from the last committed checkpoint
compare consumed tokens, LR, loss scale, and sampler position
WarningPitfall: Last Checkpoint Should Not Be Deleted First

Checkpoint rotation must never delete the last known-good checkpoint before the new checkpoint is fully committed and verified.

Choosing a Strategy

Scenario Strategy
single GPU, limited memory LoRA/QLoRA + grad accumulation
full fine-tune 7B on multi-GPU FSDP/ZeRO + checkpointing
pretrain large dense model DP + TP + PP + checkpointing
long-context SFT FlashAttention + activation checkpointing
fragmented single GPUs independent workers / LoRA jobs / inference rollouts

训练系统没有万能配置。先算 memory bill,再决定是减少 trainable parameters、shard states、切矩阵、切层,还是牺牲 compute 做 recompute。

Throughput, MFU, and Bottlenecks

tokens/sec 是最直接的系统指标:

\[ \operatorname{TPS} = \frac{\text{valid tokens processed}}{\text{wall-clock seconds}}. \]

MFU,也就是 model FLOPs utilization,粗略衡量实际计算达到硬件峰值的比例:

\[ \operatorname{MFU} = \frac{\text{model FLOPs per second}} {\text{hardware peak FLOPs per second}}. \]

对 dense decoder-only Transformer,一个常见粗略估计是每 token 训练 FLOPs 约为:

\[ 6N \]

其中 \(N\) 是参数量。于是:

\[ \operatorname{FLOPs/sec} \approx 6N\cdot \operatorname{TPS}. \]

这个估计忽略 attention 的 \(T^2\) 项和 embedding/head 细节,但足够用来判断 run 是否离谱。如果 MFU 很低,可能瓶颈在:

Symptom Likely bottleneck
GPU utilization sawtooth dataloader or checkpoint stalls
high NCCL time FSDP/TP communication
memory allocated far below max but OOM fragmentation / temp buffers
tokens/sec drops with longer sequence attention or activation memory
CPU high, GPU waiting tokenization / packing
resume changes loss curve missing optimizer/RNG/sampler state
WarningPitfall: Samples/sec Can Lie

For variable-length language data, samples/sec is not comparable across runs. Use valid tokens/sec and tokens/update.

Data Pipeline and Collator

训练系统常常不是 GPU 算不动,而是数据喂不动。LLM data path 通常是:

raw documents
  -> filtering / dedup
  -> tokenizer
  -> packing / chunking
  -> collator
  -> device transfer
  -> forward/backward

collator 决定 input_idsattention_masklabelsposition_ids 和 loss mask。一个 SFT collator 可能要做:

labels = input_ids.clone()
labels[attention_mask == 0] = -100
labels[prompt_mask == 1] = -100

packing 时还可能需要 block-diagonal attention mask,防止一个样本 attend 到另一个样本:

\[ M_{ij} = \begin{cases} 0, & \operatorname{doc}(i)=\operatorname{doc}(j)\text{ and }j\leq i,\\ -\infty, & \text{otherwise}. \end{cases} \]

WarningPitfall: Collator Bugs Look Like Model Bugs

Wrong labels, position ids, or attention masks can produce stable-looking training loss while teaching the model the wrong conditional distribution.

Streaming Data, Shards, and Determinism

大规模 LLM 数据通常不是一个小 Dataset,而是一堆 shard:

shard_00000.jsonl.zst
shard_00001.jsonl.zst
...

每个 worker 需要知道自己读哪些 shard、从哪里继续、如何 shuffle。若目标是 exact-ish resume,dataloader state 至少包括:

State Why
shard order same data permutation
offset within current shard resume at the right record
document/token buffer packing may span records
RNG state deterministic shuffle and augmentation
consumed tokens align LR schedule and logging
epoch / sample cursor avoid repeat or skip

对于 streaming packing,样本边界经常被打散成 token blocks:

doc A tokens + EOS + doc B tokens + EOS -> block length T

这意味着 “sample index” 不再足够描述训练进度。更可靠的是记录 consumed valid tokens 和 packer buffer state。否则 resume 后可能从同一个 document 中间重新切块,造成重复数据或不同 attention boundary。

WarningPitfall: Resume Can Change the Data Distribution

If streaming shuffle, packing buffers, or shard offsets are not restored, a resumed run may see a different token order even when model and optimizer states are correct.

Data Throughput Budget

GPU 需要的输入 token 速率由训练吞吐决定。如果目标是

\[ \operatorname{TPS}_{\text{train}} = \text{tokens/sec}, \]

那么 tokenizer、decompression、packing、host-to-device copy 至少要提供同样的速率,并留出余量:

\[ \operatorname{TPS}_{\text{data}} \geq \rho\operatorname{TPS}_{\text{train}}, \qquad \rho>1. \]

\(\rho\) 接近 \(1\),任何 shard 慢读、CPU 抖动、网络文件系统延迟都会让 GPU 等数据。常见优化:

Bottleneck Fix
tokenizer CPU slow offline tokenize to ids
decompression slow use larger workers / faster codec / local cache
network storage slow prefetch shards to local NVMe
packing slow vectorized packer, prepacked token blocks
H2D copy stalls pinned memory and async transfer
worker imbalance shard size balancing and monitoring

一个训练日志应该同时有:

data_tokens/sec
gpu_tokens/sec
dataloader_wait_ms
host_to_device_ms
packer_buffer_fill

如果 gpu_tokens/sec 低,先看 dataloader_wait_ms,别急着改模型并行配置。

Minimal Debug Protocol

当一个 training run 不稳,不要先调一堆超参。先做小规模闭环:

  1. 用 1 个 batch overfit,确认 loss 能接近 0;
  2. 打印一条样本的 input_ids -> labels shift;
  3. 检查 pad/prompt positions 是否是 -100
  4. 记录 pre-clip grad norm 和 optimizer LR;
  5. 跑 20 steps,保存 checkpoint,再 resume 20 steps;
  6. 比较 uninterrupted 和 resumed run 的 LR/loss/step;
  7. 再打开 FSDP/TP/PP,而不是一开始就全开。

这条流程朴素,但能把 bug 分层:objective bug、data bug、optimizer bug、distributed bug、I/O bug,不会把所有问题都怪给“模型不好训”。

Engineering Checklist

Check Why
peak allocated/reserved memory 发现 fragmentation 和 temporary buffers
tokens/sec 判断吞吐是否合理
MFU 估计 GPU 利用率
global batch 保证优化语义一致
valid tokens/update 确认 variable-length packing 的真实 denominator
loss scale factor 检查 DDP token-normalized scaling 是否正确
gradient norm 发现 scale/accumulation 错误
AMP skipped steps 避免 scheduler 和 optimizer step 错位
checkpoint resume 验证能否真实恢复
communication time 判断 TP/FSDP 是否通信瓶颈
dataloader throughput 避免 GPU 等数据

LLM 训练系统的核心不是某个库名,而是把内存、计算、通信和优化语义同时对齐。

References