目录

目录

PyTorch 分布式训练与操作工具技术文档

目录
章节主题内容概要
一、概览整体架构与文档脉络为何做分布式、整体数据流、知识结构图、各部分职责与关联、DP vs DDP / 单机 vs 多机 / 原生 vs Lightning 对比
二、核心概念进程、进程组、Rank、Backend、Collective每个概念:是什么、为什么需要、解决什么问题
三、进程组初始化init_process_groupdestroy_process_group初始化流程、参数完整说明、init_method、环境变量、单卡兼容、示例
四、DistributedDataParallelDDP 包装与梯度同步前向/反向/存储结构,构造函数全部参数,model.modulefind_unused_parameters、bucket、示例
五、数据分片DistributedSampler分片原理、构造函数、set_epoch、与 DataLoader 配合、验证集用法、示例
六、集体通信barrier / all_reduce / broadcast / all_gather / gather / reduce各 API 签名、语义、ReduceOp、适用场景、本项目中的用法、示例
七、分布式启动torchrun 与 launch单机/多机命令、全部命令行参数、环境变量、与 init 的衔接、示例
八、Checkpoint 与日志仅 rank 0 写盘与 barrier保存/加载模式、barrier 放置、日志只打 rank 0、示例
九、完整示例原生 DDP 端到端与项目风格一致的可运行脚本与启动命令
十、PyTorch LightningDDPStrategy封装内容、单机/多机用法
十一、速查与小结组件对照表与延伸组件/概念速查表、进阶方向

阅读建议:先读概览与核心概念建立全局图景,再按需跳转到对应章节查阅 API 与示例;做实现时可按「启动 → init → DDP + Sampler → 训练循环 → barrier/rank0 保存」顺序对照各章。


解决的问题

  • 单卡显存/算力不足:模型或 batch 过大,单 GPU 放不下或训练太慢。
  • 提高吞吐:多张卡同时算不同 batch,单位时间内处理的样本数成倍增加,缩短总训练时间。
  • 多机扩展:单机 GPU 数量有限时,通过多台机器进一步扩展总 GPU 数。

本质:把「一份模型 + 一份数据」拆成多份,让多个进程(每进程通常绑定 1 个 GPU)协同计算,在梯度或参数上做同步,使多卡/多机在数学上等价于「大 batch 的单卡训练」(数据并行时)。

┌─────────────────────────────────────────────────────────────────────────────────┐
│                           启动阶段(仅执行一次)                                   │
│  ① 启动工具 (torchrun / torch.distributed.launch)                                 │
│       → 为每个 GPU 启动一个进程,注入 RANK / LOCAL_RANK / WORLD_SIZE 等            │
│  ② 进程内:torch.distributed.init_process_group(backend, ...)                     │
│       → 建立进程组,选定通信后端(NCCL/GLOO),各进程完成握手                       │
│  ③ 模型放到当前 GPU,再用 DistributedDataParallel(DDP) 包装                       │
│       → DDP 在 backward 时自动做梯度 AllReduce,保证各卡参数一致更新               │
│  ④ DataLoader + DistributedSampler                                                │
│       → 每个 rank 只看到数据集的不重叠子集,避免重复训练同一批数据                  │
└─────────────────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────────────────────┐
│                           训练循环(每个 step)                                    │
│  各 rank:数据 → DataLoader(DistributedSampler 分片)→ 前向 → 反向 → DDP AllReduce │
│  同步点:dist.barrier() → 仅 rank 0 写 checkpoint / 打 log → 再 barrier(可选)    │
└─────────────────────────────────────────────────────────────────────────────────┘
  • 基础层:进程、进程组、Rank/Local Rank/World Size、Node、Backend、Collective、梯度同步(第二节概念)。
  • 搭建层:进程组初始化(第三节)→ DDP 包装模型(第四节)→ DistributedSampler 分数据(第五节)→ 集体通信按需使用(第六节)。
  • 入口层:分布式启动(第七节)决定进程数与环境变量,训练脚本内再 init、DDP、Sampler。
  • 工程层:Checkpoint 与日志(第八节)、完整示例(第九节)、Lightning 封装(第十节)。

各节关系:启动 提供进程与环境 → init 建立通信 → DDPSampler 分别负责梯度同步与数据分片 → 集体通信 用于用户级同步与聚合 → Checkpoint/日志示例/Lightning 为落地用法。

