目录

Analytic Diffusion Studio — 工具模块

13 — 工具模块

目录:src/local_diffusion/utils/

utils/
├── __init__.py              # 导出公共 API
├── wiener.py                # Wiener 滤波计算与存储
└── neural_networks.py       # UNet 网络定义

__init__.py 导出:

from .wiener import compute_wiener_filter, load_wiener_filter, save_wiener_filter
from .neural_networks import UNet

def compute_wiener_filter(dataloader, device, resolution, n_channels):
    """从数据集计算协方差矩阵和均值。"""

两遍扫描算法

第一遍:计算均值

for batch in dataloader:
    images = batch[0].to(device).flatten(start_dim=1)  # [batch, n_pixels]
    sum_images += images.sum(dim=0)
    total_samples += images.shape[0]
mean = sum_images / total_samples

第二遍:计算协方差

for batch in dataloader:
    images = batch[0].to(device).flatten(start_dim=1)
    centered = images - mean.unsqueeze(0)
    cov_accumulator += centered.T @ centered  # [n, n]
S = cov_accumulator / (total_samples - 1)

返回值

  • S:协方差矩阵 [n_pixels, n_pixels]
  • mean:均值向量 [n_pixels]

内存需求:协方差矩阵大小为 n2n^2,其中 n=C×H×Wn = C \times H \times W

  • MNIST (784):约 2.3 MB
  • CIFAR-10 (3072):约 36 MB
  • 64×64 RGB (12288):约 576 MB

为什么用两遍扫描? 一遍扫描的在线协方差算法(如 Welford)数值稳定性更好,但两遍扫描更简单且在 GPU 上更高效(可以利用矩阵乘法加速)。

def save_wiener_filter(U, LA, Vh, mean, save_path):
    save_path.mkdir(parents=True, exist_ok=True)
    torch.save(U.cpu(), save_path / "U.pt")
    torch.save(LA.cpu(), save_path / "LA.pt")
    torch.save(Vh.cpu(), save_path / "Vh.pt")
    torch.save(mean.cpu(), save_path / "mean.pt")

保存 SVD 分解的四个组件到指定目录。

def load_wiener_filter(load_path, device=None):
    U = torch.load(load_path / "U.pt", map_location=device, weights_only=True)
    LA = torch.load(load_path / "LA.pt", map_location=device, weights_only=True)
    Vh = torch.load(load_path / "Vh.pt", map_location=device, weights_only=True)
    mean = torch.load(load_path / "mean.pt", map_location=device, weights_only=True)
    return U, LA, Vh, mean

如果文件不存在,抛出 FileNotFoundError(调用方负责处理)。

weights_only=True:PyTorch 安全加载模式,防止反序列化攻击。


这是一个标准的 DDPM UNet 实现,用于 BaselineUNet 模型。

输入 x_t [B, C, H, W] + 时间步 t [B]
TimeEmbedding(t) → temb [B, tdim]
Head Conv (C → ch)
┌─ Encoder ──────────────────────┐
│  Level 0: ResBlock × 2         │
│  DownSample                    │
│  Level 1: ResBlock × 2         │
│  DownSample                    │
│  ...                           │
│  Level N: ResBlock × 2         │
└────────────────────────────────┘
Middle: ResBlock(attn=True) + ResBlock(attn=False)
┌─ Decoder ──────────────────────┐
│  Level N: ResBlock × 3 (+ skip)│
│  UpSample                      │
│  ...                           │
│  Level 0: ResBlock × 3 (+ skip)│
└────────────────────────────────┘
Tail: GroupNorm → Swish → Conv (ch → C)
输出 ε̂ [B, C, H, W]
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

Swish(x)=xσ(x)\text{Swish}(x) = x \cdot \sigma(x),是一种平滑的非线性激活,在扩散模型中广泛使用。

class TimeEmbedding(nn.Module):
    def __init__(self, T, d_model, dim):
        # 正弦位置编码
        emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
        emb = torch.exp(-emb)
        pos = torch.arange(T).float()
        emb = pos[:, None] * emb[None, :]
        emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1).view(T, d_model)

        self.timembedding = nn.Sequential(
            nn.Embedding.from_pretrained(emb),  # 查表
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),
        )

将离散时间步 t{0,1,...,T1}t \in \{0, 1, ..., T-1\} 编码为连续向量:

  1. 正弦/余弦位置编码(类似 Transformer)
  2. 两层 MLP 投影到 tdim = ch * 4
