1.6 Data Pipeline and Reproducibility


训练效果不只由 model 和 optimizer 决定。数据如何切分、如何 shuffle、如何 batch、如何 padding、如何搬到 GPU、worker 如何随机增强,都会改变实际优化问题。很多时候,代码里的 DataLoader 看起来只是“喂数据”,但它其实定义了训练分布的采样过程。

这一节把 PyTorch 数据管线拆成三层:

  1. dataset: 一个 index 对应什么样本;
  2. sampler: 以什么顺序抽样;
  3. collate: 多个样本怎样组成 batch。

再把这些选择和 reproducibility、padding、mask、吞吐、分布式训练连起来。

Dataset as a Map from Index to Example

NoteDefinition: Map-Style Dataset

A map-style dataset implements __getitem__(index) and __len__(). It represents a finite collection where each integer index maps to one example.

典型分类数据集:

class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, records, transform):
        self.records = records
        self.transform = transform

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        path, label = self.records[idx]
        image = read_image(path)
        image = self.transform(image)
        return {"image": image, "label": label}

这里 Dataset 不负责 batch,也不负责 shuffle。它只回答一个问题:给定 idx,拿到什么训练样本?

NoteDefinition: Iterable Dataset

An iterable-style dataset implements __iter__() and yields examples sequentially. It is useful for streams, generated data, or data sources without random access.

Iterable dataset 更像数据流:

class JsonlStream(torch.utils.data.IterableDataset):
    def __init__(self, path):
        self.path = path

    def __iter__(self):
        with open(self.path) as f:
            for line in f:
                yield parse_json(line)

它适合日志流、大规模 shard、在线生成数据。但它也更容易出错:多 worker 时每个 worker 都会跑一遍 __iter__(),如果不手动分片,就会重复读数据。

WarningPitfall: IterableDataset Duplicates Data Unless Sharded

For an IterableDataset, every DataLoader worker constructs its own iterator. Without worker-aware sharding, multiple workers may yield the same examples.

一个最小 worker-aware stream 可以用 get_worker_info() 按行号取模:

class ShardedJsonlStream(torch.utils.data.IterableDataset):
    def __init__(self, path):
        self.path = path

    def __iter__(self):
        info = torch.utils.data.get_worker_info()
        if info is None:
            worker_id = 0
            num_workers = 1
        else:
            worker_id = info.id
            num_workers = info.num_workers

        with open(self.path) as f:
            for i, line in enumerate(f):
                if i % num_workers != worker_id:
                    continue
                yield parse_json(line)

分布式训练时还要把 rank 纳入 shard id。若有 \(R\) 个 data-parallel ranks,每个 rank 有 \(W\) 个 loader workers,则全局 worker 数是

\[ N_{\text{stream workers}}=R W. \]

全局 worker id 可以写成

\[ s=\operatorname{rank}\cdot W+\operatorname{worker\_id}. \]

于是每条记录只由满足

\[ i\bmod (RW)=s \]

的 worker 读取:

class RankWorkerJsonlStream(torch.utils.data.IterableDataset):
    def __init__(self, path, rank, world_size):
        self.path = path
        self.rank = rank
        self.world_size = world_size

    def __iter__(self):
        info = torch.utils.data.get_worker_info()
        worker_id = 0 if info is None else info.id
        num_workers = 1 if info is None else info.num_workers

        shard_id = self.rank * num_workers + worker_id
        nshards = self.world_size * num_workers

        with open(self.path) as f:
            for i, line in enumerate(f):
                if i % nshards == shard_id:
                    yield parse_json(line)

这段代码适合解释语义,但大规模文本训练通常会按文件 shard 而不是按行取模。原因是按行取模会让每个 worker 扫完整文件,IO 放大 \(RW\) 倍;按文件 shard 可以让每个 worker 只打开自己的文件子集。

NoteDefinition: Data Shard

A data shard is a subset of records or files assigned to one worker, rank, or process so that large datasets can be read in parallel without accidental duplication.