组件职责/主题与其它部分的关系
启动工具生成多进程并注入 RANK/WORLD_SIZE/MASTER_*必须在 init_process_group 之前完成;不启动多进程就没有分布式
init_process_group建立进程组、选定 backend所有分布式 API(DDP、collective、barrier)都依赖已初始化的进程组
DDP包装模型,backward 时同步梯度依赖已 init 的进程组;需配合 DistributedSampler 才能数据不重不漏
DistributedSampler按 rank 切分数据索引依赖 rank/world_size;DataLoader 用其替代默认 sampler
集体通信 / barrier同步、聚合张量依赖已 init 的进程组;DDP 内部用 AllReduce;用户用 barrier 做「等齐再往下」
Rank 0 写盘与 log单点持久化与日志避免多进程写同一文件、日志刷屏;常与 barrier 配合
维度DataParallel (DP)DistributedDataParallel (DDP)
实现方式单进程多线程,主卡聚合梯度再广播多进程,每进程一卡,梯度 AllReduce
通信梯度汇总到主卡再广播;主卡瓶颈明显各卡对等通信(如 Ring-AllReduce),主卡无瓶颈
速度多卡扩展差,常比单卡还慢多卡接近线性加速,推荐
使用难度model = nn.DataParallel(model) 即可需多进程启动 + init + DistributedSampler
适用场景快速试验、卡数少(如 2 卡)正式训练、多卡/多机

结论:PyTorch 官方推荐用 DDP 做数据并行;DP 仅适合临时试验。

维度单机多卡多机多卡
通信机内 NVLink/PCIe,延迟低、带宽高跨机网络,延迟与带宽逊于机内
启动torchrun --nproc_per_node=N train.py每台机器各起一份 torchrun,需指定 nnodes、node_rank、master_addr、master_port
环境变量通常由 torchrun 自动设置有时由调度系统(Slurm/K8s)设置 MASTER_*、RANK 等
维度原生 torch.distributed + DDPPyTorch Lightning (DDPStrategy)
控制力完全手控 init、sampler、barrier、保存框架封装,init/sampler/保存由 Trainer 处理
代码量多:需写启动、rank 判断、sampler、barrier少:指定 strategy=DDPStrategy() 即可
适用自定义训练循环、非标准流程标准 train/val 循环、快速实验

  • 进程 (Process)
    是什么:操作系统中的一个独立执行单元;DDP 里通常「一个进程绑定一张 GPU」。
    为什么需要:实现真正并行(多进程可跑在多核/多机上),避免 Python GIL 限制。
    解决什么问题:单进程多线程无法充分利用多卡;多进程才能让每张卡在独立进程中运行,互不阻塞。

  • 进程组 (Process Group)
    是什么:参与集体通信的一组进程的集合;默认用 world 表示「所有进程」。
    为什么需要:为 collective(AllReduce、Barrier 等)划定「和谁通信」。
    解决什么问题:多机多任务时可能只让部分节点参与一次训练,进程组用来区分「这一组」进程。

  • Rank(全局 rank)
    是什么:当前进程在整个分布式任务中的唯一编号,0 到 world_size-1。
    作用:标识「我是谁」,用于划分数据(DistributedSampler)、决定谁写 checkpoint(如 rank 0)。
    解决什么问题:多进程中需要唯一 ID,否则无法分片和选主。

  • Local Rank(本机 rank)
    是什么:当前进程在本机内的 GPU 编号,通常 0 到「本机 GPU 数-1」。
    作用torch.cuda.set_device(local_rank) 绑定当前进程到对应 GPU。
    解决什么问题:多机时每台机器都有 rank 0;用 local_rank 才能正确绑定到本机某张卡。

  • World Size
    是什么:参与该次训练的总进程数(通常等于总 GPU 数)。
    作用:AllReduce、DistributedSampler 等都需要「一共有多少参与方」。
    解决什么问题:集体通信与数据分片都依赖「总数」这一信息。

  • 是什么:一台物理或逻辑机器,上面有多张 GPU。
  • 作用:多机训练时用 node_rank 区分机器,用 nnodes 表示机器数。
  • 解决什么问题:启动多机任务时要指明「有多少台机器、当前是第几台」,以便正确建连(master_addr/master_port)。
  • 是什么:进程间做集体通信时使用的底层实现。
  • 常见选择
    • NCCL:NVIDIA 的多卡/多机 GPU 通信库,CUDA 训练默认推荐
    • GLOO:CPU 或 GPU 都可用,多机无 NCCL 时可用;CPU 上调试时也常用。
    • MPI:需单独安装 MPI 与 PyTorch MPI 后端,多用于 HPC。
  • 作用:决定梯度、张量如何在不同进程间同步。
  • 解决什么问题:不同硬件/环境需要不同通信实现,backend 提供统一 API、多种实现。
  • 是什么:一组进程共同参与的通信原语,如 AllReduce、Broadcast、Barrier、AllGather、Gather、ReduceScatter。
  • 作用
    • AllReduce:各进程提供形状相同的张量,结果变为「所有进程得到同一份聚合后的张量」(如梯度求和/求平均);DDP 的梯度同步即 AllReduce。
    • Barrier:所有进程在此阻塞,直到都执行到 barrier,再一起继续。
    • Broadcast:根进程把张量发到所有进程。
    • AllGather:各进程提供一个张量,汇总后每个进程得到完整列表。
  • 解决什么问题:多进程要「对齐状态」或「汇总结果」,必须依赖集体通信而不是各自为政。
  • 是什么:DDP 在 backward() 时,把各卡梯度做 AllReduce(通常求和再除以 world_size),使各卡用同一份梯度更新参数。
  • 作用:数学上等价于「单卡大 batch」;每卡只算本地 batch 的梯度,通过同步得到全局梯度。
  • 解决什么问题:数据并行下「多卡算多份小 batch,如何得到等价于大 batch 的更新」。

  • 职责:建立默认进程组(world)、根据 backend 初始化通信库、让当前进程加入该组。
  • 调用前:本进程已由启动工具(如 torchrun)启动,且环境变量或参数中已有正确的 rank、world_size;多机时还需 MASTER_ADDR、MASTER_PORT。
  • 调用后:方可使用 DDP、DistributedSampler、dist.barrier()dist.all_reduce() 等。

