import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
 
# TODO 未验证
# 如果不是主进程,需要初始化进程组
if not dist.is_initialized():
    dist.init_process_group(backend='nccl', init_method='tcp://localhost:23456', rank=rank, world_size=world_size)
 
# 加载模型并将其转换为 `DistributedDataParallel` 模型:
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])

参考