Map-style dataset 也需要一个明确的 index contract。一个稳妥做法是把训练样本 manifest 固化成表:

Column Meaning
id stable example id
path storage location
label target or metadata
split train/validation/test
length tokens, frames, or pixels used for batching
hash content fingerprint when feasible

这样 __getitem__(idx) 的含义不会随文件系统遍历顺序变化:

class ManifestDataset(torch.utils.data.Dataset):
    def __init__(self, rows, transform):
        self.rows = rows
        self.transform = transform

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, idx):
        row = self.rows[idx]
        item = load_record(row["path"])
        item = self.transform(item)
        item["example_id"] = row["id"]
        return item
ImportantImplementation Contract: Index Stability

If training needs exact resume or comparable ablations, dataset index order must be stable. Do not rely on unsorted directory listings or mutable remote query order.

Sampler Defines the Training Order

NoteDefinition: Sampler

A sampler is an iterator that yields dataset indices. It determines the order in which examples are drawn from a map-style dataset.

shuffle=True 看起来只是一个布尔值,本质上等价于使用 random sampler。训练时的 empirical risk 通常写成

\[ \hat{R}(\theta) = \frac{1}{N}\sum_{i=1}^{N}\ell(f_\theta(x_i),y_i). \]

SGD 每一步抽一个 mini-batch \(B_t\)

\[ g_t = \frac{1}{|B_t|} \sum_{i\in B_t} \nabla_\theta \ell_i(\theta). \]

如果 sampler 每个 epoch 给出不同随机排列,\(g_t\) 就是对全数据梯度的随机估计。若 sampler 有系统性偏差,例如同一类别连续出现、长文本集中在一起、时间序列泄漏未来样本,训练动态会明显改变。

WarningPitfall: Data Order Is Part of the Algorithm

Changing shuffle, sampler seed, bucket order, or distributed sharding can change optimization trajectories even when model code is identical.

常见 sampler 类型:

Sampler Use case
sequential validation, deterministic evaluation
random standard SGD training
weighted random class imbalance, sampling curriculum
distributed split indices across ranks
bucketed group similar lengths for efficient padding

Sampler 只产生单个 index;BatchSampler 产生一组 index。长度感知 batching 往往应该写成 batch sampler,因为它决定的是“哪些样本一起进一个 batch”:

class LengthBucketBatchSampler(torch.utils.data.Sampler):
    def __init__(self, lengths, batch_size, bucket_size, generator):
        self.lengths = lengths
        self.batch_size = batch_size
        self.bucket_size = bucket_size
        self.generator = generator

    def __iter__(self):
        n = len(self.lengths)
        order = torch.randperm(n, generator=self.generator).tolist()

        for start in range(0, n, self.bucket_size):
            bucket = order[start : start + self.bucket_size]
            bucket.sort(key=lambda i: self.lengths[i])

            for j in range(0, len(bucket), self.batch_size):
                batch = bucket[j : j + self.batch_size]
                if len(batch) == self.batch_size:
                    yield batch

    def __len__(self):
        return len(self.lengths) // self.batch_size

这里的 trade-off 很明确:先随机打乱大 bucket,再在 bucket 内按长度排序。bucket 越大,padding 越少,但随机性越弱;bucket 越小,随机性更强,但 padding 浪费更多。

NoteDefinition: Batch Sampler

A batch sampler yields lists of indices. It controls mini-batch composition directly and is the right abstraction for length bucketing, curriculum batches, and token-budget batching.

动态 token batch 更适合 LLM:不是固定 \(B\) 条样本,而是固定总 token 预算。设第 \(i\) 条长度为 \(T_i\),batch 满足

\[ \sum_{i\in B}T_i \le T_{\max}. \]

一个简单 greedy 版本:

def token_budget_batches(indices, lengths, max_tokens):
    batch = []
    ntok = 0

    for i in indices:
        length = lengths[i]
        if batch and ntok + length > max_tokens:
            yield batch
            batch = []
            ntok = 0

        batch.append(i)
        ntok += length

    if batch:
        yield batch

