3.2 Optimizers


Optimizer 是训练动力学的核心。它不只是“让 loss 下降”的工具,还决定了噪声如何进入参数、不同方向的曲率如何被缩放、模型最终停在哪类 basin 里。

What an Optimizer Actually Specifies

NoteDefinition: Optimizer State

An optimizer state is the collection of auxiliary variables used to compute parameter updates, such as momentum buffers, second-moment estimates, step counters, or factored statistics.

一个 optimizer 至少指定三件事:

  1. 用什么方向更新参数;
  2. 用什么尺度缩放这个方向;
  3. 记住哪些历史信息。

普通 gradient descent 没有额外状态;momentum 记住一阶历史;Adam 记住一阶和二阶历史;Adafactor 记住 factored 二阶统计。工程上,optimizer state 经常比公式更重要,因为它决定显存、checkpoint 大小、resume 后的训练轨迹。

以 AdamW 为例,若模型参数量是 \(N\),参数本身占 \(N\) 个数,梯度占 \(N\) 个数,一阶动量 \(m\)\(N\) 个数,二阶动量 \(v\) 也占 \(N\) 个数。若都用 FP32,optimizer state 就已经是参数量的两倍。LLM 训练里的 ZeRO、FSDP optimizer sharding,很大一部分就是为了解决这个状态存储问题。

Gradient Descent

NoteDefinition: Gradient Descent

For objective \(L(\theta)\), gradient descent updates \[ \theta_{t+1} = \theta_t-\eta\nabla L(\theta_t), \] where \(\eta\) is the learning rate.

在一阶 Taylor expansion 下,

\[ L(\theta+\Delta) \approx L(\theta)+\nabla L(\theta)^\top\Delta. \]

若限制 \(\|\Delta\|_2\leq \epsilon\),最陡下降方向就是 \(-\nabla L(\theta)\)。这解释了为什么 gradient 是局部下降方向,但也暴露了它的局限:它只看局部线性近似。

Quadratic Stability and Condition Number

Optimizer 的稳定性最容易在二次函数上看清楚。设

\[ L(\theta) = \frac{1}{2}\theta^\top H\theta, \qquad H\succ 0. \]

梯度是 \(g=H\theta\),gradient descent 变成

\[ \theta_{t+1} = (I-\eta H)\theta_t. \]

\(H\) 的特征值为

\[ 0<\lambda_{\min}\le \lambda_i\le \lambda_{\max}, \]

则每个特征方向独立更新:

\[ z_{t+1,i} = (1-\eta\lambda_i)z_{t,i}. \]

收敛要求所有方向都满足

\[ |1-\eta\lambda_i|<1, \]

因此

\[ 0<\eta<\frac{2}{\lambda_{\max}}. \]

ImportantTheorem: Gradient Descent Stability on a Quadratic

For \(L(\theta)=\frac{1}{2}\theta^\top H\theta\) with \(H\succ0\), fixed-step gradient descent converges linearly if and only if \[ 0<\eta<\frac{2}{\lambda_{\max}(H)}. \] The slowest contraction is governed by the condition number \(\kappa=\lambda_{\max}/\lambda_{\min}\).

对称正定矩阵可正交对角化:

\[ H=Q\Lambda Q^\top. \]

\(z_t=Q^\top\theta_t\),则

\[ z_{t+1} = Q^\top(I-\eta H)Qz_t = (I-\eta\Lambda)z_t. \]

所以第 \(i\) 个坐标满足

\[ z_{t+1,i}=(1-\eta\lambda_i)z_{t,i}. \]

线性系统收敛当且仅当所有乘子绝对值小于 1,即

\[ -1<1-\eta\lambda_i<1. \]

对所有 \(i\) 同时成立得到 \(0<\eta<2/\lambda_{\max}\)。若选择最优常数步长

\[ \eta^*=\frac{2}{\lambda_{\max}+\lambda_{\min}}, \]

最坏方向收缩因子为

\[ \rho^* = \frac{\kappa-1}{\kappa+1}. \]

条件数越大,\(\rho^*\) 越接近 1,下降越慢。

