目录

Analytic Diffusion Studio — 模型基类与采样循环

summary: “Analytic Diffusion Studio — 模型基类与采样循环”

文件:src/local_diffusion/models/base.pysrc/local_diffusion/models/__init__.py

与数据集类似,模型也使用注册表模式:

# models/__init__.py
MODEL_REGISTRY: Dict[str, Callable[..., Any]] = {}

def register_model(name: str):
    def decorator(cls_or_factory):
        MODEL_REGISTRY[name.lower()] = cls_or_factory
        return cls_or_factory
    return decorator

def create_model(name: str, **kwargs) -> Any:
    factory = MODEL_REGISTRY.get(name.lower())
    if factory is None:
        raise ValueError(f"Unknown model '{name}'. Available: {sorted(MODEL_REGISTRY)}")
    return factory(**kwargs)

已注册的模型在模块底部通过导入触发:

from . import nearest_dataset   # "nearest_dataset"
from . import optimal           # "optimal"
from . import scfdm             # "scfdm"
from . import wiener            # "wiener"
from . import pca_locality      # "pca_locality"

generate.py 中的调用方式:

model = create_model(
    cfg.model.name,           # 例如 "pca_locality"
    dataset=dataset,          # DatasetBundle
    device=cfg.experiment.device,
    num_steps=cfg.sampling.num_inference_steps,
    params=model_params,      # 模型特定超参数 dict
)

采样结果的统一封装:

@dataclass
class SamplingOutput:
    images: torch.Tensor                    # 最终生成图像 [N, C, H, W]
    timesteps: Optional[List[int]]          # 时间步列表(如 [999, 888, ...])
    trajectory_xt: Optional[List[Tensor]]   # 每步的噪声图像 x_t
    trajectory_x0: Optional[List[Tensor]]   # 每步的预测干净图像 x̂₀

trajectory_xttrajectory_x0 是列表,每个元素对应一个时间步,形状为 [N, C, H, W]。仅在 return_intermediates=True 时记录。

所有去噪模型的抽象基类,继承自 torch.nn.Module

class BaseDenoiser(torch.nn.Module):
    prediction_type: str = "epsilon"  # 预测类型:噪声 ε

    def __init__(self, resolution, device, num_steps, *,
                 beta_1=0.0001, beta_T=0.02,
                 dataset_name="cifar10",
                 scheduler_kwargs=None, **kwargs):
        self.device = torch.device(device)
        self.n_channels = kwargs.get("in_channels", 3)
        self.resolution = resolution
        self.num_steps = num_steps

        # 初始化 DDIM 调度器
        self.scheduler = DDIMScheduler(
            beta_start=beta_1, beta_end=beta_T,
            beta_schedule="linear",
            prediction_type=self.prediction_type,
        )
        self.scheduler.set_timesteps(num_steps)

