1.2 Modules, OOP, and Data Loading


PyTorch 的面向对象系统不是为了把代码写得“像 Java”,而是为训练系统提供三个能力:

  1. registration: 参数、buffer、子模块能被自动发现;
  2. mode control: train/eval、device/dtype、state_dict 能递归作用;
  3. composability: layer、block、model、trainer 可以逐层组合。

nn.Module 的真正价值是把 Python 对象变成可训练状态机。

NoteDefinition: nn.Module

nn.Module is the base class for stateful neural-network components. It registers parameters, buffers, and child modules so that optimization, serialization, device movement, and mode switching can be applied recursively.

forward and __call__

自定义模型通常继承 nn.Module

import torch
from torch import nn


class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.act = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        h = self.act(self.fc1(x))
        return self.fc2(h)

调用时写:

logits = model(x)

而不是直接写 model.forward(x)。原因是 Module.__call__ 会在 forward 外面处理 hooks、autocast 等框架逻辑。forward 是你定义数学计算的地方,__call__ 是 PyTorch 调度 module 调用协议的入口。

WarningPitfall: Do Not Call forward Directly

Calling model.forward(x) bypasses module call machinery such as hooks. Use model(x) in training and inference code.

Parameters and Buffers

nn.Parameter 是一种特殊 tensor:赋值到 module 属性后,会被注册为可训练参数。