这就是“峡谷”问题的数学版本:高曲率方向限制学习率,低曲率方向又因为同一个学习率太小而走得很慢。preconditioning 的目标就是把 \(H\) 变得更像 identity。若存在 \(B\approx H\),更新

\[ \Delta=-\eta B^{-1}g \]

相当于在新的坐标系里优化条件数更好的问题。

WarningPitfall: LR Instability Often Starts in the Largest-Curvature Directions

When loss spikes immediately after warmup or after changing batch/sequence length, the first suspect is often the largest effective curvature direction, not the average gradient norm.

Steepest Descent Depends on Geometry

“最陡下降”不是绝对概念,它依赖于你用哪个 norm 衡量 step 的大小。若用一般正定矩阵 \(B\succ 0\) 定义局部二次距离

\[ \|\Delta\|_B^2 = \Delta^\top B\Delta, \]

那么线性化目标下的最陡下降问题是

\[ \min_\Delta g^\top \Delta + \frac{1}{2\eta}\Delta^\top B\Delta, \qquad g=\nabla L(\theta). \]

一阶条件给出

\[ g+\frac{1}{\eta}B\Delta=0, \]

所以

\[ \Delta^* = -\eta B^{-1}g. \]

ImportantTheorem: Preconditioned Gradient Descent

Choosing a local metric \(B\) yields the update \[ \theta_{t+1} = \theta_t-\eta B^{-1}\nabla L(\theta_t). \] Thus adaptive optimizers can be interpreted as gradient descent under a changing diagonal geometry.

目标

\[ \phi(\Delta)=g^\top\Delta+\frac{1}{2\eta}\Delta^\top B\Delta \]

是关于 \(\Delta\) 的严格凸二次函数,因为 \(B\succ 0\)。对 \(\Delta\) 求导:

\[ \nabla_\Delta\phi = g+\frac{1}{\eta}B\Delta. \]

令梯度为零,得到

\[ \Delta=-\eta B^{-1}g. \]

这就是唯一极小点。普通 GD 对应 \(B=I\),Newton method 对应 \(B\approx \nabla^2 L(\theta)\),Adam/RMSProp 对应一个随时间变化的 diagonal \(B_t\)

这个视角很有用:optimizer 不是神奇地“调参”,而是在选择一个几何。若某个方向历史梯度很大,Adam 会认为这个方向尺度敏感,从而缩小步长;若某个方向历史梯度很小,步长相对变大。

Stochastic Gradient Descent

full-batch gradient

\[ \nabla L(\theta) = \frac{1}{n}\sum_{i=1}^{n}\nabla \ell_i(\theta) \]

在大数据上代价太高。SGD 用 minibatch \(\mathcal{B}\) 估计:

\[ g_t = \frac{1}{|\mathcal{B}|} \sum_{i\in\mathcal{B}}\nabla \ell_i(\theta_t). \]

若 batch 均匀采样,则 \(\mathbb{E}[g_t]=\nabla L(\theta_t)\)

SGD 的噪声不是纯粹坏事。它会帮助模型逃离 sharp region,也会让训练像一个温度逐渐降低的随机动力系统。

严格一点,无偏性依赖采样和 loss reduction。若每个样本以概率 \(p_i\) 被采样,直接平均

\[ \frac{1}{B}\sum_{i\in\mathcal{B}}\nabla\ell_i(\theta) \]

一般估计的是加权风险

\[ L_p(\theta)=\sum_i p_i\ell_i(\theta), \]

不是均匀 empirical risk。若目标仍是

\[ L(\theta)=\frac{1}{n}\sum_i\ell_i(\theta), \]

则需要 importance correction:

\[ \hat g = \frac{1}{B}\sum_{i\in\mathcal{B}} \frac{1}{np_i}\nabla\ell_i(\theta). \]

在 class-imbalance 任务里,weighted sampler 到底是在改变训练分布,还是在无偏估计原分布,需要明确写进 recipe。

NoteDefinition: Optimized Training Distribution

The optimized training distribution is the distribution over examples induced by sampling weights, filtering, packing, loss masks, and reduction denominators. It may differ from the raw dataset distribution.

with replacement 和 without replacement 的方差也不同。前面写的

