目录

Pytorch Batch Size 与学习率缩放规则

在分布式深度学习训练中,当改变 batch size 时,需要相应调整学习率以维持训练动态的稳定性与收敛效率。最常见的两种调整规则是 线性缩放η=kη\eta' = k\eta)和 平方根缩放η=kη\eta' = \sqrt{k}\,\eta),以及对于极大 batch 采用的 线性缩放 + 长 warmup。下面详细解释这些规则是如何从理论分析和实践经验中得出的。


  • 小 batch:梯度估计噪声大,每一步的方向不太准,因此要用小学习率防止震荡。
  • 大 batch:梯度估计更准确,方向更可靠,可以用更大的步长加速收敛。

如果 batch 扩大 kk 倍,相当于一次更新看到了 kk 倍多的样本,因此更新的“信息量”增加了。为了使一次更新能达到原来 kk 步累积的效果,可以将学习率也扩大 kk 倍。这样,在相同的 epoch 数内,大 batch 的更新步数变少,但每一步更大,总体的参数变化量大致相当。

考虑标准 SGD 更新:

θt+1=θtηg~t,g~t=1Bi=1BLi(θt) \theta_{t+1} = \theta_t - \eta \tilde{g}_t, \quad \tilde{g}_t = \frac{1}{B}\sum_{i=1}^B \nabla L_i(\theta_t)

假设梯度变化缓慢,真实梯度为 gtE[g~t]g_t \approx \mathbb{E}[\tilde{g}_t]。忽略噪声,一次更新后参数变化约为 ηgt-\eta g_t

若 batch 扩大为 B=kBB' = kB,学习率设为 η\eta',则一次更新后参数变化约为 ηgt-\eta' g_t。为了使一次更新与原来 kk 次更新(每次学习率 η\eta)的总效果相当,即:

ηgt=kηgtη=kη \eta' g_t = k \eta g_t \quad \Rightarrow \quad \eta' = k\eta

这就是线性规则的由来。

对于强凸二次损失,可以严格证明:为了在相同的 epoch 数内达到相同的次优性(suboptimality),学习率应与 batch size 成正比(Bottou, 2012)。更一般地,Goyal et al. (2017) 在论文中给出了启发式论证:当 batch 扩大 kk 倍时,梯度的方差减小为原来的 1/k1/k,为了保持参数更新的信噪比(SNR)不变,需要将学习率扩大 kk 倍。

信噪比(SNR) 定义为:

SNR=E[Δθ]Tr(Var(Δθ))=ηgηTr(Σ)/B=gTr(Σ)/B \text{SNR} = \frac{\|\mathbb{E}[\Delta \theta]\|}{\sqrt{\text{Tr}(\text{Var}(\Delta \theta))}} = \frac{\eta \|g\|}{\eta \sqrt{\text{Tr}(\Sigma)/B}} = \frac{\|g\|}{\sqrt{\text{Tr}(\Sigma)/B}}

可见 SNR 与 η\eta 无关,但与 BB 有关。实际上,当我们增大 batch 时,SNR 自动增大(梯度更准)。若希望保持每次更新对损失函数的影响(即期望损失下降量)相似,需要调整 η\eta。对于 SGD 的期望损失下降,有近似公式:

E[L(θt+1)L(θt)]ηL2+η22BTr(2LΣ) \mathbb{E}[L(\theta_{t+1}) - L(\theta_t)] \approx -\eta \|\nabla L\|^2 + \frac{\eta^2}{2B} \text{Tr}(\nabla^2 L \cdot \Sigma)

当 batch 扩大 kk 倍时,要使一阶项主导并保持下降量相当,可令 η=kη\eta' = k\eta,这样一阶项扩大 kk 倍,但二阶项(噪声项)扩大 k2/(kB)=k/Bk^2 / (kB) = k/B,即二阶项也扩大了 kk 倍(相对于原 batch 的二阶项)。因此,在初期梯度较大时,一阶项占优,线性缩放可使下降速度匹配。

在随机梯度下降(SGD)中,我们关心损失函数在每次迭代后的期望下降量。下面给出推导过程。

SGD 的期望损失下降公式推导

设损失函数 L(θ)L(\theta) 关于参数 θ\theta 光滑。在第 tt 步,参数为 θt\theta_t,真实梯度为 gt=L(θt)g_t = \nabla L(\theta_t)。我们从数据集中随机抽取一个 mini-batch,计算梯度估计:

g~t=1Bi=1BLi(θt), \tilde{g}_t = \frac{1}{B} \sum_{i=1}^{B} \nabla L_i(\theta_t),

其中 LiL_i 是单个样本的损失,且假设各样本梯度独立同分布,满足:

E[Li]=gt,Cov(Li)=Σ. \mathbb{E}[\nabla L_i] = g_t, \quad \text{Cov}(\nabla L_i) = \Sigma.