这会产生 variable sample count。训练日志因此不能只记录 batch_size,还要记录 effective tokens:

\[ T_{\text{step}}=\sum_{i\in B_t}T_i. \]

WarningPitfall: Token-Budget Batches Change Step Semantics

With token-budget batching, the number of sequences per optimizer step varies. Loss normalization, logging, and LR schedule should be defined in tokens or optimizer steps, not assumed fixed by sample count.

Collate Function and Batch Semantics

NoteDefinition: Collate Function

A collate function transforms a list of individual examples into one batch object. It defines stacking, padding, masks, and type conversion.

默认 collate 会尝试把同名字段 stack 起来:

examples = [
    {"x": torch.tensor([1, 2]), "y": 0},
    {"x": torch.tensor([3, 4]), "y": 1},
]

变成

{
    "x": tensor([[1, 2], [3, 4]]),
    "y": tensor([0, 1]),
}

但 NLP/LLM 数据长度通常不同,必须自定义 padding:

def collate_lm(examples, pad_id):
    lengths = [len(ex["input_ids"]) for ex in examples]
    max_len = max(lengths)

    input_ids = []
    attention_mask = []
    labels = []

    for ex in examples:
        ids = ex["input_ids"]
        pad = max_len - len(ids)
        input_ids.append(ids + [pad_id] * pad)
        attention_mask.append([1] * len(ids) + [0] * pad)
        labels.append(ids[1:] + [-100] * (pad + 1))

    return {
        "input_ids": torch.tensor(input_ids),
        "attention_mask": torch.tensor(attention_mask),
        "labels": torch.tensor(labels),
    }

这里最关键的不是代码形式,而是 batch 中每个字段的语义:

Field Meaning
input_ids tokens fed into the model
attention_mask which positions are real tokens
labels targets used by the loss
-100 ignore index for CrossEntropyLoss in many LM heads
WarningPitfall: Padding Must Match the Loss

Padding only the inputs is not enough. The loss must ignore padded target positions, otherwise the model is trained to predict pad tokens.

这里还有一个常见边界:label shifting 到底在哪里做?有些模型 forward 内部会把 labels=input_ids 自动 shift;有些训练代码要求 collate 阶段已经生成 shifted labels。两者只能选一个。若 collate 已经构造

input:  [A, B, C, EOS]
label:  [B, C, EOS, IGN]

模型内部就不能再 shift 一次。否则实际目标会变成

A -> C
B -> EOS

这会把 next-token prediction 改成错位的 two-step prediction。

ImportantImplementation Contract: Shift Labels Exactly Once

For causal language modeling, decide whether shifting happens in the collator or inside the model/loss wrapper. Double shifting silently trains the wrong conditional distribution.

batch contract 应该比“有几个 tensor”更具体。一个训练 step 至少应该知道:

Field Shape Dtype Device before step Semantics
input_ids [B, T] torch.long CPU discrete token ids
attention_mask [B, T] torch.bool or integer CPU real-token mask
labels [B, T] torch.long CPU target ids or ignore_index
position_ids [B, T] if present torch.long CPU position contract

这类表不是文档装饰,而是调试工具。若 labels 被错误 cast 成 FP16,CrossEntropyLoss 会直接失败;若 attention_mask 的 1/0 语义和模型期望相反,模型可能能跑但 loss 完全不可信。

嵌套 batch 搬到 GPU 时,最好写一个递归函数,避免只搬了顶层 tensor:

def move_to_device(x, device):
    if torch.is_tensor(x):
        return x.to(device, non_blocking=True)
    if isinstance(x, dict):
        return {k: move_to_device(v, device) for k, v in x.items()}
    if isinstance(x, list):
        return [move_to_device(v, device) for v in x]
    if isinstance(x, tuple):
        return tuple(move_to_device(v, device) for v in x)
    return x