\[ \operatorname{Var}(g_{\mathcal{B}})=\frac{1}{B}\Sigma \]

对应独立采样近似。若从有限数据集中 without replacement 抽大小为 \(B\) 的 batch,方差带有限总体修正:

\[ \operatorname{Var}(g_{\mathcal{B}}) \approx \frac{1}{B}\left(1-\frac{B}{n}\right)\Sigma. \]

所以一个 epoch 内的随机排列不是简单的独立抽样;越接近扫完整个数据集,batch 梯度之间越相关。

Minibatch Noise and Batch Size

设单样本梯度随机变量为

\[ G_i(\theta)=\nabla \ell_i(\theta), \]

其均值为 \(\nabla L(\theta)\),协方差为 \(\Sigma(\theta)\)。若 minibatch 独立均匀采样,batch 梯度

\[ g_{\mathcal{B}} = \frac{1}{B}\sum_{i\in\mathcal{B}}G_i \]

满足

\[ \mathbb{E}[g_{\mathcal{B}}] = \nabla L(\theta), \qquad \operatorname{Var}(g_{\mathcal{B}}) = \frac{1}{B}\Sigma(\theta). \]

这解释了 batch size 的两个作用:

  1. batch 越大,梯度方向越接近 full gradient;
  2. batch 越大,SGD 噪声温度越低。

当 batch size 增大时,经常可以增大学习率,但不是无限线性增长。因为更新

\[ \theta_{t+1} = \theta_t-\eta g_{\mathcal{B}} \]

的噪声尺度大约是 \(\eta^2\Sigma/B\)。若只增大 \(B\) 不调 \(\eta\),训练会变得更确定、更像 full-batch optimization,可能更容易落入 sharp basin。若把 \(\eta\) 增太大,则 deterministic part 自己会不稳定。

可以用一个近似的 stochastic differential equation 视角理解学习率和 batch size 的耦合。把 SGD 写成

\[ \theta_{t+1} = \theta_t-\eta\nabla L(\theta_t) - \eta\xi_t, \]

其中

\[ \mathbb{E}[\xi_t]=0, \qquad \operatorname{Cov}(\xi_t)\approx \frac{1}{B}\Sigma(\theta_t). \]

随机项的协方差尺度是

\[ \operatorname{Cov}(\eta\xi_t) \approx \frac{\eta^2}{B}\Sigma. \]

如果只想保持每一步参数噪声尺度近似不变,\(\eta/\sqrt{B}\) 是关键比例;如果想保持每个样本预算下的扩散强度,则会出现 \(\eta/B\) 这样的有效温度直觉。真实深度网络还受 momentum、Adam preconditioning、gradient clipping、scheduler 和数据重复影响,所以这不是精确律,但它解释了为什么 large-batch 训练必须重新调 LR/warmup。

ImportantImplementation Contract: Log Batch in the Same Unit as the Objective

For image classification, examples per step may be enough. For language modeling, log valid tokens per optimizer step, because token count controls both compute and gradient-noise scale.

WarningPitfall: Larger Batch Is Not Just Faster Training

Increasing batch size changes both hardware efficiency and optimization noise. Matching throughput without matching optimization dynamics can change generalization.

Momentum

Momentum 引入速度变量:

\[ v_{t+1} = \beta v_t+g_t, \qquad \theta_{t+1} = \theta_t-\eta v_{t+1}. \]

它对持续同向的梯度进行累积,对高频震荡方向做平滑。对于狭长 valley,普通 SGD 会左右震荡;momentum 会沿低曲率方向加速。

NoteDefinition: Exponential Moving Average

The recursion \(v_t=\beta v_{t-1}+(1-\beta)x_t\) expands to \[ v_t=(1-\beta)\sum_{k=0}^{t-1}\beta^k x_{t-k}. \] Recent observations receive larger weights, and the effective window size is roughly \(1/(1-\beta)\).

Momentum as a Dynamical System

对二次函数

\[ L(\theta)=\frac{1}{2}\lambda\theta^2 \]

普通 GD 更新为

\[ \theta_{t+1} = (1-\eta\lambda)\theta_t. \]

收敛要求

\[ |1-\eta\lambda|<1 \quad\Longleftrightarrow\quad 0<\eta<\frac{2}{\lambda}. \]