完整签名(常用参数)

torch.distributed.init_process_group(
    backend,           # str: "nccl" | "gloo" | "mpi"
    init_method=None,  # str: "env://" | "tcp://IP:PORT" | "file:///path"
    world_size=None,   # int,不设则从环境变量读取
    rank=None,         # int,不设则从环境变量读取
    timeout=datetime.timedelta(seconds=1800),  # 集体通信超时
    store=None,        # Store,用于 bootstrap,高级用法
)
  • backend:通信后端;CUDA 训练用 "nccl",CPU 或调试可用 "gloo"
  • init_method
    • "env://":从环境变量读取 RANK、WORLD_SIZE、MASTER_ADDR、MASTER_PORT(与 torchrun 配套,推荐)。
    • "tcp://IP:PORT":指定 master 地址与端口,所有进程需能连到该地址。
    • "file:///path":通过共享文件系统做 rendezvous,适合无共享 IP 的环境。
  • world_size / rank:不传时从环境变量 WORLD_SIZE、RANK 读取(torchrun 会设置)。

为什么需要:多进程必须先「发现彼此」并约定通信方式,否则 DDP 与 collective 无法工作。
解决什么问题:统一建立进程组与通信后端,为后续 DDP 与集体通信提供基础。

  • 作用:销毁进程组,释放通信资源;训练结束后调用,便于干净退出。
  • 何时调用:所有分布式训练与 collective 完成后;若不调用,进程退出时也会清理,但显式调用更规范。
  • dist.get_rank():当前进程的全局 rank。
  • dist.get_world_size():当前进程组大小。
  • dist.is_initialized():进程组是否已初始化(用于单卡/多卡分支判断)。
  • 每个参与训练的进程都必须调用一次 init_process_group,且 backend、world_size 在所有进程中一致;rank 每进程不同。
  • 推荐用 torchrun 启动,由它注入 RANK、LOCAL_RANK、WORLD_SIZE 等,代码里用 os.environ["RANK"] 等读取后传入,或直接使用 init_method="env://" 不传 rank/world_size。
  • 单卡时可不调用 init(world_size==1),后续用 dist.is_initialized() 判断是否走分布式逻辑。
