目录

smalldiffusion 模型:model_dit.py

本文件实现了 DiT (Peebles & Xie, 2022) 架构,一种基于 Transformer 的扩散模型。

model_dit.py
├── PatchEmbed          # 图像 Patch 嵌入
├── Modulation          # 自适应调制层
├── ModulatedLayerNorm  # 调制 LayerNorm
├── DiTBlock            # DiT Transformer 块
├── get_pos_embed()     # 2D 正弦余弦位置编码
└── DiT                 # 完整 DiT 模型

输入图像 (B, C, H, W)
PatchEmbed ──→ (B, N, D)  ← + pos_embed (1, N, D)
    │  σ ──→ SigmaEmbedderSinCos ──→ y (B, D)
    │  cond ──→ CondEmbedderLabel ──→ + y
DiTBlock × depth  ← y 作为条件
ModulatedLayerNorm ← y
Linear (D → patch_size² × C)
Unpatchify ──→ 输出 (B, C, H, W)

将图像分割为不重叠的 patch 并线性嵌入到高维空间。

class PatchEmbed(nn.Module):
    def __init__(self, patch_size=16, channels=3, embed_dim=768, bias=True):
        super().__init__()
        self.proj = nn.Conv2d(channels, embed_dim,
                              stride=patch_size, kernel_size=patch_size, bias=bias)
        self.init()

    def init(self):
        w = self.proj.weight.data
        nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
        nn.init.constant_(self.proj.bias, 0)

    def forward(self, x):
        return rearrange(self.proj(x), 'b c h w -> b (h w) c')

使用 Conv2d 实现 patch 化:

  • kernel_size=patch_size, stride=patch_size:不重叠地将图像分割为 patch
  • 输入 (B, C, H, W) → 卷积输出 (B, D, H/p, W/p) → 重排为 (B, N, D)
  • 其中 N = (H/p) × (W/p) 是 patch 数量,D = embed_dim

使用 Xavier 均匀初始化(模拟 nn.Linear 的初始化),偏置初始化为 0。


从条件向量 yy 生成 nn 组调制参数的模块。

class Modulation(nn.Module):
    def __init__(self, dim, n):
        super().__init__()
        self.n = n
        self.proj = nn.Sequential(nn.SiLU(), nn.Linear(dim, n * dim, bias=True))
        nn.init.constant_(self.proj[-1].weight, 0)
        nn.init.constant_(self.proj[-1].bias, 0)

    def forward(self, y):
        return [m.unsqueeze(1) for m in self.proj(y).chunk(self.n, dim=1)]
  1. 输入条件向量 yy:形状 (B, D)
  2. 通过 SiLU → Linear 映射到 (B, n*D)
  3. 沿维度 1 分割为 nn 组,每组 (B, D)
  4. 每组增加维度变为 (B, 1, D) 用于广播

线性层的权重和偏置初始化为 0,使得训练初期调制参数为零向量,模型行为接近无条件的标准 Transformer。这是一种常见的稳定训练技巧。


带自适应调制的 LayerNorm,即 adaLN(adaptive Layer Normalization)。

class ModulatedLayerNorm(nn.LayerNorm):
    def __init__(self, dim, **kwargs):
        super().__init__(dim, **kwargs)
        self.modulation = Modulation(dim, 2)

    def forward(self, x, y):
        scale, shift = self.modulation(y)
        return super().forward(x) * (1 + scale) + shift
adaLN(x,y)=LN(x)(1+γ(y))+β(y)\text{adaLN}(x, y) = \text{LN}(x) \cdot (1 + \gamma(y)) + \beta(y)

其中 γ,β\gamma, \beta 是从条件 yy 生成的 scale 和 shift 参数。

标准 LayerNorm 的 scale/shift 是可学习参数,对所有输入相同。adaLN 使这些参数依赖于条件信息(噪声水平 + 类别标签),让模型在不同条件下有不同的归一化行为。这是 DiT 论文中性能最好的条件注入方式。


DiT 的核心 Transformer 块,包含调制注意力和调制 MLP。

class DiTBlock(nn.Module):
    def __init__(self, head_dim, num_heads, mlp_ratio=4.0):
        super().__init__()
        dim = head_dim * num_heads
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.norm1 = ModulatedLayerNorm(dim, elementwise_affine=False, eps=1e-6)
        self.attn = Attention(head_dim, num_heads=num_heads, qkv_bias=True)
        self.norm2 = ModulatedLayerNorm(dim, elementwise_affine=False, eps=1e-6)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim, bias=True),
            nn.GELU(approximate="tanh"),
            nn.Linear(mlp_hidden_dim, dim, bias=True),
        )
        self.scale_modulation = Modulation(dim, 2)

    def forward(self, x, y):
        gate_msa, gate_mlp = self.scale_modulation(y)
        x = x + gate_msa * self.attn(self.norm1(x, y))
        x = x + gate_mlp * self.mlp(self.norm2(x, y))
        return x
输入 x (B, N, D), y (B, D)
├─ gate_msa, gate_mlp = Modulation(y)     # 门控参数
├─ h = adaLN_1(x, y)                      # 调制归一化
├─ h = Attention(h)                        # 多头自注意力
├─ x = x + gate_msa * h                   # 门控残差连接
├─ h = adaLN_2(x, y)                      # 调制归一化
├─ h = MLP(h)                             # 前馈网络
├─ x = x + gate_mlp * h                   # 门控残差连接
输出 x (B, N, D)
  1. adaLN-Zero:使用 elementwise_affine=False 的 LayerNorm,所有 scale/shift 完全由条件 yy 决定
  2. 门控机制gate_msagate_mlp 对注意力和 MLP 的输出进行缩放,初始化为 0 使训练初期残差连接为恒等映射
  3. MLP 扩展比mlp_ratio=4.0,隐藏层维度是输入的 4 倍(标准 Transformer 设计)
  4. 近似 GELU:使用 tanh 近似的 GELU 激活函数