所以高曲率方向会限制最大学习率。Momentum 的递推可以写成二阶差分方程。用

\[ v_{t+1}=\beta v_t+\lambda\theta_t, \qquad \theta_{t+1}=\theta_t-\eta v_{t+1}, \]

消去 \(v_t\) 可得到

\[ \theta_{t+1} = (1+\beta-\eta\lambda)\theta_t - \beta\theta_{t-1}. \]

这像一个带阻尼的弹簧系统:\(\lambda\) 是曲率,\(\eta\) 是步长,\(\beta\) 是惯性。合适的 \(\beta\) 能沿低曲率方向积累速度,但过大的 \(\beta\) 会产生 overshoot。

WarningPitfall: Momentum Can Hide Instability

Momentum may keep loss decreasing for a while even when the learning rate is too high. Instability often appears later as oscillation, sudden loss spikes, or exploding optimizer states.

Nesterov Momentum

Nesterov momentum 的思想是先沿 momentum 方向 look ahead,再在 look-ahead point 计算梯度:

\[ v_{t+1} = \beta v_t + \nabla L(\theta_t-\eta\beta v_t), \]

\[ \theta_{t+1} = \theta_t-\eta v_{t+1}. \]

直觉上,普通 momentum 是“按当前坡度加速”;Nesterov 是“先看惯性会把我带到哪里,再在那里刹车或加速”。在凸优化里,Nesterov acceleration 有更强的理论收敛率;在深度学习里,它常作为 SGD momentum 的一个稳定变体。

AdaGrad, RMSProp, and Adam

Adaptive optimizers scale each coordinate by an estimate of gradient magnitude.

AdaGrad:

\[ s_t=s_{t-1}+g_t\odot g_t, \qquad \theta_{t+1} = \theta_t-\eta\frac{g_t}{\sqrt{s_t}+\epsilon}. \]

RMSProp:

\[ s_t=\rho s_{t-1}+(1-\rho)g_t\odot g_t. \]

AdaGrad 的 denominator 单调增长。对稀疏特征来说,这很好:很少出现的坐标不会因为全局 step 太小而完全学不动;频繁出现的坐标会自动降速。对长时间 nonstationary 训练来说,这也会变成问题:\(s_t\) 只增不减,后期 effective LR 可能太小。

RMSProp 把累计和换成指数滑动平均:

\[ s_t = (1-\rho)\sum_{k=0}^{t-1}\rho^k g_{t-k}^2. \]

于是它只记住最近约 \(1/(1-\rho)\) 步的梯度尺度,更适合目标曲面持续变化的深度网络。

Adam combines first and second moment estimates:

\[ m_t=\beta_1m_{t-1}+(1-\beta_1)g_t, \]

\[ v_t=\beta_2v_{t-1}+(1-\beta_2)g_t^2. \]

Bias correction:

\[ \hat{m}_t=\frac{m_t}{1-\beta_1^t}, \qquad \hat{v}_t=\frac{v_t}{1-\beta_2^t}. \]

Update:

\[ \theta_{t+1} = \theta_t-\eta\frac{\hat{m}_t}{\sqrt{\hat{v}_t}+\epsilon}. \]

Bias correction 来自一个简单事实。若梯度均值近似稳定为 \(\mu\),且 \(m_0=0\),则

\[ \mathbb{E}[m_t] = (1-\beta_1) \sum_{k=0}^{t-1} \beta_1^k\mu = (1-\beta_1^t)\mu. \]

所以早期 \(m_t\) 被初始化为零拉低,需要除以 \(1-\beta_1^t\)。二阶矩 \(v_t\) 同理。

NoteDefinition: Coordinate-Wise Adaptive Scaling

Coordinate-wise adaptive scaling means dividing each coordinate update by a history-dependent estimate of that coordinate’s gradient scale.

一个容易忽略的性质是 Adam 近似 scale-invariant。若某个坐标的梯度整体放大 \(c\) 倍,则

\[ \hat m_{t,i}\mapsto c\hat m_{t,i}, \qquad \hat v_{t,i}\mapsto c^2\hat v_{t,i}. \]

