1.2 Modules, OOP, and Data Loading
PyTorch 的面向对象系统不是为了把代码写得“像 Java”,而是为训练系统提供三个能力:
- registration: 参数、buffer、子模块能被自动发现;
- mode control: train/eval、device/dtype、state_dict 能递归作用;
- composability: layer、block、model、trainer 可以逐层组合。
nn.Module 的真正价值是把 Python 对象变成可训练状态机。
nn.Module is the base class for stateful neural-network components. It registers parameters, buffers, and child modules so that optimization, serialization, device movement, and mode switching can be applied recursively.
forward and __call__
自定义模型通常继承 nn.Module:
import torch
from torch import nn
class MLP(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.fc1 = nn.Linear(in_dim, hidden_dim)
self.act = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, out_dim)
def forward(self, x):
h = self.act(self.fc1(x))
return self.fc2(h)调用时写:
logits = model(x)而不是直接写 model.forward(x)。原因是 Module.__call__ 会在 forward 外面处理 hooks、autocast 等框架逻辑。forward 是你定义数学计算的地方,__call__ 是 PyTorch 调度 module 调用协议的入口。
Calling model.forward(x) bypasses module call machinery such as hooks. Use model(x) in training and inference code.
Parameters and Buffers
nn.Parameter 是一种特殊 tensor:赋值到 module 属性后,会被注册为可训练参数。
class Scale(nn.Module):
def __init__(self, dim):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
return x * self.weight如果状态需要保存和迁移设备,但不需要梯度,应注册为 buffer:
class RunningMean(nn.Module):
def __init__(self, dim):
super().__init__()
self.register_buffer("count", torch.zeros(()))
self.register_buffer("mean", torch.zeros(dim))| State | Registered by | In parameters() |
In state_dict() |
Example |
|---|---|---|---|---|
| parameter | nn.Parameter / submodule |
yes | yes | linear weight |
| buffer | register_buffer |
no | yes | BatchNorm running mean |
| plain tensor attr | assignment | no | no | temporary cache |
这解释了一个常见 bug:把 tensor 放在普通属性上,它不会随 model.to(device) 移动,也不会进入 checkpoint。
Registration Internals
nn.Module 的注册机制主要发生在属性赋值时。把 nn.Parameter、nn.Module 或 buffer 放进 module,不只是 Python 对象引用,还会进入 module 内部的三个映射:
_parameters: named Parameter objects
_buffers: named persistent/non-persistent tensors
_modules: named child modules
可以用一个最小例子观察:
class Tiny(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.ones(3))
self.linear = nn.Linear(3, 2)
self.register_buffer("scale", torch.ones(()))
self.cache = torch.zeros(3)
m = Tiny()
print(m._parameters.keys()) # weight
print(m._modules.keys()) # linear
print(m._buffers.keys()) # scalecache 只是普通属性,所以不会出现在这些 registry 里。递归 API 都依赖 registry:
| API | Traverses |
|---|---|
model.parameters() |
_parameters plus child modules |
model.buffers() |
_buffers plus child modules |
model.to(device) |
parameters and buffers |
model.state_dict() |
parameters and persistent buffers |
model.train() |
child modules |
Module registration is the process by which assigned parameters, buffers, and child modules are recorded in a module’s internal registries so recursive APIs can discover them.
如果后来把属性覆盖掉,registry 也会变化:
m.linear = nn.Identity()这会把原来的 linear child module 替换成 Identity。如果 optimizer 已经在替换之前创建,它仍然可能持有旧参数引用;所以“替换 module”通常应该发生在构造 optimizer 之前。
Parameters, Gradients, and Optimizer Membership
requires_grad=False 和“不在 optimizer 里”是两件事。冻结参数最稳的做法是两层都明确:
for p in backbone.parameters():
p.requires_grad_(False)
optimizer = torch.optim.AdamW(
head.parameters(),
lr=3e-4,
)如果一个 frozen parameter 仍在 optimizer group 里,通常不会有梯度更新,但 optimizer state、weight decay 语义和内存账单会变得不清晰。反过来,如果 requires_grad=True 的参数没有进入 optimizer,它会累积梯度却永远不 step。
一个覆盖性检查:
opt_params = {
id(p)
for group in optimizer.param_groups
for p in group["params"]
}
missing = [
name
for name, p in model.named_parameters()
if p.requires_grad and id(p) not in opt_params
]
assert not missing, missing还要检查重复参数。参数共享是合法的,例如 tied embedding,但 optimizer group 中重复加入同一个 parameter 通常是 bug:
seen = set()
dups = []
for group in optimizer.param_groups:
for p in group["params"]:
if id(p) in seen:
dups.append(tuple(p.shape))
seen.add(id(p))
assert not dups, dupsChanging requires_grad after constructing the optimizer does not automatically remove parameters or optimizer state. Rebuild or edit optimizer groups when freezing/unfreezing changes the training contract.
Registration Containers
PyTorch 只能递归发现注册过的 child module。普通 Python list 不会注册其中的 module:
class BadStack(nn.Module):
def __init__(self, depth, dim):
super().__init__()
self.layers = [nn.Linear(dim, dim) for _ in range(depth)]正确写法:
class GoodStack(nn.Module):
def __init__(self, depth, dim):
super().__init__()
self.layers = nn.ModuleList(
[nn.Linear(dim, dim) for _ in range(depth)]
)
def forward(self, x):
for layer in self.layers:
x = torch.relu(layer(x))
return x| Container | Has automatic forward? |
Use case |
|---|---|---|
nn.Sequential |
yes | straight-line stack |
nn.ModuleList |
no | loop or residual control flow |
nn.ModuleDict |
no | named branches, task heads |
nn.ParameterList |
no | list of raw parameters |
nn.ParameterDict |
no | named raw parameters |
If submodules are stored in a plain Python list or dict, PyTorch will not find their parameters, move them to devices, or save them in state_dict.
Device and Dtype Lifecycle
model.to(device) 不是“把 Python 对象搬到 GPU”,而是递归地把 registered parameters 和 buffers 转换到目标 device/dtype。普通 tensor attribute 不会被处理:
class BadCache(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.ones(4))
self.cache = torch.ones(4)
m = BadCache().to("cuda")
print(m.weight.device) # cuda
print(m.cache.device) # still cpu如果 cache 真的是模型状态,就注册成 buffer;如果只是 forward 中临时产生的张量,就在 forward 里按输入 device 创建:
tmp = torch.ones(x.size(0), device=x.device, dtype=x.dtype)A device contract specifies which tensors must live on the same device at a module boundary, and which tensors are allowed to remain Python or CPU metadata.
Build Order
稳妥的训练构造顺序是:
- instantiate model;
- move model to target device/dtype;
- freeze/unfreeze parameters;
- build optimizer from final trainable parameters;
- optionally load optimizer state when resuming exactly。
model = Model(config)
model.to(device=device, dtype=torch.bfloat16)
for p in model.backbone.parameters():
p.requires_grad_(False)
optimizer = make_optimizer(model)原因是 optimizer 持有 parameter 对象引用和 optimizer state。构造 optimizer 后再替换 module、改变 trainable set、或做 dtype/device surgery,都会让 optimizer contract 变得不清楚。即使某些 .to() 路径在当前 PyTorch 版本中保持 parameter identity,教学和工程上仍应把“model placement before optimizer construction”作为默认规则。
An optimizer does not rediscover model parameters every step. It updates the tensors stored in its parameter groups when the optimizer was constructed or manually edited.
Mixed Dtype Parameters and Buffers
现代训练里参数、buffer、activation 的 dtype 可能不同:
| Tensor | Common dtype | Reason |
|---|---|---|
| model weights | BF16/FP16/FP32 | memory and matmul throughput |
| LayerNorm/RMSNorm weights | FP32 or BF16 | stability policy |
| optimizer states | FP32 | stable accumulation |
| BatchNorm running stats | FP32 | stable statistics |
| integer ids/masks | int64/bool | indexing and masking |
不要盲目对整个 batch 调 .to(dtype=torch.bfloat16),因为 input_ids 和 labels 必须保持 integer dtype:
batch = {
"input_ids": input_ids.to(device), # int64
"labels": labels.to(device), # int64
"pixel_values": pixels.to(device, dtype), # float
}一个递归 mover 如果要支持 dtype,应该只转换 floating tensors:
def move_batch(x, device, dtype):
if torch.is_tensor(x):
if x.is_floating_point():
return x.to(device=device, dtype=dtype, non_blocking=True)
return x.to(device=device, non_blocking=True)
if isinstance(x, dict):
return {k: move_batch(v, device, dtype) for k, v in x.items()}
if isinstance(x, list):
return [move_batch(v, device, dtype) for v in x]
if isinstance(x, tuple):
return tuple(move_batch(v, device, dtype) for v in x)
return xEmbedding lookup expects integer token ids. Accidentally casting a whole batch to BF16/FP16 breaks indexing semantics or fails deep inside the model.
state_dict as the Serialization Boundary
state_dict() 是参数和 buffer 的有序映射:
sd = model.state_dict()
for name, tensor in sd.items():
print(name, tuple(tensor.shape))它不保存 Python class 定义、optimizer 对象、训练循环代码。恢复模型通常需要:
model = MLP(in_dim=10, hidden_dim=64, out_dim=2)
model.load_state_dict(torch.load("model.pt", map_location="cpu"))
model.eval()训练恢复还需要 optimizer、scheduler、GradScaler、random state 和 global step;这部分在训练系统章节会继续展开。
A model state_dict stores learned tensors, not the Python program. Exact training resume requires model state plus optimizer/scheduler/scaler/counter/random states.
Loading Strictly, Partially, and Safely
load_state_dict 的 strict 参数定义 checkpoint 和当前 model contract 是否必须完全一致:
result = model.load_state_dict(state, strict=False)
print(result.missing_keys)
print(result.unexpected_keys)| Case | Meaning | Typical action |
|---|---|---|
missing_keys |
model expects a tensor not in checkpoint | new head, new adapter, renamed layer |
unexpected_keys |
checkpoint has tensor not used by model | old head, deleted module, wrapper prefix |
| shape mismatch | same key but incompatible shape | usually hard error and needs explicit surgery |
strict=False 不是“安全加载”。它只是允许 key 集合不完全匹配;shape mismatch 仍然应该显式处理。常见迁移脚本会先过滤:
model_state = model.state_dict()
filtered = {}
for key, value in ckpt.items():
if key not in model_state:
continue
if value.shape != model_state[key].shape:
continue
filtered[key] = value
missing, unexpected = model.load_state_dict(filtered, strict=False)这段代码适合迁移 backbone,但必须把跳过的 key 打印出来并写进实验记录。否则你可能以为加载了 pretrained model,实际大部分层仍是随机初始化。
DDP/DataParallel checkpoints may prefix keys with module.. Blindly using strict=False can ignore every pretrained tensor if prefixes are not remapped.
Prefix Remapping and Shape Surgery
DDP、DataParallel、FSDP、Lightning、Hugging Face wrapper 都可能改变 checkpoint key。最常见的是 module. prefix:
def strip_prefix(state, prefix):
out = {}
for key, value in state.items():
if key.startswith(prefix):
out[key[len(prefix) :]] = value
else:
out[key] = value
return out
state = strip_prefix(state, "module.")更复杂的迁移要写成显式 mapping,而不是靠字符串碰碰运气:
rename = {
"encoder.layers.": "backbone.blocks.",
"classifier.": "head.",
}
def rename_key(key):
for old, new in rename.items():
if key.startswith(old):
return new + key[len(old) :]
return keyshape mismatch 最危险的是 embedding / LM head resize。假设旧词表大小是 \(V_{\text{old}}\),新词表是 \(V_{\text{new}}\)。如果 hidden size 没变,可以复制交集:
old = ckpt["embed_tokens.weight"]
new = model.state_dict()["embed_tokens.weight"].clone()
n = min(old.size(0), new.size(0))
new[:n].copy_(old[:n])
ckpt["embed_tokens.weight"] = new如果输入 embedding 和 LM head tied,还要确保二者仍共享权重:
model.lm_head.weight = model.embed_tokens.weight或者调用模型库提供的 tie_weights()。否则 save/load 之后可能悄悄从 tied 变成 untied。
Checkpoint surgery is an explicit transformation of checkpoint keys or tensors before loading them into a model with a changed architecture or vocabulary.
Load Report as an Artifact
每次 partial load 都应该产生日志,而不是只看 Python 没报错:
result = model.load_state_dict(filtered, strict=False)
report = {
"loaded": sorted(filtered.keys()),
"missing": sorted(result.missing_keys),
"unexpected": sorted(result.unexpected_keys),
"skipped_shape": skipped_shape,
}这份 report 应该和实验配置一起保存。一个简单审计规则是:
\[ \text{loaded ratio} = \frac{\sum_{k\in\text{loaded}}\operatorname{numel}(W_k)} {\sum_{k\in\text{model}}\operatorname{numel}(W_k)}. \]
如果你以为加载了 backbone,但 loaded ratio 只有 5%,那不是“fine-tuning 起点差一点”,而是实验语义完全变了。
Whenever strict=False is used, save the missing, unexpected, skipped, and loaded tensor report. Otherwise checkpoint migration is not reproducible.
buffer 还有 persistent=False:
self.register_buffer("rotary_cache", cache, persistent=False)这种 buffer 会跟着 .to(device) 移动,但不会进入 state_dict()。它适合可重建 cache,例如 RoPE cache、mask cache;不适合 BatchNorm running statistics 这类训练状态。
Train/Eval Mode
model.train() 和 model.eval() 递归设置 module 的 training flag。它影响 Dropout、BatchNorm 等层:
model.train()
loss = criterion(model(x), y)
model.eval()
with torch.no_grad():
pred = model(x_val)eval() 不等于 no_grad():
| Switch | Controls | Does not control |
|---|---|---|
model.eval() |
module behavior | autograd graph construction |
torch.no_grad() |
gradient recording | dropout/batchnorm mode |
requires_grad_(False) |
parameter gradient accumulation | module behavior |
验证/推理通常三者按需组合。
Hooks for Inspection, Not Core Logic
Module hooks 可以在不改 forward 的情况下观察 activation 或 gradient:
acts = {}
def save_activation(name):
def hook(module, inputs, output):
acts[name] = output.detach()
return hook
handle = model.fc1.register_forward_hook(save_activation("fc1"))
_ = model(x)
handle.remove()常见 hook:
| Hook | Fires when | Use |
|---|---|---|
forward_pre_hook |
before forward |
input checks, shape logging |
forward_hook |
after forward |
activation capture |
full_backward_hook |
during backward | gradient diagnostics |
hooks 很适合 debug、可视化和统计,但不适合承载模型核心逻辑。原因是 hooks 的执行顺序、分布式包装、编译模式和异常处理都会让控制流更难推理。
Saving raw hook outputs without .detach() can keep the autograd graph alive and cause memory growth across steps.
Initialization and apply
module.apply(fn) 会递归访问所有 child modules:
def init_linear(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
model.apply(init_linear)这里用 isinstance 是初始化脚本中可以接受的外围适配;核心 module 的 forward 逻辑不应到处写类型分支。
更稳的初始化习惯:
- 明确哪些层需要特殊初始化;
- 不用
.data偷改参与 autograd 的 tensor; - 初始化后做一次 forward/backward smoke test;
- 保存初始化配置,保证可复现。
Dataset Protocol
Map-style dataset 实现两个方法:
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, records, transform):
self.records = records
self.transform = transform
def __len__(self):
return len(self.records)
def __getitem__(self, idx):
item = self.records[idx]
return self.transform(item)DataLoader 会根据 index 调用 dataset[i],再用 collate_fn 把一组样本合成 batch。
A collate_fn maps a Python list of samples into one batch object, usually by stacking tensors, padding variable-length fields, and constructing masks.
Map-Style vs Iterable Datasets
PyTorch dataset 有两种常见协议:
| Dataset type | Required methods | Best for |
|---|---|---|
| map-style | __len__, __getitem__ |
finite indexed datasets, shuffling, random access |
| iterable-style | __iter__ |
streams, generated data, huge shards, remote sources |
Map-style dataset 让 sampler 控制 index 顺序:
sampler -> indices -> dataset[i] -> samples -> collate_fn -> batch
Iterable dataset 直接产生样本:
dataset iterator -> samples -> collate_fn -> batch
A sampler defines the order in which examples or indices are drawn from a dataset. In distributed training, it also defines which rank owns which shard of an epoch.
分布式训练中,DistributedSampler 会按 rank 切分 indices。每个 epoch 必须调用:
sampler.set_epoch(epoch)否则每个 epoch 的 shuffle 顺序可能完全相同。这个 bug 不会报错,但会降低 stochasticity。
Worker Seeding and Random Transforms
num_workers>0 时,dataset 和 transform 在子进程里运行。随机增强如果没有 worker seed,可能出现每个 worker 产生相同随机序列的问题。稳妥写法:
import random
import numpy as np
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
generator = torch.Generator()
generator.manual_seed(1234)
loader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=4,
worker_init_fn=seed_worker,
generator=generator,
)这里有三层随机性:
- sampler 的样本顺序;
- worker 内部 transform 的随机数;
- GPU 上模型/算子的随机数。
只设 torch.manual_seed 不一定覆盖 dataloader worker 的 Python random 和 NumPy 随机数。可复现实验必须把这几层都写清楚。
With persistent_workers=True, worker processes survive across epochs. Any mutable dataset state inside workers may persist unless explicitly reset.
IterableDataset Worker Sharding
Iterable dataset 在 num_workers>0 时,每个 worker 都会各自创建 dataset iterator。如果不显式分片,多个 worker 可能读取完全相同的数据流。
from torch.utils.data import IterableDataset, get_worker_info
class ShardedStream(IterableDataset):
def __init__(self, records):
self.records = records
def __iter__(self):
info = get_worker_info()
if info is None:
yield from self.records
return
worker_id = info.id
num_workers = info.num_workers
for i, item in enumerate(self.records):
if i % num_workers == worker_id:
yield item这个 round-robin 分片适合能顺序枚举的 records。对文件 shards,更常见的是按 shard id 分配:
\[ \text{owned}(s) \iff s\bmod N_{\text{workers}}=\text{worker\_id}. \]
在分布式训练里还要叠加 rank:
\[ \text{global\_worker\_id} = \text{rank}\cdot N_{\text{workers}}+\text{worker\_id}. \]
总 worker 数是
\[ N_{\text{global}} = \text{world\_size}\cdot N_{\text{workers}}. \]
否则每张卡、每个 worker 都可能读到重复样本。
Worker sharding partitions an iterable data stream across dataloader workers and distributed ranks so that each example is produced by exactly one consumer per epoch or stream pass.
Length, Epoch, and Stop Semantics
Map-style dataset 的 epoch 很清楚:sampler 产生一轮 indices。Iterable dataset 的 epoch 是训练系统定义出来的。常见策略:
| Strategy | Meaning | Risk |
|---|---|---|
| fixed examples | stop after \(N\) examples | uneven worker tails |
| fixed tokens | stop after \(N\) tokens | variable batch count |
| fixed steps | stop after \(S\) optimizer steps | data stream position must be checkpointed |
| infinite stream | no natural epoch | scheduler/eval cadence must be external |
因此 iterable pipeline 的 checkpoint 不能只保存 epoch。至少要保存 shard id、offset、global step、consumed examples/tokens,以及数据混洗的 seed/state。否则 resume 后可能重复或跳过数据。
For iterable datasets, epoch is not enough to resume exactly. Save stream shard, offset, worker/rank partition policy, and RNG state.
默认 collate 只适合 shape 一致的 tensor。NLP、检测、多模态任务通常需要自定义:
def collate(batch):
input_ids = [b["input_ids"] for b in batch]
labels = [b["labels"] for b in batch]
padded = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=pad_id,
)
label_pad = torch.nn.utils.rnn.pad_sequence(
labels,
batch_first=True,
padding_value=-100,
)
attention_mask = padded.ne(pad_id)
return {
"input_ids": padded,
"labels": label_pad,
"attention_mask": attention_mask,
}这里 pad_id 控制 attention,-100 控制 loss。二者不能混用。
Collate as a Batch Contract
collate_fn 不只是“把 list 变成 tensor”,它定义 batch contract。一个训练 step 应该能只看 batch keys 就知道:
input_ids: [B, T] int64
attention_mask: [B, T] bool/int
labels: [B, T] int64, -100 ignored
metadata: Python-only, not moved to GPU
对图像/检测任务,batch contract 可能是:
images: [B, C, H, W]
targets: list[dict], one per image
这时强行把 variable-size boxes stack 成一个 dense tensor 反而会制造 mask 复杂度。好的 collate 应该让后续 model/loss 的输入语义清晰,而不是追求所有东西都 tensor 化。
一个递归搬运函数:
def move_to_device(x, device):
if torch.is_tensor(x):
return x.to(device, non_blocking=True)
if isinstance(x, dict):
return {k: move_to_device(v, device) for k, v in x.items()}
if isinstance(x, list):
return [move_to_device(v, device) for v in x]
if isinstance(x, tuple):
return tuple(move_to_device(v, device) for v in x)
return xpin_memory=True 配合 non_blocking=True 才能让 host-to-device copy 更可能异步。若 batch 里混有 Python 对象,它们不会被 pin,也不应该被搬到 GPU。
DataLoader Engineering
常用参数:
| Argument | Meaning | Common failure |
|---|---|---|
batch_size |
samples per batch | too large causes OOM |
shuffle |
random sample order | false during train reduces stochasticity |
num_workers |
subprocess data loading | too high causes CPU/RAM contention |
pin_memory |
pinned host memory | useful for GPU transfer |
drop_last |
drop incomplete last batch | changes epoch token/example count |
persistent_workers |
reuse worker processes | stale dataset state if misused |
collate_fn |
batch construction | wrong padding/masks |
Data pipeline 的目标不是只“能加载”,而是让 GPU 不等 CPU。最小观测指标:
data_time = time until batch arrives
step_time = forward + backward + optimizer
gpu_util = GPU active ratio
若 data_time 接近或超过 step_time,瓶颈通常在 dataset decoding、CPU transforms、collate padding、磁盘读取或进程数配置。
Optimizer Interface
optimizer 接收 parameter iterable 或 parameter groups:
decay, no_decay = [], []
for name, p in model.named_parameters():
if not p.requires_grad:
continue
if name.endswith("bias") or "norm" in name.lower():
no_decay.append(p)
else:
decay.append(p)
optimizer = torch.optim.AdamW(
[
{"params": decay, "weight_decay": 0.1},
{"params": no_decay, "weight_decay": 0.0},
],
lr=3e-4,
)参数组是训练假设的编码:哪些参数 decay,哪些参数冻结,哪些参数学习率更大,都应该能从 parameter groups 中读出来。
Minimal Trainer Skeleton
训练循环的最小状态机:
for epoch in range(num_epochs):
model.train()
for batch in train_loader:
batch = move_to_device(batch, device)
optimizer.zero_grad(set_to_none=True)
loss = compute_loss(model, batch)
loss.backward()
optimizer.step()
model.eval()
with torch.no_grad():
metrics = evaluate(model, val_loader)这段代码省略了 AMP、gradient accumulation、scheduler、checkpoint、distributed training,但顺序是对的。最重要的是把状态边界写清楚:
- train mode before train batches;
- zero grad before backward;
- eval mode and
no_gradbefore validation; - metrics 不参与 autograd;
- checkpoint 保存模型和训练状态。
Module and Data Smoke Tests
nn.Module 和 DataLoader 的 bug 很多是“能训练,但训练的不是你以为的模型”。下面这些小测试适合在写完模型、collate、checkpoint 逻辑后立刻跑。
Test 1: All Trainable Parameters Are Optimized
def assert_optimizer_covers_trainable(model, optimizer):
opt_ids = {
id(p)
for group in optimizer.param_groups
for p in group["params"]
}
missing = [
name
for name, p in model.named_parameters()
if p.requires_grad and id(p) not in opt_ids
]
if missing:
raise AssertionError(f"trainable params missing from optimizer: {missing}")这个测试能抓住 plain list 隐藏 module、替换 head 后忘记重建 optimizer、LoRA 参数没有进 optimizer 等问题。
Test 2: No Duplicate Optimizer Parameters
def assert_no_duplicate_optimizer_params(optimizer):
seen = set()
dup = 0
for group in optimizer.param_groups:
for p in group["params"]:
if id(p) in seen:
dup += 1
seen.add(id(p))
if dup:
raise AssertionError(f"duplicate optimizer parameters: {dup}")如果模型有 tied weights,重复引用可能是有意的,但 optimizer group 里重复出现同一个 parameter 通常会让它被 step 两次或触发框架警告。显式测试比靠肉眼看 param group 稳。
Test 3: Save/Load Gives Same Output
import io
def assert_state_dict_roundtrip(model, example):
model.eval()
with torch.no_grad():
ref = model(example)
buf = io.BytesIO()
torch.save(model.state_dict(), buf)
buf.seek(0)
clone = type(model)(model.config)
clone.load_state_dict(torch.load(buf, map_location="cpu"))
clone.eval()
with torch.no_grad():
out = clone(example.cpu())
assert torch.allclose(ref.cpu(), out, atol=1e-6)真实项目里 type(model)(model.config) 未必成立,可以换成项目自己的 factory。核心检查是:model state boundary 足够完整,恢复后在 eval/no_grad 下输出一致。
Test 4: Batch Contract
def assert_lm_batch(batch, pad_id):
ids = batch["input_ids"]
labels = batch["labels"]
mask = batch["attention_mask"]
assert ids.dtype == torch.long
assert labels.dtype == torch.long
assert mask.shape == ids.shape == labels.shape
assert mask.dtype in (torch.bool, torch.long, torch.int64)
assert torch.equal(mask.bool(), ids.ne(pad_id))
assert (labels[~mask.bool()] == -100).all()这个测试把 collate 的语义钉住:attention mask 由 pad_id 决定,loss mask 由 -100 决定。二者混用是语言模型训练里非常常见的静默错误。
Test 5: Iterable Workers Do Not Duplicate
对小的 iterable dataset,可以用 num_workers=2 跑一轮,检查样本 id 是否重复:
def assert_no_duplicate_ids(loader, limit):
seen = set()
for i, batch in enumerate(loader):
for item_id in batch["id"]:
item_id = int(item_id)
if item_id in seen:
raise AssertionError(f"duplicate item id: {item_id}")
seen.add(item_id)
if i + 1 >= limit:
break这个测试不能证明大规模 streaming 完全正确,但能立刻抓住“每个 worker 都从头读全量数据”的基础错误。
For module and dataloader code, test the boundaries: parameter discovery, optimizer membership, save/load identity, batch masks, and worker sharding.
Implementation Checklist
写 nn.Module / data pipeline 时至少检查:
- 所有子模块是否在
ModuleList/Dict/Sequential中注册; - 需要保存但不训练的状态是否用 buffer;
model(x)而不是model.forward(x);train()/eval()/no_grad()是否各自放在正确位置;state_dict是否能 save/load 后给出一致输出;collate_fn是否构造了正确 padding、attention mask 和 label mask;- DataLoader worker 数量是否真的提高吞吐;
- optimizer parameter groups 是否覆盖全部 trainable parameters 且无重复;
model.to(device, dtype)是否发生在构造 optimizer 前;- batch mover 是否只把 floating tensors cast 到训练 dtype;
- partial checkpoint load 是否保存 missing/unexpected/skipped/loaded report;
- IterableDataset 是否按 worker 和 distributed rank 分片;
- iterable pipeline 的 resume 是否保存 shard、offset、seed 和 consumed tokens/examples。