目录

Analytic Diffusion Studio — 平滑最优去噪器

08 — 平滑最优去噪器 (SCFDM)

文件:src/local_diffusion/models/scfdm.py

论文:Score-based Generative Models with Closed-Form Denoisers

SCFDM (Smoothed Closed-Form Diffusion Model) 是对 Optimal 去噪器的平滑改进。核心思想是:对输入 xtx_t 添加多组小幅高斯扰动,分别用 Optimal 去噪器处理,然后取平均。

解决的问题:Optimal 去噪器的输出可能不够平滑(因为 softmax 权重对输入敏感),通过蒙特卡洛平均来平滑输出。

与 Optimal 的关系:SCFDM 继承自 OptimalDenoiser,复用其 FAISS 索引和 softmax 加权逻辑,仅在 denoise() 中添加扰动-平均步骤。

DSCFDM(xt,t)=1Mj=1MD(xt+σsϵj,t),ϵjN(0,I)D_{\text{SCFDM}}(x_t, t) = \frac{1}{M} \sum_{j=1}^{M} D^*(x_t + \sigma_s \epsilon_j, t), \quad \epsilon_j \sim \mathcal{N}(0, I)

其中:

  • DD^* 是 Optimal 去噪器
  • MM 是噪声采样数(num_noise
  • σs\sigma_s 是平滑标准差(smoothing_std
@register_model("scfdm")
class SmoothedCFDM(OptimalDenoiser):
    """继承自 OptimalDenoiser,添加高斯平滑。"""

    def __init__(self, dataset, device, num_steps, *, params=None, **kwargs):
        super().__init__(dataset=dataset, device=device, num_steps=num_steps,
                         params=params, **kwargs)
        self.num_noise = int(params.get("num_noise", 1))
        self.smoothing_std = float(params.get("smoothing_std", 0.0))
参数默认值说明
params.num_noise1高斯扰动采样数 M
params.smoothing_std0.0扰动标准差 σ_s
params.temperature1.0继承自 Optimal
params.num_neighbors2000继承自 Optimal

num_noise=1, smoothing_std=0.0 时,SCFDM 退化为 Optimal。

if self.num_noise <= 0:
    raise ValueError("num_noise must be a positive integer")
if self.smoothing_std < 0:
    raise ValueError("smoothing_std must be non-negative")

直接继承自 OptimalDenoiser.train(),构建或加载 FAISS 索引。无额外逻辑。

@torch.no_grad()
def denoise(self, latents, timestep, *, generator=None, **kwargs):
    # 1. 生成 M 组高斯噪声
    batch_shape = (self.num_noise, *latents.shape)  # [M, B, C, H, W]
    noise = torch.randn(batch_shape, generator=generator,
                        device=latents.device, dtype=latents.dtype)

    # 2. 对 x_t 添加扰动
    smoothed_latents = latents.unsqueeze(0) + self.smoothing_std * noise  # [M, B, C, H, W]

    # 3. 展平为单个大批次
    flat_latents = smoothed_latents.reshape(-1, *latents.shape[1:])  # [M*B, C, H, W]

    # 4. 调用父类 Optimal 的 denoise
    pred_x0 = super().denoise(flat_latents, timestep, generator=generator, **kwargs)

    # 5. 恢复形状并取平均
    return pred_x0.reshape(self.num_noise, *latents.shape).mean(dim=0)  # [B, C, H, W]
  1. 生成 [M, B, C, H, W] 形状的噪声
  2. 广播加法:latents.unsqueeze(0) 形状为 [1, B, C, H, W],加上噪声后得到 M 个扰动版本
  3. 将 M 个扰动版本合并为一个大批次 [M*B, C, H, W],一次性送入 Optimal 去噪器
  4. 将结果恢复为 [M, B, C, H, W],沿第 0 维取平均

计算量:是 Optimal 的 M 倍(FAISS 搜索量增加 M 倍)。

# configs/scfdm/celeba_hq.yaml
model:
  name: scfdm
  params:
    temperature: 1.0
    num_neighbors: 200
    num_noise: 10          # 10 组扰动
    smoothing_std: 0.1     # 扰动标准差
参数增大效果减小效果
num_noise更平滑,计算量线性增加更快,但平滑效果弱
smoothing_std更强的平滑,可能过度模糊趋向原始 Optimal
BaseDenoiser
    └── OptimalDenoiser
            └── SmoothedCFDM

SCFDM 完全复用 Optimal 的:

  • __init__ 中的 FAISS 配置
  • train() 中的索引构建/加载
  • denoise() 中的 softmax 加权逻辑

仅重写 denoise() 添加扰动-平均包装。

相关内容