import os
import torch
import torch.distributed as dist

def setup_distributed():
    rank = int(os.environ.get("RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))

    if world_size == 1:
        return rank, world_size, local_rank  # 单卡不 init

    dist.init_process_group(
        backend="NCCL",
        init_method="env://",
        world_size=world_size,
        rank=rank,
    )
    torch.cuda.set_device(local_rank)
    return rank, world_size, local_rank

if __name__ == "__main__":
    rank, world_size, local_rank = setup_distributed()
    print(f"rank={rank}, world_size={world_size}, local_rank={local_rank}")
    if dist.is_initialized():
        dist.destroy_process_group()

  • 角色nn.Module 的包装器;不改变单进程前向/反向的调用方式,在 backward 时插入梯度 AllReduce。
  • 前向:每个进程用本进程的数据做一次前向,得到本地 loss。
  • 反向loss.backward() 时,DDP 注册的 hook 在各卡梯度计算完成后对梯度做 AllReduce(默认等价于求平均),保证各卡参数用同一梯度更新。
  • 存储:原始模型在 model.module;保存/加载单卡权重时用 model.module.state_dict()

要点:各进程模型结构必须一致;各 step 各进程应处理不同数据(由 DistributedSampler 保证),否则等价于重复算同一 batch。

torch.nn.parallel.DistributedDataParallel(
    module,                      # 要包装的 nn.Module,需已在目标 device 上
    device_ids=None,            # 单卡时 [local_rank],多卡时通常 [local_rank]
    output_device=None,         # 默认与 device_ids[0] 一致
    dim=0,                       #  gather 的维度,一般用默认
    broadcast_buffers=True,      # 是否在 forward 前同步 BN 等 buffer
    process_group=None,         # 默认使用默认进程组
    bucket_cap_mb=25,           # 梯度桶大小(MB),影响通信/内存权衡
    find_unused_parameters=False,  # 若有参数未参与计算,须为 True
    gradient_as_bucket_view=False,  # True 可省部分内存,推荐与 Lightning 一致时开启
    static_graph=False,         # 若计算图固定可设为 True,利于优化
)
  • device_ids:当前进程对应的 GPU 列表,单进程单卡时为 [local_rank]
  • find_unused_parameters:若模型存在在 forward 中未参与计算的参数(如部分分支未走),必须设为 True,否则 backward 会报错。
  • gradient_as_bucket_view:梯度以 bucket 视图形式存在,可节省显存。
  • broadcast_buffers:每个 step 前将 BN 等 buffer 从 rank 0 广播到其它 rank,保证一致性。
  • 梯度桶 (bucket):DDP 将参数梯度按 bucket_cap_mb 打成若干桶,按桶做 AllReduce,以重叠通信与计算。
  • model.module:包装后的 DDP 实例的 .module 属性即原始模块;保存/加载时通常用 model.module
  • 必须在 init_process_group 之后、且模型已放到对应 GPU 上再包装:model = model.cuda(local_rank),然后 model = DDP(model, device_ids=[local_rank])
  • 优化器在 DDP 包装后创建,model.parameters() 会正确指向所有需要更新的参数。
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def main():
    rank, world_size, local_rank = setup_distributed()
    model = MyModel().cuda(local_rank)

    if world_size > 1:
        model = DDP(
            model,
            device_ids=[local_rank],
            find_unused_parameters=False,
            gradient_as_bucket_view=True,
        )

    opt = torch.optim.Adam(model.parameters(), lr=1e-4)
    for batch in dataloader:
        out = model(batch)
        loss = out.mean()
        loss.backward()
        opt.step()
        opt.zero_grad()

    if dist.is_initialized() and rank == 0:
        torch.save(model.module.state_dict(), "ckpt.pth")

  • 职责:按 rank 和 world_size 把数据集索引划分成不重叠的子集,每个 rank 只拿到自己那一份索引;DataLoader 据此取数。
  • 与 DDP 的关系:DDP 不关心数据从哪来,但若各进程用相同数据,梯度会重复;DistributedSampler 解决「数据分片」问题,保证不重不漏。

构造函数

torch.utils.data.distributed.DistributedSampler(
    dataset,        # Dataset
    num_replicas=None,  # 默认 dist.get_world_size()
    rank=None,      # 默认 dist.get_rank()
    shuffle=True,   # 是否打乱
    drop_last=False,  # 是否丢弃最后不完整 batch
    seed=0,
)
  • num_replicas / rank:不传则从当前进程组读取;与 init 后的环境一致。
  • shuffle:为 True 时每个 epoch 内索引打乱,但各 rank 仍只看到自己的子集。
  • drop_last:为 True 时丢弃最后不足一 batch 的样本,保证各 rank 迭代次数一致(DDP 要求各进程 step 数一致,否则会挂起)。

set_epoch(epoch):每个 epoch 开始时调用,内部用 epoch 作随机种子,使不同 epoch 的划分方式不同,提高数据利用与随机性。

  • 构造 DataLoader 时传入 sampler=train_sampler,且不要再设 shuffle=True(Sampler 已决定顺序)。
  • 每个 epoch 开始时调用 train_sampler.set_epoch(epoch)
  • 验证/测试若也要分布式跑,可用 DistributedSampler(..., shuffle=False);若只在 rank 0 上验证,可不使用 DistributedSampler。
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

dataset = MyDataset(...)
if dist.is_initialized():
    train_sampler = DistributedSampler(dataset, shuffle=True, drop_last=True)
    train_loader = DataLoader(
        dataset,
        batch_size=32,
        sampler=train_sampler,
        num_workers=4,
        pin_memory=True,
    )
else:
    train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

for epoch in range(num_epochs):
    if dist.is_initialized():
        train_sampler.set_epoch(epoch)
    for batch in train_loader:
        ...

  • 职责:协调多进程步调、聚合张量;DDP 内部已用 AllReduce 同步梯度,用户层更多用 barrier 做同步点,用 all_reduce / all_gather 等做指标汇总或自定义同步。
  • 约束:所有参与同一 collective 的进程都必须调用该 collective,且张量形状/类型一致,否则会挂起或报错。
torch.distributed.barrier(group=None)
  • 作用:所有进程阻塞直到都执行到这一行,再一起继续。
  • 应用:等齐再写文件、再打印、再评估;本项目 eval 脚本中在 rank 0 写结果前常用 barrier。
torch.distributed.all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False)
  • 作用:各进程提供形状相同的 tensor,按 op 聚合后写回各进程的 tensor(inplace)。
  • opReduceOp.SUMReduceOp.AVGReduceOp.PRODUCTReduceOp.MINReduceOp.MAX
  • 应用:梯度或标量指标汇总;
torch.distributed.broadcast(tensor, src, group=None, async_op=False)
  • 作用:src 进程的 tensor 广播到所有进程,覆盖各进程的 tensor。
  • 应用:rank 0 读 checkpoint 后广播到各进程,或广播超参数。
torch.distributed.reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False)
  • 作用:各进程的 tensor 按 op 聚合到 dst 进程。
  • 应用:只在一个进程上得到聚合结果时使用。