WarningPitfall: Do Not Cast the Whole Batch Blindly

Images or activations may use floating dtypes, but labels, token ids, indices, and masks often have semantic integer or boolean dtypes. A blanket .half() or .float() over the batch can corrupt the objective.

Padding, Masks, and Effective Tokens

Suppose two sequences are

[A, B, C, EOS]
[D, E, EOS]

Right padding gives

[A, B, C, EOS]
[D, E, EOS, PAD]

For causal LM, labels are shifted targets. If position \(t\) predicts \(t+1\), then the second sequence should contribute loss on:

D -> E
E -> EOS

but not on

EOS -> PAD
PAD -> PAD

The masked token-level loss is

\[ L = \frac{\sum_{b,t}m_{b,t}\, [-\log p_\theta(y_{b,t}\mid x_{b,\le t})]} {\sum_{b,t}m_{b,t}}, \]

where \(m_{b,t}=1\) for valid target positions and \(0\) for padding or ignored positions.

NoteDefinition: Effective Token Count

Effective token count is the number of non-ignored target tokens contributing to the loss. In sequence modeling, it is often a better denominator than batch size.

这也是为什么 LLM 训练日志里经常看 tokens/sec 而不是 samples/sec。两个 batch 的 batch size 都是 8,但如果一个平均长度 128,另一个平均长度 2048,它们的计算量和有效训练信号完全不同。

Bucketing and Packing

Padding 会浪费计算。若一个 batch 里最长序列长度是 \(T_{\max}\),第 \(i\) 条真实长度是 \(T_i\),padding waste ratio roughly is

\[ \rho = 1-\frac{\sum_i T_i}{B T_{\max}}. \]

长度差异很大时,\(\rho\) 可能非常高。常见解决方式:

Method Idea Trade-off
bucketing batch examples with similar lengths less random order
dynamic batching constrain total tokens per batch variable sample count
packing concatenate short sequences into fixed blocks need boundary masks
truncation cap maximum length may lose information

LLM pretraining 常把文本流 pack 成固定长度 block:

doc1 <eos> doc2 <eos> doc3 <eos> ...

然后切成长度 \(T\) 的 chunks。这样 GPU 利用率高,但要清楚一个问题:是否允许跨文档 attention?有些 recipe 允许,因为 <eos> 足够提供边界;有些会构造 block-diagonal mask,禁止不同文档相互 attention。

WarningPitfall: Packing Changes the Objective

Packing can introduce cross-document context unless the attention mask prevents it. This changes the conditional distribution seen by the model.

Packing Implementation Contract

packing 至少有三种不同语义:

Packing mode Attention across documents Loss across boundary Typical use
stream packing yes yes, usually through EOS pretraining text stream
packed but boundary-aware no no instruction/SFT examples
packed with EOS-only boundary maybe no after EOS mixed corpora with explicit separators

边界感知 packing 可以给每个 token 一个 doc_id

tokens: [A, B, EOS, D, E, EOS, X, EOS]
doc_id: [0, 0,   0, 1, 1,   1, 2,   2]

causal attention mask 不只是下三角,还要同文档:

\[ M_{t,s} = \mathbb{1}[s\le t]\cdot \mathbb{1}[\operatorname{doc}(s)=\operatorname{doc}(t)]. \]

代码上可以这样构造:

def packed_causal_mask(doc_ids):
    # doc_ids: [B, T]
    _, seqlen = doc_ids.shape
    same_doc = doc_ids[:, :, None].eq(doc_ids[:, None, :])
    causal = torch.ones(seqlen, seqlen, dtype=torch.bool, device=doc_ids.device).tril()
    return same_doc & causal

loss mask 还要去掉每个文档最后一个 token,因为它没有同文档 next token:

