目录

Analytic Diffusion Studio — 最优贝叶斯去噪器

07 — 最优贝叶斯去噪器 (Optimal)

文件:src/local_diffusion/models/optimal.py

Optimal 去噪器实现了贝叶斯最优估计——后验均值 E[x0xt]\mathbb{E}[x_0 | x_t]。它不做任何分布假设,直接对数据集中的所有图像做 softmax 加权平均。

解决的问题:在给定有限数据集的条件下,计算理论上最优的去噪估计。

核心思想:将数据集视为经验分布 p(x0)=1Niδ(x0x0(i))p(x_0) = \frac{1}{N}\sum_i \delta(x_0 - x_0^{(i)}),则后验均值变为 softmax 加权平均。

贝叶斯最优去噪器:

D(xt,t)=i=1Nx0(i)wi(xt,t)i=1Nwi(xt,t)D^*(x_t, t) = \frac{\sum_{i=1}^{N} x_0^{(i)} \cdot w_i(x_t, t)}{\sum_{i=1}^{N} w_i(x_t, t)}

其中权重为:

wi(xt,t)=exp(αˉtxt/αˉtx0(i)22(1αˉt)τ)w_i(x_t, t) = \exp\left(-\frac{\bar{\alpha}_t \|x_t / \sqrt{\bar{\alpha}_t} - x_0^{(i)}\|^2}{2(1-\bar{\alpha}_t) \cdot \tau}\right)
  • τ\tau 是温度参数(temperature),τ=1\tau = 1 时为标准贝叶斯最优
  • 分子中 xt/αˉtx_t / \sqrt{\bar{\alpha}_t} 是对 xtx_t 的缩放,使其与 x0x_0 在同一尺度

直接遍历所有 N 个数据点计算权重代价太高。本实现使用 FAISS 库进行近似最近邻搜索,只对 top-k 个近邻计算权重。

if num_images > 1_000_000:
    # 大数据集:使用 IVF(倒排文件)索引
    nlist = min(4096, num_images // 39)
    quantizer = faiss.IndexFlatL2(self.dim)
    index = faiss.IndexIVFFlat(quantizer, self.dim, nlist)
    index.train(train_data)  # 需要训练聚类中心
else:
    # 小数据集:精确搜索
    index = faiss.IndexFlatL2(self.dim)

index.add(dataset_images.numpy().astype(np.float32))
  • IndexFlatL2:暴力 L2 距离搜索,精确但慢
  • IndexIVFFlat:基于倒排文件的近似搜索,先找到最近的聚类,再在聚类内搜索
if platform.system() == "Darwin":
    os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE")
    faiss.omp_set_num_threads(1)  # 避免 fork 安全问题
@register_model("optimal")
class OptimalDenoiser(BaseDenoiser):
    def __init__(self, dataset, device, num_steps, *, params=None, **kwargs):
参数默认值说明
params.index_pathNoneFAISS 索引存储路径
params.temperature1.0softmax 温度参数
params.num_neighbors2000近邻搜索数量 k

默认索引路径:data/models/optimal/{dataset_name}_{resolution}

def train(self, dataset: DatasetBundle):
    try:
        self.faiss_index, self.dataset_images, self.dim = load_optimal_index(self.index_path)
    except FileNotFoundError:
        # 1. 遍历数据集,展平为 [N, n_pixels]
        # 2. 构建 FAISS 索引
        # 3. 保存索引和数据到磁盘

保存的文件:

  • index.index:FAISS 索引文件
  • data.pt:包含 dataset_images 张量和 dim 维度信息
@torch.no_grad()
def denoise(self, latents, timestep, *, generator=None, **kwargs):
    # 1. 获取调度器参数
    alpha_prod_t = self.scheduler.alphas_cumprod[timestep_index]
    beta_prod_t = 1 - alpha_prod_t

    # 2. 缩放查询向量:x_scaled = x_t / √ᾱ_t
    latents_scaled = latents / torch.sqrt(alpha_prod_t)
    latents_flat = latents_scaled.flatten(start_dim=1)

    # 3. FAISS 近邻搜索
    query_vectors = latents_flat.cpu().numpy().astype(np.float32)
    k = min(self.num_neighbors, self.dataset_images.shape[0])
    distances_np, indices_np = self.faiss_index.search(query_vectors, k)

    # 4. 缩放距离(补偿查询缩放)
    scaled_distances = distances * alpha_prod_t

    # 5. 获取近邻图像
    neighbor_images = self.dataset_images[indices].to(self.device)

    # 6. 计算 softmax 权重
    logits = -scaled_distances / (2 * beta_prod_t * self.temperature)
    weights = torch.softmax(logits, dim=1)

    # 7. 加权平均
    pred_x0_flat = torch.bmm(weights.unsqueeze(1), neighbor_images).squeeze(1)
    pred_x0 = pred_x0_flat.view_as(latents)
    return pred_x0
  1. 缩放查询:将 xtx_t 除以 αˉt\sqrt{\bar{\alpha}_t},使查询向量与数据集图像在同一尺度。这等价于在原始空间中比较 xtx_tαˉtx0(i)\sqrt{\bar{\alpha}_t} x_0^{(i)}

  2. FAISS 搜索:在缩放后的空间中找到 k 个最近邻。FAISS 返回 L2 距离和索引。

  3. 距离修正:FAISS 返回的距离是缩放空间中的,乘以 αˉt\bar{\alpha}_t 恢复到原始空间。

  4. softmax 权重logiti=αˉtdi22(1αˉt)τ\text{logit}_i = -\frac{\bar{\alpha}_t d_i^2}{2(1-\bar{\alpha}_t)\tau},然后 softmax 归一化。

  5. 加权平均torch.bmm 实现批量矩阵乘法 [B, 1, k] × [B, k, n] → [B, 1, n]

def save_optimal_index(faiss_index, dataset_images, save_path, dim):
    faiss.write_index(faiss_index, str(save_path / "index.index"))
    torch.save({"data": dataset_images.cpu(), "dim": dim}, save_path / "data.pt")
def load_optimal_index(load_path):
    faiss_index = faiss.read_index(str(load_path / "index.index"))
    saved_data = torch.load(load_path / "data.pt", weights_only=True)
    return faiss_index, saved_data["data"], saved_data["dim"]
# configs/optimal/cifar10.yaml
model:
  name: optimal
  params:
    temperature: 1.0
    num_neighbors: 200
温度 τ效果
τ → 0退化为最近邻(硬选择)
τ = 1标准贝叶斯最优
τ → ∞所有权重趋于均匀(输出趋向全局均值)
  • 需要将整个数据集的展平图像存储在内存中
  • FAISS 搜索需要 CPU-GPU 数据传输
  • 生成图像是数据集图像的加权平均,可能模糊
  • k 值选择影响精度和速度的权衡

相关内容