PyTorch DataLoader 技术解读
PyTorch DataLoader 技术解读
文档索引
| 章节 | 主题 | 主要内容 |
|---|---|---|
| 一、总览 | 整体架构与文档脉络 | 知识结构、数据流、各部分职责与关联、优缺点与适用场景 |
| 二、与 Dataset 的接口 | DataLoader 与数据源的契约 | map-style / IterableDataset、__len__/__getitem__/__iter__、ConcatDataset/Subset |
| 三、Sampler 与索引流 | 取哪些、按何顺序 | Sampler 协议、Random/Sequential/BatchSampler、batch_sampler、DistributedSampler |
| 四、batch 与 collate_fn | 聚批与默认/自定义 collate | batch_size、drop_last、default_collate 行为、自定义 collate、padding |
| 五、多进程 Worker 与传输优化 | 加速取数与 CPU→GPU | num_workers、worker_init_fn、persistent_workers、prefetch_factor、pin_memory、multiprocessing_context |
| 六、迭代行为与使用方式 | 遍历与长度 | for batch in loader、iter/next、len(loader)、多 epoch 注意点 |
| 七、小结与速查 | 汇总与速查表 | 模块对照表、参数速查、数据流一句话 |
一、总览
1.1 DataLoader 的职责与位置
torch.utils.data.DataLoader 是 PyTorch 中**将「数据集」变为「可按 batch 迭代的输入流」**的组件。它不定义数据从哪来、长什么样(由 Dataset 负责),只负责:
- 按何种顺序、每次取多少条:Sampler / batch_size / batch_sampler
- 多条样本如何拼成一个 batch:collate_fn
- 是否用多进程预取、是否锁页内存:num_workers、pin_memory、persistent_workers、prefetch_factor
在训练/验证流程中的位置:
配置/代码 → 构建 Dataset(如 GTDataset、ConcatDataset)
→ 传入 DataLoader(batch 大小、采样、collate_fn、多进程等)
→ 训练循环:for batch in train_dataloader: loss = model(batch) ...1.2 整体架构与知识结构
DataLoader 内部可拆成五条逻辑线,同一次迭代中的关系如下:
┌─────────────────────────────────────────────────────────────────┐
│ DataLoader │
│ │
┌──────────────┐ │ ① 索引流 ② 取样本 ③ 成 batch │
│ Dataset │ │ Sampler → Worker 取 → collate_fn → 输出 │
│ (数据源) │ │ (顺序/批次) dataset[i] (list→batch) │
└──────┬───────┘ │ │ │ │ │
│ │ │ │ │ │
│ │ ┌───▼───┐ ┌────▼────┐ ┌────▼────┐ │
└──────────┼─►│indexes│────────►│ samples │─────────────►│ batch │──────────┼──► 训练/验证
│ │0,3,1..│ │ (单条) │ │(tensor │ │
│ └───────┘ └─────────┘ │ or dict) │ │
│ ▲ ▲ └─────────┘ │
│ batch_size num_workers │
│ shuffle pin_memory │
└─────────────────────────────────────────────────────────────────┘- ① 索引流:Sampler(或由 shuffle 推导的 RandomSampler/SequentialSampler)产生索引;若指定
batch_sampler则直接产生「一批索引」,此时忽略batch_size/shuffle。 - ② 取样本:主进程或 Worker 根据索引调用
dataset[idx](或 IterableDataset 的迭代),得到多条单样本。 - ③ 成 batch:
collate_fn把「list of 单样本」聚合成一个 batch(如 stack 成 tensor、或 dict of tensors),供模型使用。
因此:Dataset 决定「单条是什么」;Sampler 决定「取哪些、按什么顺序」;collate_fn 决定「多条怎么变成 batch」;Worker 与 pin_memory 等决定「取数据与传输的效率」。
1.3 各部分职责与彼此关联
| 部分 | 职责/主题 | 与其它部分的关系 |
|---|---|---|
| Dataset 接口 | 数据从哪来、单条长什么样 | DataLoader 只依赖「长度 + 按索引取一条」或「可迭代流」;格式、增强、IO 全由 Dataset 负责 |
| Sampler | 取哪些索引、顺序 | 产出索引序列(或一批批索引);不接触实际数据,只驱动「谁被取」 |
| batch / collate_fn | 多少条一批、如何聚批 | Worker 取到的是 list of 单样本;collate_fn 把 list 变成模型可用的 batch |
| 多进程与传输 | 加速取数、加速 CPU→GPU | Sampler 产索引 → 在 worker 中执行 dataset[idx] → 预取若干批 → 主进程 collate,可选 pin_memory |
| 迭代 | 如何遍历 batch | 对上述链条的封装:for batch in loader 即重复「取索引→取样本→collate→产出」 |
1.4 优缺点与适用场景对比
| 选项/场景 | 优点 | 缺点 / 注意点 | 适用场景 |
|---|---|---|---|
| map-style Dataset + DataLoader | 支持随机访问、shuffle、多进程索引分片 | 需事先能确定 __len__ | 绝大多数训练/验证 |
| IterableDataset + DataLoader | 流式、不必装全量进内存 | 不能使用 shuffle=True(需在 dataset 内自行打乱),多进程需自行分片 | 大文件、管道、无法事先知道总条数 |
| shuffle=True(默认 RandomSampler) | 每 epoch 随机顺序,利于泛化 | 顺序不可复现(需设 seed) | 训练集 |
| shuffle=False(默认 SequentialSampler) | 顺序固定、可复现 | 不适合训练 | 验证/测试 |
| 自定义 sampler | 可做加权、子集、分布式分片 | 与 shuffle 互斥,需自己保证不重复不遗漏 | 多任务比例采样、DDP、子集验证 |
| batch_sampler | 完全控制每批的索引列表 | 忽略 batch_size、shuffle、sampler、drop_last | 自定义批组成、某些分布式策略 |
| num_workers=0 | 无多进程开销、调试简单 | 取数在主进程,可能拖慢 GPU | 小数据、调试 |
| num_workers>0 | 取数与训练并行,减轻 GPU 等待 | 进程切换与内存占用;Windows 需 if __name__=="__main__" | 大数据、生产训练 |
| pin_memory=True | 加速 CPU→GPU 拷贝 | 仅在与 .to(device, non_blocking=True) 配合时效果明显 | CUDA 训练 |
| persistent_workers=True | epoch 间不销毁 worker,减少 fork 开销 | 首 epoch 后不会重新初始化 worker | 多 epoch 训练 |
二、与 Dataset 的接口
2.1 结构说明
DataLoader 的第一个核心参数是 dataset,类型为:
torch.utils.data.Dataset(map-style):需实现__len__与__getitem__(idx),支持整数索引。torch.utils.data.IterableDataset:只需实现__iter__,无索引;DataLoader 从迭代器取样本并按 batch_size 聚批。
DataLoader 在迭代时**只通过「索引」或「迭代器」**与 dataset 交互,不关心 dataset 内部如何存数据、是否做 transform。
- 对 Dataset:内部用 Sampler 得到索引序列,再在某个进程里执行
dataset[i]得到单条样本。 - 对 IterableDataset:由 dataset 自己产出样本流,DataLoader 负责按
batch_size聚批并调用collate_fn;此时不能使用shuffle=True(需在 dataset 内自行打乱),且 Sampler 行为受限。
2.2 关键概念
- map-style:有长度、可随机访问。适合绝大多数「已知条数、需 shuffle」的场景。
- IterableDataset:无长度(或长度仅作提示)、只能顺序迭代。适合流式数据、大文件、无法事先知道总条数的场景。
- 为什么需要两种:map-style 便于 Sampler 任意排列索引;Iterable 便于与管道、生成器、分布式按 worker 分片对接。
2.3 使用方式
- 使用 map-style Dataset 时:必须实现
__len__和__getitem__(idx),且idx为整数;DataLoader 会用 Sampler 给出的索引调用__getitem__。 - 使用 IterableDataset 时:只需实现
__iter__;多进程时每个 worker 会得到同一迭代器的一份拷贝,若需分片需在 dataset 内根据worker_info自行划分,避免重复数据。 - ConcatDataset、Subset 等包装多个 Dataset,再整体传给 DataLoader,无需改 DataLoader 调用方式;长度与索引映射由这些包装类负责。
2.4 示例代码
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset
class SimpleDataset(Dataset):
def __init__(self, size):
self.size = size
self.data = torch.randn(size, 8)
def __len__(self):
return self.size
def __getitem__(self, idx):
return self.data[idx] # 单条 shape: (8,)
ds1 = SimpleDataset(100)
ds2 = SimpleDataset(50)
combined = ConcatDataset([ds1, ds2]) # 长度 150,索引 0..99 来自 ds1,100..149 来自 ds2
loader = DataLoader(combined, batch_size=16, shuffle=True)
batch = next(iter(loader))
print(batch.shape) # torch.Size([16, 8])三、Sampler 与索引流
3.1 结构说明
Sampler 决定每次迭代时,以什么顺序、用哪些索引访问 Dataset。协议为:实现 __iter__,产出索引序列(整数);可选实现 __len__(返回索引总数,用于 len(loader) 等)。
DataLoader 内部逻辑可简化为:
- 未指定
batch_sampler:使用sampler(或由shuffle推导出的 RandomSampler/SequentialSampler)得到整集索引序列,再按batch_size切分成一批批索引,每批交给 Worker 取样本,最后经collate_fn成 batch。 - 指定了
batch_sampler:直接使用它产出「一批索引」(每个元素为 index 的 list),此时忽略batch_size、shuffle、sampler、drop_last。
3.2 常用 Sampler 与关系
| 类型 | 作用 | 典型用法 |
|---|---|---|
| SequentialSampler | 顺序 0,1,…,N-1 | 验证集默认(shuffle=False) |
| RandomSampler | 随机打乱索引(可带 replacement) | 训练集 shuffle=True |
| BatchSampler | 在某个 Sampler 上再按 batch 切分 | 自定义每批索引、与其它 Sampler 组合 |
| DistributedSampler | 多卡下每卡不同子集,不重复 | DDP 训练 |
| SubsetRandomSampler | 仅从给定索引中随机采样 | 子集训练、部分验证 |
索引流 → 取样本 → collate 的链条中,Sampler 只负责「索引」这一环,不接触实际数据。
3.3 关键概念
- 为什么需要 Sampler:将「取哪些、按什么顺序」从 Dataset 中解耦,便于实现 shuffle、加权、分布式、子集等,而不必改 Dataset。
- batch_sampler 与 sampler 互斥:指定
batch_sampler后,DataLoader 不再使用sampler、batch_size、shuffle、drop_last;若只指定sampler,则仍用batch_size在内部做批切分。
3.4 使用方式
- 不传
sampler/batch_sampler时:shuffle=True→ 内部使用RandomSampler(dataset);shuffle=False→SequentialSampler(dataset)。 - 传了
sampler后,不能再传shuffle(会冲突)。 - DDP 时通常为训练集传
DistributedSampler,并在每 epoch 前调用sampler.set_epoch(epoch)以保证各 epoch 打乱不同。
3.5 示例代码
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler, BatchSampler, SubsetRandomSampler
dataset = SimpleDataset(100)
# 等价:不传 sampler,shuffle=False → 顺序
loader_seq = DataLoader(dataset, batch_size=8, shuffle=False)
loader_seq2 = DataLoader(dataset, batch_size=8, sampler=SequentialSampler(dataset))
# 训练常用:随机打乱
loader_shuffle = DataLoader(dataset, batch_size=8, shuffle=True)
# 自定义:只取前 50 个样本,且每批 4 条
indices = list(range(50))
sampler = SubsetRandomSampler(indices)
batch_sampler = BatchSampler(sampler, batch_size=4, drop_last=False)
loader_custom = DataLoader(dataset, batch_sampler=batch_sampler)
# 使用 batch_sampler 时不能再传 batch_size、shuffle、sampler四、batch 与 collate_fn
4.1 结构说明
- batch_size:每批包含的样本数(若使用
batch_sampler则无效)。 - drop_last:若最后不足一整批,是否丢弃该批;
True时len(loader) == floor(len(dataset)/batch_size),常用于训练时保持每批大小一致(如 BN)。 - collate_fn:签名为
(list of 单样本) -> batch。Worker 取到的是「一批索引对应的多条单样本」的 list;collate_fn 把这条 list 变成「一个 batch 对象」,训练代码拿到的就是其返回值。
4.2 默认 collate_fn 行为
未传 collate_fn 时,DataLoader 使用 torch.utils.data.default_collate:
- 若单样本为 tensor:对 list 做
torch.stack(batch)。 - 若单样本为 tuple/list:按元素位置分别 stack(要求各位置类型一致、可 stack)。
- 若单样本为 dict:按 key 分别对 value 做 default_collate(递归)。
- 若为 数字:转为 tensor。
- 不可 stack 的类型(如不等长序列、str、自定义对象):会报错或行为未定义,此时需自定义 collate_fn。
4.3 关键概念
- 为什么需要 collate_fn:单条样本可能是 tensor、tuple、dict 或不等长序列;模型需要的是固定形状的 batch(如
(B, ...)的 tensor 或 dict of tensors)。collate_fn 统一完成「多条 → 一批」的转换,并可在此做 padding、mask 等。 - padding 放在哪:可在 Dataset 的
__getitem__里对单条做固定长度 padding,或在 collate_fn 里按本批最大长度做 padding;后者更省内存、更灵活。
4.4 使用方式
- 单样本为等长 tensor 或简单 tuple/list/dict of tensors:可不传
collate_fn,用默认即可。 - 单样本为不等长序列、或需要 padding/掩码:必须自定义
collate_fn,返回一个 batch(例如{"x": padded_tensor, "mask": mask_tensor})。 drop_last=True时,最后一个不完整 batch 不会出现。
4.5 示例代码
def simple_collate(batch):
"""batch: list of (tensor,) or (tensor, label)"""
if isinstance(batch[0], tuple):
xs = torch.stack([b[0] for b in batch])
ys = torch.tensor([b[1] for b in batch])
return xs, ys
return torch.stack(batch)
# 不等长序列:padding 到本批最大长度
def pad_collate(batch):
seqs = [b[0] for b in batch]
labels = torch.tensor([b[1] for b in batch])
max_len = max(s.size(0) for s in seqs)
padded = torch.zeros(len(seqs), max_len)
for i, s in enumerate(seqs):
padded[i, :s.size(0)] = s
return padded, labels
loader_default = DataLoader(dataset, batch_size=16) # 默认 stack
loader_custom = DataLoader(dataset, batch_size=16, collate_fn=simple_collate)五、多进程 Worker 与传输优化
5.1 结构说明
- num_workers:用于加载数据的子进程数。0 表示主进程加载;>0 时,主进程只负责组 batch 和送进 GPU,取样本在 worker 进程里执行,通过队列把样本传给主进程。
- worker_init_fn:每个 worker 启动时调用一次,可用于设随机种子、改线程数等,签名为
(worker_id: int) -> None。 - persistent_workers:若 True,worker 进程在 epoch 之间不销毁,避免每个 epoch 重新 fork,适合多 epoch 训练(需 num_workers>0)。
- prefetch_factor:每个 worker 预取的 batch 数(仅 num_workers>0 时有效),默认 2。
- pin_memory:若 True,主进程会把 CPU 上的 batch 放在锁页内存,加速 CPU→GPU 拷贝;需在
.to(device, non_blocking=True)时才能更好利用。 - multiprocessing_context:多进程启动方式,如
'spawn'、'fork';Windows 上默认 spawn,Linux 上多为 fork。
关系简要:Sampler 决定取哪些索引;这些索引在 worker 中转为 dataset[idx];多个 worker 并行取,预取若干批;主进程把收到的样本 list 交给 collate_fn 成 batch,再可选地 pin_memory。
5.2 关键概念
- 为什么需要多进程:数据加载往往受 IO 或 CPU 预处理限制,多进程可让「取下一批」与「当前批在 GPU 上训练」并行,减轻 GPU 空转。
- 为什么需要 persistent_workers:每次 fork 会复制进程、重新导入模块,多 epoch 时重复开销大;保持 worker 存活可减少这部分成本。
- 为什么需要 pin_memory:锁页内存在 DMA 传输时不需要先拷到可换页内存,可减少 CPU→GPU 拷贝延迟。
5.3 使用方式
- Windows 上多进程需把数据/模型构建放在
if __name__ == "__main__"内,避免重复 fork 出错。 num_workers过大反而可能因进程切换和内存占用变慢,一般 4~8 常见;数据很轻时可设为 0。persistent_workers=True时,第一个 epoch 后不会重新创建 worker;若设为 True,则 num_workers 必须 >0。pin_memory=True通常与 CUDA 训练一起使用,且配合non_blocking=True的.to(device)。
5.4 示例代码
def worker_init(worker_id):
import numpy as np
np.random.seed(worker_id)
loader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=4,
pin_memory=True,
persistent_workers=True,
prefetch_factor=2,
worker_init_fn=worker_init,
)六、迭代行为与使用方式
6.1 结构说明
DataLoader 是可迭代对象:
- for batch in loader:遍历所有 batch,每个 batch 是
collate_fn的返回值。 - iter(loader):得到迭代器;next(iterator) 得到一个 batch。多 epoch 时每轮应重新
iter(loader)或直接for batch in loader。 - len(loader):batch 的个数。使用默认 sampler 且
drop_last=False时为ceil(len(dataset)/batch_size);drop_last=True时为floor(len(dataset)/batch_size)。若使用batch_sampler,则len(loader) == len(batch_sampler)。
迭代时内部按「Sampler → 取样本 → collate」的顺序产出 batch,与前面各模块一一对应。
6.2 使用方式
- 训练循环中不要重复使用同一个迭代器对象跨 epoch,应每个 epoch 重新
for batch in train_loader或重新iter(train_loader)。 - 若 DataLoader 使用 IterableDataset 且未实现
__len__,len(loader)可能不可用或仅为估计。 - 验证/测试时通常
shuffle=False,保证结果可复现。
6.3 示例代码
loader = DataLoader(dataset, batch_size=16, shuffle=True, drop_last=True)
print("batches per epoch:", len(loader))
for epoch in range(3):
for batch in loader:
pass
# 或手动取一个 batch
it = iter(loader)
batch0 = next(it)
batch1 = next(it)七、小结与速查
7.1 模块对照表
| 模块 | 作用 | 主要参数/组件 |
|---|---|---|
| Dataset 接口 | 数据从哪来、单条长什么样 | dataset(Dataset / IterableDataset)、ConcatDataset / Subset |
| Sampler | 取哪些索引、顺序 | sampler、batch_sampler、shuffle、RandomSampler、SequentialSampler、BatchSampler、DistributedSampler |
| batch 与 collate | 多少条一批、如何聚批 | batch_size、drop_last、collate_fn、default_collate、自定义 padding/mask |
| 多进程与传输 | 加速取数、加速到 GPU | num_workers、worker_init_fn、persistent_workers、prefetch_factor、pin_memory、multiprocessing_context |
| 迭代 | 如何遍历 batch | for batch in loader、iter(loader)、next(it)、len(loader) |
7.2 数据流一句话
Sampler 产索引 → Worker 用索引调 Dataset 取样本 → collate_fn 把样本 list 合成 batch → 迭代输出给训练/验证循环。
7.3 参数速查(DataLoader 常用)
| 参数 | 类型 | 默认 | 说明 |
|---|---|---|---|
| dataset | Dataset / IterableDataset | 必填 | 数据源 |
| batch_size | int | 1 | 每批样本数(batch_sampler 存在时无效) |
| shuffle | bool | False | 是否打乱(与 sampler 互斥) |
| sampler | Sampler | None | 索引采样器(与 shuffle 互斥) |
| batch_sampler | Sampler$$list] | None | 若指定,忽略 batch_size/shuffle/sampler/drop_last |
| num_workers | int | 0 | 加载数据的子进程数 |
| collate_fn | callable | default_collate | (list of 样本) -> batch |
| pin_memory | bool | False | 是否使用锁页内存 |
| drop_last | bool | False | 是否丢弃最后不完整批 |
| persistent_workers | bool | False | 是否在 epoch 间保持 worker 存活 |
| worker_init_fn | callable | None | 每个 worker 启动时调用 (worker_id) |
| prefetch_factor | int | 2 | 每 worker 预取批数(num_workers>0 时有效) |
| multiprocessing_context | str/callable | None | 多进程上下文,如 ‘spawn’/‘fork’ |
理解其中一环即可对应到本文相应小节进行查阅或扩展。