忽略 \(\epsilon\) 时,

\[ \frac{c\hat m_{t,i}}{\sqrt{c^2\hat v_{t,i}}} = \operatorname{sign}(c)\frac{\hat m_{t,i}}{\sqrt{\hat v_{t,i}}}. \]

如果 \(c>0\),更新几乎不变。这解释了 Adam 为什么对梯度尺度不均衡比较鲁棒;也解释了为什么 \(\epsilon\) 不是无关紧要的小数:当 \(\sqrt{\hat v_{t,i}}\) 很小时,\(\epsilon\) 会主导 denominator,Adam 退回更接近 momentum SGD。

Regime Denominator Behavior
\(\sqrt{\hat v_i}\gg\epsilon\) gradient scale estimate scale-normalized Adam
\(\sqrt{\hat v_i}\approx\epsilon\) mixed sensitive to epsilon choice
\(\sqrt{\hat v_i}\ll\epsilon\) epsilon floor roughly momentum-like with LR \(\eta/\epsilon\)
WarningPitfall: Epsilon Is a Numerical and Optimization Hyperparameter

Changing \(\epsilon\) changes more than numerical safety. It changes the effective learning rate for small-gradient coordinates.

Adam as Diagonal Preconditioning

Adam 的更新可以写成

\[ \theta_{t+1} = \theta_t-\eta D_t^{-1}\hat{m}_t, \qquad D_t=\operatorname{diag}(\sqrt{\hat{v}_t}+\epsilon). \]

这就是上一节 preconditioning 视角里的 diagonal metric。每个参数坐标有自己的 effective learning rate:

\[ \eta_{t,i}^{\text{eff}} = \frac{\eta}{\sqrt{\hat{v}_{t,i}}+\epsilon}. \]

如果某个坐标梯度长期很大,\(\hat{v}_{t,i}\) 大,它的 effective LR 就小;如果梯度长期很小,effective LR 就大。这也是 Adam 在 sparse 或 scale 不均衡问题上表现好的原因。

但这个机制也会带来隐患:若某个坐标早期梯度极小,之后突然出现大梯度,Adam 可能给它过大的 step。\(\epsilon\)、gradient clipping、warmup 都和这个风险有关。

WarningPitfall: Adam Is Not Just Faster SGD

Adam changes the geometry of the update by normalizing coordinates with historical squared gradients. It often improves early optimization, but its implicit bias can differ from SGD and sometimes generalizes differently.

Numerical Details in Adam

Adam 实现里有几个看似小、实际很关键的细节:

Detail Why it matters
\(\epsilon\) placement \(\sqrt{v}+\epsilon\) and \(\sqrt{v+\epsilon}\) are not identical
state dtype FP32 states are more stable than FP16 states
step counter bias correction depends on exact step index
skipped steps AMP overflow should not advance state incorrectly
gradient accumulation optimizer step count differs from backward count

在 mixed precision 训练中,常见流程是:

  1. forward/backward 用 BF16/FP16 activation;
  2. gradient unscale;
  3. gradient clipping;
  4. optimizer 用 FP32 master weights or FP32 optimizer states;
  5. update 后同步低精度权重。

如果顺序错了,比如先 clip scaled gradient,再 unscale,那么裁剪阈值就完全变了。

Optimizer State Memory Accounting

设参数量为 \(N\)。不同 optimizer 的训练内存可以粗略写成:

Optimizer Extra state per parameter Typical state tensors
SGD \(0\) none
momentum SGD \(1\) velocity
Adam/AdamW \(2\) first moment \(m\), second moment \(v\)
Adafactor sublinear for matrices factored row/column second moments

若参数、梯度、Adam moments 都以 FP32 存储,单个参数对应:

\[ 4\text{ bytes parameter} + 4\text{ bytes gradient} + 8\text{ bytes Adam states} =16\text{ bytes}. \]

如果还有 FP32 master weights、低精度 model weights、distributed buckets、activation 和 temporary buffers,实际峰值会更高。优化器选择因此不仅是收敛问题,也是内存预算问题。

NoteDefinition: Optimizer Memory Multiplier

