Analytic Diffusion Studio — 最近邻基线
系列 - Analytic Diffusion Studio系列
目录
10 — 最近邻基线 (Nearest Dataset)
文件:src/local_diffusion/models/nearest_dataset.py
10.1 概述
Nearest Dataset 是最简单的去噪基线:对于每个噪声图像 ,在数据集中找到欧氏距离最近的图像作为 。
解决的问题:提供一个最低复杂度的参考基线,用于衡量其他方法的改进幅度。
直觉:如果 Optimal 去噪器的温度 ,softmax 退化为 argmax,就得到最近邻。
10.2 数学公式
注意:这里直接用 与 比较,没有做 缩放(与 Optimal 不同)。
10.3 类定义
@register_model("nearest_dataset")
class NearestDatasetDenoiser(BaseDenoiser):
def __init__(self, dataset, device, num_steps, *, params=None, **kwargs):
params = params or {}
super().__init__(
resolution=dataset.resolution,
device=device,
num_steps=num_steps,
in_channels=dataset.in_channels,
dataset_name=dataset.name,
**kwargs,
)
self.dataset = dataset无额外超参数。
10.4 train() 方法
def train(self, dataset: DatasetBundle):
images = []
for batch in dataset.dataloader:
if isinstance(batch, (tuple, list)):
batch = batch[0]
images.append(batch)
dataset_tensor = torch.cat(images, dim=0).contiguous().to(self.device)
self.register_buffer("dataset_images", dataset_tensor)
self.to(self.device)
return self将整个数据集加载到 GPU 内存,存储为 [N, C, H, W] 张量。
内存需求:
- MNIST (60k × 1 × 28 × 28):约 150 MB
- CIFAR-10 (50k × 3 × 32 × 32):约 600 MB
- CelebA-HQ 64×64 (30k × 3 × 64 × 64):约 1.4 GB
10.5 denoise() 方法
@torch.no_grad()
def denoise(self, latents, timestep, *, generator=None, **_):
# 1. 展平
latents_flat = latents.flatten(start_dim=1) # [B, n]
# 2. 计算所有距离
distances = torch.cdist(latents_flat,
self.dataset_images.flatten(start_dim=1)) # [B, N]
# 3. 找最近邻
min_dist, indices = torch.min(distances, dim=1) # [B]
# 4. 返回最近邻图像
pred_x0 = self.dataset_images[indices] # [B, C, H, W]
return pred_x0步骤详解
- 将查询和数据集图像都展平为向量
torch.cdist计算成对 L2 距离矩阵[B, N]- 沿数据集维度取最小值,得到最近邻索引
- 用索引取出对应的数据集图像
复杂度:,其中 是数据集大小, 是像素数。
10.6 配置示例
# configs/nearest_dataset/mnist.yaml
model:
name: nearest_dataset
params: {} # 无超参数10.7 局限性
- 只能"复制"训练集中的图像,无法生成新图像
- 不考虑时间步 (距离计算不依赖 )
- 需要将整个数据集放入 GPU 内存
torch.cdist对大数据集可能内存不足(需要[B, N]距离矩阵)- 生成多样性完全取决于初始噪声到不同数据集图像的距离
10.8 与 Optimal 的对比
| 特性 | Nearest | Optimal |
|---|---|---|
| 距离缩放 | 无 | |
| 选择方式 | argmax(硬选择) | softmax(软加权) |
| 输出 | 单张数据集图像 | 多张图像的加权平均 |
| 时间步感知 | 否 | 是 |
| 近邻搜索 | 暴力搜索 | FAISS 加速 |