目录

PyTorch Dataset 体系技术文档

章节主题主要内容
一、总览整体架构与文档脉络知识结构、各部分作用与关联、优缺点与适用场景
二、torch.utils.data.DatasetPyTorch 数据集根基抽象基类协议、__len__/__getitem__/__getitems__/__add__、与 DataLoader 的契约
三、torch.utils.data 内置扩展TensorDataset / ConcatDataset / Subset / StackDataset / IterableDataset / ChainDataset / random_split结构、完整接口、使用方式与可运行示例
四、torch_geometric.data.Dataset图数据专用 Dataset文件管理、三阶段 transform、Dataset vs InMemoryDataset
五、HuggingFace datasets独立数据体系Arrow、map/filter、set_format 与 PyTorch 互转
六、领域库数据集torchvision / torchaudiotransform 管线、典型用法
七、典型项目扩展模式多任务/RL/图风格数据入口BaseDataset/RLDataset/GTDataset 等概念与调用关系(通用模式)
八、Padding 与 collate 协作变长样本成 batch单条 padding 与 collate_fn 的职责划分、典型实现模式
九、与 DataLoader 的协作Dataset → DataLoader → 训练collate_fn、多进程、与 dataloader_guide.md 的衔接

PyTorch 的数据集体系以 **torch.utils.data.Dataset** 为根基,向上衍生出多条线:

torch.utils.data.Dataset              ← 所有 PyTorch 系的根基(map-style)
    ├── TensorDataset                  包装多个等长 tensor
    ├── ConcatDataset                  逻辑拼接多个 map-style 数据集
    ├── Subset                         按索引取子集(train/val split)
    ├── StackDataset                   多数据集按样本对齐堆叠(tuple/dict)
    ├── IterableDataset                流式数据,另一条线,不支持随机访问
    │       └── ChainDataset           串联多个 IterableDataset(+ 运算符)
    ├── random_split(dataset, lengths) 随机划分成多个 Subset
    ├── torchvision / torchaudio 各数据集
    │       └── 在基类上加了 PIL/波形 transform 管线
    └── torch_geometric.data.Dataset   图数据:raw/processed 文件管理 + 三阶段 transform
            └── InMemoryDataset        全量加载进内存,适合小图集

datasets.Dataset (HuggingFace)        ← 独立体系,Arrow 格式,可通过 set_format 转 torch
部分作用/主题解决的核心问题
Dataset 基类定义「长度 + 按索引取一条」的协议让 DataLoader 能统一驱动任意数据源,而不关心存储与格式
TensorDataset把已有 tensor 包装成 Dataset快速从内存 tensor 建可迭代数据源
ConcatDataset / Subset逻辑组合与切分多数据源合并、train/val 划分,无需拷贝数据
StackDataset多集按样本对齐堆叠为 tuple/dict多模态(图+文等)同索引对齐
IterableDataset流式、无索引大文件、管道、无法事先知道总条数的场景
ChainDataset串联多个 IterableDataset流式多源顺序拼接;IterableDataset 的 + 运算符
random_split按长度或比例随机划分可复现的 train/val/test 划分
PyG Dataset图数据 + 文件与预处理管线下载、process、pre_filter/pre_transform/transform 分工
HuggingFace datasets列式存储、map/filter、格式转换大表、并行预处理、与 PyTorch 无缝对接
领域库图像/音频等领域的标准数据集与 transform常见任务开箱即用
本项目 BaseDataset 系多任务 pkl 文件列表、任务比例、padding、collate统一多任务驾驶数据与仿真入口
本项目 GTDataset借 PyG 外壳、pkl 路径列表、按需加载与 padding与 BaseDataset 类似但走 PyG 目录约定,便于与现有脚本统一
  • DataLoader 只依赖「长度 + 按索引取一条」或「可迭代流」;Dataset 决定单条样本格式collate_fn 决定多条如何成 batch
  • Padding__getitem__/get() 内对单条样本执行,collate_fn 在 DataLoader 内对 list of 样本执行;两者配合得到固定形状的 batch。