The optimizer memory multiplier is the ratio between optimizer state storage and parameter storage. AdamW has multiplier about \(2\) for first and second moments before sharding or quantization.

一个简单估算函数:

def adamw_state_gb(n_params, bytes_per_state=4):
    # m and v
    return 2 * n_params * bytes_per_state / 1024**3

对 7B 参数模型,FP32 Adam moments 约为

\[ 2\times 7\cdot 10^9\times 4 \approx 56\text{ GB}. \]

这还没算参数、梯度和 activation。因此 ZeRO/FSDP、8-bit optimizer、Adafactor、LoRA 等技术经常和 optimizer state 一起讨论。

Adafactor and Factored Second Moments

Adam 的二阶矩 \(v_t\) 和参数同形状。对一个矩阵参数

\[ W\in\mathbb{R}^{r\times c}, \]

完整二阶状态需要 \(rc\) 个数。Adafactor 的想法是只保存行统计和列统计:

\[ R_t\in\mathbb{R}^{r}, \qquad C_t\in\mathbb{R}^{c}. \]

给定梯度矩阵 \(G_t\),更新近似为

\[ R_t = \rho R_{t-1} + (1-\rho)\operatorname{mean}_{j}(G_{t,ij}^2), \]

\[ C_t = \rho C_{t-1} + (1-\rho)\operatorname{mean}_{i}(G_{t,ij}^2). \]

再用 outer product 重构二阶矩近似:

\[ \hat V_{ij} \approx \frac{R_i C_j}{\operatorname{mean}(R)}. \]

于是状态量从 \(rc\) 变成 \(r+c\)

ImportantTheorem: Factored Second-Moment State Is Sublinear for Matrices

For a matrix parameter \(W\in\mathbb{R}^{r\times c}\), full Adam second-moment storage is \(O(rc)\), while Adafactor-style row/column factored storage is \(O(r+c)\).

Full Adam stores one second-moment value for every entry \(W_{ij}\), so the number of stored scalars is

\[ rc. \]

Factored storage keeps one row statistic for each row and one column statistic for each column:

\[ r+c. \]

For square matrices \(r=c=d\), this changes storage from \(d^2\) to \(2d\). For Transformer projection matrices with large hidden dimension, this is a large memory reduction.

一个最小 factored denominator:

def adafactor_denominator(row_stat, col_stat, eps):
    row_scale = row_stat / row_stat.mean().clamp_min(eps)
    v_hat = row_scale[:, None] * col_stat[None, :]
    return v_hat.sqrt().add(eps)

Adafactor 的代价是 approximation bias:它假设二阶矩大致可以用行/列边缘统计重构。如果矩阵中有强烈局部结构,factored estimate 可能不如完整 \(v_t\) 精确。实践中它常配合 update clipping、relative step size 或 parameter-scale normalization 使用。

WarningPitfall: Factored States Are Not Drop-In Adam States

Adafactor changes both memory use and update geometry. Switching from AdamW to Adafactor usually requires retuning learning rate, clipping, and sometimes weight decay.

AdamW and Weight Decay

Classic Adam with L2 penalty adds \(\lambda\theta\) into the gradient before adaptive normalization:

\[ g_t \leftarrow \nabla L(\theta_t)+\lambda\theta_t. \]

This is not the same as true weight decay under adaptive scaling. AdamW decouples the shrinkage:

\[ \theta_{t+1} = (1-\eta\lambda)\theta_t - \eta\frac{\hat{m}_t}{\sqrt{\hat{v}_t}+\epsilon}. \]

在 Transformer/LLM 训练里,AdamW 是默认选择之一,因为它把“优化 loss 的梯度”和“控制参数范数的衰减”拆开了。

为什么 L2 penalty 和 weight decay 在 Adam 里不等价?把 Adam 的 diagonal denominator 记为

\[ D_t=\operatorname{diag}(\sqrt{\hat v_t}+\epsilon). \]

如果把 L2 penalty 加进 gradient,更新中 decay 部分变成

\[ -\eta D_t^{-1}\lambda\theta_t. \]

\(i\) 个坐标是

\[ -\eta\lambda \frac{\theta_{t,i}}{\sqrt{\hat v_{t,i}}+\epsilon}. \]