torch.distributed.all_gather(tensor_list, tensor, group=None, async_op=False)
  • 作用:各进程提供一个 tensor,汇总后每个进程得到完整的 tensor_list(所有进程的 tensor)。
  • 约束:各进程的 tensor 形状一致;tensor_list 长度为 world_size。
  • 应用:收集各卡上的局部结果成完整列表;
torch.distributed.gather(tensor, gather_list, dst, group=None, async_op=False)
  • 作用:各进程的 tensor 收集到 dst 进程的 gather_list。
  • 应用:只需在根进程汇总时使用。
torch.distributed.reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False)
  • 作用:各进程提供 input_list,先按元素 reduce,再按 rank 切分,每进程得到 output 的一块。
  • 应用:分布式优化器或特定通信模式。
  • 写 checkpoint / 打 log:dist.barrier() → 仅 rank 0 写/打 log → 再 dist.barrier()(可选)。
  • 汇总标量 loss:每卡得到标量后转为 1 元素 tensor,all_reduce(..., op=ReduceOp.AVG),再在 rank 0 打印。
import torch.distributed as dist

# 等所有进程到齐再保存
dist.barrier()
if rank == 0:
    torch.save(model.module.state_dict(), "ckpt.pth")
dist.barrier()

# 汇总各卡 loss 标量
loss_t = torch.tensor([loss_item], device=device)
dist.all_reduce(loss_t, op=dist.ReduceOp.AVG)
global_loss = loss_t.item()
if rank == 0:
    print(f"step loss avg: {global_loss}")

  • 职责:为每个 GPU 启动一个进程,并设置 RANK、LOCAL_RANK、WORLD_SIZE、MASTER_ADDR、MASTER_PORT 等环境变量;用户脚本只需读环境变量并 init_process_group(init_method="env://")
  • 工具torchrun(推荐,PyTorch 1.9+)与 torch.distributed.launch(旧版,行为类似)。
