4.4 Transformer Architecture


Transformer block 是 attention、MLP、normalization 和 residual connections 的组合。LLM 的规模可以巨大,但基本 block 仍然非常规整。

Running Example: One Token Through One Block

假设 batch size 为 \(2\),sequence length 为 \(5\),hidden size 为 \(d=768\)。进入某一层的 residual stream 是

\[ X\in\mathbb{R}^{2\times5\times768}. \]

Self-attention 做三件事:

  1. 对每个 token 生成 \(Q,K,V\),shape 仍然是 \(2\times5\times768\)
  2. 每个 token 和同序列内可见 token 做 attention,混合上下文;
  3. 输出一个同 shape 的更新量 \(\Delta X_{\text{attn}}\)

Residual connection 不是替换原表示,而是相加:

\[ X' = X+\Delta X_{\text{attn}}. \]

随后 MLP 对每个 token 独立做非线性变换,再写回 residual stream:

\[ X_{\text{out}}=X'+\Delta X_{\text{mlp}}. \]

这就是 Transformer block 的教学主线:attention 负责 token mixing,MLP 负责 token-wise computation,residual stream 负责跨层累积信息。

Architecture Config as a Tensor Contract

模型配置不是 metadata,而是 forward pass 的张量契约。一个 decoder-only block 至少由这些字段确定:

Config field Tensor consequence
hidden_size = d residual stream shape [B, T, d]
num_hidden_layers = L number of repeated blocks
num_attention_heads = H_q query heads and attention output layout
num_key_value_heads = H_{kv} KV cache heads for GQA/MQA
head_dim = d_h per-head channel dimension
intermediate_size = d_mid MLP hidden width
vocab_size = V embedding and LM-head rows
max_position_embeddings or RoPE config valid position/cache offsets
tie_word_embeddings whether input embedding and output head share storage
norm_eps normalization numerical contract

最基本的 shape constraints 是:

\[ d=H_q d_h, \qquad H_q \bmod H_{kv}=0. \]

第一条保证 query heads 能 reshape 回 residual stream;第二条保证 GQA 中每个 KV head 可以服务整数个 query heads。若设

\[ G=\frac{H_q}{H_{kv}}, \]

则第 \(h\) 个 query head 读取的 KV head 常见映射是

\[ h_{kv}=\left\lfloor\frac{h}{G}\right\rfloor. \]

NoteDefinition: Architecture Contract

An architecture contract is the set of tensor shapes, normalization order, projection layouts, cache layouts, and parameter-sharing rules that a checkpoint assumes during forward and generation.

一个最小 config validator:

def validate_decoder_config(cfg):
    if cfg.hidden_size != cfg.num_attention_heads * cfg.head_dim:
        raise ValueError("hidden_size must equal num_attention_heads * head_dim")
    if cfg.num_attention_heads % cfg.num_key_value_heads != 0:
        raise ValueError("query heads must be divisible by key/value heads")
    if cfg.intermediate_size <= cfg.hidden_size:
        raise ValueError("intermediate_size is unexpectedly small")
    if cfg.vocab_size <= 0:
        raise ValueError("vocab_size must be positive")

真实模型还要验证 RoPE base/scaling、bias flags、norm type、activation function、embedding tying、sliding-window attention、MoE expert 数等。很多 checkpoint 可以“load 成功”,但如果这些 contract 错了,输出语义已经不是原模型。

WarningPitfall: Matching Tensor Rank Is Not Enough

A checkpoint can load without shape errors but still be semantically wrong if head grouping, norm epsilon, RoPE scaling, activation function, or tied-embedding policy differs from the training architecture.

Block Structure

Pre-LN decoder-only Transformer:

\[ \tilde{x}_\ell = x_\ell+ \operatorname{MHA}(\operatorname{Norm}(x_\ell)), \]

\[ x_{\ell+1} = \tilde{x}_\ell+ \operatorname{MLP}(\operatorname{Norm}(\tilde{x}_\ell)). \]

NoteDefinition: Residual Stream

The residual stream is the sequence of hidden vectors passed through layers via additive residual connections. Attention and MLP blocks read from and write updates into this stream.

residual stream 是 Transformer 的主干。各层 attention/MLP 不是替换表示,而是在同一条 stream 上逐步写入增量。

Pre-LN vs Post-LN

原始 Transformer 常写成 Post-LN:

\[ x_{\ell+1} = \operatorname{LN}(x_\ell+F_\ell(x_\ell)). \]

现代 LLM 更常用 Pre-LN:

\[ x_{\ell+1} = x_\ell+F_\ell(\operatorname{LN}(x_\ell)). \]

差别看似只是 LayerNorm 放在哪里,但训练深层网络时很关键。Pre-LN 中存在一条直接的 identity gradient path:

\[ \frac{\partial x_{\ell+1}}{\partial x_\ell} = I+ \frac{\partial F_\ell(\operatorname{LN}(x_\ell))} {\partial x_\ell}. \]

即使 \(F_\ell\) 的梯度不稳定,残差里的 \(I\) 也让梯度更容易跨层传播。Post-LN 中 LayerNorm 包在残差和子层输出之后:

\[ \frac{\partial x_{\ell+1}}{\partial x_\ell} = J_{\operatorname{LN}} \left(I+\frac{\partial F_\ell}{\partial x_\ell}\right), \]

梯度每层都要经过 \(J_{\operatorname{LN}}\),深层训练更敏感。

ImportantTheorem: Pre-LN Provides an Identity Gradient Path

In a Pre-LN residual block \(x_{\ell+1}=x_\ell+F_\ell(\operatorname{LN}(x_\ell))\), the Jacobian contains an additive identity term, giving gradients a direct route through depth.

\[ x_{\ell+1}=x_\ell+F_\ell(\operatorname{LN}(x_\ell)) \]

求导:

\[ \frac{\partial x_{\ell+1}}{\partial x_\ell} = \frac{\partial x_\ell}{\partial x_\ell} + \frac{\partial F_\ell(\operatorname{LN}(x_\ell))} {\partial x_\ell} = I+J_FJ_{\operatorname{LN}}. \]

反向传播时,上游梯度乘以

\[ I+J_FJ_{\operatorname{LN}}, \]

其中 identity 项对应残差直通路径。这个结论不是说 Pre-LN 永远优于 Post-LN,而是说明它为什么更容易训练很深的 stack。

MLP

Transformer MLP 通常是 two-layer feed-forward network:

\[ \operatorname{MLP}(x) = W_2\phi(W_1x+b_1)+b_2. \]

现代 LLM 常用 gated MLP,如 SwiGLU:

\[ \operatorname{SwiGLU}(x) = (xW_g)\odot \operatorname{SiLU}(xW_u)W_d. \]

MLP 提供 token-wise nonlinear computation,attention 提供 token mixing。两者分工不同:attention 搬运信息,MLP 变换信息。

Gated MLP Details

很多 LLM 使用 gated MLP,而不是普通 FFN。常见形式:

\[ \operatorname{FFN}(x)=W_2\phi(W_1x), \]

参数量约为

\[ 2dd_{\text{ff}}. \]

SwiGLU 通常写作:

\[ \operatorname{SwiGLU}(x) = \left(\operatorname{SiLU}(xW_u)\odot xW_g\right)W_d. \]

其中

\[ W_u,W_g\in\mathbb{R}^{d\times d_{\text{mid}}}, \qquad W_d\in\mathbb{R}^{d_{\text{mid}}\times d}. \]

参数量约为

\[ 3dd_{\text{mid}}. \]

为了让 gated MLP 参数量接近普通 \(d_{\text{ff}}=4d\) 的 FFN,常取

\[ 3dd_{\text{mid}}\approx 2d(4d), \qquad d_{\text{mid}}\approx\frac{8}{3}d. \]

这就是许多 LLM 中 intermediate size 不是精确 \(4d\),而接近 \(2.67d\) 再按硬件友好倍数取整的原因。

NoteDefinition: Token Mixing vs Channel Mixing

Attention performs token mixing because each position reads other positions. The MLP performs channel mixing because it transforms each token vector independently across hidden dimensions.

SwiGLU 的实现通常把 gate/up projection 合并成一个大矩阵以减少 kernel launch:

class SwiGLUMLP(nn.Module):
    def __init__(self, d_model, d_mid):
        super().__init__()
        self.gate_up = nn.Linear(d_model, 2 * d_mid, bias=False)
        self.down = nn.Linear(d_mid, d_model, bias=False)

    def forward(self, x):
        gate, up = self.gate_up(x).chunk(2, dim=-1)
        return self.down(torch.nn.functional.silu(gate) * up)

从 shape 看:

x:       [B, T, d]
gate/up: [B, T, d_mid]
out:     [B, T, d]

MLP 不跨 token 交换信息,所以它可以看作对每个位置共享同一个非线性函数。这个性质让 MLP 在 prefill 和 decode 时都很好并行,但也意味着跨 token 的复制、指代、长程依赖必须由 attention 或其他 token-mixing 模块提供。

Normalization

LayerNorm:

\[ \operatorname{LN}(x) = \gamma\frac{x-\mu}{\sqrt{\sigma^2+\epsilon}}+\beta. \]

RMSNorm:

\[ \operatorname{RMSNorm}(x) = g\frac{x}{\sqrt{\frac{1}{d}\sum_i x_i^2+\epsilon}}. \]

RMSNorm 去掉均值中心化,计算更简单,在许多 LLM 中常见。

LayerNorm 对每个 token 的 hidden dimension 计算均值和方差:

\[ \mu=\frac1d\sum_i x_i, \qquad \sigma^2=\frac1d\sum_i(x_i-\mu)^2. \]

RMSNorm 只保留二阶尺度:

\[ \operatorname{rms}(x) = \sqrt{\frac1d\sum_i x_i^2+\epsilon}. \]

它不消除均值,因此不是完全等价替换。LLM 中 RMSNorm 常和 Pre-LN、RoPE、SwiGLU 一起出现,形成一套训练稳定且计算简洁的 decoder-only block。

Norm Placement Variants

现代 Transformer block 有几种常见变体:

Variant Equations Training behavior
Post-LN \(x_{l+1}=\operatorname{LN}(x_l+F(x_l))\) original Transformer, deep training harder
Pre-LN \(x_{l+1}=x_l+F(\operatorname{LN}(x_l))\) stable deep training, common in LLMs
Sandwich-LN extra norm around sublayer more normalization, higher cost
NormFormer-style norms inside attention/MLP paths stabilizes activations in some settings

Pre-LN 的代价是输出端可能缺少统一归一化,因此 decoder-only LLM 通常在最后再加一个 final norm:

\[ h_{\text{final}} = \operatorname{Norm}(x_L), \qquad z=h_{\text{final}}E^\top. \]

若漏掉 final norm,logit scale 可能随层数和训练过程漂移。

One Decoder Block With Shapes

假设

\[ X\in\mathbb{R}^{B\times T\times d}. \]

Pre-LN attention:

\[ U=\operatorname{Norm}(X), \]

\[ Q,K,V=UW^Q,UW^K,UW^V, \]

reshape 为

\[ Q\in\mathbb{R}^{B\times H_q\times T\times d_h}, \quad K,V\in\mathbb{R}^{B\times H_{kv}\times T\times d_h}. \]

若是 GQA,需要把 KV heads 映射或 broadcast 到 query head groups。attention 输出 reshape 回

\[ A\in\mathbb{R}^{B\times T\times d}. \]

残差:

\[ X'=X+A. \]

MLP:

\[ Y=X'+\operatorname{MLP}(\operatorname{Norm}(X')). \]

所以每层输入输出 shape 一样,都是 \([B,T,d]\)。Transformer 的深度来自重复同一个接口;复杂度藏在 attention 的 \(T^2\) 和 MLP 的 \(d_{\text{mid}}\)

Minimal Decoder Block Implementation

一个现代 Pre-LN decoder block 可以写成:

class DecoderBlock(nn.Module):
    def __init__(self, attn, mlp, norm1, norm2):
        super().__init__()
        self.attn = attn
        self.mlp = mlp
        self.norm1 = norm1
        self.norm2 = norm2

    def forward(self, x, *, attention_mask, position_ids, cache=None):
        attn_out, cache = self.attn(
            self.norm1(x),
            attention_mask=attention_mask,
            position_ids=position_ids,
            cache=cache,
        )
        x = x + attn_out
        x = x + self.mlp(self.norm2(x))
        return x, cache

这个骨架暴露了四个接口:

  1. attention_mask: 可见性约束;
  2. position_ids: RoPE/position offset;
  3. cache: decode-time KV state;
  4. residual stream x: block 之间的固定通信格式。

训练时 cache=None,输入是完整 [B,T,d];decode 时 x 可能只有最后一个 token [B,1,d],但 cache 中保存历史 K/V。

Training, Prefill, and Decode Contracts

同一个 decoder block 在三个阶段的 shape contract 不同:

Stage Input hidden Attention keys/values Output
training [B, T, d] computed from same sequence [B, T, d]
prefill [B, T_prompt, d] computed and appended to cache [B, T_prompt, d]
decode [B, 1, d] new KV plus cached history [B, 1, d]

KV cache 通常保存 projected keys/values,而不是 residual states:

\[ K_{\text{cache}},V_{\text{cache}} \in \mathbb{R}^{B\times H_{kv}\times T_{\text{cache}}\times d_h}. \]

decode 第 \(t\) 步输入一个新 token hidden state

\[ x_t\in\mathbb{R}^{B\times 1\times d}. \]

它产生

\[ q_t\in\mathbb{R}^{B\times H_q\times 1\times d_h}, \qquad k_t,v_t\in\mathbb{R}^{B\times H_{kv}\times 1\times d_h}. \]

然后把 \(k_t,v_t\) append 到 cache,并让 query attend 到长度 \(T_{\text{cache}}+1\) 的 keys。attention logits shape 是

\[ B\times H_q\times 1\times (T_{\text{cache}}+1). \]

NoteDefinition: Cache Position

Cache position is the absolute or logical position assigned to newly decoded tokens when applying positional encodings and appending KV states.

RoPE/position ids 必须使用 cache position,而不是简单地从 0 开始:

def decode_step(block, x_new, cache, cache_position):
    # x_new: [B, 1, d]
    position_ids = cache_position[:, None]
    attention_mask = make_decode_mask(cache, x_new)
    y, cache = block(
        x_new,
        attention_mask=attention_mask,
        position_ids=position_ids,
        cache=cache,
    )
    return y, cache

如果 decode 时 position ids 每步都写成 0,shape 完全正确,但 RoPE 相对位置信息会坏掉。长上下文时这类错误会表现为模型似乎“忘记前文”或生成退化。

WarningPitfall: Decode Shape Correctness Does Not Imply Cache Correctness

Autoregressive decode can have valid tensor shapes while using wrong cache positions, wrong KV head grouping, stale masks, or mismatched cache lengths.

训练与推理还有一个差别:训练通常用 full causal mask

\[ M\in\{0,1\}^{T\times T}, \]

decode 单 token step 的 causal relation 已由 cache append 顺序保证,常常只需要处理 padding、prefix、sliding-window 或 packed-request mask。不要把 training mask 逻辑机械搬到 decode kernel。

Decoder-Only, Encoder-Only, Encoder-Decoder

Architecture Mask Objective Example
encoder-only bidirectional masked/token classification BERT
decoder-only causal next-token prediction GPT/LLaMA/Qwen
encoder-decoder source bidirectional + target causal seq2seq T5/translation

LLM 主流选择 decoder-only,不是因为其他结构无效,而是因为 causal next-token prediction 与大规模文本生成、in-context learning 和统一接口非常契合。

Encoder-Decoder as Conditional Factorization

Encoder-decoder 模型通常建模

\[ p_\theta(y_{1:T_y}\mid x_{1:T_x}) = \prod_{t=1}^{T_y} p_\theta(y_t\mid y_{<t},x_{1:T_x}). \]

encoder 对 source tokens 做 bidirectional self-attention,得到 memory:

\[ H_x=\operatorname{Encoder}(x_{1:T_x}). \]

decoder 对 target prefix 做 causal self-attention,再用 cross-attention 读取 \(H_x\)。相比 decoder-only 把 source 和 target 拼成一个 prompt,encoder-decoder 有更明确的条件结构,但接口不如 decoder-only 统一。

Cross-Attention Shapes

设 source length 为 \(S\),target length 为 \(T\)。encoder 输出

\[ H_x\in\mathbb{R}^{B\times S\times d}. \]

decoder hidden state 为

\[ Y\in\mathbb{R}^{B\times T\times d}. \]

self-attention 的 Q/K/V 都来自 \(Y\),并使用 causal target mask。cross-attention 则是:

\[ Q_y=YW_Q, \qquad K_x=H_xW_K, \qquad V_x=H_xW_V. \]

因此 cross-attention logits shape 是

\[ B\times H\times T\times S. \]

这和 decoder-only self-attention 的 \(T\times T\) 不同。cross-attention mask 也不是 causal mask,而是 source padding mask:

\[ M_{t,s}=1 \quad\text{if source token }s\text{ is real}. \]

NoteDefinition: Cross-Attention

Cross-attention is attention where queries come from the decoder state and keys/values come from encoder memory.

一个 encoder-decoder decoder layer 有三条 residual branch:

target residual
  -> causal self-attention over target prefix
  -> cross-attention over encoder memory
  -> MLP

最小骨架:

class EncoderDecoderLayer(nn.Module):
    def __init__(self, self_attn, cross_attn, mlp, norm1, norm2, norm3):
        super().__init__()
        self.self_attn = self_attn
        self.cross_attn = cross_attn
        self.mlp = mlp
        self.norm1 = norm1
        self.norm2 = norm2
        self.norm3 = norm3

    def forward(self, y, encoder_hidden, *, tgt_mask, src_mask, cache=None):
        self_out, cache = self.self_attn(
            self.norm1(y),
            attention_mask=tgt_mask,
            cache=cache,
        )
        y = y + self_out

        cross_out = self.cross_attn(
            self.norm2(y),
            key_value_states=encoder_hidden,
            attention_mask=src_mask,
        )
        y = y + cross_out
        y = y + self.mlp(self.norm3(y))
        return y, cache

在 seq2seq inference 中,encoder memory \(H_x\) 对整个生成过程固定,可以预先投影或缓存 cross-attention K/V;decoder self-attention cache 仍随生成 token 增长。

WarningPitfall: Source Mask and Target Mask Are Different Objects

In encoder-decoder models, the target self-attention mask is causal, while the source cross-attention mask is usually a padding mask. Reusing one for the other silently changes the conditional factorization.

Scaling

参数量大致来自 embedding、attention projections、MLP:

\[ \text{params per layer} \approx 4d^2 + 2d d_{\text{ff}} \]

\(d_{\text{ff}}\approx4d\),每层约 \(12d^2\)。训练算力近似与 parameter count 和 token count 成正比:

\[ \text{training FLOPs}\approx 6ND, \]

其中 \(N\) 是参数量,\(D\) 是训练 token 数。

更细地,单层 decoder block 主要参数:

Component Approx params
QKV projections \(3d^2\) for MHA, less for GQA KV projections
output projection \(d^2\)
FFN \(2dd_{\text{ff}}\)
SwiGLU FFN \(3dd_{\text{mid}}\)
norms \(O(d)\)

若用普通 FFN 且 \(d_{\text{ff}}=4d\),每层约

\[ 4d^2+8d^2=12d^2. \]

若用 SwiGLU 且 \(d_{\text{mid}}\approx\frac83d\),MLP 仍约 \(8d^2\),所以经验上每层总量仍接近 \(12d^2\)。这也是估算 LLM 参数量时常用的粗略规则。

Exact Parameter Count with GQA and SwiGLU

粗略 \(12d^2\) 很好用,但读 model card 时要能算更精确的账。令

\[ d=H_qd_h, \qquad H_{kv}\le H_q. \]

若 Q/K/V/O 都无 bias,则 attention projection 参数为:

\[ N_Q=d(H_qd_h)=d^2, \]

\[ N_K=N_V=d(H_{kv}d_h), \]

\[ N_O=(H_qd_h)d=d^2. \]

所以

\[ N_{\text{attn}} = 2d^2+2dH_{kv}d_h. \]

\(H_{kv}=H_q\) 时,

\[ N_{\text{attn}}=4d^2. \]

当使用 GQA 且 \(H_{kv}<H_q\) 时,K/V 参数和 KV cache 都下降,但 Q/O 仍保持 \(d^2\) 级别。

SwiGLU 参数为

\[ N_{\text{mlp}} = N_{gate}+N_{up}+N_{down} = 3dd_{\text{mid}} \]

如果无 bias。每层 norm 参数通常是 \(O(d)\)

\[ N_{\text{norm}}\approx 2d \]

for two RMSNorm scales。于是单层 dense decoder block 近似:

\[ N_{\text{layer}} = 2d^2+2dH_{kv}d_h+3dd_{\text{mid}}+O(d). \]

ImportantImplementation Contract: Count What the Config Actually Builds

Parameter count depends on GQA heads, bias flags, gated-MLP width, norm type, embedding tying, and MoE routing. The shortcut \(12d^2L\) is a sanity check, not a checkpoint contract.

一个小的估算器:

def dense_decoder_layer_params(d, n_q, n_kv, head_dim, d_mid, bias=False):
    q = d * n_q * head_dim
    k = d * n_kv * head_dim
    v = d * n_kv * head_dim
    o = n_q * head_dim * d
    mlp = 3 * d * d_mid
    norms = 2 * d
    attn_bias = (n_q + 2 * n_kv) * head_dim + d
    mlp_bias = 2 * d_mid + d
    bias_terms = attn_bias + mlp_bias if bias else 0
    return q + k + v + o + mlp + norms + bias_terms

对于 tied embeddings,input embedding 和 LM head 是同一块权重:

\[ E\in\mathbb{R}^{V\times d}, \qquad z=hE^\top. \]

若 untied,则额外多一个 \(Vd\) 的输出矩阵。对大词表模型,这不是小项。

FLOP Accounting per Layer

参数量和 FLOPs 相关,但不完全相同。对 decoder-only prefill,单层大致:

Component Approx FLOPs
QKV projections \(6BTd^2\)
output projection \(2BTd^2\)
attention scores \(QK^\top\) \(2BH T^2 d_h\)
attention-value \(PV\) \(2BH T^2 d_h\)
FFN with \(d_{\text{ff}}\) \(4BTdd_{\text{ff}}\)

这里按 multiply-add 约 2 FLOPs 估算。若 \(H d_h=d\),attention score + value 约为:

\[ 4BT^2d. \]

普通 FFN 且 \(d_{\text{ff}}=4d\) 时,FFN 约:

\[ 16BTd^2. \]

所以短上下文下 MLP/linear projection 可能比 attention matrix 更贵;长上下文下 \(T^2\) 项会迅速主导。理解这张账,才能知道 FlashAttention、GQA、MLP fusion、tensor parallelism 分别在优化哪一块。

Residual Scale and Deep Stacks

如果每层都写入一个未缩放 update:

\[ x_{L} = x_0+\sum_{\ell=0}^{L-1}\Delta x_\ell, \]

在粗略独立假设下,residual variance 可能随 \(L\) 增长:

\[ \operatorname{Var}(x_L) \approx \operatorname{Var}(x_0) + \sum_\ell \operatorname{Var}(\Delta x_\ell). \]

深层 Transformer 因此常配合:

  1. Pre-LN/RMSNorm 控制输入尺度;
  2. careful initialization 控制 update 尺度;
  3. residual scaling 或 depth-dependent init;
  4. warmup 和 gradient clipping 控制早期不稳定。
WarningPitfall: Residual Connections Do Not Automatically Fix Scale

Residual paths help gradient flow, but activation scale can still drift through depth if update magnitudes are poorly initialized or normalized.

Residual Stream as a Communication Bus

可以把 residual stream 看成一条共享总线。每个 attention head 和 MLP 都从总线读取当前状态,然后写入一个 update:

residual stream
  -> norm
  -> attention heads read/write
  -> residual stream
  -> norm
  -> MLP read/write
  -> residual stream

这种结构带来两个后果:

  1. 层之间接口固定,易于 scale depth;
  2. 信息不是存放在某个单独模块里,而是分布在 residual stream 的方向上。

从 mechanistic interpretability 的视角,很多特征可以看作 residual stream 中的 directions;attention head 和 MLP 则是读写这些 directions 的算子。

From Architecture to Training

Transformer 架构只是底座。最终能力来自:

  1. tokenizer and data mixture;
  2. pretraining objective;
  3. optimizer and schedule;
  4. context length curriculum;
  5. instruction tuning;
  6. preference/RL alignment;
  7. inference-time sampling and systems。

这也是为什么本章后半会直接接到 LLaDA、Mamba 和 inference infrastructure:现代 LLM 已经不能只按“模型结构”来理解,必须把训练范式和系统约束一起看。

Checkpoint Compatibility Audit

加载 Transformer checkpoint 前,应该把 config 当作合约审计,而不是只看 state_dict 能否匹配。至少检查:

Contract item Failure mode
hidden_size, heads, head_dim QKV reshape or GQA grouping wrong
norm type and norm_eps activation scale differs from training
activation function MLP branch computes different nonlinear map
bias flags projection semantics and parameter count differ
RoPE base/scaling long-context positions interpreted differently
vocab size and tokenizer ids embedding rows point to wrong tokens
tied embeddings LM head storage and gradients differ
attention implementation mask convention or cache layout mismatch
MoE config if present activated parameters and routing differ

一个实用审计函数可以打印关键 tensor shape:

def audit_decoder_shapes(model):
    rows = []
    for name, p in model.named_parameters():
        if any(key in name for key in ("q_proj", "k_proj", "v_proj", "o_proj")):
            rows.append((name, tuple(p.shape)))
        if any(key in name for key in ("gate_proj", "up_proj", "down_proj")):
            rows.append((name, tuple(p.shape)))
        if "embed" in name or "lm_head" in name:
            rows.append((name, tuple(p.shape)))
    return rows

对 tied embedding,还可以检查 storage 是否共享:

def tied_embedding_ok(model):
    emb = model.get_input_embeddings().weight
    head = model.get_output_embeddings().weight
    return emb.data_ptr() == head.data_ptr()

如果配置声明 tied,但实际不是 tied,训练和推理都会多出一个输出矩阵;如果配置声明 untied,但你强行 tie,模型 capacity 和 checkpoint 语义都变了。

ImportantImplementation Contract: Config, Tokenizer, and Checkpoint Must Agree

For LLMs, architecture config, tokenizer files, and checkpoint tensors form one artifact. Valid tensor shapes alone do not prove the model is semantically loaded correctly.

Implementation Checklist

实现或阅读 Transformer architecture 时检查:

  1. block 是 Pre-LN、Post-LN 还是其他 norm placement;
  2. attention 和 MLP 是否都写回同一 residual stream;
  3. final norm 是否存在;
  4. MLP 是普通 FFN、GEGLU、SwiGLU 还是 MoE;
  5. gated MLP 的 gate/up/down shape 是否正确;
  6. hidden size、head 数和 head dim 是否整除;
  7. GQA/MQA 的 KV heads 是否和 attention kernel/cache 一致;
  8. position ids 在 train/prefill/decode 中是否一致;
  9. parameter count 和 FLOP bottleneck 是否分别估算;
  10. checkpoint 中是否保存 architecture config、tokenizer config 和 rope/cache 相关参数;
  11. config 是否满足 \(d=H_qd_h\)\(H_q\bmod H_{kv}=0\)
  12. tied/untied embedding 是否与 checkpoint 和 tokenizer 一致;
  13. prefill/decode 的 KV cache shape 是否使用 \(H_{kv}\) 而不是 \(H_q\)
  14. encoder-decoder 的 source padding mask 和 target causal mask 是否分开;
  15. norm type、activation、bias flags、RoPE scaling 是否和训练配置一致。