smalldiffusion 数据模块:data.py
系列 - Smalldiffusion系列
目录
本文件提供数据集工具函数和三个 2D 玩具数据集,用于快速验证扩散模型的正确性。
3.1 模块结构
data.py
├── MappedDataset # 数据集映射包装器
├── img_train_transform # 图像训练预处理
├── img_normalize # 图像反归一化
├── Swissroll # 瑞士卷数据集
├── DatasaurusDozen # Datasaurus 数据集
├── interpolate_polyline() # 多段线插值辅助函数
└── TreeDataset # 树形条件数据集3.2 MappedDataset
是什么
一个通用的数据集包装器,对原始数据集的每个元素应用一个映射函数。
class MappedDataset(Dataset):
def __init__(self, dataset, fn):
self.dataset = dataset
self.fn = fn
def __len__(self):
return len(self.dataset)
def __getitem__(self, i):
return self.fn(self.dataset[i])为什么需要
PyTorch 的标准图像数据集(如 FashionMNIST, CIFAR10)返回 (image, label) 元组。但无条件扩散模型的训练只需要图像数据,不需要标签。MappedDataset 提供了一种简洁的方式来丢弃标签或做其他变换。
使用示例
from torchvision.datasets import FashionMNIST
from smalldiffusion import MappedDataset, img_train_transform
# 丢弃标签,只保留图像
dataset = MappedDataset(
FashionMNIST('datasets', train=True, download=True, transform=img_train_transform),
lambda x: x[0] # x 是 (image, label),只取 image
)
# dataset[0] 现在直接返回图像张量,而非 (image, label) 元组3.3 img_train_transform
是什么
用于图像扩散模型训练的标准预处理管道。
img_train_transform = tf.Compose([
tf.RandomHorizontalFlip(), # 随机水平翻转(数据增强)
tf.ToTensor(), # PIL Image → Tensor, 值域 [0, 1]
tf.Lambda(lambda t: (t * 2) - 1) # 归一化到 [-1, 1]
])为什么归一化到 [-1, 1]
扩散模型假设数据分布的均值接近 0。将像素值从 [0, 1] 映射到 [-1, 1] 使数据中心化,有助于训练稳定性。
3.4 img_normalize
是什么
将模型输出从 [-1, 1] 反归一化回 [0, 1] 用于可视化。
img_normalize = lambda x: ((x + 1)/2).clamp(0, 1)使用示例
from torchvision.utils import save_image, make_grid
from smalldiffusion import img_normalize
# 采样后反归一化并保存
*xt, x0 = samples(model, schedule.sample_sigmas(20), gam=1.6, batchsize=64)
save_image(img_normalize(make_grid(x0)), 'samples.png')3.5 Swissroll 数据集
是什么
经典的 2D 瑞士卷数据集,点沿螺旋线分布。
class Swissroll(Dataset):
def __init__(self, tmin, tmax, N, center=(0,0), scale=1.0):
t = tmin + torch.linspace(0, 1, N) * tmax
center = torch.tensor(center).unsqueeze(0)
self.vals = center + scale * torch.stack([
t * torch.cos(t) / tmax,
t * torch.sin(t) / tmax
]).T
def __len__(self):
return len(self.vals)
def __getitem__(self, i):
return self.vals[i]参数说明
| 参数 | 类型 | 说明 |
|---|---|---|
tmin | float | 螺旋起始角度(弧度) |
tmax | float | 螺旋终止角度(弧度) |
N | int | 数据点数量 |
center | tuple | 螺旋中心坐标,默认 (0, 0) |
scale | float | 缩放因子,默认 1.0 |
数学原理
参数方程:
其中 在 上均匀采样。除以 使数据归一化到合理范围。
使用示例
import numpy as np
from torch.utils.data import DataLoader
from smalldiffusion import Swissroll
dataset = Swissroll(np.pi/2, 5*np.pi, 100)
print(f"数据点数: {len(dataset)}") # 100
print(f"数据维度: {dataset[0].shape}") # torch.Size([2])
loader = DataLoader(dataset, batch_size=2048)3.6 DatasaurusDozen 数据集
是什么
加载 Datasaurus Dozen 数据集中的指定子集。这是一组统计特征相同但形状完全不同的 2D 数据集。
class DatasaurusDozen(Dataset):
def __init__(self, csv_file, dataset, enlarge_factor=15,
delimiter='\t', scale=50, offset=50):
self.enlarge_factor = enlarge_factor
self.points = []
with open(csv_file, newline='') as f:
for name, *rest in csv.reader(f, delimiter=delimiter):
if name == dataset:
point = torch.tensor(list(map(float, rest)))
self.points.append((point - offset) / scale)
def __len__(self):
return len(self.points) * self.enlarge_factor
def __getitem__(self, i):
return self.points[i % len(self.points)]参数说明
| 参数 | 类型 | 说明 |
|---|---|---|
csv_file | str | TSV 文件路径 |
dataset | str | 子数据集名称(如 "dino", "star" 等) |
enlarge_factor | int | 数据重复倍数,默认 15 |
scale | float | 缩放因子,默认 50 |
offset | float | 偏移量,默认 50 |
设计细节
enlarge_factor:原始数据点较少(约 142 个),通过重复扩大数据集,使 DataLoader 的 batch 采样更有效(point - offset) / scale:将数据中心化并缩放到 [-1, 1] 附近__getitem__使用取模运算实现循环访问
使用示例
from smalldiffusion import DatasaurusDozen
dataset = DatasaurusDozen('datasets/DatasaurusDozen.tsv', 'dino')
print(f"数据点数: {len(dataset)}") # 142 * 15 = 2130
print(f"数据维度: {dataset[0].shape}") # torch.Size([2])3.7 interpolate_polyline 辅助函数
是什么
沿多段线(polyline)均匀采样点的工具函数,被 TreeDataset 使用。
def interpolate_polyline(points, num_samples):
points = np.array(points)
dists = np.linalg.norm(np.diff(points, axis=0), axis=1) # 相邻点距离
cumdist = np.concatenate(([0], np.cumsum(dists))) # 累积弧长
total_length = cumdist[-1]
sample_dists = np.linspace(0, total_length, num_samples) # 均匀弧长采样点
samples = []
for d in sample_dists:
seg = np.searchsorted(cumdist, d, side='right') - 1
seg = min(seg, len(dists) - 1)
t = (d - cumdist[seg]) / dists[seg] if dists[seg] > 0 else 0
sample = (1 - t) * points[seg] + t * points[seg + 1]
samples.append(sample)
return np.array(samples)工作原理
- 计算相邻点之间的欧氏距离
- 计算累积弧长
- 在总弧长上均匀取
num_samples个位置 - 对每个位置,找到所在线段并线性插值
3.8 TreeDataset
是什么
一个带标签的 2D 条件数据集,数据点沿树形结构分布。每个叶节点对应一个类别,用于条件扩散模型的训练和 Classifier-Free Guidance 的演示。
class TreeDataset(Dataset):
def __init__(self, branching_factor=4, depth=3, num_samples_per_path=30):
self.data = []
self.total_leaves = branching_factor ** depth
for i in range(self.total_leaves):
path_points = [np.array([0.0, 0.0])] # 根节点
for l in range(1, depth + 1):
group_size = branching_factor ** (depth - l)
A_l = i // group_size
avg_index = A_l * group_size + (group_size - 1) / 2.0
theta = avg_index * (2 * np.pi / self.total_leaves)
r = l / depth
p = np.array([r * np.cos(theta), r * np.sin(theta)])
path_points.append(p)
samples = interpolate_polyline(path_points, num_samples_per_path)
for sample in samples:
self.data.append((torch.tensor(sample, dtype=torch.float32), i))参数说明
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
branching_factor | int | 4 | 每个节点的分支数 |
depth | int | 3 | 树的深度 |
num_samples_per_path | int | 30 | 每条路径上的采样点数 |
数据结构
- 总叶节点数 =
branching_factor ** depth(默认 64) - 每个叶节点有一条从根 (0,0) 到单位圆上某点的路径
- 路径经过
depth个中间节点,每个节点的角度由其子树的平均叶节点位置决定 - 每条路径上均匀采样
num_samples_per_path个点 - 总数据量 =
total_leaves * num_samples_per_path
返回格式
与其他数据集不同,TreeDataset.__getitem__ 返回 (coordinate, label) 元组:
coordinate:torch.FloatTensor,形状[2]label:int,叶节点索引(类别标签)
这种格式直接支持条件训练(training_loop 的 conditional=True)。
使用示例
from torch.utils.data import DataLoader
from smalldiffusion import TreeDataset, ConditionalMLP, ScheduleLogLinear, training_loop
dataset = TreeDataset(branching_factor=4, depth=3)
print(f"总叶节点: {dataset.total_leaves}") # 64
print(f"总数据点: {len(dataset)}") # 64 * 30 = 1920
loader = DataLoader(dataset, batch_size=512, shuffle=True)
batch, labels = next(iter(loader))
print(f"数据形状: {batch.shape}") # [512, 2]
print(f"标签形状: {labels.shape}") # [512]
# 条件训练
model = ConditionalMLP(dim=2, num_classes=dataset.total_leaves)
schedule = ScheduleLogLinear(N=200, sigma_min=0.01, sigma_max=10)
trainer = training_loop(loader, model, schedule, epochs=100, conditional=True)
losses = [ns.loss.item() for ns in trainer]