这意味着同一个 weight decay 系数会被 Adam 的二阶统计重新缩放:历史梯度大的坐标 decay 小,历史梯度小的坐标 decay 大。真正的 decoupled weight decay 则是

\[ -\eta\lambda\theta_t, \]

不经过 \(D_t^{-1}\)。这就是 AdamW 的核心。

ImportantTheorem: L2 Penalty Is Not Decoupled Weight Decay under Adaptive Scaling

For an adaptive update with diagonal preconditioner \(D_t^{-1}\), adding \(\lambda\theta\) to the gradient produces shrinkage \(-\eta D_t^{-1}\lambda\theta\), while decoupled weight decay produces shrinkage \(-\eta\lambda\theta\). They are equal only when \(D_t=I\) or all coordinates share the same denominator.

Adaptive gradient update with L2 penalty is

\[ \theta_{t+1} = \theta_t - \eta D_t^{-1}(\nabla L(\theta_t)+\lambda\theta_t). \]

Expanding,

\[ \theta_{t+1} = \theta_t - \eta D_t^{-1}\nabla L(\theta_t) - \eta\lambda D_t^{-1}\theta_t. \]

Decoupled weight decay instead writes

\[ \theta_{t+1} = (1-\eta\lambda)\theta_t - \eta D_t^{-1}\nabla L(\theta_t). \]

The decay terms match only when

\[ D_t^{-1}\theta_t=\theta_t \]

for the relevant coordinates, which generally fails when \(D_t\) has coordinate-wise adaptive entries.

一个教学版 AdamW step 如下。真实 PyTorch 会有 foreach/fused kernel、dtype/device 处理和更复杂的 state 管理,但数学顺序就是这些:

def adamw_step(param, grad, state, lr, beta1, beta2, eps, weight_decay):
    step = state["step"] + 1
    m = state["m"]
    v = state["v"]

    m.mul_(beta1).add_(grad, alpha=1 - beta1)
    v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

    m_hat = m / (1 - beta1**step)
    v_hat = v / (1 - beta2**step)

    with torch.no_grad():
        param.mul_(1 - lr * weight_decay)
        param.addcdiv_(m_hat, v_hat.sqrt().add(eps), value=-lr)

    state["step"] = step
WarningPitfall: Do Not Decay Parameters with No Gradient by Accident

Some training recipes apply weight decay only when a parameter participates in an optimizer step. Be explicit about whether frozen, unused, or sparse-gradient parameters should decay.

Parameter Groups and No-Decay Rules

实际训练里通常不会对所有参数使用同一个 weight decay。常见规则是:

Parameter type Weight decay? Reason
linear/conv weight yes controls weight norm
bias no shifting activations should not be penalized like weights
LayerNorm/RMSNorm scale no norm scale is already a calibration parameter
embedding often yes, sometimes no depends on model and tokenizer behavior

PyTorch 里会用 parameter groups 表达这些差异:

decay = []
no_decay = []

for name, p in model.named_parameters():
    if not p.requires_grad:
        continue
    if name.endswith("bias") or "norm" in name.lower():
        no_decay.append(p)
    else:
        decay.append(p)

opt = torch.optim.AdamW(
    [
        {"params": decay, "weight_decay": 0.1},
        {"params": no_decay, "weight_decay": 0.0},
    ],
    lr=3e-4,
    betas=(0.9, 0.95),
)
WarningPitfall: Weight Decay Is Part of the Training Recipe

Changing no-decay rules can change final quality even when loss curves look similar. Record parameter grouping rules together with LR, batch size, and optimizer hyperparameters.

parameter groups 还要满足两个工程不变量:

  1. 每个 trainable parameter 恰好进入一个 group;
  2. frozen parameter 不应该进入会更新的 group。

检查代码:

def check_param_groups(model, optimizer):
    trainable = {
        id(p): name
        for name, p in model.named_parameters()
        if p.requires_grad
    }

    seen = {}
    for gi, group in enumerate(optimizer.param_groups):
        for p in group["params"]:
            pid = id(p)
            if pid in seen:
                old = seen[pid]
                raise ValueError(f"parameter appears in groups {old} and {gi}")
            seen[pid] = gi

    missing = sorted(name for pid, name in trainable.items() if pid not in seen)
    if missing:
        raise ValueError(f"trainable parameters not optimized: {missing[:8]}")

