目录

smalldiffusion 模型基础:model.py

本文件定义了所有模型共享的基类、预测模式修饰器、通用组件(注意力、嵌入)、玩具模型和理想去噪器。

model.py
├── ModelMixin                # 模型基类 Mixin
├── get_sigma_embeds()        # σ 嵌入函数
├── SigmaEmbedderSinCos      # σ 嵌入模块
├── alpha()                   # σ → α̅ 转换
├── Scaled()                  # 输入缩放修饰器
├── PredX0()                  # 预测 x0 修饰器
├── PredV()                   # 预测 v 修饰器
├── CondSequential            # 条件顺序容器
├── Attention                 # 多头注意力
├── CondEmbedderLabel         # 标签条件嵌入
├── TimeInputMLP              # 玩具 MLP 模型
├── ConditionalMLP            # 条件 MLP 模型
├── sq_norm()                 # 辅助函数
└── IdealDenoiser             # 理想去噪器

所有扩散模型的 Mixin 基类,提供三个关键方法。

smalldiffusion 中的模型必须满足以下协议:

  1. 继承 torch.nn.ModuleModelMixin
  2. 设置 input_dims 属性(不含 batch 维度的输入形状元组)
  3. 实现 forward(self, x, sigma, cond=None) 方法,返回与 x 同形状的预测噪声
class ModelMixin:
    def rand_input(self, batchsize):
        """生成标准正态随机输入,形状为 [batchsize, *input_dims]"""
        assert hasattr(self, 'input_dims'), 'Model must have "input_dims" attribute!'
        return torch.randn((batchsize,) + self.input_dims)

    def get_loss(self, x0, sigma, eps, cond=None, loss=nn.MSELoss):
        """计算训练损失:预测噪声与真实噪声的 MSE"""
        return loss()(eps, self(x0 + sigma * eps, sigma, cond=cond))

    def predict_eps(self, x, sigma, cond=None):
        """预测噪声 ε(默认直接调用 forward)"""
        return self(x, sigma, cond=cond)

    def predict_eps_cfg(self, x, sigma, cond, cfg_scale):
        """带 Classifier-Free Guidance 的噪声预测"""
        ...

生成采样初始噪声。形状由模型的 input_dims 决定:

  • 2D 模型:input_dims = (2,) → 输出形状 [B, 2]
  • 图像模型:input_dims = (3, 32, 32) → 输出形状 [B, 3, 32, 32]

默认实现假设模型预测噪声 ε\varepsilon

L=MSE(ε,fθ(x0+σε,σ))\mathcal{L} = \text{MSE}(\varepsilon, f_\theta(x_0 + \sigma \cdot \varepsilon, \sigma))

其中 x0+σεx_0 + \sigma \cdot \varepsilon 是加噪后的样本。此方法可被 PredX0PredV 修饰器覆盖。

实现 Classifier-Free Guidance (CFG)

def predict_eps_cfg(self, x, sigma, cond, cfg_scale):
    if cond is None or cfg_scale == 0:
        return self.predict_eps(x, sigma, cond=cond)
    assert sigma.shape == tuple(), 'CFG sampling only supports singleton sigma!'
    uncond = torch.full_like(cond, self.cond_embed.null_cond)
    eps_cond, eps_uncond = self.predict_eps(
        torch.cat([x, x]), sigma, torch.cat([cond, uncond])
    ).chunk(2)
    return eps_cond + cfg_scale * (eps_cond - eps_uncond)

CFG 公式:

ε^=εcond+s(εcondεuncond)\hat{\varepsilon} = \varepsilon_{\text{cond}} + s \cdot (\varepsilon_{\text{cond}} - \varepsilon_{\text{uncond}})

其中 sscfg_scale。当 s>0s > 0 时,模型输出被推向条件方向,远离无条件方向。

实现技巧: 将条件和无条件输入拼接成一个 batch 一次前向传播,避免两次调用模型。