def packed_labels(input_ids, doc_ids, ignore_index=-100):
    labels = input_ids.roll(shifts=-1, dims=1)
    same_next_doc = doc_ids[:, :-1].eq(doc_ids[:, 1:])

    loss_mask = torch.zeros_like(input_ids, dtype=torch.bool)
    loss_mask[:, :-1] = same_next_doc
    labels = labels.masked_fill(~loss_mask, ignore_index)
    return labels, loss_mask

若你允许 EOS -> next_doc_first_token 的 loss,那模型学到的是“文档边界后下一个样本的开头也可预测”。这对连续文本流可能合理,对 instruction 数据通常不合理,因为两个样本的相邻关系是 packing 人为制造的。

ImportantImplementation Contract: Track Document Boundaries

Packed batches need explicit boundary metadata when examples are independent. Without doc_id, sequence_id, or equivalent masks, packing may leak context and loss across examples.

packing 还会影响位置编码。若每个文档在 packed block 内都从位置 0 开始,需要构造 reset position ids:

def reset_position_ids(doc_ids):
    pos = torch.zeros_like(doc_ids)
    for b in range(doc_ids.shape[0]):
        start = 0
        for t in range(1, doc_ids.shape[1] + 1):
            is_end = t == doc_ids.shape[1] or doc_ids[b, t] != doc_ids[b, t - 1]
            if is_end:
                pos[b, start:t] = torch.arange(t - start, device=doc_ids.device)
                start = t
    return pos

对 absolute positional embeddings,reset 与不 reset 是不同建模假设;对 RoPE,也会影响相对距离。packing 不是纯粹的数据压缩,它改变了模型看到的位置几何。

DataLoader Workers and Host-to-GPU Transfer

DataLoader 的吞吐通常由这几个阶段决定:

disk/network -> decode -> transform -> collate -> pinned memory -> H2D copy -> GPU compute

如果 GPU 经常等数据,模型再优化也没有用。几个常用参数:

Option Meaning
num_workers subprocesses for loading data
pin_memory allocate page-locked host memory for faster H2D transfer
prefetch_factor batches prepared ahead per worker
persistent_workers keep workers alive across epochs
drop_last drop incomplete last batch

典型训练 loader:

loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True,
    collate_fn=collate_fn,
)

对应训练 loop 中:

batch = {
    k: v.to(device, non_blocking=True)
    for k, v in batch.items()
}

pin_memory=Truenon_blocking=True 配合时,CPU 到 GPU 的拷贝更容易和计算重叠。不过这不是总能加速:如果数据很小、CPU 已经很忙、或者机器内存紧张,开太多 workers 反而慢。

WarningPitfall: More Workers Are Not Always Faster

Too many workers can cause CPU contention, memory pressure, file descriptor exhaustion, or duplicated preprocessing overhead.

Preprocessing Cache and Data Fingerprints

昂贵且确定的预处理最好离线缓存,例如:

Raw modality Expensive deterministic preprocessing
text normalization, tokenization, chat-template rendering
image decode, resize, center crop, feature extraction
audio resampling, spectrogram, VAD segmentation
graph neighbor lists, degree features, subgraph extraction

缓存的风险是“内容已经变了但缓存还在用”。因此缓存文件应带 manifest:

cache_manifest = {
    "dataset_name": "my_corpus",
    "dataset_version": "2026-06-10",
    "preprocess_code_hash": code_hash,
    "tokenizer_name": tokenizer_name,
    "tokenizer_hash": tokenizer_hash,
    "template_hash": chat_template_hash,
    "num_records": len(records),
}

对文本 LLM,tokenized cache 的 contract 尤其严格:

Cached field Why it matters
input_ids tokenizer vocabulary and merge table dependent
special_tokens_mask chat/control token boundaries
attention_mask padding or packed-block semantics
labels or loss_mask SFT/RLHF objective semantics
length batching and token-budget accounting

如果 tokenizer 新增 special token,但旧 cache 仍用旧 ids,模型会在错误的离散空间里训练。最安全的做法是把 tokenizer 文件内容 hash 进 cache key,而不是只记录名字。

NoteDefinition: Data Fingerprint

