LLaDA / Masked Diffusion LM

不同于 GPT 系列 autoregressive language model,LLaDA 属于 masked diffusion language model (MDM)。它把文本生成从“从左到右逐 token 采样”改写为“从 masked sequence 逐步去噪”。这使它更接近 diffusion 的训练思想,但变量空间是离散 token 而不是连续图像 latent。


Running Example: Fill Several Masks

假设真实句子是:

I really like diffusion models

forward process 在某个时间 \(t\) 把它腐蚀成:

I [M] like [M] models

模型的任务不是预测“下一个 token”,而是同时根据上下文预测两个 masked positions:

\[ p_\theta(\text{really}\mid I,[M],like,[M],models,t), \]

\[ p_\theta(\text{diffusion}\mid I,[M],like,[M],models,t). \]

采样时可以先填高置信度位置,再用新填入的 token 继续推断剩余 mask。于是生成过程像 iterative refinement,而不是严格从左到右。

Masked Diffusion Model

Masked diffusion models (MDMs) (Austin et al., 2021a; Lou et al., 2023; Ou et al., 2024) 以不同于 autoregressive models 的方式定义模型分布 \(p_\theta(x_0)\)。这些模型引入了一个由时间 \(t \in [0, 1]\) 索引的 forward process \(\{x_t\}\)。该过程逐渐且独立地 mask 序列 \(x_0\) 中的所有 tokens。在时间 \(t = 0\) 时,数据点 \(x_0\) 被完全观测到,没有任何 masks;而对于 \(t \in (0, 1]\)\(x_t\) 表示在期望中具有不同 mask ratios 的 latent variables。形式上,给定 \(x_0\)\(x_t\) 的 conditional distribution 由一个完全分解的形式定义: \[ q_{t|0}(x_t|x_0) = \prod_{i=1}^{L} q_{t|0}(x^i_t|x^i_0), \quad (7) \] 其中每个 token 的 conditional distribution 由下式给出: \[ q_{t|0}(x^i_t|x^i_0) = \begin{cases} 1 - t, & x^i_t = x^i_0, \\ t, & x^i_t = \text{M}. \end{cases} \quad (8) \]

NoteDefinition: Masked Diffusion Language Model

A masked diffusion language model defines a corruption process that progressively replaces clean tokens with a mask token, and learns a reverse process that predicts masked tokens from partially observed sequences.

LLaDA 的关键不是“把 BERT 做大”这么简单。BERT 的 masked language modeling 通常服务于 representation learning;LLaDA 把 masked denoising 组织成 generative process,目标是在采样时从全 mask 或部分 mask 状态逐步恢复完整序列。换句话说,BERT 学的是“这个句子里缺了几个词时该怎么补”;LLaDA 学的是“从噪声文本分布出发,沿着一条可控的去噪轨迹走回自然语言分布”。

这里最容易混淆的是 noise level \(t\) 的含义。在 BERT-style MLM 中,mask ratio 常常是一个数据增强超参数;在 MDM 中,\(t\) 是概率模型的一部分。模型不仅要知道当前位置是 [MASK],还要知道当前整句处于哪个噪声强度:\(t\) 小时,周围上下文基本可信,任务接近局部填空;\(t\) 大时,大部分 token 都被擦掉,模型必须先恢复主题、句法骨架和长程一致性。

TipTeaching View: What Is the Latent Variable?

For an AR LM, the latent state during decoding is a prefix. For a masked diffusion LM, the latent state is a partially observed sequence plus a noise level. This is why the same cross entropy can train two very different generative stories.

从工程角度看,可以把一条训练样本拆成四个对象:

object shape meaning
input_ids [B, L] 被 corruption 后的 \(x_t\),其中部分位置为 mask_token_id
labels [B, L] 原始 \(x_0\),但只在 masked positions 上监督
loss_mask [B, L] 哪些位置参与 CE,通常等价于 input_ids == mask_token_id
t [B] or [B, 1] 每个样本的 noise level / mask ratio

注意 labels 不是“右移一位”的 next-token label。MDM 的监督位置由 corruption process 产生,而 AR LM 的监督位置由 causal factorization 产生。两者都用 cross entropy,但训练语义完全不同。

Absorbing Mask Process

对于每个 token,forward process 是一个 absorbing process:一旦变成 [MASK],之后保持 [MASK]。对单个位置:

\[ q(x_t\mid x_0) = (1-t)\delta_{x_0}+t\delta_{\text{M}}. \]

因此未被 mask 的概率是 \(1-t\),被 mask 的概率是 \(t\)。对长度 \(L\) 的序列,mask 数量

\[ N_t=\sum_{i=1}^{L}\mathbf{1}[x_t^i=\text{M}] \]

满足二项分布

\[ N_t\sim\mathrm{Binomial}(L,t), \]

于是

\[ \mathbb{E}[N_t]=Lt. \]

方差为