类型优点缺点适用场景
map-style Dataset支持随机访问、shuffle、多进程索引分片需事先能确定长度绝大多数训练/验证
IterableDataset可流式、不必装全量进内存不能真正 shuffle,多进程需自行分片大日志、管道、实时流
TensorDataset实现极简、零拷贝仅适合已在内存的等长 tensor小规模、快速实验
ConcatDataset/Subset逻辑视图、无拷贝仅做索引映射多源合并、比例划分
StackDataset多集对齐、tuple/dict 输出要求各集长度一致多模态按样本对齐
ChainDataset流式串联、无预加载仅支持 IterableDataset多流顺序拼接
PyG Dataset标准化的下载/process/transform需实现 len/get 与文件列表图结构、需持久化处理结果
HuggingFace datasetsmap/filter 强、Arrow 大表友好与 PyTorch 无继承关系NLP/表格、Hub 数据、大批量预处理
本项目 BaseDataset多任务、比例采样、与训练脚本深度集成依赖配置与 pkl 目录结构多任务驾驶模型训练与仿真
本项目 GTDataset与 PyG 生态一致、接口简单未用 PyG 的 process(),本质仍是 pkl 列表需要 DataLoader + PyG 风格时的数据入口

所有 PyTorch map-style 数据集的抽象基类。基类本身不包含任何文件管理、transform 或下载逻辑,只定义与 DataLoader 的协议

  • 必须实现__getitem__(self, index) — 按索引返回一条样本。
  • 强烈建议实现__len__(self) — 返回数据集大小;多数 Sampler 和 DataLoader 默认行为依赖它。
  • 可选__getitems__(self, indices) — 批量取多条样本,用于 DataLoader 加速;未实现时内部会多次调用 __getitem__
  • 内置便捷方法__add__(self, other) — 返回 ConcatDataset([self, other]),可用 ds1 + ds2 拼接。

与 DataLoader 的契约:DataLoader 通过 len(dataset) 得到长度,通过 dataset[i](或 dataset.__getitems__(indices))取样本;索引由 Sampler 提供,通常为 0 .. len(dataset)-1 的某种排列。若使用非整数 key 或不可求长度,需自定义 Sampler。

  • map-style:有固定长度、支持整数索引、可随机访问;与 IterableDataset(仅 __iter__,无索引)相对。
  • 为何以「长度 + 按索引取一条」为核心:DataLoader 只关心「有多少条」和「第 i 条是什么」;存储、解码、增强等都在子类的 __getitem__ 内完成,接口最小且稳定。
  • **__getitems__ 的作用**:当 DataLoader 一次要取一批索引时,若子类实现了 __getitems__(indices),会优先调用它以减少 Python 调用次数、便于做批量 I/O,否则会逐次调用 __getitem__
  • **__add__**:语法糖,ds1 + ds2 等价于 ConcatDataset([ds1, ds2]),逻辑拼接、无数据拷贝。

继承并实现 __len____getitem__ 即可接入 DataLoader;样本格式由业务决定,collate_fn 需与之匹配。需要加速时可实现 __getitems__。

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data        # shape: (N, ...)
        self.labels = labels    # shape: (N,)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

    # 可选:批量取样本,供 DataLoader 优化用
    def __getitems__(self, indices):
        return [(self.data[i], self.labels[i]) for i in indices]

ds = MyDataset(torch.randn(100, 32), torch.randint(0, 10, (100,)))
print(len(ds))         # 100
print(ds[0][0].shape)  # torch.Size([32])

# __add__ 示例
ds2 = MyDataset(torch.randn(50, 32), torch.randint(0, 10, (50,)))
combined = ds + ds2    # ConcatDataset
print(len(combined))   # 150

基于 Dataset / IterableDataset,PyTorch 在 torch.utils.data 中提供以下组件,按用途分为:包装类(TensorDataset、StackDataset)、组合类(ConcatDataset、Subset、ChainDataset)、流式基类(IterableDataset)、工具函数(random_split)。