参数含义示例
–nproc_per_node每台机器上的进程数(通常=GPU 数)4
–nnodes机器总数2
–node_rank当前机器编号(0 到 nnodes-1)0 或 1
–master_addr主节点 IP192.168.1.1
–master_port主节点端口29500
脚本后参数传给训练脚本–your_args …
  • RANK:全局 rank。
  • LOCAL_RANK:本机 GPU 编号。
  • WORLD_SIZE:总进程数。
  • MASTER_ADDR / MASTER_PORT:主节点地址与端口(多机时必需)。
  • 单机 4 卡:torchrun --nproc_per_node=4 train.py
  • 多机:每台机器执行一次 torchrun,指定 –nnodes、–node_rank、–master_addr、–master_port;首台为 master,其余 –master_addr 指向首台 IP。
  • 训练脚本入口放在 if __name__ == "__main__": 内,避免 spawn 时重复执行。

单机 4 卡:

torchrun --nproc_per_node=4 train.py --your_args ...

多机(2 机,每机 4 卡):

  • 机器 0:torchrun --nnodes=2 --node_rank=0 --master_addr=192.168.1.1 --master_port=29500 --nproc_per_node=4 train.py
  • 机器 1:torchrun --nnodes=2 --node_rank=1 --master_addr=192.168.1.1 --master_port=29500 --nproc_per_node=4 train.py

  • 保存:通常只在 rank 0 保存一次,避免多进程写同一文件;保存前 dist.barrier() 保证所有进程已到保存点。
  • 加载:可只在 rank 0 读文件再 broadcast,或每个进程各读一份(如共享 NFS);若用 model.module.load_state_dict(...),需在 DDP 包装后对 model.module 操作。
  • 日志:仅 rank == 0 时 print 或写 TensorBoard,避免刷屏和重复。
  • 保存前 barrier,保存后可选再 barrier。
  • 使用 Lightning 时,Trainer 会在 rank 0 保存、写 log,一般无需手写 barrier。
def save_checkpoint(model, path):
    if not dist.is_initialized():
        torch.save(model.state_dict(), path)
        return
    dist.barrier()
    if dist.get_rank() == 0:
        state = model.module.state_dict() if hasattr(model, "module") else model.state_dict()
        torch.save(state, path)
    dist.barrier()

def log_only_rank0(msg):
    if not dist.is_initialized() or dist.get_rank() == 0:
        print(msg)

与项目内 torch2j6mscripts 风格接近的原生 DDP + torchrun 最小可运行示例。

# train_ddp.py
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