\[ \mathrm{Var}(N_t)=Lt(1-t). \]

这个方差在实现里很重要。若 \(L\) 很短或 \(t\) 很小,某个 batch 里可能出现 \(N_t=0\) 的样本,即没有任何 token 需要预测;若直接除以 masked token 数量,就会产生除零或极大梯度权重。因此真实 collator 往往要做至少一个保护:

  1. 对每条样本强制至少 mask 一个 token;
  2. 或者在 loss normalization 里用 clamp_min(1)
  3. 或者用 fixed-count corruption:先采样 \(n=\lceil Lt\rceil\),再从 \(L\) 个位置中无放回选择 \(n\) 个 mask。

每个位置独立以概率 \(t\) 被 mask。令

\[ Z_i=\mathbf{1}[x_t^i=\text{M}]. \]

\(\mathbb{E}[Z_i]=t\)。线性期望给出

\[ \mathbb{E}[N_t] = \mathbb{E}\left[\sum_i Z_i\right] = \sum_i\mathbb{E}[Z_i] = Lt. \]

因为 \(Z_i\) 独立,且 \(\mathrm{Var}(Z_i)=t(1-t)\),所以

\[ \mathrm{Var}(N_t) = \sum_i\mathrm{Var}(Z_i) = Lt(1-t). \]

Bernoulli masking 和 fixed-count masking 的差别可以这样理解:

corruption advantage caveat
Bernoulli per token 数学最干净,\(q(x_t\mid x_0)\) 完全分解 batch 内 mask 数量波动,短序列可能无监督 token
fixed count 每条样本计算量更稳定,便于课程实验复现 严格来说位置之间不再独立,推导要改成超几何/组合形式

这个过程的好处是简单、稳定、适合大词表。它的缺点也明显:噪声状态很特殊,全部信息都被压成一个 [MASK] 符号,不像连续 diffusion 那样保留带噪的细粒度信息。因此 MDM 的建模难点并不是“怎么给 token 加高斯噪声”,而是“只看到离散缺口时,如何用上下文和时间条件恢复全局一致的文本”。

Posterior of the Absorbing Process

forward process 定义了从 clean sentence \(x_0\) 到 noisy sentence \(x_t\) 的腐蚀;reverse sampler 需要从较高噪声 \(t\) 走到较低噪声 \(s<t\)。因此一个关键量是:

\[ q(x_s^i\mid x_t^i,x_0^i), \qquad 0\le s<t\le 1. \]

对 absorbing mask,可以把每个 token 的 mask time 看成随机变量 \(\tau_i\sim\mathrm{Uniform}(0,1)\)。若 \(\tau_i\le t\),则 \(x_t^i=\text{M}\);若 \(\tau_i>t\),则 \(x_t^i=x_0^i\)。于是:

\[ x_t^i= \begin{cases} \text{M},& \tau_i\le t,\\ x_0^i,& \tau_i>t. \end{cases} \]

如果 \(x_t^i\ne\text{M}\),那么它还没有被吸收,且因为 \(s<t\),在更早时间 \(s\) 也必然没有被吸收:

\[ q(x_s^i=x_0^i\mid x_t^i=x_0^i,x_0^i)=1. \]

如果 \(x_t^i=\text{M}\),则只知道 \(\tau_i\le t\)。它在时间 \(s\) 已经 mask 的概率是

\[ q(x_s^i=\text{M}\mid x_t^i=\text{M},x_0^i) = \Pr(\tau_i\le s\mid \tau_i\le t) = \frac{s}{t}. \]

所以它在从 \(t\) 走回 \(s\) 时被恢复成 clean token 的概率是

\[ q(x_s^i=x_0^i\mid x_t^i=\text{M},x_0^i) = 1-\frac{s}{t}. \]

ImportantTheorem: Absorbing-Mask Posterior

For the Bernoulli absorbing mask process with \(0\le s<t\le1\), \[ q(x_s^i\mid x_t^i=\text{M},x_0^i) = \frac{s}{t}\delta_{\text{M}} + \left(1-\frac{s}{t}\right)\delta_{x_0^i}, \] and if \(x_t^i=x_0^i\), then \(q(x_s^i=x_0^i\mid x_t^i,x_0^i)=1\).

用吸收时间 \(\tau_i\) 表示 token 第一次变成 mask 的时间。因为

\[ q(x_t^i=\text{M}\mid x_0^i)=t, \]

可以取等价构造 \(\tau_i\sim\mathrm{Uniform}(0,1)\),并令 \(x_t^i=\text{M}\) 当且仅当 \(\tau_i\le t\)。若 \(x_t^i=\text{M}\),条件事件是 \(\tau_i\le t\)。于是

\[ \Pr(x_s^i=\text{M}\mid x_t^i=\text{M},x_0^i) = \Pr(\tau_i\le s\mid\tau_i\le t) = \frac{s}{t}. \]

