Analytic Diffusion Studio — 平滑最优去噪器
系列 - Analytic Diffusion Studio系列
目录
08 — 平滑最优去噪器 (SCFDM)
文件:src/local_diffusion/models/scfdm.py
论文:Score-based Generative Models with Closed-Form Denoisers
8.1 概述
SCFDM (Smoothed Closed-Form Diffusion Model) 是对 Optimal 去噪器的平滑改进。核心思想是:对输入 添加多组小幅高斯扰动,分别用 Optimal 去噪器处理,然后取平均。
解决的问题:Optimal 去噪器的输出可能不够平滑(因为 softmax 权重对输入敏感),通过蒙特卡洛平均来平滑输出。
与 Optimal 的关系:SCFDM 继承自 OptimalDenoiser,复用其 FAISS 索引和 softmax 加权逻辑,仅在 denoise() 中添加扰动-平均步骤。
8.2 数学公式
其中:
- 是 Optimal 去噪器
- 是噪声采样数(
num_noise) - 是平滑标准差(
smoothing_std)
8.3 类定义
@register_model("scfdm")
class SmoothedCFDM(OptimalDenoiser):
"""继承自 OptimalDenoiser,添加高斯平滑。"""
def __init__(self, dataset, device, num_steps, *, params=None, **kwargs):
super().__init__(dataset=dataset, device=device, num_steps=num_steps,
params=params, **kwargs)
self.num_noise = int(params.get("num_noise", 1))
self.smoothing_std = float(params.get("smoothing_std", 0.0))构造函数参数
| 参数 | 默认值 | 说明 |
|---|---|---|
params.num_noise | 1 | 高斯扰动采样数 M |
params.smoothing_std | 0.0 | 扰动标准差 σ_s |
params.temperature | 1.0 | 继承自 Optimal |
params.num_neighbors | 2000 | 继承自 Optimal |
当 num_noise=1, smoothing_std=0.0 时,SCFDM 退化为 Optimal。
参数校验
if self.num_noise <= 0:
raise ValueError("num_noise must be a positive integer")
if self.smoothing_std < 0:
raise ValueError("smoothing_std must be non-negative")8.4 train() 方法
直接继承自 OptimalDenoiser.train(),构建或加载 FAISS 索引。无额外逻辑。
8.5 denoise() 方法
@torch.no_grad()
def denoise(self, latents, timestep, *, generator=None, **kwargs):
# 1. 生成 M 组高斯噪声
batch_shape = (self.num_noise, *latents.shape) # [M, B, C, H, W]
noise = torch.randn(batch_shape, generator=generator,
device=latents.device, dtype=latents.dtype)
# 2. 对 x_t 添加扰动
smoothed_latents = latents.unsqueeze(0) + self.smoothing_std * noise # [M, B, C, H, W]
# 3. 展平为单个大批次
flat_latents = smoothed_latents.reshape(-1, *latents.shape[1:]) # [M*B, C, H, W]
# 4. 调用父类 Optimal 的 denoise
pred_x0 = super().denoise(flat_latents, timestep, generator=generator, **kwargs)
# 5. 恢复形状并取平均
return pred_x0.reshape(self.num_noise, *latents.shape).mean(dim=0) # [B, C, H, W]步骤详解
- 生成
[M, B, C, H, W]形状的噪声 - 广播加法:
latents.unsqueeze(0)形状为[1, B, C, H, W],加上噪声后得到 M 个扰动版本 - 将 M 个扰动版本合并为一个大批次
[M*B, C, H, W],一次性送入 Optimal 去噪器 - 将结果恢复为
[M, B, C, H, W],沿第 0 维取平均
计算量:是 Optimal 的 M 倍(FAISS 搜索量增加 M 倍)。
8.6 配置示例
# configs/scfdm/celeba_hq.yaml
model:
name: scfdm
params:
temperature: 1.0
num_neighbors: 200
num_noise: 10 # 10 组扰动
smoothing_std: 0.1 # 扰动标准差8.7 超参数影响
| 参数 | 增大效果 | 减小效果 |
|---|---|---|
num_noise | 更平滑,计算量线性增加 | 更快,但平滑效果弱 |
smoothing_std | 更强的平滑,可能过度模糊 | 趋向原始 Optimal |
8.8 继承关系
BaseDenoiser
└── OptimalDenoiser
└── SmoothedCFDMSCFDM 完全复用 Optimal 的:
__init__中的 FAISS 配置train()中的索引构建/加载denoise()中的 softmax 加权逻辑
仅重写 denoise() 添加扰动-平均包装。