生成 2D 正弦余弦位置编码,为每个 patch 提供空间位置信息。

def get_pos_embed(in_dim, patch_size, dim, N=10000):
    n = in_dim // patch_size                                          # 每边 patch 数
    assert dim % 4 == 0, 'Embedding dimension must be multiple of 4!'
    omega = 1/N**np.linspace(0, 1, dim // 4, endpoint=False)          # 频率
    freqs = np.outer(np.arange(n), omega)                             # (n, dim/4)
    embeds = repeat(np.stack([np.sin(freqs), np.cos(freqs)]),
                       ' b n d -> b n k d', k=n)                      # (2, n, n, dim/4)
    embeds_2d = np.concatenate([
        rearrange(embeds, 'b n k d -> (k n) (b d)'),                  # (n², dim/2)
        rearrange(embeds, 'b n k d -> (n k) (b d)'),                  # (n², dim/2)
    ], axis=1)                                                        # (n², dim)
    return nn.Parameter(torch.tensor(embeds_2d).float().unsqueeze(0),
                        requires_grad=False)                           # (1, n², dim)
  1. 计算每边的 patch 数 n = in_dim / patch_size
  2. 生成频率向量 ωk=1/Nk/(d/4)\omega_k = 1/N^{k/(d/4)},其中 N=10000N=10000
  3. 对行坐标和列坐标分别计算正弦/余弦编码
  4. 将行编码和列编码拼接,得到 dim 维的 2D 位置编码
  5. 返回不可训练的参数(requires_grad=False

嵌入维度 dim 必须是 4 的倍数,因为需要分配给:行-sin、行-cos、列-sin、列-cos 各 dim/4 维。


完整的 Diffusion Transformer 模型,组合上述所有组件。

class DiT(nn.Module, ModelMixin):
    def __init__(self, in_dim=32, channels=3, patch_size=2, depth=12,
                 head_dim=64, num_heads=6, mlp_ratio=4.0,
                 sig_embed=None, cond_embed=None):
        super().__init__()
        self.input_dims = (channels, in_dim, in_dim)
        dim = head_dim * num_heads

        self.pos_embed = get_pos_embed(in_dim, patch_size, dim)
        self.x_embed = PatchEmbed(patch_size, channels, dim, bias=True)
        self.sig_embed = sig_embed or SigmaEmbedderSinCos(dim)
        self.cond_embed = cond_embed

        self.blocks = CondSequential(*[
            DiTBlock(head_dim, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
        ])

        self.final_norm = ModulatedLayerNorm(dim, elementwise_affine=False, eps=1e-6)
        self.final_linear = nn.Linear(dim, patch_size**2 * channels)
        self.init()
参数默认值说明
in_dim32输入图像边长
channels3输入通道数
patch_size2Patch 大小
depth12Transformer 块数量
head_dim64每个注意力头的维度
num_heads6注意力头数量
mlp_ratio4.0MLP 隐藏层扩展比
sig_embedNone自定义 σ 嵌入器(默认 SigmaEmbedderSinCos)
cond_embedNone条件嵌入器(None 表示无条件模型)
def init(self):
    def _basic_init(module):
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
    self.apply(_basic_init)
    nn.init.normal_(self.sig_embed.mlp[0].weight, std=0.02)
    nn.init.normal_(self.sig_embed.mlp[2].weight, std=0.02)
    nn.init.constant_(self.final_linear.weight, 0)
    nn.init.constant_(self.final_linear.bias, 0)

初始化策略:

  1. 所有 Linear 层使用 Xavier 均匀初始化
  2. σ 嵌入 MLP 使用小标准差正态初始化
  3. 最终输出层零初始化(训练初期输出接近零,即预测"无噪声")

将 patch 序列重组为图像:

def unpatchify(self, x):
    patches = self.in_dim // self.patch_size
    return rearrange(x, 'b (ph pw) (psh psw c) -> b c (ph psh) (pw psw)',
                     ph=patches, pw=patches,
                     psh=self.patch_size, psw=self.patch_size)

(B, N, patch_size² × C)(B, C, H, W)

def forward(self, x, sigma, cond=None):
    x = self.x_embed(x) + self.pos_embed            # (B, N, D)
    y = self.sig_embed(x.shape[0], sigma.squeeze())  # (B, D)
    if self.cond_embed is not None:
        assert cond is not None and x.shape[0] == cond.shape[0]
        y += self.cond_embed(cond)                   # (B, D)
    x = self.blocks(x, y)                            # (B, N, D)
    x = self.final_linear(self.final_norm(x, y))     # (B, N, p²C)
    return self.unpatchify(x)                         # (B, C, H, W)

条件信息(σ 嵌入 + 可选的类别嵌入)通过相加合并为统一的条件向量 yy,然后通过 adaLN 注入每个 Transformer 块。

from smalldiffusion import DiT, CondEmbedderLabel, ScheduleDDPM, training_loop, samples
from torch.utils.data import DataLoader

# 无条件 DiT
model = DiT(in_dim=28, channels=1, patch_size=2, depth=6,
            head_dim=32, num_heads=6)

# 条件 DiT(10 类)
model = DiT(in_dim=28, channels=1, patch_size=2, depth=6,
            head_dim=32, num_heads=6,
            cond_embed=CondEmbedderLabel(32*6, 10, dropout_prob=0.1))

相关内容