将标量 σ\sigma 值编码为 2 维嵌入向量的函数。

def get_sigma_embeds(batches, sigma, scaling_factor=0.5, log_scale=True):
    if sigma.shape == torch.Size([]):
        sigma = sigma.unsqueeze(0).repeat(batches)
    else:
        assert sigma.shape == (batches,), 'sigma.shape == [] or [batches]!'
    if log_scale:
        sigma = torch.log(sigma)
    s = sigma.unsqueeze(1) * scaling_factor
    return torch.cat([torch.sin(s), torch.cos(s)], dim=1)
  1. 标量处理:若 σ\sigma 是标量,扩展为 batch 大小
  2. 对数缩放:默认取 log(σ)\log(\sigma),将指数级变化的 σ\sigma 压缩到线性范围
  3. 正弦/余弦编码[sin(sf),cos(sf)][\sin(s \cdot f), \cos(s \cdot f)],其中 ff 是缩放因子

输出形状:[B, 2]。这是一种极简的时间嵌入,仅用 2 维就能有效编码噪声水平。

标准 Transformer 位置编码使用多个频率,输出维度通常为 128 或 256。smalldiffusion 的实现仅用 1 个频率(2 维),但论文表明在扩散模型中效果相当。


get_sigma_embeds 的 2 维输出通过 MLP 映射到高维空间的模块。

