SSM & Mamba


State Space Model(SSM)

State Space Model 用连续动力系统描述序列:

\[ h'(t)=Ah(t)+Bx(t), \qquad y(t)=Ch(t)+Dx(t). \]

离散化后:

\[ h_t=\bar{A}h_{t-1}+\bar{B}x_t, \qquad y_t=Ch_t+Dx_t. \]

SSM 的关键优势是可以把 recurrence 写成 convolution 或 scan,从而在训练时并行、推理时线性递推。

NoteDefinition: State Space Model

A state space model represents a sequence by a latent state \(h_t\) updated over time and an output \(y_t\) read from that state: \[ h_t=\bar{A}h_{t-1}+\bar{B}x_t, \qquad y_t=Ch_t+Dx_t. \]

这里的 state \(h_t\) 是压缩记忆。Attention 保存所有历史 token 的 K/V;SSM 只保存固定大小的 state。二者的归纳偏置非常不同。

为了避免符号混乱,先固定一个最小 notation。设 token embedding 维度为 \(d\),序列长度为 \(T\),batch size 为 \(B\)。Mamba/SSM 内部通常还会有一个 state dimension \(N\),它不是词表维度,也不是 hidden size,而是每个通道用来承载动态系统记忆的维度。

symbol typical shape meaning
\(x_t\) \(d\) current token representation
\(h_t\) \(d_{\text{inner}}\times N\) recurrent SSM state
\(A\) \(d_{\text{inner}}\times N\) or structured state transition rates
\(B_t\) \(d_{\text{inner}}\times N\) input-dependent write vector
\(C_t\) \(d_{\text{inner}}\times N\) input-dependent read vector
\(\Delta_t\) \(d_{\text{inner}}\) token-dependent step size / timescale

这里最关键的直觉是:state 不是一个单一向量,而更像很多独立通道上的小动态系统。每个通道维护 \(N\) 个状态坐标,最后再投影回模型维度。这让模型既能保持线性时间,又能在每个 hidden channel 上表达多种衰减时间尺度。

Running Example: Remembering a Topic

想象一个很长的文档前面出现了主题:

This document is about GPU scheduling ...

几千 token 以后,模型又看到:

the allocation policy should ...

attention 的做法是让当前位置直接检索前文所有 token,但这需要 KV cache 随上下文增长。SSM/Mamba 的做法更像维护一个压缩状态:当看到 GPU scheduling 时,把“当前主题”写入 state;后面看到 allocation policy 时,从 state 中读出相关信息。

这不是精确复制前文,而是长期状态压缩。因此 Mamba 的优势和风险都很清楚:它便宜地保留长程状态,但不像 attention 那样显式回看任意 token。

NoteDefinition: Selective State Space Model

A selective SSM makes the transition or input matrices depend on the current token, allowing the model to choose what to remember or forget based on content.

Discretization

连续 SSM 不能直接在 token 序列上跑,需要离散化。若时间步长为 \(\Delta\),zero-order hold 给出

\[ \bar{A}=\exp(\Delta A), \]

\[ \bar{B} = (\Delta A)^{-1}(\exp(\Delta A)-I)\Delta B. \]

于是 recurrence 变为

\[ h_t=\bar{A}h_{t-1}+\bar{B}x_t. \]

如果 \(A\) 是 diagonal 或 structured matrix,\(\exp(\Delta A)\) 可以高效计算。这是 SSM 能工程化的关键。

在区间 \([t,t+\Delta]\) 内假设输入 \(x(t)\) 保持常数 \(x_t\),连续系统

\[ h'(s)=Ah(s)+Bx_t \]

的解为

\[ h(t+\Delta) = e^{\Delta A}h(t) + \int_0^\Delta e^{(\Delta-\tau)A}B x_t\,d\tau. \]

\(A\) 可逆,

\[ \int_0^\Delta e^{(\Delta-\tau)A}d\tau = A^{-1}(e^{\Delta A}-I). \]

因此

\[ \bar{A}=e^{\Delta A}, \qquad \bar{B}=A^{-1}(e^{\Delta A}-I)B. \]

不同论文/实现会采用等价参数化;重点是连续动力系统可以被离散成 token-level recurrence。

Stability and Timescales

离散化后最重要的数值条件是状态不能无限爆炸。若某个标量通道满足

\[ h_t=\bar{a}h_{t-1}+\bar{b}x_t, \]

那么长期稳定通常需要

\[ |\bar{a}|<1. \]

对连续参数来说,若 \(A\) 的特征值实部为负,则

\[ \bar{A}=\exp(\Delta A) \]

的特征值模长小于 \(1\)。这解释了为什么许多 SSM 实现会把 \(A\) 参数化成负数,例如用

\[ A=-\exp(A_{\log}) \]

之类的形式,保证连续系统是衰减的。\(\Delta_t\) 也必须为正,否则“时间步长”没有物理意义,且 \(\exp(\Delta_t A)\) 的稳定性解释会被破坏。常见做法是:

\[ \Delta_t=\mathrm{softplus}(\tilde{\Delta}_t+b_\Delta). \]

NoteDefinition: Timescale

The timescale of a stable scalar SSM channel is the rough number of steps over which its state remains influential before exponential decay makes it negligible.

标量情况下,如果

\[ \bar{a}=\exp(\Delta a),\qquad a<0, \]

那么 \(k\) 步后的记忆权重是

\[ \bar{a}^{k}=\exp(k\Delta a). \]

\(k\Delta |a|\approx 1\) 时,权重衰减到约 \(e^{-1}\)。所以有效 timescale 近似为

\[ \tau\approx \frac{1}{\Delta |a|}. \]

这给了 selectivity 一个很具体的意义:模型可以通过调节 \(\Delta_t\) 改变当前 token 对记忆的时间尺度。\(\Delta_t\) 大时,旧状态衰减快、新输入写入强;\(\Delta_t\) 小时,旧状态更持久,当前 token 对状态的扰动更温和。

WarningPitfall: Stable Does Not Mean Useful

Constraining \(A\) to be stable prevents explosion, but an overly small decay rate can make the state too inert, while an overly large decay rate can erase long-range information.

Convolution View

展开 recurrence:

\[ h_t = \bar{A}^t h_0 + \sum_{k=0}^{t} \bar{A}^{k}\bar{B}x_{t-k}. \]

\(h_0=0\),输出为

\[ y_t = C h_t = \sum_{k=0}^{t} C\bar{A}^{k}\bar{B}x_{t-k}. \]

这就是一个 convolution kernel:

\[ K_k=C\bar{A}^{k}\bar{B}. \]

所以 SSM 一方面像 RNN,可以递推;另一方面像 CNN,可以用 convolution 并行训练。这是它区别于普通 RNN 的核心。

如果 \(A\) 是 diagonal,记第 \(n\) 个 state coordinate 的离散衰减为 \(\bar{a}_n\),输入系数为 \(\bar{b}_n\),读出系数为 \(c_n\),则 kernel 的第 \(k\) 项是

\[ K_k = \sum_{n=1}^{N} c_n \bar{a}_n^k \bar{b}_n. \]

这说明 SSM kernel 是一组指数衰减基函数的线性组合。不同 \(\bar{a}_n\) 对应不同 timescale:接近 \(1\) 的通道保留长程信息,接近 \(0\) 的通道只看局部邻域。传统 structured SSM 的很多设计,都是在让这些 timescale 覆盖足够宽的范围。

对 diagonal \(\bar{A}=\mathrm{diag}(\bar{a}_1,\ldots,\bar{a}_N)\)

\[ \bar{A}^k = \mathrm{diag}(\bar{a}_1^k,\ldots,\bar{a}_N^k). \]

\(C=(c_1,\ldots,c_N)\)\(\bar{B}=(\bar{b}_1,\ldots,\bar{b}_N)^\top\),则

\[ C\bar{A}^k\bar{B} = \sum_{n=1}^{N}c_n\bar{a}_n^k\bar{b}_n. \]

所以 SSM convolution kernel 是指数基的 mixture。

这也解释了 fixed SSM 的限制:如果 \(A,B,C\) 都固定,那么同一个 kernel 会作用在所有 token 上。模型可以学习“普遍有用”的长程滤波器,却不能在看到某个具体 token 时临时改变写入强度或读出方向。

Why Classical SSMs Struggle With Tokens

如果 \(\bar{A},\bar{B},C\) 对所有位置固定,则 SSM 是 linear time-invariant system。它擅长建模平滑信号和长程卷积,但对离散语言有一个问题:每个 token 需要不同的信息选择策略。

例如:

The password is 4931. ... What is the password?

看到 4931 时模型应该强力写入 state;看到普通 function words 时应该少写或快速遗忘。固定 SSM kernel 很难根据内容动态改变记忆策略。

Mamba 的核心改动就是 selective:让部分 SSM 参数依赖当前输入。

Mamba

Mamba 的核心是 selective scan。与 attention 不同,它不显式构造 \(T\times T\) attention matrix,而是维护随序列递推的 state:

\[ h_t=A(x_t)h_{t-1}+B(x_t)x_t, \qquad y_t=C(x_t)h_t. \]

这给出线性时间复杂度:

\[ O(Td) \]

而标准 attention 是

\[ O(T^2d). \]

Mamba 的 selective 体现在 \(\Delta,B,C\) 等参数由输入 token 动态产生:

\[ \Delta_t,B_t,C_t = s_\theta(x_t). \]

于是模型可以在不同 token 上选择不同的记忆时间尺度。比如遇到章节标题时增大保留,遇到停用词时快速遗忘。

更贴近 Mamba block 的简化流程:

x
 -> linear projection
 -> depthwise convolution for local mixing
 -> selective SSM scan
 -> gate
 -> output projection

它把 local convolution、input-dependent SSM 和 gating 结合起来。和 Transformer block 不同,Mamba block 通常不需要显式 attention matrix,也没有标准 MLP 子层的同样分工。

NoteDefinition: Selectivity

Selectivity means that the state update parameters, such as \(\Delta_t\), \(B_t\), or \(C_t\), are functions of the current input token. This lets the model decide when to write, keep, or read information.

更接近实现的 block 可以拆成两条分支。给定输入

\[ X\in\mathbb{R}^{B\times T\times d}, \]

先做输入投影:

\[ [U,Z]=XW_{\text{in}}, \qquad U,Z\in\mathbb{R}^{B\times T\times d_{\text{inner}}}. \]

\(U\) 经过 depthwise causal convolution 做局部混合:

\[ \tilde{U}=\mathrm{SiLU}(\mathrm{DWConv}_{\text{causal}}(U)). \]

然后从 \(\tilde{U}_t\) 生成 selective SSM 参数:

\[ \Delta_t,B_t,C_t = \mathrm{Linear}(\tilde{U}_t). \]

SSM scan 得到 \(Y_t\),gate 分支给出

\[ G_t=\mathrm{SiLU}(Z_t), \qquad O_t=(Y_t\odot G_t)W_{\text{out}}. \]

所以 Mamba block 可以看成:

input
  -> in_proj -> x branch -> causal depthwise conv -> selective SSM -> multiply by gate
            -> z branch -> SiLU gate -----------------------------^
  -> out_proj

Depthwise causal convolution 的作用不是替代 SSM,而是提供短程 pattern detector。语言里很多局部结构,例如标点、词缀、括号、代码缩进,没必要完全依赖长程 state;先用一个小卷积处理邻域,再把结果送进 SSM,会让写入/读取参数更有局部上下文。

WarningPitfall: Causal Conv Must Be Causal

If the depthwise convolution sees future tokens during language modeling, the block leaks labels. Padding and trimming around the convolution are as important as the SSM recurrence itself.

一个简化的 shape checklist:

tensor shape note
x [B, T, d_model] residual stream input
u, z [B, T, d_inner] projected branches
dt [B, T, d_inner] positive step sizes after softplus
B_t, C_t [B, T, N] or broadcasted input-dependent write/read parameters
state [B, d_inner, N] recurrent memory
y [B, T, d_inner] scan output before gate/out projection

Selective Scan

在 Mamba-style diagonal SSM 中,一个通道的更新可以写成

\[ \bar{a}_{t,n}=\exp(\Delta_t a_n), \]

\[ \bar{b}_{t,n} \approx \Delta_t B_{t,n} \]

或使用更精确的 zero-order-hold 形式。于是每个 hidden channel 和 state coordinate 上都有

\[ h_{t,n} = \bar{a}_{t,n}h_{t-1,n} + \bar{b}_{t,n}\tilde{u}_t, \]

读出为

\[ y_t=\sum_{n=1}^{N} C_{t,n}h_{t,n}. \]

这里 \(\bar{a}_{t,n}\)\(\bar{b}_{t,n}\)\(C_{t,n}\) 都可以随 token 变化,所以不再是固定 convolution kernel。这就是 selective SSM 和 classical LTI SSM 的分界线。

普通 recurrence 看起来必须一步一步算:

\[ h_t=a_t h_{t-1}+b_t. \]

但这种 affine recurrence 可以组合:

\[ h_t = a_t(a_{t-1}h_{t-2}+b_{t-1})+b_t = (a_ta_{t-1})h_{t-2}+(a_tb_{t-1}+b_t). \]

把每一步看成 pair \((a_t,b_t)\),组合运算为

\[ (a_2,b_2)\circ(a_1,b_1) = (a_2a_1,\;a_2b_1+b_2). \]

这个组合是 associative 的,因此可以用 parallel scan 加速。Mamba 的工程价值就在这里:它保留 recurrence 的状态压缩,又能在训练时并行。

ImportantTheorem: Affine Recurrence Composition Is Associative

Define \[ (a_2,b_2)\circ(a_1,b_1) = (a_2a_1,\ a_2b_1+b_2). \] Then \(\circ\) is associative.

计算左结合:

\[ ((a_3,b_3)\circ(a_2,b_2))\circ(a_1,b_1) = (a_3a_2,\ a_3b_2+b_3)\circ(a_1,b_1) \]

\[ = (a_3a_2a_1,\ a_3a_2b_1+a_3b_2+b_3). \]

计算右结合:

\[ (a_3,b_3)\circ((a_2,b_2)\circ(a_1,b_1)) = (a_3,b_3)\circ(a_2a_1,\ a_2b_1+b_2) \]

\[ = (a_3a_2a_1,\ a_3(a_2b_1+b_2)+b_3) = (a_3a_2a_1,\ a_3a_2b_1+a_3b_2+b_3). \]

两者相同,因此组合律成立。parallel prefix scan 正是利用这个性质把 sequential recurrence 并行化。

在 tensor 实现里,pair \((a_t,b_t)\) 不是标量,而是可以 broadcast 到 [B, d_inner, N] 的张量。组合仍然逐元素成立:

\[ (A_2,B_2)\circ(A_1,B_1) = (A_2\odot A_1,\ A_2\odot B_1+B_2). \]

这让实现可以在 chunk 内做 prefix scan,在 chunk 间传递最终 state。一个概念版流程是:

def scan_chunk(a, b, h0):
    # a, b: [B, chunk, d_inner, N]
    h = h0
    ys = []
    for i in range(a.size(1)):
        h = a[:, i] * h + b[:, i]
        ys.append(h)
    return torch.stack(ys, dim=1), h

def chunked_scan(a, b, chunk_size):
    h = torch.zeros_like(b[:, 0])
    outs = []
    for start in range(0, a.size(1), chunk_size):
        end = start + chunk_size
        y, h = scan_chunk(a[:, start:end], b[:, start:end], h)
        outs.append(y)
    return torch.cat(outs, dim=1)

这段代码仍然是 Python loop,不是高性能实现。真正 kernel 会把 chunk 内 scan、读写 state、乘 \(C_t\)、gate 和部分重计算融合起来,减少显存读写。理解 chunked scan 的意义在于:SSM 的推理 state 可以跨 chunk 传递,而训练时又可以在 chunk 内用并行 scan 提高吞吐。

TipKernel Intuition

The theoretical win is associativity. The practical win comes only when the scan avoids materializing large intermediate states in slow memory.

Chunk Summaries and Parallel Scan

上面的 chunked_scan 伪代码还是“一个 chunk 接一个 chunk”地串行跑。真正并行化时,关键是每个 chunk 不只产生内部输出,还产生一个 chunk summary。对标量 affine recurrence:

\[ h_t=a_t h_{t-1}+b_t, \]

从位置 \(l\)\(r\) 的整个 chunk 可以压缩成一个 pair:

\[ S_{l:r}=(A_{l:r},B_{l:r}), \]

使得

\[ h_r=A_{l:r}h_{l-1}+B_{l:r}. \]

其中

\[ A_{l:r}=\prod_{t=l}^{r}a_t, \]

\[ B_{l:r} = b_r+a_rb_{r-1}+a_ra_{r-1}b_{r-2}+\cdots+ \left(\prod_{t=l+1}^{r}a_t\right)b_l. \]

这就是 chunk 的“传递函数”。如果我们有很多 chunk summary,就可以再对这些 summary 做 prefix scan,得到每个 chunk 的初始 state;然后每个 chunk 内部可以并行计算本 chunk 的所有 hidden states。

ImportantTheorem: Chunk Summaries Preserve the Recurrence

If a chunk summary \(S_{l:r}=(A_{l:r},B_{l:r})\) satisfies \[ h_r=A_{l:r}h_{l-1}+B_{l:r}, \] then composing chunk summaries with the same affine operator gives the exact state at chunk boundaries.

设两个相邻 chunk 的 summaries 为

\[ S_{l:m}=(A_1,B_1),\qquad S_{m+1:r}=(A_2,B_2). \]

第一个 chunk 给出

\[ h_m=A_1h_{l-1}+B_1. \]

第二个 chunk 给出

\[ h_r=A_2h_m+B_2. \]

代入可得

\[ h_r=A_2(A_1h_{l-1}+B_1)+B_2 = (A_2A_1)h_{l-1}+(A_2B_1+B_2). \]

这正是

\[ (A_2,B_2)\circ(A_1,B_1) \]

的定义。因此 chunk summary 的组合和逐 token recurrence 等价。

概念上,一个更接近 GPU scan 的流程是:

def summarize_chunk(a, b):
    # a, b: [B, C, D, N], C is chunk size
    A = torch.ones_like(a[:, 0])
    B = torch.zeros_like(b[:, 0])
    for i in range(a.size(1)):
        A, B = a[:, i] * A, a[:, i] * B + b[:, i]
    return A, B

def apply_chunk(a, b, h0):
    h = h0
    out = []
    for i in range(a.size(1)):
        h = a[:, i] * h + b[:, i]
        out.append(h)
    return torch.stack(out, dim=1), h

def two_level_scan(a, b, chunk_size):
    chunks = list(split_chunks(a, b, chunk_size))
    summaries = [summarize_chunk(ca, cb) for ca, cb in chunks]
    init_states = exclusive_prefix_apply_summaries(summaries)

    outs = []
    for (ca, cb), h0 in zip(chunks, init_states):
        y, _ = apply_chunk(ca, cb, h0)
        outs.append(y)
    return torch.cat(outs, dim=1)

这里 exclusive_prefix_apply_summaries 不是普通 Python loop,而是对 chunk summaries 做 exclusive parallel prefix:第 \(k\) 个 chunk 的初始 state 来自前 \(k-1\) 个 chunks,而不是包含当前 chunk 的 summary。实际 kernel 还会把 a,b,C_t 的计算和 readout 融合,避免在 HBM 中写出完整 [B,T,d_{\text{inner}},N] 状态。

WarningPitfall: Chunk Boundaries Are Part of Correctness

If chunk summaries are computed with the wrong order, or if the initial state of each chunk is off by one, loss can look plausible while long-range behavior is wrong.

一个很直接的测试是让同一段序列用不同 chunk size 跑,输出必须一致:

ref = chunked_scan(a, b, chunk_size=1)
for size in [2, 4, 16, 64]:
    got = chunked_scan(a, b, chunk_size=size)
    assert torch.allclose(ref, got, atol=1e-4, rtol=1e-4)

Backward Recurrence and Recomputation

理解 selective scan 的 backward,可以先看最小标量形式:

\[ h_t=a_t h_{t-1}+b_t, \qquad y_t=c_t h_t. \]

设上游梯度为 \(\bar{y}_t=\partial L/\partial y_t\),hidden state 的伴随变量为 \(\bar{h}_t=\partial L/\partial h_t\)。因为 \(h_t\) 同时影响 \(y_t\) 和未来的 \(h_{t+1}\),反向递推是

\[ \bar{h}_t = c_t\bar{y}_t+a_{t+1}\bar{h}_{t+1}. \]

然后每一步的局部梯度为

\[ \bar{c}_t=\bar{y}_t h_t, \qquad \bar{a}_t=\bar{h}_t h_{t-1}, \qquad \bar{b}_t=\bar{h}_t. \]

tensor 版本只是把乘法换成逐元素乘法和相应的 reduce。对 Mamba 来说,\(a_t\) 又来自

\[ a_t=\exp(\Delta_t A), \]

因此还要继续链式求导:

\[ \frac{\partial a_t}{\partial \Delta_t} = A\exp(\Delta_t A) = Aa_t, \qquad \frac{\partial a_t}{\partial A} = \Delta_t a_t. \]

这解释了为什么实现要小心保存或重算 \(h_t,a_t,\Delta_t\)。如果把所有 \(h_t\) 都保存下来,activation memory 会出现 [B,T,d_{\text{inner}},N];如果不保存,就需要在 backward 中从 chunk 初始 state 重算 forward hidden states。

NoteDefinition: Recompute Backward

Recompute backward saves memory by storing only compact boundary states and recomputing within-chunk forward states during the backward pass.

概念版 backward loop:

def scan_backward(a, b, c, h, grad_y):
    # h contains h_t, or is recomputed chunk by chunk.
    grad_h_next = torch.zeros_like(h[:, 0])
    grad_a = torch.zeros_like(a)
    grad_b = torch.zeros_like(b)
    grad_c = torch.zeros_like(c)

    for t in reversed(range(a.size(1))):
        grad_h = grad_y[:, t][..., None] * c[:, t] + grad_h_next
        h_prev = h[:, t - 1] if t > 0 else torch.zeros_like(h[:, 0])

        grad_c[:, t] = grad_y[:, t][..., None] * h[:, t]
        grad_a[:, t] = grad_h * h_prev
        grad_b[:, t] = grad_h
        grad_h_next = grad_h * a[:, t]

    return grad_a, grad_b, grad_c

这段代码省略了 gate、projection、softplus、ZOH 和 chunk boundary,但它暴露了一个工程事实:scan backward 本身也是一个反向 recurrence。高性能实现通常会把 backward 分成 chunk 内反向扫描、chunk summary 梯度传播和参数投影梯度累积。

WarningPitfall: Forward Equivalence Does Not Prove Backward Equivalence

A custom scan kernel must be checked against autograd for both outputs and gradients. Matching forward values alone does not catch incorrect chunk-boundary gradients.

Hardware-Aware Parallelism

Selective SSM 让参数随输入变化,传统固定 convolution trick 不再直接适用。Mamba 的工程解法是 hardware-aware scan:在 GPU 上融合 projection、scan、gating 和必要的 recomputation,避免把巨大的中间 state 全部 materialize 到 HBM。

训练时关键瓶颈不是理论 FLOPs,而是 memory traffic。一个慢实现可能数学上是 \(O(T)\),但因为频繁读写大 tensor,实际比 attention 还慢。Mamba 的论文强调 hardware-aware algorithm,就是因为线性复杂度只有和合适 kernel 结合才有意义。

粗略看,若把所有中间 state materialize,状态张量大小会像

\[ B\times T\times d_{\text{inner}}\times N \]

增长。即使 \(N\) 不大,\(T\) 很长时这个张量也会非常重。更合理的 fused scan 会尽量只保存必要输出,反向时重算部分中间量,类似 activation checkpointing 的思想:用额外计算换显存和 HBM bandwidth。

implementation choice memory traffic risk
Python loop recurrence low conceptual complexity kernel launch 多,训练不可并行
materialized full state easy backward [B,T,d_inner,N] 巨大,HBM 压力高
fused chunked scan avoids most state writes kernel 复杂,layout/debug 成本高
recompute in backward saves activation memory backward FLOPs 增加

这也是为什么读 Mamba 时不能只看复杂度表。\(O(T)\) 是必要条件,不是充分条件。一个好的 Mamba kernel 必须同时处理:

  1. causal convolution 的局部缓存;
  2. selective 参数投影;
  3. \(\Delta\) softplus 和 \(A\) exponentiation;
  4. scan recurrence;
  5. \(C_t\) readout、gate 和 output projection;
  6. backward 中的重算与梯度累积。
WarningPitfall: Linear-Time Formula Does Not Guarantee Fast Kernels

An \(O(T)\) recurrence can still be slow if it materializes large states or launches many small kernels. Efficient SSMs need scan-friendly layouts and fused GPU kernels.

Mamba vs. Transformer

Aspect Transformer Mamba/SSM
token mixing attention over all tokens recurrent state update
training parallelism high scan-based high
inference memory KV cache grows with context fixed-size state
long context cost quadratic attention / KV memory linear recurrence
weakness expensive long context less direct content-addressed retrieval

Mamba 不是“替代一切 attention”的银弹,而是另一种序列建模归纳偏置:用可选择的动态系统承载长程信息,用固定状态减少 KV cache 压力。

Memory and Retrieval Trade-Off

Attention computes

\[ o_t=\sum_{j\le t}\alpha_{tj}v_j, \]

where each past token can be addressed directly. Mamba computes

\[ h_t=A_t h_{t-1}+B_t x_t, \qquad y_t=C_t h_t. \]

Everything from the past must pass through \(h_t\). If state dimension is \(N\), then the model has only \(N\) channels to compress history. This is efficient, but exact retrieval of an arbitrary old token is not as natural as attention.

WarningPitfall: Fixed State Is Not the Same as Unlimited Memory

Mamba’s recurrent state can summarize long histories, but information not preserved in the state is lost. Attention pays memory to keep explicit token-level access.

Autoregressive Inference State

在 causal language modeling 中,Mamba 推理仍然是逐 token 生成,但每步不需要保存所有历史 KV。每层通常需要两类 cache:

cache shape intuition purpose
convolution state [B, d_inner, conv_width] 为 depthwise causal conv 保存最近几个 token
SSM state [B, d_inner, N] 保存 long-range recurrent memory

单步 decode 可以抽象成:

def step(x_t, conv_state, ssm_state, params):
    u_t, z_t = in_proj(x_t).chunk(2, dim=-1)

    conv_state = roll_and_insert(conv_state, u_t)
    u_t = silu(depthwise_conv_step(conv_state))

    dt_t, B_t, C_t = param_proj(u_t)
    a_t = torch.exp(dt_t[..., None] * params.A)
    b_t = input_write(dt_t, B_t, u_t)

    ssm_state = a_t * ssm_state + b_t
    y_t = (C_t * ssm_state).sum(dim=-1)
    y_t = y_t * silu(z_t)

    out_t = out_proj(y_t)
    return out_t, conv_state, ssm_state

和 Transformer KV cache 相比,Mamba cache 不随已生成长度 \(L\) 增长:

\[ \text{Mamba cache per layer} \approx B(d_{\text{inner}}w+d_{\text{inner}}N), \]

其中 \(w\) 是 causal conv width。Transformer KV cache 则近似为

\[ \text{KV cache per layer} \approx 2B L n_{\text{kv}} d_{\text{head}}. \]

因此长流式生成、边缘设备和内存受限 serving 是 SSM 很有吸引力的场景。但代价也明确:Transformer 保存的是 token-addressable memory;Mamba 保存的是压缩后的 dynamical state。

WarningPitfall: Reset State at Sequence Boundaries

When batching independent prompts or packed streams, recurrent states must be reset at true sequence boundaries. Otherwise information leaks across users or documents.

Continuous Batching State Table

Transformer serving 的 KV cache 通常以 request 为单位管理;Mamba serving 也一样,只是 state 更小。连续 batching 下,batch row 会不断重排:某个 request 这一步在 row 3,下一步可能在 row 0。于是不能把 conv_state[i]ssm_state[i] 当作长期身份,必须用 request id 映射:

state_table = {
    req_id: {
        "conv": torch.zeros(num_layers, d_inner, conv_width),
        "ssm": torch.zeros(num_layers, d_inner, state_dim),
        "pos": 0,
    }
}

每个 decode tick 先 gather active requests 的 state,跑一步,再 scatter 回去:

def decode_tick(active_reqs):
    conv = stack([state_table[r.id]["conv"] for r in active_reqs])
    ssm = stack([state_table[r.id]["ssm"] for r in active_reqs])
    token = torch.tensor([r.last_token for r in active_reqs], device=device)

    logits, conv, ssm = mamba_step(token, conv, ssm)

    for i, req in enumerate(active_reqs):
        state_table[req.id]["conv"] = conv[i]
        state_table[req.id]["ssm"] = ssm[i]
        state_table[req.id]["pos"] += 1

这个 gather/scatter 看起来像实现细节,但它定义了用户隔离边界。若请求结束或取消,必须删除对应 state;若 batch row 被复用,不能继承旧 row 的 recurrent memory。

WarningPitfall: Batch Row Is Not Request Identity

In continuous batching, request identity must be stored outside the tensor row index. Otherwise state can leak when rows are compacted or reused.

Reset Masks for Packing and Streaming

训练时 packed sequence 也需要 reset。设 reset_t=1 表示当前位置是新样本的第一个 token,那么 recurrence 应该在更新前清空旧 state:

\[ h_{t-1}^{\text{used}} = (1-r_t)h_{t-1}, \qquad h_t=a_t h_{t-1}^{\text{used}}+b_t. \]

其中 \(r_t\in\{0,1\}\)。如果 batch 里有多条 packed stream,reset shape 通常是 [B,T],再 broadcast 到 [B,T,d_{\text{inner}},N]

def selective_scan_with_reset(a, b, reset):
    # a, b: [B, T, D, N], reset: [B, T]
    h = torch.zeros_like(b[:, 0])
    ys = []
    for t in range(a.size(1)):
        keep = (~reset[:, t]).to(h.dtype)[:, None, None]
        h = h * keep
        h = a[:, t] * h + b[:, t]
        ys.append(h)
    return torch.stack(ys, dim=1)

注意 reset 的时间点:如果 x_t 是新样本第一个 token,要在读入 x_t 前 reset;否则新样本的第一个输出已经看到了上一个样本的 state。

TipSmoke Test: Reset Equivalence

Run two documents separately, then run them packed with a reset mask at the boundary. The outputs for the second document should match the separate run.

Minimal Mamba-Like Pseudocode

This is not the real fused implementation, but it shows the dataflow:

def selective_scan(u, dt, A, B, C):
    # u:  [B, T, d_inner]
    # dt: [B, T, d_inner]
    # A:  [d_inner, N] with negative entries
    # B:  [B, T, N]
    # C:  [B, T, N]
    batch, steps, width = u.shape
    state_dim = A.size(-1)
    h = torch.zeros(batch, width, state_dim, device=u.device)
    ys = []

    for t in range(steps):
        at = torch.exp(dt[:, t, :, None] * A[None])
        bt = dt[:, t, :, None] * B[:, t, None, :] * u[:, t, :, None]
        h = at * h + bt

        y = (C[:, t, None, :] * h).sum(dim=-1)
        ys.append(y)

    return torch.stack(ys, dim=1)

真实实现会把这个 loop 变成 scan kernel,并处理更精确的 discretization、projection、gate、normalization、layout 和 backward recomputation。上面这段的价值是暴露 state 的形状:h 保存的是每个 expanded channel 的 \(N\) 维状态,而不是整个历史序列。

Why It Matters for LLM Systems

在长上下文和边缘推理场景中,KV cache 的显存增长是核心瓶颈。SSM/Mamba 的吸引力在于推理状态不随上下文线性膨胀到同等规模。但 attention 的显式检索能力很强,因此混合架构仍然很有吸引力:局部/全局 attention 负责检索,SSM 负责长程压缩状态。

When to Prefer Which

Situation Attention is attractive SSM/Mamba is attractive
exact copying direct token retrieval harder unless state stores it
very long streams KV cache expensive fixed recurrent state
batch prefill mature kernels scan kernels needed
edge inference memory bottleneck small state
tool/code tasks precise dependencies may need hybrid attention
audio/genomics long dense signals natural sequence backbone

The design question is not “Transformer or Mamba forever”. It is which memory mechanism matches the task: explicit addressable memory, compressed dynamical state, or a hybrid.

Implementation Checklist

实现或调试 Mamba-style block 时,可以按下面顺序检查:

  1. \(A\) 的参数化是否保证稳定,例如连续 eigenvalues 为负;
  2. \(\Delta_t\) 是否经过 softplus 或其他正值约束;
  3. causal depthwise convolution 是否只看当前和过去 token;
  4. scan 输出是否和朴素 Python recurrence 在小 batch 上逐元素对齐;
  5. chunked scan 的 chunk boundary 是否正确传递 final state;
  6. 不同 chunk size 的 scan 输出是否一致;
  7. custom scan kernel 的 backward 是否和 autograd reference 对齐;
  8. packed sequence 或多用户 batch 是否在真实边界 reset recurrent state;
  9. continuous batching 是否用 request id 管理 state,而不是依赖 batch row;
  10. training mode 是否没有把未来 token 泄漏进 convolution padding;
  11. inference cache 是否同时保存 conv state 和 SSM state;
  12. mixed precision 下 exp(dt * A) 是否出现 underflow/overflow;
  13. backward 是否因为 materializing [B,T,d_inner,N] state 而爆显存;
  14. 长上下文 benchmark 是否同时报告 tokens/s、显存、state size 和 exact-copy 任务表现;
  15. 与 Transformer 比较时是否区分 prefill、decode、streaming 和 retrieval-heavy workload。

一个最小单元测试很有用:随机生成很短的 u, dt, A, B, C,分别用朴素 loop 和 scan kernel 计算输出,检查最大误差:

ref = selective_scan_loop(u, dt, A, B, C)
fast = selective_scan_kernel(u, dt, A, B, C)
err = (ref - fast).abs().max()
assert err < 1e-4

然后再做两个 causality tests:

  1. 修改未来 token \(x_{t+1}\),确认 \(y_t\) 不变;
  2. 在 batch 中拼接两条独立序列,确认 reset 后后一条不受前一条 state 影响。
  3. 对 custom scan kernel 做 gradcheck 或和 PyTorch reference backward 比较梯度;
  4. 在 continuous batching 中重排 batch rows,确认同一 request 的输出不随 row id 改变。
WarningPitfall: Passing Perplexity Is Not Enough

A Mamba implementation can achieve plausible loss while still leaking future tokens through convolution padding or packed-sequence state carry. Causality and reset tests are separate correctness checks.

References