补事件就是 \(x_s^i=x_0^i\)。若 \(x_t^i=x_0^i\),则 \(\tau_i>t\),从而 \(\tau_i>s\),所以 \(x_s^i=x_0^i\)

这个后验告诉我们 reverse step 应该做什么:对当前还是 [MASK] 的位置,以概率 \(1-s/t\) 解除 mask,并用模型预测的 clean token 分布替代未知的 \(x_0^i\)。也就是说,模型不是直接学

\[ p_\theta(x_s\mid x_t), \]

而是先学 clean-token predictor

\[ p_\theta(x_0^i\mid x_t,t), \]

再把它和 absorbing posterior 组合成从 \(t\)\(s\) 的 transition。

一个 posterior-style reverse step 可以写成:

@torch.no_grad()
def posterior_reverse_step(model, ids, t, s, mask_id, temperature):
    # ids: [B, L], current x_t
    active = ids.eq(mask_id)
    logits = model(input_ids=ids, noise_level=t).logits
    probs = torch.softmax(logits / temperature, dim=-1)
    proposal = torch.multinomial(
        probs.view(-1, probs.size(-1)),
        num_samples=1,
    ).view_as(ids)

    unmask_prob = (1.0 - s / t).clamp(min=0.0, max=1.0)
    coin = torch.rand_like(ids.float()) < unmask_prob[:, None]
    reveal = active & coin
    return torch.where(reveal, proposal, ids)

这个版本和 confidence-based remasking 的区别很重要:posterior step 按理论 transition 随机决定释放概率;confidence sampler 按模型置信度选择释放哪些位置。前者更贴近 forward process 的后验结构,后者更像工程启发式,常用于提高生成质量或减少早期错误 commit。

WarningPitfall: Reverse Step and Confidence Policy Are Different

The absorbing posterior tells how many masked variables should become clean when moving from \(t\) to \(s\). A confidence policy decides which positions to commit using model scores. Mixing them without tracking the schedule changes the sampler.

Reverse Process

反向过程不是预测下一个 token,而是在任意 mask pattern 下预测被 mask 的 token:

\[ p_\theta(x_0\mid x_t) = \prod_{i:x_t^i=\text{M}} p_\theta(x_0^i\mid x_t,t). \]

训练目标通常是 masked-token cross entropy:

\[ \mathcal{L}_{\text{MDM}} = \mathbb{E}_{t,x_0,x_t} \left[ - \sum_{i:x_t^i=\text{M}} \log p_\theta(x_0^i\mid x_t,t) \right]. \]

更具体地,Transformer 给每个位置输出 hidden state:

\[ h_i=f_\theta(x_t,t)_i\in\mathbb{R}^{d}. \]

LM head 把它映射成词表 logits:

\[ z_i=W_{\text{out}}h_i+b_{\text{out}}\in\mathbb{R}^{|V|}, \qquad p_\theta(v\mid x_t,t,i) = \frac{\exp z_{i,v}}{\sum_{u\in V}\exp z_{i,u}}. \]

然后只在 \(x_t^i=\text{M}\) 的位置计算

\[ \ell_i = -\log p_\theta(x_0^i\mid x_t,t,i). \]

在 PyTorch/Hugging Face 风格实现中,最常见写法是把非监督位置设成 ignore index:

labels = clean_ids.clone()
labels[~is_masked] = -100
loss = F.cross_entropy(
    logits.view(-1, vocab_size),
    labels.view(-1),
    ignore_index=-100,
    reduction="sum",
)
denom = is_masked.sum().clamp_min(1)
loss = loss / denom

这里的 reduction="sum" 加手动归一化比直接 reduction="mean" 更透明,因为后面常常还要乘上 \(\lambda(t)\)、按 token 数量做全局平均,或者在分布式训练中先 all-reduce masked token 计数。

WarningPitfall: Do Not Predict the Mask Token Itself

The target at a masked position is the original clean token \(x_0^i\), not [MASK]. The mask token is an input-side sentinel, not a desired output class.

一个细节是:模型是否也对 unmasked positions 输出 logits?通常会输出,因为 Transformer 是全位置并行计算;但 loss 不监督这些位置。监督 unmasked positions 会让模型学会 copy identity shortcut,甚至把任务变成“所有位置重构自己”,削弱 denoising 学习。训练时的关键是:上下文可以看见 unmasked tokens,梯度主要从 masked tokens 回传。

Likelihood Bound Intuition

自回归模型直接分解

\[ \log p_\theta(x_0) = \sum_i\log p_\theta(x_i\mid x_{<i}). \]

Masked diffusion 不按固定顺序生成,因此通常通过 denoising objective 优化 likelihood bound。直觉是:如果模型在任意噪声强度 \(t\) 下都能预测被 mask 的 token,那么它就学到了从噪声分布回到数据分布的反向过程。

对 absorbing mask,训练时常见形式可以看作 weighted masked CE:

\[ \mathcal{L} = \mathbb{E}_{t,x_0,x_t} \left[ \frac{1}{t} \sum_{i:x_t^i=\text{M}} -\log p_\theta(x_0^i\mid x_t,t) \right], \]

其中权重的具体形式依赖时间参数化和推导。重要的是:loss 不是只在固定 15% mask ratio 上训练,而是在一系列 mask ratios 上训练,使模型学会多噪声强度恢复。

为什么会出现类似 \(1/t\) 的权重?可以从一个很朴素的连续时间直觉看。已知某个位置在时间 \(t\) 已经是 mask,那么它从 clean token 被吸收到 mask 的“历史”发生在 \([0,t]\) 中。若 forward process 是均匀参数化的 Bernoulli corruption,则

\[ q(x_t^i=\text{M}\mid x_0^i)=t. \]

在训练时我们只看已经 masked 的位置;这相当于对“被选中监督的位置”做条件化。一个 masked position 的条件密度会带有

\[ \frac{1}{q(x_t^i=\text{M}\mid x_0^i)} = \frac{1}{t} \]

这样的校正因子。严格的 variational bound 推导会从反向转移核或 continuous-time absorbing process 出发,但工程上要记住的事实是:低噪声处 masked token 少,若不加权或不归一化,它们对总 loss 的贡献会被 batch 采样频率淹没;高噪声处 masked token 多,若只按 token 求和,又可能主导梯度。

因此实际实现里通常要同时区分三件事:

quantity role common bug
\(\lambda(t)\) 理论或经验的 noise-level weight 把所有 \(t\) 当成同等难度,导致 sampler 早期/后期一端崩
masked-token normalization 控制不同样本的有效 token 数 直接 sum loss,使长句和高 mask ratio 样本主导训练
batch/global normalization DDP/FSDP 下的跨卡平均 每卡平均再平均,遇到 padding/mask 数不均时产生偏差

一个更稳定的 token-level 写法是

\[ \mathcal{L}_{\text{batch}} = \frac{ \sum_{b=1}^{B}\lambda(t_b) \sum_{i=1}^{L}m_{b,i}\, [-\log p_\theta(x_{0,b}^i\mid x_{t,b},t_b)] }{ \sum_{b=1}^{B}\sum_{i=1}^{L}m_{b,i} }, \]

其中 \(m_{b,i}=\mathbf{1}[x_{t,b}^i=\text{M}]\)。若在分布式训练中每张卡的 masked token 数不同,分子和分母都应该 all-reduce 后再相除。

WarningPitfall: MDM Training Is Not Standard BERT MLM

Standard MLM usually trains at a fixed or narrow masking ratio for representation learning. Masked diffusion LM trains across noise levels and uses the model as a reverse generative process.

Why This Is Interesting

AR LM 的采样有严格顺序:第 \(t\) 个 token 依赖前 \(t-1\) 个 token。MDM 则允许并行地填充多个 mask,因此有潜力减少 decoding steps,并且能自然处理 infilling、editing、bidirectional constraints。

Aspect AR LM Masked diffusion LM
factorization left-to-right iterative denoising
training target next-token CE masked-token CE
generation sequential partially parallel
editing indirect natural
challenge long decoding latency schedule and discrete denoising quality
WarningPitfall: Diffusion in Text Is Not Gaussian Diffusion

Image diffusion often relies on continuous Gaussian noise and score matching. Text diffusion must operate over categorical variables, mask states, or relaxed token distributions. The math analogy is useful, but the transition kernel and reverse parameterization are different.

Architecture: Why a Vanilla Transformer Can Work

LLaDA-style models can use a Transformer backbone with bidirectional attention because the input is a partially observed sequence, not a causal prefix. The model sees all unmasked tokens:

left context  [M]  right context

and predicts the masked position using both sides. That is exactly what AR decoding cannot do directly. The architecture therefore changes the mask:

Model Attention mask Prediction
GPT-style AR causal next token
BERT MLM bidirectional randomly masked tokens
LLaDA-style MDM bidirectional over partial sequence denoising across time

The same Transformer block can be used, but its probabilistic meaning changes because the conditioning pattern changes.

一个常见输入参数化是

\[ h_i^{(0)} = E[x_t^i]+P_i+T(t), \]

其中 \(E[x_t^i]\) 是 token embedding,\(P_i\) 是 position embedding 或 RoPE 产生的位置相位,\(T(t)\) 是 noise-level embedding。\(T(t)\) 可以来自 sinusoidal embedding 加 MLP,也可以把离散 diffusion step 嵌入成 learned embedding。它的作用是告诉模型“当前 [MASK] 应该补到多细”:在 \(t\approx 1\) 时,模型应该先恢复粗粒度语义;在 \(t\approx 0\) 时,模型应该做精确的局部修补。

Attention mask 也要分清三层含义:

mask used in meaning
denoising mask input ids and loss 哪些 token 被替换为 [MASK],哪些位置参与监督
padding mask attention and loss 哪些位置是 batch padding,既不能被 attention 看到,也不能算 loss
causal mask AR Transformer attention LLaDA-style bidirectional denoising 通常不用 causal mask

因此 MDM 的 attention score 通常是

\[ A_{ij} = \frac{q_i^\top k_j}{\sqrt{d_h}} + \mathrm{pad\_mask}_{ij}, \]

而不是

\[ A_{ij} = \frac{q_i^\top k_j}{\sqrt{d_h}} + \mathrm{causal\_mask}_{ij} + \mathrm{pad\_mask}_{ij}. \]

这意味着第 \(i\) 个 mask 位置可以看见右侧 token,也可以看见其他 mask 位置的 [MASK] embedding。它不能看见的是 padding;它也不应该在训练 label 中被要求预测 padding token。

NoteImplementation Distinction

attention_mask usually means padding visibility. loss_mask means supervised denoising positions. Reusing one tensor for both is a common source of silent bugs.

Remasking and Confidence

Sampling is not only “fill masks once”. A practical sampler often allows remasking: low-confidence generated tokens can be masked again and regenerated. Let

\[ c_i=\max_v p_\theta(v\mid x_t,t) \]

be confidence at position \(i\). A simple schedule chooses a target number of masked positions \(m_s\) at step \(s\), then keeps the \(m_s\) least confident positions masked:

predict all masked positions
compute confidence
commit high-confidence tokens
remask low-confidence tokens
repeat

This gives the model a way to revise earlier mistakes, closer to image diffusion refinement.

NoteDefinition: Remasking

Remasking is a sampling operation that turns selected uncertain generated tokens back into mask tokens so that later denoising steps can revise them.

confidence 的定义不是唯一的。常见选择包括:

score formula behavior
max probability \(\max_v p_\theta(v\mid x_t,t)\) 简单,偏好尖锐分布
negative entropy \(-\sum_v p_v\log p_v\) 惩罚整体不确定性,不只看 top-1
margin \(p_{(1)}-p_{(2)}\) 关心第一名和第二名是否接近
sampled-token probability \(p_\theta(\hat{x}_i\mid x_t,t)\) 和实际采样出的 token 绑定,适合 stochastic decoding

要特别注意:confidence policy 不等于训练目标。训练目标拟合 \(p_\theta(x_0^i\mid x_t,t)\);采样策略决定哪些位置先 commit,哪些位置继续保留为 mask。两个模型 loss 相同,可能因为 remasking 策略不同而生成质量和速度差别很大。

一个具体例子:假设某位置 top-1 是 bank,概率 \(0.42\);top-2 是 river,概率 \(0.39\)。max probability 看起来不低,但 margin 只有 \(0.03\),说明语义分歧很大。若这一步过早 commit bank,后续 token 可能围绕金融语境继续生成;若 remask 它,模型在更多上下文恢复后可能转向河岸语境。

Generation Length

AR models naturally decide length by emitting EOS. Masked diffusion models often need a target length or a length prediction mechanism. If generation begins from

[M] [M] [M] ... [M]

then the number of mask slots already fixes the maximum output length. Practical systems may:

  1. predict output length first;
  2. sample several candidate lengths;
  3. include EOS and allow unused positions after EOS;
  4. use task-specific length constraints。
WarningPitfall: Parallel Decoding Still Needs Length Control

Masked diffusion can reduce sequential token dependency, but it does not remove the need to decide how many token slots to denoise.

Sampling Schedule

采样时从全 mask 序列开始:

[M] [M] [M] [M] [M]

在每一步,模型对所有 mask positions 给出 categorical distribution。一个简单策略是:

  1. 计算每个 mask position 的最大概率;
  2. 选择置信度最高的一部分位置填入 token;
  3. 保留低置信度位置为 [M]
  4. 重复直到没有 mask。

例如:

step 0: [M] [M] [M] [M] [M]
step 1: I   [M] like [M] models
step 2: I   really like [M] models
step 3: I   really like diffusion models

这种过程比 AR 采样更像“先定骨架,再补细节”。但它也引入一个新问题:如果早期填错高置信度 token,后续步骤可能围绕错误 token 继续自洽。

采样 schedule 可以写成“剩余 mask 比例” \(r_s\),其中 \(s=0,\ldots,S\)\(r_0=1\)\(r_S=0\)。第 \(s\) 步之后保留

\[ m_s=\lceil r_s L\rceil \]

个低置信度位置为 mask。常见 schedule:

schedule \(r_s\) intuition
linear \(1-\frac{s}{S}\) 每步释放差不多数量的 token
cosine \(\cos^2\left(\frac{\pi s}{2S}\right)\) 前期保守,后期快速细化
square \(\left(1-\frac{s}{S}\right)^2\) 前期释放更多,后期保留少量难点反复修
NoteDefinition: Mask Schedule Contract

