3.8 Optimizer Engineering in PyTorch
Optimizer 的数学公式只说明“应该怎么更新参数”,但真正训练大模型时,错误往往出在工程顺序:什么时候 zero_grad,loss 是否除以 accumulation steps,AMP 里是否先 unscale 再 clip,AdamW 的 weight decay 是否误加进 adaptive gradient,scheduler 是按 micro-step 还是 optimizer-step 走,DDP 下 gradient norm 是局部还是全局。
这一节把 optimizer 从公式落到训练循环。
One Training Step as a State Machine
一个训练 step 不是一行 optimizer.step(),而是一组状态转换:
batch -> forward -> loss -> backward -> gradient transforms -> optimizer update -> scheduler -> zero_grad
更细一点:
- load micro-batch;
- forward 计算 loss;
- scale/normalize loss;
- backward 累积 gradients;
- unscale gradients if AMP;
- all-reduce gradients if distributed;
- clip or transform gradients;
- optimizer updates parameters and optimizer states;
- scheduler updates LR;
- clear gradients。
任何一步放错,训练都可能“能跑但学歪”。
An optimizer step is the moment when model parameters are updated. In gradient accumulation, many backward passes may correspond to one optimizer step.
实际训练至少有三套计数器:
| counter | increments when | used for |
|---|---|---|
| micro step | every micro-batch forward/backward | dataloader progress, accumulation |
| optimizer step | parameters actually update | LR scheduler, Adam bias correction |
| token/example count | valid samples or tokens consumed | throughput, token budget, loss denominator |
把这些混在一起会产生很隐蔽的 bug。比如 warmup 计划写成 1000 optimizer steps,但代码每个 micro-step 都 scheduler.step(),实际 warmup 会快 \(K_{\text{accum}}\) 倍结束。
Most LR schedules should advance only after a successful optimizer update. Advancing on every micro-step changes the intended schedule under gradient accumulation.
一个更准确的状态机可以写成:
for each micro-batch:
forward under autocast
compute local normalized loss contribution
backward accumulates .grad
if not update boundary:
maybe skip DDP all-reduce with no_sync
continue
unscale if using GradScaler
check finite gradients
clip/transform gradients
optimizer.step updates parameters and moments
scheduler.step advances optimizer-step clock
zero_grad clears gradient buffers
log update-level metrics
注意 “successful optimizer update” 这个词。FP16 AMP overflow 时,scaler.step(optimizer) 可能跳过真正的 parameter update;此时 scheduler 和 global optimizer step 通常也不应该前进。
AdamW Exact Update
Adam 维护一阶和二阶动量:
\[ m_t=\beta_1m_{t-1}+(1-\beta_1)g_t, \]
\[ v_t=\beta_2v_{t-1}+(1-\beta_2)g_t^2. \]
因为 \(m_0=v_0=0\),早期估计偏向 0,需要 bias correction:
\[ \hat{m}_t=\frac{m_t}{1-\beta_1^t}, \qquad \hat{v}_t=\frac{v_t}{1-\beta_2^t}. \]
Adam update:
\[ \theta_{t+1} = \theta_t-\eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t}+\epsilon}. \]
AdamW 把 weight decay 从 adaptive gradient 中解耦:
\[ \theta_{t+1} = (1-\eta\lambda)\theta_t - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t}+\epsilon}. \]
更接近实现的顺序是:
grad -> update m/v -> bias correction -> adaptive step -> decoupled decay -> parameter write
有些实现把 decoupled decay 写在 adaptive step 之前:
\[ \theta \leftarrow \theta(1-\eta\lambda), \qquad \theta \leftarrow \theta-\eta\frac{\hat m}{\sqrt{\hat v}+\epsilon}. \]
这和上面的闭式写法等价,因为两项都使用当前 step 的 \(\eta\)。关键是 weight decay 不进入 \(m_t,v_t\)。
假设 \(g_t\) 是同分布随机变量,\(\mathbb{E}[g_t]=\mu\)。对
\[ m_t=\beta m_{t-1}+(1-\beta)g_t,\quad m_0=0 \]
展开:
\[ m_t=(1-\beta)\sum_{i=1}^{t}\beta^{t-i}g_i. \]
取期望:
\[ \mathbb{E}[m_t] =(1-\beta)\sum_{i=1}^{t}\beta^{t-i}\mu = (1-\beta^t)\mu. \]
所以 \(m_t\) 低估了真实均值 \(\mu\),除以 \(1-\beta^t\) 后才无偏。
Adam 的每坐标 effective step size 可以写成
\[ \eta_{t,i}^{\text{eff}} = \frac{\eta}{\sqrt{\hat v_{t,i}}+\epsilon}. \]
这就是 Adam 像 diagonal preconditioner 的原因:历史梯度方差大的坐标步长小,方差小的坐标步长大。工程上要监控的是 update 本身:
\[ \Delta\theta_t = -\eta\frac{\hat m_t}{\sqrt{\hat v_t}+\epsilon} -\eta\lambda\theta_t. \]
如果
\[ \frac{\|\Delta\theta_t\|}{\|\theta_t\|} \]
长期过大,训练会像在重置参数;长期极小,模型可能几乎不动。
The update-to-weight ratio is \(\|\Delta\theta\|/\|\theta\|\) for a parameter group or module. It measures how large one optimizer step is relative to the current parameter scale.
eps 也不是无关紧要的小常数。若 \(\sqrt{\hat v}\) 很小,\(\epsilon\) 主导 denominator,Adam 退化成近似 momentum SGD with scale \(1/\epsilon\)。在 FP16/BF16、低秩 adapter、小 batch 或稀疏梯度里,eps 会影响早期更新尺度。
现代 PyTorch optimizer 还有 foreach、fused、capturable 等实现选项:
| option | meaning | caveat |
|---|---|---|
foreach=True |
batch many tensor ops into foreach kernels | more temporary memory |
fused=True |
use fused CUDA optimizer kernel | fastest when supported, less universal |
capturable=True |
make optimizer CUDA-graph capturable | step/LR tensors must be device-friendly |
| 8-bit optimizer | quantize optimizer states | lower memory, possible numerical drift |
这些选项不应该改变数学 update,但会改变速度、内存、可复现性和可用硬件范围。
L2 Regularization vs. Weight Decay
在普通 SGD 中,L2 penalty:
\[ \min_\theta L(\theta)+\frac{\lambda}{2}\|\theta\|^2 \]
给出 gradient:
\[ g_t=\nabla L(\theta_t)+\lambda\theta_t. \]
SGD update:
\[ \theta_{t+1} = \theta_t-\eta(\nabla L(\theta_t)+\lambda\theta_t) = (1-\eta\lambda)\theta_t-\eta\nabla L(\theta_t). \]
所以在 SGD 下,L2 和 weight decay 等价。
但 Adam 中,如果把 \(\lambda\theta\) 加到 \(g_t\),它会进入 \(m_t,v_t\),再被坐标级 adaptive denominator 缩放。这就不再等价于简单的参数收缩。AdamW 的核心就是:
- 用 loss gradient 更新 Adam moments;
- 另行对参数做 multiplicative shrinkage。
weight_decay May Mean Different Things
Modern PyTorch AdamW implements decoupled weight decay. Older optimizers or custom code may implement L2 regularization by adding \(\lambda\theta\) to gradients. They are not equivalent under adaptive scaling.
Parameter Groups
Transformer/LLM 训练里,通常把参数分组:
decay = []
no_decay = []
for name, p in model.named_parameters():
if not p.requires_grad:
continue
if name.endswith(".bias") or "norm" in name.lower():
no_decay.append(p)
else:
decay.append(p)
optimizer = torch.optim.AdamW(
[
{"params": decay, "weight_decay": 0.1},
{"params": no_decay, "weight_decay": 0.0},
],
lr=3e-4,
betas=(0.9, 0.95),
eps=1e-8,
)为什么 bias 和 norm parameters 常常不 decay?
- bias 主要平移 activation,参数量很小,decay 对泛化帮助有限;
- LayerNorm/RMSNorm scale 负责表示尺度,强行 decay 可能扰乱 normalization;
- 对大模型而言,regularization 更常来自数据规模、dropout、weight decay on matrix weights、early stopping 或 post-training。
这不是数学定理,而是工程惯例。某些训练 recipe 会 decay embedding 或对不同层设置不同 LR,因此要看具体模型族。
更稳妥的参数分组要检查三件事:
- 每个 trainable parameter 恰好进入一个 group;
- 每个 group 的 hyperparameters 符合预期;
- 日志能显示每组参数量、LR、weight decay。
decay_names = set()
no_decay_names = set()
for name, p in model.named_parameters():
if not p.requires_grad:
continue
if p.ndim < 2 or name.endswith(".bias") or "norm" in name.lower():
no_decay_names.add(name)
else:
decay_names.add(name)
param_map = {name: p for name, p in model.named_parameters() if p.requires_grad}
assert not (decay_names & no_decay_names)
assert decay_names | no_decay_names == set(param_map)
groups = [
{"params": [param_map[n] for n in sorted(decay_names)], "weight_decay": 0.1},
{"params": [param_map[n] for n in sorted(no_decay_names)], "weight_decay": 0.0},
]对 fine-tuning 还可能有第三类 group:
| group | LR | decay | example |
|---|---|---|---|
| frozen | 0 / not in optimizer | none | base model frozen |
| adapter | high | often 0 | LoRA A/B, prefix params |
| head | medium/high | task-dependent | classifier or reward head |
| backbone | low | matrix weights only | full fine-tune |
如果 LoRA 参数和 backbone 参数共用同一个 LR,可能出现 adapter 学太慢或 backbone 漂移太快。参数分组不是样板代码,而是训练假设。
Gradient Accumulation
当显存放不下大 batch,可以把一个 large batch 拆成 \(K\) 个 micro-batches。有效梯度应是平均:
\[ g = \frac{1}{K}\sum_{k=1}^{K}g_k. \]
所以 loss 要除以 \(K\):
optimizer.zero_grad(set_to_none=True)
for i, batch in enumerate(loader):
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
loss = model(**batch).loss
loss = loss / grad_accum_steps
loss.backward()
if (i + 1) % grad_accum_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad(set_to_none=True)如果忘记除以 \(K\),等价于把 learning rate 放大 \(K\) 倍。小模型可能只是 loss 抖,大模型可能直接 overflow。
Effective batch size is \[ B_{\text{eff}} = B_{\text{micro}}\times K_{\text{accum}}\times N_{\text{data parallel}}. \] It is the number of examples contributing to one optimizer step.
Scheduler 通常按 optimizer step 走,而不是按 micro-step 走。否则 warmup 会提前结束,LR 曲线会偏快。
对 token-level LM,loss / grad_accum_steps 只有在每个 micro-batch token 数一样时才严格等价。更稳的做法是累积 numerator,并用全局 token denominator 缩放:
token_loss_sum = (loss_per_token * loss_mask).sum()
token_count = loss_mask.sum().clamp_min(1)
loss = token_loss_sum / token_count
loss = loss / grad_accum_steps如果每个 micro-batch 有不同有效 token 数,这仍然是“micro-batch 平均再 accumulation 平均”。要得到真正的 large-batch token mean,需要按整个 accumulation window 的总 token 数归一化。实践中常用两种近似:
| strategy | exact token mean? | trade-off |
|---|---|---|
| fixed-length packed batches | yes-ish | collator more complex |
| divide by local valid tokens per micro-batch | no | simple, usually acceptable |
| precompute accumulation-window token count | yes | needs two-pass or delayed backward |
在 DDP 下,非 update boundary 的 micro-step 可以用 no_sync() 避免每次 backward 都 all-reduce:
from contextlib import nullcontext
for micro_step, batch in enumerate(loader):
is_update = (micro_step + 1) % grad_accum_steps == 0
ctx = model.no_sync() if ddp and not is_update else nullcontext()
with ctx:
loss = compute_loss(batch) / grad_accum_steps
loss.backward()
if is_update:
optimizer.step()
optimizer.zero_grad(set_to_none=True)如果忘记 no_sync(),结果通常还是对的,但通信量放大 \(K_{\text{accum}}\) 倍;如果错误地在 update step 也 no_sync,则各卡参数会开始分叉。
Mixed Precision and AMP
BF16/FP16 训练减少显存和提高吞吐,但梯度数值范围变窄。AMP 的关键顺序:
scaler.scale(loss).backward()
if ready_to_step:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)为什么 clip 必须在 unscale 后?
如果 scaled loss 是
\[ \tilde{L}=sL, \]
则 backward 得到
\[ \tilde{g}=s\nabla L. \]
若直接 clip \(\tilde{g}\) 到阈值 \(\tau\),等价于把真实梯度 clip 到 \(\tau/s\)。loss scale 越大,真实阈值越小,训练被错误压制。正确做法是先
\[ g=\tilde{g}/s, \]
再 clip。
FP16 often needs dynamic loss scaling. BF16 has a wider exponent range and usually does not need scaling, though it still benefits from mixed precision kernels.
AMP 还影响 scheduler:如果 GradScaler 检测到 Inf/NaN,它会跳过 optimizer step。一个保守写法是比较 scale 是否下降:
old_scale = scaler.get_scale()
scaler.step(optimizer)
scaler.update()
new_scale = scaler.get_scale()
step_was_skipped = new_scale < old_scale
if not step_was_skipped:
scheduler.step()
global_step += 1如果 step 被跳过但 scheduler 前进,LR schedule 和 Adam bias-correction step 会逐渐不同步。对于长训练,这种偏移通常不大;对于短 warmup 或频繁 overflow,它会很明显。
混合精度下还有几个 dtype 约定:
| tensor/state | common dtype | reason |
|---|---|---|
| model weights | BF16/FP16 or FP32 master | memory/compute |
| gradients | BF16/FP16 bucket or FP32 accumulated | communication/precision |
| Adam moments | FP32 | stable statistics |
| loss scale | FP32 scalar | overflow control |
| reductions for norms | FP32 | avoid overflow/underflow |
因此 grad_norm、loss denominator、optimizer states 最好用 FP32 语义处理,不要为了省一点内存把所有统计量都留在低精度。
Optimizer State Memory
AdamW 至少保存:
- parameters;
- gradients;
- first moment \(m\);
- second moment \(v\);
- sometimes FP32 master weights。
若参数量为 \(N\),BF16 参数和梯度各 2 bytes,Adam states 为 FP32 各 4 bytes,则粗略显存:
\[ \text{memory} \approx 2N + 2N + 4N + 4N = 12N\text{ bytes}. \]
如果还有 FP32 master weights,再加 \(4N\)。所以一个 7B 模型的 optimizer state 远大于权重本身:
\[ 12\times 7\text{B} \approx 84\text{ GB}. \]
这解释了为什么 pretraining 需要 ZeRO/FSDP/offload,而 inference 只加载权重和 KV cache。
不同优化器/策略的状态量不同:
| optimizer/strategy | state per parameter | memory intuition |
|---|---|---|
| SGD | optional momentum | small |
| AdamW | \(m\), \(v\) | large but standard |
| Adafactor | factored second moment | lower for matrices |
| 8-bit Adam | quantized \(m\), \(v\) | lower memory, extra quantization |
| ZeRO/FSDP | sharded params/grads/states | each rank holds a shard |
Adafactor 对矩阵二阶矩做行/列分解,避免保存完整 \(v\in\mathbb{R}^{m\times n}\)。这对超大 embedding 或 LM head 很有吸引力,但 update 语义和 AdamW 不完全一样,不能只当作“省显存 Adam”。
Distributed Gradients
DDP 默认会 all-reduce gradients。若每张卡有 local batch \(B\),world size 为 \(N\),DDP 后每个参数上的梯度通常是各卡平均:
\[ g = \frac{1}{N}\sum_{r=1}^{N}g^{(r)}. \]
global norm clipping 应该基于所有参数的全局 norm:
\[ \|g\|_2 = \sqrt{\sum_p\|g_p\|_2^2}. \]
在 ZeRO/FSDP 中,每张卡只持有一部分参数或 optimizer state,因此不能天真地只对本地 shard 算 norm。正确实现需要跨 rank 汇总 squared norm。
For sharded training, clipping each shard independently changes the update direction. Global norm clipping needs a cross-rank reduction over squared gradient norms.
DDP、ZeRO、FSDP 的 optimizer 语义不同:
| method | params | gradients | optimizer state | main concern |
|---|---|---|---|---|
| DDP | replicated | all-reduced | replicated | communication cost |
| ZeRO-1 | replicated | all-reduced | sharded | optimizer memory |
| ZeRO-2 | replicated | sharded/reduced | sharded | grad memory |
| ZeRO-3/FSDP | sharded | sharded | sharded | gather/scatter and checkpointing |
在 sharded 训练中,optimizer.state_dict() 可能不是普通单机 state dict。保存/加载要使用框架提供的 full/sharded state dict API,否则可能得到只能在同样 world size 下恢复的 checkpoint。
Checkpoint and Exact Resume
能 resume 不等于能 exact resume。完整 checkpoint 至少包括:
| state | why |
|---|---|
| model weights | parameters |
| optimizer state | Adam moments, step count |
| scheduler state | LR phase and warmup progress |
| GradScaler state | FP16 scale and growth tracker |
| RNG states | dropout, sampling, augmentation |
| dataloader/sampler state | data order |
| global counters | micro step, optimizer step, tokens seen |
| config/git hash | reproducibility |
ckpt = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"scaler": scaler.state_dict() if scaler is not None else None,
"global_step": global_step,
"tokens_seen": tokens_seen,
"rng_cpu": torch.get_rng_state(),
"rng_cuda": torch.cuda.get_rng_state_all(),
}
torch.save(ckpt, path)Loading only model weights restarts optimizer moments, LR schedule, loss scale, and data order. That is fine for fine-tuning from a checkpoint, but it is not an exact continuation of the same run.
Resume 后的第一个 sanity check:
- LR 是否等于保存前下一步应有的 LR;
- optimizer step 是否连续;
- loss scale 是否恢复;
- sampler 是否没有重复/跳过大量数据;
- fixed batch 上 loss 是否和保存前接近。
Numerically Stable Training Loop
一个更完整的 pattern:
optimizer.zero_grad(set_to_none=True)
for micro_step, batch in enumerate(loader):
is_update = (micro_step + 1) % grad_accum_steps == 0
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
out = model(**batch)
loss = out.loss / grad_accum_steps
loss.backward()
if not is_update:
continue
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=1.0,
)
optimizer.step()
scheduler.step()
optimizer.zero_grad(set_to_none=True)
log({
"loss": loss.item() * grad_accum_steps,
"grad_norm": grad_norm.item(),
"lr": scheduler.get_last_lr()[0],
})这段代码省略了 DDP no_sync()、GradScaler、checkpointing 和 error handling,但顺序是对的。
带 AMP、accumulation、scheduler skip 的版本:
from contextlib import nullcontext
optimizer.zero_grad(set_to_none=True)
for micro_step, batch in enumerate(loader):
is_update = (micro_step + 1) % grad_accum_steps == 0
sync_ctx = model.no_sync() if ddp and not is_update else nullcontext()
with sync_ctx:
with torch.autocast("cuda", dtype=torch.float16):
loss = model(**batch).loss / grad_accum_steps
scaler.scale(loss).backward()
if not is_update:
continue
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
old_scale = scaler.get_scale()
scaler.step(optimizer)
scaler.update()
step_skipped = scaler.get_scale() < old_scale
if not step_skipped:
scheduler.step()
global_step += 1
optimizer.zero_grad(set_to_none=True)BF16 版本通常去掉 GradScaler:
with torch.autocast("cuda", dtype=torch.bfloat16):
loss = model(**batch).loss / grad_accum_steps
loss.backward()这里真正重要的是顺序,而不是模板长什么样:先 accumulation,再 unscale,再 clip,再 step,再 scheduler,再 zero grad。
What to Log
只记录 loss 不够。至少记录:
| Metric | Why |
|---|---|
| LR | 判断 warmup/decay 是否按 optimizer step 走 |
| grad norm | 发现 spike、爆炸梯度、错误 loss scale |
| clip ratio | 判断 clipping 是否成为常态 |
| optimizer step | 区分 micro-step 和 update-step |
| tokens/sec | 判断 dataloader/kernel/communication bottleneck |
| loss scale | FP16 AMP 稳定性 |
| parameter norm | 检查 weight decay 是否过强或失效 |
| update/weight ratio | 判断 step 是否过大或过小 |
| tokens seen | 对齐 token-budget schedule |
| skipped steps | AMP overflow 或异常 batch |
| data time vs step time | 判断输入管线瓶颈 |
工程上,“训练看起来在跑”不等于训练正确。optimizer engineering 的目标是让数学 update、batch 语义、分布式语义和数值精度语义一致。
Update/weight ratio 可以这样记录:
@torch.no_grad()
def update_ratio(params_before, model):
ratios = {}
for name, p in model.named_parameters():
if not p.requires_grad:
continue
delta = p - params_before[name]
ratios[name] = delta.norm() / p.norm().clamp_min(1e-12)
return ratios真实训练不一定每步保存完整 params_before,但可以抽样记录少数模块,或用 optimizer step 里的 update norm 统计。
Debugging Protocol
当训练“不对劲”时,按下面顺序缩小问题:
- 单卡、无 AMP、无 accumulation 跑一个 batch,确认 loss 会下降;
- 加 AMP,确认无 overflow/NaN,GradScaler 不频繁跳步;
- 加 accumulation,确认等效 batch 的 loss scale 和 LR 不变;
- 加 DDP,确认各 rank 初始参数一致、loss 接近;
- 加
no_sync(),确认 update step 后参数仍同步; - 加 scheduler,确认 warmup/decay 按 optimizer step 前进;
- 加 checkpoint resume,确认 LR、step、loss scale、loss 连续。
Before launching a long run, overfit a tiny batch for 20-100 optimizer steps. If the model cannot overfit a tiny batch, the optimizer loop, loss mask, LR, or data path is probably wrong.
Implementation Checklist
optimizer.zero_grad(set_to_none=True)是否在正确位置;- accumulation 时 loss 是否按预期缩放;
- scheduler 是否只在 successful optimizer step 后前进;
- AMP 是否先
unscale_再 clip; - FP16 overflow 时是否跳过 scheduler/global step;
- DDP accumulation 是否正确使用
no_sync(); - global norm clipping 在 FSDP/ZeRO 下是否真是 global;
- parameter groups 是否覆盖所有 trainable params 且不重复;
- optimizer state memory 是否估算过,checkpoint 是否可恢复;
- 日志是否区分 micro step、optimizer step、tokens seen;
- resume 是否恢复 optimizer/scheduler/scaler/RNG/sampler;
- tiny-batch overfit 是否通过。
两个 smoke tests:
# 1. parameter groups partition trainable params
grouped = {id(p) for g in optimizer.param_groups for p in g["params"]}
trainable = {id(p) for p in model.parameters() if p.requires_grad}
assert grouped == trainable
# 2. scheduler advances only on optimizer update boundary
micro_steps = 8
grad_accum_steps = 4
expected_updates = micro_steps // grad_accum_steps
assert scheduler_step_count == expected_updates这些测试很朴素,但能抓住 optimizer engineering 里最常见的沉默错误:参数漏进 optimizer、scheduler 走错时钟、accumulation 改变了实际 LR。
References
- Adam: A Method for Stochastic Optimization, Kingma and Ba.
- Decoupled Weight Decay Regularization, Loshchilov and Hutter.