1.5 Autograd and the Training Loop


深度学习代码最容易给人一种错觉:只要写出

loss.backward()
optimizer.step()

模型就会自动学习。但 PyTorch 的 autograd 不是魔法,它只是把一次 forward 中发生的 tensor operations 记录成一张动态计算图,然后按 chain rule 反向传播。训练 loop 的可靠性来自一个非常具体的状态流:

parameters -> forward graph -> scalar loss -> backward gradients -> optimizer update

这一节把 autogradnn.Module、loss、optimizer 和 mode switching 连成一个完整训练步骤。目标不是背 API,而是看懂每一行训练代码到底改变了哪些状态。

Dynamic Computation Graph

NoteDefinition: Dynamic Computation Graph

A dynamic computation graph is a graph built at runtime from the tensor operations executed during one forward pass. PyTorch records the operations needed to compute gradients and discards the graph after backward by default.

假设我们有

x = torch.randn(4, 3)
w = torch.randn(3, 2, requires_grad=True)
b = torch.zeros(2, requires_grad=True)

logits = x @ w + b
loss = logits.pow(2).mean()
loss.backward()

这里真正需要求梯度的是 \(w,b\),因为它们是可训练参数。x @ w+ bpowmean 会在 forward 时形成一条计算链:

w ---- matmul ---- add ---- pow ---- mean ---- loss
                  /
b ---------------

backward() 从标量 loss 出发,沿着图反向应用 chain rule:

\[ \frac{\partial L}{\partial w} = \frac{\partial L}{\partial \operatorname{logits}} \frac{\partial \operatorname{logits}}{\partial w}, \qquad \frac{\partial L}{\partial b} = \frac{\partial L}{\partial \operatorname{logits}} \frac{\partial \operatorname{logits}}{\partial b}. \]

因为图是在 Python 执行过程中即时构造的,所以 iffor、递归、不同 batch 的不同分支都可以参与建图。这是 PyTorch eager mode 的核心便利:模型结构可以像普通程序一样写。

WarningPitfall: The Graph Is Per Forward Pass

PyTorch does not keep one permanent graph for the whole model. Each forward pass builds a fresh graph for the operations actually executed in that pass.

这意味着两件事:

  1. 如果某个参数在本次 forward 没有参与 loss,它这次不会得到梯度;
  2. loss.backward() 默认释放本次图的中间状态,不能对同一个 graph 反复 backward,除非显式使用 retain_graph=True

Saved Tensors and Version Counters

反向传播并不是只记住 operation 名字。很多 backward 公式还需要 forward 的中间值。例如

\[ y=\operatorname{relu}(x) \]

的 backward 需要知道哪些位置 \(x>0\);softmax-cross-entropy 的 backward 需要概率分布;矩阵乘法

\[ Y=XW \]

\(W\) 的梯度是

\[ \frac{\partial L}{\partial W} = X^\top \frac{\partial L}{\partial Y}, \]

所以 backward 至少需要能拿到 forward 时的 \(X\)。这些被 backward 函数保留下来的值通常称为 saved tensors。

NoteDefinition: Saved Tensor

A saved tensor is an intermediate tensor or metadata object stored by an autograd node during forward because its backward formula needs it later.

这也是 activation memory 的来源之一。训练时显存通常不只由参数决定,还包括:

\[ \text{memory} \approx \text{parameters} + \text{gradients} + \text{optimizer states} + \text{saved activations}. \]

当模型很深时,saved activations 往往随 batch size、sequence length、hidden size 和 layer 数增长。activation checkpointing 的基本思想就是不保存某些中间 activation,而是在 backward 时重新执行一段 forward,用额外计算换显存。

NoteDefinition: Version Counter

A version counter is metadata attached to tensors that increments when an in-place operation mutates tensor storage. Autograd uses it to detect whether a saved value was changed before backward.

考虑一个简化例子:

x = torch.randn(4, requires_grad=True)
u = x * 1.0        # non-leaf tensor with autograd history
y = u.pow(2)
u.add_(1.0)        # unsafe if backward needs the old u
loss = y.sum()
loss.backward()