A data fingerprint is a stable hash or manifest that identifies the raw data, preprocessing code, tokenizer/config, and filtering rules used to produce a training dataset.

一个简单 fingerprint 可以组合多层信息:

def fingerprint(parts):
    h = hashlib.sha256()
    for key, value in sorted(parts.items()):
        h.update(str(key).encode())
        h.update(b"\0")
        h.update(str(value).encode())
        h.update(b"\0")
    return h.hexdigest()

不要只记录 “dataset=v1”。可复现实验至少要能回答:

  1. 哪些原始文件进入训练;
  2. 哪些样本被过滤;
  3. tokenizer、normalizer、chat template 是哪一版;
  4. packing、truncation、loss mask 规则是什么;
  5. cache 是否由当前代码生成。
WarningPitfall: Online Tokenization Can Hide Data Drift

If tokenization happens inside DataLoader workers, changing tokenizer files or chat templates can silently change the training distribution without changing model code.

Randomness in Data Pipelines

训练中的随机性至少来自:

  1. model initialization;
  2. data shuffling;
  3. random data augmentation;
  4. dropout;
  5. CUDA kernels;
  6. distributed reduction order。

一个基本 seed setup:

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

但 DataLoader workers 是独立进程。若 augmentation 里用了 NumPy 或 Python random,需要在 worker 中设置种子:

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

generator = torch.Generator()
generator.manual_seed(42)

loader = DataLoader(
    dataset,
    shuffle=True,
    worker_init_fn=seed_worker,
    generator=generator,
)

worker seed 解决的是“每个 worker 随机流不同”。但有些 augmentation 更适合 per-example stateless seed:同一个样本在同一个 epoch 用同一个随机增强,不依赖它被哪个 worker 读到。

设样本 id 为 \(u\),epoch 为 \(e\),全局 seed 为 \(s_0\),可以定义

\[ s(u,e)=H(s_0,u,e). \]

实现上:

def example_seed(base_seed, example_id, epoch):
    text = f"{base_seed}:{example_id}:{epoch}"
    digest = hashlib.blake2b(text.encode(), digest_size=8).digest()
    return int.from_bytes(digest, "little") % 2**32


def augment_with_seed(example, epoch, base_seed):
    seed = example_seed(base_seed, example["example_id"], epoch)
    rng = random.Random(seed)

    if rng.random() < 0.5:
        example["image"] = horizontal_flip(example["image"])
    return example

这种做法的好处是:改变 num_workers、rank 数、prefetch 深度、甚至 batch order,都不会改变某个样本在某个 epoch 的 augmentation 随机数。

ImportantImplementation Contract: Separate Sample Randomness from Worker Randomness

Worker seeds make worker processes independent; stateless per-example seeds make augmentation reproducible under changes to worker count, sharding, and resume position.

NoteDefinition: Reproducibility

Reproducibility means that repeated runs under specified software, hardware, seeds, and deterministic settings produce the same or statistically comparable results.

这里要注意“完全一样”和“统计可比”是两种标准。研究复现实验通常至少要求均值、方差、趋势稳定;debug 某个 bug 时才需要 bitwise identical。

Deterministic Algorithms

PyTorch 可以要求使用 deterministic algorithms:

torch.use_deterministic_algorithms(True)

这有助于调试,但可能牺牲速度,某些 operation 还可能没有 deterministic 实现。cuDNN benchmark 也会影响卷积算法选择:

torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
WarningPitfall: Determinism Can Reduce Throughput

Deterministic kernels may be slower, and exact reproducibility can depend on PyTorch version, CUDA version, GPU model, and distributed topology.

实际工作中可以分层:

Stage Recommended setting
debugging strict seeds, deterministic algorithms, small data
ablation fixed split, fixed data order policy, multiple seeds
large training controlled seeds, logged versions, tolerate non-bitwise differences

Train/Validation/Test Splits

数据切分是最容易被低估的实验设计问题。随机切分不总是合理:

