Analytic Diffusion Studio — 最优贝叶斯去噪器
系列 - Analytic Diffusion Studio系列
目录
07 — 最优贝叶斯去噪器 (Optimal)
文件:src/local_diffusion/models/optimal.py
7.1 概述
Optimal 去噪器实现了贝叶斯最优估计——后验均值 。它不做任何分布假设,直接对数据集中的所有图像做 softmax 加权平均。
解决的问题:在给定有限数据集的条件下,计算理论上最优的去噪估计。
核心思想:将数据集视为经验分布 ,则后验均值变为 softmax 加权平均。
7.2 数学公式
贝叶斯最优去噪器:
其中权重为:
- 是温度参数(
temperature), 时为标准贝叶斯最优 - 分子中 是对 的缩放,使其与 在同一尺度
7.3 FAISS 加速
直接遍历所有 N 个数据点计算权重代价太高。本实现使用 FAISS 库进行近似最近邻搜索,只对 top-k 个近邻计算权重。
FAISS 索引构建
if num_images > 1_000_000:
# 大数据集:使用 IVF(倒排文件)索引
nlist = min(4096, num_images // 39)
quantizer = faiss.IndexFlatL2(self.dim)
index = faiss.IndexIVFFlat(quantizer, self.dim, nlist)
index.train(train_data) # 需要训练聚类中心
else:
# 小数据集:精确搜索
index = faiss.IndexFlatL2(self.dim)
index.add(dataset_images.numpy().astype(np.float32))IndexFlatL2:暴力 L2 距离搜索,精确但慢IndexIVFFlat:基于倒排文件的近似搜索,先找到最近的聚类,再在聚类内搜索
macOS 兼容性
if platform.system() == "Darwin":
os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE")
faiss.omp_set_num_threads(1) # 避免 fork 安全问题7.4 类定义
@register_model("optimal")
class OptimalDenoiser(BaseDenoiser):
def __init__(self, dataset, device, num_steps, *, params=None, **kwargs):构造函数参数
| 参数 | 默认值 | 说明 |
|---|---|---|
params.index_path | None | FAISS 索引存储路径 |
params.temperature | 1.0 | softmax 温度参数 |
params.num_neighbors | 2000 | 近邻搜索数量 k |
默认索引路径:data/models/optimal/{dataset_name}_{resolution}
7.5 train() 方法
def train(self, dataset: DatasetBundle):
try:
self.faiss_index, self.dataset_images, self.dim = load_optimal_index(self.index_path)
except FileNotFoundError:
# 1. 遍历数据集,展平为 [N, n_pixels]
# 2. 构建 FAISS 索引
# 3. 保存索引和数据到磁盘保存的文件:
index.index:FAISS 索引文件data.pt:包含dataset_images张量和dim维度信息
7.6 denoise() 方法
@torch.no_grad()
def denoise(self, latents, timestep, *, generator=None, **kwargs):
# 1. 获取调度器参数
alpha_prod_t = self.scheduler.alphas_cumprod[timestep_index]
beta_prod_t = 1 - alpha_prod_t
# 2. 缩放查询向量:x_scaled = x_t / √ᾱ_t
latents_scaled = latents / torch.sqrt(alpha_prod_t)
latents_flat = latents_scaled.flatten(start_dim=1)
# 3. FAISS 近邻搜索
query_vectors = latents_flat.cpu().numpy().astype(np.float32)
k = min(self.num_neighbors, self.dataset_images.shape[0])
distances_np, indices_np = self.faiss_index.search(query_vectors, k)
# 4. 缩放距离(补偿查询缩放)
scaled_distances = distances * alpha_prod_t
# 5. 获取近邻图像
neighbor_images = self.dataset_images[indices].to(self.device)
# 6. 计算 softmax 权重
logits = -scaled_distances / (2 * beta_prod_t * self.temperature)
weights = torch.softmax(logits, dim=1)
# 7. 加权平均
pred_x0_flat = torch.bmm(weights.unsqueeze(1), neighbor_images).squeeze(1)
pred_x0 = pred_x0_flat.view_as(latents)
return pred_x0步骤详解
缩放查询:将 除以 ,使查询向量与数据集图像在同一尺度。这等价于在原始空间中比较 与 。
FAISS 搜索:在缩放后的空间中找到 k 个最近邻。FAISS 返回 L2 距离和索引。
距离修正:FAISS 返回的距离是缩放空间中的,乘以 恢复到原始空间。
softmax 权重:,然后 softmax 归一化。
加权平均:
torch.bmm实现批量矩阵乘法[B, 1, k] × [B, k, n] → [B, 1, n]。
7.7 辅助函数
save_optimal_index()
def save_optimal_index(faiss_index, dataset_images, save_path, dim):
faiss.write_index(faiss_index, str(save_path / "index.index"))
torch.save({"data": dataset_images.cpu(), "dim": dim}, save_path / "data.pt")load_optimal_index()
def load_optimal_index(load_path):
faiss_index = faiss.read_index(str(load_path / "index.index"))
saved_data = torch.load(load_path / "data.pt", weights_only=True)
return faiss_index, saved_data["data"], saved_data["dim"]7.8 配置示例
# configs/optimal/cifar10.yaml
model:
name: optimal
params:
temperature: 1.0
num_neighbors: 2007.9 温度参数的影响
| 温度 τ | 效果 |
|---|---|
| τ → 0 | 退化为最近邻(硬选择) |
| τ = 1 | 标准贝叶斯最优 |
| τ → ∞ | 所有权重趋于均匀(输出趋向全局均值) |
7.10 局限性
- 需要将整个数据集的展平图像存储在内存中
- FAISS 搜索需要 CPU-GPU 数据传输
- 生成图像是数据集图像的加权平均,可能模糊
- k 值选择影响精度和速度的权衡