关键属性:

  • self.scheduler:HuggingFace diffusers 的 DDIMScheduler,管理噪声调度表和时间步
  • self.scheduler.alphas_cumprod:累积 αˉt\bar{\alpha}_t 数组,长度 1000
  • self.scheduler.timesteps:采样时间步(如 10 步时为 [999, 899, 799, ...]
def denoise(self, latents, timestep, *, generator=None, **kwargs):
    """预测去噪后的干净图像 x̂₀。子类必须实现此方法。"""
    raise NotImplementedError

输入:

  • latents:当前噪声图像 xtx_t,形状 [B, C, H, W]
  • timestep:当前时间步(标量或 0-d 张量)

输出:

  • 预测的干净图像 x^0\hat{x}_0,形状 [B, C, H, W]
def train(self, dataset):
    """预计算数据集相关参数。子类必须实现。"""
    raise NotImplementedError

不同模型的 train() 做不同的事:

  • Wiener:计算协方差矩阵 SVD
  • Optimal:构建 FAISS 索引
  • PCA Locality:计算协方差 SVD + 保留数据集引用
  • Nearest:将数据集加载到 GPU 内存
def prepare_latents(self, batch_size, generator=None):
    shape = (batch_size, self.n_channels, self.resolution, self.resolution)
    latents = torch.randn(shape, generator=generator, device=self.device)
    return latents * self.scheduler.init_noise_sigma

生成初始噪声 xTN(0,I)x_T \sim \mathcal{N}(0, I),乘以调度器的初始噪声标准差。

def compute_noise_from_x0(self, x_t, pred_x0, timestep):
    alpha_prod = self.scheduler.alphas_cumprod[t]
    beta_prod = 1 - alpha_prod
    sqrt_alpha = torch.sqrt(alpha_prod)
    sqrt_beta = torch.sqrt(beta_prod + 1e-8)
    return (x_t - sqrt_alpha * pred_x0) / sqrt_beta

从预测的 x^0\hat{x}_0 反推预测噪声 ϵ^\hat{\epsilon}

ϵ^=xtαˉtx^01αˉt\hat{\epsilon} = \frac{x_t - \sqrt{\bar{\alpha}_t} \hat{x}_0}{\sqrt{1 - \bar{\alpha}_t}}

这是因为 DDIMScheduler 期望接收噪声预测(prediction_type="epsilon")。

def set_timesteps(self, num_steps):
    self.scheduler.set_timesteps(num_steps)
    self.num_steps = num_steps

更新采样步数。DDIMScheduler 会自动计算等间距时间步。

def _image_preprocess(self, img):
    # 插值到目标分辨率 + 归一化到 [-1, 1]
    imgs = F.interpolate(img, size=(self.resolution, self.resolution), mode="bilinear")
    return (imgs - 0.5) * 2

def _image_postprocess(self, img):
    # [-1, 1] → [0, 1]
    return ((img + 1) / 2).clamp(0, 1)

这两个方法在基类中定义但未被所有子类使用(数据预处理主要在 data/utils.py 中完成)。

@torch.no_grad()
def sample(self, *, num_samples, batch_size, generator=None,
           return_intermediates=False) -> SamplingOutput:

处理多批次采样:

  1. 计算需要的批次数
  2. 对每批调用 _sample_batch()
  3. 拼接所有批次的结果(images、trajectory_xt、trajectory_x0)
  4. 返回统一的 SamplingOutput
def _sample_batch(self, *, batch_size, generator, return_intermediates):
    latents = self.prepare_latents(batch_size, generator)

    for step_idx, timestep in enumerate(self.scheduler.timesteps):
        # 1. 调用子类的 denoise() 预测 x̂₀
        pred_x0 = self.denoise(latents, timestep, generator=generator)

        # 2. 从 x̂₀ 反推预测噪声 ε̂
        predicted_noise = self.compute_noise_from_x0(latents, pred_x0, timestep)

        # 3. DDIM 调度器计算 x_{t-1}
        step_output = self.scheduler.step(
            model_output=predicted_noise,
            timestep=timestep,
            sample=latents,
        )

        # 4. 记录轨迹(可选)
        if return_intermediates:
            trajectory_xt.append(latents.detach().cpu())
            trajectory_x0.append(pred_x0.detach().cpu())

        latents = step_output.prev_sample  # 更新为 x_{t-1}

    return SamplingOutput(images=last_pred_x0, ...)

注意:最终返回的 images 是最后一步的 pred_x0(而非 latents),因为 pred_x0 是对干净图像的直接预测。

x_T (纯噪声)
  ▼ t = 999
  denoise(x_T, 999) → x̂₀⁽¹⁾
  compute_noise → ε̂⁽¹⁾
  scheduler.step → x_{t-1}
  ▼ t = 899
  denoise(x_{899}, 899) → x̂₀⁽²⁾
  compute_noise → ε̂⁽²⁾
  scheduler.step → x_{t-2}
  ... (重复 num_inference_steps 次)
  ▼ t = 99
  denoise(x_{99}, 99) → x̂₀⁽ᴺ⁾  ← 最终输出
def build_sample_output(self, images, trajectory_xt, trajectory_x0, timesteps):
    return SamplingOutput(
        images=images,
        trajectory_xt=trajectory_xt,
        trajectory_x0=trajectory_x0,
        timesteps=timesteps,
    )

简单的工厂方法,将采样结果封装为 SamplingOutput

相关内容