1.4 Tensor Operation Patterns
Tensor API 很多,但常见操作可以归成几类:索引、选择、拼接、维度变换、窗口化、padding、归约和矩阵乘。真正重要的不是背函数名,而是知道每类操作的 shape 规则、是否复制、是否保留梯度、以及在模型里常用来表达什么。
Indexing: View or Copy?
基础切片通常返回 view:
import torch
import torch.nn.functional as F
x = torch.arange(24).reshape(2, 3, 4)
y = x[:, 1:, :]y 和 x 共享 storage。修改 y 可能影响 x。
高级索引通常返回 copy:
idx = torch.tensor([0, 2])
z = x[:, idx, :]这类索引需要 gather 出新内存,常见于样本选择、token 选择和 hard negative mining。它仍然可导,但内存语义不同。
| Operation | Usually view? | Shape rule |
|---|---|---|
x[i] |
yes | remove indexed dim |
x[:, 1:3] |
yes | slice range |
x[..., None] |
yes | insert size-1 dim |
x[mask] |
no | flatten selected elements |
x.index_select(dim, idx) |
no | replace dim by len(idx) |
torch.gather |
no | output shape = index shape |
x[mask] returns a 1D tensor of selected elements when mask has the same shape as x. Use masked_fill or where if you need to preserve shape.
Storage, Stride, and View Safety
一个 tensor 的 shape 不是完整信息。PyTorch tensor 还包含 storage、storage offset 和 stride。对二维 tensor,地址可以写成:
\[ \operatorname{addr}(i,j) = \operatorname{storage\_offset} +i\cdot s_0+j\cdot s_1. \]
transpose 通常不复制数据,只改变 stride:
x = torch.arange(12).reshape(3, 4)
y = x.transpose(0, 1)
print(x.stride()) # (4, 1)
print(y.stride()) # (1, 4)view 要求新的 shape 能用原 stride 解释;reshape 如果不能 view,就会复制:
y = x.transpose(0, 1)
z = y.reshape(12) # may allocate
w = y.contiguous().view(12)A tensor is contiguous when its logical index order matches a dense row-major memory layout, so adjacent logical elements are adjacent in storage.
很多 kernel 可以处理 non-contiguous input,但有两种成本:
- kernel 内部按复杂 stride 访存,coalescing 变差;
- kernel 或框架先隐式
.contiguous(),产生额外复制。
这就是为什么 attention 里常见:
x = x.transpose(1, 2).contiguous().view(B, T, H * D)contiguous() 不是装饰,而是在明确支付一次重排内存的成本,换取后续 view 和 GEMM 友好的布局。
reshape may return a view or allocate a copy. In performance-critical code, check is_contiguous, strides, or memory traces instead of assuming it is free.
Broadcasting as Shape Algebra
Broadcasting 不是“PyTorch 很聪明地帮你补维度”,而是一套确定的 shape algebra。两个 shape 从右往左对齐。每一维必须满足:
- 两边相等;
- 或者其中一边是 1;
- 或者其中一边不存在,视为 1。
例如:
x: [B, T, C]
bias: [C]
result: [B, T, C]
bias 会被当成 [1, 1, C]。再比如 attention mask:
scores: [B, H, Q, K]
mask: [B, 1, 1, K]
result: [B, H, Q, K]
这里 mask 在 head 和 query 维度上广播,表示同一个 padding mask 被所有 heads 和 queries 复用。
Broadcasting aligns tensor shapes from the trailing dimension and virtually expands dimensions of size 1 to match the other operand without necessarily copying data.
Gradient Through Broadcast
广播维度在 forward 中像复制,在 backward 中会把梯度 sum 回原始维度。设
\[ y_{btc}=x_{btc}+b_c, \qquad L=L(y). \]
则
\[ \frac{\partial L}{\partial b_c} = \sum_{b,t} \frac{\partial L}{\partial y_{btc}}. \]
这正是 bias gradient 的来源:bias 在 batch 和 time 上被广播,反传时沿这些维度归约。
expand 的 stride-0 语义也解释了这一点:
bias = torch.randn(4, requires_grad=True)
y = bias[None, None, :].expand(2, 3, 4)
loss = y.sum()
loss.backward()
print(bias.grad) # tensor([6., 6., 6., 6.])每个 bias 元素被逻辑使用了 \(2\times3\) 次,所以梯度累加为 6。
If a tensor has an accidental size-1 dimension, PyTorch may broadcast it instead of erroring. The result has a valid shape but wrong semantics.
Broadcasting Debug Pattern
当 shape 复杂时,不要直接相信报错。先把每个 tensor 显式 reshape 到想要的 contract:
# logits: [B, T, V], labels: [B, T]
B, T, V = logits.shape
labels = labels.reshape(B, T, 1)
chosen = logits.gather(-1, labels).squeeze(-1)如果某个张量按语义应该是 [B, T],不要让它以 [B, T, 1] 或 [B, 1, T] 的形式混进 loss;这两种 shape 都能广播,但表达的是完全不同的轴。
一个实用检查是把关键 binary op 的输入和输出一起记录:
def log_binary_shapes(name, a, b, out):
print(name, tuple(a.shape), tuple(b.shape), "->", tuple(out.shape))真实训练里,很多“loss 怎么突然小了/大了”的问题,最后都是 [B, T] mask 被错误广播成 [B, B, T] 或 [B, T, T]。
Masks and where
Mask operations are everywhere in DL: padding masks, causal masks, loss masks, attention masks.
x = torch.randn(2, 4)
mask = x > 0
y = x.masked_fill(~mask, 0.0)torch.where 在两个 tensor 之间逐元素选择:
y = torch.where(mask, x, torch.zeros_like(x))如果用于 logits mask,常见写法:
scores = scores.masked_fill(~attention_mask, float("-inf"))
probs = torch.softmax(scores, dim=-1)注意 mask 的 dtype 和广播 shape:
scores: [B, H, Q, K]
attention_mask: [B, 1, 1, K]
mask 不是装饰,它定义了模型的条件独立结构。attention mask 错了,模型能“偷看”未来或 padding;loss mask 错了,模型会优化错误目标。
Masked Reductions
masked mean 不能写成 x[mask].mean() 后就结束,因为 boolean indexing 会 flatten,并且空 mask 会产生 NaN。更稳的写法是保留 shape:
def masked_mean(x, mask, dim):
mask = mask.to(dtype=x.dtype)
total = (x * mask).sum(dim=dim)
denom = mask.sum(dim=dim).clamp_min(1)
return total / denom对于 loss,denominator 是 objective 的一部分。token mean、sequence mean、batch mean 会给不同样本不同权重:
\[ \mathcal{L}_{\text{token}} = \frac{\sum_{b,t}m_{bt}\ell_{bt}} {\sum_{b,t}m_{bt}}, \qquad \mathcal{L}_{\text{seq}} = \frac{1}{B}\sum_b \frac{\sum_tm_{bt}\ell_{bt}} {\sum_tm_{bt}}. \]
如果 mask 是 [B, T],loss 是 [B, T],就不要让 boolean indexing 悄悄丢掉 batch 维度,否则你很难再实现 sequence-level weighting。
If a mask has zero valid elements, decide whether the result should be zero, skipped, or an error. mean over an empty selection silently creates NaN.
Gather and Scatter
gather 按 index 从某个维度取值:
logits = torch.randn(2, 3, 5) # [B, T, V]
labels = torch.tensor([[1, 3, 4], [0, 2, 2]])
token_logits = logits.gather(
dim=-1,
index=labels.unsqueeze(-1),
).squeeze(-1)这相当于取每个位置真实 label 的 logit。scatter / scatter_add 则反向把值写回目标位置:
out = torch.zeros(2, 5)
index = torch.tensor([[1, 3, 3], [0, 2, 2]])
src = torch.ones(2, 3)
out.scatter_add_(dim=1, index=index, src=src)scatter_add 常用于 histogram、segment aggregation、GNN message passing、MoE dispatch 后的 combine。
Gather selects values from indexed positions; scatter writes or accumulates values into indexed positions. Many sparse and routing operations are gather-scatter programs.
Gradients of Gather and Scatter
gather 的 backward 是 scatter_add。如果同一个 source index 被 gather 多次,梯度会累加回同一个位置:
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
idx = torch.tensor([0, 0, 2])
y = x.gather(0, idx)
loss = y.sum()
loss.backward()
print(x.grad) # tensor([2., 0., 1.])数学上,若
\[ y_i=x_{\operatorname{idx}_i}, \]
则
\[ \frac{\partial L}{\partial x_j} = \sum_{i:\operatorname{idx}_i=j} \frac{\partial L}{\partial y_i}. \]
这正是 embedding lookup 的梯度语义:一个 token id 在 batch 中出现多次,它对应 embedding row 的梯度会累加。
scatter_add 的 backward 则会把目标梯度 gather 回 src。但 scatter_ 这种覆盖写入在 index 重复时语义更危险:多个 source 写同一个位置,谁覆盖谁可能依赖实现顺序。需要聚合时优先用 scatter_add、scatter_reduce 或明确的 segment reduction。
If multiple values write to the same index, use an operation with explicit reduction semantics such as scatter_add instead of relying on overwrite order.
Segment Reductions
GNN、MoE、variable-length sequence pooling 常需要按 group id 聚合:
def segment_sum(src, segment_ids, num_segments):
# src: [N, D], segment_ids: [N]
out = src.new_zeros(num_segments, src.size(-1))
index = segment_ids[:, None].expand_as(src)
return out.scatter_add_(0, index, src)如果要做 segment mean,还要除以 count:
def segment_mean(src, segment_ids, num_segments):
total = segment_sum(src, segment_ids, num_segments)
count = src.new_zeros(num_segments, 1)
ones = src.new_ones(src.size(0), 1)
count.scatter_add_(0, segment_ids[:, None], ones)
return total / count.clamp_min(1)这个 pattern 比 for-loop 更接近 GPU 的工作方式:把所有边/message/token assignment 展平成 [N, D],再用 index 把它们规约回 group。
Concatenation and Stacking
cat 沿已有维度拼接:
a = torch.randn(2, 3)
b = torch.randn(2, 5)
c = torch.cat([a, b], dim=1) # [2, 8]stack 新增维度:
a = torch.randn(3)
b = torch.randn(3)
s = torch.stack([a, b], dim=0) # [2, 3]区别:
| Operation | Adds new dim? | Inputs shape |
|---|---|---|
cat |
no | same except concat dim |
stack |
yes | all same |
cat/stack 都会分配新 tensor。若在循环里不断 cat,会造成重复分配和复制:
# bad
out = torch.empty(0, d)
for h in chunks:
out = torch.cat([out, h], dim=0)
# better
outs = []
for h in chunks:
outs.append(h)
out = torch.cat(outs, dim=0)Split and Chunk
chunk 按块数切:
parts = torch.chunk(x, chunks=4, dim=0)split 按大小切:
parts = torch.split(x, split_size_or_sections=[2, 3, 5], dim=0)这些通常返回 views。训练中常用来切 heads、切 micro-batches、切 QKV:
q, k, v = qkv.chunk(3, dim=-1)Shape Transformations
常见 shape 变换:
x.reshape(B, T, H, D)
x.permute(0, 2, 1, 3)
x.transpose(1, 2)
x.unsqueeze(1)
x.squeeze(-1)
x.flatten(1)推荐把每个轴写成变量名,而不是只写魔法数字:
B, T, C = x.shape
x = x.reshape(B, T, num_heads, head_dim)
x = x.transpose(1, 2) # [B, H, T, Dh]这样读者能看到 tensor program 的语义。
Calling x.squeeze() removes all size-1 dimensions. Prefer x.squeeze(dim) when batch size may be 1.
Shape Assertions as Documentation
复杂模型里,shape 注释最好变成可执行检查。一个轻量 helper:
def check_shape(x, shape, name):
if len(x.shape) != len(shape):
raise ValueError(f"{name}: expected rank {len(shape)}, got {x.shape}")
for got, exp in zip(x.shape, shape):
if exp is not None and got != exp:
raise ValueError(f"{name}: expected {shape}, got {tuple(x.shape)}")
check_shape(q, (B, H, T, D), "query")
check_shape(attention_mask, (B, 1, 1, T), "attention_mask")这里用 None 表示不检查的维度。训练早期显式检查 shape 比等到 matmul 报一个很深的 CUDA error 更省时间。
A tensor contract records the expected axes, dtype, device, mask semantics, and gradient requirements of tensors crossing a module boundary.
常见 contract 可以写在 forward docstring 或注释里:
input_ids: [B, T], int64, CPU/GPU
attention_mask: [B, T], bool, 1 for real token
hidden_states: [B, T, C], float, requires_grad during train
logits: [B, T, V], float
这类文字不是形式主义;它让后面的人知道 [B, T] 的 mask 到 attention 前必须 reshape 成 [B, 1, 1, T],也知道 label mask 和 attention mask 不是同一个张量。
expand vs repeat
expand 不复制数据,只改 stride;repeat 真复制数据:
x = torch.randn(3, 1)
y = x.expand(3, 4) # view, stride 0 on expanded dim
z = x.repeat(1, 4) # copy| Operation | Copies? | Writable semantics | Use case |
|---|---|---|---|
expand |
no | aliased, often not safe in-place | broadcasting constants |
repeat |
yes | independent values | materialized tiled tensor |
如果后续只是读,优先 expand;如果后续要修改或传给要求 contiguous 的 kernel,可能需要 repeat 或 contiguous()。
unfold and Sliding Windows
unfold 把滑动窗口变成一个新维度:
x = torch.arange(10)
w = x.unfold(dimension=0, size=4, step=2)
# shape [4, 4]: windows [0:4], [2:6], [4:8], [6:10]对于图像,nn.Unfold 会把局部 patch 展成列,形成 im2col:
patches = torch.nn.Unfold(kernel_size=3, padding=1)(images)卷积可以理解成:
\[ \text{conv}(x,W) = \text{matmul}(\operatorname{unfold}(x), \operatorname{flatten}(W)). \]
这解释了为什么卷积实现常常转化为 GEMM 或隐式 GEMM:矩阵乘是 GPU 上高度优化的核心算子。
as_strided and Overlapping Views
unfold 的核心可以理解成一种受控的 as_strided:通过人为指定 shape 和 stride,让同一块 storage 以窗口形式被访问。对一维张量:
x = torch.arange(8)
w = x.as_strided(size=(3, 4), stride=(2, 1))逻辑上得到窗口:
row 0: x[0], x[1], x[2], x[3]
row 1: x[2], x[3], x[4], x[5]
row 2: x[4], x[5], x[6], x[7]
注意这里窗口之间重叠。地址公式是:
\[ \operatorname{addr}(i,j) = \operatorname{offset}+2i+j. \]
所以 w[0, 2] 和 w[1, 0] 都指向 x[2]。这种 view 对只读计算很有用,但 in-place 写入会产生不清楚的覆盖语义。
If two logical elements share the same storage location, in-place writes through the view can overwrite each other. Use unfold/as_strided for read patterns, not casual mutation.
Backward 也会沿重叠位置累加。若
\[ y_{ij}=x_{2i+j}, \qquad L=\sum_{i,j}y_{ij}, \]
则 \(x_2\) 出现在两个窗口里,所以
\[ \frac{\partial L}{\partial x_2}=2. \]
这和 gather 的重复 index 梯度累加是同一种数学结构:一个 storage element 被多个 logical output 读取,反传时多个路径的梯度加回同一个位置。
An overlapping view is a tensor view where two or more logical indices map to the same underlying storage location.
工程建议很简单:直接写 as_strided 前先问自己能不能用 unfold、view、transpose、permute 或库函数表达。as_strided 是底层工具,适合写 kernel 原型或解释内存布局,不适合在普通训练代码里随手用。
Padding
F.pad 的 pad 参数从最后一维开始写:
x = torch.randn(2, 3)
y = F.pad(x, (1, 2), mode="constant", value=0)对 2D spatial tensor,(left, right, top, bottom):
img = torch.randn(1, 1, 32, 32)
padded = F.pad(img, (1, 1, 2, 2))不同 padding mode 有不同边界假设:
| Mode | Meaning | Common use |
|---|---|---|
constant |
fill with value | masks, sequence pad |
reflect |
mirror without edge repeat | image augmentation/filtering |
replicate |
repeat edge value | image boundary |
circular |
wrap around | periodic signal |
对于 NLP padding,通常不只是补 input_ids,还要同时补 attention_mask 和 labels。
Reductions and keepdim
归约操作会减少维度:
x = torch.randn(2, 3, 4)
m1 = x.mean(dim=1) # [2, 4]
m2 = x.mean(dim=1, keepdim=True) # [2, 1, 4]keepdim=True 的好处是后续广播更安全:
x_centered = x - x.mean(dim=1, keepdim=True)常见归约:
| Operation | Meaning |
|---|---|
sum/mean |
loss/statistics |
max/min |
pooling or clipping |
argmax |
index of max, non-differentiable |
logsumexp |
stable softmax/partition |
norm |
gradient norm, regularization |
稳定的 log-sum-exp:
\[ \log\sum_i e^{x_i} = m+\log\sum_i e^{x_i-m}, \qquad m=\max_i x_i. \]
这就是 torch.logsumexp 比手写 torch.log(torch.exp(x).sum()) 更稳的原因。
Matrix Multiplication Family
PyTorch 有多种矩阵乘:
| Function | Input | Meaning |
|---|---|---|
torch.dot |
[N], [N] |
vector dot |
torch.mv |
[M,N], [N] |
matrix-vector |
torch.mm |
[M,N], [N,P] |
matrix-matrix |
torch.bmm |
[B,M,N], [B,N,P] |
batched matmul |
torch.matmul / @ |
broadcasted | general matmul |
torch.einsum |
named equation | explicit contraction |
Attention score 可以写成:
scores = q @ k.transpose(-2, -1)也可以写成:
scores = torch.einsum("bhtd,bhsd->bhts", q, k)einsum 更接近数学符号,适合讲义和复杂 contraction;matmul 通常更容易被优化库识别。性能关键路径上要 profile,而不是凭感觉选。
einsum as Axis Algebra
einsum 的价值是把维度语义直接写进公式。比如 batched bilinear score:
\[ s_{bnm} = x_{bni}W_{ij}y_{bmj}. \]
可以写成:
scores = torch.einsum("bni,ij,bmj->bnm", x, W, y)这比手动 reshape/matmul 更不容易把 \(n\)、\(m\) 轴弄反。另一个例子是 multi-head attention:
scores = torch.einsum("bhtd,bhsd->bhts", q, k)
out = torch.einsum("bhts,bhsd->bhtd", probs, v)缺点是并非所有 einsum 都会走最优 kernel。简单 GEMM、batched GEMM、attention kernel 这类热点通常应优先用库函数或 fused kernel;复杂但不频繁的统计/调试计算适合 einsum。
Vectorization Instead of Python Loops
PyTorch 的性能来自把大批量工作交给底层 kernel。逐样本 Python loop 通常是坏信号:
# slow
outs = []
for i in range(B):
outs.append(layer(x[i]))
out = torch.stack(outs)
# better
out = layer(x)有些逻辑看似必须 per-example,其实可以用 batch 维度、mask 或 gather/scatter 表达。例如从每个序列取最后一个有效 token:
# hidden: [B, T, C], attention_mask: [B, T]
lengths = attention_mask.long().sum(dim=1).clamp_min(1)
idx = (lengths - 1)[:, None, None].expand(-1, 1, hidden.size(-1))
last = hidden.gather(dim=1, index=idx).squeeze(1)如果确实需要“对 batch 中每个样本运行同一函数”,可以考虑 torch.vmap / torch.func.vmap。但 vmap 要求函数本身没有难以批处理的 Python side effect,并且不是所有 op 都有 batching rule。
Vectorizing a loop may create a large intermediate tensor. Always compare both compute time and peak memory, especially for pairwise [B, B] or [T, T] constructions.
Pairwise Tensors and Blockwise Computation
pairwise operation 是 vectorization 最常见的陷阱。比如所有向量两两距离,直接写:
diff = x[:, None, :] - y[None, :, :] # [N, M, D]
dist = diff.square().sum(dim=-1) # [N, M]数学上很清楚:
\[ d_{ij} = \|x_i-y_j\|_2^2. \]
但中间张量大小是 \(NMD\)。若 \(N=M=8192,D=1024\),diff 有
\[ 8192^2\cdot1024 \approx 6.87\times10^{10} \]
个元素。BF16 也要约 128 GB,根本不能作为“更向量化”的实现。
更好的写法用代数展开:
\[ \|x_i-y_j\|^2 = \|x_i\|^2+\|y_j\|^2-2x_i^\top y_j. \]
x2 = x.square().sum(dim=-1, keepdim=True) # [N, 1]
y2 = y.square().sum(dim=-1).unsqueeze(0) # [1, M]
dist = x2 + y2 - 2 * (x @ y.T) # [N, M]
dist = dist.clamp_min(0)这仍然产生 [N, M],但避免了 [N, M, D]。如果 [N, M] 本身也太大,就需要 blockwise computation:
def pairwise_topk(x, y, k, block):
vals = []
inds = []
y2 = y.square().sum(dim=-1).unsqueeze(0)
for start in range(0, x.size(0), block):
xb = x[start : start + block]
xb2 = xb.square().sum(dim=-1, keepdim=True)
dist = xb2 + y2 - 2 * (xb @ y.T)
v, i = torch.topk(-dist, k=k, dim=-1)
vals.append(v)
inds.append(i)
return torch.cat(vals, dim=0), torch.cat(inds, dim=0)这里返回 top-k negative distance,而不是完整 [N, M] 距离表。许多 retrieval、contrastive learning、hard negative mining 都需要这种“计算很多,但只保留少量结果”的思维。
If a vectorized expression creates [N, M, D] or [B, H, T, T], compute in blocks or use a fused library kernel before increasing batch size.
Attention as a Pairwise Operation
attention score
\[ S_{bhts} = q_{bht}^\top k_{bhs} \]
本质上也是 pairwise tensor。朴素 attention materializes [B, H, T, T] scores 和 probabilities;FlashAttention 的核心优化就是分块计算 softmax attention,避免把完整 attention matrix 写回 HBM。
这和上面的 pairwise distance 是同一个工程原则:数学上可以写出完整矩阵,系统上不一定应该 materialize 完整矩阵。
| Operation | Natural matrix | Common memory fix |
|---|---|---|
| pairwise distance | [N, M] or [N, M, D] |
algebraic expansion, block top-k |
| attention | [B, H, T, T] |
tiled/fused attention |
| contrastive logits | [B, B] |
cross-device gather + chunked CE |
| graph messages | [E, D] |
sparse gather-scatter |
Top-k, Sort, and Non-Differentiable Indices
topk 和 sort 返回 values 和 indices:
values, indices = torch.topk(scores, k=5, dim=-1)values 对 scores 可导,indices 是离散选择,不可导。很多 retrieval、hard negative mining、MoE routing、beam search 都用 top-k,但要记住:选择本身通常不接受梯度。若训练目标需要对选择过程可导,常见替代包括 softmax relaxation、Gumbel-Softmax、straight-through estimator 或 policy-gradient 方法。
A discrete index operation returns integer positions. Gradients can flow through selected values, but not through the integer selection rule itself.
这也是 MoE router 训练要额外设计 load-balance loss、auxiliary loss 或 straight-through trick 的原因:top-k dispatch 的 index 不会自动告诉 router“差一点被选中的 expert”应该怎样更新。
Tensor Operation Smoke Tests
Tensor bug 很多不是“代码跑不起来”,而是 silently wrong。下面这些小测试适合放进模型工具函数旁边,尤其是你写了自定义 collator、mask reduction、routing、sliding window 或 gather/scatter 逻辑时。
Test 1: View or Copy
import torch
x = torch.arange(12).reshape(3, 4)
view = x[:, 1:]
copy = x[:, torch.tensor([1, 2])]
assert view.untyped_storage().data_ptr() == x.untyped_storage().data_ptr()
assert copy.untyped_storage().data_ptr() != x.untyped_storage().data_ptr()这个测试不是要求业务代码到处比较 storage pointer,而是提醒你:basic slicing 和 advanced indexing 的内存语义不同。如果一个优化依赖“没有复制”,就应该被测试覆盖。
Test 2: Broadcast Gradient Is Summed
bias = torch.zeros(4, requires_grad=True)
y = bias[None, None, :].expand(2, 3, 4)
y.sum().backward()
assert torch.allclose(bias.grad, torch.full((4,), 6.0))这个测试确认 broadcast 维度的反传是 sum reduction。它也能抓出手写 backward 或自定义 op 中忘记归约 broadcast 维度的错误。
Test 3: Masked Mean Does Not Produce NaN
x = torch.randn(2, 3)
mask = torch.zeros(2, 3, dtype=torch.bool)
out = masked_mean(x, mask, dim=1)
assert torch.isfinite(out).all()
assert torch.allclose(out, torch.zeros_like(out))这里的 expected behavior 是“空 mask 返回 0”。如果你的任务语义是“空 mask 应该报错”,那测试就应该改成 assert raises。关键是不要让 NaN 静默进入 loss。
Test 4: Gather Repeated Indices Accumulate Gradients
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
idx = torch.tensor([0, 0, 2])
y = x.gather(0, idx)
y.sum().backward()
assert torch.allclose(x.grad, torch.tensor([2.0, 0.0, 1.0]))embedding、MoE combine、GNN aggregation 都依赖这个语义。重复 index 不是边界情况,而是常态。
Test 5: Pairwise Implementation Matches Naive Small Case
块化或代数展开的 pairwise distance 应该在小输入上和朴素实现一致:
x = torch.randn(5, 7)
y = torch.randn(6, 7)
naive = (x[:, None, :] - y[None, :, :]).square().sum(-1)
fast = x.square().sum(-1, keepdim=True) + y.square().sum(-1)[None, :] - 2 * x @ y.T
assert torch.allclose(fast.clamp_min(0), naive, atol=1e-5)大规模实现靠 block 和 kernel,小规模测试靠直接公式。这是 tensor program 最稳的调试习惯。
For tensor utilities, tiny deterministic tensors often catch more bugs than a full training run because shape, stride, broadcast, and gradient semantics become inspectable.
Operation Checklist
写 tensor 操作时逐项检查:
- shape 注释是否写出每个轴的语义;
- slicing/transpose 后是否需要 contiguous;
- advanced indexing 是否引入 copy;
- mask shape 是否能正确广播;
cat是否在循环里重复分配;expand后是否避免 in-place 写;squeeze是否指定 dim;- reduction denominator 是否符合 loss 定义;
pad参数顺序是否从最后一维开始;einsum/matmul 的维度是否和公式一致。- gather/scatter 是否明确处理重复 index;
- top-k/sort 的 index 是否被误认为可导;
- vectorization 是否引入过大的中间张量;
- broadcast 维度的 backward 是否需要 sum 回原 shape;
- sliding-window/view 操作是否可能产生 overlapping view;
- pairwise 操作是否 materialize 了不可承受的中间张量;
- module 边界是否写清 tensor contract;
- tiny tensor smoke tests 是否覆盖 view/copy、mask、gather 和 pairwise 语义。