A mask schedule specifies the intended number of active mask positions after each denoising step. A sampler is schedule-consistent only if the actual active mask count follows this contract up to rounding and frozen-token constraints.

schedule 不是只画一条曲线。实现中至少要明确四类位置:

state meaning can change?
frozen prompt/source/padding positions no
active mask 当前需要预测的位置 yes
committed 已经填入 token 的位置 maybe, if remasking allowed
finished EOS 后或 unused slots no

设可生成位置数为 \(L_{\text{gen}}\),第 \(s\) 步目标剩余 mask 数是

\[ m_s=\left\lceil r_s L_{\text{gen}}\right\rceil. \]

若 prompt tokens 被 frozen,不能把它们计入 \(L_{\text{gen}}\);若 EOS 后的 slots 也 frozen,也要从可 remask 集合里排除。否则 sampler 看似按 schedule 运行,实际会把用户输入或 padding 当成可编辑 token。

一个 schedule-consistent commit 函数可以先算“当前还要保留多少 active masks”,再只在 editable positions 内做 top-\(k\)

def apply_confidence_schedule(ids, proposal, confidence, editable, step, steps, mask_id):
    # editable: positions allowed to be mask/proposal, excludes prompt and padding
    gen_len = editable.sum(dim=1)
    ratio = math.cos(0.5 * math.pi * step / steps) ** 2
    keep = torch.ceil(ratio * gen_len).long()

    out = torch.where(editable & ids.eq(mask_id), proposal, ids)
    score = confidence.masked_fill(~editable, float("inf"))

    for b in range(ids.size(0)):
        k = min(int(keep[b]), int(editable[b].sum()))
        if k > 0:
            pos = score[b].topk(k=k, largest=False).indices
            out[b, pos] = mask_id
    return out

这个代码仍然很朴素,但它显式区分了 editableactiveactive 是当前仍为 [MASK] 的位置;editable 是允许被 sampler 改写的位置。条件生成、infilling 和编辑任务里,这两个集合不能混。

极简 sampler 可以写成:

@torch.no_grad()
def mdm_sample(model, length, steps, mask_id, temperature):
    ids = torch.full((1, length), mask_id, device=model.device)

    for s in range(1, steps + 1):
        t = torch.tensor([1.0 - (s - 1) / steps], device=ids.device)
        logits = model(input_ids=ids, noise_level=t).logits
        probs = torch.softmax(logits / temperature, dim=-1)

        pred = torch.multinomial(
            probs.view(-1, probs.size(-1)),
            num_samples=1,
        ).view_as(ids)
        conf = probs.gather(-1, pred[..., None]).squeeze(-1)

        active = ids.eq(mask_id)
        proposal = torch.where(active, pred, ids)

        keep = int(math.ceil((1.0 - s / steps) * length))
        if keep == 0:
            ids = proposal
            break

        conf = conf.masked_fill(~active, float("inf"))
        remask_pos = conf.topk(k=keep, largest=False).indices
        ids = proposal
        ids.scatter_(1, remask_pos, mask_id)

    return ids

这段代码故意写得朴素:真实实现需要处理 batch、padding、prompt constraints、EOS、temperature/top-p、重复惩罚,以及“已经 commit 的 token 是否允许再次 remask”。但它揭示了 MDM sampler 的核心循环:predict、score、commit/remask、repeat。

WarningPitfall: Prompt Tokens Should Usually Be Frozen

For conditional generation, prompt/context tokens should normally be excluded from remasking and loss. Otherwise the sampler may rewrite the instruction or source text.

Parallel Decoding Consistency

MDM 每一步会并行提议多个 token。并行的好处是减少 sequential depth,风险是局部提议之间可能互相不一致。例如两个位置分别生成:

The capital of France is [M], and it is located in [M].

早期 step 可能同时提议 ParisGermany。如果两个位置都被高置信度 commit,后续模型会在错误上下文上继续修补。remasking 的本质就是给 sampler 一个撤销机制。

可以把一次 denoising step 拆成三个张量:

tensor role
proposal_ids 当前模型为 active mask 给出的候选 token
commit_mask 本步决定接受哪些 proposal
next_ids 接受 proposal 后再按 schedule/remask 得到的下一状态

工程上不要直接覆盖 ids 后再计算 confidence,否则会混淆“模型基于哪个状态预测”和“本步之后状态是什么”。更稳的顺序是:

logits = model(ids, noise_level=t).logits
proposal_ids, confidence = sample_or_argmax(logits)
commit_mask = choose_commit_positions(ids, confidence, editable, schedule)
next_ids = ids.clone()
next_ids[commit_mask] = proposal_ids[commit_mask]
next_ids[remask_mask] = mask_id
TipSmoke Test: Frozen Prompt Invariance

Run the sampler with a fixed prompt and assert that prompt token ids are identical before and after every denoising step.

