1.6 Data Pipeline and Reproducibility
训练效果不只由 model 和 optimizer 决定。数据如何切分、如何 shuffle、如何 batch、如何 padding、如何搬到 GPU、worker 如何随机增强,都会改变实际优化问题。很多时候,代码里的 DataLoader 看起来只是“喂数据”,但它其实定义了训练分布的采样过程。
这一节把 PyTorch 数据管线拆成三层:
- dataset: 一个 index 对应什么样本;
- sampler: 以什么顺序抽样;
- collate: 多个样本怎样组成 batch。
再把这些选择和 reproducibility、padding、mask、吞吐、分布式训练连起来。
Dataset as a Map from Index to Example
A map-style dataset implements __getitem__(index) and __len__(). It represents a finite collection where each integer index maps to one example.
典型分类数据集:
class ImageDataset(torch.utils.data.Dataset):
def __init__(self, records, transform):
self.records = records
self.transform = transform
def __len__(self):
return len(self.records)
def __getitem__(self, idx):
path, label = self.records[idx]
image = read_image(path)
image = self.transform(image)
return {"image": image, "label": label}这里 Dataset 不负责 batch,也不负责 shuffle。它只回答一个问题:给定 idx,拿到什么训练样本?
An iterable-style dataset implements __iter__() and yields examples sequentially. It is useful for streams, generated data, or data sources without random access.
Iterable dataset 更像数据流:
class JsonlStream(torch.utils.data.IterableDataset):
def __init__(self, path):
self.path = path
def __iter__(self):
with open(self.path) as f:
for line in f:
yield parse_json(line)它适合日志流、大规模 shard、在线生成数据。但它也更容易出错:多 worker 时每个 worker 都会跑一遍 __iter__(),如果不手动分片,就会重复读数据。
For an IterableDataset, every DataLoader worker constructs its own iterator. Without worker-aware sharding, multiple workers may yield the same examples.
一个最小 worker-aware stream 可以用 get_worker_info() 按行号取模:
class ShardedJsonlStream(torch.utils.data.IterableDataset):
def __init__(self, path):
self.path = path
def __iter__(self):
info = torch.utils.data.get_worker_info()
if info is None:
worker_id = 0
num_workers = 1
else:
worker_id = info.id
num_workers = info.num_workers
with open(self.path) as f:
for i, line in enumerate(f):
if i % num_workers != worker_id:
continue
yield parse_json(line)分布式训练时还要把 rank 纳入 shard id。若有 \(R\) 个 data-parallel ranks,每个 rank 有 \(W\) 个 loader workers,则全局 worker 数是
\[ N_{\text{stream workers}}=R W. \]
全局 worker id 可以写成
\[ s=\operatorname{rank}\cdot W+\operatorname{worker\_id}. \]
于是每条记录只由满足
\[ i\bmod (RW)=s \]
的 worker 读取:
class RankWorkerJsonlStream(torch.utils.data.IterableDataset):
def __init__(self, path, rank, world_size):
self.path = path
self.rank = rank
self.world_size = world_size
def __iter__(self):
info = torch.utils.data.get_worker_info()
worker_id = 0 if info is None else info.id
num_workers = 1 if info is None else info.num_workers
shard_id = self.rank * num_workers + worker_id
nshards = self.world_size * num_workers
with open(self.path) as f:
for i, line in enumerate(f):
if i % nshards == shard_id:
yield parse_json(line)这段代码适合解释语义,但大规模文本训练通常会按文件 shard 而不是按行取模。原因是按行取模会让每个 worker 扫完整文件,IO 放大 \(RW\) 倍;按文件 shard 可以让每个 worker 只打开自己的文件子集。
A data shard is a subset of records or files assigned to one worker, rank, or process so that large datasets can be read in parallel without accidental duplication.
Map-style dataset 也需要一个明确的 index contract。一个稳妥做法是把训练样本 manifest 固化成表:
| Column | Meaning |
|---|---|
id |
stable example id |
path |
storage location |
label |
target or metadata |
split |
train/validation/test |
length |
tokens, frames, or pixels used for batching |
hash |
content fingerprint when feasible |
这样 __getitem__(idx) 的含义不会随文件系统遍历顺序变化:
class ManifestDataset(torch.utils.data.Dataset):
def __init__(self, rows, transform):
self.rows = rows
self.transform = transform
def __len__(self):
return len(self.rows)
def __getitem__(self, idx):
row = self.rows[idx]
item = load_record(row["path"])
item = self.transform(item)
item["example_id"] = row["id"]
return itemIf training needs exact resume or comparable ablations, dataset index order must be stable. Do not rely on unsorted directory listings or mutable remote query order.
Sampler Defines the Training Order
A sampler is an iterator that yields dataset indices. It determines the order in which examples are drawn from a map-style dataset.
shuffle=True 看起来只是一个布尔值,本质上等价于使用 random sampler。训练时的 empirical risk 通常写成
\[ \hat{R}(\theta) = \frac{1}{N}\sum_{i=1}^{N}\ell(f_\theta(x_i),y_i). \]
SGD 每一步抽一个 mini-batch \(B_t\):
\[ g_t = \frac{1}{|B_t|} \sum_{i\in B_t} \nabla_\theta \ell_i(\theta). \]
如果 sampler 每个 epoch 给出不同随机排列,\(g_t\) 就是对全数据梯度的随机估计。若 sampler 有系统性偏差,例如同一类别连续出现、长文本集中在一起、时间序列泄漏未来样本,训练动态会明显改变。
Changing shuffle, sampler seed, bucket order, or distributed sharding can change optimization trajectories even when model code is identical.
常见 sampler 类型:
| Sampler | Use case |
|---|---|
| sequential | validation, deterministic evaluation |
| random | standard SGD training |
| weighted random | class imbalance, sampling curriculum |
| distributed | split indices across ranks |
| bucketed | group similar lengths for efficient padding |
Sampler 只产生单个 index;BatchSampler 产生一组 index。长度感知 batching 往往应该写成 batch sampler,因为它决定的是“哪些样本一起进一个 batch”:
class LengthBucketBatchSampler(torch.utils.data.Sampler):
def __init__(self, lengths, batch_size, bucket_size, generator):
self.lengths = lengths
self.batch_size = batch_size
self.bucket_size = bucket_size
self.generator = generator
def __iter__(self):
n = len(self.lengths)
order = torch.randperm(n, generator=self.generator).tolist()
for start in range(0, n, self.bucket_size):
bucket = order[start : start + self.bucket_size]
bucket.sort(key=lambda i: self.lengths[i])
for j in range(0, len(bucket), self.batch_size):
batch = bucket[j : j + self.batch_size]
if len(batch) == self.batch_size:
yield batch
def __len__(self):
return len(self.lengths) // self.batch_size这里的 trade-off 很明确:先随机打乱大 bucket,再在 bucket 内按长度排序。bucket 越大,padding 越少,但随机性越弱;bucket 越小,随机性更强,但 padding 浪费更多。
A batch sampler yields lists of indices. It controls mini-batch composition directly and is the right abstraction for length bucketing, curriculum batches, and token-budget batching.
动态 token batch 更适合 LLM:不是固定 \(B\) 条样本,而是固定总 token 预算。设第 \(i\) 条长度为 \(T_i\),batch 满足
\[ \sum_{i\in B}T_i \le T_{\max}. \]
一个简单 greedy 版本:
def token_budget_batches(indices, lengths, max_tokens):
batch = []
ntok = 0
for i in indices:
length = lengths[i]
if batch and ntok + length > max_tokens:
yield batch
batch = []
ntok = 0
batch.append(i)
ntok += length
if batch:
yield batch这会产生 variable sample count。训练日志因此不能只记录 batch_size,还要记录 effective tokens:
\[ T_{\text{step}}=\sum_{i\in B_t}T_i. \]
With token-budget batching, the number of sequences per optimizer step varies. Loss normalization, logging, and LR schedule should be defined in tokens or optimizer steps, not assumed fixed by sample count.
Collate Function and Batch Semantics
A collate function transforms a list of individual examples into one batch object. It defines stacking, padding, masks, and type conversion.
默认 collate 会尝试把同名字段 stack 起来:
examples = [
{"x": torch.tensor([1, 2]), "y": 0},
{"x": torch.tensor([3, 4]), "y": 1},
]变成
{
"x": tensor([[1, 2], [3, 4]]),
"y": tensor([0, 1]),
}但 NLP/LLM 数据长度通常不同,必须自定义 padding:
def collate_lm(examples, pad_id):
lengths = [len(ex["input_ids"]) for ex in examples]
max_len = max(lengths)
input_ids = []
attention_mask = []
labels = []
for ex in examples:
ids = ex["input_ids"]
pad = max_len - len(ids)
input_ids.append(ids + [pad_id] * pad)
attention_mask.append([1] * len(ids) + [0] * pad)
labels.append(ids[1:] + [-100] * (pad + 1))
return {
"input_ids": torch.tensor(input_ids),
"attention_mask": torch.tensor(attention_mask),
"labels": torch.tensor(labels),
}这里最关键的不是代码形式,而是 batch 中每个字段的语义:
| Field | Meaning |
|---|---|
input_ids |
tokens fed into the model |
attention_mask |
which positions are real tokens |
labels |
targets used by the loss |
-100 |
ignore index for CrossEntropyLoss in many LM heads |
Padding only the inputs is not enough. The loss must ignore padded target positions, otherwise the model is trained to predict pad tokens.
这里还有一个常见边界:label shifting 到底在哪里做?有些模型 forward 内部会把 labels=input_ids 自动 shift;有些训练代码要求 collate 阶段已经生成 shifted labels。两者只能选一个。若 collate 已经构造
input: [A, B, C, EOS]
label: [B, C, EOS, IGN]
模型内部就不能再 shift 一次。否则实际目标会变成
A -> C
B -> EOS
这会把 next-token prediction 改成错位的 two-step prediction。
For causal language modeling, decide whether shifting happens in the collator or inside the model/loss wrapper. Double shifting silently trains the wrong conditional distribution.
batch contract 应该比“有几个 tensor”更具体。一个训练 step 至少应该知道:
| Field | Shape | Dtype | Device before step | Semantics |
|---|---|---|---|---|
input_ids |
[B, T] |
torch.long |
CPU | discrete token ids |
attention_mask |
[B, T] |
torch.bool or integer |
CPU | real-token mask |
labels |
[B, T] |
torch.long |
CPU | target ids or ignore_index |
position_ids |
[B, T] if present |
torch.long |
CPU | position contract |
这类表不是文档装饰,而是调试工具。若 labels 被错误 cast 成 FP16,CrossEntropyLoss 会直接失败;若 attention_mask 的 1/0 语义和模型期望相反,模型可能能跑但 loss 完全不可信。
嵌套 batch 搬到 GPU 时,最好写一个递归函数,避免只搬了顶层 tensor:
def move_to_device(x, device):
if torch.is_tensor(x):
return x.to(device, non_blocking=True)
if isinstance(x, dict):
return {k: move_to_device(v, device) for k, v in x.items()}
if isinstance(x, list):
return [move_to_device(v, device) for v in x]
if isinstance(x, tuple):
return tuple(move_to_device(v, device) for v in x)
return xImages or activations may use floating dtypes, but labels, token ids, indices, and masks often have semantic integer or boolean dtypes. A blanket .half() or .float() over the batch can corrupt the objective.
Padding, Masks, and Effective Tokens
Suppose two sequences are
[A, B, C, EOS]
[D, E, EOS]
Right padding gives
[A, B, C, EOS]
[D, E, EOS, PAD]
For causal LM, labels are shifted targets. If position \(t\) predicts \(t+1\), then the second sequence should contribute loss on:
D -> E
E -> EOS
but not on
EOS -> PAD
PAD -> PAD
The masked token-level loss is
\[ L = \frac{\sum_{b,t}m_{b,t}\, [-\log p_\theta(y_{b,t}\mid x_{b,\le t})]} {\sum_{b,t}m_{b,t}}, \]
where \(m_{b,t}=1\) for valid target positions and \(0\) for padding or ignored positions.
Effective token count is the number of non-ignored target tokens contributing to the loss. In sequence modeling, it is often a better denominator than batch size.
这也是为什么 LLM 训练日志里经常看 tokens/sec 而不是 samples/sec。两个 batch 的 batch size 都是 8,但如果一个平均长度 128,另一个平均长度 2048,它们的计算量和有效训练信号完全不同。
Bucketing and Packing
Padding 会浪费计算。若一个 batch 里最长序列长度是 \(T_{\max}\),第 \(i\) 条真实长度是 \(T_i\),padding waste ratio roughly is
\[ \rho = 1-\frac{\sum_i T_i}{B T_{\max}}. \]
长度差异很大时,\(\rho\) 可能非常高。常见解决方式:
| Method | Idea | Trade-off |
|---|---|---|
| bucketing | batch examples with similar lengths | less random order |
| dynamic batching | constrain total tokens per batch | variable sample count |
| packing | concatenate short sequences into fixed blocks | need boundary masks |
| truncation | cap maximum length | may lose information |
LLM pretraining 常把文本流 pack 成固定长度 block:
doc1 <eos> doc2 <eos> doc3 <eos> ...
然后切成长度 \(T\) 的 chunks。这样 GPU 利用率高,但要清楚一个问题:是否允许跨文档 attention?有些 recipe 允许,因为 <eos> 足够提供边界;有些会构造 block-diagonal mask,禁止不同文档相互 attention。
Packing can introduce cross-document context unless the attention mask prevents it. This changes the conditional distribution seen by the model.
Packing Implementation Contract
packing 至少有三种不同语义:
| Packing mode | Attention across documents | Loss across boundary | Typical use |
|---|---|---|---|
| stream packing | yes | yes, usually through EOS | pretraining text stream |
| packed but boundary-aware | no | no | instruction/SFT examples |
| packed with EOS-only boundary | maybe | no after EOS | mixed corpora with explicit separators |
边界感知 packing 可以给每个 token 一个 doc_id:
tokens: [A, B, EOS, D, E, EOS, X, EOS]
doc_id: [0, 0, 0, 1, 1, 1, 2, 2]
causal attention mask 不只是下三角,还要同文档:
\[ M_{t,s} = \mathbb{1}[s\le t]\cdot \mathbb{1}[\operatorname{doc}(s)=\operatorname{doc}(t)]. \]
代码上可以这样构造:
def packed_causal_mask(doc_ids):
# doc_ids: [B, T]
_, seqlen = doc_ids.shape
same_doc = doc_ids[:, :, None].eq(doc_ids[:, None, :])
causal = torch.ones(seqlen, seqlen, dtype=torch.bool, device=doc_ids.device).tril()
return same_doc & causalloss mask 还要去掉每个文档最后一个 token,因为它没有同文档 next token:
def packed_labels(input_ids, doc_ids, ignore_index=-100):
labels = input_ids.roll(shifts=-1, dims=1)
same_next_doc = doc_ids[:, :-1].eq(doc_ids[:, 1:])
loss_mask = torch.zeros_like(input_ids, dtype=torch.bool)
loss_mask[:, :-1] = same_next_doc
labels = labels.masked_fill(~loss_mask, ignore_index)
return labels, loss_mask若你允许 EOS -> next_doc_first_token 的 loss,那模型学到的是“文档边界后下一个样本的开头也可预测”。这对连续文本流可能合理,对 instruction 数据通常不合理,因为两个样本的相邻关系是 packing 人为制造的。
Packed batches need explicit boundary metadata when examples are independent. Without doc_id, sequence_id, or equivalent masks, packing may leak context and loss across examples.
packing 还会影响位置编码。若每个文档在 packed block 内都从位置 0 开始,需要构造 reset position ids:
def reset_position_ids(doc_ids):
pos = torch.zeros_like(doc_ids)
for b in range(doc_ids.shape[0]):
start = 0
for t in range(1, doc_ids.shape[1] + 1):
is_end = t == doc_ids.shape[1] or doc_ids[b, t] != doc_ids[b, t - 1]
if is_end:
pos[b, start:t] = torch.arange(t - start, device=doc_ids.device)
start = t
return pos对 absolute positional embeddings,reset 与不 reset 是不同建模假设;对 RoPE,也会影响相对距离。packing 不是纯粹的数据压缩,它改变了模型看到的位置几何。
DataLoader Workers and Host-to-GPU Transfer
DataLoader 的吞吐通常由这几个阶段决定:
disk/network -> decode -> transform -> collate -> pinned memory -> H2D copy -> GPU compute
如果 GPU 经常等数据,模型再优化也没有用。几个常用参数:
| Option | Meaning |
|---|---|
num_workers |
subprocesses for loading data |
pin_memory |
allocate page-locked host memory for faster H2D transfer |
prefetch_factor |
batches prepared ahead per worker |
persistent_workers |
keep workers alive across epochs |
drop_last |
drop incomplete last batch |
典型训练 loader:
loader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=4,
pin_memory=True,
persistent_workers=True,
collate_fn=collate_fn,
)对应训练 loop 中:
batch = {
k: v.to(device, non_blocking=True)
for k, v in batch.items()
}pin_memory=True 和 non_blocking=True 配合时,CPU 到 GPU 的拷贝更容易和计算重叠。不过这不是总能加速:如果数据很小、CPU 已经很忙、或者机器内存紧张,开太多 workers 反而慢。
Too many workers can cause CPU contention, memory pressure, file descriptor exhaustion, or duplicated preprocessing overhead.
Preprocessing Cache and Data Fingerprints
昂贵且确定的预处理最好离线缓存,例如:
| Raw modality | Expensive deterministic preprocessing |
|---|---|
| text | normalization, tokenization, chat-template rendering |
| image | decode, resize, center crop, feature extraction |
| audio | resampling, spectrogram, VAD segmentation |
| graph | neighbor lists, degree features, subgraph extraction |
缓存的风险是“内容已经变了但缓存还在用”。因此缓存文件应带 manifest:
cache_manifest = {
"dataset_name": "my_corpus",
"dataset_version": "2026-06-10",
"preprocess_code_hash": code_hash,
"tokenizer_name": tokenizer_name,
"tokenizer_hash": tokenizer_hash,
"template_hash": chat_template_hash,
"num_records": len(records),
}对文本 LLM,tokenized cache 的 contract 尤其严格:
| Cached field | Why it matters |
|---|---|
input_ids |
tokenizer vocabulary and merge table dependent |
special_tokens_mask |
chat/control token boundaries |
attention_mask |
padding or packed-block semantics |
labels or loss_mask |
SFT/RLHF objective semantics |
length |
batching and token-budget accounting |
如果 tokenizer 新增 special token,但旧 cache 仍用旧 ids,模型会在错误的离散空间里训练。最安全的做法是把 tokenizer 文件内容 hash 进 cache key,而不是只记录名字。
A data fingerprint is a stable hash or manifest that identifies the raw data, preprocessing code, tokenizer/config, and filtering rules used to produce a training dataset.
一个简单 fingerprint 可以组合多层信息:
def fingerprint(parts):
h = hashlib.sha256()
for key, value in sorted(parts.items()):
h.update(str(key).encode())
h.update(b"\0")
h.update(str(value).encode())
h.update(b"\0")
return h.hexdigest()不要只记录 “dataset=v1”。可复现实验至少要能回答:
- 哪些原始文件进入训练;
- 哪些样本被过滤;
- tokenizer、normalizer、chat template 是哪一版;
- packing、truncation、loss mask 规则是什么;
- cache 是否由当前代码生成。
If tokenization happens inside DataLoader workers, changing tokenizer files or chat templates can silently change the training distribution without changing model code.
Randomness in Data Pipelines
训练中的随机性至少来自:
- model initialization;
- data shuffling;
- random data augmentation;
- dropout;
- CUDA kernels;
- distributed reduction order。
一个基本 seed setup:
def seed_everything(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)但 DataLoader workers 是独立进程。若 augmentation 里用了 NumPy 或 Python random,需要在 worker 中设置种子:
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
generator = torch.Generator()
generator.manual_seed(42)
loader = DataLoader(
dataset,
shuffle=True,
worker_init_fn=seed_worker,
generator=generator,
)worker seed 解决的是“每个 worker 随机流不同”。但有些 augmentation 更适合 per-example stateless seed:同一个样本在同一个 epoch 用同一个随机增强,不依赖它被哪个 worker 读到。
设样本 id 为 \(u\),epoch 为 \(e\),全局 seed 为 \(s_0\),可以定义
\[ s(u,e)=H(s_0,u,e). \]
实现上:
def example_seed(base_seed, example_id, epoch):
text = f"{base_seed}:{example_id}:{epoch}"
digest = hashlib.blake2b(text.encode(), digest_size=8).digest()
return int.from_bytes(digest, "little") % 2**32
def augment_with_seed(example, epoch, base_seed):
seed = example_seed(base_seed, example["example_id"], epoch)
rng = random.Random(seed)
if rng.random() < 0.5:
example["image"] = horizontal_flip(example["image"])
return example这种做法的好处是:改变 num_workers、rank 数、prefetch 深度、甚至 batch order,都不会改变某个样本在某个 epoch 的 augmentation 随机数。
Worker seeds make worker processes independent; stateless per-example seeds make augmentation reproducible under changes to worker count, sharding, and resume position.
Reproducibility means that repeated runs under specified software, hardware, seeds, and deterministic settings produce the same or statistically comparable results.
这里要注意“完全一样”和“统计可比”是两种标准。研究复现实验通常至少要求均值、方差、趋势稳定;debug 某个 bug 时才需要 bitwise identical。
Deterministic Algorithms
PyTorch 可以要求使用 deterministic algorithms:
torch.use_deterministic_algorithms(True)这有助于调试,但可能牺牲速度,某些 operation 还可能没有 deterministic 实现。cuDNN benchmark 也会影响卷积算法选择:
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = TrueDeterministic kernels may be slower, and exact reproducibility can depend on PyTorch version, CUDA version, GPU model, and distributed topology.
实际工作中可以分层:
| Stage | Recommended setting |
|---|---|
| debugging | strict seeds, deterministic algorithms, small data |
| ablation | fixed split, fixed data order policy, multiple seeds |
| large training | controlled seeds, logged versions, tolerate non-bitwise differences |
Train/Validation/Test Splits
数据切分是最容易被低估的实验设计问题。随机切分不总是合理:
| Data type | Split risk |
|---|---|
| image duplicates | near-duplicate leakage |
| user behavior | same user in train and test |
| time series | future leaking into past |
| documents | paragraphs from same document split apart |
| graphs | edges/nodes leak through neighborhood |
形式上,我们希望验证集估计的是目标分布风险:
\[ R(\theta)=\mathbb{E}_{(x,y)\sim P_{\text{target}}}\ell(f_\theta(x),y). \]
若 validation set 与 train set 有信息泄漏,估计的就不是泛化风险,而是某种记忆能力。
If examples share users, documents, time windows, graphs, or generated variants, random splitting may leak information across splits.
实验记录里至少保存:
- split generation code;
- split seed;
- exact file list or index list;
- preprocessing version;
- filtering rules;
- tokenizer or vocabulary version。
还可以写几个很朴素的 leakage tests。比如 document-level split 应保证同一 doc_id 不跨 split:
def assert_disjoint_groups(rows, group_key, split_key):
owners = {}
for row in rows:
group = row[group_key]
split = row[split_key]
old = owners.setdefault(group, split)
if old != split:
raise ValueError(f"group {group} appears in {old} and {split}")时间序列 split 至少检查验证集时间不早于训练集:
def assert_time_order(train_rows, val_rows, time_key):
max_train = max(row[time_key] for row in train_rows)
min_val = min(row[time_key] for row in val_rows)
if min_val < max_train:
raise ValueError("validation contains timestamps before training ends")近重复图像、网页镜像、代码仓库 fork 这类情况不能只靠 exact id,需要 hash、embedding 或 domain-specific rule。它们不一定是“代码 bug”,但会让 validation loss 虚高可信度。
Split by the unit that carries dependence: user, document, time window, graph component, repository, or source domain. IID row-level splitting is valid only when rows are genuinely independent.
Distributed Data Loading
Data parallel training 中,每个 rank 应该看到不同数据 shard。DistributedSampler 的核心作用是把 dataset indices 按 rank 切开:
sampler = torch.utils.data.DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
shuffle=True,
)
loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
collate_fn=collate_fn,
)每个 epoch 还需要:
sampler.set_epoch(epoch)否则每个 epoch 的 shuffle 顺序可能重复。
Global batch size is \[ B_{\text{global}} = B_{\text{per-rank}} \times N_{\text{ranks}} \times K_{\text{accum}}. \] It is the number of examples contributing to one optimizer step across all workers.
对 token-level 任务,更常用 global token batch:
\[ T_{\text{global}} = \sum_{\text{rank}}\sum_{\text{micro}}\sum_{b,t}m_{b,t}. \]
这比样本数更准确,因为不同 rank 的有效 token 数可能不同。
Exact Resume of Data Order
exact resume 要求中断后下一步看到的 batch 与不中断训练完全一致。模型 checkpoint 只覆盖参数,数据顺序还需要这些状态:
| State | Map-style dataset | Iterable or streaming dataset |
|---|---|---|
| epoch | sampler epoch | stream epoch or pass id |
| position | batch index or consumed indices | file shard and byte/record cursor |
| RNG | sampler generator state | shuffle buffer RNG state |
| distributed | rank/world size contract | shard assignment contract |
| accounting | examples/tokens seen | examples/tokens seen |
对 map-style dataset,最简单可恢复策略是:每个 epoch 由 (seed, epoch) 决定一个 permutation,checkpoint 保存 epoch 和 batch_in_epoch。恢复时重新生成 permutation,跳过已消费 batch:
def epoch_indices(n, seed, epoch):
g = torch.Generator()
g.manual_seed(seed + epoch)
return torch.randperm(n, generator=g).tolist()
def resume_batches(n, seed, epoch, batch_size, batch_in_epoch):
indices = epoch_indices(n, seed, epoch)
start = batch_in_epoch * batch_size
return indices[start:]这个策略简单但有条件:dataset 长度、index order、world size、batch size、drop_last 都不能变。若它们变了,batch_in_epoch 不再指向同一批样本。
Changing world_size, per-rank batch size, gradient accumulation, or drop_last can change which examples contribute to the next optimizer step, even if model and optimizer states load successfully.
streaming dataset 更难,因为数据不是随机可索引的。工程上常见选择:
| Strategy | Pros | Cons |
|---|---|---|
| save exact cursor | precise | storage-specific; hard for compressed streams |
| deterministic shard order + skip records | simple | slow when skipping deep into stream |
| checkpoint precomputed sample ids | precise | extra manifest/storage |
| tolerate statistical resume | scalable | not bitwise exact |
大规模预训练往往接受 statistical resume:恢复后继续从同一分布采样,但不保证每条样本顺序完全一致。若目标是 debug 或 ablation,应该先用 map-style subset 做 exact resume 验证。
Statistical resume restores training from the same distribution and comparable counters, but does not guarantee the exact next batch equals the uninterrupted run.
drop_last and BatchNorm
drop_last=True 会丢掉最后一个不完整 batch。它看似只是整齐 batch shape,但会影响:
- 每个 epoch 实际训练样本数;
- class imbalance 下少数样本是否更容易被丢;
- distributed training 中各 rank batch 数是否一致;
- BatchNorm statistics 是否受小 batch 干扰。
如果模型使用 BatchNorm,小 batch 的统计不稳定,最后一个小 batch 可能让 running statistics 抖动。若使用 LayerNorm/RMSNorm,通常影响较小。
Some code assumes fixed batch size for reshaping, BatchNorm, contrastive negatives, or distributed collectives. Decide deliberately whether to keep or drop the last batch.
A Practical Data Pipeline Template
一个相对稳妥的监督学习 pipeline:
def make_loader(dataset, batch_size, seed, train):
generator = torch.Generator()
generator.manual_seed(seed)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=train,
num_workers=4,
pin_memory=True,
persistent_workers=True,
drop_last=train,
worker_init_fn=seed_worker,
generator=generator,
collate_fn=collate_fn,
)对应训练时记录:
metadata = {
"seed": seed,
"batch_size": batch_size,
"num_workers": 4,
"drop_last": train,
"dataset_version": dataset.version,
"preprocess_version": preprocess_version,
}这类 metadata 看起来啰嗦,但当实验结果突然变化时,它能把问题从“玄学波动”缩小到几个具体变量。
Data Pipeline Checklist
在开始大规模训练前,至少检查:
- dataset length matches expectation;
- one sample has correct dtype, shape, and label;
- one batch has correct padding and mask;
- loss ignores padded or invalid targets;
- train/validation split has no leakage;
- shuffling is enabled only where needed;
- worker seeds cover Python, NumPy, and PyTorch randomness;
- distributed sampler calls
set_epoch(epoch); - effective batch or token count is logged;
- throughput is measured as both samples/sec and tokens/sec when relevant;
- iterable or streaming datasets shard by rank and worker without duplication;
- causal LM label shifting happens exactly once;
- packed examples carry document-boundary masks when examples are independent;
- cached preprocessing artifacts include data and tokenizer fingerprints;
- resume checkpoints record sampler epoch, batch offset, and token/example counters。
数据管线的好坏不在于
DataLoader参数写得多,而在于它定义的采样分布、batch 语义、mask 语义和实验记录是否清楚。训练不稳定时,先别急着换模型,很多问题其实藏在 batch 里。
References
- PyTorch documentation: torch.utils.data
- PyTorch documentation: Reproducibility