2.2 Basic Neural Network Blocks


大多数网络都由少数基本块组合而成:affine layer、activation、normalization、regularization、residual connection 和 loss。理解这些块,比背某个模型名字更重要,因为新架构往往只是重新排列这些块,并改变它们的尺度、归一化位置和训练语义。

Linear Layer

nn.Linear(in_features, out_features) 实现:

\[ y=xW^\top+b. \]

PyTorch 中权重 shape 是:

weight: [out_features, in_features]
bias:   [out_features]

输入可以有任意 leading dimensions:

x: [B, in_features]       -> y: [B, out_features]
x: [B, T, in_features]    -> y: [B, T, out_features]

也就是最后一维做 affine transform,前面的维度当作 batch-like 维度保留。

参数量:

\[ N_{\text{param}} = d_{\text{out}}d_{\text{in}} + \mathbf{1}_{\text{bias}}d_{\text{out}}. \]

import torch
from torch import nn
import torch.nn.functional as F

layer = nn.Linear(10, 5, bias=True)
assert layer.weight.shape == (5, 10)
assert layer.bias.shape == (5,)
NoteDefinition: Affine Layer

An affine layer applies a linear map plus a bias: \(x\mapsto xW^\top+b\).

Linear as a Batched Tensor Contraction

nn.Linear 不是只服务于二维输入。若输入 shape 是 [B,T,d_in],权重 shape 是 [d_out,d_in],则

\[ y_{b,t,o} = \sum_{i=1}^{d_{\text{in}}} x_{b,t,i}W_{o,i} +b_o. \]

也就是只收缩最后一维,保留所有 leading dimensions。等价地,可以把前面的维度展平成一个大 batch:

\[ X_{\text{flat}}\in\mathbb{R}^{(BT)\times d_{\text{in}}}, \qquad Y_{\text{flat}}=X_{\text{flat}}W^\top+\mathbf{1}b^\top. \]

这解释了为什么 Transformer 里的 QKV projection、MLP projection 都可以直接对 [B,T,C]Linear(C, D)

x = torch.randn(4, 128, 768)
proj = nn.Linear(768, 2304)
y = proj(x)
assert y.shape == (4, 128, 2304)

Bias, Fused Kernels, and Initialization

Bias 是 broadcast 到所有 leading dimensions 的:

\[ y_{\ldots,o}=x_{\ldots,:}W_{o,:}^\top+b_o. \]

在很多 modern block 中,如果 Linear 后面紧跟 LayerNorm 或 BatchNorm,bias 可能可以关掉,因为 normalization 会重新引入 shift 参数。但这不是无条件规则:若 Linear 后面直接接 activation 或 residual add,bias 仍可能有用。

初始化时,fan-in / fan-out 为:

\[ \operatorname{fan\_in}=d_{\text{in}}, \qquad \operatorname{fan\_out}=d_{\text{out}}. \]

Xavier initialization 常用于 tanh/sigmoid-like activation:

\[ \operatorname{Var}(W)\approx \frac{2}{d_{\text{in}}+d_{\text{out}}}, \]

Kaiming initialization 常用于 ReLU-like activation:

\[ \operatorname{Var}(W)\approx \frac{2}{d_{\text{in}}}. \]

WarningPitfall: Flattening Can Destroy Semantics

Flattening [B,T,C] to [B,T*C] before a linear layer changes the model: it mixes time and feature dimensions. Applying Linear(C,D) to [B,T,C] keeps the token/time axis separate.

Softmax and Cross Entropy

Softmax 把 logits \(z\in\mathbb{R}^C\) 变成概率:

\[ p_i = \frac{\exp z_i}{\sum_j \exp z_j}. \]

one-hot label \(y\) 的 cross entropy:

\[ \mathcal{L} = -\sum_i y_i\log p_i = -\log p_c, \]

其中 \(c\) 是真实类别。

把 softmax 代入:

\[ \mathcal{L} = -z_c+\log\sum_j e^{z_j}. \]

