在分布式深度学习训练中,当改变 batch size 时,需要相应调整学习率以维持训练动态的稳定性与收敛效率。最常见的两种调整规则是 线性缩放(η′=kη)和 平方根缩放(η′=kη),以及对于极大 batch 采用的 线性缩放 + 长 warmup。下面详细解释这些规则是如何从理论分析和实践经验中得出的。
- 小 batch:梯度估计噪声大,每一步的方向不太准,因此要用小学习率防止震荡。
- 大 batch:梯度估计更准确,方向更可靠,可以用更大的步长加速收敛。
如果 batch 扩大 k 倍,相当于一次更新看到了 k 倍多的样本,因此更新的“信息量”增加了。为了使一次更新能达到原来 k 步累积的效果,可以将学习率也扩大 k 倍。这样,在相同的 epoch 数内,大 batch 的更新步数变少,但每一步更大,总体的参数变化量大致相当。
考虑标准 SGD 更新:
θt+1=θt−ηg~t,g~t=B1i=1∑B∇Li(θt)假设梯度变化缓慢,真实梯度为 gt≈E[g~t]。忽略噪声,一次更新后参数变化约为 −ηgt。
若 batch 扩大为 B′=kB,学习率设为 η′,则一次更新后参数变化约为 −η′gt。为了使一次更新与原来 k 次更新(每次学习率 η)的总效果相当,即:
η′gt=kηgt⇒η′=kη这就是线性规则的由来。
对于强凸二次损失,可以严格证明:为了在相同的 epoch 数内达到相同的次优性(suboptimality),学习率应与 batch size 成正比(Bottou, 2012)。更一般地,Goyal et al. (2017) 在论文中给出了启发式论证:当 batch 扩大 k 倍时,梯度的方差减小为原来的 1/k,为了保持参数更新的信噪比(SNR)不变,需要将学习率扩大 k 倍。
信噪比(SNR) 定义为:
SNR=Tr(Var(Δθ))∥E[Δθ]∥=ηTr(Σ)/Bη∥g∥=Tr(Σ)/B∥g∥可见 SNR 与 η 无关,但与 B 有关。实际上,当我们增大 batch 时,SNR 自动增大(梯度更准)。若希望保持每次更新对损失函数的影响(即期望损失下降量)相似,需要调整 η。对于 SGD 的期望损失下降,有近似公式:
E[L(θt+1)−L(θt)]≈−η∥∇L∥2+2Bη2Tr(∇2L⋅Σ)当 batch 扩大 k 倍时,要使一阶项主导并保持下降量相当,可令 η′=kη,这样一阶项扩大 k 倍,但二阶项(噪声项)扩大 k2/(kB)=k/B,即二阶项也扩大了 k 倍(相对于原 batch 的二阶项)。因此,在初期梯度较大时,一阶项占优,线性缩放可使下降速度匹配。
在随机梯度下降(SGD)中,我们关心损失函数在每次迭代后的期望下降量。下面给出推导过程。
设损失函数 L(θ) 关于参数 θ 光滑。在第 t 步,参数为 θt,真实梯度为 gt=∇L(θt)。我们从数据集中随机抽取一个 mini-batch,计算梯度估计:
g~t=B1i=1∑B∇Li(θt),其中 Li 是单个样本的损失,且假设各样本梯度独立同分布,满足:
E[∇Li]=gt,Cov(∇Li)=Σ.于是,梯度估计的均值和协方差为:
E[g~t]=gt,Cov(g~t)=B1Σ.SGD 更新规则为:
θt+1=θt−ηg~t,其中 η 是学习率。
对损失函数 L 在 θt 处进行二阶泰勒展开:
L(θt+1)=L(θt)+∇L(θt)⊤(θt+1−θt)+21(θt+1−θt)⊤∇2L(θt)(θt+1−θt)+O(η3).代入 θt+1−θt=−ηg~t,并记 Ht=∇2L(θt),得:
L(θt+1)−L(θt)=−ηgt⊤g~t+21η2g~t⊤Htg~t+O(η3).
对上述等式两边取期望(关于 mini-batch 的随机性):
E[L(θt+1)−L(θt)]=−ηgt⊤E[g~t]+21η2E[g~t⊤Htg~t]+O(η3).由 E[g~t]=gt,得:
−ηgt⊤gt=−η∥gt∥2.计算 E[g~t⊤Htg~t]。由于 Ht 是确定性的(在给定 θt 下),有:
E[g~t⊤Htg~t]=E[Tr(Htg~tg~t⊤)]=Tr(HtE[g~tg~t⊤]).而
E[g~tg~t⊤]=Cov(g~t)+E[g~t]E[g~t]⊤=B1Σ+gtgt⊤.因此,
E[g~t⊤Htg~t]=B1Tr(HtΣ)+gt⊤Htgt.于是二阶项贡献为:
21η2(gt⊤Htgt+B1Tr(HtΣ)).
将一阶和二阶结果合并:
E[L(θt+1)−L(θt)]=−η∥gt∥2+21η2gt⊤Htgt+2Bη2Tr(HtΣ)+O(η3).在随机梯度下降的经典分析中,常忽略确定性二阶项 21η2gt⊤Htgt,因为它相对于一阶项是更高阶的小量(尤其当学习率较小时),且在很多情况下该项并不主导收敛行为。因此得到常用近似:
E[L(θt+1)−L(θt)]≈−η∥∇L∥2+2Bη2Tr(∇2L⋅Σ).这个公式揭示了:
- 第一项是确定性梯度下降带来的损失下降,与学习率成正比。
- 第二项是 mini-batch 随机性导致的“噪声项”,它可能减缓下降甚至使损失上升(当 η 过大时),且与 batch size B 成反比,说明增大 batch 可减小噪声影响。
考虑参数更新的方差。更新量 Δθ=−ηg~ 的协方差矩阵为:
Var(Δθ)=η2⋅B1Σ当 batch 扩大为 B′=kB,新学习率为 η′ 时,更新量的方差变为:
Var(Δθ′)=η′2⋅kB1Σ如果我们希望保持更新量的方差不变(即噪声引起的波动幅度不变),则需:
η′2⋅kB1Σ=η2⋅B1Σ⇒kη′2=η2⇒η′=kη这种观点在 batch 极大时更为合理。因为当 batch 很大时,梯度估计的方差已经非常小,此时再线性放大学习率会导致更新量过大(因为更新量的期望与方差之比会变得很大),可能使参数跨过最优区域。保持方差不变可以避免训练不稳定。
平方根缩放也可以从二阶优化的角度理解。当 batch 极大时,梯度接近真实梯度,相当于确定性优化。此时,一步最优的步长应与 Hessian 矩阵的 Lipschitz 常数相关。如果采用一阶方法,通常学习率需要与 Lipschitz 常数的倒数成正比,而这个常数与 batch 无关。因此,不能无限放大学习率,平方根缩放是一种折中。
当 batch 极大(例如从 256 扩大到 32k)时,直接使用线性缩放(η′=kη)往往会导致训练初期发散或性能下降。原因在于:
- 初始阶段参数随机初始化,梯度可能很大,直接使用大学习率容易导致梯度爆炸。
- 模型需要时间适应,尤其是 Batch Normalization 的统计量也需要调整。
- 梯度噪声极低,但参数初值离最优解很远,大步长可能使参数跳出合适区域。
Warmup 是指在训练开始的前 Tw 步或前几个 epoch 使用较小的学习率(例如从 0 线性增加到目标学习率)。它的作用是:
- 让模型在初期以温和的步伐探索参数空间,避免剧烈震荡。
- 逐渐积累梯度的稳定性,使 BN 的 running statistics 适应大学习率。
- 为后续的大学习率打下基础。
对于极大 batch,即使采用平方根缩放,也可能需要 warmup。但许多实验表明,线性缩放配合足够长的 warmup 可以达到与平方根缩放相当甚至更好的效果。这是因为 warmup 阶段实际上起到了“软化”线性缩放的作用:初期学习率小,相当于学习率自动遵循了某种渐进增长,最终达到 kη。只要 warmup 长度足够长,模型就能平稳过渡。
通常 warmup 的 epoch 数会随 batch 扩大而增加。例如,原始 batch 256 可能用 5 epoch warmup,当 batch 扩大到 8k 时,可能需要 100 epoch 甚至更长。一种经验法则是:warmup 步数与 batch 扩大倍数成正比,即保持 warmup 期间看到的样本总数不变(即 warmup 的 epoch 数不变?不,因为 batch 变大,每 epoch 步数变少,所以需要更多 epoch 才能看到相同数量的样本)。更常见的做法是固定 warmup 的迭代步数,例如从 5000 步增加到 20000 步。
| 缩放方式 | 公式 | 推导依据 | 适用场景 |
|---|
| 线性缩放 | η′=kη | 保持一次更新与原来 k 次更新的期望变化量相同,或保持收敛速度一致 | batch 扩大倍数适中(如 ≤32 倍),梯度噪声仍明显 |
| 平方根缩放 | η′=kη | 保持参数更新量的方差不变,避免大步长带来的不稳定 | batch 极大(如 128 倍以上),梯度噪声可忽略 |
| 线性 + 长 warmup | 先用 warmup 再线性缩放 | 通过初期小学习率平滑过渡,使模型适应大学习率,再逐步放大 | 超大 batch 时作为线性缩放的补充,也可替代平方根缩放 |