3.6 Parameter Initialization
Initialization 决定训练开始时信号和梯度能否穿过网络。好的初始化让每层 activation variance 和 gradient variance 大致稳定;坏的初始化会让网络一开始就 saturated、dead 或 exploding。
Variance Propagation
考虑线性层
\[ y_i=\sum_{j=1}^{n}w_{ij}x_j. \]
若 \(w_{ij}\) 和 \(x_j\) 独立,均值为 \(0\),则
\[ \operatorname{Var}(y_i) = n\operatorname{Var}(w)\operatorname{Var}(x). \]
为了让 \(\operatorname{Var}(y_i)\approx\operatorname{Var}(x)\),需要
\[ \operatorname{Var}(w)\approx\frac{1}{n}. \]
这就是 Xavier/He initialization 的基本来源。
若 \(w_{ij},x_j\) 独立且均值为 0,则
\[ \operatorname{Var}(w_{ij}x_j) = \mathbb{E}[w_{ij}^2x_j^2] = \mathbb{E}[w_{ij}^2]\mathbb{E}[x_j^2] = \operatorname{Var}(w)\operatorname{Var}(x). \]
又因为不同 \(j\) 的项独立,
\[ \operatorname{Var}\left(\sum_{j=1}^{n}w_{ij}x_j\right) = \sum_{j=1}^{n}\operatorname{Var}(w_{ij}x_j) = n\operatorname{Var}(w)\operatorname{Var}(x). \]
要让输出方差接近输入方差,就需要
\[ n\operatorname{Var}(w)\approx1. \]
初始化的核心不是“随机小一点”,而是让信号在 forward 和 backward 两个方向都不要系统性变大或变小。
Fan-In, Fan-Out, and Convolution
初始化公式里的 \(n\) 不是随便取的。对线性层
\[ y = xW^\top, \qquad W\in\mathbb{R}^{n_{\text{out}}\times n_{\text{in}}}, \]
有
\[ \operatorname{fan\_in}=n_{\text{in}}, \qquad \operatorname{fan\_out}=n_{\text{out}}. \]
对卷积核
\[ W\in \mathbb{R}^{C_{\text{out}}\times C_{\text{in}}\times k_h\times k_w}, \]
每个输出位置看到 \(C_{\text{in}}k_hk_w\) 个输入,所以
\[ \operatorname{fan\_in}=C_{\text{in}}k_hk_w, \qquad \operatorname{fan\_out}=C_{\text{out}}k_hk_w. \]
fan_in is the number of input degrees contributing to one output unit; fan_out is the number of output degrees receiving gradient from one input unit. Initialization uses these counts to balance forward activation variance and backward gradient variance.
在 grouped convolution 中,fan-in 要除以 groups;在 depthwise convolution 中,每个 channel 只看自己的 kernel,所以 fan-in 约为 \(k_hk_w\)。这类细节如果算错,模型可能一开始 activation scale 就偏掉。
PyTorch 的 nn.init 会根据 tensor shape 自动估 fan-in/fan-out,但自定义权重布局时要小心。例如有些代码把 linear weight 存成 [in, out] 而不是 PyTorch 默认 [out, in],直接套 kaiming_uniform_ 就会用错方向。
Xavier Initialization
For activations roughly symmetric around zero, Xavier initialization sets \[ \operatorname{Var}(w) = \frac{2}{n_{\text{in}}+n_{\text{out}}}. \] For a uniform distribution, \[ w\sim U\left[-\sqrt{\frac{6}{n_{\text{in}}+n_{\text{out}}}}, \sqrt{\frac{6}{n_{\text{in}}+n_{\text{out}}}}\right]. \]
它平衡 forward 和 backward 的 variance,适合 tanh、linear、sigmoid 附近的浅层情形。
为什么是
\[ \frac{2}{n_{\text{in}}+n_{\text{out}}} \]
而不是只用 \(1/n_{\text{in}}\)?Forward variance 希望
\[ \operatorname{Var}(w)\approx\frac1{n_{\text{in}}}, \]
Backward gradient variance 希望
\[ \operatorname{Var}(w)\approx\frac1{n_{\text{out}}}. \]
Xavier 取二者的折中,使两个方向都不太坏:
\[ \operatorname{Var}(w) = \frac{2}{n_{\text{in}}+n_{\text{out}}}. \]
对 uniform 分布 \(U[-a,a]\),
\[ \operatorname{Var}(w)=\frac{a^2}{3}. \]
令
\[ \frac{a^2}{3} = \frac{2}{n_{\text{in}}+n_{\text{out}}}, \]
得到
\[ a=\sqrt{\frac{6}{n_{\text{in}}+n_{\text{out}}}}. \]
He Initialization
ReLU 会把大约一半 activation 截成 \(0\)。为了补偿方差损失,He initialization 使用
\[ \operatorname{Var}(w)=\frac{2}{n_{\text{in}}}. \]
For ReLU-like activations, initialize \[ w\sim \mathcal{N}\left(0,\frac{2}{n_{\text{in}}}\right) \] or the corresponding uniform distribution.
若 pre-activation \(a\) 近似对称,ReLU 输出
\[ h=\max(0,a) \]
会把约一半值置零。粗略地,二阶矩约减半:
\[ \mathbb{E}[h^2]\approx\frac12\mathbb{E}[a^2]. \]
线性层产生
\[ \operatorname{Var}(a) \approx n_{\text{in}}\operatorname{Var}(w)\operatorname{Var}(x). \]
为了 ReLU 后保持方差:
\[ \frac12 n_{\text{in}}\operatorname{Var}(w)\approx1, \]
所以
\[ \operatorname{Var}(w)\approx\frac{2}{n_{\text{in}}}. \]
GELU/SiLU 不像 ReLU 那样精确截半,但它们也是非线性门控,实际框架默认初始化通常仍围绕类似的方差传播原则设计。
Backward Variance and Jacobians
Forward variance 只解释了一半。训练是否顺利还取决于 loss gradient 能否穿过层。对一层
\[ h_{\ell+1}=\phi(W_\ell h_\ell), \]
局部 Jacobian 是
\[ J_\ell = \frac{\partial h_{\ell+1}}{\partial h_\ell} = \operatorname{diag}(\phi'(a_\ell))W_\ell. \]
整个网络的反向梯度会乘上
\[ J_{L-1}^{\top}J_{L-2}^{\top}\cdots J_0^{\top}. \]
如果这些 Jacobian 的奇异值多数小于 \(1\),梯度消失;多数大于 \(1\),梯度爆炸。理想情况不是每个权重很小,而是 Jacobian product 的尺度接近稳定。
A network has approximate dynamical isometry at initialization when the singular values of the input-output Jacobian concentrate near \(1\), so gradients propagate through depth without strong contraction or expansion.
正交初始化的直觉来自这里。若线性网络中每个 \(W_\ell\) 都是 orthogonal matrix,则
\[ \|W_\ell v\|_2=\|v\|_2. \]
因此纯线性时梯度范数不会因为矩阵乘法本身改变。非线性、残差、normalization、attention 都会破坏这个简单结论,但“看 Jacobian singular values”仍然是理解初始化的好语言。
For a deep linear network with orthogonal square weight matrices, backpropagated gradient norms are preserved across layers.
若 \(W^\top W=I\),则对任意向量 \(g\),
\[ \|W^\top g\|_2^2 = g^\top W W^\top g = g^\top g = \|g\|_2^2. \]
多层相乘仍然保持范数,因为每一层都保持范数。
这也是为什么仅看 parameter std 不够:两个矩阵可以有相同 std,但 singular value distribution 完全不同。训练大模型时更常监控 activation RMS、residual RMS、attention entropy、gradient norm,而不是只打印权重方差。
Residual Networks
残差结构
\[ x_{\ell+1}=x_\ell+F_\ell(x_\ell) \]
让梯度有 identity path,可以缓解深层网络退化。但如果每个 residual branch 初始方差过大,层数很多时 residual stream 仍会膨胀。
现代 Transformer 常用几类技巧:
| Technique | Purpose |
|---|---|
| LayerNorm / RMSNorm | normalize residual stream statistics |
| residual branch scaling | avoid accumulation across depth |
| small output projection init | make residual updates small initially |
| zero-init final norm/proj in blocks | start near identity mapping |
Transformer Initialization
Transformer 中最敏感的是 attention projection、MLP projection、embedding scale 和 residual stream。常见经验包括:
\[ W\sim\mathcal{N}(0,\sigma^2), \qquad \sigma\approx0.02 \]
以及对 residual projection 使用 depth-aware scaling:
\[ \sigma_{\text{proj}} \propto \frac{1}{\sqrt{2L}}. \]
直觉是:每层 residual 都往同一条 stream 写入信息,层数越深,单层写入幅度越应该保守。
Attention Logit Scale at Initialization
Self-attention 的 logits 是
\[ s_{ij} = \frac{q_i^\top k_j}{\sqrt{d_h}}. \]
如果 \(q\) 和 \(k\) 的每个维度方差约为 \(\sigma_q^2\) 和 \(\sigma_k^2\),且近似独立,则
\[ \operatorname{Var}(q_i^\top k_j) = d_h\sigma_q^2\sigma_k^2. \]
除以 \(\sqrt{d_h}\) 后
\[ \operatorname{Var}(s_{ij}) \approx \sigma_q^2\sigma_k^2. \]
所以 \(\sqrt{d_h}\) scaling 的作用是让 attention logit variance 不随 head dimension 线性增长。若 Q/K projection 初始化过大,softmax 初始就会非常尖锐;若过小,attention 接近均匀,早期几乎不区分 token。
At step 0, attention entropy near zero usually means QK logits are too large or masks are wrong. Entropy near the maximum for all heads can mean QK logits are too small, though uniform attention can also be normal for some early layers.
一个简单检查:
attn = torch.softmax(scores.masked_fill(mask, -torch.inf), dim=-1)
entropy = -(attn * attn.clamp_min(1e-9).log()).sum(dim=-1).mean()如果使用 additive mask,mask 后不能留下有限的大负数错误广播;否则初始化诊断会把 mask bug 误判成 QK scale 问题。
Residual Variance Accumulation
设 residual block 为
\[ x_{\ell+1}=x_\ell+u_\ell. \]
若粗略假设 \(x_\ell\) 与 update \(u_\ell\) 不相关,则
\[ \operatorname{Var}(x_{\ell+1}) = \operatorname{Var}(x_\ell)+\operatorname{Var}(u_\ell). \]
如果每层 update 方差都相同,经过 \(L\) 层后 residual stream 方差可能线性增长:
\[ \operatorname{Var}(x_L) \approx \operatorname{Var}(x_0)+L\sigma_u^2. \]
为了让总方差保持 \(O(1)\),每层 update 方差应随深度缩小:
\[ \sigma_u^2\propto\frac1L, \qquad \sigma_u\propto\frac1{\sqrt{L}}. \]
Transformer 中常见的 residual projection scaling
\[ \sigma_{\text{proj}}\propto\frac1{\sqrt{2L}} \]
就是这个直觉的一个版本。分母里的 \(2\) 来自每层通常有 attention 和 MLP 两个 residual branches。
If a network has \(O(L)\) independent residual updates and each update is initialized with standard deviation \(O(1/\sqrt{L})\), then the total residual variance remains \(O(1)\) at initialization.
设每层 residual update 方差为 \(c/L\),并粗略假设不同层 update 不相关,则
\[ \operatorname{Var}(x_L) \approx \operatorname{Var}(x_0) + \sum_{\ell=0}^{L-1}\operatorname{Var}(u_\ell) = \operatorname{Var}(x_0)+c. \]
因此总方差随深度保持 \(O(1)\)。真实网络中 update 并不独立,但这个估算解释了为什么深度越大,residual branch 初始输出越应该保守。
Pre-LN, Post-LN, and Residual Scale
Transformer block 有两种常见 norm placement。Post-LN 写作
\[ x_{\ell+1} = \operatorname{LN}(x_\ell+F_\ell(x_\ell)). \]
Pre-LN 写作
\[ x_{\ell+1} = x_\ell+F_\ell(\operatorname{LN}(x_\ell)). \]
Pre-LN 的梯度有一条更直接的 residual path,因此通常更容易训练深层模型;Post-LN 的 residual stream 每层被 norm 重置,但梯度要穿过更多 normalization 和 branch 结构,深层训练更敏感。
| placement | forward scale | backward path | initialization concern |
|---|---|---|---|
| Post-LN | normalized after update | gradient crosses LN after addition | deep stacks can be fragile |
| Pre-LN | residual stream can drift | identity gradient path | residual update scale matters |
| Sandwich norm | more normalization points | stronger scale control | extra compute and tuning |
Pre-LN 并不意味着 residual stream 不会漂移。因为每层写入
\[ x_{\ell+1}-x_\ell=F_\ell(\operatorname{LN}(x_\ell)), \]
如果 \(F_\ell\) 初始过大,residual RMS 仍会随层数增长。很多 LLM recipe 会对 attention output projection 和 MLP down projection 使用更小的 std,正是为了控制写入 residual stream 的幅度。
Symmetry Breaking
所有权重不能初始化成同一个值。若同一层所有 hidden units 参数相同,则它们会收到相同梯度,永远学成同一个 neuron。
Zero initialization is acceptable for some biases, but not for hidden-layer weights. Without random asymmetry, neurons in the same layer remain identical under gradient descent.
考虑同一层两个 hidden units \(a,b\),若初始化
\[ w_a=w_b, \qquad b_a=b_b, \]
则对任意输入 \(x\),
\[ h_a(x)=h_b(x). \]
下一层若也对它们对称,则 loss 对这两个 units 的梯度相同:
\[ \nabla_{w_a}L=\nabla_{w_b}L. \]
一次更新后仍有
\[ w_a'=w_b'. \]
归纳可知它们永远保持相同,相当于浪费了一个 neuron。随机初始化的最基本作用就是打破这种排列对称性。
Embeddings and Output Heads
语言模型里 embedding table
\[ E\in\mathbb{R}^{|\mathcal{V}|\times d} \]
和 LM head
\[ W_{\text{out}}\in\mathbb{R}^{d\times|\mathcal{V}|} \]
也需要初始化。若使用 tied embeddings:
\[ W_{\text{out}}=E^\top, \]
输入 token identity 和输出分类器共享参数。这减少参数,也让输入/输出 token geometry 对齐。
初始 logits 的尺度很重要。若 hidden state \(h\) 和 output weights \(w_k\) 方差过大,
\[ z_k=h^\top w_k \]
会使 softmax 过早尖锐;若过小,则所有 token 近似均匀,早期梯度也可能缺乏区分。很多 Transformer recipe 使用较小的 normal init,例如 \(\sigma=0.02\),再配合 residual scaling 和 normalization 稳定训练。
更形式化地,若 \(h_j\) 和 \(w_{k,j}\) 独立、零均值,方差分别为 \(\sigma_h^2\) 和 \(\sigma_w^2\),则
\[ \operatorname{Var}(z_k) = \operatorname{Var}\left(\sum_{j=1}^{d}h_jw_{k,j}\right) = d\sigma_h^2\sigma_w^2. \]
所以 LM head 初始化不仅取决于 \(W_{\text{out}}\),也取决于 residual stream 的 hidden scale。若模型使用 tied embeddings,embedding 初始化同时影响输入表示和输出分类器。
When adding special tokens after pretraining, newly resized embedding rows are randomly initialized unless explicitly handled. Badly initialized special-token embeddings can destabilize early fine-tuning or make control tokens hard to learn.
常见做法是把新增 token row 初始化为已有 token embedding 的均值附近:
old_n = old_embed.weight.size(0)
model.resize_token_embeddings(new_n)
with torch.no_grad():
emb = model.get_input_embeddings().weight
emb[old_n:].normal_(mean=0.0, std=emb[:old_n].std().item())如果是 instruct/chat 模板 token,也可以用已有 special tokens 的均值初始化,避免它们一开始落在完全陌生的 embedding 区域。
Initialization and Normalization
LayerNorm/RMSNorm 改变了初始化问题。因为每个 block 读入前会 normalize,forward activation scale 更容易控制。但这不意味着初始化不重要:
- residual branch 输出仍会累积;
- attention logits 的初始尺度影响 softmax entropy;
- MLP gate 的初始尺度影响饱和程度;
- embedding/logit scale 影响 early cross entropy;
- optimizer moments 会继承初始梯度统计。
Pre-LN Transformer 常常比无 normalization 的深层网络更容易训练,但它仍需要合理的 residual projection init。
LayerNorm 通常初始化为
\[ \gamma=\mathbf{1}, \qquad \beta=\mathbf{0}. \]
RMSNorm 没有 mean-centering 和 bias,通常只初始化 scale 为 \(1\)。如果把 norm scale 初始化得太小,residual branch 读到的输入会被整体压小;太大则会放大 branch 写入。
有些 residual architecture 会把 block 末端的 scale 初始化为 \(0\) 或很小,使网络一开始接近 identity:
\[ x_{\ell+1}=x_\ell+\alpha_\ell F_\ell(x_\ell), \qquad \alpha_\ell\approx0. \]
这在很深的 residual net、adapter、LoRA 或 diffusion U-Net 中都很常见:先保证基模型行为不被随机分支破坏,再让训练逐渐打开新路径。
Width Scaling and muP Intuition
宽度变大时,如果初始化和学习率不随宽度配套变化,同一套超参数可能不再稳定。普通 fan-in 初始化让每个 pre-activation variance 保持 \(O(1)\):
\[ W_{ij}\sim \mathcal{N}\left(0,\frac{1}{n_{\text{in}}}\right). \]
但不同参数类型对输出的贡献不同:hidden weights、embedding、output head、residual projections 的宽度缩放并不完全一样。最大更新参数化(muP)的核心思想是选择一套 width scaling,使不同宽度模型在训练早期有相似的 feature learning dynamics。
A width-consistent parameterization chooses initialization and learning-rate scaling so that changing hidden width preserves comparable activation scales and update magnitudes.
对课程笔记来说,不需要马上掌握 muP 的所有细节,但要知道一个工程原则:如果你把模型宽度放大很多,不能只复制小模型的 init 和 LR,然后期待训练动力学完全相同。需要至少监控:
| signal | why it matters across width |
|---|---|
| activation RMS | forward scale should remain comparable |
| update-to-weight ratio | optimizer step should not dominate weights |
| logit std | CE entropy should start in similar regime |
| grad norm by module | wider modules may receive different gradient scale |
| loss after first N steps | catches silent scaling mismatch |
Adapter and LoRA Initialization
在 fine-tuning 中,常希望新增模块一开始不改变 base model。以 LoRA 为例:
\[ W_{\text{eff}} = W_0+\frac{\alpha}{r}BA. \]
常见初始化是 \(A\) 随机、\(B=0\)。这样初始时
\[ BA=0, \qquad W_{\text{eff}}=W_0, \]
模型行为完全等于 base model,但梯度仍能更新 \(B\):
\[ \frac{\partial L}{\partial B} = \frac{\alpha}{r} \frac{\partial L}{\partial W_{\text{eff}}}A^\top. \]
If both LoRA factors are initialized to zero, then the gradients for both factors can vanish at the first step. Initialize one factor randomly and the other to zero so the adapter starts as a no-op but remains trainable.
同样的思想适用于 residual adapters:最后一层 projection 或 scale 可以初始化为 0,让 adapter 初始不扰动主干;但 adapter 内部前面的层需要有随机性,否则没有可学习方向。
Practical Checklist
| component | common initialization | check |
|---|---|---|
| Linear + tanh | Xavier | activation variance stable |
| Linear + ReLU/GELU | He or framework default | dead units rare |
| Conv | fan-in/out includes kernel area and groups | output RMS stable |
| Transformer residual projections | smaller/depth-scaled std | residual RMS not growing by depth |
| Q/K projections | std keeps attention entropy reasonable | no saturated softmax at step 0 |
| Bias | often zero | no unwanted offset |
| LayerNorm/RMSNorm scale | one | normalized branch starts neutral |
| LM head | small normal or tied embedding | logit std and CE sane |
| Adapter/LoRA output | zero/no-op path | base model preserved |
Minimal PyTorch Pattern
import torch
from torch import nn
def init_mlp(module: nn.Module):
if isinstance(module, nn.Linear):
nn.init.kaiming_uniform_(module.weight, nonlinearity="relu")
if module.bias is not None:
nn.init.zeros_(module.bias)
model.apply(init_mlp)这里的 isinstance 只放在外围初始化适配层里。核心模型定义本身不需要为了初始化到处写特殊分支。
Transformer 风格的初始化通常需要知道层数和模块名字。一个简化版本:
def init_transformer(module: nn.Module, n_layers: int, std: float = 0.02):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
nn.init.zeros_(module.bias)
name = getattr(module, "_init_name", "")
if name.endswith(("attn_out", "mlp_down")):
module.weight.data.mul_((2 * n_layers) ** -0.5)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=std)真实项目中更推荐在模块构造时给关键 projection 明确命名,或在模型的 init_weights() 中按模块路径处理,避免靠字符串猜错。
Debugging Bad Initialization
初始化问题通常在前几百步就会暴露:
| Symptom | Likely issue |
|---|---|
| loss immediately NaN | logits/activations too large, LR too high |
| gradients vanish | saturated sigmoid/tanh, too-small weights |
| dead ReLUs | negative bias or bad scale |
| attention entropy near zero at step 0 | Q/K scale too large |
| attention entropy uniform forever | Q/K or LR too small |
| residual norm grows with depth | residual branch init too large |
| first-step CE far below \(\log V\) randomly | logits too sharp or target leakage |
| update/weight ratio huge | init too small, LR too high, or bad optimizer scale |
可以在第一批 batch 上记录:
with torch.no_grad():
for name, p in model.named_parameters():
if p.ndim >= 2:
print(name, float(p.std()))更有用的是记录每层 activation RMS、attention entropy、gradient norm。初始化是否合理,不应只看权重分布,还要看信号穿过模型后的统计。
用 forward hook 记录 activation RMS:
stats = {}
def save_rms(name):
def hook(_module, _inp, out):
if isinstance(out, tuple):
out = out[0]
stats[name] = out.detach().float().pow(2).mean().sqrt().item()
return hook
for name, module in model.named_modules():
if name.endswith(("attn", "mlp", "norm")):
module.register_forward_hook(save_rms(name))一次 dummy batch 后,如果 RMS 随层数单调爆炸或快速衰减,通常比等 loss NaN 更早暴露问题。
Implementation Checklist
初始化/调试模型时可以按下面顺序查:
- linear/conv 的 weight layout 是否符合
nn.init的 fan-in/fan-out 假设; - activation 和 init 是否匹配,ReLU/GELU 不要误用 tanh 的小尺度;
- residual output projection 是否有 depth-aware scaling;
- attention QK logits 在 step 0 是否过尖或全均匀;
- embedding 和 LM head 是否 tied,新增 token 是否合理初始化;
- LayerNorm/RMSNorm scale 是否从 1 开始;
- adapter/LoRA 是否初始为 no-op 但保留可学习梯度;
- 不同宽度模型是否比较了 activation RMS、logit std 和 update ratio;
- first batch forward/backward 是否无 NaN/Inf;
- 每层 activation RMS 和 grad norm 是否随深度系统性漂移。
两个 smoke tests:
# 1. if embeddings are tied, they should really share storage
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
# 2. initialization produces finite first-step statistics
logits = model(input_ids).logits
assert torch.isfinite(logits).all()
assert logits.float().std() < 10.0第二个阈值不是数学定理,只是粗暴保护:随机初始化的语言模型如果 logits 标准差已经很大,softmax 往往会过早饱和,训练第一步就可能不稳。