4.2B Attention Implementation and Kernels
上一节讲了 attention 的数学定义:
\[ \operatorname{Attention}(Q,K,V) = \operatorname{softmax} \left( \frac{QK^\top}{\sqrt{d_h}} \right)V. \]
这一节把它落到实现:tensor shape 怎么排,mask 怎么 broadcast,softmax 怎么做数值稳定,为什么 naive attention 浪费显存,FlashAttention 为什么能保持 exact attention 却减少 HBM 读写。
Tensor Shapes
假设:
- batch size \(B\);
- sequence length \(T\);
- hidden size \(d\);
- heads \(H\);
- head dimension \(d_h=d/H\)。
输入 residual stream:
\[ X\in\mathbb{R}^{B\times T\times d}. \]
线性投影:
\[ Q=XW^Q,\quad K=XW^K,\quad V=XW^V. \]
reshape 成 multi-head:
\[ Q,K,V\in\mathbb{R}^{B\times H\times T\times d_h}. \]
attention logits:
\[ S=\frac{QK^\top}{\sqrt{d_h}} \in \mathbb{R}^{B\times H\times T\times T}. \]
最后输出:
\[ O=\operatorname{softmax}(S)V \in \mathbb{R}^{B\times H\times T\times d_h}. \]
再 transpose/reshape 回:
\[ O\in\mathbb{R}^{B\times T\times d}. \]
view After transpose Can Be Wrong
After transposing dimensions, tensors are often non-contiguous. Use .contiguous().view(...) or .reshape(...) carefully, otherwise PyTorch may error or silently copy.
QKV Projection Layout
实现 multi-head attention 时,QKV projection 有两种常见写法:
self.q_proj = nn.Linear(d_model, num_q_heads * head_dim, bias=False)
self.k_proj = nn.Linear(d_model, num_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(d_model, num_kv_heads * head_dim, bias=False)或者 fused QKV:
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)fused QKV 的好处是减少 kernel launch,并把三个 GEMM 合并成一个更大的 GEMM;坏处是 GQA/MQA、不同 bias policy、LoRA target module、checkpoint surgery 会更复杂。一个 dense MHA block 中,fused projection 输出:
x: [B, T, d_model]
qkv: [B, T, 3 * H * Dh]
q: [B, H, T, Dh]
k: [B, H, T, Dh]
v: [B, H, T, Dh]
常见 shape 代码:
qkv = self.qkv_proj(x)
qkv = qkv.view(B, T, 3, H, Dh)
q, k, v = qkv.unbind(dim=2)
q = q.transpose(1, 2) # [B, H, T, Dh]
k = k.transpose(1, 2)
v = v.transpose(1, 2)如果是 GQA,query heads 和 KV heads 不同:
q: [B, Hq, T, Dh]
k: [B, Hkv, T, Dh]
v: [B, Hkv, T, Dh]
其中 \(H_q\) 必须能被 \(H_{kv}\) 整除。一个朴素 repeat:
def repeat_kv(x, groups):
# x: [B, Hkv, T, Dh]
B, Hkv, T, Dh = x.shape
x = x[:, :, None, :, :].expand(B, Hkv, groups, T, Dh)
return x.reshape(B, Hkv * groups, T, Dh)但高性能 kernel 往往不会 materialize repeat,而是在 kernel 内部把 query head 映射到 KV head:
\[ h_{kv}=\left\lfloor \frac{h_q}{H_q/H_{kv}}\right\rfloor. \]
A head layout contract specifies whether tensors are stored as [B,H,T,Dh] or [B,T,H,Dh], how Q heads map to KV heads, and whether KV repetition is materialized or handled inside the attention kernel.
这件事直接影响 LoRA 和 checkpoint。若 checkpoint 使用 fused qkv_proj.weight,而代码写成 separate q_proj/k_proj/v_proj,就要按输出维度切权重;若模型是 GQA,k_proj/v_proj 的输出维度不是 d_model,而是 Hkv * Dh。
A single qkv_proj tensor can hide whether the model uses MHA, GQA, biases, QK-Norm, or different head counts. Always recover semantic shapes before editing adapters or loading checkpoints.
Layout for Kernels
不同 API 对布局要求不同。教学里常写 [B,H,T,Dh],因为和公式接近;有些 kernel 内部偏好 [B,T,H,Dh] 或 packed layout,因为连续内存访问更好。关键检查是 stride:
print(q.shape, q.stride())例如:
q = q.view(B, T, H, Dh).transpose(1, 2)得到的 [B,H,T,Dh] 通常不是 contiguous。若后面调用支持 strided input 的 fused attention,可能没问题;若后面要 view(B, T, H * Dh),必须先 transpose 回 [B,T,H,Dh] 并 contiguous:
out = out.transpose(1, 2).contiguous().view(B, T, H * Dh)不要在每一步都盲目 .contiguous()。它会复制数据。正确策略是:在 kernel 需要 contiguous 或 reshape 语义要求 contiguous 时,明确支付这次复制成本。
Naive PyTorch Attention
概念实现:
import math
import torch
def attention(q, k, v, attn_mask):
# q, k, v: [B, H, T, Dh]
scores = q @ k.transpose(-2, -1)
scores = scores / math.sqrt(q.size(-1))
scores = scores + attn_mask
probs = torch.softmax(scores, dim=-1)
out = probs @ v
return out这个实现很好懂,但会显式 materialize:
\[ S\in\mathbb{R}^{B\times H\times T\times T} \]
和
\[ P=\operatorname{softmax}(S) \in\mathbb{R}^{B\times H\times T\times T}. \]
当 \(T=8192\) 时,\(T^2\approx 67\) million。对每个 batch/head 都保存完整矩阵,显存很快爆炸。
Mask Broadcasting
常见 mask 有两种:
- causal mask: \([1,1,T,T]\);
- padding mask: \([B,1,1,T]\)。
它们加到 scores 上:
\[ S'=S+M_{\text{causal}}+M_{\text{pad}}. \]
causal mask:
\[ M_{ij}^{\text{causal}} = \begin{cases} 0,& j\leq i,\\ -\infty,& j>i. \end{cases} \]
padding mask:
\[ M_{b,j}^{\text{pad}} = \begin{cases} 0,& \text{token }j\text{ is real},\\ -\infty,& \text{token }j\text{ is pad}. \end{cases} \]
An additive attention mask is added to attention logits before softmax. Valid positions receive \(0\), while invalid positions receive a large negative value so their softmax probability becomes zero.
工程上常用 torch.finfo(dtype).min 或一个足够大的负数,而不是真的写 Python -inf。但 FP16 下过大的负数和某些 fused kernels 可能有细节差异,最好使用框架推荐 API。
Boolean Mask vs Additive Bias
mask contract 必须写清楚:True 是 visible 还是 masked?PyTorch 不同 API 的约定可能不同。讲义里建议统一成:
visible_mask: bool, True means token can be attended to
additive_bias: float, 0 for visible and large negative for masked
二者转换:
def visible_to_bias(visible, dtype):
bias = torch.zeros_like(visible, dtype=dtype)
return bias.masked_fill(~visible, torch.finfo(dtype).min)attention score 的完整形状通常是:
scores: [B, H, Tq, Tk]
causal_mask: [1, 1, Tq, Tk]
padding_mask: [B, 1, 1, Tk]
若是 packed sequence,还可能需要 block-diagonal mask:
doc_id: [B, T]
visible[b, i, j] = (j <= i) and (doc_id[b, i] == doc_id[b, j])
这表示同一个 packed row 里的两个样本不应该互相 attention。只加 causal mask 不够,因为前一个样本的尾 token 仍然在后一个样本的过去。
In packed training, a causal mask prevents future leakage but still allows later documents to attend to earlier documents unless a document-boundary mask is also applied.
Empty Rows and Padding-Only Queries
softmax 需要至少一个 visible key。若某一行全被 mask 掉,数学上:
\[ \sum_j e^{-\infty}=0, \]
分母为 \(0\),输出未定义。padding-only query、左 padding 的前缀位置、错误的 block mask 都可能产生 all-masked rows。
稳妥的训练策略:
- padding query 的 loss 置为
-100; - attention 实现确保 padding query 至少能 attend 到某个 dummy/self 位置,或在输出后把 padding query 清零;
- debug 时显式检查 visible count:
visible_count = visible_mask.sum(dim=-1)
if (visible_count == 0).any():
raise ValueError("attention mask has all-masked query rows")Stable softmax prevents overflow, but it cannot define a probability distribution when every key is masked. Mask construction must avoid or explicitly handle empty rows.
Numerically Stable Softmax
softmax:
\[ p_i=\frac{e^{s_i}}{\sum_j e^{s_j}}. \]
如果 \(s_i\) 很大,\(e^{s_i}\) 会 overflow。稳定做法减去最大值:
\[ p_i = \frac{e^{s_i-m}}{\sum_j e^{s_j-m}}, \qquad m=\max_j s_j. \]
因为分子分母同时乘以 \(e^{-m}\):
\[ \frac{e^{s_i}}{\sum_j e^{s_j}} = \frac{e^{s_i}e^{-m}}{\sum_j e^{s_j}e^{-m}} = \frac{e^{s_i-m}}{\sum_j e^{s_j-m}}. \]
减去最大值不会改变 softmax 结果,但会让最大 exponent 变成 \(e^0=1\),避免 overflow。
Dropout in Attention
Transformer 训练时常对 attention probabilities 做 dropout:
\[ \tilde{P}_{ij} = \frac{M_{ij}P_{ij}}{1-p}, \qquad M_{ij}\sim\operatorname{Bernoulli}(1-p). \]
dropout 在 softmax 后、乘 \(V\) 前:
\[ O=\tilde{P}V. \]
推理时 dropout 关闭。这个细节很重要:attention dropout 不是 mask future tokens,而是训练正则化。
Memory Cost
naive attention 的主要中间张量:
| Tensor | Shape | Elements |
|---|---|---|
| scores | \(B,H,T,T\) | \(BHT^2\) |
| probabilities | \(B,H,T,T\) | \(BHT^2\) |
| output | \(B,H,T,d_h\) | \(BTd\) |
对于 \(B=4,H=32,T=4096\),仅 scores 的元素数:
\[ 4\times32\times4096^2\approx2.15\times10^9. \]
若 FP16 是 2 bytes,scores 单独约:
\[ 4.3\text{ GB}. \]
再加 probabilities、backward 需要的保存项、QKV、activation,显存压力非常大。
FlashAttention: IO-Aware Exact Attention
FlashAttention 的核心不是近似 attention,而是不把完整 \(T\times T\) attention matrix 写回 HBM。它把 \(Q,K,V\) 分块放入 SRAM,在 block 内计算 partial softmax 和 partial output,并用 online softmax 维护全局归一化。
An IO-aware algorithm is designed around the cost of moving data between memory levels, such as GPU HBM and on-chip SRAM, rather than only counting arithmetic FLOPs.
标准 attention 的问题是:
read Q,K -> write scores to HBM -> read scores -> write probs -> read probs,V -> write output
FlashAttention 更像:
stream K,V blocks through SRAM
update output and softmax normalizer online
write only final output
这减少 HBM traffic,因此 wall-clock 更快、显存更省。
Online Softmax
关键数学问题:如果不一次性看到所有 logits,怎么得到 exact softmax?
对一行 logits \(s_1,\ldots,s_T\),定义:
\[ m=\max_j s_j, \qquad \ell=\sum_j e^{s_j-m}. \]
softmax output 对 value 的加权和:
\[ o = \frac{1}{\ell} \sum_j e^{s_j-m}v_j. \]
现在分块处理。已有 block 的状态是 \((m_{\text{old}},\ell_{\text{old}},o_{\text{old}})\)。新 block 的最大值和归一化和为 \((m_{\text{new}},\ell_{\text{new}},o_{\text{new}})\)。合并最大值:
\[ m=\max(m_{\text{old}},m_{\text{new}}). \]
旧 normalizer 需要重标定:
\[ \ell = e^{m_{\text{old}}-m}\ell_{\text{old}} + e^{m_{\text{new}}-m}\ell_{\text{new}}. \]
输出也按同样权重合并:
\[ o = \frac{ e^{m_{\text{old}}-m}\ell_{\text{old}}o_{\text{old}} + e^{m_{\text{new}}-m}\ell_{\text{new}}o_{\text{new}} }{ \ell }. \]
旧 block 的未归一化 value sum 是:
\[ u_{\text{old}} = \sum_{j\in old}e^{s_j-m_{\text{old}}}v_j = \ell_{\text{old}}o_{\text{old}}. \]
如果全局最大值变成 \(m\),旧 block 的贡献需要乘:
\[ e^{m_{\text{old}}-m}. \]
因为
\[ e^{s_j-m} = e^{s_j-m_{\text{old}}}e^{m_{\text{old}}-m}. \]
新 block 同理。把旧新未归一化 sum 相加,再除以新的 normalizer \(\ell\),得到合并公式。
这说明 FlashAttention 不是改了 softmax 定义,而是在数学上等价地流式计算 softmax。
Tiled Attention Pseudocode
FlashAttention 的核心状态可以按每个 query row 维护三件事:
m: running max logit
l: running normalizer
o: running output numerator / normalized output
简化伪代码:
def tiled_attention(q_block, k_blocks, v_blocks):
# q_block: [Bq, Dh], each k/v block: [Bk, Dh]
m = torch.full((Bq,), -torch.inf, device=q_block.device)
l = torch.zeros((Bq,), device=q_block.device)
o = torch.zeros((Bq, Dh), device=q_block.device)
for k, v in zip(k_blocks, v_blocks):
s = q_block @ k.T / math.sqrt(Dh) # [Bq, Bk]
m_new = torch.maximum(m, s.max(dim=-1).values)
alpha_old = torch.exp(m - m_new)
p = torch.exp(s - m_new[:, None])
l_new = alpha_old * l + p.sum(dim=-1)
o = (alpha_old[:, None] * l[:, None] * o + p @ v) / l_new[:, None]
m, l = m_new, l_new
return o真实 kernel 会把这些放在 SRAM/register 里,并处理 dropout、causal boundary、backward、dtype accumulation。伪代码的价值是说明:FlashAttention 保存的是每行的 compact state,而不是完整 \(T\times T\) 概率矩阵。
An online attention state stores the running row maximum, softmax normalizer, and partial output needed to combine attention blocks exactly without materializing the full score matrix.
Causal FlashAttention
causal mask 在 block attention 中变成 block-level 和 element-level 两层判断:
- 如果 key block 完全在 query block 未来,可以跳过;
- 如果 block 与下三角边界相交,只对 block 内未来位置加 mask;
- 如果 key block 完全在过去,正常计算。
这样可以避免无意义的未来 block 计算。
设 query block 覆盖位置 \([q_0,q_1)\),key block 覆盖位置 \([k_0,k_1)\)。对 causal attention:
| Relation | Condition | Action |
|---|---|---|
| future block | \(k_0\geq q_1\) | skip block |
| past block | \(k_1\leq q_0+1\) | compute full block |
| boundary block | otherwise | compute with triangular mask |
这个 block-level 判断减少了大量无效计算。长序列中,causal attention 只需要下三角区域,理论上约为 full bidirectional attention 的一半;FlashAttention 还进一步减少 HBM traffic。
PyTorch SDPA
现代 PyTorch 提供 scaled_dot_product_attention,可以自动选择 math、memory-efficient、FlashAttention 后端:
import torch.nn.functional as F
out = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=0.0,
is_causal=True,
)实际训练中,优先用框架的 fused attention API,而不是手写 softmax(q @ k.T) @ v。
is_causal and attn_mask Can Interact
Different backends have constraints on whether custom masks, dropout, GQA, and causal mode can be fused. Always verify the selected kernel and numerical behavior for the target PyTorch/CUDA version.
SDPA Dropout and Mode Semantics
scaled_dot_product_attention 是函数 API,不知道你的 module 处于 train/eval。dropout_p 应该显式传:
dropout_p = self.attn_dropout if self.training else 0.0
out = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
)若在 eval 时仍传非零 dropout_p,它可能继续随机丢 attention probability。反过来,如果训练时传 0.0,attention dropout 实际关闭。这个 bug 比 nn.Dropout 更隐蔽,因为函数式 SDPA 不会自动读取 model.training。
Functional attention APIs use the dropout_p argument directly. They do not automatically disable dropout in model.eval() unless the caller passes 0.0.
Backend Selection and Fallbacks
SDPA 后端大致有三类:
| Backend | Stores full scores? | Strength | Common fallback reason |
|---|---|---|---|
| math | yes | most general | always available |
| memory-efficient | no/full reduced | lower memory | mask/dropout/layout constraints |
| flash | no | fastest for supported shapes | dtype, device, head dim, mask form |
具体选择随 PyTorch/CUDA/GPU 而变,所以讲义和代码都不应假设“调用 SDPA 就一定是 FlashAttention”。工程上至少记录:
- PyTorch version;
- CUDA version and GPU;
- q/k/v dtype;
- q/k/v shape and head_dim;
- mask type;
- selected backend or profiler kernel name。
最实用的确认方式是 profiler:看 kernel 名、显存峰值和是否 materialize [B,H,T,T]。如果长上下文训练突然 OOM,很可能是某个自定义 mask 让 SDPA fallback 到 math backend。
Reference Equivalence Test
给 fused attention 写单元测试时,不要只检查 shape。应和 naive reference 在小尺寸上比较:
def attention_ref(q, k, v, visible):
scores = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
scores = scores.masked_fill(~visible, torch.finfo(scores.dtype).min)
probs = torch.softmax(scores, dim=-1)
return probs @ v
with torch.no_grad():
out_ref = attention_ref(q.float(), k.float(), v.float(), visible)
out = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=visible,
dropout_p=0.0,
is_causal=False,
)
torch.testing.assert_close(out.float(), out_ref, rtol=1e-3, atol=1e-3)小尺寸 reference 可以用 FP32,fused path 用 BF16/FP16;容忍度要反映 dtype。这个测试能抓住 head transpose、mask polarity、scale、causal convention 和 GQA repeat 错误。
Backward Pass
attention backward 也很贵。naive 实现为了反传会保存 \(P=\operatorname{softmax}(S)\)。FlashAttention backward 通常重新计算部分 scores/probabilities,而不是保存完整 \(P\),用更多计算换更少显存。
这是一种常见深度学习系统 trade-off:
\[ \text{memory} \leftrightarrow \text{recompute}. \]
Activation checkpointing 也是同一思想:forward 少存中间激活,backward 重新算。
Prefill and Decode Attention
训练和 prefill 通常是 many-query attention:
q: [B, Hq, T, Dh]
k: [B, Hkv, T, Dh]
v: [B, Hkv, T, Dh]
decode 通常是 one-query attention:
q_new: [B, Hq, 1, Dh]
k_cache: [B, Hkv, T_cache, Dh]
v_cache: [B, Hkv, T_cache, Dh]
这两种工作负载很不同:
| Stage | Dominant work | Kernel concern |
|---|---|---|
| prefill/train | large GEMM-like attention over \(T\times T\) | tiling and HBM traffic |
| decode | one/few query rows over long KV cache | cache bandwidth and layout |
decode attention 的输出形状仍是 [B,Hq,1,Dh],但它会读取越来越长的 KV cache。此时 GQA/MQA 的收益非常直接:减少 Hkv 就减少每步读取的 K/V bytes。
KV cache 写入也有 contract:
k_new, v_new = project_kv(x_new)
k_new = apply_rope(k_new, position=cache_pos)
cache.write(layer_id, batch_slot, cache_pos, k_new, v_new)
k_all, v_all = cache.view(layer_id, batch_slot, upto=cache_pos + 1)位置编码必须在写入 cache 前按逻辑位置应用;cache 地址是物理位置,RoPE/position id 是逻辑位置。二者混淆会让 shape 全对但长上下文质量异常。
Decode attention has tiny query length and long cached key/value length. A kernel optimized for training prefill may underperform when decode is memory-bandwidth-bound.
Engineering Checklist
实现 attention 时至少检查:
| Item | Failure mode |
|---|---|
| QKV shape | heads 和 sequence 维度转错 |
| mask dtype | bool mask/additive mask 混用 |
| mask shape | broadcast 到错误维度 |
| causal convention | j <= i 和 shift label 对不上 |
| softmax dtype | FP16 overflow or underflow |
| dropout | eval 时忘记关闭 |
| contiguous layout | transpose 后 view 错 |
| kernel backend | 退回 slow math kernel |
| GQA heads | KV heads 与 Q heads 分组不一致 |
| dropout_p | functional SDPA 在 eval 时仍可能 dropout |
| packed documents | causal mask 没有隔开样本边界 |
| all-masked rows | padding query 产生 NaN |
| backend fallback | custom mask 触发 math kernel |
| KV cache position | 物理 cache 地址和逻辑 position 混用 |
attention 的数学公式很短,但工程实现是 LLM 性能和稳定性的核心。长上下文训练、低延迟推理、显存利用率,很多时候都卡在这一层。