smalldiffusion 核心模块:diffusion.py
本文件是 smalldiffusion 的核心,包含噪声调度(Schedule)、训练循环(training_loop)和采样算法(samples),总计不到 100 行代码。
2.1 模块结构
diffusion.py
├── Schedule (基类)
│ ├── ScheduleLogLinear
│ ├── ScheduleDDPM
│ ├── ScheduleLDM
│ ├── ScheduleSigmoid
│ └── ScheduleCosine
├── sigmas_from_betas() # 辅助函数:β → σ 转换
├── generate_train_sample() # 辅助函数:生成训练样本
├── training_loop() # 训练循环生成器
└── samples() # 采样生成器2.2 Schedule 基类
是什么
Schedule 管理扩散过程中噪声水平 的递增序列。它是所有调度策略的基类。
为什么需要
扩散模型的训练和采样都依赖于一个预定义的噪声调度:
- 训练时:随机采样一个 值,决定给数据加多少噪声
- 采样时:按递减的 序列逐步去噪
接口定义
class Schedule:
def __init__(self, sigmas: torch.FloatTensor):
self.sigmas = sigmas # 递增的 σ 序列
def __getitem__(self, i) -> torch.FloatTensor:
return self.sigmas[i] # 支持索引和切片
def __len__(self) -> int:
return len(self.sigmas)
def sample_sigmas(self, steps: int) -> torch.FloatTensor:
"""采样时使用:从完整调度中子采样 steps+1 个递减的 σ 值"""
...
def sample_batch(self, x0: torch.FloatTensor) -> torch.FloatTensor:
"""训练时使用:随机采样一批 σ 值"""
...sample_sigmas(steps) 详解
该方法在采样阶段调用,从完整的 步调度中选取 steps 个时间步,返回 steps + 1 个递减的 值(包含起始和终止值)。
采用 “trailing” 间距策略(参考 Table 2, arXiv:2305.08891):
def sample_sigmas(self, steps: int) -> torch.FloatTensor:
indices = list((len(self) * (1 - np.arange(0, steps)/steps))
.round().astype(np.int64) - 1)
return self[indices + [0]]工作原理:
np.arange(0, steps)/steps生成[0, 1/steps, 2/steps, ..., (steps-1)/steps]1 - ...翻转为递减序列- 乘以
len(self)并四舍五入得到索引 - 末尾追加索引
0(最小 )
示例: 若 N=1000, steps=5,则选取索引约为 [999, 799, 599, 399, 199, 0],返回 6 个 值。
sample_batch(x0) 详解
训练时调用,为每个样本随机采样一个 值:
def sample_batch(self, x0: torch.FloatTensor) -> torch.FloatTensor:
batchsize = x0.shape[0]
return self[torch.randint(len(self), (batchsize,))].to(x0)从 [0, N) 均匀随机选取索引,返回对应的 值,并转移到与 x0 相同的设备。
2.3 sigmas_from_betas 辅助函数
是什么
将 序列转换为 序列的工具函数。
数学推导
给定 序列,累积乘积 ,则:
def sigmas_from_betas(betas: torch.FloatTensor):
return (1/torch.cumprod(1.0 - betas, dim=0) - 1).sqrt()为什么需要
大多数经典扩散模型论文(DDPM、LDM 等)使用 参数化定义调度,而 smalldiffusion 内部统一使用 参数化。此函数是两种参数化之间的桥梁。
2.4 ScheduleLogLinear
是什么
在对数空间中线性插值的简单调度, 从 sigma_min 到 sigma_max 呈对数线性增长。
class ScheduleLogLinear(Schedule):
def __init__(self, N: int, sigma_min: float=0.02, sigma_max: float=10):
super().__init__(torch.logspace(math.log10(sigma_min), math.log10(sigma_max), N))适用场景
- 玩具模型和小数据集
- 快速实验和原型验证
- 与
Scaled修饰器配合使用效果好(U-Net 示例中使用)
使用示例
from smalldiffusion import ScheduleLogLinear
schedule = ScheduleLogLinear(N=200, sigma_min=0.005, sigma_max=10)
print(f"σ 范围: [{schedule[0]:.4f}, {schedule[-1]:.4f}]")
print(f"总步数: {len(schedule)}")
# 采样时子采样 20 步
sigmas = schedule.sample_sigmas(20)
print(f"采样 σ 序列长度: {len(sigmas)}") # 212.5 ScheduleDDPM
是什么
复现 DDPM 论文 (Ho et al., 2020) 中的线性 调度。
class ScheduleDDPM(Schedule):
def __init__(self, N: int=1000, beta_start: float=0.0001, beta_end: float=0.02):
super().__init__(sigmas_from_betas(torch.linspace(beta_start, beta_end, N)))数学细节
从 beta_start 到 beta_end 线性增长:
然后通过 sigmas_from_betas 转换为 序列。
适用场景
- 像素空间图像扩散模型
- 与 HuggingFace Diffusers 的
DDIMScheduler/DDPMScheduler等价
2.6 ScheduleLDM
是什么
复现潜空间扩散模型(如 Stable Diffusion)使用的 “scaled linear” 调度。
class ScheduleLDM(Schedule):
def __init__(self, N: int=1000, beta_start: float=0.00085, beta_end: float=0.012):
super().__init__(sigmas_from_betas(torch.linspace(beta_start**0.5, beta_end**0.5, N)**2))数学细节
与 DDPM 不同,LDM 对 做线性插值后再平方:
这使得 的增长更平缓,适合潜空间中的扩散。
适用场景
- Stable Diffusion 等潜空间扩散模型
- 默认参数与
CompVis/stable-diffusion-v1-4的调度一致
2.7 ScheduleSigmoid
是什么
使用 Sigmoid 函数定义 调度,来自 GeoDiff。
class ScheduleSigmoid(Schedule):
def __init__(self, N: int=1000, beta_start: float=0.0001, beta_end: float=0.02):
betas = torch.sigmoid(torch.linspace(-6, 6, N)) * (beta_end - beta_start) + beta_start
super().__init__(sigmas_from_betas(betas))数学细节
Sigmoid 形状使得 在中间区域变化最快,两端变化缓慢,形成 S 形曲线。
适用场景
- 分子构象生成(GeoDiff)
- CIFAR-10 训练示例中使用此调度
2.8 ScheduleCosine
是什么
使用余弦函数定义 调度,来自 iDDPM (Nichol & Dhariwal, 2021)。
class ScheduleCosine(Schedule):
def __init__(self, N: int=1000, beta_start: float=0.0001, beta_end: float=0.02, max_beta: float=0.999):
alpha_bar = lambda t: np.cos((t + 0.008) / 1.008 * np.pi / 2) ** 2
betas = [min(1 - alpha_bar((i+1)/N)/alpha_bar(i/N), max_beta)
for i in range(N)]
super().__init__(sigmas_from_betas(torch.tensor(betas, dtype=torch.float32)))数学细节
偏移量 0.008 防止 时 过小;max_beta 截断防止数值不稳定。
适用场景
- 改进的 DDPM 训练
- 在低噪声区域提供更均匀的信噪比变化
2.9 generate_train_sample 函数
是什么
为训练生成 四元组的辅助函数。
def generate_train_sample(x0, schedule, conditional=False):
cond = x0[1] if conditional else None
x0 = x0[0] if conditional else x0
sigma = schedule.sample_batch(x0)
while len(sigma.shape) < len(x0.shape):
sigma = sigma.unsqueeze(-1)
eps = torch.randn_like(x0)
return x0, sigma, eps, cond工作流程
- 条件处理:若
conditional=True,x0是(data, labels)元组,拆分为数据和条件 - 采样 σ:从调度中随机采样一批 值
- 维度对齐:将 扩展到与 相同的维度数(用于广播),例如图像数据
[B, C, H, W]需要 形状为[B, 1, 1, 1] - 生成噪声:采样与 同形状的标准正态噪声
2.10 training_loop 函数
是什么
一个 Python 生成器函数,实现完整的扩散模型训练循环。
def training_loop(loader, model, schedule, accelerator=None,
epochs=10000, lr=1e-3, conditional=False):
accelerator = accelerator or Accelerator()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
for _ in (pbar := tqdm(range(epochs))):
for x0 in loader:
model.train()
optimizer.zero_grad()
x0, sigma, eps, cond = generate_train_sample(x0, schedule, conditional)
loss = model.get_loss(x0, sigma, eps, cond=cond)
yield SimpleNamespace(**locals())
accelerator.backward(loss)
optimizer.step()参数说明
| 参数 | 类型 | 说明 |
|---|---|---|
loader | DataLoader | PyTorch 数据加载器 |
model | nn.Module | 扩散模型(需实现 get_loss 方法) |
schedule | Schedule | 噪声调度 |
accelerator | Accelerator | HuggingFace Accelerate 实例(可选,默认自动创建) |
epochs | int | 训练轮数,默认 10000 |
lr | float | 学习率,默认 1e-3 |
conditional | bool | 是否为条件生成,默认 False |
设计亮点
- 生成器模式:使用
yield而非回调,调用者可以在每个训练步后执行自定义逻辑(如记录损失、保存检查点) - Accelerate 集成:通过
accelerator.prepare()自动处理多 GPU 分布式训练 - 命名空间暴露:
yield SimpleNamespace(**locals())将所有局部变量(loss,x0,sigma,eps,pbar等)暴露给调用者
使用示例
from smalldiffusion import training_loop
# 基本用法:收集损失
trainer = training_loop(loader, model, schedule, epochs=100)
losses = [ns.loss.item() for ns in trainer]
# 高级用法:自定义训练逻辑
for ns in training_loop(loader, model, schedule, epochs=100):
ns.pbar.set_description(f'Loss={ns.loss.item():.5f}')
if ns.loss.item() < 0.01:
break # 提前停止训练流程图
for each epoch:
for each batch x0 from loader:
1. model.train()
2. optimizer.zero_grad()
3. (x0, σ, ε, cond) = generate_train_sample(x0, schedule)
4. loss = model.get_loss(x0, σ, ε, cond)
5. yield namespace ← 调用者可在此处插入逻辑
6. loss.backward()
7. optimizer.step()2.11 samples 函数
是什么
扩散模型的通用采样生成器,仅用 5 行核心代码统一了 DDPM、DDIM 和加速采样算法。
@torch.no_grad()
def samples(model, sigmas, gam=1., mu=0., cfg_scale=0.,
batchsize=1, xt=None, cond=None, accelerator=None):
model.eval()
accelerator = accelerator or Accelerator()
xt = model.rand_input(batchsize).to(accelerator.device) * sigmas[0] if xt is None else xt
if cond is not None:
assert cond.shape[0] == xt.shape[0], 'cond must have same shape as x!'
cond = cond.to(xt.device)
eps = None
for i, (sig, sig_prev) in enumerate(pairwise(sigmas)):
eps_prev, eps = eps, model.predict_eps_cfg(xt, sig.to(xt), cond, cfg_scale)
eps_av = eps * gam + eps_prev * (1-gam) if i > 0 else eps
sig_p = (sig_prev/sig**mu)**(1/(1-mu))
eta = (sig_prev**2 - sig_p**2).sqrt()
xt = xt - (sig - sig_p) * eps_av + eta * model.rand_input(xt.shape[0]).to(xt)
yield xt参数说明
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
model | nn.Module | - | 扩散模型 |
sigmas | FloatTensor | - | 递减的 σ 序列(N+1 个值对应 N 步采样) |
gam | float | 1.0 | 噪声预测平均权重,建议 ≥ 1 |
mu | float | 0.0 | 随机性控制参数,范围 [0, 1) |
cfg_scale | float | 0.0 | Classifier-Free Guidance 强度,0 表示不使用 |
batchsize | int | 1 | 生成样本数 |
xt | FloatTensor | None | 自定义初始噪声(可选) |
cond | Tensor | None | 条件信息(可选) |
accelerator | Accelerator | None | 多 GPU 支持 |
核心采样公式推导
每一步从 到 的更新:
第 1 步:噪声预测平均
当 gam=1 时退化为仅使用当前预测;gam=2 时利用历史预测进行外推加速。
第 2 步:计算中间 σ
当 mu=0 时 (确定性);当 mu=0.5 时引入随机性(DDPM 行为)。
第 3 步:计算随机噪声幅度
第 4 步:更新
采样算法对应关系
| 算法 | gam | mu | 行为 |
|---|---|---|---|
| DDPM | 1 | 0.5 | ,,随机采样 |
| DDIM | 1 | 0 | ,,确定性采样 |
| 加速 | 2 | 0 | 利用 外推,确定性,更少步数 |
使用示例
from smalldiffusion import samples, ScheduleLogLinear
schedule = ScheduleLogLinear(N=200, sigma_min=0.005, sigma_max=10)
# DDIM 采样(确定性)
*intermediates, x0 = samples(model, schedule.sample_sigmas(20), gam=1, mu=0)
# DDPM 采样(随机)
*intermediates, x0 = samples(model, schedule.sample_sigmas(50), gam=1, mu=0.5)
# 加速采样
*intermediates, x0 = samples(model, schedule.sample_sigmas(10), gam=2)
# 条件采样 + CFG
import torch
cond = torch.tensor([0, 1, 2, 3]) # 4 个类别标签
*intermediates, x0 = samples(model, schedule.sample_sigmas(20),
gam=1.6, batchsize=4,
cond=cond, cfg_scale=4.0)生成器特性
samples 是生成器,每步 yield 当前的 。这允许:
- 可视化去噪过程的中间结果
- 使用
*xt, x0 = samples(...)解包获取所有中间步和最终结果 - 提前终止采样