梯度:

\[ \frac{\partial\mathcal{L}}{\partial z_i} = p_i-\mathbf{1}[i=c]. \]

Proof

\[ \mathcal{L}=-z_c+\log\sum_j e^{z_j} \]

求导。第一项给 \(-\mathbf{1}[i=c]\),第二项给:

\[ \frac{e^{z_i}}{\sum_j e^{z_j}}=p_i. \]

相加即 \(p_i-\mathbf{1}[i=c]\)

PyTorch 的 nn.CrossEntropyLoss 接收 raw logits,不接收 softmax 后的概率:

criterion = nn.CrossEntropyLoss()
logits = torch.randn(32, 10)
labels = torch.randint(0, 10, (32,))
loss = criterion(logits, labels)

内部等价于 stable log_softmax + nll_loss。若先手动 softmax 再传入,会损失数值稳定性并改变语义。

WarningPitfall: CrossEntropyLoss Wants Logits

Do not pass probabilities from softmax into nn.CrossEntropyLoss; pass raw logits.

Stable Log-Softmax

直接计算 \(\exp z_i\) 可能 overflow。稳定写法是先减去最大 logit:

\[ \log\sum_j e^{z_j} = m+\log\sum_j e^{z_j-m}, \qquad m=\max_j z_j. \]

由于 softmax 对整体平移不变:

\[ \operatorname{softmax}(z+c\mathbf{1}) = \operatorname{softmax}(z), \]

减去 \(m\) 不改变概率,却让指数项落在 \([0,1]\)

def stable_cross_entropy(logits, labels):
    # logits: [B, C], labels: [B]
    z = logits - logits.max(dim=-1, keepdim=True).values
    log_probs = z - torch.logsumexp(z, dim=-1, keepdim=True)
    idx = torch.arange(labels.numel(), device=labels.device)
    return -log_probs[idx, labels].mean()
NoteDefinition: LogSumExp Trick

The log-sum-exp trick rewrites \(\log\sum_i e^{z_i}\) as \(m+\log\sum_i e^{z_i-m}\) with \(m=\max_i z_i\) to avoid numerical overflow.

Label Smoothing

Hard one-hot target 把所有概率质量放在正确类别 \(c\)。Label smoothing 使用

\[ q_i = (1-\epsilon)\mathbf{1}[i=c] +\epsilon\frac{1}{C}. \]

loss 变成

\[ \mathcal{L} = -\sum_i q_i\log p_i. \]

它的梯度仍然是非常干净的形式:

\[ \frac{\partial \mathcal{L}}{\partial z_i} = p_i-q_i. \]

直观上,label smoothing 不让模型把正确类别推到概率 \(1\)。它常能改善 calibration,但也可能伤害需要极高置信度的任务,尤其是类别定义本来就很清楚、数据量又足够时。

Activation Functions

Activation 引入非线性。没有 activation,多层 affine 仍然等价于一个 affine map:

\[ W_2(W_1x+b_1)+b_2 = (W_2W_1)x+(W_2b_1+b_2). \]

常见 activation:

Activation Formula Gradient issue
ReLU \(\max(x,0)\) dead negative units
LeakyReLU \(\max(x,\alpha x)\) less dead units
Sigmoid \(1/(1+e^{-x})\) saturation
Tanh \((e^x-e^{-x})/(e^x+e^{-x})\) saturation
GELU \(x\Phi(x)\) smooth transformer FFN
SiLU/Swish \(x\sigma(x)\) smooth gated behavior

Sigmoid derivative:

\[ \sigma'(x)=\sigma(x)(1-\sigma(x)). \]

\(|x|\) 很大时,\(\sigma'(x)\approx0\),深层网络中梯度容易消失。ReLU 缓解正半轴饱和,但负半轴梯度为 0。

Derivative Shape and Saturation

Activation 选择会改变局部 Jacobian。逐元素 activation \(\phi\) 的 Jacobian 是对角矩阵:

\[ \frac{\partial \phi(x)}{\partial x} = \operatorname{diag}(\phi'(x_1),\ldots,\phi'(x_d)). \]

如果大量 \(\phi'(x_i)\approx0\),反向传播就会被压小。常见导数:

Activation Derivative
ReLU \(\mathbf{1}[x>0]\)
LeakyReLU \(1\) if \(x>0\), else \(\alpha\)
Sigmoid \(\sigma(x)(1-\sigma(x))\)
Tanh \(1-\tanh^2(x)\)
SiLU \(\sigma(x)+x\sigma(x)(1-\sigma(x))\)

GELU 常用近似:

\[ \operatorname{GELU}(x) \approx \frac{x}{2} \left( 1+\tanh\left[\sqrt{\frac{2}{\pi}}\left(x+0.044715x^3\right)\right] \right). \]

平滑 activation 的优点不是“更高级”,而是负半轴也保留一些可微的门控行为;代价是比 ReLU 稍贵,并且 kernel fusion 更依赖框架实现。

Gated Activations and SwiGLU

现代 Transformer FFN 常用 gated MLP。GLU family 的基本形式:

\[ \operatorname{GLU}(x) = (xW_v)\odot \sigma(xW_g). \]

SwiGLU 把 sigmoid gate 换成 SiLU/Swish:

\[ \operatorname{SwiGLU}(x) = (xW_v)\odot \operatorname{SiLU}(xW_g). \]

然后再投影回 hidden size:

\[ y = \left((xW_v)\odot \operatorname{SiLU}(xW_g)\right)W_o. \]

若 hidden size 是 \(d\),中间维度是 \(d_{\text{ff}}\),普通 FFN 参数约为

\[ 2dd_{\text{ff}}, \]

而 SwiGLU 有 gate/value/down 三个矩阵:

\[ 3dd_{\text{ff}}. \]

因此为了保持参数量接近,常取 \(d_{\text{ff}}\approx \frac{8}{3}d\),使 \(3d\cdot\frac{8}{3}d\approx8d^2\),接近普通 \(4d\) FFN 的 \(2d(4d)=8d^2\)

Dropout

训练时 dropout 使用 mask:

\[ m_i\sim\operatorname{Bernoulli}(1-p), \qquad y_i=\frac{m_i}{1-p}x_i. \]

这样:

\[ \mathbb{E}[y_i] = x_i. \]

也就是 inverted dropout:训练时放大保留的 activation,推理时不需要再缩放。

drop = nn.Dropout(p=0.1)
model.train()
y_train = drop(x)

model.eval()
y_eval = drop(x)  # identity

Dropout 是训练模式相关层,忘记 model.eval() 会导致推理随机。

Dropout Changes Variance

Inverted dropout 保持期望不变,但会增加方差。设 \(m\sim\operatorname{Bernoulli}(1-p)\)\(y=\frac{m}{1-p}x\)。条件在固定 \(x\) 上:

\[ \mathbb{E}[y]=x, \qquad \operatorname{Var}(y) = \frac{p}{1-p}x^2. \]

所以 dropout 不是“无影响地随机删一些神经元”。它会给 activation 注入乘性噪声;\(p\) 越大,训练越 noisy。

NoteDefinition: Inverted Dropout

Inverted dropout scales retained activations by \(1/(1-p)\) during training so that inference can use the identity map.

Dropout Mask Shape

Dropout 的 mask shape 会决定正则化语义:

Layer Typical mask semantics
Dropout on [B,T,C] independent element-wise mask
Dropout2d on [B,C,H,W] drop whole channels
attention dropout drop attention probabilities
residual dropout drop sublayer output before residual add
class ResidualDropoutBlock(nn.Module):
    def __init__(self, dim, p):
        super().__init__()
        self.ffn = nn.Sequential(nn.Linear(dim, 4 * dim), nn.GELU(), nn.Linear(4 * dim, dim))
        self.drop = nn.Dropout(p)

    def forward(self, x):
        return x + self.drop(self.ffn(x))
WarningPitfall: Dropout Position Changes the Model

Dropping hidden activations, attention probabilities, and residual branches are different operations. They may share the name “dropout” but regularize different random variables.

Batch Normalization

BatchNorm 对 mini-batch 统计做标准化:

\[ \hat{x} = \frac{x-\mu_\mathcal{B}} {\sqrt{\sigma_\mathcal{B}^2+\epsilon}}, \qquad y=\gamma\hat{x}+\beta. \]

其中:

\[ \mu_\mathcal{B} = \frac{1}{m}\sum_{i=1}^m x_i, \qquad \sigma_\mathcal{B}^2 = \frac{1}{m}\sum_{i=1}^m(x_i-\mu_\mathcal{B})^2. \]

训练时使用 batch statistics,并更新 running mean/var;推理时使用 running statistics。

Mode Statistics used Updates running stats?
train() current mini-batch yes
eval() running mean/var no

BatchNorm 对 batch size 敏感。小 batch、分布变化很大的 batch、序列长度 padding 很多的 batch,都可能让统计量噪声很大。

BatchNorm Axes and Momentum

BatchNorm 的关键不是公式本身,而是统计轴。对 BatchNorm1d 输入 [B,C],每个 channel 的统计量在 batch 维上算;对 BatchNorm2d 输入 [B,C,H,W],统计量在 [B,H,W] 上算:

\[ \mu_c = \frac{1}{BHW} \sum_{b,h,w}x_{b,c,h,w}. \]

PyTorch 的 running statistics 更新是:

\[ \operatorname{running} \leftarrow (1-\alpha)\operatorname{running} +\alpha\operatorname{batch}, \]

其中 \(\alpha\)momentum。这和优化器里的 momentum 命名方向不同:PyTorch BN 的 momentum=0.1 表示新 batch 占 \(10\%\)

bn = nn.BatchNorm2d(num_features=64, momentum=0.1)
x = torch.randn(8, 64, 32, 32)
y = bn(x)
WarningPitfall: BatchNorm and Gradient Accumulation

Gradient accumulation increases optimizer batch size, but BatchNorm still sees each micro-batch separately. If micro-batch is tiny, BN statistics can remain noisy even when effective batch size is large.

When BatchNorm Fails

BatchNorm works well when mini-batch statistics are representative. It becomes fragile when:

  1. batch size 很小;
  2. sequence padding 很多,padding token 混进统计;
  3. train/test distribution 的 channel statistics 差异大;
  4. distributed training 中每张卡 batch 很小但没有 SyncBatchNorm;
  5. fine-tuning 时数据量太小,running stats 被少量样本污染。

常见处理:

Problem Option
tiny batch LayerNorm / GroupNorm
multi-GPU small local batch SyncBatchNorm
fine-tuning small dataset freeze BN stats
variable-length sequences mask-aware normalization or LayerNorm

LayerNorm and RMSNorm

LayerNorm 不跨 batch 做统计,而是在每个样本内部最后若干维做归一化:

\[ \operatorname{LN}(x) = \gamma \frac{x-\mu}{\sqrt{\sigma^2+\epsilon}} +\beta. \]

对 Transformer hidden state [B,T,C],LayerNorm 通常沿 C 归一化。它适合变长序列和小 batch。

RMSNorm 去掉均值中心化,只按 root mean square 缩放:

\[ \operatorname{RMSNorm}(x) = \gamma \frac{x} {\sqrt{\frac{1}{d}\sum_i x_i^2+\epsilon}}. \]

它更便宜,现代 LLM 中很常见。

LayerNorm Axes

nn.LayerNorm(normalized_shape) 对输入最后若干维做统计。若 hidden state 是 [B,T,C],常用:

ln = nn.LayerNorm(768)
y = ln(x)  # x: [B, T, 768]

对应每个 token 独立算:

\[ \mu_{b,t} = \frac{1}{C}\sum_{c=1}^Cx_{b,t,c}. \]

它不跨 batch,也不跨 time,所以非常适合变长序列和 autoregressive model。若误写成 LayerNorm([T,C]),就会把 time 维也归一化进去,语义完全不同。

NoteDefinition: Normalization Axis

The normalization axis is the set of dimensions used to compute mean and variance. Different axes define different layers even if formulas look similar.

RMSNorm and Scale Invariance

RMSNorm 定义:

\[ \operatorname{RMSNorm}(x) = \gamma\frac{x}{\operatorname{RMS}(x)}, \qquad \operatorname{RMS}(x) = \sqrt{\frac{1}{d}\sum_i x_i^2+\epsilon}. \]

若忽略 \(\epsilon\),则对任意 \(\alpha>0\)

\[ \operatorname{RMSNorm}(\alpha x) = \gamma\frac{\alpha x}{\alpha \operatorname{RMS}(x)} = \operatorname{RMSNorm}(x). \]

RMSNorm 保留均值方向,不做 centering;它只控制尺度。这也是很多 LLM 选它的原因:省掉 mean subtraction,计算更便宜,同时在 residual stream 上只规范化 magnitude。

class RMSNorm(nn.Module):
    def __init__(self, dim, eps):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim))
        self.eps = eps

    def forward(self, x):
        rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
        return self.weight * x / rms
WarningPitfall: Epsilon and Dtype Matter

For FP16/BF16, too small an epsilon can make normalization numerically brittle. Keep normalization accumulation in a stable dtype when implementing custom kernels.

Residual Connections

Residual block:

\[ y=x+F(x). \]

反向传播:

\[ \frac{\partial y}{\partial x} = I+\frac{\partial F}{\partial x}. \]

这条 identity path 让梯度可以绕过复杂子层直接传播,是训练很深网络的关键。

ImportantTheorem: Residual Gradient Path

For \(y=x+F(x)\), the Jacobian contains an identity term, so gradients have a direct path independent of the sublayer Jacobian.

Proof

逐元素看 \(y_i=x_i+F_i(x)\)。对 \(x_j\) 求导:

\[ \frac{\partial y_i}{\partial x_j} = \frac{\partial x_i}{\partial x_j} + \frac{\partial F_i}{\partial x_j} = \delta_{ij} + \frac{\partial F_i}{\partial x_j}. \]

矩阵形式就是 \(I+J_F\)

Pre-Norm vs. Post-Norm

Transformer block 里 residual 和 normalization 的相对位置非常关键。

Post-norm:

\[ y=\operatorname{Norm}(x+F(x)). \]

Pre-norm:

\[ y=x+F(\operatorname{Norm}(x)). \]

Pre-norm 的 residual path 更接近纯 identity。反向看:

\[ \frac{\partial y}{\partial x} = I + J_F(\operatorname{Norm}(x))J_{\operatorname{Norm}}(x). \]

identity term 没有被 Norm 包住,因此深层网络更稳定。Post-norm 有时最终质量好,但通常需要更小心的 warmup、初始化和 residual scaling。

Layout Formula Training behavior
post-norm \(\operatorname{Norm}(x+F(x))\) classic Transformer, deeper networks harder
pre-norm \(x+F(\operatorname{Norm}(x))\) easier optimization, common in LLMs
sandwich norm \(x+F(\operatorname{Norm}_1(x))\) then \(\operatorname{Norm}_2\) extra stability, extra compute

Residual Scaling

如果有 \(L\) 个 residual updates:

\[ x_L=x_0+\sum_{\ell=1}^L F_\ell(x_{\ell-1}), \]

粗略独立假设下,若每个 update 方差为 \(\sigma^2\),residual stream 方差会随 \(L\) 增长。常见稳定化手段:

  1. 初始化最后一个 projection 为较小值或零;
  2. residual branch 乘 \(1/\sqrt{L}\) 或类似尺度;
  3. 使用 pre-norm / RMSNorm;
  4. 使用 gradient clipping 和 warmup。
class ScaledResidual(nn.Module):
    def __init__(self, dim, depth):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.ffn = nn.Sequential(nn.Linear(dim, 4 * dim), nn.GELU(), nn.Linear(4 * dim, dim))
        self.scale = depth ** -0.5

    def forward(self, x):
        return x + self.scale * self.ffn(self.norm(x))

Minimal MLP Block

一个分类 MLP:

class MLPClassifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, num_classes, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes),
        )

    def forward(self, x):
        return self.net(x)

