目录

smalldiffusion 核心模块:diffusion.py

本文件是 smalldiffusion 的核心,包含噪声调度(Schedule)、训练循环(training_loop)和采样算法(samples),总计不到 100 行代码。

diffusion.py
├── Schedule (基类)
│   ├── ScheduleLogLinear
│   ├── ScheduleDDPM
│   ├── ScheduleLDM
│   ├── ScheduleSigmoid
│   └── ScheduleCosine
├── sigmas_from_betas()      # 辅助函数:β → σ 转换
├── generate_train_sample()  # 辅助函数:生成训练样本
├── training_loop()          # 训练循环生成器
└── samples()                # 采样生成器

Schedule 管理扩散过程中噪声水平 σ\sigma 的递增序列。它是所有调度策略的基类。

扩散模型的训练和采样都依赖于一个预定义的噪声调度:

  • 训练时:随机采样一个 σ\sigma 值,决定给数据加多少噪声
  • 采样时:按递减的 σ\sigma 序列逐步去噪
class Schedule:
    def __init__(self, sigmas: torch.FloatTensor):
        self.sigmas = sigmas  # 递增的 σ 序列

    def __getitem__(self, i) -> torch.FloatTensor:
        return self.sigmas[i]  # 支持索引和切片

    def __len__(self) -> int:
        return len(self.sigmas)

    def sample_sigmas(self, steps: int) -> torch.FloatTensor:
        """采样时使用:从完整调度中子采样 steps+1 个递减的 σ 值"""
        ...

    def sample_batch(self, x0: torch.FloatTensor) -> torch.FloatTensor:
        """训练时使用:随机采样一批 σ 值"""
        ...

该方法在采样阶段调用,从完整的 NN 步调度中选取 steps 个时间步,返回 steps + 1 个递减的 σ\sigma 值(包含起始和终止值)。

采用 “trailing” 间距策略(参考 Table 2, arXiv:2305.08891):

def sample_sigmas(self, steps: int) -> torch.FloatTensor:
    indices = list((len(self) * (1 - np.arange(0, steps)/steps))
                   .round().astype(np.int64) - 1)
    return self[indices + [0]]

