Analytic Diffusion Studio — 模型基类与采样循环
系列 - Analytic Diffusion Studio系列
目录
summary: “Analytic Diffusion Studio — 模型基类与采样循环”
文件:src/local_diffusion/models/base.py、src/local_diffusion/models/__init__.py
5.1 模型注册表
与数据集类似,模型也使用注册表模式:
# models/__init__.py
MODEL_REGISTRY: Dict[str, Callable[..., Any]] = {}
def register_model(name: str):
def decorator(cls_or_factory):
MODEL_REGISTRY[name.lower()] = cls_or_factory
return cls_or_factory
return decorator
def create_model(name: str, **kwargs) -> Any:
factory = MODEL_REGISTRY.get(name.lower())
if factory is None:
raise ValueError(f"Unknown model '{name}'. Available: {sorted(MODEL_REGISTRY)}")
return factory(**kwargs)已注册的模型在模块底部通过导入触发:
from . import nearest_dataset # "nearest_dataset"
from . import optimal # "optimal"
from . import scfdm # "scfdm"
from . import wiener # "wiener"
from . import pca_locality # "pca_locality"在 generate.py 中的调用方式:
model = create_model(
cfg.model.name, # 例如 "pca_locality"
dataset=dataset, # DatasetBundle
device=cfg.experiment.device,
num_steps=cfg.sampling.num_inference_steps,
params=model_params, # 模型特定超参数 dict
)5.2 SamplingOutput 数据类
采样结果的统一封装:
@dataclass
class SamplingOutput:
images: torch.Tensor # 最终生成图像 [N, C, H, W]
timesteps: Optional[List[int]] # 时间步列表(如 [999, 888, ...])
trajectory_xt: Optional[List[Tensor]] # 每步的噪声图像 x_t
trajectory_x0: Optional[List[Tensor]] # 每步的预测干净图像 x̂₀trajectory_xt 和 trajectory_x0 是列表,每个元素对应一个时间步,形状为 [N, C, H, W]。仅在 return_intermediates=True 时记录。
5.3 BaseDenoiser 类
所有去噪模型的抽象基类,继承自 torch.nn.Module。
构造函数
class BaseDenoiser(torch.nn.Module):
prediction_type: str = "epsilon" # 预测类型:噪声 ε
def __init__(self, resolution, device, num_steps, *,
beta_1=0.0001, beta_T=0.02,
dataset_name="cifar10",
scheduler_kwargs=None, **kwargs):
self.device = torch.device(device)
self.n_channels = kwargs.get("in_channels", 3)
self.resolution = resolution
self.num_steps = num_steps
# 初始化 DDIM 调度器
self.scheduler = DDIMScheduler(
beta_start=beta_1, beta_end=beta_T,
beta_schedule="linear",
prediction_type=self.prediction_type,
)
self.scheduler.set_timesteps(num_steps)关键属性:
self.scheduler:HuggingFace diffusers 的DDIMScheduler,管理噪声调度表和时间步self.scheduler.alphas_cumprod:累积 数组,长度 1000self.scheduler.timesteps:采样时间步(如 10 步时为[999, 899, 799, ...])
denoise() — 抽象方法
def denoise(self, latents, timestep, *, generator=None, **kwargs):
"""预测去噪后的干净图像 x̂₀。子类必须实现此方法。"""
raise NotImplementedError输入:
latents:当前噪声图像 ,形状[B, C, H, W]timestep:当前时间步(标量或 0-d 张量)
输出:
- 预测的干净图像 ,形状
[B, C, H, W]
train() — 抽象方法
def train(self, dataset):
"""预计算数据集相关参数。子类必须实现。"""
raise NotImplementedError不同模型的 train() 做不同的事:
- Wiener:计算协方差矩阵 SVD
- Optimal:构建 FAISS 索引
- PCA Locality:计算协方差 SVD + 保留数据集引用
- Nearest:将数据集加载到 GPU 内存
prepare_latents()
def prepare_latents(self, batch_size, generator=None):
shape = (batch_size, self.n_channels, self.resolution, self.resolution)
latents = torch.randn(shape, generator=generator, device=self.device)
return latents * self.scheduler.init_noise_sigma生成初始噪声 ,乘以调度器的初始噪声标准差。
compute_noise_from_x0()
def compute_noise_from_x0(self, x_t, pred_x0, timestep):
alpha_prod = self.scheduler.alphas_cumprod[t]
beta_prod = 1 - alpha_prod
sqrt_alpha = torch.sqrt(alpha_prod)
sqrt_beta = torch.sqrt(beta_prod + 1e-8)
return (x_t - sqrt_alpha * pred_x0) / sqrt_beta从预测的 反推预测噪声 :
这是因为 DDIMScheduler 期望接收噪声预测(prediction_type="epsilon")。
set_timesteps()
def set_timesteps(self, num_steps):
self.scheduler.set_timesteps(num_steps)
self.num_steps = num_steps更新采样步数。DDIMScheduler 会自动计算等间距时间步。
_image_preprocess() / _image_postprocess()
def _image_preprocess(self, img):
# 插值到目标分辨率 + 归一化到 [-1, 1]
imgs = F.interpolate(img, size=(self.resolution, self.resolution), mode="bilinear")
return (imgs - 0.5) * 2
def _image_postprocess(self, img):
# [-1, 1] → [0, 1]
return ((img + 1) / 2).clamp(0, 1)这两个方法在基类中定义但未被所有子类使用(数据预处理主要在 data/utils.py 中完成)。
5.4 采样循环
sample() — 公共接口
@torch.no_grad()
def sample(self, *, num_samples, batch_size, generator=None,
return_intermediates=False) -> SamplingOutput:处理多批次采样:
- 计算需要的批次数
- 对每批调用
_sample_batch() - 拼接所有批次的结果(images、trajectory_xt、trajectory_x0)
- 返回统一的
SamplingOutput
_sample_batch() — 单批次 DDIM 循环
def _sample_batch(self, *, batch_size, generator, return_intermediates):
latents = self.prepare_latents(batch_size, generator)
for step_idx, timestep in enumerate(self.scheduler.timesteps):
# 1. 调用子类的 denoise() 预测 x̂₀
pred_x0 = self.denoise(latents, timestep, generator=generator)
# 2. 从 x̂₀ 反推预测噪声 ε̂
predicted_noise = self.compute_noise_from_x0(latents, pred_x0, timestep)
# 3. DDIM 调度器计算 x_{t-1}
step_output = self.scheduler.step(
model_output=predicted_noise,
timestep=timestep,
sample=latents,
)
# 4. 记录轨迹(可选)
if return_intermediates:
trajectory_xt.append(latents.detach().cpu())
trajectory_x0.append(pred_x0.detach().cpu())
latents = step_output.prev_sample # 更新为 x_{t-1}
return SamplingOutput(images=last_pred_x0, ...)注意:最终返回的 images 是最后一步的 pred_x0(而非 latents),因为 pred_x0 是对干净图像的直接预测。
5.5 采样流程图
x_T (纯噪声)
│
▼ t = 999
denoise(x_T, 999) → x̂₀⁽¹⁾
compute_noise → ε̂⁽¹⁾
scheduler.step → x_{t-1}
│
▼ t = 899
denoise(x_{899}, 899) → x̂₀⁽²⁾
compute_noise → ε̂⁽²⁾
scheduler.step → x_{t-2}
│
... (重复 num_inference_steps 次)
│
▼ t = 99
denoise(x_{99}, 99) → x̂₀⁽ᴺ⁾ ← 最终输出5.6 build_sample_output()
def build_sample_output(self, images, trajectory_xt, trajectory_x0, timesteps):
return SamplingOutput(
images=images,
trajectory_xt=trajectory_xt,
trajectory_x0=trajectory_x0,
timesteps=timesteps,
)简单的工厂方法,将采样结果封装为 SamplingOutput。