目录

Analytic Diffusion Studio — 配置系统

文件:src/local_diffusion/configuration.py

配置系统基于 OmegaConf 实现,支持 YAML 继承、CLI 覆盖、结构化类型校验。

配置系统定义了一组嵌套的 dataclass,作为配置的结构化 schema:

实验元数据。

字段类型默认值说明
namestr"default"实验分组名(决定输出目录层级)
run_namestrNone运行名称(输出子目录名)
seedint42全局随机种子
tagsList[str][]实验标签(用于 WandB 过滤)
append_timestampboolTrue是否在运行名后追加时间戳
devicestr自动检测计算设备:cuda / cpu / mps

目录布局。

字段类型默认值说明
rootstr"data"数据根目录
datasetsstrNonedata/datasets数据集存储路径
modelsstrNonedata/models模型/索引存储路径
runsstrNonedata/runs实验输出路径
wandbstr"wandb"WandB 日志路径

数据集加载参数。

字段类型默认值说明
namestr"mnist"数据集名称(注册表键)
splitstr"train"数据集划分
downloadboolTrue是否自动下载
batch_sizeint512DataLoader 批大小
num_workersint4DataLoader 工作进程数
subset_sizeintNone子集大小(None 表示全量)
rootstrNone数据集根目录(默认继承 paths.datasets)
resolutionintNone图像分辨率覆盖

模型选择与超参数。

字段类型默认值说明
namestr"nearest_dataset"模型名称(注册表键)
paramsDict{}模型特定超参数

采样参数。

字段类型默认值说明
num_samplesint16生成样本总数
batch_sizeint8每批生成数量
num_inference_stepsint10DDIM 推理步数

评估与输出配置。

字段类型说明
baseline_pathstrUNet 基线权重路径(None 则跳过对比)
outputOutputConfig输出开关(见下)
wandbWandbConfigWandB 配置(见下)
字段默认值说明
code_snapshotTrue保存代码快照(tar.gz)
save_metricsTrue保存 metrics.json
save_final_imagesTrue保存最终生成图像
save_image_gridTrue保存图像网格
save_intermediate_imagesTrue保存中间去噪步骤图像
字段默认值说明
enabledTrue是否启用 WandB
project"local-diffusion"WandB 项目名
entityNoneWandB 团队/用户
mode"online"online / offline / disabled
tagsNone标签(None 时继承 experiment.tags)
job_type"generation"WandB 任务类型

运行时路径集合(由 ensure_run_directory() 创建)。

字段说明
run_dir运行根目录
artifacts产物目录
images最终图像目录
tensors张量存储目录
intermediate_images中间步骤图像目录
logs日志目录
config保存的配置文件路径

load_config(config_path, overrides) 的执行步骤:

1. _resolve_config_path()
   ├── 绝对路径  直接使用
   ├── 相对路径  先尝试 configs/ 前缀
   └── 否则原样使用

2. OmegaConf.load(config_path)
   └── 解析 YAML 文件

3. 处理 defaults 列表
   ├── 遍历 defaults 中的路径
   ├── 解析每个默认配置文件
   └── OmegaConf.merge(*defaults, main_config)

4. OmegaConf.merge(structured_base, merged)
   └──  Config dataclass 合并(提供类型校验和默认值)

5. 应用 CLI overrides
   └── OmegaConf.from_dotlist(overrides)

6. _resolve_default_paths()
   └── 将相对路径解析为绝对路径

7. _resolve_metrics_defaults()
   └── wandb.tags 为空时继承 experiment.tags

8. OmegaConf.set_readonly(True)
   └── 冻结配置,防止意外修改

configs/pca_locality/celeba_hq.yaml

defaults:
  - /defaults.yaml          # 继承基础配置

experiment:
  run_name: pca_locality_celeba_hq  # 覆盖运行名
  tags: [baseline, pca_locality, celeba_hq]

dataset:
  name: celeba_hq
  resolution: 64            # 覆盖分辨率(降低计算量)

model:
  name: pca_locality
  params:
    temperature: 1.0
    mask_threshold: 0.02

合并优先级:defaults.yaml < 具体配置文件 < CLI overrides

通过点号分隔的路径覆盖任意配置值:

uv run generate.py --config pca_locality/celeba_hq.yaml \
    sampling.num_samples=16 \
    model.params.temperature=0.5 \
    experiment.device=cpu \
    metrics.wandb.enabled=false

ensure_run_directory(cfg) 创建如下目录结构:

data/runs/{experiment.name}/{run_name}_{timestamp}/
├── config.yaml
├── artifacts/
│   ├── images/
│   ├── tensors/
│   └── intermediate_images/
└── logs/

snapshot_codebase() 函数:

  1. 调用 git ls-files 获取所有 git 跟踪的文件
  2. 打包为 code_snapshot.tar.gz
  3. 存储在运行目录下,确保实验可复现
函数说明
save_config(cfg, path)将 OmegaConf 配置保存为 YAML
config_to_dict(cfg)将配置转为普通 Python dict
get_git_tracked_paths(root)获取 git 跟踪的文件列表

相关内容