工作原理:

  • np.arange(0, steps)/steps 生成 [0, 1/steps, 2/steps, ..., (steps-1)/steps]
  • 1 - ... 翻转为递减序列
  • 乘以 len(self) 并四舍五入得到索引
  • 末尾追加索引 0(最小 σ\sigma

示例:N=1000, steps=5,则选取索引约为 [999, 799, 599, 399, 199, 0],返回 6 个 σ\sigma 值。

训练时调用,为每个样本随机采样一个 σ\sigma 值:

def sample_batch(self, x0: torch.FloatTensor) -> torch.FloatTensor:
    batchsize = x0.shape[0]
    return self[torch.randint(len(self), (batchsize,))].to(x0)

[0, N) 均匀随机选取索引,返回对应的 σ\sigma 值,并转移到与 x0 相同的设备。


β\beta 序列转换为 σ\sigma 序列的工具函数。

给定 βt\beta_t 序列,累积乘积 αˉt=s=1t(1βs)\bar{\alpha}_t = \prod_{s=1}^{t}(1-\beta_s),则:

σt=1αˉt1\sigma_t = \sqrt{\frac{1}{\bar{\alpha}_t} - 1}
def sigmas_from_betas(betas: torch.FloatTensor):
    return (1/torch.cumprod(1.0 - betas, dim=0) - 1).sqrt()

大多数经典扩散模型论文(DDPM、LDM 等)使用 β\beta 参数化定义调度,而 smalldiffusion 内部统一使用 σ\sigma 参数化。此函数是两种参数化之间的桥梁。


在对数空间中线性插值的简单调度,σ\sigmasigma_minsigma_max 呈对数线性增长。

class ScheduleLogLinear(Schedule):
    def __init__(self, N: int, sigma_min: float=0.02, sigma_max: float=10):
        super().__init__(torch.logspace(math.log10(sigma_min), math.log10(sigma_max), N))
  • 玩具模型和小数据集
  • 快速实验和原型验证
  • Scaled 修饰器配合使用效果好(U-Net 示例中使用)
from smalldiffusion import ScheduleLogLinear

schedule = ScheduleLogLinear(N=200, sigma_min=0.005, sigma_max=10)
print(f"σ 范围: [{schedule[0]:.4f}, {schedule[-1]:.4f}]")
print(f"总步数: {len(schedule)}")
# 采样时子采样 20 步
sigmas = schedule.sample_sigmas(20)
print(f"采样 σ 序列长度: {len(sigmas)}")  # 21

复现 DDPM 论文 (Ho et al., 2020) 中的线性 β\beta 调度。

class ScheduleDDPM(Schedule):
    def __init__(self, N: int=1000, beta_start: float=0.0001, beta_end: float=0.02):
        super().__init__(sigmas_from_betas(torch.linspace(beta_start, beta_end, N)))

β\betabeta_startbeta_end 线性增长:

βt=βstart+tN1(βendβstart)\beta_t = \beta_{\text{start}} + \frac{t}{N-1}(\beta_{\text{end}} - \beta_{\text{start}})

然后通过 sigmas_from_betas 转换为 σ\sigma 序列。

  • 像素空间图像扩散模型
  • 与 HuggingFace Diffusers 的 DDIMScheduler / DDPMScheduler 等价

复现潜空间扩散模型(如 Stable Diffusion)使用的 “scaled linear” β\beta 调度。

class ScheduleLDM(Schedule):
    def __init__(self, N: int=1000, beta_start: float=0.00085, beta_end: float=0.012):
        super().__init__(sigmas_from_betas(torch.linspace(beta_start**0.5, beta_end**0.5, N)**2))

与 DDPM 不同,LDM 对 β\sqrt{\beta} 做线性插值后再平方:

βt=(βstart+tN1(βendβstart))2\beta_t = \left(\sqrt{\beta_{\text{start}}} + \frac{t}{N-1}(\sqrt{\beta_{\text{end}}} - \sqrt{\beta_{\text{start}}})\right)^2

这使得 β\beta 的增长更平缓,适合潜空间中的扩散。

  • Stable Diffusion 等潜空间扩散模型
  • 默认参数与 CompVis/stable-diffusion-v1-4 的调度一致

使用 Sigmoid 函数定义 β\beta 调度,来自 GeoDiff

class ScheduleSigmoid(Schedule):
    def __init__(self, N: int=1000, beta_start: float=0.0001, beta_end: float=0.02):
        betas = torch.sigmoid(torch.linspace(-6, 6, N)) * (beta_end - beta_start) + beta_start
        super().__init__(sigmas_from_betas(betas))
βt=sigmoid(6+12tN1)(βendβstart)+βstart\beta_t = \text{sigmoid}\left(-6 + \frac{12t}{N-1}\right) \cdot (\beta_{\text{end}} - \beta_{\text{start}}) + \beta_{\text{start}}

Sigmoid 形状使得 β\beta 在中间区域变化最快,两端变化缓慢,形成 S 形曲线。

  • 分子构象生成(GeoDiff)
  • CIFAR-10 训练示例中使用此调度

使用余弦函数定义 αˉ\bar{\alpha} 调度,来自 iDDPM (Nichol & Dhariwal, 2021)

class ScheduleCosine(Schedule):
    def __init__(self, N: int=1000, beta_start: float=0.0001, beta_end: float=0.02, max_beta: float=0.999):
        alpha_bar = lambda t: np.cos((t + 0.008) / 1.008 * np.pi / 2) ** 2
        betas = [min(1 - alpha_bar((i+1)/N)/alpha_bar(i/N), max_beta)
                 for i in range(N)]
        super().__init__(sigmas_from_betas(torch.tensor(betas, dtype=torch.float32)))
αˉ(t)=cos2(t+0.0081.008π2)\bar{\alpha}(t) = \cos^2\left(\frac{t + 0.008}{1.008} \cdot \frac{\pi}{2}\right)βt=min(1αˉ(t+1/N)αˉ(t/N), βmax)\beta_t = \min\left(1 - \frac{\bar{\alpha}(t+1/N)}{\bar{\alpha}(t/N)},\ \beta_{\max}\right)

偏移量 0.008 防止 t=0t=0β\beta 过小;max_beta 截断防止数值不稳定。

  • 改进的 DDPM 训练
  • 在低噪声区域提供更均匀的信噪比变化

为训练生成 (x0,σ,ε,cond)(x_0, \sigma, \varepsilon, \text{cond}) 四元组的辅助函数。

def generate_train_sample(x0, schedule, conditional=False):
    cond = x0[1] if conditional else None
    x0   = x0[0] if conditional else x0
    sigma = schedule.sample_batch(x0)
    while len(sigma.shape) < len(x0.shape):
        sigma = sigma.unsqueeze(-1)
    eps = torch.randn_like(x0)
    return x0, sigma, eps, cond
  1. 条件处理:若 conditional=Truex0(data, labels) 元组,拆分为数据和条件
  2. 采样 σ:从调度中随机采样一批 σ\sigma
  3. 维度对齐:将 σ\sigma 扩展到与 x0x_0 相同的维度数(用于广播),例如图像数据 [B, C, H, W] 需要 σ\sigma 形状为 [B, 1, 1, 1]
  4. 生成噪声:采样与 x0x_0 同形状的标准正态噪声 ε\varepsilon

一个 Python 生成器函数,实现完整的扩散模型训练循环。

def training_loop(loader, model, schedule, accelerator=None,
                  epochs=10000, lr=1e-3, conditional=False):
    accelerator = accelerator or Accelerator()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
    for _ in (pbar := tqdm(range(epochs))):
        for x0 in loader:
            model.train()
            optimizer.zero_grad()
            x0, sigma, eps, cond = generate_train_sample(x0, schedule, conditional)
            loss = model.get_loss(x0, sigma, eps, cond=cond)
            yield SimpleNamespace(**locals())
            accelerator.backward(loss)
            optimizer.step()
参数类型说明
loaderDataLoaderPyTorch 数据加载器
modelnn.Module扩散模型(需实现 get_loss 方法)
scheduleSchedule噪声调度
acceleratorAcceleratorHuggingFace Accelerate 实例(可选,默认自动创建)
epochsint训练轮数,默认 10000
lrfloat学习率,默认 1e-3
conditionalbool是否为条件生成,默认 False
  1. 生成器模式:使用 yield 而非回调,调用者可以在每个训练步后执行自定义逻辑(如记录损失、保存检查点)
  2. Accelerate 集成:通过 accelerator.prepare() 自动处理多 GPU 分布式训练
  3. 命名空间暴露yield SimpleNamespace(**locals()) 将所有局部变量(loss, x0, sigma, eps, pbar 等)暴露给调用者
from smalldiffusion import training_loop

# 基本用法:收集损失
trainer = training_loop(loader, model, schedule, epochs=100)
losses = [ns.loss.item() for ns in trainer]

# 高级用法:自定义训练逻辑
for ns in training_loop(loader, model, schedule, epochs=100):
    ns.pbar.set_description(f'Loss={ns.loss.item():.5f}')
    if ns.loss.item() < 0.01:
        break  # 提前停止
for each epoch:
    for each batch x0 from loader:
        1. model.train()
        2. optimizer.zero_grad()
        3. (x0, σ, ε, cond) = generate_train_sample(x0, schedule)
        4. loss = model.get_loss(x0, σ, ε, cond)
        5. yield namespace   调用者可在此处插入逻辑
        6. loss.backward()
        7. optimizer.step()

扩散模型的通用采样生成器,仅用 5 行核心代码统一了 DDPM、DDIM 和加速采样算法。

@torch.no_grad()
def samples(model, sigmas, gam=1., mu=0., cfg_scale=0.,
            batchsize=1, xt=None, cond=None, accelerator=None):
    model.eval()
    accelerator = accelerator or Accelerator()
    xt = model.rand_input(batchsize).to(accelerator.device) * sigmas[0] if xt is None else xt
    if cond is not None:
        assert cond.shape[0] == xt.shape[0], 'cond must have same shape as x!'
        cond = cond.to(xt.device)
    eps = None
    for i, (sig, sig_prev) in enumerate(pairwise(sigmas)):
        eps_prev, eps = eps, model.predict_eps_cfg(xt, sig.to(xt), cond, cfg_scale)
        eps_av = eps * gam + eps_prev * (1-gam)  if i > 0 else eps
        sig_p = (sig_prev/sig**mu)**(1/(1-mu))
        eta = (sig_prev**2 - sig_p**2).sqrt()
        xt = xt - (sig - sig_p) * eps_av + eta * model.rand_input(xt.shape[0]).to(xt)
        yield xt
参数类型默认值说明
modelnn.Module-扩散模型
sigmasFloatTensor-递减的 σ 序列(N+1 个值对应 N 步采样)
gamfloat1.0噪声预测平均权重,建议 ≥ 1
mufloat0.0随机性控制参数,范围 [0, 1)
cfg_scalefloat0.0Classifier-Free Guidance 强度,0 表示不使用
batchsizeint1生成样本数
xtFloatTensorNone自定义初始噪声(可选)
condTensorNone条件信息(可选)
acceleratorAcceleratorNone多 GPU 支持

每一步从 σt\sigma_tσt1\sigma_{t-1} 的更新:

第 1 步:噪声预测平均

εˉ=γεt+(1γ)εt+1\bar{\varepsilon} = \gamma \cdot \varepsilon_t + (1-\gamma) \cdot \varepsilon_{t+1}

gam=1 时退化为仅使用当前预测;gam=2 时利用历史预测进行外推加速。

第 2 步:计算中间 σ

σp=(σt1σtμ)1/(1μ)\sigma_p = \left(\frac{\sigma_{t-1}}{\sigma_t^\mu}\right)^{1/(1-\mu)}

mu=0σp=σt1\sigma_p = \sigma_{t-1}(确定性);当 mu=0.5 时引入随机性(DDPM 行为)。

第 3 步:计算随机噪声幅度

η=σt12σp2\eta = \sqrt{\sigma_{t-1}^2 - \sigma_p^2}

第 4 步:更新

xt1=xt(σtσp)εˉ+ηz,zN(0,I)x_{t-1} = x_t - (\sigma_t - \sigma_p) \cdot \bar{\varepsilon} + \eta \cdot z, \quad z \sim \mathcal{N}(0, I)
算法gammu行为
DDPM10.5σp<σt1\sigma_p < \sigma_{t-1}η>0\eta > 0,随机采样
DDIM10σp=σt1\sigma_p = \sigma_{t-1}η=0\eta = 0,确定性采样
加速20利用 εt+1\varepsilon_{t+1} 外推,确定性,更少步数
from smalldiffusion import samples, ScheduleLogLinear

schedule = ScheduleLogLinear(N=200, sigma_min=0.005, sigma_max=10)

# DDIM 采样(确定性)
*intermediates, x0 = samples(model, schedule.sample_sigmas(20), gam=1, mu=0)

# DDPM 采样(随机)
*intermediates, x0 = samples(model, schedule.sample_sigmas(50), gam=1, mu=0.5)

# 加速采样
*intermediates, x0 = samples(model, schedule.sample_sigmas(10), gam=2)

# 条件采样 + CFG
import torch
cond = torch.tensor([0, 1, 2, 3])  # 4 个类别标签
*intermediates, x0 = samples(model, schedule.sample_sigmas(20),
                              gam=1.6, batchsize=4,
                              cond=cond, cfg_scale=4.0)

samples 是生成器,每步 yield 当前的 xtx_t。这允许:

  • 可视化去噪过程的中间结果
  • 使用 *xt, x0 = samples(...) 解包获取所有中间步和最终结果
  • 提前终止采样

相关内容