class Scale(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return x * self.weight

如果状态需要保存和迁移设备,但不需要梯度,应注册为 buffer:

class RunningMean(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.register_buffer("count", torch.zeros(()))
        self.register_buffer("mean", torch.zeros(dim))
State Registered by In parameters() In state_dict() Example
parameter nn.Parameter / submodule yes yes linear weight
buffer register_buffer no yes BatchNorm running mean
plain tensor attr assignment no no temporary cache

这解释了一个常见 bug:把 tensor 放在普通属性上,它不会随 model.to(device) 移动,也不会进入 checkpoint。

Registration Internals

nn.Module 的注册机制主要发生在属性赋值时。把 nn.Parameternn.Module 或 buffer 放进 module,不只是 Python 对象引用,还会进入 module 内部的三个映射:

_parameters: named Parameter objects
_buffers:    named persistent/non-persistent tensors
_modules:    named child modules

可以用一个最小例子观察:

class Tiny(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(3))
        self.linear = nn.Linear(3, 2)
        self.register_buffer("scale", torch.ones(()))
        self.cache = torch.zeros(3)


m = Tiny()
print(m._parameters.keys())  # weight
print(m._modules.keys())     # linear
print(m._buffers.keys())     # scale

cache 只是普通属性,所以不会出现在这些 registry 里。递归 API 都依赖 registry:

API Traverses
model.parameters() _parameters plus child modules
model.buffers() _buffers plus child modules
model.to(device) parameters and buffers
model.state_dict() parameters and persistent buffers
model.train() child modules
NoteDefinition: Module Registration

Module registration is the process by which assigned parameters, buffers, and child modules are recorded in a module’s internal registries so recursive APIs can discover them.

如果后来把属性覆盖掉,registry 也会变化:

m.linear = nn.Identity()

这会把原来的 linear child module 替换成 Identity。如果 optimizer 已经在替换之前创建,它仍然可能持有旧参数引用;所以“替换 module”通常应该发生在构造 optimizer 之前。

Parameters, Gradients, and Optimizer Membership

requires_grad=False 和“不在 optimizer 里”是两件事。冻结参数最稳的做法是两层都明确:

for p in backbone.parameters():
    p.requires_grad_(False)

optimizer = torch.optim.AdamW(
    head.parameters(),
    lr=3e-4,
)

如果一个 frozen parameter 仍在 optimizer group 里,通常不会有梯度更新,但 optimizer state、weight decay 语义和内存账单会变得不清晰。反过来,如果 requires_grad=True 的参数没有进入 optimizer,它会累积梯度却永远不 step。

一个覆盖性检查:

opt_params = {
    id(p)
    for group in optimizer.param_groups
    for p in group["params"]
}

missing = [
    name
    for name, p in model.named_parameters()
    if p.requires_grad and id(p) not in opt_params
]
assert not missing, missing

还要检查重复参数。参数共享是合法的,例如 tied embedding,但 optimizer group 中重复加入同一个 parameter 通常是 bug:

seen = set()
dups = []
for group in optimizer.param_groups:
    for p in group["params"]:
        if id(p) in seen:
            dups.append(tuple(p.shape))
        seen.add(id(p))
assert not dups, dups
WarningPitfall: Frozen Parameters Can Leave Stale Optimizer State

Changing requires_grad after constructing the optimizer does not automatically remove parameters or optimizer state. Rebuild or edit optimizer groups when freezing/unfreezing changes the training contract.

Registration Containers

PyTorch 只能递归发现注册过的 child module。普通 Python list 不会注册其中的 module:

class BadStack(nn.Module):
    def __init__(self, depth, dim):
        super().__init__()
        self.layers = [nn.Linear(dim, dim) for _ in range(depth)]

正确写法:

class GoodStack(nn.Module):
    def __init__(self, depth, dim):
        super().__init__()
        self.layers = nn.ModuleList(
            [nn.Linear(dim, dim) for _ in range(depth)]
        )

    def forward(self, x):
        for layer in self.layers:
            x = torch.relu(layer(x))
        return x
Container Has automatic forward? Use case
nn.Sequential yes straight-line stack
nn.ModuleList no loop or residual control flow
nn.ModuleDict no named branches, task heads
nn.ParameterList no list of raw parameters
nn.ParameterDict no named raw parameters
WarningPitfall: Python Containers Hide Modules

If submodules are stored in a plain Python list or dict, PyTorch will not find their parameters, move them to devices, or save them in state_dict.

Device and Dtype Lifecycle

model.to(device) 不是“把 Python 对象搬到 GPU”,而是递归地把 registered parameters 和 buffers 转换到目标 device/dtype。普通 tensor attribute 不会被处理:

class BadCache(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(4))
        self.cache = torch.ones(4)


m = BadCache().to("cuda")
print(m.weight.device)  # cuda
print(m.cache.device)   # still cpu

如果 cache 真的是模型状态,就注册成 buffer;如果只是 forward 中临时产生的张量,就在 forward 里按输入 device 创建:

tmp = torch.ones(x.size(0), device=x.device, dtype=x.dtype)
NoteDefinition: Device Contract

A device contract specifies which tensors must live on the same device at a module boundary, and which tensors are allowed to remain Python or CPU metadata.

Build Order

稳妥的训练构造顺序是:

  1. instantiate model;
  2. move model to target device/dtype;
  3. freeze/unfreeze parameters;
  4. build optimizer from final trainable parameters;
  5. optionally load optimizer state when resuming exactly。
model = Model(config)
model.to(device=device, dtype=torch.bfloat16)

for p in model.backbone.parameters():
    p.requires_grad_(False)

optimizer = make_optimizer(model)

原因是 optimizer 持有 parameter 对象引用和 optimizer state。构造 optimizer 后再替换 module、改变 trainable set、或做 dtype/device surgery,都会让 optimizer contract 变得不清楚。即使某些 .to() 路径在当前 PyTorch 版本中保持 parameter identity,教学和工程上仍应把“model placement before optimizer construction”作为默认规则。

WarningPitfall: Optimizer Captures Parameter Objects

An optimizer does not rediscover model parameters every step. It updates the tensors stored in its parameter groups when the optimizer was constructed or manually edited.

Mixed Dtype Parameters and Buffers

现代训练里参数、buffer、activation 的 dtype 可能不同:

Tensor Common dtype Reason
model weights BF16/FP16/FP32 memory and matmul throughput
LayerNorm/RMSNorm weights FP32 or BF16 stability policy
optimizer states FP32 stable accumulation
BatchNorm running stats FP32 stable statistics
integer ids/masks int64/bool indexing and masking

不要盲目对整个 batch 调 .to(dtype=torch.bfloat16),因为 input_ids 和 labels 必须保持 integer dtype:

batch = {
    "input_ids": input_ids.to(device),          # int64
    "labels": labels.to(device),                # int64
    "pixel_values": pixels.to(device, dtype),   # float
}

一个递归 mover 如果要支持 dtype,应该只转换 floating tensors:

def move_batch(x, device, dtype):
    if torch.is_tensor(x):
        if x.is_floating_point():
            return x.to(device=device, dtype=dtype, non_blocking=True)
        return x.to(device=device, non_blocking=True)
    if isinstance(x, dict):
        return {k: move_batch(v, device, dtype) for k, v in x.items()}
    if isinstance(x, list):
        return [move_batch(v, device, dtype) for v in x]
    if isinstance(x, tuple):
        return tuple(move_batch(v, device, dtype) for v in x)
    return x
WarningPitfall: Do Not Cast Token IDs to Floating Point

Embedding lookup expects integer token ids. Accidentally casting a whole batch to BF16/FP16 breaks indexing semantics or fails deep inside the model.

state_dict as the Serialization Boundary

state_dict() 是参数和 buffer 的有序映射:

sd = model.state_dict()
for name, tensor in sd.items():
    print(name, tuple(tensor.shape))

它不保存 Python class 定义、optimizer 对象、训练循环代码。恢复模型通常需要:

model = MLP(in_dim=10, hidden_dim=64, out_dim=2)
model.load_state_dict(torch.load("model.pt", map_location="cpu"))
model.eval()

训练恢复还需要 optimizer、scheduler、GradScaler、random state 和 global step;这部分在训练系统章节会继续展开。

NoteDefinition: Checkpoint Boundary

A model state_dict stores learned tensors, not the Python program. Exact training resume requires model state plus optimizer/scheduler/scaler/counter/random states.

Loading Strictly, Partially, and Safely

load_state_dictstrict 参数定义 checkpoint 和当前 model contract 是否必须完全一致:

result = model.load_state_dict(state, strict=False)
print(result.missing_keys)
print(result.unexpected_keys)
Case Meaning Typical action
missing_keys model expects a tensor not in checkpoint new head, new adapter, renamed layer
unexpected_keys checkpoint has tensor not used by model old head, deleted module, wrapper prefix
shape mismatch same key but incompatible shape usually hard error and needs explicit surgery

strict=False 不是“安全加载”。它只是允许 key 集合不完全匹配;shape mismatch 仍然应该显式处理。常见迁移脚本会先过滤:

model_state = model.state_dict()
filtered = {}
for key, value in ckpt.items():
    if key not in model_state:
        continue
    if value.shape != model_state[key].shape:
        continue
    filtered[key] = value

missing, unexpected = model.load_state_dict(filtered, strict=False)

这段代码适合迁移 backbone,但必须把跳过的 key 打印出来并写进实验记录。否则你可能以为加载了 pretrained model,实际大部分层仍是随机初始化。

WarningPitfall: Prefix Mismatches Hide Real Load Failures

DDP/DataParallel checkpoints may prefix keys with module.. Blindly using strict=False can ignore every pretrained tensor if prefixes are not remapped.

Prefix Remapping and Shape Surgery

DDP、DataParallel、FSDP、Lightning、Hugging Face wrapper 都可能改变 checkpoint key。最常见的是 module. prefix:

def strip_prefix(state, prefix):
    out = {}
    for key, value in state.items():
        if key.startswith(prefix):
            out[key[len(prefix) :]] = value
        else:
            out[key] = value
    return out


state = strip_prefix(state, "module.")

更复杂的迁移要写成显式 mapping,而不是靠字符串碰碰运气:

rename = {
    "encoder.layers.": "backbone.blocks.",
    "classifier.": "head.",
}


def rename_key(key):
    for old, new in rename.items():
        if key.startswith(old):
            return new + key[len(old) :]
    return key

shape mismatch 最危险的是 embedding / LM head resize。假设旧词表大小是 \(V_{\text{old}}\),新词表是 \(V_{\text{new}}\)。如果 hidden size 没变,可以复制交集:

old = ckpt["embed_tokens.weight"]
new = model.state_dict()["embed_tokens.weight"].clone()
n = min(old.size(0), new.size(0))
new[:n].copy_(old[:n])
ckpt["embed_tokens.weight"] = new

如果输入 embedding 和 LM head tied,还要确保二者仍共享权重:

model.lm_head.weight = model.embed_tokens.weight

或者调用模型库提供的 tie_weights()。否则 save/load 之后可能悄悄从 tied 变成 untied。

NoteDefinition: Checkpoint Surgery

Checkpoint surgery is an explicit transformation of checkpoint keys or tensors before loading them into a model with a changed architecture or vocabulary.

Load Report as an Artifact

每次 partial load 都应该产生日志,而不是只看 Python 没报错:

result = model.load_state_dict(filtered, strict=False)
report = {
    "loaded": sorted(filtered.keys()),
    "missing": sorted(result.missing_keys),
    "unexpected": sorted(result.unexpected_keys),
    "skipped_shape": skipped_shape,
}

这份 report 应该和实验配置一起保存。一个简单审计规则是:

\[ \text{loaded ratio} = \frac{\sum_{k\in\text{loaded}}\operatorname{numel}(W_k)} {\sum_{k\in\text{model}}\operatorname{numel}(W_k)}. \]

如果你以为加载了 backbone,但 loaded ratio 只有 5%,那不是“fine-tuning 起点差一点”,而是实验语义完全变了。

WarningPitfall: Partial Loading Must Be Auditable

Whenever strict=False is used, save the missing, unexpected, skipped, and loaded tensor report. Otherwise checkpoint migration is not reproducible.

buffer 还有 persistent=False

self.register_buffer("rotary_cache", cache, persistent=False)

这种 buffer 会跟着 .to(device) 移动,但不会进入 state_dict()。它适合可重建 cache,例如 RoPE cache、mask cache;不适合 BatchNorm running statistics 这类训练状态。

Train/Eval Mode

model.train()model.eval() 递归设置 module 的 training flag。它影响 Dropout、BatchNorm 等层:

model.train()
loss = criterion(model(x), y)

model.eval()
with torch.no_grad():
    pred = model(x_val)

eval() 不等于 no_grad()

Switch Controls Does not control
model.eval() module behavior autograd graph construction
torch.no_grad() gradient recording dropout/batchnorm mode
requires_grad_(False) parameter gradient accumulation module behavior

验证/推理通常三者按需组合。

Hooks for Inspection, Not Core Logic

Module hooks 可以在不改 forward 的情况下观察 activation 或 gradient:

acts = {}


def save_activation(name):
    def hook(module, inputs, output):
        acts[name] = output.detach()
    return hook


handle = model.fc1.register_forward_hook(save_activation("fc1"))
_ = model(x)
handle.remove()

常见 hook:

Hook Fires when Use
forward_pre_hook before forward input checks, shape logging
forward_hook after forward activation capture
full_backward_hook during backward gradient diagnostics

hooks 很适合 debug、可视化和统计,但不适合承载模型核心逻辑。原因是 hooks 的执行顺序、分布式包装、编译模式和异常处理都会让控制流更难推理。

WarningPitfall: Hooks Can Retain Graphs

Saving raw hook outputs without .detach() can keep the autograd graph alive and cause memory growth across steps.

Initialization and apply

module.apply(fn) 会递归访问所有 child modules:

def init_linear(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.zeros_(m.bias)


model.apply(init_linear)

这里用 isinstance 是初始化脚本中可以接受的外围适配;核心 module 的 forward 逻辑不应到处写类型分支。

更稳的初始化习惯:

  1. 明确哪些层需要特殊初始化;
  2. 不用 .data 偷改参与 autograd 的 tensor;
  3. 初始化后做一次 forward/backward smoke test;
  4. 保存初始化配置,保证可复现。

Dataset Protocol

Map-style dataset 实现两个方法:

from torch.utils.data import Dataset


class MyDataset(Dataset):
    def __init__(self, records, transform):
        self.records = records
        self.transform = transform

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        item = self.records[idx]
        return self.transform(item)

DataLoader 会根据 index 调用 dataset[i],再用 collate_fn 把一组样本合成 batch。

NoteDefinition: Collate Function

A collate_fn maps a Python list of samples into one batch object, usually by stacking tensors, padding variable-length fields, and constructing masks.

Map-Style vs Iterable Datasets

PyTorch dataset 有两种常见协议:

Dataset type Required methods Best for
map-style __len__, __getitem__ finite indexed datasets, shuffling, random access
iterable-style __iter__ streams, generated data, huge shards, remote sources

Map-style dataset 让 sampler 控制 index 顺序:

sampler -> indices -> dataset[i] -> samples -> collate_fn -> batch

Iterable dataset 直接产生样本:

dataset iterator -> samples -> collate_fn -> batch
NoteDefinition: Sampler

A sampler defines the order in which examples or indices are drawn from a dataset. In distributed training, it also defines which rank owns which shard of an epoch.

分布式训练中,DistributedSampler 会按 rank 切分 indices。每个 epoch 必须调用:

sampler.set_epoch(epoch)

否则每个 epoch 的 shuffle 顺序可能完全相同。这个 bug 不会报错,但会降低 stochasticity。

Worker Seeding and Random Transforms

num_workers>0 时,dataset 和 transform 在子进程里运行。随机增强如果没有 worker seed,可能出现每个 worker 产生相同随机序列的问题。稳妥写法:

import random

import numpy as np


def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


generator = torch.Generator()
generator.manual_seed(1234)

loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    worker_init_fn=seed_worker,
    generator=generator,
)

这里有三层随机性:

  1. sampler 的样本顺序;
  2. worker 内部 transform 的随机数;
  3. GPU 上模型/算子的随机数。

只设 torch.manual_seed 不一定覆盖 dataloader worker 的 Python random 和 NumPy 随机数。可复现实验必须把这几层都写清楚。

WarningPitfall: Persistent Workers Keep Dataset State Alive

With persistent_workers=True, worker processes survive across epochs. Any mutable dataset state inside workers may persist unless explicitly reset.

IterableDataset Worker Sharding

Iterable dataset 在 num_workers>0 时,每个 worker 都会各自创建 dataset iterator。如果不显式分片,多个 worker 可能读取完全相同的数据流。

from torch.utils.data import IterableDataset, get_worker_info


class ShardedStream(IterableDataset):
    def __init__(self, records):
        self.records = records

    def __iter__(self):
        info = get_worker_info()
        if info is None:
            yield from self.records
            return

        worker_id = info.id
        num_workers = info.num_workers
        for i, item in enumerate(self.records):
            if i % num_workers == worker_id:
                yield item

这个 round-robin 分片适合能顺序枚举的 records。对文件 shards,更常见的是按 shard id 分配:

\[ \text{owned}(s) \iff s\bmod N_{\text{workers}}=\text{worker\_id}. \]

在分布式训练里还要叠加 rank:

\[ \text{global\_worker\_id} = \text{rank}\cdot N_{\text{workers}}+\text{worker\_id}. \]

总 worker 数是

\[ N_{\text{global}} = \text{world\_size}\cdot N_{\text{workers}}. \]

否则每张卡、每个 worker 都可能读到重复样本。

NoteDefinition: Worker Sharding

Worker sharding partitions an iterable data stream across dataloader workers and distributed ranks so that each example is produced by exactly one consumer per epoch or stream pass.

Length, Epoch, and Stop Semantics

Map-style dataset 的 epoch 很清楚:sampler 产生一轮 indices。Iterable dataset 的 epoch 是训练系统定义出来的。常见策略:

Strategy Meaning Risk
fixed examples stop after \(N\) examples uneven worker tails
fixed tokens stop after \(N\) tokens variable batch count
fixed steps stop after \(S\) optimizer steps data stream position must be checkpointed
infinite stream no natural epoch scheduler/eval cadence must be external

因此 iterable pipeline 的 checkpoint 不能只保存 epoch。至少要保存 shard id、offset、global step、consumed examples/tokens,以及数据混洗的 seed/state。否则 resume 后可能重复或跳过数据。

WarningPitfall: Iterable Data Resume Needs Stream State

For iterable datasets, epoch is not enough to resume exactly. Save stream shard, offset, worker/rank partition policy, and RNG state.

默认 collate 只适合 shape 一致的 tensor。NLP、检测、多模态任务通常需要自定义:

def collate(batch):
    input_ids = [b["input_ids"] for b in batch]
    labels = [b["labels"] for b in batch]
    padded = torch.nn.utils.rnn.pad_sequence(
        input_ids,
        batch_first=True,
        padding_value=pad_id,
    )
    label_pad = torch.nn.utils.rnn.pad_sequence(
        labels,
        batch_first=True,
        padding_value=-100,
    )
    attention_mask = padded.ne(pad_id)
    return {
        "input_ids": padded,
        "labels": label_pad,
        "attention_mask": attention_mask,
    }

这里 pad_id 控制 attention,-100 控制 loss。二者不能混用。

Collate as a Batch Contract

collate_fn 不只是“把 list 变成 tensor”,它定义 batch contract。一个训练 step 应该能只看 batch keys 就知道:

input_ids:      [B, T] int64
attention_mask: [B, T] bool/int
labels:         [B, T] int64, -100 ignored
metadata:       Python-only, not moved to GPU

对图像/检测任务,batch contract 可能是:

images:  [B, C, H, W]
targets: list[dict], one per image

这时强行把 variable-size boxes stack 成一个 dense tensor 反而会制造 mask 复杂度。好的 collate 应该让后续 model/loss 的输入语义清晰,而不是追求所有东西都 tensor 化。

一个递归搬运函数:

def move_to_device(x, device):
    if torch.is_tensor(x):
        return x.to(device, non_blocking=True)
    if isinstance(x, dict):
        return {k: move_to_device(v, device) for k, v in x.items()}
    if isinstance(x, list):
        return [move_to_device(v, device) for v in x]
    if isinstance(x, tuple):
        return tuple(move_to_device(v, device) for v in x)
    return x

pin_memory=True 配合 non_blocking=True 才能让 host-to-device copy 更可能异步。若 batch 里混有 Python 对象,它们不会被 pin,也不应该被搬到 GPU。

DataLoader Engineering

常用参数:

Argument Meaning Common failure
batch_size samples per batch too large causes OOM
shuffle random sample order false during train reduces stochasticity
num_workers subprocess data loading too high causes CPU/RAM contention
pin_memory pinned host memory useful for GPU transfer
drop_last drop incomplete last batch changes epoch token/example count
persistent_workers reuse worker processes stale dataset state if misused
collate_fn batch construction wrong padding/masks

Data pipeline 的目标不是只“能加载”,而是让 GPU 不等 CPU。最小观测指标:

data_time = time until batch arrives
step_time = forward + backward + optimizer
gpu_util  = GPU active ratio

data_time 接近或超过 step_time,瓶颈通常在 dataset decoding、CPU transforms、collate padding、磁盘读取或进程数配置。

Optimizer Interface

optimizer 接收 parameter iterable 或 parameter groups:

decay, no_decay = [], []
for name, p in model.named_parameters():
    if not p.requires_grad:
        continue
    if name.endswith("bias") or "norm" in name.lower():
        no_decay.append(p)
    else:
        decay.append(p)

optimizer = torch.optim.AdamW(
    [
        {"params": decay, "weight_decay": 0.1},
        {"params": no_decay, "weight_decay": 0.0},
    ],
    lr=3e-4,
)

参数组是训练假设的编码:哪些参数 decay,哪些参数冻结,哪些参数学习率更大,都应该能从 parameter groups 中读出来。

Minimal Trainer Skeleton

训练循环的最小状态机:

for epoch in range(num_epochs):
    model.train()
    for batch in train_loader:
        batch = move_to_device(batch, device)
        optimizer.zero_grad(set_to_none=True)
        loss = compute_loss(model, batch)
        loss.backward()
        optimizer.step()

    model.eval()
    with torch.no_grad():
        metrics = evaluate(model, val_loader)

这段代码省略了 AMP、gradient accumulation、scheduler、checkpoint、distributed training,但顺序是对的。最重要的是把状态边界写清楚:

  1. train mode before train batches;
  2. zero grad before backward;
  3. eval mode and no_grad before validation;
  4. metrics 不参与 autograd;
  5. checkpoint 保存模型和训练状态。

Module and Data Smoke Tests

nn.ModuleDataLoader 的 bug 很多是“能训练,但训练的不是你以为的模型”。下面这些小测试适合在写完模型、collate、checkpoint 逻辑后立刻跑。

Test 1: All Trainable Parameters Are Optimized

def assert_optimizer_covers_trainable(model, optimizer):
    opt_ids = {
        id(p)
        for group in optimizer.param_groups
        for p in group["params"]
    }
    missing = [
        name
        for name, p in model.named_parameters()
        if p.requires_grad and id(p) not in opt_ids
    ]
    if missing:
        raise AssertionError(f"trainable params missing from optimizer: {missing}")

这个测试能抓住 plain list 隐藏 module、替换 head 后忘记重建 optimizer、LoRA 参数没有进 optimizer 等问题。

Test 2: No Duplicate Optimizer Parameters

def assert_no_duplicate_optimizer_params(optimizer):
    seen = set()
    dup = 0
    for group in optimizer.param_groups:
        for p in group["params"]:
            if id(p) in seen:
                dup += 1
            seen.add(id(p))
    if dup:
        raise AssertionError(f"duplicate optimizer parameters: {dup}")

如果模型有 tied weights,重复引用可能是有意的,但 optimizer group 里重复出现同一个 parameter 通常会让它被 step 两次或触发框架警告。显式测试比靠肉眼看 param group 稳。

Test 3: Save/Load Gives Same Output

import io


def assert_state_dict_roundtrip(model, example):
    model.eval()
    with torch.no_grad():
        ref = model(example)

    buf = io.BytesIO()
    torch.save(model.state_dict(), buf)
    buf.seek(0)

    clone = type(model)(model.config)
    clone.load_state_dict(torch.load(buf, map_location="cpu"))
    clone.eval()
    with torch.no_grad():
        out = clone(example.cpu())

    assert torch.allclose(ref.cpu(), out, atol=1e-6)

真实项目里 type(model)(model.config) 未必成立,可以换成项目自己的 factory。核心检查是:model state boundary 足够完整,恢复后在 eval/no_grad 下输出一致。

Test 4: Batch Contract

def assert_lm_batch(batch, pad_id):
    ids = batch["input_ids"]
    labels = batch["labels"]
    mask = batch["attention_mask"]

    assert ids.dtype == torch.long
    assert labels.dtype == torch.long
    assert mask.shape == ids.shape == labels.shape
    assert mask.dtype in (torch.bool, torch.long, torch.int64)
    assert torch.equal(mask.bool(), ids.ne(pad_id))
    assert (labels[~mask.bool()] == -100).all()

这个测试把 collate 的语义钉住:attention mask 由 pad_id 决定,loss mask 由 -100 决定。二者混用是语言模型训练里非常常见的静默错误。

Test 5: Iterable Workers Do Not Duplicate

对小的 iterable dataset,可以用 num_workers=2 跑一轮,检查样本 id 是否重复:

def assert_no_duplicate_ids(loader, limit):
    seen = set()
    for i, batch in enumerate(loader):
        for item_id in batch["id"]:
            item_id = int(item_id)
            if item_id in seen:
                raise AssertionError(f"duplicate item id: {item_id}")
            seen.add(item_id)
        if i + 1 >= limit:
            break

这个测试不能证明大规模 streaming 完全正确,但能立刻抓住“每个 worker 都从头读全量数据”的基础错误。

TipImplementation Pattern: Test State Boundaries

For module and dataloader code, test the boundaries: parameter discovery, optimizer membership, save/load identity, batch masks, and worker sharding.

Implementation Checklist

nn.Module / data pipeline 时至少检查:

  1. 所有子模块是否在 ModuleList/Dict/Sequential 中注册;
  2. 需要保存但不训练的状态是否用 buffer;
  3. model(x) 而不是 model.forward(x)
  4. train() / eval() / no_grad() 是否各自放在正确位置;
  5. state_dict 是否能 save/load 后给出一致输出;
  6. collate_fn 是否构造了正确 padding、attention mask 和 label mask;
  7. DataLoader worker 数量是否真的提高吞吐;
  8. optimizer parameter groups 是否覆盖全部 trainable parameters 且无重复;
  9. model.to(device, dtype) 是否发生在构造 optimizer 前;
  10. batch mover 是否只把 floating tensors cast 到训练 dtype;
  11. partial checkpoint load 是否保存 missing/unexpected/skipped/loaded report;
  12. IterableDataset 是否按 worker 和 distributed rank 分片;
  13. iterable pipeline 的 resume 是否保存 shard、offset、seed 和 consumed tokens/examples。