Data type Split risk
image duplicates near-duplicate leakage
user behavior same user in train and test
time series future leaking into past
documents paragraphs from same document split apart
graphs edges/nodes leak through neighborhood

形式上,我们希望验证集估计的是目标分布风险:

\[ R(\theta)=\mathbb{E}_{(x,y)\sim P_{\text{target}}}\ell(f_\theta(x),y). \]

若 validation set 与 train set 有信息泄漏,估计的就不是泛化风险,而是某种记忆能力。

WarningPitfall: Random Split Is Not Always IID

If examples share users, documents, time windows, graphs, or generated variants, random splitting may leak information across splits.

实验记录里至少保存:

  1. split generation code;
  2. split seed;
  3. exact file list or index list;
  4. preprocessing version;
  5. filtering rules;
  6. tokenizer or vocabulary version。

还可以写几个很朴素的 leakage tests。比如 document-level split 应保证同一 doc_id 不跨 split:

def assert_disjoint_groups(rows, group_key, split_key):
    owners = {}
    for row in rows:
        group = row[group_key]
        split = row[split_key]
        old = owners.setdefault(group, split)
        if old != split:
            raise ValueError(f"group {group} appears in {old} and {split}")

时间序列 split 至少检查验证集时间不早于训练集:

def assert_time_order(train_rows, val_rows, time_key):
    max_train = max(row[time_key] for row in train_rows)
    min_val = min(row[time_key] for row in val_rows)
    if min_val < max_train:
        raise ValueError("validation contains timestamps before training ends")

近重复图像、网页镜像、代码仓库 fork 这类情况不能只靠 exact id,需要 hash、embedding 或 domain-specific rule。它们不一定是“代码 bug”,但会让 validation loss 虚高可信度。

ImportantImplementation Contract: Split by Dependence Unit

Split by the unit that carries dependence: user, document, time window, graph component, repository, or source domain. IID row-level splitting is valid only when rows are genuinely independent.

Distributed Data Loading

Data parallel training 中,每个 rank 应该看到不同数据 shard。DistributedSampler 的核心作用是把 dataset indices 按 rank 切开:

sampler = torch.utils.data.DistributedSampler(
    dataset,
    num_replicas=world_size,
    rank=rank,
    shuffle=True,
)

loader = DataLoader(
    dataset,
    batch_size=batch_size,
    sampler=sampler,
    collate_fn=collate_fn,
)

每个 epoch 还需要:

sampler.set_epoch(epoch)

否则每个 epoch 的 shuffle 顺序可能重复。

NoteDefinition: Global Batch Size

Global batch size is \[ B_{\text{global}} = B_{\text{per-rank}} \times N_{\text{ranks}} \times K_{\text{accum}}. \] It is the number of examples contributing to one optimizer step across all workers.

对 token-level 任务,更常用 global token batch:

\[ T_{\text{global}} = \sum_{\text{rank}}\sum_{\text{micro}}\sum_{b,t}m_{b,t}. \]

这比样本数更准确,因为不同 rank 的有效 token 数可能不同。

Exact Resume of Data Order

exact resume 要求中断后下一步看到的 batch 与不中断训练完全一致。模型 checkpoint 只覆盖参数,数据顺序还需要这些状态:

State Map-style dataset Iterable or streaming dataset
epoch sampler epoch stream epoch or pass id
position batch index or consumed indices file shard and byte/record cursor
RNG sampler generator state shuffle buffer RNG state
distributed rank/world size contract shard assignment contract
accounting examples/tokens seen examples/tokens seen

对 map-style dataset,最简单可恢复策略是:每个 epoch 由 (seed, epoch) 决定一个 permutation,checkpoint 保存 epochbatch_in_epoch。恢复时重新生成 permutation,跳过已消费 batch:

def epoch_indices(n, seed, epoch):
    g = torch.Generator()
    g.manual_seed(seed + epoch)
    return torch.randperm(n, generator=g).tolist()