pow(2) 的 backward 需要 forward 时的 \(u\),因为

\[ \frac{d}{du}u^2=2u. \]

如果 u 在 backward 前被原地改掉,使用新值会得到错误梯度。PyTorch 通过 version counter 发现 saved tensor 的版本变化,从而报错。这类报错不是“PyTorch 太严格”,而是在阻止 silent wrong gradient。

WarningPitfall: Backward May Need Values, Not Only Shapes

If a backward formula needs the original forward value, mutating that value in-place can invalidate the gradient. Shape-compatible code can still be mathematically wrong.

Leaf Tensors and Parameters

NoteDefinition: Leaf Tensor

A leaf tensor is a tensor created by the user rather than produced as the result of an autograd-tracked operation. If a leaf tensor has requires_grad=True, PyTorch accumulates its gradient in .grad during backward.

大多数时候,模型参数都是 leaf tensors:

linear = torch.nn.Linear(3, 2)
for name, p in linear.named_parameters():
    print(name, p.is_leaf, p.requires_grad)

典型输出是:

weight True True
bias   True True

而 forward 里的中间结果通常不是 leaf:

y = linear(x)
print(y.is_leaf)   # False
print(y.grad_fn)   # AddmmBackward0 or similar

中间结果有 grad_fn,说明它知道自己由哪个 operation 产生。反传时,PyTorch 用这些 grad_fn 节点把梯度传回 leaf parameters。

WarningPitfall: .grad Is Usually Populated Only on Leaf Tensors

Intermediate tensors do not keep .grad by default. Use tensor.retain_grad() only when debugging intermediate gradients.

这点对调试很重要。如果你写:

h = model.encoder(x)
loss = head(h).mean()
loss.backward()
print(h.grad)

大概率会看到 None。这不表示梯度没有穿过 h,只表示 PyTorch 没有把中间梯度存到 h.grad 上。真正要检查的是相关参数的 .grad,或者在调试时对 h 调用 retain_grad()

requires_grad 的传播也值得单独记住。若一个 operation 的任意输入需要梯度,而且该 operation 有可微实现,那么输出通常也会 requires_grad=True

x = torch.randn(8, 16)
w = torch.randn(16, 32, requires_grad=True)

h = x @ w
print(h.requires_grad)  # True
print(h.grad_fn)        # MmBackward0 or similar

反过来,下面这些操作会把 tensor 从 autograd graph 中拿出去:

Pattern Effect Common use
x.detach() shares storage, stops gradient target network, frozen teacher
x.item() converts scalar tensor to Python number logging, control flow
x.tolist() converts tensor values to Python list debugging, serialization
x.cpu().numpy() leaves PyTorch autograd external libraries
torch.tensor(x) copies data into a new leaf-like tensor often accidental

其中最隐蔽的是 torch.tensor(existing_tensor)。如果你只是想改变 dtype/device,应使用 .to(...);如果你想复制但保留 autograd lineage,应使用 clone();如果你想明确断图,才使用 detach()

WarningPitfall: Rewrapping a Tensor Breaks the Graph

torch.tensor(existing_tensor) constructs a new tensor from data and does not preserve the original autograd history. Prefer clone, to, or detach according to intent.

Chain Rule in Tensor Form

设一个两层网络为

\[ h = \phi(xW_1+b_1), \qquad z = hW_2+b_2, \qquad L=\ell(z,y). \]

\[ G_z=\frac{\partial L}{\partial z}\in\mathbb{R}^{B\times C}. \]

则第二层参数梯度为

\[ \frac{\partial L}{\partial W_2} = h^\top G_z, \qquad \frac{\partial L}{\partial b_2} = \sum_{i=1}^{B}G_{z,i}. \]

传回 hidden layer:

\[ G_h = G_zW_2^\top. \]

再经过 activation:

\[ G_a = G_h\odot \phi'(a), \qquad a=xW_1+b_1. \]

第一层参数梯度:

\[ \frac{\partial L}{\partial W_1} = x^\top G_a, \qquad \frac{\partial L}{\partial b_1} = \sum_{i=1}^{B}G_{a,i}. \]

把 batch 维看成独立样本求和。对第二层,

\[ z_{ic}=\sum_j h_{ij}W_{2,jc}+b_{2,c}. \]

于是

\[ \frac{\partial L}{\partial W_{2,jc}} = \sum_i \frac{\partial L}{\partial z_{ic}} \frac{\partial z_{ic}}{\partial W_{2,jc}} = \sum_i G_{z,ic}h_{ij}, \]

这正是矩阵乘法 \(h^\top G_z\) 的第 \((j,c)\) 项。对 hidden state,

\[ \frac{\partial L}{\partial h_{ij}} = \sum_c \frac{\partial L}{\partial z_{ic}} \frac{\partial z_{ic}}{\partial h_{ij}} = \sum_c G_{z,ic}W_{2,jc}, \]

\(G_zW_2^\top\)。第一层同理,只是多乘一个 activation derivative \(\phi'(a)\)

PyTorch 的 autograd 做的就是这套张量链式法则,只是它不需要你手写每个矩阵的 backward。理解这些 shape 可以帮你定位常见错误:loss reduction 是否正确、batch 维是否被错误广播、bias 梯度是否应该按 batch 求和。

Scalar Loss and Vector-Jacobian Product

loss.backward() 最常见的形式要求 loss 是标量。原因是反向传播本质上计算 vector-Jacobian product。若

\[ y=f(x)\in\mathbb{R}^{m}, \qquad x\in\mathbb{R}^{n}, \]

Jacobian 是

\[ J=\frac{\partial y}{\partial x}\in\mathbb{R}^{m\times n}. \]

如果我们有一个上游向量 \(v\in\mathbb{R}^{m}\),反传计算的是

\[ v^\top J. \]

\(y=L\) 是标量时,\(v=1\),所以可以省略。若输出不是标量,就必须传入上游梯度:

y = model(x)              # shape [B, C]
v = torch.ones_like(y)
y.backward(v)

更常见的写法是先把 per-example loss 聚合成标量:

loss_vec = criterion(logits, labels, reduction="none")
loss = loss_vec.mean()
loss.backward()
WarningPitfall: sum and mean Change Gradient Scale

Using loss.sum() makes gradient magnitude grow with batch size. Using loss.mean() keeps the scale closer across different batch sizes.

这不是纯粹的风格问题。batch size 改变时,如果从 mean 换成 sum,等价学习率会随 batch size 放大,optimizer 的稳定区间也会改变。

backward vs autograd.grad

loss.backward() 的副作用是把梯度累积到 leaf parameters 的 .grad buffer。训练 optimizer 时,这正是我们想要的。但有些场景只想“查询某个梯度”,不想污染参数 .grad,这时更适合 torch.autograd.grad

logits = model(x)
loss = criterion(logits, y)

grads = torch.autograd.grad(
    loss,
    tuple(model.parameters()),
    retain_graph=False,
    create_graph=False,
)

两者的区别可以概括为:

API Returns gradients Writes .grad Typical use
loss.backward() no yes ordinary training
torch.autograd.grad(...) yes no penalties, meta-learning, analysis

这和数学里的两种视角对应。若参数为 \(\theta\),普通训练需要把

\[ g=\nabla_\theta L(\theta) \]

存到 optimizer 能读取的位置,也就是 .grad。而 gradient penalty 可能只是要构造一个新的 loss:

\[ R(\theta) = \left\|\nabla_x f_\theta(x)\right\|_2^2, \]

此时我们需要先拿到 \(\nabla_x f_\theta(x)\),再让 \(R\)\(\theta\) 反传:

x = x.detach().requires_grad_(True)
score = discriminator(x).sum()

grad_x = torch.autograd.grad(
    score,
    x,
    create_graph=True,
)[0]
penalty = grad_x.pow(2).flatten(1).sum(dim=1).mean()
penalty.backward()

这里 create_graph=True 是关键:否则 grad_x 会被当成普通数值,penalty 无法继续对模型参数产生二阶路径。

WarningPitfall: Unused Inputs and Silent Assumptions

If an input is not connected to the loss, autograd.grad normally raises an error. Treat allow_unused=True as a debugging tool, not a way to hide disconnected computation.

Gradient Accumulation Semantics

PyTorch 的 .grad 是累积的,不是覆盖的:

loss1.backward()
loss2.backward()

执行后,参数梯度是

\[ g = \nabla_\theta L_1+\nabla_\theta L_2. \]

这给 gradient accumulation 提供了自然实现:

optimizer.zero_grad(set_to_none=True)

for step, batch in enumerate(loader):
    logits = model(batch["x"])
    loss = criterion(logits, batch["y"])
    loss = loss / accum_steps
    loss.backward()

    if (step + 1) % accum_steps == 0:
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

为什么要除以 accum_steps?如果目标是模拟一个大 batch 的平均梯度:

\[ g_{\text{large}} = \frac{1}{K}\sum_{k=1}^{K}\nabla_\theta L_k, \]

那么每个 micro-batch 的 loss 就应该先除以 \(K\)。否则你得到的是

\[ \sum_{k=1}^{K}\nabla_\theta L_k, \]

等价于把 learning rate 放大 \(K\) 倍。

NoteDefinition: Gradient Buffer

A gradient buffer is the .grad storage attached to a parameter. Backward adds gradients into this buffer; zero_grad clears or resets it before the next optimizer step.

optimizer.zero_grad(set_to_none=True) 会把 .grad 设为 None,而不是写入全零 tensor。这样通常更省内存、更快,也能暴露“本 step 某参数没有收到梯度”的情况。

Variable-Length Token Loss

上面的 loss / accum_steps 默认每个 micro-batch 的 loss 已经是同一种平均方式。对语言模型,这个假设经常不成立,因为每个 micro-batch 的有效 token 数可能不同。设第 \(k\) 个 micro-batch 有 token loss

\[ \ell_{ki}, \qquad m_{ki}\in\{0,1\} \]

其中 \(m_{ki}=1\) 表示该 token 参与 loss。真正的 large-batch token mean 是

\[ L = \frac{\sum_{k=1}^{K}\sum_i m_{ki}\ell_{ki}} {\sum_{k=1}^{K}\sum_i m_{ki}}. \]

如果每个 micro-batch 先各自取 mean,再除以 \(K\)

\[ \tilde L = \frac{1}{K} \sum_{k=1}^{K} \frac{\sum_i m_{ki}\ell_{ki}} {\sum_i m_{ki}}, \]

那么每个 micro-batch 权重相同,而不是每个 token 权重相同。当 prompt/response 长度差异很大时,这会改变训练目标。

一个工程上更明确的写法是先让 loss function 返回 per-token loss,再按 mask 归一化:

loss_tok = F.cross_entropy(
    logits.flatten(0, 1),
    labels.flatten(),
    ignore_index=-100,
    reduction="none",
).view_as(labels)

mask = labels.ne(-100)
loss = (loss_tok * mask).sum() / mask.sum().clamp_min(1)

如果还要做 gradient accumulation,并且想严格得到整个 accumulation window 的 token mean,有两种常见选择:

Method Exact token mean Trade-off
pre-count valid tokens for the window yes needs looking ahead or staging batches
scale by local token count per micro-batch no simple and often acceptable

严格版本需要先知道本 accumulation window 的总有效 token 数 \(N_{\text{tok}}\),再对每个 micro-batch 使用

\[ L_k = \frac{\sum_i m_{ki}\ell_{ki}}{N_{\text{tok}}}. \]

这样 backward 累积后才等价于一次大 batch:

\[ \sum_{k=1}^{K}\nabla_\theta L_k = \nabla_\theta \frac{\sum_{k,i}m_{ki}\ell_{ki}}{N_{\text{tok}}}. \]

ImportantImplementation Contract: Normalize the Object You Mean

For token-level objectives, decide whether each sequence, each micro-batch, or each token should have equal weight. The loss normalization code is the mathematical objective.

A Minimal Robust Training Step

一个可以作为模板的单步训练如下:

def train_step(model, batch, optimizer, criterion, device):
    model.train()

    x = batch["x"].to(device, non_blocking=True)
    y = batch["y"].to(device, non_blocking=True)

    optimizer.zero_grad(set_to_none=True)

    logits = model(x)
    loss = criterion(logits, y)

    loss.backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()

    return loss.detach()

这里每一行都有状态含义:

Line State change
model.train() enable training behavior such as dropout and batch norm updates
.to(device) move batch tensors to the same device as parameters
zero_grad clear old gradient buffers
model(x) build a fresh computation graph
criterion produce scalar objective
backward accumulate gradients into parameters
clip_grad_norm_ transform gradients before update
optimizer.step() update parameters and optimizer states
loss.detach() return a value without keeping the graph alive
WarningPitfall: Returning Raw loss Can Keep Graphs Alive

If you append raw loss tensors to a Python list for logging, you may accidentally keep computation graphs in memory. Prefer loss.detach() or loss.item() for logging.

Parameter Update Boundary

Autograd 负责算梯度,optimizer 负责改参数。这个边界应该非常清楚:参数更新不应该被 autograd 记录成下一张训练图的一部分。因此 optimizer 内部通常在 no-grad 语义下做 in-place update:

\[ \theta_{t+1} = \theta_t-\eta g_t. \]

手写 update 时也应该这样写:

with torch.no_grad():
    for p in model.parameters():
        if p.grad is not None:
            p.add_(p.grad, alpha=-lr)

不要用 .data 偷改参数:

# Avoid this pattern.
for p in model.parameters():
    p.data.add_(p.grad, alpha=-lr)

.data 会绕过 autograd 的安全检查,包括 version counter。它有时看起来能跑,但可能让 saved tensor 和真实 storage 的关系变得不可追踪。

WarningPitfall: Optimizer Coverage Is a Training Contract

If a trainable parameter is not passed to the optimizer, it can receive gradients but never update. If a frozen parameter remains in the optimizer, it may still carry stale optimizer state.

一个简单覆盖检查:

opt_ids = {
    id(p)
    for group in optimizer.param_groups
    for p in group["params"]
}

for name, p in model.named_parameters():
    if p.requires_grad and id(p) not in opt_ids:
        print("trainable but not optimized:", name)

训练 loop 中的四种状态可以这样分层:

State Owned by Cleared or updated by
forward graph autograd engine freed after backward
parameter .grad parameter tensor zero_grad and backward
parameter value module parameter optimizer.step()
optimizer state optimizer optimizer.step() and checkpoint load

这解释了为什么只保存 model.state_dict() 不能 exact resume:Adam 的 moment、scheduler step、AMP scale、随机数状态和 dataloader 位置都不是模型参数本身。

AMP Training Step Semantics

Mixed precision 的目标是让大部分矩阵计算用 BF16/FP16 提高吞吐、降低显存,同时保留足够稳定的梯度更新。FP16 的指数范围较窄,所以常用 loss scaling:

\[ \tilde L=sL. \]

反传得到的是

\[ \nabla_\theta \tilde L = s\nabla_\theta L. \]

optimizer step 前必须 unscale:

\[ g = \frac{1}{s}\nabla_\theta \tilde L. \]

典型 FP16 AMP 单步:

scaler = torch.amp.GradScaler("cuda")

optimizer.zero_grad(set_to_none=True)

with torch.amp.autocast("cuda", dtype=torch.float16):
    logits = model(x)
    loss = criterion(logits, y)

scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()

顺序不能随便换:

  1. autocast 只包 forward 和 loss;
  2. scale(loss).backward() 产生 scaled gradients;
  3. unscale_(optimizer).grad 还原到真实尺度;
  4. gradient clipping 应该发生在 unscaled gradients 上;
  5. scaler.step(optimizer) 可能因为 Inf/NaN 跳过 update;
  6. scaler.update() 调整下一步 loss scale。
WarningPitfall: Clip After Unscale

Clipping scaled gradients clips the artificial loss-scale factor, not the real gradient norm. In AMP, unscale before gradient clipping.

BF16 通常不需要 GradScaler,因为它和 FP32 有相同的 exponent 位数,动态范围更宽。它仍然可能因为模型本身不稳定而 NaN,但一般不靠 loss scaling 解决。

Train Mode, Eval Mode, and Grad Mode

PyTorch 有两个容易混淆的开关:

  1. module mode: model.train() / model.eval()
  2. grad mode: normal grad / torch.no_grad() / torch.inference_mode()

它们不是同一个东西。

Switch Controls Typical use
model.train() module behavior training loop
model.eval() module behavior validation or inference
torch.no_grad() autograd recording validation metrics
torch.inference_mode() autograd + extra inference optimizations pure inference

model.eval() 会让 Dropout 停止随机丢弃,让 BatchNorm 使用 running statistics,但它不关闭 autograd。验证时仍然应该写:

model.eval()
total = 0.0

with torch.no_grad():
    for batch in val_loader:
        logits = model(batch["x"].to(device))
        loss = criterion(logits, batch["y"].to(device))
        total += loss.item()

torch.inference_mode()no_grad() 更强,适合纯推理,因为它还会关闭一些 view tracking 和 version counter 相关开销。但它创建的 tensor 不适合拿回训练图里继续参与 autograd。因此训练中的验证一般用 no_grad() 就够;部署推理可考虑 inference_mode()

WarningPitfall: eval() Does Not Mean No Gradients

model.eval() changes module behavior but does not disable gradient tracking. Use torch.no_grad() or torch.inference_mode() to stop autograd recording.

Detach, Clone, and In-Place Operations

detach() 返回一个和原 tensor 共享数据、但不再连接当前 autograd graph 的 tensor:

h = encoder(x)
h_stop = h.detach()
loss = head(h_stop).mean()

此时 loss 的梯度不会传回 encoder。这在 contrastive learning、target networks、teacher-student、truncated BPTT 中很常见。

clone() 复制数据,但不自动断开 autograd;detach().clone() 同时断图并复制数据。

NoteDefinition: Stop-Gradient

Stop-gradient means treating a computed tensor as a constant for subsequent gradient computation, preventing gradients from flowing back through the operations that produced it.

In-place operations 是另一个常见坑:

x.relu_()

如果某个 backward 需要用到 in-place 修改前的值,PyTorch 会通过 version counter 检测并报错。这个错误看起来烦,但其实是保护你:否则梯度会 silently wrong。

WarningPitfall: In-Place Ops Can Break Backward

In-place operations may overwrite values needed by backward. Use out-of-place operations unless memory pressure or profiling clearly justifies in-place code.

Accidental Graph Breaks in Training Code

这里的 graph break 不是 torch.compile 的 graph break,而是 autograd 路径被你无意中切断。常见例子:

loss = criterion(logits, labels)

# Bad: Python float no longer has autograd history.
loss_value = loss.item()
loss_value.backward()   # fails: float has no backward

更隐蔽的是只丢失一部分路径:

loss = criterion(logits, labels)
penalty = weight_decay_penalty(model)

total = loss.item() + penalty
total.backward()        # only the penalty path remains differentiable

另一个隐蔽例子是用 Python control flow 根据 tensor value 选择分支:

if loss.item() > 1.0:
    loss = loss * 0.5

这段代码不是错的,因为分支决策本来就不可微;但要知道梯度只会沿被选中的 tensor expression 传播,不会对“阈值决策”本身求导。如果你需要可微的 soft gating,应写成 tensor operation:

\[ \alpha=\sigma(k(L-\tau)), \qquad L'=\alpha L_1+(1-\alpha)L_2. \]

alpha = torch.sigmoid(k * (loss - tau))
loss = alpha * loss_a + (1.0 - alpha) * loss_b
ImportantImplementation Contract: Log Values Outside the Graph

Use .item() for logging and metrics, not for values that will still participate in differentiable objectives.

Higher-Order Gradients

通常训练只需要一阶梯度。但 meta-learning、implicit differentiation、gradient penalty 可能需要二阶梯度。此时要让 backward 本身也被记录进 graph:

grad = torch.autograd.grad(
    loss,
    params,
    create_graph=True,
)
penalty = sum(g.pow(2).sum() for g in grad)
penalty.backward()

retain_graph=Truecreate_graph=True 不一样:

Option Meaning
retain_graph=True keep the current graph for another backward
create_graph=True build a graph of the gradient computation

若只是需要对两个 loss 共用同一个 forward graph,可以用 retain_graph=True。若需要对梯度再求梯度,才需要 create_graph=True

Debugging Autograd

当训练“能跑但不学”时,先查这几个状态:

Symptom Likely check
loss constant labels, loss mask, requires_grad, optimizer parameter list
.grad is None parameter unused, detached path, no_grad, frozen parameter
gradients all zero saturated activation, wrong loss, underflow, hard mask
gradients explode LR too high, missing normalization, unstable recurrence
memory grows every step storing graph tensors, missing detach, retain_graph=True misuse
train/eval mismatch dropout, batch norm, data preprocessing mismatch

几个实用检查:

for name, p in model.named_parameters():
    if p.requires_grad and p.grad is None:
        print("no grad:", name)

total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=float("inf"))
print("grad norm:", float(total_norm))

也可以临时打开 anomaly detection:

with torch.autograd.detect_anomaly():
    loss.backward()

但它会显著变慢,只适合定位 NaN 或 backward 报错。

更系统的检查可以分三层:参数覆盖、梯度数值、更新幅度。

def grad_report(model):
    rows = []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if p.grad is None:
            rows.append((name, "none", None, None))
            continue
        g = p.grad.detach()
        rows.append((
            name,
            str(tuple(g.shape)),
            float(g.norm()),
            bool(torch.isfinite(g).all()),
        ))
    return rows

如果所有梯度都有限但模型不动,可以检查 update ratio:

\[ r = \frac{\|\Delta\theta\|_2}{\|\theta\|_2+\epsilon}. \]

一个小的 wrapper:

before = {
    name: p.detach().clone()
    for name, p in model.named_parameters()
    if p.requires_grad
}

optimizer.step()

for name, p in model.named_parameters():
    if name not in before:
        continue
    delta = p.detach() - before[name]
    ratio = delta.norm() / before[name].norm().clamp_min(1e-12)
    print(name, float(ratio))

ratio 接近零可能表示 LR 太小、梯度被 clip 没了、参数没进 optimizer、或者 AMP step 被跳过;突然巨大则要查 loss spike、学习率 schedule、weight decay 和异常 batch。

Training Loop Checklist

一个可靠的 loop 至少满足:

  1. model.train() and model.eval() are placed deliberately;
  2. gradients are cleared exactly once per optimizer step;
  3. loss reduction matches the intended effective batch size;
  4. gradient accumulation normalization matches sequence/token weighting;
  5. validation uses no_grad() and does not update optimizer;
  6. logged tensors are detached or converted to Python numbers;
  7. clipping, AMP unscale, scheduler step, and optimizer step are ordered consistently;
  8. frozen parameters are removed from optimizer or have requires_grad=False
  9. train and validation preprocessing are aligned;
  10. loss masks match the mathematical objective;
  11. trainable parameters are all covered by optimizer param groups;
  12. torch.autograd.grad is used when gradients should be returned but not accumulated;
  13. .item(), NumPy conversion, and torch.tensor(x) are not used inside differentiable paths;
  14. AMP skipped steps do not accidentally advance scheduler or global update counters.

训练循环的核心不是“把 API 拼起来”,而是维护几个状态不混乱:graph、grad buffer、parameter value、optimizer state、module mode、randomness。只要这些状态清楚,很多 bug 会从玄学变成可定位的工程问题。

References