目录

Analytic Diffusion Studio — 数据模块

目录:src/local_diffusion/data/

数据模块负责数据集的注册、加载、预处理和后处理。采用工厂模式 + 注册表实现可扩展的数据集管理。

data/
├── __init__.py                  # 导出公共 API,触发数据集注册
├── datasets.py                  # 注册表、DatasetBundle、build_dataset()
├── torchvision_datasets.py      # MNIST、Fashion-MNIST、CIFAR-10 注册
├── image_folder_datasets.py     # CelebA-HQ、AFHQ 注册
└── utils.py                     # 图像变换、后处理、子集截取

DATASET_REGISTRY 是一个全局字典,将数据集名称映射到工厂函数:

DATASET_REGISTRY: Dict[str, DatasetFactory] = {}
@register_dataset("mnist")
def build_mnist(cfg: DatasetConfig) -> DatasetFactoryOutput:
    ...

装饰器将函数注册到 DATASET_REGISTRY["mnist"]。注册在模块导入时自动完成——data/__init__.py 中显式导入了 torchvision_datasetsimage_folder_datasets

工厂函数的返回类型:

@dataclass
class DatasetFactoryOutput:
    dataset: Dataset          # PyTorch Dataset 实例
    resolution: int           # 图像分辨率(正方形边长)
    in_channels: int          # 通道数(1=灰度,3=RGB)
    postprocess: Optional[Callable]  # 后处理函数([-1,1] → [0,1])

build_dataset() 的返回类型,封装了数据集的所有信息:

@dataclass
class DatasetBundle:
    name: str                 # 数据集名称
    dataset: Dataset          # PyTorch Dataset
    dataloader: DataLoader    # 预配置的 DataLoader
    resolution: int           # 图像分辨率
    in_channels: int          # 通道数
    split: str                # 数据划分(train/test)
    postprocess: Callable     # 后处理函数
def build_dataset(cfg: DatasetConfig) -> DatasetBundle:

执行步骤:

  1. DATASET_REGISTRY 查找工厂函数
  2. 调用工厂函数,传入 DatasetConfig
  3. 可选:应用子集截取 (maybe_apply_subset)
  4. 创建 DataLoadershuffle=False, pin_memory=True
  5. 封装为 DatasetBundle 返回
# DataLoader 配置
DataLoader(
    dataset,
    batch_size=cfg.batch_size,
    shuffle=False,          # 不打乱(保证可复现)
    num_workers=cfg.num_workers,
    pin_memory=True,        # GPU 加速
    drop_last=False,
)
@register_dataset("mnist")
def build_mnist(cfg: DatasetConfig) -> DatasetFactoryOutput:
    transform = utils.compose_transform(28, in_channels=1)
    dataset = datasets.MNIST(root=cfg.root, train=cfg.split == "train",
                             download=cfg.download, transform=transform)
    postprocess = utils.get_postprocess_fn()
    return DatasetFactoryOutput(dataset=dataset, resolution=28,
                                in_channels=1, postprocess=postprocess)
  • 分辨率:28×28,灰度(1 通道)
  • 支持自动下载
@register_dataset("fashion_mnist")
def build_fashion_mnist(cfg: DatasetConfig) -> DatasetFactoryOutput:
  • 与 MNIST 结构相同,分辨率 28×28,灰度
  • 使用 datasets.FashionMNIST
@register_dataset("cifar10")
def build_cifar10(cfg: DatasetConfig) -> DatasetFactoryOutput:
    transform = utils.compose_transform(32, in_channels=3)
    dataset = datasets.CIFAR10(...)
  • 分辨率:32×32,RGB(3 通道)
  • 支持自动下载
@register_dataset("celeba_hq")
def build_celeba_hq(cfg: DatasetConfig) -> DatasetFactoryOutput:
  • 默认分辨率:256(可通过 cfg.resolution 覆盖,通常降至 64)
  • 使用自定义 ImageFolderDataset(支持扁平目录结构)
  • 可选自动下载(从 Kaggle,需要 curl)
  • 期望目录:data/datasets/celebahq-resized-256x256/versions/1/celeba_hq_256/
@register_dataset("afhq")
def build_afhq(cfg: DatasetConfig) -> DatasetFactoryOutput:
  • 默认分辨率:512(通常降至 64)
  • 使用 torchvision.datasets.ImageFolder(按类别子目录组织)
  • 可选自动下载(从 Dropbox)
  • 期望目录:data/datasets/afhq/{split}/{class}/
class ImageFolderDataset(Dataset):
    """支持扁平目录结构的图像数据集(不要求按类别子目录组织)。"""
    
    def __init__(self, root_dir: str, transform=None):
        # 扫描目录中所有 .png/.jpg/.jpeg/.bmp/.webp 文件
        self.image_files = sorted([f for f in os.listdir(root_dir)
                                    if f.lower().endswith((...)) ])
    
    def __getitem__(self, idx):
        image = Image.open(img_path).convert("RGB")
        return image, 0  # 返回 0 作为虚拟标签

torchvision.datasets.ImageFolder 的区别:后者要求图像按类别放在子目录中,而 ImageFolderDataset 支持所有图像直接放在一个目录下。

utils.compose_transform(resolution, in_channels) 构建变换链:

def compose_transform(resolution, *, in_channels):
    ops = []
    if in_channels == 1:
        ops.append(transforms.Grayscale(num_output_channels=1))
    ops.extend([
        transforms.Resize((resolution, resolution)),
        transforms.ToTensor(),                    # [0, 255] → [0, 1]
        transforms.Normalize((0.5,)*C, (0.5,)*C), # [0, 1] → [-1, 1]
    ])
    return transforms.Compose(ops)

变换步骤:

  1. 灰度转换(仅单通道数据集)
  2. Resize 到目标分辨率(双线性插值)
  3. ToTensor:PIL Image → [0, 1] 浮点张量
  4. Normalize(x - 0.5) / 0.5[-1, 1] 范围

为什么归一化到 [-1, 1]? 扩散模型假设数据分布以 0 为中心,这与高斯噪声的分布一致,有利于训练和采样的数值稳定性。

def get_postprocess_fn():
    def postprocess(tensor):
        return ((tensor + 1.0) / 2.0).clamp(0, 1)  # [-1, 1] → [0, 1]
    return postprocess

后处理将模型输出从 [-1, 1] 映射回 [0, 1],用于图像保存和可视化。

def maybe_apply_subset(dataset, subset_size):
    if subset_size is None:
        return dataset
    indices = torch.arange(subset_size)
    return Subset(dataset, indices.tolist())

通过 dataset.subset_size 配置项限制数据集大小,用于快速调试或降低计算量。取前 N 个样本(非随机采样)。

# 在 data/ 下新建文件或在已有文件中添加:
from .datasets import register_dataset, DatasetFactoryOutput

@register_dataset("my_dataset")
def build_my_dataset(cfg: DatasetConfig) -> DatasetFactoryOutput:
    transform = utils.compose_transform(cfg.resolution or 64, in_channels=3)
    dataset = MyCustomDataset(root=cfg.root, transform=transform)
    return DatasetFactoryOutput(
        dataset=dataset,
        resolution=cfg.resolution or 64,
        in_channels=3,
        postprocess=utils.get_postprocess_fn(),
    )

然后在 data/__init__.py 中导入新模块以触发注册。

相关内容