def resume_batches(n, seed, epoch, batch_size, batch_in_epoch):
    indices = epoch_indices(n, seed, epoch)
    start = batch_in_epoch * batch_size
    return indices[start:]

这个策略简单但有条件:dataset 长度、index order、world size、batch size、drop_last 都不能变。若它们变了,batch_in_epoch 不再指向同一批样本。

WarningPitfall: Resuming with a Different World Size Changes Data Order

Changing world_size, per-rank batch size, gradient accumulation, or drop_last can change which examples contribute to the next optimizer step, even if model and optimizer states load successfully.

streaming dataset 更难,因为数据不是随机可索引的。工程上常见选择:

Strategy Pros Cons
save exact cursor precise storage-specific; hard for compressed streams
deterministic shard order + skip records simple slow when skipping deep into stream
checkpoint precomputed sample ids precise extra manifest/storage
tolerate statistical resume scalable not bitwise exact

大规模预训练往往接受 statistical resume:恢复后继续从同一分布采样,但不保证每条样本顺序完全一致。若目标是 debug 或 ablation,应该先用 map-style subset 做 exact resume 验证。

NoteDefinition: Statistical Resume

Statistical resume restores training from the same distribution and comparable counters, but does not guarantee the exact next batch equals the uninterrupted run.

drop_last and BatchNorm

drop_last=True 会丢掉最后一个不完整 batch。它看似只是整齐 batch shape,但会影响:

  1. 每个 epoch 实际训练样本数;
  2. class imbalance 下少数样本是否更容易被丢;
  3. distributed training 中各 rank batch 数是否一致;
  4. BatchNorm statistics 是否受小 batch 干扰。

如果模型使用 BatchNorm,小 batch 的统计不稳定,最后一个小 batch 可能让 running statistics 抖动。若使用 LayerNorm/RMSNorm,通常影响较小。

WarningPitfall: Incomplete Batches Can Break Assumptions

Some code assumes fixed batch size for reshaping, BatchNorm, contrastive negatives, or distributed collectives. Decide deliberately whether to keep or drop the last batch.

A Practical Data Pipeline Template

一个相对稳妥的监督学习 pipeline:

def make_loader(dataset, batch_size, seed, train):
    generator = torch.Generator()
    generator.manual_seed(seed)

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=train,
        num_workers=4,
        pin_memory=True,
        persistent_workers=True,
        drop_last=train,
        worker_init_fn=seed_worker,
        generator=generator,
        collate_fn=collate_fn,
    )

对应训练时记录:

metadata = {
    "seed": seed,
    "batch_size": batch_size,
    "num_workers": 4,
    "drop_last": train,
    "dataset_version": dataset.version,
    "preprocess_version": preprocess_version,
}

这类 metadata 看起来啰嗦,但当实验结果突然变化时,它能把问题从“玄学波动”缩小到几个具体变量。

Data Pipeline Checklist

在开始大规模训练前,至少检查:

  1. dataset length matches expectation;
  2. one sample has correct dtype, shape, and label;
  3. one batch has correct padding and mask;
  4. loss ignores padded or invalid targets;
  5. train/validation split has no leakage;
  6. shuffling is enabled only where needed;
  7. worker seeds cover Python, NumPy, and PyTorch randomness;
  8. distributed sampler calls set_epoch(epoch)
  9. effective batch or token count is logged;
  10. throughput is measured as both samples/sec and tokens/sec when relevant;
  11. iterable or streaming datasets shard by rank and worker without duplication;
  12. causal LM label shifting happens exactly once;
  13. packed examples carry document-boundary masks when examples are independent;
  14. cached preprocessing artifacts include data and tokenizer fingerprints;
  15. resume checkpoints record sampler epoch, batch offset, and token/example counters。

数据管线的好坏不在于 DataLoader 参数写得多,而在于它定义的采样分布、batch 语义、mask 语义和实验记录是否清楚。训练不稳定时,先别急着换模型,很多问题其实藏在 batch 里。

References