于是,梯度估计的均值和协方差为:

E[g~t]=gt,Cov(g~t)=1BΣ. \mathbb{E}[\tilde{g}_t] = g_t, \quad \text{Cov}(\tilde{g}_t) = \frac{1}{B} \Sigma.

SGD 更新规则为:

θt+1=θtηg~t, \theta_{t+1} = \theta_t - \eta \tilde{g}_t,

其中 η\eta 是学习率。


对损失函数 LLθt\theta_t 处进行二阶泰勒展开:

L(θt+1)=L(θt)+L(θt)(θt+1θt)+12(θt+1θt)2L(θt)(θt+1θt)+O(η3). L(\theta_{t+1}) = L(\theta_t) + \nabla L(\theta_t)^\top (\theta_{t+1} - \theta_t) + \frac{1}{2} (\theta_{t+1} - \theta_t)^\top \nabla^2 L(\theta_t) (\theta_{t+1} - \theta_t) + O(\eta^3).

代入 θt+1θt=ηg~t\theta_{t+1} - \theta_t = -\eta \tilde{g}_t,并记 Ht=2L(θt)H_t = \nabla^2 L(\theta_t),得:

L(θt+1)L(θt)=ηgtg~t+12η2g~tHtg~t+O(η3). L(\theta_{t+1}) - L(\theta_t) = -\eta g_t^\top \tilde{g}_t + \frac{1}{2} \eta^2 \tilde{g}_t^\top H_t \tilde{g}_t + O(\eta^3).

对上述等式两边取期望(关于 mini-batch 的随机性):

E[L(θt+1)L(θt)]=ηgtE[g~t]+12η2E[g~tHtg~t]+O(η3). \mathbb{E}[L(\theta_{t+1}) - L(\theta_t)] = -\eta g_t^\top \mathbb{E}[\tilde{g}_t] + \frac{1}{2} \eta^2 \mathbb{E}[\tilde{g}_t^\top H_t \tilde{g}_t] + O(\eta^3).

E[g~t]=gt\mathbb{E}[\tilde{g}_t] = g_t,得:

ηgtgt=ηgt2. -\eta g_t^\top g_t = -\eta \|g_t\|^2.

计算 E[g~tHtg~t]\mathbb{E}[\tilde{g}_t^\top H_t \tilde{g}_t]。由于 HtH_t 是确定性的(在给定 θt\theta_t 下),有:

E[g~tHtg~t]=E[Tr(Htg~tg~t)]=Tr(HtE[g~tg~t]). \mathbb{E}[\tilde{g}_t^\top H_t \tilde{g}_t] = \mathbb{E}[\operatorname{Tr}(H_t \tilde{g}_t \tilde{g}_t^\top)] = \operatorname{Tr}\left(H_t \mathbb{E}[\tilde{g}_t \tilde{g}_t^\top]\right).

E[g~tg~t]=Cov(g~t)+E[g~t]E[g~t]=1BΣ+gtgt. \mathbb{E}[\tilde{g}_t \tilde{g}_t^\top] = \operatorname{Cov}(\tilde{g}_t) + \mathbb{E}[\tilde{g}_t] \mathbb{E}[\tilde{g}_t]^\top = \frac{1}{B} \Sigma + g_t g_t^\top.

因此,

E[g~tHtg~t]=1BTr(HtΣ)+gtHtgt. \mathbb{E}[\tilde{g}_t^\top H_t \tilde{g}_t] = \frac{1}{B} \operatorname{Tr}(H_t \Sigma) + g_t^\top H_t g_t.

于是二阶项贡献为:

12η2(gtHtgt+1BTr(HtΣ)). \frac{1}{2} \eta^2 \left( g_t^\top H_t g_t + \frac{1}{B} \operatorname{Tr}(H_t \Sigma) \right).

将一阶和二阶结果合并:

E[L(θt+1)L(θt)]=ηgt2+12η2gtHtgt+η22BTr(HtΣ)+O(η3). \mathbb{E}[L(\theta_{t+1}) - L(\theta_t)] = -\eta \|g_t\|^2 + \frac{1}{2} \eta^2 g_t^\top H_t g_t + \frac{\eta^2}{2B} \operatorname{Tr}(H_t \Sigma) + O(\eta^3).

在随机梯度下降的经典分析中,常忽略确定性二阶项 12η2gtHtgt\frac{1}{2} \eta^2 g_t^\top H_t g_t,因为它相对于一阶项是更高阶的小量(尤其当学习率较小时),且在很多情况下该项并不主导收敛行为。因此得到常用近似:

E[L(θt+1)L(θt)]ηL2+η22BTr(2LΣ). \boxed{\mathbb{E}[L(\theta_{t+1}) - L(\theta_t)] \approx -\eta \|\nabla L\|^2 + \frac{\eta^2}{2B} \operatorname{Tr}(\nabla^2 L \cdot \Sigma)}.

这个公式揭示了:

  • 第一项是确定性梯度下降带来的损失下降,与学习率成正比。
  • 第二项是 mini-batch 随机性导致的“噪声项”,它可能减缓下降甚至使损失上升(当 η\eta 过大时),且与 batch size BB 成反比,说明增大 batch 可减小噪声影响。

考虑参数更新的方差。更新量 Δθ=ηg~\Delta \theta = -\eta \tilde{g} 的协方差矩阵为:

Var(Δθ)=η21BΣ \text{Var}(\Delta \theta) = \eta^2 \cdot \frac{1}{B} \Sigma

当 batch 扩大为 B=kBB' = kB,新学习率为 η\eta' 时,更新量的方差变为:

Var(Δθ)=η21kBΣ \text{Var}(\Delta \theta') = \eta'^2 \cdot \frac{1}{kB} \Sigma

如果我们希望保持更新量的方差不变(即噪声引起的波动幅度不变),则需:

η21kBΣ=η21BΣη2k=η2η=kη \eta'^2 \cdot \frac{1}{kB} \Sigma = \eta^2 \cdot \frac{1}{B} \Sigma \quad \Rightarrow \quad \frac{\eta'^2}{k} = \eta^2 \quad \Rightarrow \quad \eta' = \sqrt{k}\,\eta

这种观点在 batch 极大时更为合理。因为当 batch 很大时,梯度估计的方差已经非常小,此时再线性放大学习率会导致更新量过大(因为更新量的期望与方差之比会变得很大),可能使参数跨过最优区域。保持方差不变可以避免训练不稳定。

平方根缩放也可以从二阶优化的角度理解。当 batch 极大时,梯度接近真实梯度,相当于确定性优化。此时,一步最优的步长应与 Hessian 矩阵的 Lipschitz 常数相关。如果采用一阶方法,通常学习率需要与 Lipschitz 常数的倒数成正比,而这个常数与 batch 无关。因此,不能无限放大学习率,平方根缩放是一种折中。


当 batch 极大(例如从 256 扩大到 32k)时,直接使用线性缩放(η=kη\eta' = k\eta)往往会导致训练初期发散或性能下降。原因在于:

  • 初始阶段参数随机初始化,梯度可能很大,直接使用大学习率容易导致梯度爆炸。
  • 模型需要时间适应,尤其是 Batch Normalization 的统计量也需要调整。
  • 梯度噪声极低,但参数初值离最优解很远,大步长可能使参数跳出合适区域。

Warmup 是指在训练开始的前 TwT_w 步或前几个 epoch 使用较小的学习率(例如从 0 线性增加到目标学习率)。它的作用是:

  • 让模型在初期以温和的步伐探索参数空间,避免剧烈震荡。
  • 逐渐积累梯度的稳定性,使 BN 的 running statistics 适应大学习率。
  • 为后续的大学习率打下基础。

对于极大 batch,即使采用平方根缩放,也可能需要 warmup。但许多实验表明,线性缩放配合足够长的 warmup 可以达到与平方根缩放相当甚至更好的效果。这是因为 warmup 阶段实际上起到了“软化”线性缩放的作用:初期学习率小,相当于学习率自动遵循了某种渐进增长,最终达到 kηk\eta。只要 warmup 长度足够长,模型就能平稳过渡。

通常 warmup 的 epoch 数会随 batch 扩大而增加。例如,原始 batch 256 可能用 5 epoch warmup,当 batch 扩大到 8k 时,可能需要 100 epoch 甚至更长。一种经验法则是:warmup 步数与 batch 扩大倍数成正比,即保持 warmup 期间看到的样本总数不变(即 warmup 的 epoch 数不变?不,因为 batch 变大,每 epoch 步数变少,所以需要更多 epoch 才能看到相同数量的样本)。更常见的做法是固定 warmup 的迭代步数,例如从 5000 步增加到 20000 步。


缩放方式公式推导依据适用场景
线性缩放η=kη\eta' = k\eta保持一次更新与原来 kk 次更新的期望变化量相同,或保持收敛速度一致batch 扩大倍数适中(如 ≤32 倍),梯度噪声仍明显
平方根缩放η=kη\eta' = \sqrt{k}\,\eta保持参数更新量的方差不变,避免大步长带来的不稳定batch 极大(如 128 倍以上),梯度噪声可忽略
线性 + 长 warmup先用 warmup 再线性缩放通过初期小学习率平滑过渡,使模型适应大学习率,再逐步放大超大 batch 时作为线性缩放的补充,也可替代平方根缩放

相关内容