PyTorch Dataset 体系技术文档
文档索引
| 章节 | 主题 | 主要内容 |
|---|---|---|
| 一、总览 | 整体架构与文档脉络 | 知识结构、各部分作用与关联、优缺点与适用场景 |
| 二、torch.utils.data.Dataset | PyTorch 数据集根基 | 抽象基类协议、__len__/__getitem__/__getitems__/__add__、与 DataLoader 的契约 |
| 三、torch.utils.data 内置扩展 | TensorDataset / ConcatDataset / Subset / StackDataset / IterableDataset / ChainDataset / random_split | 结构、完整接口、使用方式与可运行示例 |
| 四、torch_geometric.data.Dataset | 图数据专用 Dataset | 文件管理、三阶段 transform、Dataset vs InMemoryDataset |
| 五、HuggingFace datasets | 独立数据体系 | Arrow、map/filter、set_format 与 PyTorch 互转 |
| 六、领域库数据集 | torchvision / torchaudio | transform 管线、典型用法 |
| 七、典型项目扩展模式 | 多任务/RL/图风格数据入口 | BaseDataset/RLDataset/GTDataset 等概念与调用关系(通用模式) |
| 八、Padding 与 collate 协作 | 变长样本成 batch | 单条 padding 与 collate_fn 的职责划分、典型实现模式 |
| 九、与 DataLoader 的协作 | Dataset → DataLoader → 训练 | collate_fn、多进程、与 dataloader_guide.md 的衔接 |
一、总览
1.1 整体架构与知识结构
PyTorch 的数据集体系以 **torch.utils.data.Dataset** 为根基,向上衍生出多条线:
torch.utils.data.Dataset ← 所有 PyTorch 系的根基(map-style)
├── TensorDataset 包装多个等长 tensor
├── ConcatDataset 逻辑拼接多个 map-style 数据集
├── Subset 按索引取子集(train/val split)
├── StackDataset 多数据集按样本对齐堆叠(tuple/dict)
├── IterableDataset 流式数据,另一条线,不支持随机访问
│ └── ChainDataset 串联多个 IterableDataset(+ 运算符)
├── random_split(dataset, lengths) 随机划分成多个 Subset
│
├── torchvision / torchaudio 各数据集
│ └── 在基类上加了 PIL/波形 transform 管线
│
└── torch_geometric.data.Dataset 图数据:raw/processed 文件管理 + 三阶段 transform
└── InMemoryDataset 全量加载进内存,适合小图集
datasets.Dataset (HuggingFace) ← 独立体系,Arrow 格式,可通过 set_format 转 torch1.2 各部分作用与主题
| 部分 | 作用/主题 | 解决的核心问题 |
|---|---|---|
| Dataset 基类 | 定义「长度 + 按索引取一条」的协议 | 让 DataLoader 能统一驱动任意数据源,而不关心存储与格式 |
| TensorDataset | 把已有 tensor 包装成 Dataset | 快速从内存 tensor 建可迭代数据源 |
| ConcatDataset / Subset | 逻辑组合与切分 | 多数据源合并、train/val 划分,无需拷贝数据 |
| StackDataset | 多集按样本对齐堆叠为 tuple/dict | 多模态(图+文等)同索引对齐 |
| IterableDataset | 流式、无索引 | 大文件、管道、无法事先知道总条数的场景 |
| ChainDataset | 串联多个 IterableDataset | 流式多源顺序拼接;IterableDataset 的 + 运算符 |
| random_split | 按长度或比例随机划分 | 可复现的 train/val/test 划分 |
| PyG Dataset | 图数据 + 文件与预处理管线 | 下载、process、pre_filter/pre_transform/transform 分工 |
| HuggingFace datasets | 列式存储、map/filter、格式转换 | 大表、并行预处理、与 PyTorch 无缝对接 |
| 领域库 | 图像/音频等领域的标准数据集与 transform | 常见任务开箱即用 |
| 本项目 BaseDataset 系 | 多任务 pkl 文件列表、任务比例、padding、collate | 统一多任务驾驶数据与仿真入口 |
| 本项目 GTDataset | 借 PyG 外壳、pkl 路径列表、按需加载与 padding | 与 BaseDataset 类似但走 PyG 目录约定,便于与现有脚本统一 |
1.3 彼此关联
- DataLoader 只依赖「长度 + 按索引取一条」或「可迭代流」;Dataset 决定单条样本格式,collate_fn 决定多条如何成 batch。
- Padding 在
__getitem__/get()内对单条样本执行,collate_fn 在 DataLoader 内对 list of 样本执行;两者配合得到固定形状的 batch。
1.4 优缺点与适用场景对比
| 类型 | 优点 | 缺点 | 适用场景 |
|---|---|---|---|
| map-style Dataset | 支持随机访问、shuffle、多进程索引分片 | 需事先能确定长度 | 绝大多数训练/验证 |
| IterableDataset | 可流式、不必装全量进内存 | 不能真正 shuffle,多进程需自行分片 | 大日志、管道、实时流 |
| TensorDataset | 实现极简、零拷贝 | 仅适合已在内存的等长 tensor | 小规模、快速实验 |
| ConcatDataset/Subset | 逻辑视图、无拷贝 | 仅做索引映射 | 多源合并、比例划分 |
| StackDataset | 多集对齐、tuple/dict 输出 | 要求各集长度一致 | 多模态按样本对齐 |
| ChainDataset | 流式串联、无预加载 | 仅支持 IterableDataset | 多流顺序拼接 |
| PyG Dataset | 标准化的下载/process/transform | 需实现 len/get 与文件列表 | 图结构、需持久化处理结果 |
| HuggingFace datasets | map/filter 强、Arrow 大表友好 | 与 PyTorch 无继承关系 | NLP/表格、Hub 数据、大批量预处理 |
| 本项目 BaseDataset | 多任务、比例采样、与训练脚本深度集成 | 依赖配置与 pkl 目录结构 | 多任务驾驶模型训练与仿真 |
| 本项目 GTDataset | 与 PyG 生态一致、接口简单 | 未用 PyG 的 process(),本质仍是 pkl 列表 | 需要 DataLoader + PyG 风格时的数据入口 |
二、torch.utils.data.Dataset — 根基
2.1 结构与边界
所有 PyTorch map-style 数据集的抽象基类。基类本身不包含任何文件管理、transform 或下载逻辑,只定义与 DataLoader 的协议:
- 必须实现:
__getitem__(self, index)— 按索引返回一条样本。 - 强烈建议实现:
__len__(self)— 返回数据集大小;多数 Sampler 和 DataLoader 默认行为依赖它。 - 可选:
__getitems__(self, indices)— 批量取多条样本,用于 DataLoader 加速;未实现时内部会多次调用__getitem__。 - 内置便捷方法:
__add__(self, other)— 返回ConcatDataset([self, other]),可用ds1 + ds2拼接。
与 DataLoader 的契约:DataLoader 通过 len(dataset) 得到长度,通过 dataset[i](或 dataset.__getitems__(indices))取样本;索引由 Sampler 提供,通常为 0 .. len(dataset)-1 的某种排列。若使用非整数 key 或不可求长度,需自定义 Sampler。
2.2 关键概念
- map-style:有固定长度、支持整数索引、可随机访问;与 IterableDataset(仅
__iter__,无索引)相对。 - 为何以「长度 + 按索引取一条」为核心:DataLoader 只关心「有多少条」和「第 i 条是什么」;存储、解码、增强等都在子类的
__getitem__内完成,接口最小且稳定。 **__getitems__的作用**:当 DataLoader 一次要取一批索引时,若子类实现了__getitems__(indices),会优先调用它以减少 Python 调用次数、便于做批量 I/O,否则会逐次调用__getitem__。**__add__**:语法糖,ds1 + ds2等价于ConcatDataset([ds1, ds2]),逻辑拼接、无数据拷贝。
2.3 使用方式与示例
继承并实现 __len__ 与 __getitem__ 即可接入 DataLoader;样本格式由业务决定,collate_fn 需与之匹配。需要加速时可实现 __getitems__。
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data, labels):
self.data = data # shape: (N, ...)
self.labels = labels # shape: (N,)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
# 可选:批量取样本,供 DataLoader 优化用
def __getitems__(self, indices):
return [(self.data[i], self.labels[i]) for i in indices]
ds = MyDataset(torch.randn(100, 32), torch.randint(0, 10, (100,)))
print(len(ds)) # 100
print(ds[0][0].shape) # torch.Size([32])
# __add__ 示例
ds2 = MyDataset(torch.randn(50, 32), torch.randint(0, 10, (50,)))
combined = ds + ds2 # ConcatDataset
print(len(combined)) # 150三、torch.utils.data 内置扩展
3.1 结构说明与模块划分
基于 Dataset / IterableDataset,PyTorch 在 torch.utils.data 中提供以下组件,按用途分为:包装类(TensorDataset、StackDataset)、组合类(ConcatDataset、Subset、ChainDataset)、流式基类(IterableDataset)、工具函数(random_split)。
| 类/函数 | 用途 | 核心接口/行为 |
|---|---|---|
| TensorDataset | 包装多个第 0 维等长的 tensor | __init__(*tensors);__getitem__(idx) → (t[0][idx], t[1][idx], ...);__len__;属性 .tensors |
| ConcatDataset | 逻辑拼接多个 map-style Dataset | __init__(datasets);__len__ 为各集长度之和;__getitem__(idx) 按 cumulative_sizes 映射到对应子集;静态方法 cumsum(sequence);属性 cumulative_sizes;不支持 IterableDataset |
| Subset | 按索引列表取子集 | __init__(dataset, indices);__getitem__(idx) → dataset[indices[idx]],支持 list 索引;__getitems__(indices) 委托底层以支持 DataLoader 批取;__len__ → len(indices) |
| StackDataset | 多数据集「按样本对齐」堆叠为 tuple 或 dict | __init__(*args) 或 __init__(**kwargs);__getitem__(idx) 返回 (ds0[idx], ds1[idx], ...) 或 {k: ds[idx] for k, ds};__getitems__ 若子集支持则批取;__len__ 要求各数据集长度一致 |
| IterableDataset | 流式数据基类 | 实现 __iter__,无 __len__/__getitem__;__add__ 返回 ChainDataset;多 worker 时需用 get_worker_info() 分片 |
| ChainDataset | 串联多个 IterableDataset | __init__(datasets);__iter__ 依次 yield from 各集;__len__ 为各集长度之和(若均有 __len__) |
| random_split | 按长度或比例随机划分 | random_split(dataset, lengths, generator=None) 返回 list[Subset];lengths 可为整数列表或和为 1 的比例列表 |
3.2 关键概念与注意点
- TensorDataset:所有传入 tensor 的第 0 维必须相同;适合已在内存的简单特征+标签;实现极简、零拷贝。
- ConcatDataset:子集类型可不同;索引自动偏移(含负索引),无数据拷贝;不能包含 IterableDataset(会 AssertionError)。
- Subset:
indices可为任意整数序列,可重复;不拷贝数据。若子类重写__getitem__,必须同时重写__getitems__,否则 DataLoader 会报错。 - StackDataset:多个 Dataset 长度必须一致;按「同一索引」从各集取一条再组成 tuple/dict,适合多模态(如图+文)对齐。
- IterableDataset:不能使用 DataLoader 的
shuffle=True(需在数据源侧打乱);多 worker 时必须在__iter__或worker_init_fn中按get_worker_info()分片,否则每个 worker 会重复全量数据。 - ChainDataset:仅接受 IterableDataset;流式串联,不预加载。
- random_split:
lengths若为比例(和为一),会先换算为整数并处理余数;可用generator固定随机种子以保证可复现。
3.3 示例(可运行)
import math
import torch
from torch.utils.data import (
TensorDataset, ConcatDataset, Subset, StackDataset,
IterableDataset, ChainDataset, random_split,
)
# ---------- TensorDataset ----------
X = torch.randn(200, 16)
y = torch.randint(0, 5, (200,))
tensor_ds = TensorDataset(X, y)
print(len(tensor_ds), tensor_ds[0][0].shape) # 200, torch.Size([16])
# ---------- ConcatDataset:索引映射与负索引 ----------
combined = ConcatDataset([tensor_ds, tensor_ds])
print(len(combined)) # 400
print(combined.cumulative_sizes) # [200, 400]
print(combined[0] is tensor_ds[0], combined[200] is tensor_ds[0]) # True True
# ---------- Subset:train/val ----------
train_ds = Subset(tensor_ds, range(0, 160))
val_ds = Subset(tensor_ds, range(160, 200))
# ---------- random_split ----------
train_sub, val_sub = random_split(tensor_ds, [0.8, 0.2], generator=torch.Generator().manual_seed(42))
assert len(train_sub) + len(val_sub) == len(tensor_ds)
# ---------- StackDataset:多模态对齐 ----------
text_len = len(tensor_ds)
# 假设有两个等长数据集
stacked = StackDataset(tensor_ds, Subset(tensor_ds, range(text_len)))
print(stacked[0]) # ((tensor(...), tensor(...)), (tensor(...), tensor(...)))
# dict 形式:StackDataset(a=X_ds, b=y_ds) -> stacked[i] == {"a": ..., "b": ...}
# ---------- IterableDataset + 多 worker 分片 ----------
class RangeIterableDataset(IterableDataset):
def __init__(self, start, end):
self.start, self.end = start, end
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
iter_start, iter_end = self.start, self.end
else:
per_worker = int(math.ceil((self.end - self.start) / worker_info.num_workers))
iter_start = self.start + worker_info.id * per_worker
iter_end = min(iter_start + per_worker, self.end)
return iter(range(iter_start, iter_end))
it_ds = RangeIterableDataset(0, 10)
# 单进程:list(DataLoader(it_ds, num_workers=0)) 得到 0..9 各一次
# 多进程需上述分片,否则会重复
# ---------- ChainDataset ----------
chained = RangeIterableDataset(0, 3) + RangeIterableDataset(10, 13) # 0,1,2,10,11,12
print(list(chained)) # [0, 1, 2, 10, 11, 12]四、torch_geometric.data.Dataset — 图数据专用
4.1 结构与目录约定
继承自 torch.utils.data.Dataset,增加:
- 文件管理:
raw_dir(原始文件)与processed_dir(处理后的.pt),按需download()/process()。 - 三阶段 transform:
- pre_filter:在
process()时过滤,不存盘。 - pre_transform:在
process()时执行,结果存盘。 - transform:每次
get()时执行,不存盘。
目录与调用链:
root/
raw/ ← raw_dir
processed/ ← processed_dir
__init__ → 检查 raw → download();检查 processed → process()
dataset[idx] → get(idx) → [pre_filter / pre_transform 已在 process 中] → transform(data)4.2 必须实现的接口(PyG Dataset)
| 接口 | 说明 |
|---|---|
raw_file_names (property) | 原始文件名列表,用于判断是否需要 download |
processed_file_names (property) | 处理后文件名列表,用于判断是否需要 process |
len() | 数据集大小(注意是 len() 不是 __len__,PyG 内部会调) |
get(idx) | 按索引返回单条样本(注意是 get 不是 __getitem__,基类会在此上应用 transform) |
可选:download()、process()。
4.3 Dataset vs InMemoryDataset
| 对比项 | Dataset | InMemoryDataset |
|---|---|---|
| 存储 | 每样本一个 .pt | 全部合并为一个文件 |
get() | 每次从磁盘读 | 内存切片(data + slices) |
| 规模 | 大图集 | 小/中图集 |
是否必须实现 get() | 是 | 否(内部用 data/slices) |
4.4 三阶段 transform 对比
| 阶段 | 触发时机 | 是否存盘 | 典型用途 |
|---|---|---|---|
| pre_filter | process() 一次 | 否 | 过滤无效图(如 0 节点) |
| pre_transform | process() 一次 | 是 | 特征归一化、边属性 |
| transform | 每次 get() | 否 | 随机 dropout edge、数据增强 |
4.5 与 DataLoader 的配合
PyG 的 Data 对象可直接作为样本;DataLoader 的 collate_fn 通常使用 torch_geometric.loader.DataLoader 自带的 collate,将 list of Data 聚成 Batch(含 batch 向量等)。若使用 torch.utils.data.DataLoader,需自定义 collate 或封装成 tensor/dict。
五、HuggingFace datasets.Dataset — 独立体系
5.1 结构与边界
HuggingFace 的 datasets.Dataset 与 PyTorch 无继承关系,底层为 Apache Arrow 列式存储,支持内存映射、大表高效访问。常用能力与调用关系:
- load_dataset():从 Hub 或本地(CSV/JSON/Parquet 等)加载,返回
Dataset或DatasetDict。 - .map(fn, batched=True, …):对每行或每批应用函数,并行预处理(如 tokenize);
batched=True时传入批,性能更好。 - .filter(fn):按条件筛样本,返回新 Dataset。
- .set_format(“torch”, columns=…): 使迭代时输出变为 torch tensor,可直接配合
torch.utils.data.DataLoader(需用IterableDataset包装或通过.with_format迭代)。 - .train_test_split():划分 train/test,返回
DatasetDict。
5.2 关键概念
- 为何独立于 PyTorch:面向多框架(PyTorch/TF/JAX)与 Hub 生态,列式存储便于列级操作与流式。
- 与 PyTorch 对接:
set_format("torch")后,迭代得到的是 tensor;可包成IterableDataset或逐批取用,也可用Dataset的__getitem__ 封装成 map-style。
5.3 使用方式与示例(需安装 datasets)
# 需安装: pip install datasets
from datasets import load_dataset
ds = load_dataset("imdb", split="train[:100]") # 取前 100 条
ds = ds.map(lambda x: {"len": len(x["text"].split())}, batched=False)
ds = ds.filter(lambda x: x["len"] > 10)
# set_format("torch", columns=...) 在列已是数值/tensor 时使用,便于进 DataLoader
# 划分
splits = ds.train_test_split(test_size=0.2, seed=42)
train_ds, test_ds = splits["train"], splits["test"]六、领域库数据集
6.1 结构说明
torchvision / torchaudio 等均继承 torch.utils.data.Dataset,在基类上增加领域 transform 与标准数据接口:
- torchvision:
datasets.CIFAR10、ImageFolder等;transform 为 PIL → Tensor 或增强(Resize、Normalize、RandomCrop 等);transforms.Compose串联。 - torchaudio:
datasets.LIBRISPEECH等;transform 为波形/频谱(Resample、MelSpectrogram 等)。 - torchtext:文本 pipeline,逐渐被 HuggingFace 替代。
transform 在 __getitem__ 中对单样本执行;与 DataLoader 的 collate_fn 无关。
6.2 典型用法示例(可运行)
import torch
from torchvision import datasets, transforms
# 使用 CIFAR10(需能访问或已下载)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
# root 需存在或留空并允许下载
# cifar = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
# 仅演示结构:用 TensorDataset 模拟「图像+标签」
X = torch.rand(100, 3, 32, 32)
y = torch.randint(0, 10, (100,))
ds = torch.utils.data.TensorDataset(X, y)
loader = torch.utils.data.DataLoader(ds, batch_size=8, shuffle=True)
batch = next(iter(loader))
print(batch[0].shape, batch[1].shape) # torch.Size([8, 3, 32, 32]), torch.Size([8])ImageFolder:目录结构为 root/class_a/xx.png,自动按子目录名作标签,配合 transform 即可用于分类。
七、典型项目扩展模式
7.1 结构/边界说明
许多项目会在 torch.utils.data.Dataset 或 PyG/HuggingFace 之上再封一层,以统一多任务、多数据源、比例采样、padding 与 collate。本节描述通用模式;若你的项目中有类似 BaseDataset、RLDataset、SplitMergeDataset、GTDataset 等,可对应下列概念与调用关系。
- BaseDataset 系:从配置与 pkl 文件列表构建;产出「dict of tensors + labels」;支持多任务比例采样(通过 Sampler 或自定义 DataLoader 逻辑)。
- RLDataset:在 BaseDataset 思路上增加轨迹/步采样、专用 collate(如按步长 pad)。
- SplitMergeDataset:多数据源按比例或顺序合并,常与
ConcatDataset或自定义 Sampler 配合。 - GTDataset:借 PyG 的
Dataset外壳(目录约定、len/get),但数据来自 pkl 路径列表、按需加载并做 padding,与 BaseDataset 产出形态类似,便于与既有 PyG 风格脚本统一。
7.2 关键概念与调用关系
- 单条形态:通常为
dict of tensors + labels,便于 collate_fn 统一拼成 batch(如default_collate或自定义gt_collate_fn)。 - Padding:变长序列/图在 Dataset 的
__getitem__/get()内对单条做 pad,或留到 collate_fn 对 list 做 pad;两者分工见下一节。 - 与 DataLoader 的衔接:Dataset 只负责「单条样本」;Sampler 负责「取哪些、顺序」;collate_fn 负责「多条 → batch」。
7.3 适用场景
- 多任务驾驶/仿真:多 pkl 源、任务比例、统一入口。
- 需要 PyG 风格目录与 DataLoader 时:用 GTDataset 封装 pkl 列表,配合 PyG DataLoader 或自定义 collate。
八、Padding 与 collate 协作
8.1 职责划分
- 单条 padding(在 Dataset 内):在
__getitem__ 或 PyG 的get()中,对单条样本做长度对齐(如 pad 到固定长度或到本 batch 内最大长度前的某个上界)。适合「每条独立 pad 到同一规则」的场景。 - batch 内 padding(在 collate_fn 内):DataLoader 收集到 list of 样本后,在 collate_fn 里按本 batch 最大长度做 pad,再 stack 成 tensor。更灵活,且只 pad 到当前 batch 所需长度,节省显存。
两者可组合:例如在 Dataset 里做简单截断/最小 pad,在 collate_fn 里再做 batch 维度的 pad 与 mask。
8.2 关键概念
- 为何需要 padding:变长序列/图无法直接
torch.stack成规则 tensor;padding 后得到固定形状,便于模型前向。 - mask:通常与 padding 一起维护一个 mask 张量(如 1 为有效、0 为 pad),供模型忽略 pad 位置。
8.3 典型 collate_fn 模式(可运行)
import torch
def pad_collate(batch):
"""batch: list of (seq, label), seq 为 1D tensor 变长"""
seqs = [b[0] for b in batch]
labels = torch.stack([b[1] for b in batch])
max_len = max(s.size(0) for s in seqs)
padded = torch.zeros(len(seqs), max_len, dtype=seqs[0].dtype)
mask = torch.zeros(len(seqs), max_len, dtype=torch.bool)
for i, s in enumerate(seqs):
padded[i, :s.size(0)] = s
mask[i, :s.size(0)] = True
return padded, mask, labels
# 使用示例
data = [(torch.randn(i + 1), torch.tensor(i % 2)) for i in range(5)]
padded, mask, labels = pad_collate(data)
print(padded.shape, mask.shape, labels.shape) # [5, 5], [5, 5], [5]九、与 DataLoader 的协作
9.1 数据流与职责
- Dataset:提供
__len__与__getitem__(或 IterableDataset 的__iter__),定义「单条样本」形态。 - DataLoader:根据 Sampler 产生索引(或直接消费迭代器),在 worker 中取样本,经 collate_fn 聚成 batch,可选多进程与 pin_memory。
因此:Dataset 决定单条格式;collate_fn 决定多条如何成 batch;Sampler 决定取哪些、顺序。
9.2 关键衔接点
- map-style:DataLoader 调用
len(dataset)与dataset[indices](或__getitems__(indices));collate_fn 的输入是「list of 单条样本」。 - IterableDataset:无索引,DataLoader 从
iter(dataset)取数据;多 worker 时需在 dataset 或worker_init_fn中分片,避免重复。 - 自定义 batch:若需自定义每批的组成(如加权、多任务比例),可传
batch_sampler,或使用Sampler控制索引。