目录

smalldiffusion 项目总览

smalldiffusion 是一个教学与实验导向的扩散模型库,核心训练和采样代码不到 100 行。它的设计目标是:

  • 提供可读、可理解的扩散模型实现
  • 支持从 2D 玩具数据到 Stable Diffusion 级别的预训练模型
  • 方便研究者快速实验新的采样算法和模型架构

论文参考:Permenter and Yuan, arXiv:2306.04848

扩散模型的核心思想是:将数据逐步加噪直到变成纯噪声,然后学习逆过程。

给定干净数据 x0x_0,前向过程生成带噪样本:

xt=x0+σε,εN(0,I)x_t = x_0 + \sigma \cdot \varepsilon, \quad \varepsilon \sim \mathcal{N}(0, I)

其中 σ\sigma 是噪声水平。smalldiffusion 使用 σ\sigma 参数化(而非常见的 α\alpha/αˉ\bar{\alpha} 参数化),两者的关系为:

σ=1αˉ1,αˉ=11+σ2\sigma = \sqrt{\frac{1}{\bar{\alpha}} - 1}, \quad \bar{\alpha} = \frac{1}{1 + \sigma^2}

代码中的 alpha(sigma) 函数即计算 αˉ\bar{\alpha}

def alpha(sigma):
    return 1 / (1 + sigma**2)

模型学习预测噪声 εθ(xt,σ)\varepsilon_\theta(x_t, \sigma),采样时从纯噪声 xTN(0,σT2I)x_T \sim \mathcal{N}(0, \sigma_T^2 I) 出发,逐步去噪。smalldiffusion 的采样公式(5 行代码)统一了 DDPM、DDIM 和加速采样:

xt1=xt(σtσp)εˉ+ηzx_{t-1} = x_t - (\sigma_t - \sigma_p) \cdot \bar{\varepsilon} + \eta \cdot z

其中:

  • εˉ\bar{\varepsilon} 是当前和上一步噪声预测的加权平均(由 gam 控制)
  • σp\sigma_pη\eta 由参数 mu 控制确定性/随机性比例
  • zN(0,I)z \sim \mathcal{N}(0, I) 是随机噪声

项目由三个核心概念组成:数据(Data)模型(Model)调度(Schedule),它们通过 training_loopsamples 两个函数协作。

┌─────────────────────────────────────────────────────┐
│                    diffusion.py                      │
│  ┌──────────────┐  ┌──────────────┐                 │
│  │ training_loop│  │   samples    │                 │
│  └──────┬───────┘  └──────┬───────┘                 │
│         │                 │                          │
│  ┌──────┴─────────────────┴──────┐                  │
│  │         Schedule 系列          │                  │
│  │ LogLinear│DDPM│LDM│Sigmoid│Cos│                  │
│  └───────────────────────────────┘                  │
└─────────────────────────────────────────────────────┘
         │                 │
    ┌────┴────┐      ┌────┴────┐
    │ data.py │      │model.py │
    │Swissroll│      │ModelMixin│
    │Datasaur.│      │Scaled   │
    │TreeData │      │PredX0   │
    │MappedDS │      │PredV    │
    └─────────┘      │TimeInput│
                     │CondMLP  │
                     │IdealDen.│
                     ├─────────┤
                     │model_dit│
                     │  DiT    │
                     ├─────────┤
                     │model_unet│
                     │  Unet   │
                     └─────────┘
模块职责关键导出
diffusion.py噪声调度、训练循环、采样算法Schedule, ScheduleLogLinear, ScheduleDDPM, ScheduleLDM, ScheduleSigmoid, ScheduleCosine, training_loop, samples
data.py数据集定义与预处理工具Swissroll, DatasaurusDozen, TreeDataset, MappedDataset, img_train_transform, img_normalize
model.py模型基类、修饰器、基础组件ModelMixin, Scaled, PredX0, PredV, TimeInputMLP, ConditionalMLP, IdealDenoiser, Attention, SigmaEmbedderSinCos, CondEmbedderLabel, CondSequential
model_dit.pyDiffusion Transformer 实现DiT
model_unet.pyU-Net 实现Unet

训练阶段:

  1. DataLoader 产出一批数据 x0x_0
  2. Schedule.sample_batch() 随机采样噪声水平 σ\sigma
  3. generate_train_sample() 生成 (x0,σ,ε)(x_0, \sigma, \varepsilon) 三元组
  4. 模型前向传播预测噪声,计算 MSE 损失
  5. 反向传播更新参数

采样阶段:

  1. Schedule.sample_sigmas(steps) 生成递减的 σ\sigma 序列
  2. xTN(0,σ02I)x_T \sim \mathcal{N}(0, \sigma_0^2 I) 开始
  3. 每步调用模型预测噪声,按采样公式更新 xtx_t
  4. 最终得到生成样本 x0x_0
特性TimeInputMLPDiTUnet
适用数据2D 玩具数据图像图像
参数量级~10K~10M~10M
时间嵌入sin/cos (2维)SigmaEmbedderSinCos + MLPSigmaEmbedderSinCos + MLP
核心结构全连接层 + GELUTransformer Block + ModulationResNet Block + Attention + Skip Connection
条件生成ConditionalMLP 变体通过 cond_embed 参数通过 cond_embed 参数
输入缩放可选 (Scaled)可选 (Scaled)通常使用 Scaled
训练速度快(秒级)中等(小时级)中等(小时级)
生成质量仅适合简单分布FashionMNIST FID ~5-6CIFAR-10 FID ~3-4
Schedule公式特点典型用途默认参数
ScheduleLogLinearσ\sigma 在对数空间线性增长玩具模型、小数据集N=200, σ_min=0.02, σ_max=10
ScheduleDDPM线性 β\beta 调度像素空间图像扩散N=1000, β_start=0.0001, β_end=0.02
ScheduleLDM缩放线性 β\beta 调度潜空间扩散 (Stable Diffusion)N=1000, β_start=0.00085, β_end=0.012
ScheduleSigmoidSigmoid 形状的 β\beta 调度分子构象生成 (GeoDiff)N=1000, β_start=0.0001, β_end=0.02
ScheduleCosine余弦 αˉ\bar{\alpha} 调度改进的 DDPM (iDDPM)N=1000, max_beta=0.999
算法gammu特点
DDPM10.5随机采样,需要较多步数
DDIM10确定性采样,可用较少步数
加速采样20利用历史噪声预测加速收敛
dependencies = [
  "accelerate",   # 多 GPU 训练支持
  "numpy",        # 数值计算
  "torchvision",  # 图像变换、数据集
  "torch",        # 深度学习框架
  "tqdm",         # 进度条
  "einops",       # 张量重排
]

相关内容