class SigmaEmbedderSinCos(nn.Module):
    def __init__(self, hidden_size, scaling_factor=0.5, log_scale=True):
        super().__init__()
        self.scaling_factor = scaling_factor
        self.log_scale = log_scale
        self.mlp = nn.Sequential(
            nn.Linear(2, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )

    def forward(self, batches, sigma):
        sig_embed = get_sigma_embeds(batches, sigma,
                                     self.scaling_factor, self.log_scale)  # (B, 2)
        return self.mlp(sig_embed)                                         # (B, D)
σ → [sin, cos] (2维) → Linear(2, D) → SiLU → Linear(D, D) → 输出 (D维)

DiTUnet 使用,将噪声水平信息注入模型。


σ\sigma 参数化到 αˉ\bar{\alpha} 参数化的转换函数。

def alpha(sigma):
    return 1 / (1 + sigma**2)
αˉ=11+σ2,σ=1αˉ1\bar{\alpha} = \frac{1}{1 + \sigma^2}, \quad \sigma = \sqrt{\frac{1}{\bar{\alpha}} - 1}

ScaledPredV 修饰器和 diffusers_wrapper.py 使用。


一个类修饰器(class decorator),对模型输入进行缩放,使不同噪声水平下输入的范数保持恒定。

def Scaled(cls: ModelMixin):
    def forward(self, x, sigma, cond=None):
        return cls.forward(self, x * alpha(sigma).sqrt(), sigma, cond=cond)
    return type(cls.__name__ + 'Scaled', (cls,), dict(forward=forward))

加噪样本 xt=x0+σεx_t = x_0 + \sigma \varepsilon 的期望范数随 σ\sigma 增大而增大。缩放因子 αˉ=11+σ2\sqrt{\bar{\alpha}} = \frac{1}{\sqrt{1+\sigma^2}} 将输入归一化:

x~t=αˉxt\tilde{x}_t = \sqrt{\bar{\alpha}} \cdot x_t

使得 E[x~t2]\mathbb{E}[\|\tilde{x}_t\|^2] 对所有 σ\sigma 近似恒定。

from smalldiffusion import Scaled, Unet

# 创建带输入缩放的 U-Net
model = Scaled(Unet)(28, 1, 1, ch=64, ch_mult=(1, 1, 2))
# 等价于创建了一个名为 "UnetScaled" 的新类

Scaled 使用 Python 的 type() 动态创建新类,继承原始类但覆盖 forward 方法。新类名为原类名 + “Scaled”。


将模型从预测噪声 ε\varepsilon 改为预测干净数据 x0x_0 的类修饰器。

def PredX0(cls: ModelMixin):
    def get_loss(self, x0, sigma, eps, cond=None, loss=nn.MSELoss):
        return loss()(x0, self(x0 + sigma * eps, sigma, cond=cond))
    def predict_eps(self, x, sigma, cond=None):
        x0_hat = self(x, sigma, cond=cond)
        return (x - x0_hat) / sigma
    return type(cls.__name__ + 'PredX0', (cls,),
                dict(get_loss=get_loss, predict_eps=predict_eps))

若模型预测 x^0\hat{x}_0,可以反推噪声预测:

ε^=xtx^0σ\hat{\varepsilon} = \frac{x_t - \hat{x}_0}{\sigma}

因为 xt=x0+σεx_t = x_0 + \sigma \varepsilon,所以 ε=(xtx0)/σ\varepsilon = (x_t - x_0) / \sigma

  • get_loss:损失变为 MSE(x0,fθ(xt,σ))\text{MSE}(x_0, f_\theta(x_t, \sigma))
  • predict_eps:从 x^0\hat{x}_0 反推 ε^\hat{\varepsilon},使采样代码无需修改

将模型改为预测 velocity vv 的类修饰器,来自 Progressive Distillation

def PredV(cls: ModelMixin):
    def get_loss(self, x0, sigma, eps, cond=None, loss=nn.MSELoss):
        xt = x0 + sigma * eps
        v = alpha(sigma).sqrt() * eps - (1 - alpha(sigma)).sqrt() * x0
        return loss()(v, self(xt, sigma, cond=cond))
    def predict_eps(self, x, sigma, cond=None):
        v_hat = self(x, sigma, cond=cond)
        return alpha(sigma).sqrt() * (v_hat + (1 - alpha(sigma)).sqrt() * x)
    return type(cls.__name__ + 'PredV', (cls,),
                dict(get_loss=get_loss, predict_eps=predict_eps))

Velocity 定义为:

v=αˉε1αˉx0v = \sqrt{\bar{\alpha}} \cdot \varepsilon - \sqrt{1 - \bar{\alpha}} \cdot x_0

v^\hat{v} 反推噪声:

ε^=αˉ(v^+1αˉxt)\hat{\varepsilon} = \sqrt{\bar{\alpha}} \cdot (\hat{v} + \sqrt{1 - \bar{\alpha}} \cdot x_t)

在高噪声水平下,预测 ε\varepsilon 的信噪比很低;在低噪声水平下,预测 x0x_0 的信噪比很低。v-prediction 在两种极端情况下都有更均衡的信噪比。

修饰器可以组合使用:

from smalldiffusion import Scaled, PredX0, PredV, DiT

# 带输入缩放 + 预测 x0
model = Scaled(PredX0(DiT))(in_dim=16, channels=3, patch_size=2, depth=4)

# 带输入缩放 + 预测 v
model = Scaled(PredV(DiT))(in_dim=16, channels=3, patch_size=2, depth=4)

支持条件输入的 nn.Sequential 变体。

class CondSequential(nn.Sequential):
    def forward(self, x, cond):
        for module in self._modules.values():
            x = module(x, cond)
        return x

标准 nn.Sequential 只支持单输入。扩散模型的中间层需要同时接收特征 x 和条件信息 cond(如时间嵌入)。CondSequential(x, cond) 传递给每个子模块。

DiT 的 Transformer Block 序列和 Unet 的中间层使用。


标准多头自注意力模块。

class Attention(nn.Module):
    def __init__(self, head_dim, num_heads=8, qkv_bias=False):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        dim = head_dim * num_heads
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        # x: (B, N, D) → (B, N, D)
        q, k, v = rearrange(self.qkv(x), 'b n (qkv h k) -> qkv b h n k',
                            h=self.num_heads, k=self.head_dim)
        x = rearrange(F.scaled_dot_product_attention(q, k, v),
                      'b h n k -> b n (h k)')
        return self.proj(x)
参数说明
head_dim每个注意力头的维度
num_heads注意力头数量
qkv_biasQKV 投影是否使用偏置
  1. 线性投影生成 Q, K, V:(B, N, D) → (B, N, 3D) → 3 × (B, H, N, d)
  2. 缩放点积注意力:F.scaled_dot_product_attention(q, k, v)(PyTorch 原生实现,自动选择 Flash Attention 等优化)
  3. 拼接多头并投影:(B, H, N, d) → (B, N, D)

DiTUnetAttnBlock 共同使用。


将离散类别标签嵌入为连续向量的模块,支持训练时的随机 dropout(用于 Classifier-Free Guidance)。

class CondEmbedderLabel(nn.Module):
    def __init__(self, hidden_size, num_classes, dropout_prob=0.1):
        super().__init__()
        self.embeddings = nn.Embedding(num_classes + 1, hidden_size)
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.null_cond = num_classes
        self.dropout_prob = dropout_prob

    def forward(self, labels):  # (B,) → (B, D)
        if self.training:
            drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
            labels = torch.where(drop_ids, self.null_cond, labels)
        return self.mlp(self.embeddings(labels))
  • 嵌入表大小num_classes + 1,额外的一个位置用于"无条件"标签
  • null_cond:值为 num_classes,表示无条件输入
  • 训练时 dropout:以 dropout_prob 概率将标签替换为 null_cond,使模型同时学习条件和无条件生成
  • MLPSiLU → Linear,将嵌入映射到模型隐藏维度

CFG 采样需要模型同时能做条件和无条件预测。训练时随机丢弃条件信息,使模型学会在无条件下也能生成合理输出。


用于 2D 玩具数据的简单 MLP 模型。

class TimeInputMLP(nn.Module, ModelMixin):
    sigma_dim = 2
    def __init__(self, dim=2, output_dim=None, hidden_dims=(16,128,256,128,16)):
        super().__init__()
        layers = []
        for in_dim, out_dim in pairwise((dim + self.sigma_dim,) + hidden_dims):
            layers.extend([nn.Linear(in_dim, out_dim), nn.GELU()])
        layers.append(nn.Linear(hidden_dims[-1], output_dim or dim))
        self.net = nn.Sequential(*layers)
        self.input_dims = (dim,)

    def forward(self, x, sigma, cond=None):
        sigma_embeds = get_sigma_embeds(x.shape[0], sigma.squeeze())  # (B, 2)
        nn_input = torch.cat([x, sigma_embeds], dim=1)                 # (B, dim+2)
        return self.net(nn_input)

默认配置 hidden_dims=(16,128,256,128,16)

输入 (dim+2) → Linear → GELU → Linear → GELU → ... → Linear → 输出 (dim)
     4          16       128      256      128     16      2
参数默认值说明
dim2数据维度
output_dimNone输出维度(默认等于 dim)
hidden_dims(16,128,256,128,16)隐藏层维度序列
from smalldiffusion import TimeInputMLP
import torch

model = TimeInputMLP(dim=2, hidden_dims=(16, 128, 128, 16))
x = torch.randn(32, 2)       # batch of 2D points
sigma = torch.tensor(1.0)     # noise level
output = model(x, sigma)      # (32, 2)

TimeInputMLP 的条件版本,支持类别标签条件生成。

class ConditionalMLP(TimeInputMLP):
    def __init__(self, dim=2, hidden_dims=(16,128,256,128,16),
                 cond_dim=4, num_classes=10, dropout_prob=0.1):
        super().__init__(dim=dim+cond_dim, output_dim=dim, hidden_dims=hidden_dims)
        self.input_dims = (dim,)  # 覆盖父类设置
        self.cond_embed = CondEmbedderLabel(cond_dim, num_classes, dropout_prob)

    def forward(self, x, sigma, cond):
        cond_embeds = self.cond_embed(cond)                           # (B, cond_dim)
        sigma_embeds = get_sigma_embeds(x.shape[0], sigma.squeeze())  # (B, sigma_dim)
        nn_input = torch.cat([x, sigma_embeds, cond_embeds], dim=1)   # (B, dim+sigma_dim+cond_dim)
        return self.net(nn_input)
  • 继承 TimeInputMLP,但将输入维度扩展为 dim + cond_dim
  • input_dims 仍设为 (dim,),因为 rand_input 只需生成数据维度的噪声
  • 条件嵌入通过 CondEmbedderLabel 将标签映射为 cond_dim 维向量
  • 三部分拼接后送入 MLP:[x, sigma_embed, cond_embed]
from smalldiffusion import ConditionalMLP
import torch

model = ConditionalMLP(dim=2, cond_dim=4, num_classes=10)
x = torch.randn(32, 2)
sigma = torch.tensor(1.0)
cond = torch.randint(0, 10, (32,))
output = model(x, sigma, cond)  # (32, 2)

计算矩阵每行的平方范数并重复 k 次的辅助函数,被 IdealDenoiser 使用。

def sq_norm(M, k):
    # M: (b, n) → (b,) → (b, k)
    return (torch.norm(M, dim=1)**2).unsqueeze(1).repeat(1, k)

用于高效计算成对平方距离矩阵 xidj2\|x_i - d_j\|^2


给定数据集的理论最优去噪器(Bayes 最优估计器),用于验证和基准测试。

class IdealDenoiser(nn.Module, ModelMixin):
    def __init__(self, dataset):
        super().__init__()
        self.data = torch.stack([dataset[i] for i in range(len(dataset))])
        self.input_dims = self.data.shape[1:]

    def forward(self, x, sigma, cond=None):
        data = self.data.to(x)
        x_flat = x.flatten(start_dim=1)
        d_flat = data.flatten(start_dim=1)
        xb, xr = x_flat.shape
        db, dr = d_flat.shape
        # 计算成对平方距离
        sq_diffs = sq_norm(x_flat, db).T + sq_norm(d_flat, xb) - 2 * d_flat @ x_flat.T
        # Softmax 权重
        weights = F.softmax(-sq_diffs / 2 / sigma.squeeze()**2, dim=0)
        # 加权平均
        eps = torch.einsum('ij,i...->j...', weights, data)
        return (x - eps) / sigma

对于高斯噪声模型 xt=x0+σεx_t = x_0 + \sigma \varepsilon,Bayes 最优去噪器为:

x^0(xt)=E[x0xt]=ix0(i)exp(xtx0(i)22σ2)iexp(xtx0(i)22σ2)\hat{x}_0(x_t) = \mathbb{E}[x_0 | x_t] = \frac{\sum_{i} x_0^{(i)} \exp\left(-\frac{\|x_t - x_0^{(i)}\|^2}{2\sigma^2}\right)}{\sum_{i} \exp\left(-\frac{\|x_t - x_0^{(i)}\|^2}{2\sigma^2}\right)}

即数据集中所有点的 softmax 加权平均,权重与距离的负指数成正比。

然后转换为噪声预测:ε^=(xtx^0)/σ\hat{\varepsilon} = (x_t - \hat{x}_0) / \sigma

使用展开的平方距离公式避免显式计算差值矩阵:

xidj2=xi2+dj22xidj\|x_i - d_j\|^2 = \|x_i\|^2 + \|d_j\|^2 - 2 x_i \cdot d_j
  • 验证采样算法的正确性
  • 作为训练模型的性能上界
  • 不需要训练,直接从数据集构造
from smalldiffusion import IdealDenoiser, Swissroll, samples, ScheduleLogLinear
import numpy as np

dataset = Swissroll(np.pi/2, 5*np.pi, 100)
model = IdealDenoiser(dataset)
schedule = ScheduleLogLinear(N=200, sigma_min=0.005, sigma_max=10)
*xt, x0 = samples(model, schedule.sample_sigmas(20), gam=2)

相关内容