类/函数用途核心接口/行为
TensorDataset包装多个第 0 维等长的 tensor__init__(*tensors)__getitem__(idx)(t[0][idx], t[1][idx], ...)__len__;属性 .tensors
ConcatDataset逻辑拼接多个 map-style Dataset__init__(datasets)__len__ 为各集长度之和;__getitem__(idx)cumulative_sizes 映射到对应子集;静态方法 cumsum(sequence);属性 cumulative_sizes不支持 IterableDataset
Subset按索引列表取子集__init__(dataset, indices)__getitem__(idx)dataset[indices[idx]],支持 list 索引;__getitems__(indices) 委托底层以支持 DataLoader 批取;__len__len(indices)
StackDataset多数据集「按样本对齐」堆叠为 tuple 或 dict__init__(*args)__init__(**kwargs)__getitem__(idx) 返回 (ds0[idx], ds1[idx], ...){k: ds[idx] for k, ds}__getitems__ 若子集支持则批取;__len__ 要求各数据集长度一致
IterableDataset流式数据基类实现 __iter__,无 __len__/__getitem____add__ 返回 ChainDataset;多 worker 时需用 get_worker_info() 分片
ChainDataset串联多个 IterableDataset__init__(datasets)__iter__ 依次 yield from 各集;__len__ 为各集长度之和(若均有 __len__
random_split按长度或比例随机划分random_split(dataset, lengths, generator=None) 返回 list[Subset]lengths 可为整数列表或和为 1 的比例列表
  • TensorDataset:所有传入 tensor 的第 0 维必须相同;适合已在内存的简单特征+标签;实现极简、零拷贝。
  • ConcatDataset:子集类型可不同;索引自动偏移(含负索引),无数据拷贝;不能包含 IterableDataset(会 AssertionError)。
  • Subsetindices 可为任意整数序列,可重复;不拷贝数据。若子类重写 __getitem__,必须同时重写 __getitems__,否则 DataLoader 会报错。
  • StackDataset:多个 Dataset 长度必须一致;按「同一索引」从各集取一条再组成 tuple/dict,适合多模态(如图+文)对齐。
  • IterableDataset:不能使用 DataLoader 的 shuffle=True(需在数据源侧打乱);多 worker 时必须在 __iter__worker_init_fn 中按 get_worker_info() 分片,否则每个 worker 会重复全量数据。
  • ChainDataset:仅接受 IterableDataset;流式串联,不预加载。
  • random_splitlengths 若为比例(和为一),会先换算为整数并处理余数;可用 generator 固定随机种子以保证可复现。
import math
import torch
from torch.utils.data import (
    TensorDataset, ConcatDataset, Subset, StackDataset,
    IterableDataset, ChainDataset, random_split,
)

# ---------- TensorDataset ----------
X = torch.randn(200, 16)
y = torch.randint(0, 5, (200,))
tensor_ds = TensorDataset(X, y)
print(len(tensor_ds), tensor_ds[0][0].shape)  # 200, torch.Size([16])

# ---------- ConcatDataset:索引映射与负索引 ----------
combined = ConcatDataset([tensor_ds, tensor_ds])
print(len(combined))  # 400
print(combined.cumulative_sizes)  # [200, 400]
print(combined[0] is tensor_ds[0], combined[200] is tensor_ds[0])  # True True

# ---------- Subset:train/val ----------
train_ds = Subset(tensor_ds, range(0, 160))
val_ds = Subset(tensor_ds, range(160, 200))

# ---------- random_split ----------
train_sub, val_sub = random_split(tensor_ds, [0.8, 0.2], generator=torch.Generator().manual_seed(42))
assert len(train_sub) + len(val_sub) == len(tensor_ds)

# ---------- StackDataset:多模态对齐 ----------
text_len = len(tensor_ds)
# 假设有两个等长数据集
stacked = StackDataset(tensor_ds, Subset(tensor_ds, range(text_len)))
print(stacked[0])  # ((tensor(...), tensor(...)), (tensor(...), tensor(...)))
# dict 形式:StackDataset(a=X_ds, b=y_ds) -> stacked[i] == {"a": ..., "b": ...}

# ---------- IterableDataset + 多 worker 分片 ----------
class RangeIterableDataset(IterableDataset):
    def __init__(self, start, end):
        self.start, self.end = start, end

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            iter_start, iter_end = self.start, self.end
        else:
            per_worker = int(math.ceil((self.end - self.start) / worker_info.num_workers))
            iter_start = self.start + worker_info.id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
        return iter(range(iter_start, iter_end))

it_ds = RangeIterableDataset(0, 10)
# 单进程:list(DataLoader(it_ds, num_workers=0)) 得到 0..9 各一次
# 多进程需上述分片,否则会重复

# ---------- ChainDataset ----------
chained = RangeIterableDataset(0, 3) + RangeIterableDataset(10, 13)  # 0,1,2,10,11,12
print(list(chained))  # [0, 1, 2, 10, 11, 12]

继承自 torch.utils.data.Dataset,增加:

  1. 文件管理raw_dir(原始文件)与 processed_dir(处理后的 .pt),按需 download() / process()
  2. 三阶段 transform
  • pre_filter:在 process() 时过滤,不存盘。
  • pre_transform:在 process() 时执行,结果存盘。
  • transform:每次 get() 时执行,不存盘。

目录与调用链:

root/
  raw/           raw_dir
  processed/     processed_dir

__init__  检查 raw  download();检查 processed  process()
dataset[idx]  get(idx)  [pre_filter / pre_transform 已在 process ]  transform(data)
接口说明
raw_file_names (property)原始文件名列表,用于判断是否需要 download
processed_file_names (property)处理后文件名列表,用于判断是否需要 process
len()数据集大小(注意是 len() 不是 __len__,PyG 内部会调)
get(idx)按索引返回单条样本(注意是 get 不是 __getitem__,基类会在此上应用 transform)

可选:download()process()

对比项DatasetInMemoryDataset
存储每样本一个 .pt全部合并为一个文件
get()每次从磁盘读内存切片(data + slices)
规模大图集小/中图集
是否必须实现 get()否(内部用 data/slices)
阶段触发时机是否存盘典型用途
pre_filterprocess() 一次过滤无效图(如 0 节点)
pre_transformprocess() 一次特征归一化、边属性
transform每次 get()随机 dropout edge、数据增强

PyG 的 Data 对象可直接作为样本;DataLoader 的 collate_fn 通常使用 torch_geometric.loader.DataLoader 自带的 collate,将 list of Data 聚成 Batch(含 batch 向量等)。若使用 torch.utils.data.DataLoader,需自定义 collate 或封装成 tensor/dict。


HuggingFace 的 datasets.Dataset 与 PyTorch 无继承关系,底层为 Apache Arrow 列式存储,支持内存映射、大表高效访问。常用能力与调用关系:

  • load_dataset():从 Hub 或本地(CSV/JSON/Parquet 等)加载,返回 DatasetDatasetDict
  • .map(fn, batched=True, …):对每行或每批应用函数,并行预处理(如 tokenize);batched=True 时传入批,性能更好。
  • .filter(fn):按条件筛样本,返回新 Dataset。
  • .set_format(“torch”, columns=…): 使迭代时输出变为 torch tensor,可直接配合 torch.utils.data.DataLoader(需用 IterableDataset 包装或通过 .with_format 迭代)。
  • .train_test_split():划分 train/test,返回 DatasetDict
  • 为何独立于 PyTorch:面向多框架(PyTorch/TF/JAX)与 Hub 生态,列式存储便于列级操作与流式。
  • 与 PyTorch 对接set_format("torch") 后,迭代得到的是 tensor;可包成 IterableDataset 或逐批取用,也可用 Dataset__getitem__ 封装成 map-style。
# 需安装: pip install datasets
from datasets import load_dataset

ds = load_dataset("imdb", split="train[:100]")  # 取前 100 条
ds = ds.map(lambda x: {"len": len(x["text"].split())}, batched=False)
ds = ds.filter(lambda x: x["len"] > 10)
# set_format("torch", columns=...) 在列已是数值/tensor 时使用,便于进 DataLoader
# 划分
splits = ds.train_test_split(test_size=0.2, seed=42)
train_ds, test_ds = splits["train"], splits["test"]

torchvision / torchaudio 等均继承 torch.utils.data.Dataset,在基类上增加领域 transform 与标准数据接口:

  • torchvisiondatasets.CIFAR10ImageFolder 等;transform 为 PIL → Tensor 或增强(Resize、Normalize、RandomCrop 等);transforms.Compose 串联。
  • torchaudiodatasets.LIBRISPEECH 等;transform 为波形/频谱(Resample、MelSpectrogram 等)。
  • torchtext:文本 pipeline,逐渐被 HuggingFace 替代。

transform__getitem__ 中对单样本执行;与 DataLoader 的 collate_fn 无关。

import torch
from torchvision import datasets, transforms

# 使用 CIFAR10(需能访问或已下载)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
# root 需存在或留空并允许下载
# cifar = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
# 仅演示结构:用 TensorDataset 模拟「图像+标签」
X = torch.rand(100, 3, 32, 32)
y = torch.randint(0, 10, (100,))
ds = torch.utils.data.TensorDataset(X, y)
loader = torch.utils.data.DataLoader(ds, batch_size=8, shuffle=True)
batch = next(iter(loader))
print(batch[0].shape, batch[1].shape)  # torch.Size([8, 3, 32, 32]), torch.Size([8])

ImageFolder:目录结构为 root/class_a/xx.png,自动按子目录名作标签,配合 transform 即可用于分类。


许多项目会在 torch.utils.data.Dataset 或 PyG/HuggingFace 之上再封一层,以统一多任务、多数据源、比例采样、padding 与 collate。本节描述通用模式;若你的项目中有类似 BaseDatasetRLDatasetSplitMergeDatasetGTDataset 等,可对应下列概念与调用关系。

  • BaseDataset 系:从配置与 pkl 文件列表构建;产出「dict of tensors + labels」;支持多任务比例采样(通过 Sampler 或自定义 DataLoader 逻辑)。
  • RLDataset:在 BaseDataset 思路上增加轨迹/步采样、专用 collate(如按步长 pad)。
  • SplitMergeDataset:多数据源按比例或顺序合并,常与 ConcatDataset 或自定义 Sampler 配合。
  • GTDataset:借 PyG 的 Dataset 外壳(目录约定、len/get),但数据来自 pkl 路径列表、按需加载并做 padding,与 BaseDataset 产出形态类似,便于与既有 PyG 风格脚本统一。
  • 单条形态:通常为 dict of tensors + labels,便于 collate_fn 统一拼成 batch(如 default_collate 或自定义 gt_collate_fn)。
  • Padding:变长序列/图在 Dataset 的 __getitem__/get() 内对单条做 pad,或留到 collate_fn 对 list 做 pad;两者分工见下一节。
  • 与 DataLoader 的衔接:Dataset 只负责「单条样本」;Sampler 负责「取哪些、顺序」;collate_fn 负责「多条 → batch」。
  • 多任务驾驶/仿真:多 pkl 源、任务比例、统一入口。
  • 需要 PyG 风格目录与 DataLoader 时:用 GTDataset 封装 pkl 列表,配合 PyG DataLoader 或自定义 collate。

  • 单条 padding(在 Dataset 内):在 __getitem__ 或 PyG 的 get() 中,对单条样本做长度对齐(如 pad 到固定长度或到本 batch 内最大长度前的某个上界)。适合「每条独立 pad 到同一规则」的场景。
  • batch 内 padding(在 collate_fn 内):DataLoader 收集到 list of 样本后,在 collate_fn 里按本 batch 最大长度做 pad,再 stack 成 tensor。更灵活,且只 pad 到当前 batch 所需长度,节省显存。

两者可组合:例如在 Dataset 里做简单截断/最小 pad,在 collate_fn 里再做 batch 维度的 pad 与 mask。

  • 为何需要 padding:变长序列/图无法直接 torch.stack 成规则 tensor;padding 后得到固定形状,便于模型前向。
  • mask:通常与 padding 一起维护一个 mask 张量(如 1 为有效、0 为 pad),供模型忽略 pad 位置。
import torch

def pad_collate(batch):
    """batch: list of (seq, label), seq 为 1D tensor 变长"""
    seqs = [b[0] for b in batch]
    labels = torch.stack([b[1] for b in batch])
    max_len = max(s.size(0) for s in seqs)
    padded = torch.zeros(len(seqs), max_len, dtype=seqs[0].dtype)
    mask = torch.zeros(len(seqs), max_len, dtype=torch.bool)
    for i, s in enumerate(seqs):
        padded[i, :s.size(0)] = s
        mask[i, :s.size(0)] = True
    return padded, mask, labels

# 使用示例
data = [(torch.randn(i + 1), torch.tensor(i % 2)) for i in range(5)]
padded, mask, labels = pad_collate(data)
print(padded.shape, mask.shape, labels.shape)  # [5, 5], [5, 5], [5]

  • Dataset:提供 __len____getitem__(或 IterableDataset 的 __iter__),定义「单条样本」形态。
  • DataLoader:根据 Sampler 产生索引(或直接消费迭代器),在 worker 中取样本,经 collate_fn 聚成 batch,可选多进程与 pin_memory。

因此:Dataset 决定单条格式;collate_fn 决定多条如何成 batch;Sampler 决定取哪些、顺序。

  • map-style:DataLoader 调用 len(dataset)dataset[indices](或 __getitems__(indices));collate_fn 的输入是「list of 单条样本」。
  • IterableDataset:无索引,DataLoader 从 iter(dataset) 取数据;多 worker 时需在 dataset 或 worker_init_fn 中分片,避免重复。
  • 自定义 batch:若需自定义每批的组成(如加权、多任务比例),可传 batch_sampler,或使用 Sampler 控制索引。

相关内容