几点实现语义:

  1. 最后一层输出 logits,不加 softmax;
  2. dropout 在 train/eval 下行为不同;
  3. LayerNorm 的参数会进入 optimizer;
  4. CrossEntropyLoss 负责 softmax + NLL。

Modern Feed-Forward Block

现代 LLM/Transformer 的 MLP block 往往长这样:

class SwiGLUFFN(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.gate = nn.Linear(dim, hidden_dim, bias=False)
        self.up = nn.Linear(dim, hidden_dim, bias=False)
        self.down = nn.Linear(hidden_dim, dim, bias=False)

    def forward(self, x):
        return self.down(F.silu(self.gate(x)) * self.up(x))

shape contract:

x:    [B, T, d]
gate: [B, T, h]
up:   [B, T, h]
prod: [B, T, h]
out:  [B, T, d]

参数量:

\[ N_{\text{SwiGLU}} = 3dh \quad \text{without bias}. \]

\(h=\frac{8}{3}d\),则

\[ N_{\text{SwiGLU}} = 8d^2, \]

和普通 Linear(d,4d) -> activation -> Linear(4d,d) 接近。

Block Composition Pattern

一个更接近真实训练代码的 pre-norm MLP block:

class PreNormMLPBlock(nn.Module):
    def __init__(self, dim, hidden_dim, dropout):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.ffn = SwiGLUFFN(dim, hidden_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        h = self.norm(x)
        h = self.ffn(h)
        return x + self.drop(h)

这段代码体现了几个 block-level 设计选择:

  1. Norm 放在 branch 内部,residual path 保持 identity;
  2. dropout 作用在 residual update 上,而不是输入 x 上;
  3. FFN 不改变 [B,T,d] 的外部 contract;
  4. 最后一层没有 activation,因为输出要回到 residual stream。

Shape Debugging

建议在每个 block 边界写 shape contract:

def forward(self, x):
    # x: [batch, in_dim]
    logits = self.net(x)
    # logits: [batch, num_classes]
    return logits

对复杂模型可以用 hooks:

def print_shape(name):
    def hook(module, inputs, output):
        print(name, tuple(output.shape))
    return hook

for name, module in model.named_modules():
    if isinstance(module, nn.Linear):
        module.register_forward_hook(print_shape(name))

Implementation Checklist

构建基础网络块时检查:

  1. Linear 输入最后一维是否等于 in_features
  2. 是否错误 flatten 了 batch/time/spatial 语义维度;
  3. logits 是否未经过 softmax 就传给 CrossEntropyLoss
  4. CE / log-softmax 是否使用 stable logsumexp;
  5. activation 是否适合当前初始化和深度;
  6. SwiGLU/GLU 的 gate、up、down shape 是否匹配;
  7. dropout 是否只在训练时生效,mask 语义是否符合任务;
  8. BatchNorm 小 batch、gradient accumulation、distributed local batch 下是否稳定;
  9. LayerNorm/RMSNorm 归一化维度是否正确;
  10. epsilon 和 accumulation dtype 是否适合 FP16/BF16;
  11. residual branch shape 是否完全一致;
  12. pre-norm/post-norm 选择是否匹配网络深度和 warmup;
  13. loss reduction 是否符合 batch/token 权重设定;
  14. model.train() / model.eval() 是否在正确阶段调用。