2.6 Generative Adversarial Networks
GAN 是深度生成模型中最具有博弈论气质的一类方法。它不显式写 likelihood,而是让 generator 和 discriminator 对抗:一个负责造样本,一个负责识别真假。
Running Example: Learning a 1D Gaussian
先从最小例子理解 GAN。真实数据来自
\[ x\sim \mathcal{N}(3,1). \]
generator 从简单噪声 \(z\sim\mathcal{N}(0,1)\) 生成
\[ G_\theta(z)=az+b. \]
如果一开始 \(a=1,b=0\),那么 generator 生成的是 \(\mathcal{N}(0,1)\),明显偏左。discriminator 看到真实样本大多在 \(3\) 附近,假样本大多在 \(0\) 附近,于是学会:
\[ D(x)\approx \begin{cases} 1, & x\approx3,\\ 0, & x\approx0. \end{cases} \]
generator 的更新方向来自 discriminator 的反馈:它需要把 \(G(z)\) 推到 discriminator 更容易判真的区域,也就是把 \(b\) 往 \(3\) 推。这个例子里,GAN 训练不是直接比较均值方差,而是通过“真假分类器的梯度”间接移动生成分布。
GAN does not ask the generator to match each data point. It asks the generator to move its distribution into regions where a learned discriminator can no longer separate fake from real.
Minimax Objective
Given data distribution \(p_{\text{data}}\), latent prior \(z\sim p(z)\), generator \(G_\theta(z)\), and discriminator \(D_\phi(x)\), GAN solves \[ \min_\theta \max_\phi V(D_\phi,G_\theta) = \mathbb{E}_{x\sim p_{\text{data}}}\log D_\phi(x) + \mathbb{E}_{z\sim p(z)} \log(1-D_\phi(G_\theta(z))). \]
discriminator 学习真假分类,generator 学习让假样本被判真。这个训练目标把生成建模变成了一个 two-player game。
从二分类角度看,discriminator 在做 density-ratio estimation。若真实样本和生成样本先验各占一半,则 Bayes 最优后验为
\[ D^\star(x) = P(y=\text{real}\mid x) = \frac{p_{\text{data}}(x)} {p_{\text{data}}(x)+p_g(x)}. \]
因此 logit
\[ s^\star(x) = \log\frac{D^\star(x)}{1-D^\star(x)} = \log\frac{p_{\text{data}}(x)}{p_g(x)}. \]
也就是 log density ratio。GAN 的一个核心直觉是:我们不显式估计 \(p_{\text{data}}(x)\),而是训练一个分类器,让它的边界和 logit 为 generator 提供“哪里更像真实数据”的方向。
The density ratio between two distributions is \(p(x)/q(x)\). A calibrated discriminator trained to separate samples from \(p\) and \(q\) implicitly estimates this ratio through its logits.
Optimal Discriminator
固定 generator 后,最优 discriminator 为
\[ D^\star(x) = \frac{p_{\text{data}}(x)} {p_{\text{data}}(x)+p_g(x)}. \]
With the optimal discriminator, the GAN value reduces to \[ V(D^\star,G) = -\log 4 +2\,\operatorname{JS}(p_{\text{data}}\Vert p_g). \] Thus the global optimum is reached when \(p_g=p_{\text{data}}\).
对每个 \(x\),最大化
\[ p_{\text{data}}(x)\log D(x)+p_g(x)\log(1-D(x)). \]
对 \(D(x)\) 求导:
\[ \frac{p_{\text{data}}(x)}{D(x)} - \frac{p_g(x)}{1-D(x)} =0. \]
解得
\[ D^\star(x) = \frac{p_{\text{data}}(x)} {p_{\text{data}}(x)+p_g(x)}. \]
令
\[ m(x)=\frac12(p_{\text{data}}(x)+p_g(x)). \]
代回 value:
\[ V(D^\star,G) = \int p_{\text{data}}(x) \log \frac{p_{\text{data}}(x)} {p_{\text{data}}(x)+p_g(x)} dx + \int p_g(x) \log \frac{p_g(x)} {p_{\text{data}}(x)+p_g(x)} dx. \]
因为 \(p_{\text{data}}+p_g=2m\),
\[ V(D^\star,G) = \int p_{\text{data}} \log \frac{p_{\text{data}}}{2m} dx + \int p_g \log \frac{p_g}{2m} dx. \]
拆出 \(\log 2\):
\[ V(D^\star,G) = -2\log2 + \operatorname{KL}(p_{\text{data}}\Vert m) + \operatorname{KL}(p_g\Vert m). \]
而
\[ \operatorname{JS}(p_{\text{data}}\Vert p_g) = \frac12\operatorname{KL}(p_{\text{data}}\Vert m) + \frac12\operatorname{KL}(p_g\Vert m). \]
所以
\[ V(D^\star,G) = -\log4 +2\operatorname{JS}(p_{\text{data}}\Vert p_g). \]
JS divergence 非负,且当且仅当两个分布相同为 0,因此全局最优为 \(p_g=p_{\text{data}}\)。
Why JS Can Saturate
JS divergence 的理论漂亮,但在高维生成里有一个很尖锐的问题:真实分布和生成分布的 support 常常几乎不重叠。比如真实数据落在一条 manifold 上,初始 generator 落在另一条 manifold 上。若二者支撑集不相交,则最优 discriminator 可以完美分类:
\[ D^\star(x)= \begin{cases} 1,&x\in \operatorname{supp}(p_{\text{data}}),\\ 0,&x\in \operatorname{supp}(p_g). \end{cases} \]
此时
\[ \operatorname{JS}(p_{\text{data}}\Vert p_g)=\log 2, \]
达到最大值。问题不是数值上“loss 大”,而是这个最大值对“两个 support 离多远”不敏感:差一点、差很远,JS 都饱和。
若 \(p\) 和 \(q\) 支撑集不相交,令 \(m=(p+q)/2\)。在 \(p\) 的支撑上,\(m=p/2\),所以
\[ \operatorname{KL}(p\Vert m) = \int p(x)\log\frac{p(x)}{p(x)/2}dx = \log 2. \]
同理,在 \(q\) 的支撑上,
\[ \operatorname{KL}(q\Vert m)=\log 2. \]
因此
\[ \operatorname{JS}(p\Vert q) = \frac12\log2+\frac12\log2 = \log2. \]
这解释了为什么 GAN 的训练早期会特别脆弱:discriminator 太容易赢,理论上它确实学到了最优分类器,但这个分类器不一定给 generator 一个平滑的“往哪里挪”的方向。WGAN 的动机正是把“真假能不能分开”换成“把一个分布搬到另一个分布需要多大代价”。
Non-Saturating Loss
原始 generator loss
\[ \mathbb{E}_{z}\log(1-D(G(z))) \]
在 discriminator 很强时梯度会饱和。实际训练常用 non-saturating loss:
\[ \mathcal{L}_G = -\mathbb{E}_{z}\log D(G(z)). \]
它不改变最优点,但给 generator 更强梯度。
Why the Original Loss Saturates
如果 discriminator 很快变强,对 fake samples 有
\[ D(G(z))\approx0. \]
原始 generator 最小化
\[ \mathcal{L}_G^{\text{sat}} = \mathbb{E}_z\log(1-D(G(z))). \]
对 discriminator logit \(s\),令 \(D=\sigma(s)\),则
\[ \frac{\partial}{\partial s}\log(1-\sigma(s)) = -\sigma(s). \]
当 \(D\approx0\) 时,梯度接近 0。也就是说 discriminator 越自信 fake 是 fake,原始 generator loss 反而越没有梯度。
Non-saturating loss 为
\[ \mathcal{L}_G^{\text{ns}} = -\mathbb{E}_z\log D(G(z)). \]
其 logit 梯度是
\[ \frac{\partial}{\partial s}[-\log\sigma(s)] = -(1-\sigma(s)). \]
当 \(D\approx0\) 时,梯度约为 \(-1\),generator 仍能收到强信号。
If the discriminator separates real and fake too easily, the generator may receive weak or uninformative gradients. GAN training needs a balanced game, not a discriminator that wins instantly.
Common GAN Loss Variants
实际代码通常直接让 discriminator 输出 logits \(s(x)\),而不是先 sigmoid 成概率。这样可以用 BCEWithLogitsLoss 避免 \(\log(0)\) 数值问题。
| variant | discriminator loss | generator loss | comment |
|---|---|---|---|
| minimax BCE | \(-\log\sigma(s_r)-\log(1-\sigma(s_f))\) | \(\log(1-\sigma(s_f))\) | original, generator saturates |
| non-saturating BCE | same as BCE | \(-\log\sigma(s_f)\) | common default |
| hinge GAN | \(\max(0,1-s_r)+\max(0,1+s_f)\) | \(-s_f\) | common in image GANs |
| least-squares GAN | \((s_r-1)^2+s_f^2\) | \((s_f-1)^2\) | smoother quadratic penalty |
| WGAN | \(-(s_r-s_f)\) for critic minimization | \(-s_f\) | critic score, no sigmoid |
hinge loss 的直觉是 margin classification:real logit 只要大于 \(1\) 就不再被强推高,fake logit 只要小于 \(-1\) 就不再被强推低。这能防止 discriminator 无限增大 logit scale,把容量留给边界附近样本。
用 logits 写 non-saturating GAN:
d_real = discriminator(real)
d_fake = discriminator(fake.detach())
d_loss = (
F.binary_cross_entropy_with_logits(d_real, torch.ones_like(d_real))
+ F.binary_cross_entropy_with_logits(d_fake, torch.zeros_like(d_fake))
)
g_fake = discriminator(generator(z))
g_loss = F.binary_cross_entropy_with_logits(g_fake, torch.ones_like(g_fake))这个写法比 torch.log(discriminator(x)) 更稳,因为 sigmoid 和 log 被融合成数值稳定的 log-sum-exp 形式。
Mode Collapse
Mode collapse happens when the generator maps many latent inputs to a small set of outputs that fool the discriminator. The samples can look sharp but cover only part of the data distribution.
GAN 的问题不是“生成质量不够锐利”,而是分布覆盖困难。相比 VAE/Diffusion,GAN 很擅长生成局部逼真的样本,但 likelihood-free 训练让模式覆盖和评估都更难。
Mode collapse 可以从 game dynamics 理解。假设真实分布有多个 modes:
real: A B C D
fake: B
如果 generator 发现 mode B 很容易骗过当前 discriminator,它可能把大量 \(z\) 映射到 B 附近。discriminator 下一轮会惩罚 B,但 generator 又可能跳到另一个 mode。训练不是单目标 convex optimization,而是两个模型互相追逐,所以 oscillation 和 collapse 都很自然。
常见缓解方式:
| Method | Intuition |
|---|---|
| minibatch discrimination | discriminator sees sample diversity |
| unrolled GAN | generator anticipates discriminator updates |
| feature matching | match discriminator feature statistics |
| spectral norm | constrain discriminator Lipschitzness |
| WGAN-GP | smoother critic signal |
| EMA generator | average generator trajectory |
从映射角度看,mode collapse 是 generator 的 Jacobian 退化。若大量不同 \(z\) 被映射到相近样本,
\[ G(z_1)\approx G(z_2)\quad \text{for many }z_1\ne z_2, \]
则 latent space 的有效维度被压扁。可以检查局部 Jacobian
\[ J_G(z)=\frac{\partial G(z)}{\partial z} \]
的 singular values:若很多方向 singular value 接近 0,说明改变 latent 不怎么改变输出。当然真实图像模型不会频繁显式算完整 Jacobian,但这个视角能解释为什么 diversity loss、minibatch discrimination、latent regression 等方法会尝试强迫 \(z\) 的变化反映到样本变化上。
GAN collapse 诊断通常要结合多个信号:
| signal | collapse symptom |
|---|---|
| nearest-neighbor duplicate rate | generated samples repeat |
| class histogram | few classes dominate |
| precision high, recall low | samples sharp but coverage poor |
| latent interpolation | large intervals map to same output |
| discriminator fake logits | generator exploits narrow weakness |
WGAN
Wasserstein GAN 用 Earth Mover distance 替代 JS divergence。直觉上,JS divergence 在两个分布 support 不重叠时会饱和,而 Wasserstein distance 仍然能反映“要搬多远”。
\[ W(p_{\text{data}},p_g) = \sup_{\|f\|_L\leq1} \mathbb{E}_{x\sim p_{\text{data}}}f(x) - \mathbb{E}_{x\sim p_g}f(x). \]
critic 不再输出真假概率,而是输出分数。为了满足 1-Lipschitz constraint,WGAN-GP 加入 gradient penalty:
\[ \lambda \mathbb{E}_{\hat{x}} \left(\|\nabla_{\hat{x}}D(\hat{x})\|_2-1\right)^2. \]
A critic \(f\) is 1-Lipschitz if \[ |f(x)-f(y)|\le \|x-y\| \] for all \(x,y\). WGAN constrains the critic to this function class so that the dual objective equals Wasserstein-1 distance.
WGAN 的训练目标常写成 critic maximization:
\[ \max_{\phi} \mathbb{E}_{x\sim p_{\text{data}}}D_\phi(x) - \mathbb{E}_{z\sim p(z)}D_\phi(G_\theta(z)). \]
Generator 最小化
\[ \mathcal{L}_G = - \mathbb{E}_{z\sim p(z)}D_\phi(G_\theta(z)). \]
相比 sigmoid discriminator,critic 的输出不必压到 \([0,1]\),因此梯度更像“把 fake 往 real 分布方向推”的连续势能场。
Kantorovich-Rubinstein duality 告诉我们:
\[ W_1(p,q) = \sup_{\|f\|_L\le 1} \mathbb{E}_{x\sim p}f(x) - \mathbb{E}_{x\sim q}f(x). \]
这里 \(f\) 是 1-Lipschitz critic。它不是概率分类器,而是在学习一个“地势函数”:真实分布区域得分高,生成分布区域得分低,同时斜率被限制,不能无限陡。generator 沿着 critic 的梯度把样本往高分区域推。
一个 1D 例子很直观。设
\[ p=\delta_0,\qquad q=\delta_a. \]
则 Wasserstein-1 距离就是
\[ W_1(p,q)=|a|. \]
即使两个点分布完全不重叠,距离也随 \(a\) 连续变化;而 JS divergence 对任何 \(a\ne0\) 都是 \(\log2\)。这就是 WGAN 在 support mismatch 下更有训练信号的原因。
对任意 1-Lipschitz \(f\),
\[ f(0)-f(a)\le |a|. \]
取 \(f(x)=-x\) 当 \(a>0\),或 \(f(x)=x\) 当 \(a<0\),即可达到 \(|a|\)。因此 dual objective 的 supremum 是 \(|a|\)。
Original WGAN used weight clipping, but clipping can reduce critic capacity and cause optimization pathologies. Gradient penalty or spectral normalization is usually more stable.
WGAN-GP 通常在 real 和 fake 的线性插值点上施加 gradient penalty:
\[ \hat{x}=\epsilon x+(1-\epsilon)G(z), \qquad \epsilon\sim U(0,1). \]
实现时要对 \(\hat{x}\) 开启 gradient:
eps = torch.rand(real.size(0), *([1] * (real.ndim - 1)), device=real.device)
x_hat = eps * real + (1 - eps) * fake.detach()
x_hat.requires_grad_(True)
score_hat = critic(x_hat)
grad = torch.autograd.grad(
outputs=score_hat.sum(),
inputs=x_hat,
create_graph=True,
)[0]
gp = ((grad.flatten(1).norm(2, dim=1) - 1.0) ** 2).mean()常见 critic loss 写成 minimization:
\[ \mathcal{L}_D = \mathbb{E}D(G(z)) - \mathbb{E}D(x) + \lambda_{\text{gp}}\mathcal{L}_{\text{gp}}. \]
符号很容易写反:如果用 optimizer 最小化 loss,real score 应该被推高,fake score 应该被推低。
Spectral Normalization and R1 Regularization
Spectral normalization 约束每个线性/卷积层的 operator norm:
\[ \bar{W}=\frac{W}{\sigma_{\max}(W)}. \]
若每层激活是 1-Lipschitz,那么网络 Lipschitz 常数可由各层 spectral norm 的乘积控制。实际中用 power iteration 近似最大奇异值:
u = normalize(torch.randn(out_dim))
for _ in range(n_power_iter):
v = normalize(W.T @ u)
u = normalize(W @ v)
sigma = u @ W @ v
W_bar = W / sigmaR1 regularization 则在真实样本上惩罚 discriminator 梯度:
\[ \frac{\gamma}{2} \mathbb{E}_{x\sim p_{\text{data}}} \|\nabla_x D(x)\|_2^2. \]
它的直觉是让 discriminator 在真实数据 manifold 附近更平滑,避免过度尖锐的判别边界。和 WGAN-GP 不同,R1 常用于 non-saturating / logistic GAN,而不是把 critic 解释成 Wasserstein dual。
| regularizer | applied to | goal |
|---|---|---|
| gradient penalty | interpolation between real/fake | approximate 1-Lipschitz critic |
| spectral norm | discriminator weights | global Lipschitz control |
| R1 | real samples | smooth discriminator near data |
| path length regularization | generator mapping | smoother latent-to-image geometry |
Practical Stabilization
GAN 代码最常见的错误不是公式写错,而是训练节奏失衡。几个工程要点:
- discriminator update 通常可以多于 generator update;
- 训练 discriminator 时必须 detach fake samples;
- generator loss 不能 detach fake samples;
- discriminator 太强时要降低容量、加正则或减少更新;
- discriminator 太弱时 generator 梯度不可信;
- 用 EMA generator 做采样通常更稳;
- 监控 real/fake logits、gradient norm、sample diversity,而不是只看 loss。
GAN loss 本身不一定和视觉质量单调对应。一个 discriminator loss 上升可能表示 generator 变好了,也可能表示 discriminator 崩了。因此需要看样本、FID、precision/recall、mode coverage。
训练节奏可以看成控制 game balance。若每次 generator 更新前训练 \(n_D\) 次 discriminator:
repeat n_D times:
update D with real and detached fake
update G with non-detached fake
\(n_D\) 太小,discriminator 给出的梯度不可靠;\(n_D\) 太大,discriminator 可能过强,generator 梯度变差。WGAN 常用多个 critic steps;普通 non-saturating GAN 则更常 1:1 或小心调节。
EMA generator 是另一个常见技巧:
\[ \theta_{\text{ema}} \leftarrow \beta\theta_{\text{ema}}+(1-\beta)\theta. \]
采样时使用 \(G_{\theta_{\text{ema}}}\) 往往更稳,因为 adversarial training 的瞬时 generator 会沿着博弈轨迹振荡,EMA 相当于平滑这条轨迹。
TTUR and Update-Time Scales
GAN 不是单个 loss 的最小化,而是两个优化器耦合在一起。把参数写成 \(\theta\) for generator、\(\phi\) for discriminator,一个 alternating update 是:
\[ \phi_{k+1} = \phi_k-\eta_D\nabla_\phi L_D(\phi_k,\theta_k), \]
\[ \theta_{k+1} = \theta_k-\eta_G\nabla_\theta L_G(\theta_k,\phi_{k+1}). \]
Two-time-scale update rule (TTUR) 的直觉是:让 discriminator/critic 和 generator 使用不同学习率或不同更新频率,使 evaluator 保持“足够强但不过强”。
TTUR trains the generator and discriminator with different effective time scales, usually via different learning rates, update frequencies, or both.
有效 time scale 不只是 learning rate,还包括每轮更新次数。若每次 generator step 前做 \(n_D\) 次 discriminator step,则一种粗略尺度是:
\[ \tau_D\approx n_D\eta_D, \qquad \tau_G\approx \eta_G. \]
当 \(\tau_D\gg\tau_G\) 时,discriminator 很快接近当前 generator 的最佳响应,可能让 generator 梯度变弱或高度局部;当 \(\tau_D\ll\tau_G\) 时,generator 在追一个落后的 evaluator,容易收到错误梯度。工程上常用配置不是因为它们神秘,而是在调这个相对速度:
| Setting | Effect |
|---|---|
| larger \(\eta_D\) | faster discriminator response |
| larger \(\eta_G\) | more aggressive generator movement |
| larger \(n_D\) | critic closer to best response |
| stronger D regularization | slows and smooths discriminator |
| EMA decay larger | slower evaluation generator trajectory |
Increasing discriminator steps changes not only the number of gradients but also Adam moment estimates, regularization frequency, and wall-clock data exposure. It is not equivalent to only scaling the discriminator learning rate.
一个训练日志至少要记录:
g_steps
d_steps
d_steps_per_g_step
lr_g, lr_d
beta1_g, beta2_g, beta1_d, beta2_d
regularization_interval
ema_decay
否则不同 run 之间很难比较。两个 run 都叫 “same GAN loss”,但只要 \(n_D\)、TTUR 或 lazy regularization interval 不同,game dynamics 就已经不同。
Lazy Regularization and Scale Correction
R1、path length regularization、某些 gradient penalty 很贵,因为它们需要对输入或中间表示求梯度。实践中常做 lazy regularization:每 \(K\) 个 step 才加一次正则。
如果原目标是每步都优化:
\[ L(\phi) = L_{\text{main}}(\phi)+\lambda R(\phi), \]
lazy 版本在非正则 step 用 \(L_{\text{main}}\),在正则 step 用:
\[ L_{\text{lazy}}(\phi) = L_{\text{main}}(\phi)+K\lambda R(\phi). \]
If regularization is applied once every \(K\) steps and multiplied by \(K\), the average regularization gradient over a full interval equals the gradient of applying the regularizer every step, ignoring parameter drift within the interval.
假设在一个长度为 \(K\) 的小区间内参数近似不变为 \(\phi\)。每步正则的平均梯度是:
\[ \frac1K \sum_{j=1}^K \lambda\nabla R(\phi) = \lambda\nabla R(\phi). \]
lazy regularization 只在其中一步施加 \(K\lambda R(\phi)\),区间平均正则梯度为:
\[ \frac1K K\lambda\nabla R(\phi) = \lambda\nabla R(\phi). \]
所以乘以 \(K\) 后,在忽略区间内参数漂移时,平均正则强度一致。若不乘以 \(K\),正则强度会变成原来的 \(1/K\)。
R1 的 lazy 实现通常类似:
def r1_penalty(discriminator, real):
real = real.detach().requires_grad_(True)
real_logits = discriminator(real)
grad = torch.autograd.grad(
outputs=real_logits.sum(),
inputs=real,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
return grad.square().flatten(1).sum(dim=1).mean()
if step % r1_interval == 0:
penalty = r1_penalty(discriminator, real)
d_loss = d_loss + (r1_gamma * 0.5 * r1_interval) * penalty这里有两个容易写错的地方:
real.detach().requires_grad_(True)让 R1 对输入求梯度,但不把 dataloader 或 augmentation 图拉进来;create_graph=True让 penalty 能继续对 discriminator 参数反传。
Even with scale correction, lazy regularization produces intermittent large gradients. Adam’s moments see a different gradient sequence than dense regularization, so exact equivalence does not hold. Log the interval and tune learning rates with it.
Optimizer Boundary Contract
GAN 训练 loop 里必须显式隔离三条边界:
| Boundary | During D update | During G update |
|---|---|---|
| fake image tensor | detached | not detached |
| discriminator params | requires grad / step | may receive gradients but not stepped, or frozen |
| generator params | no gradient | requires grad / step |
为了省显存和避免误累积,很多代码在 G update 时临时冻结 discriminator 参数:
def set_requires_grad(module, flag):
for p in module.parameters():
p.requires_grad_(flag)
# D step
set_requires_grad(discriminator, True)
set_requires_grad(generator, False)
fake = generator(z).detach()
d_loss = d_loss_fn(discriminator(real), discriminator(fake))
d_opt.zero_grad(set_to_none=True)
d_loss.backward()
d_opt.step()
# G step
set_requires_grad(discriminator, False)
set_requires_grad(generator, True)
fake = generator(z)
g_loss = g_loss_fn(discriminator(fake))
g_opt.zero_grad(set_to_none=True)
g_loss.backward()
g_opt.step()冻结 discriminator 参数不等于不需要 discriminator 的 backward。G step 仍要通过 discriminator 把
\[ \nabla_x D_\phi(x) \]
传给 fake image,再传回 generator。冻结的是 \(\nabla_\phi\),不是 \(\nabla_x\)。如果把整个 discriminator forward 包在 torch.no_grad() 里,generator 就收不到梯度。
During generator update, discriminator parameters do not need gradients, but the discriminator computation must remain differentiable with respect to its input fake samples.
generator loss 可写成:
\[ L_G(\theta) = \ell(D_\phi(G_\theta(z))). \]
链式法则给:
\[ \nabla_\theta L_G = \frac{\partial \ell}{\partial D} \frac{\partial D_\phi(x)}{\partial x}\bigg|_{x=G_\theta(z)} \frac{\partial G_\theta(z)}{\partial \theta}. \]
这里不需要 \(\partial D_\phi/\partial \phi\),所以可以让 discriminator 参数 requires_grad=False。但必须保留 \(\partial D_\phi/\partial x\),所以不能 detach(fake),也不能在 G step 对 discriminator forward 使用 no_grad()。
两个额外 smoke tests 很有用:
def assert_g_step_has_input_grad(discriminator, generator, z):
discriminator.zero_grad(set_to_none=True)
generator.zero_grad(set_to_none=True)
fake = generator(z)
fake.retain_grad()
loss = -discriminator(fake).mean()
loss.backward()
assert fake.grad is not None
assert any(p.grad is not None for p in generator.parameters())
def assert_frozen_does_not_accumulate_param_grad(discriminator, generator, z):
discriminator.zero_grad(set_to_none=True)
generator.zero_grad(set_to_none=True)
set_requires_grad(discriminator, False)
fake = generator(z)
loss = -discriminator(fake).mean()
loss.backward()
assert all(p.grad is None for p in discriminator.parameters())
assert any(p.grad is not None for p in generator.parameters())这些测试比看 loss 曲线更直接:它们检查 GAN 训练最基本的梯度边界是否正确。
建议监控:
| metric | useful signal |
|---|---|
D(real) mean/logit |
real score 是否持续高 |
D(fake) mean/logit |
fake score 是否被分开 |
D(real)-D(fake) |
WGAN critic gap |
| grad norm of D/G | 是否一方梯度消失或爆炸 |
| generated sample grid | 人眼发现 collapse 很快 |
| diversity / recall | 是否覆盖模式 |
| EMA vs raw generator samples | 轨迹是否剧烈振荡 |
| \(n_D\), lr ratio, reg interval | game time scale 是否改变 |
| R1/GP value | 正则是否生效或爆炸 |
Evaluation: Quality and Coverage
生成模型评估至少有两个维度:
| Dimension | Question |
|---|---|
| fidelity | samples look realistic? |
| diversity | distribution covers real modes? |
FID 用 Inception feature 的均值和协方差近似比较真实样本和生成样本:
\[ \operatorname{FID} = \|\mu_r-\mu_g\|_2^2 + \operatorname{Tr} \left( \Sigma_r+\Sigma_g-2(\Sigma_r\Sigma_g)^{1/2} \right). \]
FID 低通常表示质量较好,但它依赖 feature extractor,也可能被数据预处理、样本数、重复样本影响。GAN 尤其要同时看 diversity,否则一个只生成少数高质量样本的模型可能欺骗部分指标。
FID 的估计对样本数敏感。真实使用时要固定:
- feature extractor;
- resize/crop/normalization;
- real statistics 的数据 split;
- generated sample count;
- random seed 或报告均值方差。
precision/recall for generative models 试图拆开 fidelity 和 coverage:
| metric | high means | low means |
|---|---|---|
| precision | generated samples lie near real manifold | artifacts / unrealistic samples |
| recall | generated samples cover real manifold | missing modes / collapse |
一个模型可能 precision 高、recall 低:少数样本非常逼真但模式覆盖差。GAN 论文或实验报告里,最好同时展示 sample grid、nearest real neighbors、class histogram 或 domain-specific coverage 指标。
Adversarial Training as a Paradigm
GAN 的意义不止是一个生成模型,它还定义了一种训练范式:目标不是直接拟合标签,而是通过一个 learned evaluator 提供训练信号。
| Paradigm | Generator-like component | Evaluator-like component |
|---|---|---|
| GAN | image/audio generator | discriminator |
| adversarial robustness | classifier | adversarial attacker |
| RLHF/RLAIF | policy model | reward model |
| data filtering | sample producer | quality classifier |
| diffusion guidance | denoiser | classifier/reward guidance |
从这个角度看,GAN 是后来“用模型监督模型”的早期形态。LLM post-training 中的 reward model、AI feedback、verifier,也都带有这种 learned evaluator 的味道。
Minimal Training Loop
for real in loader:
z = torch.randn(real.size(0), latent_dim, device=real.device)
fake = generator(z).detach()
real_logits = discriminator(real)
fake_logits = discriminator(fake)
d_loss = (
F.binary_cross_entropy_with_logits(real_logits, torch.ones_like(real_logits))
+ F.binary_cross_entropy_with_logits(fake_logits, torch.zeros_like(fake_logits))
)
d_opt.zero_grad(set_to_none=True)
d_loss.backward()
d_opt.step()
z = torch.randn(real.size(0), latent_dim, device=real.device)
fake = generator(z)
fake_logits = discriminator(fake)
g_loss = F.binary_cross_entropy_with_logits(
fake_logits,
torch.ones_like(fake_logits),
)
g_opt.zero_grad(set_to_none=True)
g_loss.backward()
g_opt.step()真实工程里需要 spectral norm、gradient penalty、EMA generator、balanced update ratio 和大量监控。GAN 对训练细节非常敏感。
WGAN-GP 的训练 loop 结构稍有不同:
for real in loader:
for _ in range(n_critic):
z = torch.randn(real.size(0), latent_dim, device=real.device)
fake = generator(z).detach()
d_real = critic(real).mean()
d_fake = critic(fake).mean()
gp = gradient_penalty(critic, real, fake)
d_loss = d_fake - d_real + gp_weight * gp
d_opt.zero_grad(set_to_none=True)
d_loss.backward()
d_opt.step()
z = torch.randn(real.size(0), latent_dim, device=real.device)
fake = generator(z)
g_loss = -critic(fake).mean()
g_opt.zero_grad(set_to_none=True)
g_loss.backward()
g_opt.step()关键差异:
- critic 不用 sigmoid;
- critic loss 的符号要和 minimization 约定匹配;
- fake 在 critic update 中 detach;
- generator update 中不能 detach fake;
- gradient penalty 需要
create_graph=True,否则无法对 critic 参数回传二阶路径。
Implementation Checklist
实现或调试 GAN 时,可以按下面顺序检查:
- discriminator 输出是 logits 还是 probability,loss 是否匹配;
- discriminator update 中 fake 是否
.detach(); - generator update 中 fake 是否保留梯度;
D(real)、D(fake)、gradient norm 是否显示一方过强;- BCE/hinge/WGAN loss 的符号是否和 optimizer minimize 约定一致;
- WGAN critic 是否没有 sigmoid;
- gradient penalty 是否对 interpolated samples 开启
requires_grad; - spectral norm / R1 / GP 是否只加在 discriminator/critic 路径;
- EMA generator 是否只用于 evaluation/sampling 或按设计参与训练;
- G step 是否没有把 discriminator forward 放进
no_grad(); - lazy regularization 是否按 interval 做 scale correction;
- TTUR 的 lr、betas、\(n_D\) 是否记录;
- FID 统计是否固定 preprocessing、sample count 和 real split;
- 是否同时看 precision/recall 或 mode coverage;
- 是否保存 fixed latent grid,用同一组 \(z\) 观察训练过程。
两个 smoke tests:
# 1. detach boundary
fake = generator(z).detach()
d_loss = discriminator(fake).mean()
d_loss.backward()
assert all(p.grad is None for p in generator.parameters())
# 2. generator receives gradient
fake = generator(z)
g_loss = -discriminator(fake).mean()
g_loss.backward()
assert any(p.grad is not None for p in generator.parameters())
# 3. frozen discriminator still passes input gradients
discriminator.zero_grad(set_to_none=True)
generator.zero_grad(set_to_none=True)
set_requires_grad(discriminator, False)
fake = generator(z)
fake.retain_grad()
g_loss = -discriminator(fake).mean()
g_loss.backward()
assert fake.grad is not None
assert all(p.grad is None for p in discriminator.parameters())这两个测试很朴素,但能抓住最常见、也最隐蔽的 GAN 训练错误:该断的梯度没断,不该断的梯度断了。