目录

Analytic Diffusion Studio — 最近邻基线

10 — 最近邻基线 (Nearest Dataset)

文件:src/local_diffusion/models/nearest_dataset.py

Nearest Dataset 是最简单的去噪基线:对于每个噪声图像 xtx_t,在数据集中找到欧氏距离最近的图像作为 x^0\hat{x}_0

解决的问题:提供一个最低复杂度的参考基线,用于衡量其他方法的改进幅度。

直觉:如果 Optimal 去噪器的温度 τ0\tau \to 0,softmax 退化为 argmax,就得到最近邻。

DNN(xt,t)=x0(i),i=argminixtx0(i)2D_{\text{NN}}(x_t, t) = x_0^{(i^*)}, \quad i^* = \arg\min_i \|x_t - x_0^{(i)}\|_2

注意:这里直接用 xtx_tx0(i)x_0^{(i)} 比较,没有做 αˉt\sqrt{\bar{\alpha}_t} 缩放(与 Optimal 不同)。

@register_model("nearest_dataset")
class NearestDatasetDenoiser(BaseDenoiser):
    def __init__(self, dataset, device, num_steps, *, params=None, **kwargs):
        params = params or {}
        super().__init__(
            resolution=dataset.resolution,
            device=device,
            num_steps=num_steps,
            in_channels=dataset.in_channels,
            dataset_name=dataset.name,
            **kwargs,
        )
        self.dataset = dataset

无额外超参数。

def train(self, dataset: DatasetBundle):
    images = []
    for batch in dataset.dataloader:
        if isinstance(batch, (tuple, list)):
            batch = batch[0]
        images.append(batch)

    dataset_tensor = torch.cat(images, dim=0).contiguous().to(self.device)
    self.register_buffer("dataset_images", dataset_tensor)
    self.to(self.device)
    return self

将整个数据集加载到 GPU 内存,存储为 [N, C, H, W] 张量。

内存需求

  • MNIST (60k × 1 × 28 × 28):约 150 MB
  • CIFAR-10 (50k × 3 × 32 × 32):约 600 MB
  • CelebA-HQ 64×64 (30k × 3 × 64 × 64):约 1.4 GB
@torch.no_grad()
def denoise(self, latents, timestep, *, generator=None, **_):
    # 1. 展平
    latents_flat = latents.flatten(start_dim=1)  # [B, n]

    # 2. 计算所有距离
    distances = torch.cdist(latents_flat,
                            self.dataset_images.flatten(start_dim=1))  # [B, N]

    # 3. 找最近邻
    min_dist, indices = torch.min(distances, dim=1)  # [B]

    # 4. 返回最近邻图像
    pred_x0 = self.dataset_images[indices]  # [B, C, H, W]
    return pred_x0
  1. 将查询和数据集图像都展平为向量
  2. torch.cdist 计算成对 L2 距离矩阵 [B, N]
  3. 沿数据集维度取最小值,得到最近邻索引
  4. 用索引取出对应的数据集图像

复杂度O(BNn)O(B \cdot N \cdot n),其中 NN 是数据集大小,nn 是像素数。

# configs/nearest_dataset/mnist.yaml
model:
  name: nearest_dataset
  params: {}  # 无超参数
  • 只能"复制"训练集中的图像,无法生成新图像
  • 不考虑时间步 tt(距离计算不依赖 αˉt\bar{\alpha}_t
  • 需要将整个数据集放入 GPU 内存
  • torch.cdist 对大数据集可能内存不足(需要 [B, N] 距离矩阵)
  • 生成多样性完全取决于初始噪声到不同数据集图像的距离
特性NearestOptimal
距离缩放xt/αˉtx_t / \sqrt{\bar{\alpha}_t}
选择方式argmax(硬选择)softmax(软加权)
输出单张数据集图像多张图像的加权平均
时间步感知
近邻搜索暴力搜索FAISS 加速

相关内容