def setup():
    rank = int(os.environ.get("RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    if world_size > 1:
        dist.init_process_group(backend="NCCL", init_method="env://")
    torch.cuda.set_device(local_rank)
    return rank, world_size, local_rank

def main():
    rank, world_size, local_rank = setup()
    model = MyModel().cuda(local_rank)
    if world_size > 1:
        model = DDP(model, device_ids=[local_rank])

    dataset = MyDataset(...)
    sampler = DistributedSampler(dataset, shuffle=True) if world_size > 1 else None
    loader = DataLoader(
        dataset,
        batch_size=32,
        sampler=sampler,
        shuffle=(sampler is None),
        num_workers=4,
        pin_memory=True,
    )

    opt = torch.optim.Adam(model.parameters(), lr=1e-4)
    for epoch in range(10):
        if sampler is not None:
            sampler.set_epoch(epoch)
        for batch in loader:
            batch = batch.cuda(local_rank)
            out = model(batch)
            loss = out.mean()
            loss.backward()
            opt.step()
            opt.zero_grad()

        if rank == 0:
            ckpt = model.module.state_dict() if world_size > 1 else model.state_dict()
            torch.save(ckpt, f"ckpt_epoch_{epoch}.pth")

    if world_size > 1:
        dist.destroy_process_group()

if __name__ == "__main__":
    main()

单机 4 卡启动: torchrun --nproc_per_node=4 train_ddp.py


Trainer(strategy=DDPStrategy(...)) 在内部完成:根据 devices 启动多进程(或检测已有分布式环境)、init_process_group、DDP 包装、为 DataLoader 自动注入 DistributedSampler、仅在 rank 0 写 checkpoint 与 log。用户只需构造普通 DataLoader 与 model,调用 trainer.fit(model, train_dataloader, val_dataloader)

  • 单机多卡:Trainer(devices=4, accelerator="gpu", strategy=DDPStrategy(...)),Lightning 会 spawn 多进程。
  • 多机:仍用 torchrun 在每台机器上启动,设置 NNODES、NODE_RANK、MASTER_ADDR、MASTER_PORT;Lightning 检测到已有分布式环境会加入,不再重复 spawn。
  • DDPStrategy(find_unused_parameters=True) 与原生 DDP 含义相同。
import pytorch_lightning as pl
from pytorch_lightning.strategies import DDPStrategy
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

trainer = pl.Trainer(
    max_epochs=100,
    devices=4,
    accelerator="gpu",
    strategy=DDPStrategy(find_unused_parameters=False, gradient_as_bucket_view=True),
    logger=...,
    callbacks=...,
)
trainer.fit(model, train_loader, val_loader)

组件/概念作用解决的问题
进程组 init建立通信组、选定 backend多进程如何发现彼此、用什么后端通信
Rank / Local Rank / World Size唯一标识与规模谁存盘、谁打 log、数据如何分片
DDP包装模型并同步梯度多卡数据并行时梯度一致、等价大 batch
DistributedSampler按 rank 分数据索引各卡数据不重不漏
Barrier / AllReduce / AllGather 等同步与聚合等齐再写文件、汇总指标、收集结果
torchrun多进程启动与环境变量免手写 spawn、统一 RANK/WORLD_SIZE 等
Rank 0 写盘与 log单点持久化与日志避免多进程写同一文件、日志刷屏
Lightning DDPStrategy封装 init/DDP/Sampler/保存标准训练循环下减少样板代码

按「启动 → init → DDP + Sampler → 训练循环 → barrier/rank0 保存」这条线串联,即可覆盖 PyTorch 分布式训练与操作工具的主干用法。

  • 梯度累积:在 DDP 下每 N 个 step 再做一次 step/zero_grad,等效更大 batch。
  • 混合精度:与 DDP 结合使用 AMP(如 torch.cuda.amp),需注意梯度缩放与同步顺序。
  • 多进程 DataLoadernum_workers > 0 时每个训练进程会再起 worker,注意总进程数与资源。
  • 自定义进程组:多任务时可为不同任务建不同进程组,使用 process_group 参数。
  • 弹性训练:torchrun 支持 --max_restarts 等,节点失败时可重启;更复杂弹性可用 PyTorch Elastic。

以下是 PyTorch 分布式训练中最常用的接口,按功能分类整理:


接口描述常用参数示例
dist.init_process_group初始化分布式进程组backend, init_method, rank, world_sizedist.init_process_group('nccl', rank=rank, world_size=4)
dist.get_rank()获取当前进程的 rankrank = dist.get_rank()
dist.get_world_size()获取总进程数world_size = dist.get_world_size()
dist.is_initialized()检查是否已初始化if dist.is_initialized():
dist.destroy_process_group()销毁进程组可选 groupdist.destroy_process_group()
dist.new_group()创建新的进程子组ranksgroup = dist.new_group([0,1,2])

接口描述常用参数示例
dist.send()发送张量tensor, dstdist.send(tensor, dst=1)
dist.recv()接收张量tensor, srcdist.recv(tensor, src=0)
dist.isend()异步发送tensor, dstwork = dist.isend(tensor, dst=1)
dist.irecv()异步接收tensor, srcwork = dist.irecv(tensor, src=0)

接口描述常用参数示例
dist.all_reduce()所有进程规约并广播结果tensor, op, group, async_opdist.all_reduce(tensor, op=dist.ReduceOp.SUM)
dist.reduce()规约到根进程tensor, dst, op, groupdist.reduce(tensor, dst=0, op=dist.ReduceOp.SUM)
dist.broadcast()从根进程广播tensor, src, groupdist.broadcast(tensor, src=0)
dist.all_gather()收集所有进程数据到每个进程tensor_list, tensor, groupdist.all_gather([t1,t2,t3], tensor)
dist.gather()收集数据到根进程tensor, gather_list, dstdist.gather(tensor, gather_list, dst=0)
dist.scatter()从根进程分发数据tensor, scatter_list, srcdist.scatter(tensor, scatter_list, src=0)
dist.reduce_scatter()规约后分发output, input_list, opdist.reduce_scatter(output, input_list, op=dist.ReduceOp.SUM)
dist.all_to_all()全交换output_tensor_list, input_tensor_listdist.all_to_all([out1,out2], [in1,in2])
接口描述示例
dist.all_reduce_coalesced()合并多个张量一起 AllReducedist.all_reduce_coalesced([t1,t2,t3])
dist.all_gather_coalesced()合并多个张量一起 AllGatherdist.all_gather_coalesced(output_lists, input_list)
dist.reduce_scatter_coalesced()合并多个张量一起 ReduceScatterdist.reduce_scatter_coalesced(outputs, input_lists)
接口描述示例
dist.all_reduce_multigpu()多 GPU 单进程 AllReducedist.all_reduce_multigpu([t0,t1])
dist.all_gather_multigpu()多 GPU 单进程 AllGatherdist.all_gather_multigpu([out0,out1], [in0,in1])
dist.reduce_multigpu()多 GPU 单进程 Reducedist.reduce_multigpu([t0,t1], dst=0)

接口描述参数示例
dist.barrier()同步屏障,等待所有进程到达group, async_opdist.barrier()
Work.wait()等待异步操作完成work.wait()
Work.is_completed()检查异步操作是否完成if work.is_completed():

接口描述常用参数示例
DistributedSampler分布式采样器dataset, num_replicas, rank, shufflesampler = DistributedSampler(dataset)
DataLoader配合 sampler 使用dataset, batch_size, samplerloader = DataLoader(ds, sampler=sampler)

接口描述常用参数示例
DistributedDataParallel分布式数据并行封装module, device_ids, output_devicemodel = DDP(model, device_ids=[rank])
DataParallel单机多卡数据并行(不推荐)module, device_idsmodel = DataParallel(model)

接口/工具描述常用参数示例
torch.multiprocessing.spawn启动多进程fn, args, nprocsmp.spawn(train, args=(4,), nprocs=4)
torchrun命令行启动工具--nproc_per_node, --nnodestorchrun --nproc_per_node=4 train.py
dist.launch旧版启动脚本--nproc_per_nodepython -m torch.distributed.launch

枚举/常量描述值示例
ReduceOp规约操作类型SUM, AVG, MAX, MIN, PRODUCT
group.WORLD默认全局进程组dist.group.WORLD
Backend后端类型'nccl', 'gloo', 'mpi'

ReduceOp 详解

dist.ReduceOp.SUM      # 求和
dist.ReduceOp.AVG      # 求平均
dist.ReduceOp.MAX      # 最大值
dist.ReduceOp.MIN      # 最小值
dist.ReduceOp.PRODUCT  # 求积
dist.ReduceOp.BAND     # 按位与
dist.ReduceOp.BOR      # 按位或
dist.ReduceOp.BXOR     # 按位异或

接口描述示例
dist.get_backend()获取当前后端backend = dist.get_backend()
dist.get_rank(group)获取指定组内 rankrank = dist.get_rank(group)
dist.get_world_size(group)获取指定组大小size = dist.get_world_size(group)
dist.batch_isend_irecv()批量异步通信ops = dist.batch_isend_irecv([...])
dist.reduce_op旧版规约操作dist.reduce_op.SUM(已弃用)

接口描述说明
model.require_backward_grad_sync控制是否同步梯度False 可禁用梯度同步(梯度累积)
model.no_sync()上下文管理器,临时禁用梯度同步with model.no_sync(): loss.backward()
register_comm_hook()注册自定义通信钩子model.register_comm_hook(state, hook)

类别最常用接口用途
初始化init_process_group, get_rank启动分布式环境
同步all_reduce梯度聚合
模型封装DistributedDataParallel自动梯度同步
数据加载DistributedSampler分片数据
启动torchrun多进程启动
屏障barrier进程同步

相关内容