Objective as a Weighted Denoising Problem

如果 \(t\) 越大,mask ratio 越高,任务越难;\(t\) 越小,任务越接近局部填空。实际训练通常会采样不同 \(t\),让模型学会从不同噪声强度恢复文本:

\[ \mathcal{L} = \mathbb{E}_{t\sim p(t)} \mathbb{E}_{x_t\sim q(x_t\mid x_0)} \left[ \lambda(t) \sum_{i:x_t^i=\text{M}} -\log p_\theta(x_0^i\mid x_t,t) \right]. \]

\(\lambda(t)\) 可以调节不同噪声强度的权重。直觉上,高 mask ratio 训练全局语义建模,低 mask ratio 训练局部语法和精确恢复。

训练 batch 构造可以抽象成:

def corrupt_for_mdm(input_ids, attention_mask, mask_id, t):
    # input_ids: [B, L], attention_mask: 1 for real tokens, 0 for padding
    clean = input_ids.clone()
    valid = attention_mask.bool()

    prob = t[:, None].expand_as(input_ids)
    noise = torch.rand_like(prob.float())
    is_masked = (noise < prob) & valid

    # optional: force at least one masked token per nonempty sequence
    empty = is_masked.sum(dim=1).eq(0)
    if empty.any():
        first_valid = valid.float().argmax(dim=1)
        is_masked[empty, first_valid[empty]] = True

    corrupted = torch.where(
        is_masked,
        torch.full_like(input_ids, mask_id),
        input_ids,
    )
    labels = clean.masked_fill(~is_masked, -100)
    return corrupted, labels, is_masked

然后 training step 里做:

t = sample_noise_level(batch_size, device=input_ids.device)
corrupted, labels, is_masked = corrupt_for_mdm(
    input_ids=input_ids,
    attention_mask=attention_mask,
    mask_id=tokenizer.mask_token_id,
    t=t,
)
out = model(
    input_ids=corrupted,
    attention_mask=attention_mask,
    noise_level=t,
)
token_loss = F.cross_entropy(
    out.logits.view(-1, out.logits.size(-1)),
    labels.view(-1),
    ignore_index=-100,
    reduction="none",
).view_as(labels)
weight = lambda_t(t)[:, None]
loss_num = (token_loss * is_masked * weight).sum()
loss_den = is_masked.sum().clamp_min(1)
loss = loss_num / loss_den

这里有几个实现点值得反复检查:

  1. padding 位置不能被 corruption 成 [MASK]
  2. padding 位置不能参与 loss;
  3. attention_mask 仍然传给模型,防止 padding 被其他 token attend;
  4. t 要和样本绑定,而不是和 token 随机独立绑定,否则同一句内部的噪声等级解释会变乱;
  5. 如果使用 sequence packing,packed sample 之间的 block diagonal attention mask 也要保留,否则不同文档会互相泄漏上下文。

如果要用 fixed-count corruption,上面的 noise < prob 可以替换成每条样本对 valid positions 采样 top-\(k\)。例如令 \(k_b=\max(1,\lfloor t_b n_b\rceil)\),其中 \(n_b\) 是第 \(b\) 条样本的非 padding token 数。这样每条样本的监督 token 数更稳定,但数学上不再是逐位置独立 Bernoulli。

A Tiny Numerical Example

句子长度为 \(L=4\),在 \(t=0.5\) 时每个 token 独立以 \(0.5\) 概率被 mask。真实句子:

deep learning is fun

某次 corruption 得到:

deep [M] is [M]

模型输出:

\[ p_\theta(x_2=\text{learning}\mid x_t)=0.7, \qquad p_\theta(x_4=\text{fun}\mid x_t)=0.4. \]

这次样本的 denoising loss 是

\[ -\log 0.7-\log 0.4\approx1.273. \]

这和普通 cross entropy 一样,只是 supervised positions 由 forward corruption process 随机决定。

Comparison With AR Decoding

Suppose output length is \(L\) and diffusion sampling uses \(S\) denoising steps. AR decoding requires \(L\) model calls, one per token. Masked diffusion requires roughly \(S\) calls, each predicting many positions in parallel.

If \(S\ll L\), MDM can reduce sequential depth. But each denoising step processes the whole sequence, and there is no standard KV cache reuse like AR decoding because tokens can be revised bidirectionally.

Cost dimension AR LM Masked diffusion LM
sequential steps \(L\) \(S\)
per step input prefix/new token with cache full partial sequence
cache reuse KV cache natural harder because tokens change
editing/infilling awkward natural
length handling EOS natural needs length control

所以 MDM 的优势不是“必然更快”,而是把串行深度从 token-level 改成 denoising-step-level。真实速度取决于 \(S\)、序列长度、cache 机制、硬件 kernel 和 remasking 策略。