class DownSample(nn.Module):
    def __init__(self, in_ch):
        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)

使用 stride=2 的卷积实现 2× 下采样。

class UpSample(nn.Module):
    def __init__(self, in_ch):
        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)

    def forward(self, x, temb):
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        x = self.main(x)
        return x

先最近邻插值 2× 上采样,再卷积平滑。

class AttnBlock(nn.Module):
    def __init__(self, in_ch):
        self.group_norm = nn.GroupNorm(32, in_ch)
        self.proj_q = nn.Conv2d(in_ch, in_ch, 1)
        self.proj_k = nn.Conv2d(in_ch, in_ch, 1)
        self.proj_v = nn.Conv2d(in_ch, in_ch, 1)
        self.proj = nn.Conv2d(in_ch, in_ch, 1)

标准自注意力机制:

  1. GroupNorm 归一化
  2. 1×1 卷积生成 Q、K、V
  3. 注意力权重:W=softmax(QKT/C)W = \text{softmax}(QK^T / \sqrt{C})
  4. 输出:h=WVh = WV
  5. 残差连接:x+proj(h)x + \text{proj}(h)
class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
        self.block1 = nn.Sequential(GroupNorm, Swish, Conv2d)
        self.temb_proj = nn.Sequential(Swish, Linear(tdim, out_ch))
        self.block2 = nn.Sequential(GroupNorm, Swish, Dropout, Conv2d)
        self.shortcut = Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else Identity()
        self.attn = AttnBlock(out_ch) if attn else Identity()

前向传播:

def forward(self, x, temb):
    h = self.block1(x)
    h += self.temb_proj(temb)[:, :, None, None]  # 时间嵌入注入
    h = self.block2(h)
    h = h + self.shortcut(x)                      # 残差连接
    h = self.attn(h)                               # 可选注意力
    return h
class FlattenLinear(nn.Module):
    def __init__(self, channels, height, width, tdim):
        self.linear = nn.Linear(channels * height * width + tdim,
                                channels * height * width)

将特征图展平后与时间嵌入拼接,通过全连接层处理。在当前配置中未使用(为扩展预留)。

class UNet(nn.Module):
    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout,
                 in_channels=3, out_channels=3):
参数说明
T时间步总数(1000)
ch基础通道数
ch_mult各级通道倍率列表
attn使用注意力的级别索引列表
num_res_blocks每级残差块数
dropoutDropout 概率
in_channels输入通道数
out_channels输出通道数
def forward(self, x, t, return_middle_feature=False, return_all_features=False):
    temb = self.time_embedding(t)
    h = self.head(x)
    hs = [h]

    # 编码器
    for layer in self.downblocks:
        h = layer(h, temb)
        hs.append(h)

    # 中间层
    for layer in self.middleblocks:
        h = layer(h, temb)

    # 解码器(带跳跃连接)
    for layer in self.upblocks:
        if isinstance(layer, ResBlock):
            h = torch.cat([h, hs.pop()], dim=1)  # 跳跃连接
        h = layer(h, temb)

    h = self.tail(h)
    return h

可选返回中间特征(用于分析):

  • return_middle_feature=True:返回 (output, middle_feature, temb)
  • return_all_features=True:返回 (output, middle_feature, pretail_features, temb)
def initialize(self):
    init.xavier_uniform_(self.head.weight)
    init.zeros_(self.head.bias)
    init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)  # 小增益
    init.zeros_(self.tail[-1].bias)

尾部卷积使用极小的增益(1e-5),使初始输出接近零——这是扩散模型训练的常见技巧。

以 CIFAR-10 (32×32, ch=128, ch_mult=[1,2,3,4]) 为例:

编码器:
  Level 0: 128 → 128 (ResBlock ×2), DownSample → 16×16
  Level 1: 128 → 256 (ResBlock ×2), DownSample → 8×8
  Level 2: 256 → 384 (ResBlock ×2), DownSample → 4×4
  Level 3: 384 → 512 (ResBlock ×2)

中间层:
  512 → 512 (ResBlock with Attn)
  512 → 512 (ResBlock)

解码器:
  Level 3: 512+512 → 512 (ResBlock ×3)
  UpSample → 8×8
  Level 2: 512+384 → 384 (ResBlock ×3)
  UpSample → 16×16
  Level 1: 384+256 → 256 (ResBlock ×3)
  UpSample → 32×32
  Level 0: 256+128 → 128 (ResBlock ×3)

相关内容