Analytic Diffusion Studio — 工具模块
系列 - Analytic Diffusion Studio系列
目录
13 — 工具模块
目录:src/local_diffusion/utils/
13.1 模块结构
utils/
├── __init__.py # 导出公共 API
├── wiener.py # Wiener 滤波计算与存储
└── neural_networks.py # UNet 网络定义__init__.py 导出:
from .wiener import compute_wiener_filter, load_wiener_filter, save_wiener_filter
from .neural_networks import UNet13.2 Wiener 滤波工具 (wiener.py)
compute_wiener_filter()
def compute_wiener_filter(dataloader, device, resolution, n_channels):
"""从数据集计算协方差矩阵和均值。"""两遍扫描算法:
第一遍:计算均值
for batch in dataloader:
images = batch[0].to(device).flatten(start_dim=1) # [batch, n_pixels]
sum_images += images.sum(dim=0)
total_samples += images.shape[0]
mean = sum_images / total_samples第二遍:计算协方差
for batch in dataloader:
images = batch[0].to(device).flatten(start_dim=1)
centered = images - mean.unsqueeze(0)
cov_accumulator += centered.T @ centered # [n, n]
S = cov_accumulator / (total_samples - 1)返回值:
S:协方差矩阵[n_pixels, n_pixels]mean:均值向量[n_pixels]
内存需求:协方差矩阵大小为 ,其中 :
- MNIST (784):约 2.3 MB
- CIFAR-10 (3072):约 36 MB
- 64×64 RGB (12288):约 576 MB
为什么用两遍扫描? 一遍扫描的在线协方差算法(如 Welford)数值稳定性更好,但两遍扫描更简单且在 GPU 上更高效(可以利用矩阵乘法加速)。
save_wiener_filter()
def save_wiener_filter(U, LA, Vh, mean, save_path):
save_path.mkdir(parents=True, exist_ok=True)
torch.save(U.cpu(), save_path / "U.pt")
torch.save(LA.cpu(), save_path / "LA.pt")
torch.save(Vh.cpu(), save_path / "Vh.pt")
torch.save(mean.cpu(), save_path / "mean.pt")保存 SVD 分解的四个组件到指定目录。
load_wiener_filter()
def load_wiener_filter(load_path, device=None):
U = torch.load(load_path / "U.pt", map_location=device, weights_only=True)
LA = torch.load(load_path / "LA.pt", map_location=device, weights_only=True)
Vh = torch.load(load_path / "Vh.pt", map_location=device, weights_only=True)
mean = torch.load(load_path / "mean.pt", map_location=device, weights_only=True)
return U, LA, Vh, mean如果文件不存在,抛出 FileNotFoundError(调用方负责处理)。
weights_only=True:PyTorch 安全加载模式,防止反序列化攻击。
13.3 UNet 网络 (neural_networks.py)
这是一个标准的 DDPM UNet 实现,用于 BaselineUNet 模型。
整体架构
输入 x_t [B, C, H, W] + 时间步 t [B]
│
▼
TimeEmbedding(t) → temb [B, tdim]
│
▼
Head Conv (C → ch)
│
▼
┌─ Encoder ──────────────────────┐
│ Level 0: ResBlock × 2 │
│ DownSample │
│ Level 1: ResBlock × 2 │
│ DownSample │
│ ... │
│ Level N: ResBlock × 2 │
└────────────────────────────────┘
│
▼
Middle: ResBlock(attn=True) + ResBlock(attn=False)
│
▼
┌─ Decoder ──────────────────────┐
│ Level N: ResBlock × 3 (+ skip)│
│ UpSample │
│ ... │
│ Level 0: ResBlock × 3 (+ skip)│
└────────────────────────────────┘
│
▼
Tail: GroupNorm → Swish → Conv (ch → C)
│
▼
输出 ε̂ [B, C, H, W]Swish 激活函数
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x),是一种平滑的非线性激活,在扩散模型中广泛使用。
TimeEmbedding
class TimeEmbedding(nn.Module):
def __init__(self, T, d_model, dim):
# 正弦位置编码
emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
emb = torch.exp(-emb)
pos = torch.arange(T).float()
emb = pos[:, None] * emb[None, :]
emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1).view(T, d_model)
self.timembedding = nn.Sequential(
nn.Embedding.from_pretrained(emb), # 查表
nn.Linear(d_model, dim),
Swish(),
nn.Linear(dim, dim),
)将离散时间步 编码为连续向量:
- 正弦/余弦位置编码(类似 Transformer)
- 两层 MLP 投影到
tdim = ch * 4维
DownSample
class DownSample(nn.Module):
def __init__(self, in_ch):
self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)使用 stride=2 的卷积实现 2× 下采样。
UpSample
class UpSample(nn.Module):
def __init__(self, in_ch):
self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
def forward(self, x, temb):
x = F.interpolate(x, scale_factor=2, mode="nearest")
x = self.main(x)
return x先最近邻插值 2× 上采样,再卷积平滑。
AttnBlock
class AttnBlock(nn.Module):
def __init__(self, in_ch):
self.group_norm = nn.GroupNorm(32, in_ch)
self.proj_q = nn.Conv2d(in_ch, in_ch, 1)
self.proj_k = nn.Conv2d(in_ch, in_ch, 1)
self.proj_v = nn.Conv2d(in_ch, in_ch, 1)
self.proj = nn.Conv2d(in_ch, in_ch, 1)标准自注意力机制:
- GroupNorm 归一化
- 1×1 卷积生成 Q、K、V
- 注意力权重:
- 输出:
- 残差连接:
ResBlock
class ResBlock(nn.Module):
def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
self.block1 = nn.Sequential(GroupNorm, Swish, Conv2d)
self.temb_proj = nn.Sequential(Swish, Linear(tdim, out_ch))
self.block2 = nn.Sequential(GroupNorm, Swish, Dropout, Conv2d)
self.shortcut = Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else Identity()
self.attn = AttnBlock(out_ch) if attn else Identity()前向传播:
def forward(self, x, temb):
h = self.block1(x)
h += self.temb_proj(temb)[:, :, None, None] # 时间嵌入注入
h = self.block2(h)
h = h + self.shortcut(x) # 残差连接
h = self.attn(h) # 可选注意力
return hFlattenLinear
class FlattenLinear(nn.Module):
def __init__(self, channels, height, width, tdim):
self.linear = nn.Linear(channels * height * width + tdim,
channels * height * width)将特征图展平后与时间嵌入拼接,通过全连接层处理。在当前配置中未使用(为扩展预留)。
UNet 主类
class UNet(nn.Module):
def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout,
in_channels=3, out_channels=3):构造函数参数
| 参数 | 说明 |
|---|---|
T | 时间步总数(1000) |
ch | 基础通道数 |
ch_mult | 各级通道倍率列表 |
attn | 使用注意力的级别索引列表 |
num_res_blocks | 每级残差块数 |
dropout | Dropout 概率 |
in_channels | 输入通道数 |
out_channels | 输出通道数 |
forward() 方法
def forward(self, x, t, return_middle_feature=False, return_all_features=False):
temb = self.time_embedding(t)
h = self.head(x)
hs = [h]
# 编码器
for layer in self.downblocks:
h = layer(h, temb)
hs.append(h)
# 中间层
for layer in self.middleblocks:
h = layer(h, temb)
# 解码器(带跳跃连接)
for layer in self.upblocks:
if isinstance(layer, ResBlock):
h = torch.cat([h, hs.pop()], dim=1) # 跳跃连接
h = layer(h, temb)
h = self.tail(h)
return h可选返回中间特征(用于分析):
return_middle_feature=True:返回(output, middle_feature, temb)return_all_features=True:返回(output, middle_feature, pretail_features, temb)
权重初始化
def initialize(self):
init.xavier_uniform_(self.head.weight)
init.zeros_(self.head.bias)
init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) # 小增益
init.zeros_(self.tail[-1].bias)尾部卷积使用极小的增益(1e-5),使初始输出接近零——这是扩散模型训练的常见技巧。
编码器-解码器通道数示例
以 CIFAR-10 (32×32, ch=128, ch_mult=[1,2,3,4]) 为例:
编码器:
Level 0: 128 → 128 (ResBlock ×2), DownSample → 16×16
Level 1: 128 → 256 (ResBlock ×2), DownSample → 8×8
Level 2: 256 → 384 (ResBlock ×2), DownSample → 4×4
Level 3: 384 → 512 (ResBlock ×2)
中间层:
512 → 512 (ResBlock with Attn)
512 → 512 (ResBlock)
解码器:
Level 3: 512+512 → 512 (ResBlock ×3)
UpSample → 8×8
Level 2: 512+384 → 384 (ResBlock ×3)
UpSample → 16×16
Level 1: 384+256 → 256 (ResBlock ×3)
UpSample → 32×32
Level 0: 256+128 → 128 (ResBlock ×3)