KV cache 的差别尤其关键。AR decoding 在第 \(i\) 步只新增一个 token,过去 token 的 key/value 不再改变,因此可以缓存:

\[ K_{\le i}=[K_{\le i-1};K_i], \qquad V_{\le i}=[V_{\le i-1};V_i]. \]

MDM 的第 \(s\) 步会把多个 [MASK] 改成 token,也可能把低置信度 token 改回 [MASK]。由于 self-attention 是双向的,任意位置 token 改变都会改变后续层中许多位置的 hidden states:

\[ h_j^{(\ell+1)} = \sum_i \mathrm{softmax}\left( \frac{q_j^\top k_i}{\sqrt{d_h}} \right)v_i. \]

只要某个 \(x_i\) 变了,\(k_i,v_i\) 变;进一步,所有 attend 到 \(i\)\(h_j\) 也变;下一层的 \(q_j,k_j,v_j\) 又随之变化。这就是为什么 naive MDM sampler 每一步通常要重新跑整段序列,而不能像 GPT decode 那样只追加一个 token。

可能的工程优化方向包括:

idea what it tries to reuse limitation
frozen committed tokens 不再 remask 的 token 的局部表示 双向层中它们仍受其他 token 变化影响
block-wise denoising 把序列切块,只更新活跃 block 跨 block attention 会带来一致性问题
shallow refinement 前几层缓存、后几层更新 近似推理,质量需要重新验证
fewer denoising steps 用 schedule 减少 \(S\) 过早 commit 会损害质量

所以系统层面的判断不是“\(S<L\) 就赢”。更合理的粗略比较是:

\[ \text{AR cost} \approx L\cdot C_{\text{decode-with-cache}}, \qquad \text{MDM cost} \approx S\cdot C_{\text{full-sequence}}. \]

当序列很长、\(S\) 足够小、并且 full-sequence kernel 吞吐高时,MDM 有机会减少 wall-clock;当 \(S\) 接近 \(L\) 或 remasking 太频繁时,它可能比 AR 更贵。

Failure Modes

Masked diffusion LM 的常见问题:

Failure Cause
repetition low-confidence positions keep reinforcing same local pattern
inconsistency parallel fills conflict before later revision
length mismatch target length or EOS handling wrong
weak exact copying no left-to-right pointer-like generation
slow wall-clock too many denoising steps over full sequence

这些问题和 AR LM 不一样。AR 的错误常来自早期 token 造成 exposure bias;MDM 的错误常来自并行局部决策之间不一致,或 schedule 太早 commit 错误 token。

Implementation Checklist

实现一个 LLaDA-style MDM 时,可以按下面顺序自检:

  1. tokenizer 里确实有专用 mask_token_id,且它不会和 padding/EOS 混用;
  2. corruption 只作用于真实 token,不作用于 padding 和固定 prompt;
  3. labels 只在 masked positions 保留 clean token,其余位置为 ignore index;
  4. attention 使用 bidirectional visibility,但 padding/packing boundary 仍然被 mask;
  5. 模型显式接收 noise level \(t\) 或 discrete step embedding;
  6. loss 同时考虑 \(\lambda(t)\) 和 masked-token normalization;
  7. 分布式训练时 all-reduce loss numerator 和 denominator,而不是每卡先局部平均;
  8. sampler 明确区分 proposal token、committed token、remasked token;
  9. conditional generation 时 prompt/source tokens 被 frozen;
  10. length control、EOS handling 和 unused slots 有一致规则;
  11. 推理 benchmark 同时报告 denoising steps、full-sequence tokens processed、wall-clock、quality;
  12. posterior-style reverse step 是否满足 \(1-s/t\) 的 unmask probability;
  13. sampler 是否区分 editableactive maskcommittedfrozen
  14. prompt/source/padding tokens 在每一步采样后是否保持不变;
  15. 对同一个 checkpoint 比较不同 remasking confidence policy,而不是只看训练 loss。

一个小实验很能暴露问题:取一批真实句子,固定 \(t\in\{0.1,0.3,0.6,0.9\}\),分别统计 masked token CE、top-1 accuracy、sequence-level exact recovery。若 \(t=0.1\) 表现很好但 \(t=0.9\) 崩,说明模型会局部填空但缺全局语义;若高 \(t\) 还行但低 \(t\) 不稳,通常是 loss 权重、time embedding 或 label mask 写错。

再加两个 sampler tests:

  1. \(t=1\)\(s=0.5\) 的 posterior-style step,统计大量样本的剩余 mask 比例,应接近 \(0.5\)
  2. 对 conditional generation,记录每一步 prompt span 的 token ids,确认 sampler 从未改写 frozen span。

Connection to Chapter 2

LLaDA is the natural continuation of discrete modeling: token prediction remains cross entropy over a vocabulary, but the conditioning pattern is denoising rather than purely causal. This connects BERT-style masked modeling, diffusion-style iterative refinement, and LLM-scale generation.

References