LLaDA / Masked Diffusion LM
不同于 GPT 系列 autoregressive language model,LLaDA 属于 masked diffusion language model (MDM)。它把文本生成从“从左到右逐 token 采样”改写为“从 masked sequence 逐步去噪”。这使它更接近 diffusion 的训练思想,但变量空间是离散 token 而不是连续图像 latent。
Running Example: Fill Several Masks
假设真实句子是:
I really like diffusion models
forward process 在某个时间 \(t\) 把它腐蚀成:
I [M] like [M] models
模型的任务不是预测“下一个 token”,而是同时根据上下文预测两个 masked positions:
\[ p_\theta(\text{really}\mid I,[M],like,[M],models,t), \]
\[ p_\theta(\text{diffusion}\mid I,[M],like,[M],models,t). \]
采样时可以先填高置信度位置,再用新填入的 token 继续推断剩余 mask。于是生成过程像 iterative refinement,而不是严格从左到右。
Masked Diffusion Model
Masked diffusion models (MDMs) (Austin et al., 2021a; Lou et al., 2023; Ou et al., 2024) 以不同于 autoregressive models 的方式定义模型分布 \(p_\theta(x_0)\)。这些模型引入了一个由时间 \(t \in [0, 1]\) 索引的 forward process \(\{x_t\}\)。该过程逐渐且独立地 mask 序列 \(x_0\) 中的所有 tokens。在时间 \(t = 0\) 时,数据点 \(x_0\) 被完全观测到,没有任何 masks;而对于 \(t \in (0, 1]\),\(x_t\) 表示在期望中具有不同 mask ratios 的 latent variables。形式上,给定 \(x_0\) 时 \(x_t\) 的 conditional distribution 由一个完全分解的形式定义: \[ q_{t|0}(x_t|x_0) = \prod_{i=1}^{L} q_{t|0}(x^i_t|x^i_0), \quad (7) \] 其中每个 token 的 conditional distribution 由下式给出: \[ q_{t|0}(x^i_t|x^i_0) = \begin{cases} 1 - t, & x^i_t = x^i_0, \\ t, & x^i_t = \text{M}. \end{cases} \quad (8) \]
A masked diffusion language model defines a corruption process that progressively replaces clean tokens with a mask token, and learns a reverse process that predicts masked tokens from partially observed sequences.
LLaDA 的关键不是“把 BERT 做大”这么简单。BERT 的 masked language modeling 通常服务于 representation learning;LLaDA 把 masked denoising 组织成 generative process,目标是在采样时从全 mask 或部分 mask 状态逐步恢复完整序列。换句话说,BERT 学的是“这个句子里缺了几个词时该怎么补”;LLaDA 学的是“从噪声文本分布出发,沿着一条可控的去噪轨迹走回自然语言分布”。
这里最容易混淆的是 noise level \(t\) 的含义。在 BERT-style MLM 中,mask ratio 常常是一个数据增强超参数;在 MDM 中,\(t\) 是概率模型的一部分。模型不仅要知道当前位置是 [MASK],还要知道当前整句处于哪个噪声强度:\(t\) 小时,周围上下文基本可信,任务接近局部填空;\(t\) 大时,大部分 token 都被擦掉,模型必须先恢复主题、句法骨架和长程一致性。
For an AR LM, the latent state during decoding is a prefix. For a masked diffusion LM, the latent state is a partially observed sequence plus a noise level. This is why the same cross entropy can train two very different generative stories.
从工程角度看,可以把一条训练样本拆成四个对象:
| object | shape | meaning |
|---|---|---|
input_ids |
[B, L] |
被 corruption 后的 \(x_t\),其中部分位置为 mask_token_id |
labels |
[B, L] |
原始 \(x_0\),但只在 masked positions 上监督 |
loss_mask |
[B, L] |
哪些位置参与 CE,通常等价于 input_ids == mask_token_id |
t |
[B] or [B, 1] |
每个样本的 noise level / mask ratio |
注意 labels 不是“右移一位”的 next-token label。MDM 的监督位置由 corruption process 产生,而 AR LM 的监督位置由 causal factorization 产生。两者都用 cross entropy,但训练语义完全不同。
Absorbing Mask Process
对于每个 token,forward process 是一个 absorbing process:一旦变成 [MASK],之后保持 [MASK]。对单个位置:
\[ q(x_t\mid x_0) = (1-t)\delta_{x_0}+t\delta_{\text{M}}. \]
因此未被 mask 的概率是 \(1-t\),被 mask 的概率是 \(t\)。对长度 \(L\) 的序列,mask 数量
\[ N_t=\sum_{i=1}^{L}\mathbf{1}[x_t^i=\text{M}] \]
满足二项分布
\[ N_t\sim\mathrm{Binomial}(L,t), \]
于是
\[ \mathbb{E}[N_t]=Lt. \]
方差为
\[ \mathrm{Var}(N_t)=Lt(1-t). \]
这个方差在实现里很重要。若 \(L\) 很短或 \(t\) 很小,某个 batch 里可能出现 \(N_t=0\) 的样本,即没有任何 token 需要预测;若直接除以 masked token 数量,就会产生除零或极大梯度权重。因此真实 collator 往往要做至少一个保护:
- 对每条样本强制至少 mask 一个 token;
- 或者在 loss normalization 里用
clamp_min(1); - 或者用 fixed-count corruption:先采样 \(n=\lceil Lt\rceil\),再从 \(L\) 个位置中无放回选择 \(n\) 个 mask。
每个位置独立以概率 \(t\) 被 mask。令
\[ Z_i=\mathbf{1}[x_t^i=\text{M}]. \]
则 \(\mathbb{E}[Z_i]=t\)。线性期望给出
\[ \mathbb{E}[N_t] = \mathbb{E}\left[\sum_i Z_i\right] = \sum_i\mathbb{E}[Z_i] = Lt. \]
因为 \(Z_i\) 独立,且 \(\mathrm{Var}(Z_i)=t(1-t)\),所以
\[ \mathrm{Var}(N_t) = \sum_i\mathrm{Var}(Z_i) = Lt(1-t). \]
Bernoulli masking 和 fixed-count masking 的差别可以这样理解:
| corruption | advantage | caveat |
|---|---|---|
| Bernoulli per token | 数学最干净,\(q(x_t\mid x_0)\) 完全分解 | batch 内 mask 数量波动,短序列可能无监督 token |
| fixed count | 每条样本计算量更稳定,便于课程实验复现 | 严格来说位置之间不再独立,推导要改成超几何/组合形式 |
这个过程的好处是简单、稳定、适合大词表。它的缺点也明显:噪声状态很特殊,全部信息都被压成一个 [MASK] 符号,不像连续 diffusion 那样保留带噪的细粒度信息。因此 MDM 的建模难点并不是“怎么给 token 加高斯噪声”,而是“只看到离散缺口时,如何用上下文和时间条件恢复全局一致的文本”。
Posterior of the Absorbing Process
forward process 定义了从 clean sentence \(x_0\) 到 noisy sentence \(x_t\) 的腐蚀;reverse sampler 需要从较高噪声 \(t\) 走到较低噪声 \(s<t\)。因此一个关键量是:
\[ q(x_s^i\mid x_t^i,x_0^i), \qquad 0\le s<t\le 1. \]
对 absorbing mask,可以把每个 token 的 mask time 看成随机变量 \(\tau_i\sim\mathrm{Uniform}(0,1)\)。若 \(\tau_i\le t\),则 \(x_t^i=\text{M}\);若 \(\tau_i>t\),则 \(x_t^i=x_0^i\)。于是:
\[ x_t^i= \begin{cases} \text{M},& \tau_i\le t,\\ x_0^i,& \tau_i>t. \end{cases} \]
如果 \(x_t^i\ne\text{M}\),那么它还没有被吸收,且因为 \(s<t\),在更早时间 \(s\) 也必然没有被吸收:
\[ q(x_s^i=x_0^i\mid x_t^i=x_0^i,x_0^i)=1. \]
如果 \(x_t^i=\text{M}\),则只知道 \(\tau_i\le t\)。它在时间 \(s\) 已经 mask 的概率是
\[ q(x_s^i=\text{M}\mid x_t^i=\text{M},x_0^i) = \Pr(\tau_i\le s\mid \tau_i\le t) = \frac{s}{t}. \]
所以它在从 \(t\) 走回 \(s\) 时被恢复成 clean token 的概率是
\[ q(x_s^i=x_0^i\mid x_t^i=\text{M},x_0^i) = 1-\frac{s}{t}. \]
For the Bernoulli absorbing mask process with \(0\le s<t\le1\), \[ q(x_s^i\mid x_t^i=\text{M},x_0^i) = \frac{s}{t}\delta_{\text{M}} + \left(1-\frac{s}{t}\right)\delta_{x_0^i}, \] and if \(x_t^i=x_0^i\), then \(q(x_s^i=x_0^i\mid x_t^i,x_0^i)=1\).
用吸收时间 \(\tau_i\) 表示 token 第一次变成 mask 的时间。因为
\[ q(x_t^i=\text{M}\mid x_0^i)=t, \]
可以取等价构造 \(\tau_i\sim\mathrm{Uniform}(0,1)\),并令 \(x_t^i=\text{M}\) 当且仅当 \(\tau_i\le t\)。若 \(x_t^i=\text{M}\),条件事件是 \(\tau_i\le t\)。于是
\[ \Pr(x_s^i=\text{M}\mid x_t^i=\text{M},x_0^i) = \Pr(\tau_i\le s\mid\tau_i\le t) = \frac{s}{t}. \]
补事件就是 \(x_s^i=x_0^i\)。若 \(x_t^i=x_0^i\),则 \(\tau_i>t\),从而 \(\tau_i>s\),所以 \(x_s^i=x_0^i\)。
这个后验告诉我们 reverse step 应该做什么:对当前还是 [MASK] 的位置,以概率 \(1-s/t\) 解除 mask,并用模型预测的 clean token 分布替代未知的 \(x_0^i\)。也就是说,模型不是直接学
\[ p_\theta(x_s\mid x_t), \]
而是先学 clean-token predictor
\[ p_\theta(x_0^i\mid x_t,t), \]
再把它和 absorbing posterior 组合成从 \(t\) 到 \(s\) 的 transition。
一个 posterior-style reverse step 可以写成:
@torch.no_grad()
def posterior_reverse_step(model, ids, t, s, mask_id, temperature):
# ids: [B, L], current x_t
active = ids.eq(mask_id)
logits = model(input_ids=ids, noise_level=t).logits
probs = torch.softmax(logits / temperature, dim=-1)
proposal = torch.multinomial(
probs.view(-1, probs.size(-1)),
num_samples=1,
).view_as(ids)
unmask_prob = (1.0 - s / t).clamp(min=0.0, max=1.0)
coin = torch.rand_like(ids.float()) < unmask_prob[:, None]
reveal = active & coin
return torch.where(reveal, proposal, ids)这个版本和 confidence-based remasking 的区别很重要:posterior step 按理论 transition 随机决定释放概率;confidence sampler 按模型置信度选择释放哪些位置。前者更贴近 forward process 的后验结构,后者更像工程启发式,常用于提高生成质量或减少早期错误 commit。
The absorbing posterior tells how many masked variables should become clean when moving from \(t\) to \(s\). A confidence policy decides which positions to commit using model scores. Mixing them without tracking the schedule changes the sampler.
Reverse Process
反向过程不是预测下一个 token,而是在任意 mask pattern 下预测被 mask 的 token:
\[ p_\theta(x_0\mid x_t) = \prod_{i:x_t^i=\text{M}} p_\theta(x_0^i\mid x_t,t). \]
训练目标通常是 masked-token cross entropy:
\[ \mathcal{L}_{\text{MDM}} = \mathbb{E}_{t,x_0,x_t} \left[ - \sum_{i:x_t^i=\text{M}} \log p_\theta(x_0^i\mid x_t,t) \right]. \]
更具体地,Transformer 给每个位置输出 hidden state:
\[ h_i=f_\theta(x_t,t)_i\in\mathbb{R}^{d}. \]
LM head 把它映射成词表 logits:
\[ z_i=W_{\text{out}}h_i+b_{\text{out}}\in\mathbb{R}^{|V|}, \qquad p_\theta(v\mid x_t,t,i) = \frac{\exp z_{i,v}}{\sum_{u\in V}\exp z_{i,u}}. \]
然后只在 \(x_t^i=\text{M}\) 的位置计算
\[ \ell_i = -\log p_\theta(x_0^i\mid x_t,t,i). \]
在 PyTorch/Hugging Face 风格实现中,最常见写法是把非监督位置设成 ignore index:
labels = clean_ids.clone()
labels[~is_masked] = -100
loss = F.cross_entropy(
logits.view(-1, vocab_size),
labels.view(-1),
ignore_index=-100,
reduction="sum",
)
denom = is_masked.sum().clamp_min(1)
loss = loss / denom这里的 reduction="sum" 加手动归一化比直接 reduction="mean" 更透明,因为后面常常还要乘上 \(\lambda(t)\)、按 token 数量做全局平均,或者在分布式训练中先 all-reduce masked token 计数。
The target at a masked position is the original clean token \(x_0^i\), not [MASK]. The mask token is an input-side sentinel, not a desired output class.
一个细节是:模型是否也对 unmasked positions 输出 logits?通常会输出,因为 Transformer 是全位置并行计算;但 loss 不监督这些位置。监督 unmasked positions 会让模型学会 copy identity shortcut,甚至把任务变成“所有位置重构自己”,削弱 denoising 学习。训练时的关键是:上下文可以看见 unmasked tokens,梯度主要从 masked tokens 回传。
Likelihood Bound Intuition
自回归模型直接分解
\[ \log p_\theta(x_0) = \sum_i\log p_\theta(x_i\mid x_{<i}). \]
Masked diffusion 不按固定顺序生成,因此通常通过 denoising objective 优化 likelihood bound。直觉是:如果模型在任意噪声强度 \(t\) 下都能预测被 mask 的 token,那么它就学到了从噪声分布回到数据分布的反向过程。
对 absorbing mask,训练时常见形式可以看作 weighted masked CE:
\[ \mathcal{L} = \mathbb{E}_{t,x_0,x_t} \left[ \frac{1}{t} \sum_{i:x_t^i=\text{M}} -\log p_\theta(x_0^i\mid x_t,t) \right], \]
其中权重的具体形式依赖时间参数化和推导。重要的是:loss 不是只在固定 15% mask ratio 上训练,而是在一系列 mask ratios 上训练,使模型学会多噪声强度恢复。
为什么会出现类似 \(1/t\) 的权重?可以从一个很朴素的连续时间直觉看。已知某个位置在时间 \(t\) 已经是 mask,那么它从 clean token 被吸收到 mask 的“历史”发生在 \([0,t]\) 中。若 forward process 是均匀参数化的 Bernoulli corruption,则
\[ q(x_t^i=\text{M}\mid x_0^i)=t. \]
在训练时我们只看已经 masked 的位置;这相当于对“被选中监督的位置”做条件化。一个 masked position 的条件密度会带有
\[ \frac{1}{q(x_t^i=\text{M}\mid x_0^i)} = \frac{1}{t} \]
这样的校正因子。严格的 variational bound 推导会从反向转移核或 continuous-time absorbing process 出发,但工程上要记住的事实是:低噪声处 masked token 少,若不加权或不归一化,它们对总 loss 的贡献会被 batch 采样频率淹没;高噪声处 masked token 多,若只按 token 求和,又可能主导梯度。
因此实际实现里通常要同时区分三件事:
| quantity | role | common bug |
|---|---|---|
| \(\lambda(t)\) | 理论或经验的 noise-level weight | 把所有 \(t\) 当成同等难度,导致 sampler 早期/后期一端崩 |
| masked-token normalization | 控制不同样本的有效 token 数 | 直接 sum loss,使长句和高 mask ratio 样本主导训练 |
| batch/global normalization | DDP/FSDP 下的跨卡平均 | 每卡平均再平均,遇到 padding/mask 数不均时产生偏差 |
一个更稳定的 token-level 写法是
\[ \mathcal{L}_{\text{batch}} = \frac{ \sum_{b=1}^{B}\lambda(t_b) \sum_{i=1}^{L}m_{b,i}\, [-\log p_\theta(x_{0,b}^i\mid x_{t,b},t_b)] }{ \sum_{b=1}^{B}\sum_{i=1}^{L}m_{b,i} }, \]
其中 \(m_{b,i}=\mathbf{1}[x_{t,b}^i=\text{M}]\)。若在分布式训练中每张卡的 masked token 数不同,分子和分母都应该 all-reduce 后再相除。
Standard MLM usually trains at a fixed or narrow masking ratio for representation learning. Masked diffusion LM trains across noise levels and uses the model as a reverse generative process.
Why This Is Interesting
AR LM 的采样有严格顺序:第 \(t\) 个 token 依赖前 \(t-1\) 个 token。MDM 则允许并行地填充多个 mask,因此有潜力减少 decoding steps,并且能自然处理 infilling、editing、bidirectional constraints。
| Aspect | AR LM | Masked diffusion LM |
|---|---|---|
| factorization | left-to-right | iterative denoising |
| training target | next-token CE | masked-token CE |
| generation | sequential | partially parallel |
| editing | indirect | natural |
| challenge | long decoding latency | schedule and discrete denoising quality |
Image diffusion often relies on continuous Gaussian noise and score matching. Text diffusion must operate over categorical variables, mask states, or relaxed token distributions. The math analogy is useful, but the transition kernel and reverse parameterization are different.
Architecture: Why a Vanilla Transformer Can Work
LLaDA-style models can use a Transformer backbone with bidirectional attention because the input is a partially observed sequence, not a causal prefix. The model sees all unmasked tokens:
left context [M] right context
and predicts the masked position using both sides. That is exactly what AR decoding cannot do directly. The architecture therefore changes the mask:
| Model | Attention mask | Prediction |
|---|---|---|
| GPT-style AR | causal | next token |
| BERT MLM | bidirectional | randomly masked tokens |
| LLaDA-style MDM | bidirectional over partial sequence | denoising across time |
The same Transformer block can be used, but its probabilistic meaning changes because the conditioning pattern changes.
一个常见输入参数化是
\[ h_i^{(0)} = E[x_t^i]+P_i+T(t), \]
其中 \(E[x_t^i]\) 是 token embedding,\(P_i\) 是 position embedding 或 RoPE 产生的位置相位,\(T(t)\) 是 noise-level embedding。\(T(t)\) 可以来自 sinusoidal embedding 加 MLP,也可以把离散 diffusion step 嵌入成 learned embedding。它的作用是告诉模型“当前 [MASK] 应该补到多细”:在 \(t\approx 1\) 时,模型应该先恢复粗粒度语义;在 \(t\approx 0\) 时,模型应该做精确的局部修补。
Attention mask 也要分清三层含义:
| mask | used in | meaning |
|---|---|---|
| denoising mask | input ids and loss | 哪些 token 被替换为 [MASK],哪些位置参与监督 |
| padding mask | attention and loss | 哪些位置是 batch padding,既不能被 attention 看到,也不能算 loss |
| causal mask | AR Transformer attention | LLaDA-style bidirectional denoising 通常不用 causal mask |
因此 MDM 的 attention score 通常是
\[ A_{ij} = \frac{q_i^\top k_j}{\sqrt{d_h}} + \mathrm{pad\_mask}_{ij}, \]
而不是
\[ A_{ij} = \frac{q_i^\top k_j}{\sqrt{d_h}} + \mathrm{causal\_mask}_{ij} + \mathrm{pad\_mask}_{ij}. \]
这意味着第 \(i\) 个 mask 位置可以看见右侧 token,也可以看见其他 mask 位置的 [MASK] embedding。它不能看见的是 padding;它也不应该在训练 label 中被要求预测 padding token。
attention_mask usually means padding visibility. loss_mask means supervised denoising positions. Reusing one tensor for both is a common source of silent bugs.
Remasking and Confidence
Sampling is not only “fill masks once”. A practical sampler often allows remasking: low-confidence generated tokens can be masked again and regenerated. Let
\[ c_i=\max_v p_\theta(v\mid x_t,t) \]
be confidence at position \(i\). A simple schedule chooses a target number of masked positions \(m_s\) at step \(s\), then keeps the \(m_s\) least confident positions masked:
predict all masked positions
compute confidence
commit high-confidence tokens
remask low-confidence tokens
repeat
This gives the model a way to revise earlier mistakes, closer to image diffusion refinement.
Remasking is a sampling operation that turns selected uncertain generated tokens back into mask tokens so that later denoising steps can revise them.
confidence 的定义不是唯一的。常见选择包括:
| score | formula | behavior |
|---|---|---|
| max probability | \(\max_v p_\theta(v\mid x_t,t)\) | 简单,偏好尖锐分布 |
| negative entropy | \(-\sum_v p_v\log p_v\) | 惩罚整体不确定性,不只看 top-1 |
| margin | \(p_{(1)}-p_{(2)}\) | 关心第一名和第二名是否接近 |
| sampled-token probability | \(p_\theta(\hat{x}_i\mid x_t,t)\) | 和实际采样出的 token 绑定,适合 stochastic decoding |
要特别注意:confidence policy 不等于训练目标。训练目标拟合 \(p_\theta(x_0^i\mid x_t,t)\);采样策略决定哪些位置先 commit,哪些位置继续保留为 mask。两个模型 loss 相同,可能因为 remasking 策略不同而生成质量和速度差别很大。
一个具体例子:假设某位置 top-1 是 bank,概率 \(0.42\);top-2 是 river,概率 \(0.39\)。max probability 看起来不低,但 margin 只有 \(0.03\),说明语义分歧很大。若这一步过早 commit bank,后续 token 可能围绕金融语境继续生成;若 remask 它,模型在更多上下文恢复后可能转向河岸语境。
Generation Length
AR models naturally decide length by emitting EOS. Masked diffusion models often need a target length or a length prediction mechanism. If generation begins from
[M] [M] [M] ... [M]
then the number of mask slots already fixes the maximum output length. Practical systems may:
- predict output length first;
- sample several candidate lengths;
- include EOS and allow unused positions after EOS;
- use task-specific length constraints。
Masked diffusion can reduce sequential token dependency, but it does not remove the need to decide how many token slots to denoise.
Sampling Schedule
采样时从全 mask 序列开始:
[M] [M] [M] [M] [M]
在每一步,模型对所有 mask positions 给出 categorical distribution。一个简单策略是:
- 计算每个 mask position 的最大概率;
- 选择置信度最高的一部分位置填入 token;
- 保留低置信度位置为
[M]; - 重复直到没有 mask。
例如:
step 0: [M] [M] [M] [M] [M]
step 1: I [M] like [M] models
step 2: I really like [M] models
step 3: I really like diffusion models
这种过程比 AR 采样更像“先定骨架,再补细节”。但它也引入一个新问题:如果早期填错高置信度 token,后续步骤可能围绕错误 token 继续自洽。
采样 schedule 可以写成“剩余 mask 比例” \(r_s\),其中 \(s=0,\ldots,S\),\(r_0=1\),\(r_S=0\)。第 \(s\) 步之后保留
\[ m_s=\lceil r_s L\rceil \]
个低置信度位置为 mask。常见 schedule:
| schedule | \(r_s\) | intuition |
|---|---|---|
| linear | \(1-\frac{s}{S}\) | 每步释放差不多数量的 token |
| cosine | \(\cos^2\left(\frac{\pi s}{2S}\right)\) | 前期保守,后期快速细化 |
| square | \(\left(1-\frac{s}{S}\right)^2\) | 前期释放更多,后期保留少量难点反复修 |
A mask schedule specifies the intended number of active mask positions after each denoising step. A sampler is schedule-consistent only if the actual active mask count follows this contract up to rounding and frozen-token constraints.
schedule 不是只画一条曲线。实现中至少要明确四类位置:
| state | meaning | can change? |
|---|---|---|
| frozen | prompt/source/padding positions | no |
| active mask | 当前需要预测的位置 | yes |
| committed | 已经填入 token 的位置 | maybe, if remasking allowed |
| finished | EOS 后或 unused slots | no |
设可生成位置数为 \(L_{\text{gen}}\),第 \(s\) 步目标剩余 mask 数是
\[ m_s=\left\lceil r_s L_{\text{gen}}\right\rceil. \]
若 prompt tokens 被 frozen,不能把它们计入 \(L_{\text{gen}}\);若 EOS 后的 slots 也 frozen,也要从可 remask 集合里排除。否则 sampler 看似按 schedule 运行,实际会把用户输入或 padding 当成可编辑 token。
一个 schedule-consistent commit 函数可以先算“当前还要保留多少 active masks”,再只在 editable positions 内做 top-\(k\):
def apply_confidence_schedule(ids, proposal, confidence, editable, step, steps, mask_id):
# editable: positions allowed to be mask/proposal, excludes prompt and padding
gen_len = editable.sum(dim=1)
ratio = math.cos(0.5 * math.pi * step / steps) ** 2
keep = torch.ceil(ratio * gen_len).long()
out = torch.where(editable & ids.eq(mask_id), proposal, ids)
score = confidence.masked_fill(~editable, float("inf"))
for b in range(ids.size(0)):
k = min(int(keep[b]), int(editable[b].sum()))
if k > 0:
pos = score[b].topk(k=k, largest=False).indices
out[b, pos] = mask_id
return out这个代码仍然很朴素,但它显式区分了 editable 和 active。active 是当前仍为 [MASK] 的位置;editable 是允许被 sampler 改写的位置。条件生成、infilling 和编辑任务里,这两个集合不能混。
极简 sampler 可以写成:
@torch.no_grad()
def mdm_sample(model, length, steps, mask_id, temperature):
ids = torch.full((1, length), mask_id, device=model.device)
for s in range(1, steps + 1):
t = torch.tensor([1.0 - (s - 1) / steps], device=ids.device)
logits = model(input_ids=ids, noise_level=t).logits
probs = torch.softmax(logits / temperature, dim=-1)
pred = torch.multinomial(
probs.view(-1, probs.size(-1)),
num_samples=1,
).view_as(ids)
conf = probs.gather(-1, pred[..., None]).squeeze(-1)
active = ids.eq(mask_id)
proposal = torch.where(active, pred, ids)
keep = int(math.ceil((1.0 - s / steps) * length))
if keep == 0:
ids = proposal
break
conf = conf.masked_fill(~active, float("inf"))
remask_pos = conf.topk(k=keep, largest=False).indices
ids = proposal
ids.scatter_(1, remask_pos, mask_id)
return ids这段代码故意写得朴素:真实实现需要处理 batch、padding、prompt constraints、EOS、temperature/top-p、重复惩罚,以及“已经 commit 的 token 是否允许再次 remask”。但它揭示了 MDM sampler 的核心循环:predict、score、commit/remask、repeat。
For conditional generation, prompt/context tokens should normally be excluded from remasking and loss. Otherwise the sampler may rewrite the instruction or source text.
Parallel Decoding Consistency
MDM 每一步会并行提议多个 token。并行的好处是减少 sequential depth,风险是局部提议之间可能互相不一致。例如两个位置分别生成:
The capital of France is [M], and it is located in [M].
早期 step 可能同时提议 Paris 和 Germany。如果两个位置都被高置信度 commit,后续模型会在错误上下文上继续修补。remasking 的本质就是给 sampler 一个撤销机制。
可以把一次 denoising step 拆成三个张量:
| tensor | role |
|---|---|
proposal_ids |
当前模型为 active mask 给出的候选 token |
commit_mask |
本步决定接受哪些 proposal |
next_ids |
接受 proposal 后再按 schedule/remask 得到的下一状态 |
工程上不要直接覆盖 ids 后再计算 confidence,否则会混淆“模型基于哪个状态预测”和“本步之后状态是什么”。更稳的顺序是:
logits = model(ids, noise_level=t).logits
proposal_ids, confidence = sample_or_argmax(logits)
commit_mask = choose_commit_positions(ids, confidence, editable, schedule)
next_ids = ids.clone()
next_ids[commit_mask] = proposal_ids[commit_mask]
next_ids[remask_mask] = mask_idRun the sampler with a fixed prompt and assert that prompt token ids are identical before and after every denoising step.
Objective as a Weighted Denoising Problem
如果 \(t\) 越大,mask ratio 越高,任务越难;\(t\) 越小,任务越接近局部填空。实际训练通常会采样不同 \(t\),让模型学会从不同噪声强度恢复文本:
\[ \mathcal{L} = \mathbb{E}_{t\sim p(t)} \mathbb{E}_{x_t\sim q(x_t\mid x_0)} \left[ \lambda(t) \sum_{i:x_t^i=\text{M}} -\log p_\theta(x_0^i\mid x_t,t) \right]. \]
\(\lambda(t)\) 可以调节不同噪声强度的权重。直觉上,高 mask ratio 训练全局语义建模,低 mask ratio 训练局部语法和精确恢复。
训练 batch 构造可以抽象成:
def corrupt_for_mdm(input_ids, attention_mask, mask_id, t):
# input_ids: [B, L], attention_mask: 1 for real tokens, 0 for padding
clean = input_ids.clone()
valid = attention_mask.bool()
prob = t[:, None].expand_as(input_ids)
noise = torch.rand_like(prob.float())
is_masked = (noise < prob) & valid
# optional: force at least one masked token per nonempty sequence
empty = is_masked.sum(dim=1).eq(0)
if empty.any():
first_valid = valid.float().argmax(dim=1)
is_masked[empty, first_valid[empty]] = True
corrupted = torch.where(
is_masked,
torch.full_like(input_ids, mask_id),
input_ids,
)
labels = clean.masked_fill(~is_masked, -100)
return corrupted, labels, is_masked然后 training step 里做:
t = sample_noise_level(batch_size, device=input_ids.device)
corrupted, labels, is_masked = corrupt_for_mdm(
input_ids=input_ids,
attention_mask=attention_mask,
mask_id=tokenizer.mask_token_id,
t=t,
)
out = model(
input_ids=corrupted,
attention_mask=attention_mask,
noise_level=t,
)
token_loss = F.cross_entropy(
out.logits.view(-1, out.logits.size(-1)),
labels.view(-1),
ignore_index=-100,
reduction="none",
).view_as(labels)
weight = lambda_t(t)[:, None]
loss_num = (token_loss * is_masked * weight).sum()
loss_den = is_masked.sum().clamp_min(1)
loss = loss_num / loss_den这里有几个实现点值得反复检查:
- padding 位置不能被 corruption 成
[MASK]; - padding 位置不能参与 loss;
attention_mask仍然传给模型,防止 padding 被其他 token attend;t要和样本绑定,而不是和 token 随机独立绑定,否则同一句内部的噪声等级解释会变乱;- 如果使用 sequence packing,packed sample 之间的 block diagonal attention mask 也要保留,否则不同文档会互相泄漏上下文。
如果要用 fixed-count corruption,上面的 noise < prob 可以替换成每条样本对 valid positions 采样 top-\(k\)。例如令 \(k_b=\max(1,\lfloor t_b n_b\rceil)\),其中 \(n_b\) 是第 \(b\) 条样本的非 padding token 数。这样每条样本的监督 token 数更稳定,但数学上不再是逐位置独立 Bernoulli。
A Tiny Numerical Example
句子长度为 \(L=4\),在 \(t=0.5\) 时每个 token 独立以 \(0.5\) 概率被 mask。真实句子:
deep learning is fun
某次 corruption 得到:
deep [M] is [M]
模型输出:
\[ p_\theta(x_2=\text{learning}\mid x_t)=0.7, \qquad p_\theta(x_4=\text{fun}\mid x_t)=0.4. \]
这次样本的 denoising loss 是
\[ -\log 0.7-\log 0.4\approx1.273. \]
这和普通 cross entropy 一样,只是 supervised positions 由 forward corruption process 随机决定。
Comparison With AR Decoding
Suppose output length is \(L\) and diffusion sampling uses \(S\) denoising steps. AR decoding requires \(L\) model calls, one per token. Masked diffusion requires roughly \(S\) calls, each predicting many positions in parallel.
If \(S\ll L\), MDM can reduce sequential depth. But each denoising step processes the whole sequence, and there is no standard KV cache reuse like AR decoding because tokens can be revised bidirectionally.
| Cost dimension | AR LM | Masked diffusion LM |
|---|---|---|
| sequential steps | \(L\) | \(S\) |
| per step input | prefix/new token with cache | full partial sequence |
| cache reuse | KV cache natural | harder because tokens change |
| editing/infilling | awkward | natural |
| length handling | EOS natural | needs length control |
所以 MDM 的优势不是“必然更快”,而是把串行深度从 token-level 改成 denoising-step-level。真实速度取决于 \(S\)、序列长度、cache 机制、硬件 kernel 和 remasking 策略。
KV cache 的差别尤其关键。AR decoding 在第 \(i\) 步只新增一个 token,过去 token 的 key/value 不再改变,因此可以缓存:
\[ K_{\le i}=[K_{\le i-1};K_i], \qquad V_{\le i}=[V_{\le i-1};V_i]. \]
MDM 的第 \(s\) 步会把多个 [MASK] 改成 token,也可能把低置信度 token 改回 [MASK]。由于 self-attention 是双向的,任意位置 token 改变都会改变后续层中许多位置的 hidden states:
\[ h_j^{(\ell+1)} = \sum_i \mathrm{softmax}\left( \frac{q_j^\top k_i}{\sqrt{d_h}} \right)v_i. \]
只要某个 \(x_i\) 变了,\(k_i,v_i\) 变;进一步,所有 attend 到 \(i\) 的 \(h_j\) 也变;下一层的 \(q_j,k_j,v_j\) 又随之变化。这就是为什么 naive MDM sampler 每一步通常要重新跑整段序列,而不能像 GPT decode 那样只追加一个 token。
可能的工程优化方向包括:
| idea | what it tries to reuse | limitation |
|---|---|---|
| frozen committed tokens | 不再 remask 的 token 的局部表示 | 双向层中它们仍受其他 token 变化影响 |
| block-wise denoising | 把序列切块,只更新活跃 block | 跨 block attention 会带来一致性问题 |
| shallow refinement | 前几层缓存、后几层更新 | 近似推理,质量需要重新验证 |
| fewer denoising steps | 用 schedule 减少 \(S\) | 过早 commit 会损害质量 |
所以系统层面的判断不是“\(S<L\) 就赢”。更合理的粗略比较是:
\[ \text{AR cost} \approx L\cdot C_{\text{decode-with-cache}}, \qquad \text{MDM cost} \approx S\cdot C_{\text{full-sequence}}. \]
当序列很长、\(S\) 足够小、并且 full-sequence kernel 吞吐高时,MDM 有机会减少 wall-clock;当 \(S\) 接近 \(L\) 或 remasking 太频繁时,它可能比 AR 更贵。
Failure Modes
Masked diffusion LM 的常见问题:
| Failure | Cause |
|---|---|
| repetition | low-confidence positions keep reinforcing same local pattern |
| inconsistency | parallel fills conflict before later revision |
| length mismatch | target length or EOS handling wrong |
| weak exact copying | no left-to-right pointer-like generation |
| slow wall-clock | too many denoising steps over full sequence |
这些问题和 AR LM 不一样。AR 的错误常来自早期 token 造成 exposure bias;MDM 的错误常来自并行局部决策之间不一致,或 schedule 太早 commit 错误 token。
Implementation Checklist
实现一个 LLaDA-style MDM 时,可以按下面顺序自检:
- tokenizer 里确实有专用
mask_token_id,且它不会和 padding/EOS 混用; - corruption 只作用于真实 token,不作用于 padding 和固定 prompt;
labels只在 masked positions 保留 clean token,其余位置为 ignore index;- attention 使用 bidirectional visibility,但 padding/packing boundary 仍然被 mask;
- 模型显式接收 noise level \(t\) 或 discrete step embedding;
- loss 同时考虑 \(\lambda(t)\) 和 masked-token normalization;
- 分布式训练时 all-reduce loss numerator 和 denominator,而不是每卡先局部平均;
- sampler 明确区分 proposal token、committed token、remasked token;
- conditional generation 时 prompt/source tokens 被 frozen;
- length control、EOS handling 和 unused slots 有一致规则;
- 推理 benchmark 同时报告 denoising steps、full-sequence tokens processed、wall-clock、quality;
- posterior-style reverse step 是否满足 \(1-s/t\) 的 unmask probability;
- sampler 是否区分
editable、active mask、committed和frozen; - prompt/source/padding tokens 在每一步采样后是否保持不变;
- 对同一个 checkpoint 比较不同 remasking confidence policy,而不是只看训练 loss。
一个小实验很能暴露问题:取一批真实句子,固定 \(t\in\{0.1,0.3,0.6,0.9\}\),分别统计 masked token CE、top-1 accuracy、sequence-level exact recovery。若 \(t=0.1\) 表现很好但 \(t=0.9\) 崩,说明模型会局部填空但缺全局语义;若高 \(t\) 还行但低 \(t\) 不稳,通常是 loss 权重、time embedding 或 label mask 写错。
再加两个 sampler tests:
- 从 \(t=1\) 到 \(s=0.5\) 的 posterior-style step,统计大量样本的剩余 mask 比例,应接近 \(0.5\);
- 对 conditional generation,记录每一步 prompt span 的 token ids,确认 sampler 从未改写 frozen span。
Connection to Chapter 2
LLaDA is the natural continuation of discrete modeling: token prediction remains cross entropy over a vocabulary, but the conditioning pattern is denoising rather than purely causal. This connects BERT-style masked modeling, diffusion-style iterative refinement, and LLM-scale generation.
References
- Nie et al., Large Language Diffusion Models