对 LLM,no-decay 规则最好在日志里记录 group 数量和参数量:

for i, group in enumerate(opt.param_groups):
    n = sum(p.numel() for p in group["params"])
    print(i, "weight_decay=", group["weight_decay"], "params=", n)

这样以后换模型、加 LoRA、tie embedding、resize vocab 时,能立刻发现参数分组是否发生非预期变化。

ImportantImplementation Contract: Parameter Grouping Is Part of the Optimizer

An optimizer is not fully specified by its class and scalar hyperparameters. Parameter grouping rules determine which tensors receive decay, LR, betas, and updates.

Optimizer Choice in Deep Learning

Optimizer Strength Weakness Typical use
SGD strong implicit bias, simple slow early training CNN, final fine-tuning
SGD + Momentum stable acceleration LR sensitive vision training
AdaGrad sparse feature friendly denominator only grows NLP sparse features
RMSProp handles nonstationarity less standard today RNN/older RL
Adam fast, robust defaults can overfit/sharp minima deep nets, pretraining
AdamW decoupled regularization more hyperparameters Transformer/LLM
Adafactor much lower second-moment memory approximate geometry, retuning needed memory-constrained Transformers

Choosing Hyperparameters

Optimizer hyperparameters should be read as a coupled system:

Hyperparameter Effect Common symptom if wrong
learning rate \(\eta\) global step scale divergence, slow learning
\(\beta_1\) momentum memory noisy updates or sluggish reaction
\(\beta_2\) second-moment memory unstable effective LR or slow adaptation
\(\epsilon\) floor on denominator bad behavior on tiny gradients
weight decay norm control overfit or underfit
warmup steps early stability early loss spike
clipping threshold step outlier control exploding or over-clipped gradients
state precision moment numerical stability drift, overflow, or memory pressure
parameter grouping per-tensor recipe silent no-decay or duplicate-update bugs

LLM pretraining 常见 AdamW 参数是 \(\beta_1=0.9\)\(\beta_2\)\(0.95\)\(0.999\) 之间。较小的 \(\beta_2\) 让二阶统计更快适应 nonstationary training;较大的 \(\beta_2\) 更平滑但反应慢。没有一个组合对所有规模都最优,所以要结合 batch size、sequence length、LR schedule 和 loss spike 行为一起看。

Minimal PyTorch Pattern

import torch

model = torch.nn.Linear(128, 10)
opt = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    betas=(0.9, 0.95),
    weight_decay=0.1,
)

for x, y in loader:
    opt.zero_grad(set_to_none=True)
    loss = criterion(model(x), y)
    loss.backward()
    opt.step()

set_to_none=True 会减少不必要的 zero fill,并让未参与本轮计算的参数保持 grad is None,对大模型训练更干净。

Debugging Optimizer Behavior

当训练不稳定时,不要只看 loss。至少记录:

  1. gradient norm;
  2. parameter norm;
  3. update norm;
  4. ratio \(\|\Delta\theta\|/\|\theta\|\)
  5. effective learning rate statistics;
  6. skipped AMP steps;
  7. optimizer state dtype 和是否从 checkpoint 正确恢复。

一个简单的 update ratio 检查:

with torch.no_grad():
    total_p = torch.zeros([], device=device)
    total_u = torch.zeros([], device=device)
    for p, old in zip(model.parameters(), old_params):
        if p.requires_grad:
            total_p += p.float().norm().square()
            total_u += (p.float() - old.float()).norm().square()
    update_ratio = (total_u.sqrt() / total_p.sqrt()).item()

如果 update ratio 长期接近零,可能 LR 太小、梯度被过度裁剪、loss mask 太稀疏,或者参数没有被正确加入 optimizer。若 update ratio 突然暴涨,通常要检查 loss spike、AMP overflow、异常 batch、weight decay 或 resume 后 step counter。

TipPractical Rule

Treat the optimizer, scheduler, clipping, precision, and parameter grouping as one recipe. Debugging them separately often misses the actual failure mode.