Interview Prep · Distributed Training Systems

Distributed Training 面试 Cheat Sheet

DDP / FSDP2 / ZeRO / TP / PP / SP / CP / EP — 公式推导 + From-Scratch 代码 + 25 高频题(L1 必会 · L2 进阶 · L3 顶级 lab)

Source: docs/tutorials/distributed_training_tutorial.md SHA256: 86087779b70e Rendered: 2026-05-19 05:38 UTC

§0 TL;DR Cheat Sheet

一页搞定分布式训练 7 大维度

DP(数据切分) × TP(张量切分) × PP(层切分) × SP(序列切分) × CP(context 切分) × EP(专家切分) × 激活重计算。详见后文 §1–§11 推导。

  1. DDP(PyTorch):每张卡持完整模型,反向时用 NCCL all-reduce 同步梯度,bucket fusion + computation/communication overlap 是关键工程优化。
  2. ZeRO 1/2/3(Rajbhandari et al., SC 2020):把 optimizer state / gradient / parameter 分别切到 $N$ 张卡上,单卡显存从 $\Phi (2+2+12) = 16\Phi$ 字节降到 $16\Phi/N$(fp16 训练,Adam,参数量 $\Phi$)。
  3. FSDP / FSDP2(Zhao et al., VLDB 2023;PyTorch 2.4+, 2024):PyTorch 原生 ZeRO-3。FSDP2 用 per-parameter DTensor sharding 替换 FSDP1 的 flat-parameter 模式,与 TP / PP 复合更顺。
  4. TP(Megatron-LM, Shoeybi 2019;Narayanan SC 2021):列并行 → 行并行 配对,每层只需 2× all-reduce/forward;attention 按 head 切,FFN 第一层 col-parallel + 第二层 row-parallel。
  5. PP:GPipe(Huang NeurIPS 2019)→ 1F1B / PipeDream(Narayanan SOSP 2019)→ Megatron-LM interleaved 1F1B(Narayanan SC 2021)。Bubble ratio $\approx (P-1)/M$($P$ 个 stage,$M$ 个 micro-batch),interleaved 把它再压 $V$ 倍。
  6. SP / CP / EP:SP(Korthikanti et al., MLSys 2023)把 TP 外的 LayerNorm/Dropout activation 沿序列切;CP(Megatron-LM 2024 / Ring Attention Liu 2024)切长 context;EP 把 MoE expert 分到不同 GPU,前向用 all-to-all routing。
  7. 2024 前沿DualPipe(DeepSeek-V3, Dec 2024)双向流水线把 forward/backward 计算与通信完全 overlap;Llama 3 405B(Meta 2024)用 16K H100、3.8e25 FLOPs、54 天 466 次中断(419 unexpected + 47 计划维护);TorchTitan(Liang et al., ICLR 2025)打通 FSDP2 + TP + PP + SP + Float8。

§1 直觉 — 为什么模型大了一张卡装不下

一道送分题但很多人答不全:训练一个模型,单卡显存里到底装了什么?

考虑 fp16 mixed-precision 训练 + Adam optimizer,参数量为 $\Phi$:

单卡占用备注
Parameter (fp16)$2\Phi$forward / backward 用
Gradient (fp16)$2\Phi$backward 累积
Optimizer state$12\Phi$fp32 master copy ($4\Phi$) + Adam $m$, $v$ ($4\Phi + 4\Phi$)
小计(model states)$16\Phi$与 batch size 无关
Activation$O(B \cdot L \cdot D \cdot \text{depth})$与 batch size、seq len 线性,与 depth 线性
Workspace / temp buffervariesNCCL / cuDNN workspace

7B 参数模型仅 model states 就要 $\approx 112$ GB——一张 80GB H100 都装不下。所以分布式训练首要目的是切 model states + activation

三个正交切分维度

训练任务可以沿以下三轴切分,理论上彼此独立、实践中复合使用:

3D / 4D / 5D parallelism 就是这几条轴的笛卡尔积。Llama 3 / DeepSeek-V3 都用 4D(DP × TP × PP × CP/SP/EP 子集)。

§2 NCCL 通信原语与拓扑

分布式训练 99% 的通信交给 NCCL,必须熟悉 5 个原语的语义和通信量。

2.1 五大集合通信原语

设有 $N$ 张 GPU,每张卡持有 size = $S$ 的 buffer。

原语输入 → 输出等价语义Ring 算法通信量 / GPU
all-reduce$N$ 个 $S$ → $N$ 个相同 $S$sum + broadcast$2(N-1)/N \cdot S \approx 2S$
reduce-scatter$N$ 个 $S$ → $N$ 个不同 $S/N$sum 后切片$(N-1)/N \cdot S \approx S$
all-gather$N$ 个 $S/N$ → $N$ 个 $S$把各 rank 的片段拼成完整$(N-1)/N \cdot S \approx S$
broadcastrank0 的 $S$ → $N$ 个 $S$单源$S$(树形)
all-to-all$N \times N$ 块矩阵转置shuffle$(N-1)/N \cdot S \approx S$

关键恒等式

$$\boxed{\;\text{all-reduce} = \text{reduce-scatter} + \text{all-gather}\;}$$

每一步 ring 上各传 $S/N$,共 $2(N-1)$ 步 → 单 GPU 总流量 $2(N-1)S/N \approx 2S$ bytes(与 $N$ 几乎无关,这是 ring 算法的精髓)。

面试加分:NCCL 不只一个算法

NCCL 在小 message 用 tree all-reduce(latency-bound,$O(\log N)$ 跳);大 message 用 ring all-reduce(bandwidth-bound,吞吐最优)。NVLink 拓扑下还有 NVLS (NVLink SHARP),硬件做 reduction(H100/H200 NVSwitch 支持)。

链路单向带宽(H100 代)用途
NVLink 4.0900 GB/s (per GPU, 18 链路聚合)同节点 GPU↔GPU
PCIe 5.0 x1664 GB/sGPU↔CPU、慢路径
InfiniBand NDR 400G50 GB/s (per port)跨节点

经验法则:节点内通信比跨节点快 10-20 倍。所以 TP 一定要塞进一个节点,DP / PP 才跨节点。Llama 3 训练拓扑:TP=8(节点内 NVLink)× PP=16(跨节点 IB)× DP=128。

2.3 PyTorch 中的 NCCL 调用

import torch
import torch.distributed as dist

dist.init_process_group(backend="nccl")  # 后端固定 nccl
rank = dist.get_rank()
world_size = dist.get_world_size()

# all-reduce(默认 SUM;可选 AVG / MIN / MAX / PRODUCT)
buf = torch.ones(1024, device=f"cuda:{rank % 8}") * rank
dist.all_reduce(buf, op=dist.ReduceOp.SUM)
# 现在 buf == sum(0..world_size-1) * ones(1024)

device = torch.device(f"cuda:{rank % 8}")

# reduce-scatter
input_list = [torch.full((1024,), float(rank + i), device=device) for i in range(world_size)]
output = torch.empty(1024, device=device)
dist.reduce_scatter(output, input_list, op=dist.ReduceOp.SUM)

# all-gather
local = torch.full((1024,), float(rank), device=device)
gathered = [torch.empty(1024, device=device) for _ in range(world_size)]
dist.all_gather(gathered, local)

# all-to-all(MoE 必备)
in_split = list(torch.randn(world_size, 1024, device=device).unbind(0))
out_split = [torch.empty(1024, device=device) for _ in range(world_size)]
dist.all_to_all(out_split, in_split)

§3 DDP — DistributedDataParallel

3.1 算法骨架

DDP 是数据并行最朴素的实现:

  1. 复制:每个 rank 都持有完整模型副本(参数 / 梯度 / 优化器状态全量)
  2. 切 batch:global batch $B$ 切成 $N$ 份 micro-batch,每 rank 算自己那份的 forward + backward
  3. 同步梯度:backward 完毕后对所有 gradient 做 all-reduce(SUM 后除以 $N$ 取平均,等价于 AVG)
  4. 本地 optimizer step:每 rank 用同样的 gradient 跑同样的 optimizer,参数始终一致

数学形式(loss $\mathcal{L}$ 在 mini-batch 上的均值):

$$g_\text{global} = \frac{1}{N}\sum_{i=1}^N \nabla_\theta \mathcal{L}(\theta; \mathcal{B}_i) = \text{all-reduce-mean}(g_1, \dots, g_N)$$

每个 rank 拿到的 $g_\text{global}$ 完全相同,因此 $N$ 个 rank 上的 $\theta$ 永远保持一致(同初始化 + 同梯度 + 同 optimizer)。

3.2 Bucket Fusion + Overlap(DDP 工程精髓)

朴素实现:backward 完成 → 把所有梯度按张量拼起来 → 一次 all-reduce。这样 GPU 空等通信,浪费惨重。

PyTorch DDP 实际做法:

Backward time axis →

Layer N:   [grad N]──┐
Layer N-1: [grad N-1]┼─bucket_N─[all-reduce N]
Layer N-2: [grad N-2]┘                       ↓ (后台 stream)
Layer N-3: [grad N-3]──┐
Layer N-4: [grad N-4]──┼─bucket_N-1─[all-reduce N-1]
...                                          ↓
Layer 1:   [grad 1]──────────────[all-reduce 1]

主 stream (compute):    ████████████████████████████████
后台 stream (NCCL):           ▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓
                              ↑ 与计算重叠
bucket 太小 / 太大都不好

太小:communication launch overhead 主导,bandwidth 利用率低;太大:等齐时间长、首个 all-reduce 启动晚。PyTorch 默认 25 MB 是大模型常用 sweet spot。可通过 DDP(bucket_cap_mb=...) 调。

3.3 PyTorch 代码(含 overlap)

import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup_ddp(rank, world_size, local_rank):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(local_rank)   # 多节点时 local_rank != global rank

def main(rank, world_size, local_rank):
    setup_ddp(rank, world_size, local_rank)
    device = torch.device(f"cuda:{local_rank}")

    model = MyModel().to(device)
    model = DDP(
        model,
        device_ids=[local_rank],
        bucket_cap_mb=25,                # bucket 大小
        gradient_as_bucket_view=True,     # 内存优化: grad 直接是 bucket view
        static_graph=False,               # 若图静态可开,启用更多 fusion
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
    loader = make_distributed_loader(rank, world_size)

    for batch in loader:
        x, y = batch[0].to(device), batch[1].to(device)
        loss = nn.functional.cross_entropy(model(x), y)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()       # ← bucket hook 在这里触发 all-reduce
        optimizer.step()

3.4 复杂度

记号约定(全文统一)

后续凡用 $\Phi$ 表示 参数量(个数)。fp16 / bf16 训练下一个参数 2 bytes,故 fp16 weight buffer 大小 = $2\Phi$ bytes。下表"$2\Phi$ bytes"指字节数,"$\Phi$ params"指个数。

说明
单 GPU 显存$16\Phi$ bytes + activationmodel states 不分摊(含 fp16 params/grads + fp32 master/Adam)
单 step 通信量$\approx 2 \cdot 2\Phi = 4\Phi$ bytes(ring all-reduce on fp16 gradient)等价于 reduce-scatter + all-gather
扩展性计算线性 scale;通信 bandwidth $O(1)$(ring 不依赖 $N$)但 latency-bound 小模型仍受 $N$ 影响

DDP 的硬伤:model states 不切。Adam 训练 70B 模型,单卡光 model states 就 1.12 TB,DDP 完全装不下。所以 7B+ 模型必须上 ZeRO / FSDP。

§4 ZeRO 1/2/3 — Zero Redundancy Optimizer

Rajbhandari et al. "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models" (SC 2020) 是分布式训练第二个划时代工作(第一个是 Megatron-TP)。核心思想:DDP 把 model states 复制 $N$ 份是浪费,切成 $N$ 份每张卡只放 $1/N$ 即可,通信只多一点点。

4.1 三阶段显存数学

记 $\Phi$ = 参数量,$N$ = DP world size。fp16 mixed precision + Adam 下单卡 model states:

$$\text{DDP}: \quad 2\Phi + 2\Phi + 12\Phi = 16\Phi$$

ZeRO 三阶段按 「切什么」 区分:

阶段切的内容单卡 model states通信量 (per step)
ZeRO-1optimizer state$2\Phi + 2\Phi + 12\Phi/N$$2\Phi$(all-reduce 等价于 reduce-scatter + all-gather;ZeRO-1 只 reduce-scatter grad)
ZeRO-2optimizer state + gradient$2\Phi + 2\Phi/N + 12\Phi/N$$2\Phi$(同 ZeRO-1)
ZeRO-3opt state + grad + parameter$2\Phi/N + 2\Phi/N + 12\Phi/N = 16\Phi/N$$3\Phi$(forward + backward 各一次 all-gather,backward 一次 reduce-scatter)
极限情况下

$N$ 足够大时(如 1024 张 H100),ZeRO-3 单卡 model states $16\Phi / 1024 = 0.0156\Phi$。65B 模型即 $\approx 1$ GB / GPU,配合 activation checkpoint 单 H100 完全装得下。

4.2 ZeRO-3 工作流(最常用,前向 / 反向 / 优化)

ZeRO-3 把参数本身也切了,所以参数在用之前必须 all-gather 才能 forward。

前向流程(每层 / 每 module):

1. all-gather: 从 N 张卡 collect 完整 W^(ℓ)  ──[通信: φ_ℓ bytes]
2. compute:    y = f(x; W^(ℓ))               ──[计算]
3. release:    丢掉本地不属于自己的 shard    ──[释放显存]

反向流程

1. all-gather: 再次取回 W^(ℓ)(forward 时已释放)──[通信: φ_ℓ]
2. compute:    grad_W^(ℓ), grad_x
3. reduce-scatter: 把 grad_W^(ℓ) reduce 并切回各自的 shard ──[通信: φ_ℓ]
4. release: 丢掉非己 shard

伪代码:

# ZeRO-3 forward (单层抽象, 简化版)
def zero3_forward(layer_idx, x, sharded_W):
    # 1. all-gather full weight
    full_W = all_gather(sharded_W)    # [Φ_ℓ / N, ...] × N -> [Φ_ℓ, ...]
    # 2. compute
    y = layer_forward(x, full_W)
    # 3. 释放 full_W (只留 sharded_W)
    del full_W
    return y

def zero3_backward(layer_idx, dy, sharded_W, cached_input):
    # 1. all-gather (forward 时已释放)
    full_W = all_gather(sharded_W)
    # 2. compute local gradients
    dW_local, dx = layer_backward(dy, full_W, cached_input)
    del full_W
    # 3. reduce-scatter gradient 到对应 shard
    dW_sharded = reduce_scatter(dW_local)  # [Φ_ℓ, ...] / N -> [Φ_ℓ/N, ...]
    return dW_sharded, dx

4.3 ZeRO-1/2/3 vs DDP 通信对比(重要面试题)

总参数 $\Phi$(个数)。下表的单位是 "fp16 weight buffer 等价数"(即 $\Phi$ 在通信流量列代表 $2\Phi$ bytes 的 fp16 buffer 流量)。一次 forward+backward+update 的 per-GPU 通信流量(ring 假设):

模式ForwardBackwardOptim总(fp16 buffer equiv.)
DDP0$2\Phi$ (all-reduce grad)0$2\Phi$(即 $4\Phi$ bytes)
ZeRO-10$\Phi$ (reduce-scatter grad)$\Phi$ (all-gather updated params)$2\Phi$
ZeRO-20$\Phi$ (reduce-scatter grad)$\Phi$ (all-gather)$2\Phi$
ZeRO-3$\Phi$ (all-gather params, on-the-fly)$\Phi$ (all-gather) + $\Phi$ (reduce-scatter grad)0 (已在 backward 中)$3\Phi$
关键结论

ZeRO-1/2 通信量与 DDP 相同($2\Phi$ buffer),但显存大幅下降;ZeRO-3 多 $1.5\times$ 通信,换 $N\times$ 显存下降。实际工程中 ZeRO-3 通信也能通过 prefetch + overlap 部分隐藏,所以是 70B+ 模型的主流选择。换算成字节:DDP $\approx 4\Phi$ bytes,ZeRO-3 $\approx 6\Phi$ bytes。

4.4 ZeRO-Offload / ZeRO-Infinity

ZeRO-Offload(Ren et al., USENIX ATC 2021):把 optimizer state + 部分 gradient 卸到 CPU,CPU 端跑 Adam update。代价:CPU↔GPU PCIe 通信 + CPU 计算慢。适合小集群 + 大模型场景(如单机 8 卡跑 13B)。

ZeRO-Infinity(Rajbhandari et al., SC 2021):再进一步,把参数 / 优化器卸到 NVMe。理论上单机能跑 1T 参数(实际 throughput 极低,主要用于推理 / 微调)。

Offload 用不用是 trade-off

卸 CPU 后单 step 时间通常变长 1.5-3 倍;NVMe 卸更慢 5-10 倍。只在「装不下 + 没钱加卡」时才用。生产训练优先扩 GPU 数。

§5 FSDP / FSDP2 — PyTorch 原生 ZeRO-3

Zhao et al. "PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel" (VLDB 2023) 把 ZeRO-3 思想做进 PyTorch 主干。FSDP2(PyTorch 2.4+, 2024)是大重写:per-parameter DTensor sharding 替换 FSDP1 的 flat-parameter。

5.1 FSDP1 vs FSDP2 核心区别

维度FSDP1 (2022-2023)FSDP2 (2024+)
数据结构FlatParameter:把同一 wrap unit 内所有 param 拼成一个 1D buffer,再 chunkDTensor per-parameter:每个 param 独立按 dim-0 切
状态字典需 all-gather 才能产出 unflattened state dict通信自由 sharded state dict
冻结参数同 unit 内必须全冻或全可训各 param 独立,混合冻可训自然
TP 复合困难(flat-buffer 与 TP 沿不同维切冲突)天然兼容:DTensor 描述多轴 placement (Shard(0), Replicate, Shard(1) 组合)
APIFullyShardedDataParallelfully_shard() 函数式 wrap

5.2 FSDP wrap 策略(最核心的设计决策)

FSDP 不是把整个模型当一个 unit。它按自定义 unit boundary 切,每个 unit 内的参数一起 all-gather / reduce-scatter。

import torch
import torch.nn as nn
from torch.distributed.fsdp import fully_shard           # FSDP2 API
from torch.distributed.fsdp import MixedPrecisionPolicy

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, n_heads)
        self.mlp  = FeedForward(d_model)
        self.ln1  = nn.LayerNorm(d_model)
        self.ln2  = nn.LayerNorm(d_model)
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

# 把每个 TransformerBlock 当一个 FSDP unit
model = MyLLM(n_layers=32, d_model=4096, n_heads=32).cuda()

mp_policy = MixedPrecisionPolicy(
    param_dtype=torch.bfloat16,        # forward 用 bf16
    reduce_dtype=torch.float32,        # gradient reduce 用 fp32 防误差累积
)

for block in model.blocks:
    fully_shard(block, mp_policy=mp_policy)
fully_shard(model, mp_policy=mp_policy)   # root 也要 wrap
wrap 粒度的工程权衡

Unit 越小(如每个 linear 都 wrap):每次 all-gather 只取一层,显存 peak 低,但通信次数多、prefetch 难做;Unit 越大(如整个 block 或多个 block):通信次数少、容易 overlap,但 peak memory 高。TransformerBlock 粒度是 LLaMA / GPT 类的标准答案

5.3 FSDP2 + 混合精度 + 激活检查点

from torch.distributed.fsdp import CPUOffloadPolicy, OffloadPolicy
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper, CheckpointImpl
)

# 1. 激活检查点 (gradient checkpointing) — 重算换显存
def apply_ac(model):
    for i, block in enumerate(model.blocks):
        model.blocks[i] = checkpoint_wrapper(
            block,
            checkpoint_impl=CheckpointImpl.NO_REENTRANT,
        )
apply_ac(model)

# 2. FSDP2 wrap
mp_policy = MixedPrecisionPolicy(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32,
)
offload_policy = OffloadPolicy()                   # 或 CPUOffloadPolicy() 卸到 CPU

for block in model.blocks:
    fully_shard(block, mp_policy=mp_policy, offload_policy=offload_policy)
fully_shard(model, mp_policy=mp_policy)

# 3. 训练 (注意: 必须用 torch._foreach_ 友好的 optimizer)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, fused=True)

for batch in loader:
    loss = compute_loss(model, batch)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

5.4 FSDP vs ZeRO-3 通信量差异(L3 题)

结论完全一样(在 wrap unit = layer 这个粒度上)。FSDP 是 ZeRO-3 的 PyTorch 原生实现,算法等价。差异在工程:

维度DeepSpeed ZeRO-3FSDP2
集成外部库 + monkey patchPyTorch 主干
state_dict需要单独 API 拼装原生 set_model_state_dict / DCP
TP 复合DeepSpeed-Megatron 桥接DTensor 原生
冻结 / LoRA复杂 (flat-param 冲突)自然
编译受限torch.compile(fullgraph=False) 友好

实际选型:新项目用 FSDP2;DeepSpeed 仍在 70B+ MoE / 多框架场景用得多

5.5 HSDP — Hybrid Sharded DP(混合分片)

问题:1024 张 GPU 全部 ZeRO-3,单次 all-gather 跨 1024 卡 → 单卡通信量 $(N-1)/N \cdot \phi \to \phi$ 不变,但 latency 大幅上升(IB 跨节点贵)。

HSDP:组内 ZeRO-3(如节点内 8 卡),组间 DDP。等价于把 1024 卡分成 128 个 ZeRO-3 group(每组 8 卡)。

模式Shard 范围单卡 model states跨节点通信
纯 ZeRO-3全 1024 卡$16\Phi / 1024$
HSDP (8)8 卡组内$16\Phi / 8 = 2\Phi$少(组间是 grad all-reduce)

trade-off:显存换通信效率。Llama 3 在某些训练阶段用 HSDP。

§6 ZeRO++ — 通信优化(2023)

Wang et al. "ZeRO++: Extremely Efficient Collective Communication for Giant Model Training" (NeurIPS MLSys workshop 2023, arXiv 2306.10209)。在 ZeRO-3 之上做了 3 件事:

6.1 qwZ:Quantized Weight all-gather(量化前向通信)

ZeRO-3 forward 时每层都要 all-gather fp16 weight。qwZ 把 fp16 → int8 后再 all-gather,通信量减半

forward (原始):  weight_shard (fp16) ──all-gather──> full_weight (fp16) ──compute
forward (qwZ):   weight_shard (fp16) ──quant──> shard (int8)
                                        ──all-gather──> full (int8)
                                        ──dequant──> full (fp16) ──compute

代价:每次 all-gather 前后做 block-wise quantization / dequantization(block size 一般 2048-4096 元素一组,保留 per-block scale)。block-quant 比 per-tensor quant 精度好得多,对训练 loss 影响 < 1%。

6.2 hpZ:Hierarchical Partition(分层切分)

观察:backward 时 all-gather 的代价比 forward 大(更深更靠后的层先反传)。hpZ 把 weight 在 节点内复制(全 NVLink),跨节点不切:

backward 的 weight all-gather 只在节点内做(NVLink,便宜),不跨 IB。代价:每节点显存翻 8 倍——但 model states 本来就被切到 1024 卡上,乘 8 还远小于 DDP。

6.3 qgZ:Quantized Gradient reduce(量化梯度通信)

backward 末段 reduce-scatter gradient 也走 int8(fp16 → int8 + 量化 reduce)。原始 reduce-scatter 用 SUM 操作不能简单量化(量化 + sum 会累积误差),ZeRO++ 改用 all-to-all + 反量化 + local sum 路径绕开。

6.4 综合效果

ZeRO++ 论文报告:384 GPU 规模 throughput 2.16× 提升,通信量降 4×(fp16→int8 节省 2× × 2 个通信原语)。代价:实现复杂、精度需小心 ablation。

量化通信 ≠ 量化训练

这里只量化集合通信途中的 transient buffer,权重本身的存储和计算仍是 fp16 / bf16。所以 loss 影响通常可忽略;和 fp8 训练(如 Hopper 上的 fp8 GEMM)是两件事。

§7 Tensor Parallel — Megatron-LM

Shoeybi et al. "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism" (arXiv 2019) + Narayanan et al. "Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM" (SC 2021)。

7.1 核心思想:列并行 → 行并行 配对

考虑一个 MLP 层:$Y = \text{GELU}(X W_1) W_2$,$X \in \mathbb{R}^{B \times D}$,$W_1 \in \mathbb{R}^{D \times 4D}$,$W_2 \in \mathbb{R}^{4D \times D}$。

Column Parallel(切 $W_1$ 的输出维 $4D$):

$$W_1 = [W_1^{(1)} \mid W_1^{(2)} \mid \cdots \mid W_1^{(T)}], \quad W_1^{(i)} \in \mathbb{R}^{D \times 4D/T}$$

每个 rank $i$ 持有 $W_1^{(i)}$,独立算 $Y_1^{(i)} = X W_1^{(i)} \in \mathbb{R}^{B \times 4D/T}$。输入 $X$ 全副本,输出沿 column 切

Row Parallel(切 $W_2$ 的输入维 $4D$):

$$W_2 = \begin{bmatrix} W_2^{(1)} \\ W_2^{(2)} \\ \vdots \\ W_2^{(T)} \end{bmatrix}, \quad W_2^{(i)} \in \mathbb{R}^{4D/T \times D}$$

输入沿 row 切:$Z^{(i)} = \text{GELU}(Y_1^{(i)}) W_2^{(i)} \in \mathbb{R}^{B \times D}$。每个 rank 算自己那部分 partial sum,最后 all-reduce 求和:

$$Y = \sum_{i=1}^T Z^{(i)} = \text{all-reduce}(Z^{(1)}, \dots, Z^{(T)})$$

7.2 通信量分析

列并行:输入 $X$ 是全副本,输出沿 column 切 → forward 无通信(输出已分散在各 rank)。

行并行:输入已沿 row 切,输出需 all-reduce → forward 1× all-reduce $(BD)$

col + row 配对:列并行末端不通信(结果直接喂给行并行),行并行末端 1× all-reduce。整个 MLP block forward 只 1× all-reduce

Backward 镜像:每个 MLP block backward 也是 1× all-reduce(gradient 流过 col-row 时镜像通信)。

总计:单个 transformer block(MLP + Attention,attention 也是 col-row)forward + backward = 4× all-reduce, 每次 $BLD$($L$ = seq len)。

7.3 Attention 的 TP 切法

Multi-head attention 的 $W_Q, W_K, W_V \in \mathbb{R}^{D \times D}$。直接按 head 维切(每个 head 独立、列并行天然):

H heads, T-way TP:
  - 每 rank 持 H/T 个 head 的 W_Q^(i), W_K^(i), W_V^(i)
  - 每 rank 独立算自己的 head_h(Q W_Q^h, ...)
  - 输出 W_O 切行并行 (row-parallel) → 末端 all-reduce
为什么 attention 切 head 而不切 dim

head 维天然独立(不同 head 不交互),切 head 不需要跨 rank 通信中间结果;切 hidden dim 则需要在 softmax 前后做额外通信。代码也简洁:直接把 head 维分配给 ranks。

7.4 TP 代码骨架

下面是教学版 col-parallel / row-parallel,明确写出 forward 和 backward 的通信对(关键:col-parallel forward 不通信,backward all-reduce 输入梯度;row-parallel forward all-reduce 输出,backward 不通信):

import math
import torch
import torch.nn as nn
import torch.distributed as dist

class _CopyToTPRegion(torch.autograd.Function):
    """ Identity in forward, all-reduce in backward (col-parallel 入口) """
    @staticmethod
    def forward(ctx, x, tp_group):
        ctx.tp_group = tp_group
        return x
    @staticmethod
    def backward(ctx, grad_out):
        dist.all_reduce(grad_out, op=dist.ReduceOp.SUM, group=ctx.tp_group)
        return grad_out, None

class _ReduceFromTPRegion(torch.autograd.Function):
    """ All-reduce in forward, identity in backward (row-parallel 出口) """
    @staticmethod
    def forward(ctx, x, tp_group):
        dist.all_reduce(x, op=dist.ReduceOp.SUM, group=tp_group)
        return x
    @staticmethod
    def backward(ctx, grad_out):
        return grad_out, None

class ColumnParallelLinear(nn.Module):
    """ Y = X W,  W ∈ R^{in×out},  切 out 维成 T 份;
        forward: input replicated -> 输出 [B, out/T] 沿 last dim 切
        backward: input grad 要 all-reduce(多个 TP rank 各算了 partial dX)"""
    def __init__(self, in_features, out_features, tp_group, tp_size):
        super().__init__()
        assert out_features % tp_size == 0
        self.tp_group = tp_group
        self.out_per_rank = out_features // tp_size
        self.weight = nn.Parameter(torch.empty(self.out_per_rank, in_features))
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
    def forward(self, x):
        # x: [B, in] (replicated). 入口 _CopyToTPRegion 让 backward 自动 all-reduce dX
        x = _CopyToTPRegion.apply(x, self.tp_group)
        return torch.nn.functional.linear(x, self.weight)   # [B, out/T]

class RowParallelLinear(nn.Module):
    """ Y = X W,  W ∈ R^{in×out},  切 in 维成 T 份;
        forward: input [B, in/T] (sharded), 输出 partial sum -> all-reduce
        backward: dW/dX 各自本地算即可,不需通信 """
    def __init__(self, in_features, out_features, tp_group, tp_size):
        super().__init__()
        assert in_features % tp_size == 0
        self.tp_group = tp_group
        self.in_per_rank = in_features // tp_size
        self.weight = nn.Parameter(torch.empty(out_features, self.in_per_rank))
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
    def forward(self, x):
        # x: [B, in/T] (sharded along last dim)
        local_out = torch.nn.functional.linear(x, self.weight)  # [B, out] partial sum
        return _ReduceFromTPRegion.apply(local_out, self.tp_group)

class TPTransformerMLP(nn.Module):
    """ Megatron-LM 标准 col+row 配对 """
    def __init__(self, d_model, tp_group, tp_size):
        super().__init__()
        d_ff = 4 * d_model
        self.fc1 = ColumnParallelLinear(d_model, d_ff, tp_group, tp_size)   # 出 [B, L, 4D/T]
        self.fc2 = RowParallelLinear(d_ff, d_model, tp_group, tp_size)      # 入 [B, L, 4D/T] -> [B, L, D]
    def forward(self, x):
        # x: [B, L, D] (replicated)
        h = torch.nn.functional.gelu(self.fc1(x))   # [B, L, 4D/T]
        return self.fc2(h)                          # [B, L, D] after all-reduce
TP 通信对子(必记)

col-parallel: forward = identity, backward = all-reduce(dX);row-parallel: forward = all-reduce(Y), backward = identity。一个 transformer block (col + row × 2 sets: MLP + attention) 共 4 次 all-reduce,2 在 forward、2 在 backward——这就是上文 §7.2 那 4 次的来源。

每个 transformer block forward + backward = 4× all-reduce of activation tensor 大小 $\approx BLD$。对 7B 模型 $D = 4096$,$B \times L = 2048$ → 单次 $\approx 32$ MB,per-block 4 次 × 32 layer = 128 次 / step, $\approx 4$ GB / step

NVLink 带宽 900 GB/s 下 $\approx 4.4$ ms;IB 50 GB/s 下 $\approx 80$ ms。所以 TP 必须塞进节点内(NVLink 域)。这就是为什么 LLaMA / Llama 3 / GPT-3 都用 TP=8(一个节点 8 卡)而不是 TP=16。

§8 Sequence Parallel — 省 activation memory

Korthikanti et al. "Reducing Activation Recomputation in Large Transformer Models" (MLSys 2023)。

8.1 动机:activation memory 主要是 LayerNorm / Dropout 这种 element-wise op

TP 把 attention / MLP 的中间 activation 切了(沿 head / hidden 维),但 TP 外的部分(LayerNorm 输入输出、Dropout、residual)仍是全副本——每张 TP rank 都存一份 $[B, L, D]$ 大小的 activation。

7B 模型 $B \times L \times D = 2048 \times 4096$, fp16 = 16 MB / activation。一层有约 4-6 个这种 full-replica activation。32 层 × 5 → 2.5 GB / GPU 浪费。

8.2 SP 解法:沿 sequence 维切

让 TP 外的 activation 也沿 $L$(sequence)维切到 TP rank 上。

TP-only:          [B, L, D]                       [B, L, D]
                  全副本                          全副本
                  ↓                                ↑
                  LayerNorm  → split → Attention → all-gather → LayerNorm
                                       (TP)       

TP + SP:          [B, L/T, D]                      [B, L/T, D]
                  沿 L 切                          沿 L 切
                  ↓                                ↑
                  LayerNorm  → all-gather → Attention → reduce-scatter → LayerNorm
                                          (TP)

关键操作变化

通信量:每个 transformer block 在 TP+SP 下需要 1× all-gather + 1× reduce-scatter(替换纯 TP 的 1× all-reduce),二者通信量等价,总量与纯 TP 一致。$\boxed{\text{通信量与纯 TP 完全相同}}$——但 LayerNorm / Dropout 的 activation 从 $BLD$ 降到 $BLD/T$。

SP 通信量为什么和纯 TP 一样

因为 $\text{all-reduce} = \text{reduce-scatter} + \text{all-gather}$,纯 TP 末端 1× all-reduce $= $ SP 模式下 1× reduce-scatter + 后续 1× all-gather(next block 入口)。通信被重新分布,总量不变,但 activation 显存省下来了

8.3 SP activation memory 省了多少(L3 题)

未切分的单 transformer block 总 activation 内存为 $A = A_\text{in} + A_\text{out}$,其中:

纯 TP 单卡 activation:$A^\text{TP}_\text{per-card} = A_\text{in}/T + A_\text{out}$。

TP + SP 单卡 activation:SP 把 TP-外那部分也沿 seq 维切(每卡持 $[B, L/T, D]$)→ $A^\text{TP+SP}_\text{per-card} = A_\text{in}/T + A_\text{out}/T = A/T$。

SP 比纯 TP 省的部分

$$\boxed{\;A^\text{TP}_\text{per-card} - A^\text{TP+SP}_\text{per-card} = A_\text{out} \cdot \left(1 - \frac{1}{T}\right)\;}$$

对 LLaMA-class 模型,$A_\text{out}$ 约占未切分总 activation 的 30-50%。TP=8 下省 $7/8 \cdot A_\text{out}$ ≈ 总 activation 减少 25-40%(与 Korthikanti 2023 报告一致)。

8.4 Selective Activation Recompute

Korthikanti et al. 同篇还提了 selective recompute:只对部分 op(如 attention 的 $QK^\top$ 和 softmax 这种二次 memory 的)做 recompute,其余 op 不重算。比 full activation checkpoint 减 30-40% 计算开销,省同样多的 memory。

§9 Pipeline Parallel

9.1 朴素 PP 与 bubble

把 $L$ 层模型按深度切到 $P$ 个 stage(每 stage $L/P$ 层)。Naive PP 让每个 mini-batch 顺序流过所有 stage:

Stage 0 (GPU0): [F1]                                  [B1]
Stage 1 (GPU1):       [F1]                       [B1]
Stage 2 (GPU2):              [F1]           [B1]
Stage 3 (GPU3):                    [F1] [B1]
                  ↑ 大量 GPU idle ↑                ↑大量 idle ↑

GPU 利用率 = $1/P$,完全不能用。

9.2 GPipe(Huang et al., NeurIPS 2019)

把 mini-batch 切成 $M$ 个 micro-batch,micro-batch 之间流水:

Stage 0: [F1][F2][F3][F4]                                    [B4][B3][B2][B1]
Stage 1:     [F1][F2][F3][F4]                            [B4][B3][B2][B1]
Stage 2:         [F1][F2][F3][F4]                    [B4][B3][B2][B1]
Stage 3:             [F1][F2][F3][F4]            [B4][B3][B2][B1]
         |←─ "warm-up" ─→|       |← 全部 micro-batch backward ─→| "cool-down"

Bubble(GPU 空闲时间占比):

$$\boxed{\;\text{bubble ratio} = \frac{P - 1}{M + P - 1}\;}$$

推导:每个 micro-batch 走完 $P$ stage 要 $P$ 个 step;warm-up 阶段(stage $i$ 等前 $i$ 个 micro-batch)共 $P-1$ 个 step idle;cool-down 阶段对称 $P-1$ 个 step idle;总 step = $M + P - 1$(forward)+ $M + P - 1$(backward)= $2(M+P-1)$;其中 idle = $2(P-1)$。idle 占比 = $(P-1)/(M+P-1)$。

GPipe 的硬伤

必须等所有 $M$ 个 micro-batch 全部 forward 完才能开始 backward——所有 micro-batch 的 activation 都要存住,activation memory 与 $M$ 线性增长。这是 1F1B 解决的问题。

9.3 1F1B / PipeDream(Narayanan SOSP 2019, Megatron-LM-2 SC 2021)

1F1B = 1 Forward 1 Backward:每个 stage 在 forward 一个 micro-batch 后立刻 backward(前一个已经完成的):

Stage 0: [F1][F2][F3][F4][B1][F5][B2][F6][B3][F7][B4]...
                          ↑ 一旦 micro-batch 1 backward 准备好(其 forward 已到 stage P)
                            就立刻 B1,腾出 micro-batch 1 的 activation 内存

关键性质

Stage i 在稳态时的 activation 在内存中:
  - micro-batch i forward 后等 backward 的中间 activation
  - 因为有 P-i 个 stage 在 i 之后,且每 stage 走 forward / backward 都一个 step
  - 所以 stage i 同时存 P-i 个 forward 完但还没 backward 的 activation
  - 第一个 stage (i=0) 存最多: P 个;最后 stage 存最少: 1 个

9.4 Interleaved 1F1B(Megatron-LM-2, SC 2021)

进一步把 bubble 再压。每张 GPU 不持有连续 $L/P$ 层,而是持有 $V$ 段不连续的层virtual stages):

原 1F1B (P=4, L=8 layers):
  GPU0: layers 0,1     GPU1: layers 2,3     GPU2: layers 4,5     GPU3: layers 6,7

Interleaved (P=4, V=2, L=8):
  GPU0: layers 0,4     GPU1: layers 1,5     GPU2: layers 2,6     GPU3: layers 3,7
  (每张 GPU 持 V=2 段, 共 L/(PV) = 1 层 / 段)

每个 micro-batch 在 stage 0 跑 layer 0 → stage 1 跑 layer 1 → ... → stage 3 跑 layer 3 → 再回 stage 0 跑 layer 4 → ... → stage 3 跑 layer 7。一个 micro-batch 通过 stage 列 $V$ 次

Bubble ratio(Narayanan et al., SC 2021, Eq. 4):

$$\boxed{\;\text{interleaved bubble} = \frac{P-1}{V \cdot M + P - 1} \approx \frac{P-1}{V \cdot M}\;}$$

把分母里 $M$ 替换成 $V \cdot M$(流水管子里同时塞了 $V$ 倍 micro-batch 的虚拟 stage 通过次数),warm-up / cool-down 的 $P-1$ 不变。代价:单 micro-batch 跨 GPU 的 send/recv 次数变 $V$ 倍(每个 micro-batch 在 stage 列上走 $V$ 圈),所以 $V$ 不能太大(一般 $V = 2$ 到 $4$)。

interleaved bubble 推导直观版

普通 1F1B:bubble = $(P-1)/(M+P-1)$,分母是 micro-batch 数 $M$ 加上 warm/cool $P-1$;interleaved:每个 micro-batch 跨 GPU 走 $V$ 圈,等效 micro-batch 数变成 $V \cdot M$,bubble = $(P-1)/(V M + P - 1) \approx (P-1)/(VM)$。"$V$ 倍 micro-batch" 是核心直觉。

9.5 1F1B 调度伪代码

def one_f_one_b_schedule(P, M, stage_rank, num_warmup_microbatches):
    """
    stage_rank: 当前 GPU 持有的 pipeline stage 编号 (0..P-1)
    num_warmup_microbatches = P - 1 - stage_rank
    """
    # ===== Warm-up: stage_rank 走 (P-1-stage_rank) 个 forward
    for i in range(num_warmup_microbatches):
        recv_activation_from_prev_stage()
        out = forward(model, activation)
        send_activation_to_next_stage(out)

    # ===== Steady state: F1B1 alternating
    num_microbatches_remaining = M - num_warmup_microbatches
    for i in range(num_microbatches_remaining):
        # forward
        recv_activation_from_prev_stage()
        out = forward(model, activation)
        send_activation_to_next_stage(out)

        # backward (前面 forward 过、现在到的)
        recv_grad_from_next_stage()
        grad_in = backward(model, grad_out)
        send_grad_to_prev_stage(grad_in)

    # ===== Cool-down: 剩余 backward
    for i in range(num_warmup_microbatches):
        recv_grad_from_next_stage()
        grad_in = backward(model, grad_out)
        send_grad_to_prev_stage(grad_in)

9.6 PP 通信特点

PP 只在 stage 边界传 activation / gradient,单次传一个 micro-batch 大小 $\approx B/M \cdot L \cdot D$ bytes。通信量极小(远小于 TP / DP),所以 PP 可以跨节点(IB 也够)。

9.7 PP 的硬伤:load imbalance

§10 Context Parallel — 长序列切分

随着上下文窗口从 4K → 128K → 1M,单卡装不下完整 attention 计算。Context Parallel (CP) 沿 sequence 维切。

10.1 Ring Attention(Liu et al., arXiv 2310.01889, 2024)

把 $Q, K, V$ 沿 seq 维切到 $C$ 个 rank:每 rank 持 $L/C$ 个 token 的 Q/K/V。

Rank 0: Q[0:L/C],   K[0:L/C],   V[0:L/C]
Rank 1: Q[L/C:2L/C], K[L/C:2L/C], V[L/C:2L/C]
...

但 attention 需要每个 query 看到所有 key(causal 下看到所有过去 key)。Ring attention 解法

  1. 每个 rank 用本地 Q 与本地 K, V 算一块 partial attention
  2. 把本地 K, V 沿 ring 转发给下一个 rank
  3. 下一个 rank 用本地 Q × 上一 rank 的 K, V 算 partial attention,累积到 output(online softmax 风格)
  4. 转 $C$ 步后每个 Q 都看遍全部 K, V,attention 完整

通信量:每个 rank 持 $L/C$ 个 token 的 K, V,大小 $2 \cdot L/C \cdot D$(fp16 下乘 2 bytes)。Ring 转 $C-1$ 步才让每个 K, V shard 走遍所有 rank。每 rank 总传输量 $\approx (C-1) \cdot 2 L D / C \approx 2 L D$(与 $C$ 几乎无关——这正是 ring 的常规性质)。

关键:用 online softmax(FlashAttention 同款)让 partial attention 可累积,不需物化完整 $L \times L$ scores。

Ring attention 与 FlashAttention 关系

FlashAttention 是单卡内沿 block tiling 算 attention(block 在 SRAM 内)。Ring attention 是把这个 tiling 推到多卡 / 多节点 level:每个 rank 持有一块 K, V,按 ring 顺序计算 partial attention 并累积。两者数学上一脉相承——都是 online softmax + block-wise accumulation。

10.2 Llama 3 CP 实现

Llama 3 paper (arXiv 2407.21783) 报告在 128K 长上下文 阶段用 CP=16(短上下文 8K 阶段不需 CP,分配回 DP)。结合 FlashAttention v3,128K context 单 step 时间从不可训练 → 几秒级可控。

10.3 CP 与 TP 的正交性

§11 Expert Parallel — MoE 路由

11.1 MoE 基本结构

Mixture-of-Experts 把单个 FFN 替换为 $E$ 个 expert FFN + 一个 gate / router:

$$y = \sum_{e=1}^E G_e(x) \cdot \text{Expert}_e(x), \quad G \in \mathbb{R}^E, \quad \sum_e G_e = 1$$

实际部署用 top-K routing:只选 $G$ 输出最大的 $K$ 个 expert(典型 $K = 1, 2$),其余 expert 不计算。计算量与 expert 数 $E$ 无关,只与 $K$ 有关——这是 MoE 能 scale 参数的关键。

11.2 Expert Parallel:expert 分到不同 GPU

模式单卡 expert 数通信
不并行 (replicate)$E$0
TP-style expert split每 expert 切片,全 GPU 算每个 expertmany all-reduce
EP: 每 GPU 持 $E/N$ 个 expert$E/N$all-to-all dispatch + combine

EP 的 forward 流程:

1. 每 GPU 对自己的 token batch 算 gate → 得到每 token 的 routing decision (assign 给 e_1, e_2, ..., e_K)
2. all-to-all dispatch: 把 token 按 expert 归属发到对应 GPU
   - 输入: 每 GPU 持 B/N 个 token, 每 token 带 K 个 (expert_id, token_data)
   - 输出: 每 GPU 收到所有 GPU 发给本机 expert 的 token
3. 每 GPU 用本地 expert 算 FFN
4. all-to-all combine: 把 expert 输出送回 token 原属 GPU
5. 每 GPU 按 gate weight 合并 K 个 expert 输出
EP 的两个硬伤

(1) load imbalance: gate 倾向于 favor 某几个 expert,导致部分 GPU 过载;解:load balancing loss(Switch Transformer / GShard)。(2) all-to-all 通信量大: 与 token 数 × hidden 成正比,跨节点 IB 是瓶颈。DeepSeek-V3 用 node-limited routing 限制每 token 至多 dispatch 到 $M$ 个节点,减少 IB 流量。

11.3 EP all-to-all 代码骨架(伪代码)

下面是 EP 前向流程的伪代码骨架——重点在 dispatch / combine 的两次 all-to-all,省去了 routing / bucketing 的工程细节(生产代码见 Megatron-Core MoE 或 DeepSpeed-MoE):

def expert_parallel_forward(
    x,                # [B, L, D]
    gate,             # nn.Module: tokens [BL, D] -> (top_k_ids, top_k_w) 各 [BL, K]
    experts_local,    # 本地持有的 E_local 个 expert (nn.ModuleList)
    ep_group, ep_size, ep_rank,
    E_total, K,
):
    """ 教学版:每张 GPU 持 E_local = E_total / ep_size 个 expert """
    B, L, D = x.shape
    tokens = x.reshape(B * L, D)

    # 1. routing: 每 token 选 K 个 expert
    top_k_ids, top_k_w = gate(tokens)             # 都是 [BL, K]

    # 2. expand: top-K 把每 token 复制 K 份, 每份带一个 expert_id
    expanded = tokens.unsqueeze(1).expand(-1, K, -1).reshape(B * L * K, D)
    expand_ids = top_k_ids.reshape(B * L * K)     # [BL*K]
    expand_w   = top_k_w.reshape(B * L * K)

    # 3. 算每个复制 token 该送到哪个 EP rank
    target_rank = expand_ids // (E_total // ep_size)        # [BL*K]

    # 4. 按 target_rank 排序 + 统计每 rank send_count
    perm = torch.argsort(target_rank)
    sorted_tokens = expanded[perm]
    send_counts = torch.bincount(target_rank, minlength=ep_size).tolist()

    # 5a. 交换 send_counts -> recv_counts (每 rank 告诉别人自己要收多少)
    send_t = torch.tensor(send_counts, dtype=torch.int64, device=x.device)
    recv_t = torch.empty_like(send_t)
    dist.all_to_all_single(recv_t, send_t, group=ep_group)         # 一次小 a2a 同步 counts
    recv_counts = recv_t.tolist()

    # 5b. 同步 token 的 expert_id (用于本机分配到对应 local expert)
    sorted_ids = expand_ids[perm]
    received_ids = torch.empty(sum(recv_counts), dtype=sorted_ids.dtype, device=x.device)
    dist.all_to_all_single(received_ids, sorted_ids,
                           output_split_sizes=recv_counts,
                           input_split_sizes=send_counts,
                           group=ep_group)

    # 5c. 同步 token 本体
    received_tokens = torch.empty(sum(recv_counts), D, device=x.device, dtype=x.dtype)
    dist.all_to_all_single(received_tokens, sorted_tokens,
                           output_split_sizes=recv_counts,
                           input_split_sizes=send_counts,
                           group=ep_group)

    # 6. local expert compute
    received_out = torch.zeros_like(received_tokens)
    for local_eid in range(len(experts_local)):
        global_eid = ep_rank * (E_total // ep_size) + local_eid
        mask = (received_ids == global_eid)
        if mask.any():
            received_out[mask] = experts_local[local_eid](received_tokens[mask])

    # 7. all-to-all combine: 反向, 用对调过的 split sizes
    combined = torch.empty_like(sorted_tokens)
    dist.all_to_all_single(combined, received_out,
                           output_split_sizes=send_counts,    # 反向
                           input_split_sizes=recv_counts,
                           group=ep_group)

    # 8. 反排序 + gate weight 合并
    inv_perm = torch.empty_like(perm)
    inv_perm[perm] = torch.arange(perm.numel(), device=perm.device)
    out_expanded = combined[inv_perm]                          # [BL*K, D]
    out_expanded = out_expanded * expand_w.unsqueeze(-1)
    out = out_expanded.view(B * L, K, D).sum(dim=1)            # [BL, D]
    return out.view(B, L, D)
生产实现远比这复杂

真实代码处理 (a) capacity factor 限制 (避免某 expert 过载);(b) drop tokens 当 expert 满了;(c) NVL/IB hierarchical all-to-all(DeepSeek 的 node-limited routing);(d) backward 的镜像 all-to-all + gate gradient。本骨架仅示意 dispatch / combine 双向流程。

§12 Activation Memory 优化

12.1 Gradient Checkpointing(Chen et al., arXiv:1604.06174, 2016)

不存中间 activation,反向时重算。空间 $O(\sqrt{L})$(最优分段),时间 + 33% (1 次 extra forward)。

from torch.utils.checkpoint import checkpoint

def block_forward(x):
    return transformer_block(x)

# 反向时重算 block_forward,不存中间 activation
y = checkpoint(block_forward, x, use_reentrant=False)

12.2 Selective Recompute(Korthikanti 2023)

只对二次 memory 的 op(attention 的 $QK^\top$ 和 softmax)做 recompute,其余存。比 full recompute 减 30-40% 计算 overhead,省同样多 memory。Megatron-LM 默认开启。

12.3 Offload(ZeRO-Infinity)

把 activation / optimizer state 卸到 CPU RAM 或 NVMe。CPU 卸适合 13B-30B 量级单机训练;NVMe 卸吞吐极低,主要用于推理或 trillion-scale 大模型探索。

12.4 Activation Memory 公式(必背)

一个 transformer block 的 activation 大致(fp16, full save):

$$A_\text{block} \approx \underbrace{34 \cdot B \cdot L \cdot D}_{\text{各 LayerNorm/QKV/output/MLP residual}} + \underbrace{5 \cdot B \cdot L^2 \cdot H}_{\text{attention 中间矩阵}}$$

详见 Korthikanti et al. 2023 Table 2。$L \gg D$ 时第二项主导(FlashAttention 直接干掉这部分);$L \approx D$ 时第一项主导(SP 把它降 $T$ 倍)。

§13 综合:3D / 4D / 5D Parallelism

把前面几条轴拼起来。设 world size $W$,

$$W = D_{DP} \times T_{TP} \times P_{PP} \times C_{CP} \times E_{EP}$$

每条轴的角色:

切什么在哪通信原语
DP / FSDPbatch + model states跨节点 IB okall-reduce / reduce-scatter + all-gather
TP单层 hidden / head节点内 NVLinkall-reduce (per block × 4)
SPTP 外的 LayerNorm / Dropout activation节点内(与 TP 共 group)reduce-scatter + all-gather
PPlayer depth跨节点 IB okpoint-to-point send/recv
CPsequence 维节点内或跨节点都可(通信 $\propto L$)ring K/V
EPMoE expert跨节点 IB ok(all-to-all 量大)all-to-all

13.1 Llama 3 405B 训练拓扑(公开信息)

Meta 2024 报告 (2407.21783):

13.2 DeepSeek-V3 训练拓扑(公开信息)

DeepSeek 2024 报告 (2412.19437):

DeepSeek-V3 为什么不用 TP

V3 用了 MLA (multi-head latent attention) + 大量 MoE expert,TP 切 head 收益小(latent attention head 维已经很小);而 EP 的 all-to-all overlap 配合 DualPipe 把通信全藏起来。整体 PP × EP × ZeRO 已够装下 671B 参数。

§14 DualPipe — 2024 流水线前沿

DeepSeek 2024 在 V3 paper 与独立 repo 发布的 DualPipe 算法(arXiv 2412.19437 + github.com/deepseek-ai/DualPipe)。

14.1 核心想法

1F1B 的 bubble 来自 warm-up + cool-down 阶段。DualPipe 让 两个方向同时跑流水线——一组 micro-batch 从 stage 0 → P,另一组从 P → 0,两组在中间相遇时刚好填满 warm-up / cool-down 空隙。

传统 1F1B (P=4):
Stage 0: [F1][F2][F3][F4][B1][F5][B2]...
Stage 1:     [F1][F2][F3][F4][B1][F5][B2]...
                                         (前后两端 bubble)

DualPipe (P=4):
Forward 方向 micro-batch:  [F1][F2][F3][F4]...
Reverse 方向 micro-batch: ...[F4'][F3'][F2'][F1']
Stage 0: 同时 process Forward 的 F + Reverse 的 F' + 对应 B  ← 计算 / 通信完全 overlap

更精确:DualPipe 设计了一个双向 schedule,让每张 GPU 在任意时刻都有两个 micro-batch 在 forward / backward 重叠执行;同时 expert parallel 的 all-to-all 通信也被 hide 在两组 micro-batch 的间隙里。

14.2 DualPipe vs 普通 1F1B 性质

维度1F1BInterleaved 1F1BDualPipe
Bubble$(P-1)/(M+P-1)$$(P-1)/(VM+P-1)$理想 0 (warm-up / cool-down 互补)
Activation memory$P$ × per stage$V \cdot P$ ×$\approx 2 \times P$
通信 overlap部分同 1F1B几乎 100% all-to-all overlap
实现复杂度极高(需 forward & reverse 调度)

14.3 DualPipe 的代价

DeepSeek 用 DualPipe 是因为 EP all-to-all 通信极重,必须想办法藏起来。一般训练任务普通 1F1B 已足够

§15 TorchTitan — PyTorch 原生 4D 平台

Liang et al. (ICLR 2025, arXiv 2410.06511) "TorchTitan: One-stop PyTorch Native Solution for Production-Ready LLM Pre-training"

15.1 设计目标

15.2 代码风格(与 DeepSpeed monkey patch 对比)

# TorchTitan 风格:声明式 DTensor placement
import torch
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import (
    parallelize_module, RowwiseParallel, ColwiseParallel,
    SequenceParallel, PrepareModuleInput,
)
from torch.distributed.fsdp import fully_shard

# 1. 构造 4D 设备 mesh
mesh = init_device_mesh(
    "cuda",
    mesh_shape=(2, 8, 4, 8),                # PP=2, FSDP=8, CP=4, TP=8
    mesh_dim_names=("pp", "fsdp", "cp", "tp"),
)

# 2. 应用 TP + SP(声明每个 sub-module 的 sharding 策略)
parallelize_module(
    model,
    mesh["tp"],
    {
        "attn.wq": ColwiseParallel(),
        "attn.wk": ColwiseParallel(),
        "attn.wv": ColwiseParallel(),
        "attn.wo": RowwiseParallel(),
        "mlp.fc1": ColwiseParallel(),
        "mlp.fc2": RowwiseParallel(),
        "norm1":   SequenceParallel(),
        "norm2":   SequenceParallel(),
    },
)

# 3. FSDP2 wrap(自动认到 mesh["fsdp"])
for block in model.blocks:
    fully_shard(block, mesh=mesh["fsdp"])
fully_shard(model, mesh=mesh["fsdp"])

# 4. PP(pipeline schedule,1F1B interleaved,伪代码示意)
# 真实 ScheduleInterleaved1F1B 需要 List[PipelineStage](每 rank 持 V 个 virtual stage)
from torch.distributed.pipelining import PipelineStage, ScheduleInterleaved1F1B
stages = [PipelineStage(submod_v0, ...), PipelineStage(submod_v1, ...)]  # V=2 virtual stages
schedule = ScheduleInterleaved1F1B(stages, n_microbatches=32, loss_fn=loss_fn)

15.3 Float8 训练(Hopper / Blackwell)

TorchTitan 集成 Float8 训练(H100/B100 fp8 GEMM),权重 / activation 都用 fp8,accumulation 在 fp32。以下是 torchao 的典型用法(具体 API 请以 torchao 当前版本文档为准):

import torch.nn as nn
from torchao.float8 import convert_to_float8_training, Float8LinearConfig

convert_to_float8_training(
    model,
    config=Float8LinearConfig(),    # 默认 dynamic scaling
    module_filter_fn=lambda m, n: isinstance(m, nn.Linear) and "lm_head" not in n,
)

效果:H100 上 throughput +20-40%,loss 几乎无差异(block-wise scaling 关键)。

15.4 异步 Checkpointing

import torch.distributed.checkpoint as DCP

# 异步保存:返回 Future,不阻塞训练
future = DCP.async_save(model.state_dict(), checkpoint_id="step_10000")
# ... 继续训练
future.result()                 # 必要时等

DCP(Distributed Checkpoint)配合 FSDP2 的 sharded state dict,单 checkpoint 写盘只需几秒(每 rank 写自己那部分到分布式存储),不阻塞训练。

§16 通信原语总结表

为了在面试中快速 recall,一张表收尾。

操作当原始数据全部加起来是 $S$(每 rank $S$)每 rank 通信量(ring)用在
all-reduce(buf, SUM)$N \cdot S$ → 全 rank $S$$2(N-1)/N \cdot S$DDP gradient sync, TP block 末端
reduce-scatter(buf)$N \cdot S$ → 各 rank $S/N$$(N-1)/N \cdot S$ZeRO grad reduce, SP 出口
all-gather(shard)$N \cdot S/N$ → 全 rank $S$$(N-1)/N \cdot S$ZeRO-3 forward param, SP 入口
broadcast(buf, src)rank src 的 $S$ → 全 rank $S$$S$(树) / $(N-1)/N \cdot S$(ring)model 初始化广播
all-to-all(buf)$N \cdot N$ 块 → 转置$(N-1)/N \cdot S$MoE EP routing, sequence sharding 变换
point-to-point send/recv1 → 1$S$PP stage 边界

§17 25 高频面试题

按 L1 / L2 / L3 排序,每题 collapsible 答案要点。

L1 必会题(任何 ML 工程岗都会问)

Q1. DDP 和 DP(DataParallel)的区别?
  • DP(nn.DataParallel):单进程多卡,主 rank scatter input + gather output,有 GIL + 主 rank 瓶颈,已 deprecated
  • DDP(DistributedDataParallel):多进程多卡,每进程一张 GPU,NCCL all-reduce 同步梯度
  • DDP 比 DP 快 1.5-3 倍且 scaling 远好

陷阱:说 DP 是"训练快"——错,DP 是历史遗留 API,生产用 DDP。

Q2. NCCL all-reduce 的通信量?
  • Ring 算法:单 GPU 总流量 $2(N-1)/N \cdot S \approx 2S$ bytes
  • 与 $N$ 几乎无关(这是 ring 的精髓)
  • 等价于 reduce-scatter ($S$) + all-gather ($S$)

陷阱:说 $N \cdot S$ 或 $S/N$;忘了 ring 的 step 数和 per-step 流量。

Q3. Adam mixed-precision 训练单卡 model states 多少?
  • 参数 fp16: $2\Phi$
  • 梯度 fp16: $2\Phi$
  • Optimizer (fp32 master + Adam m + v): $4\Phi + 4\Phi + 4\Phi = 12\Phi$
  • 总计 $16\Phi$ bytes

陷阱:忘了 fp32 master copy;或把 Adam $m, v$ 当 fp16(实际 fp32)。

Q4. ZeRO 1/2/3 分别切什么?
  • ZeRO-1: 切 optimizer state
  • ZeRO-2: + gradient
  • ZeRO-3: + parameter(最激进)

单卡 model states 从 $16\Phi$ 降到 $16\Phi/N$(ZeRO-3)。

陷阱:把顺序搞反;不知道 ZeRO-1/2 通信量与 DDP 完全一样(都是 $2\Phi$)。

Q5. FSDP 和 ZeRO-3 有什么区别?
  • 算法层面:完全等价(FSDP 就是 ZeRO-3 的 PyTorch 实现)
  • 工程层面:FSDP 集成在 PyTorch 主干;ZeRO-3 在 DeepSpeed 库
  • FSDP2 改用 per-parameter DTensor sharding,与 TP 复合自然,state_dict 简洁

陷阱:说"FSDP 比 ZeRO 通信少"或反之——错,通信量同。

Q6. Tensor Parallel 切 attention 的哪一维?
  • head 维(每张卡持 $H/T$ 个 head)
  • $W_Q, W_K, W_V$ column-parallel(按输出维切)
  • $W_O$ row-parallel(按输入维切)
  • 每个 attention block forward 1× all-reduce, backward 1× all-reduce

陷阱:说切 hidden_dim 维;或忘了 col + row 配对让中间不通信。

Q7. TP 必须放在哪里?
  • 节点内 NVLink 域(带宽 900 GB/s)
  • 跨节点 IB(50 GB/s)走 TP 性能会暴跌
  • 这就是为什么 TP=8 是黄金尺寸(一个节点 8 卡)

陷阱:说 TP 可以任意跨节点;或把节点内 / 跨节点带宽搞反。

Q8. Pipeline Parallel 的 bubble 怎么算?
  • 朴素 PP:$M = 1$,bubble = $(P-1)/P$(巨大)
  • GPipe / 1F1B with $M$ micro-batch:$(P-1)/(M+P-1)$
  • 通常 $M \geq 4P$ 让 bubble < 20%

陷阱:忘了 $M$ 在分母;只说 $1/P$ 不说 $M$。

Q9. 1F1B 比 GPipe 好在哪?
  • bubble ratio 一样 $(P-1)/(M+P-1)$
  • 但 1F1B 每 stage 同时存活的 activation 数 = $P$(不是 GPipe 的 $M$)
  • 大幅省 activation memory

陷阱:说 1F1B 减小了 bubble——错,1F1B 不减 bubble,省 activation。

Q10. 激活检查点(gradient checkpointing)的代价?
  • 显存:$O(L) \to O(\sqrt{L})$
  • 时间:+ 33%(一次额外 forward)
  • 实际生产几乎必开(70B+ 模型不开就 OOM)

陷阱:说"显存减半"——不准确,理论是 $\sqrt{L}$;或说"时间减半"。

L2 进阶题(research-oriented 岗位)

Q11. ZeRO-3 forward 为什么需要 all-gather?
  • 参数被切到 $N$ 张卡 → 每张卡只持 $1/N$
  • 算某层 forward 时需要完整 $W^{(\ell)}$ → 临时 all-gather 到所有 rank
  • forward 完立刻 release,只保留 shard
  • backward 时再次 all-gather(forward 中已释放)

陷阱:以为参数在所有 rank 上常驻;或忘了 backward 也要重新 all-gather。

Q12. 推导:DDP vs ZeRO-3 通信量差。
  • DDP: backward 1× all-reduce = $2\Phi$,总 $2\Phi$
  • ZeRO-3: forward 1× all-gather ($\Phi$) + backward 1× all-gather ($\Phi$) + 1× reduce-scatter ($\Phi$) = $3\Phi$
  • ZeRO-3 比 DDP 多 50% 通信,换 $N\times$ 显存下降

陷阱:以为 ZeRO 减通信——错,ZeRO 减显存,可能增通信(视阶段)。

Q13. NCCL ring all-reduce 单步流量?
  • 总 step 数:$2(N-1)$(reduce-scatter $N-1$ + all-gather $N-1$)
  • 单 step:每 rank 发送 $S/N$ bytes
  • 单 GPU 总流量:$2(N-1) \cdot S/N \approx 2S$
  • Bandwidth 利用率 $\to 1$ 当 $N \to \infty$

陷阱:说 $S$ 而非 $2S$;或说 $N \cdot S$(忘了 ring 的精髓)。

Q14. SP(Sequence Parallel)省了什么?
  • 不省通信(与纯 TP 通信量一样)
  • TP 外的 activation memory(LayerNorm / Dropout)
  • 把全副本 $[B, L, D]$ 切成 $[B, L/T, D]$
  • 总 activation memory 降 25-40%

陷阱:说"SP 减通信"——错,SP 只是把 all-reduce 重排成 reduce-scatter + all-gather(总量等同),但 activation 切了。

Q15. Interleaved 1F1B 怎么减 bubble?
  • 每 stage 持 $V$ 个 virtual stage(不连续的 $V$ 段 layer)
  • bubble: $(P-1)/(VM+P-1) \approx (P-1)/(VM)$
  • 同 $M$ 下 bubble 降 $V$ 倍
  • 代价:通信次数 × $V$

陷阱:说"interleaved 减计算量"——错,只重排时间轴;或忘了通信 × V 的代价。

Q16. MoE 的 EP all-to-all 通信量?
  • 每 token 选 $K$ 个 expert
  • Dispatch: 每 GPU 把本机 $B/N$ token 发到对应 expert 的 GPU
  • Combine: 反向
  • 总 per-rank 通信 $\approx 2 \cdot K \cdot B/N \cdot D$(双向 all-to-all)

陷阱:忘了 top-K(不是 $E$);忘了 dispatch + combine 两次。

Q17. HSDP 和 FSDP 区别?
  • FSDP:所有 GPU 同一 sharding group
  • HSDP:组内 FSDP / ZeRO-3,组间 DDP
  • HSDP 单 group 内通信少(节点内 NVLink),组间用 grad all-reduce(不需 weight all-gather)
  • trade-off:组内 model states 多(不切到全 world),换跨节点通信减

陷阱:以为 HSDP 减总通信——精度上是减跨节点通信,组内 / 组间是 trade-off。

Q18. Llama 3 405B 用了哪些并行?
  • 16K H100(总 16384,保持不变;CP 上去时 DP 下来)
  • 短 context (8K):TP=8 × CP=1 × PP=16 × DP=128
  • 长 context (128K):TP=8 × CP=16 × PP=16 × DP=8
  • 54 天共 466 次中断(419 unexpected + 47 planned),90%+ 有效训练时间
  • 训练用 BF16(不是 FP8——FP8 是推理量化)
  • Meta 2024, arXiv 2407.21783

陷阱:说用了 EP——错,Llama 3 是 dense 模型不用 EP;或忘了 CP 这一阶段的存在。

Q19. ZeRO++ 三个 trick?
  • qwZ:forward all-gather 用 int8 量化(block-wise quant)
  • hpZ:weight 在节点内复制(NVLink 域),跨节点仍切——backward all-gather 走 NVLink
  • qgZ:backward gradient reduce-scatter 也走 int8
  • 总通信量降 4×,throughput 384 GPU 上 +116% (Wang 2023)

陷阱:以为 ZeRO++ 改训练精度——只在通信途中量化 buffer,权重和计算仍是 fp16/bf16。

Q20. 一个 7B 模型在 8 张 A100-40G 上能跑 fp16 训练吗?
  • Model states (Adam): $16 \times 7$B $= 112$ GB
  • DDP: 单卡 112 GB,单卡放不下(A100 40G)
  • ZeRO-3 / FSDP: 单卡 $112/8 = 14$ GB ✓
  • 加 activation (with checkpoint): 几 GB ✓
  • 加 workspace: 几 GB
  • 结论:FSDP/ZeRO-3 + 激活检查点可以跑

陷阱:说"DDP 也能跑";或忘了 model states 而只算参数 $\Phi$;或忘了 fp32 master copy。

L3 高级变体(顶级 lab / 自研基建岗)

Q21. 推导 1F1B bubble ratio + interleaved 怎么进一步减小。
  • 1F1B:warm-up $P-1$ step (forward fill) + cool-down $P-1$ step (backward drain) + steady $M-P+1$ step
  • 总 step $= 2M$(forward + backward)需要 $2(M + P - 1)$ 时间槽(含 warm/cool 各 $P-1$ idle 槽)
  • bubble = $2(P-1) / [2(M+P-1)] = (P-1)/(M+P-1)$
  • Interleaved 1F1B(Narayanan SC 2021, Eq. 4):每 GPU 持 $V$ 段不连续 layer (virtual stages),每 micro-batch 在 stage 列上走 $V$ 圈 → 流水里等效 $V \cdot M$ 个 micro-batch 在排队
  • bubble = $(P-1) / (V \cdot M + P - 1) \approx (P-1)/(V \cdot M)$
  • 同 $M$ 下 bubble 降 $V$ 倍;代价:跨 GPU send/recv 次数 $\times V$

陷阱:忘了 interleaved 通信 × V 的代价;说"interleaved 减 forward 计算"——只重排时间。

Q22. FSDP vs ZeRO-3 通信量 + 各自适用场景。
  • 通信量完全一样:forward all-gather $\Phi$ + backward all-gather $\Phi$ + reduce-scatter $\Phi$ = $3\Phi$
  • 工程差异:

    • FSDP2: 用 DTensor 描述 per-param sharding,与 TP / PP 复合自然,torch.compile 友好
    • DeepSpeed ZeRO-3: 用 flat-parameter,要 monkey patch 拼装,但 ZeRO++ / Offload / Infinity 生态完整
  • 选型:

    • 新项目(dense + 4D 并行)→ FSDP2 / TorchTitan
    • MoE + 跨框架 + offload → DeepSpeed

陷阱:说"FSDP 比 ZeRO 通信少"——算法相同;或说"FSDP 不支持 offload"——FSDP2 OffloadPolicy 已支持。

Q23. TP + SP 相比纯 TP 省多少 activation memory?
  • 设 transformer block activation $A_\text{block}$
  • TP-内 (attention 中间 / MLP 中间):约 $A_\text{block} \times 0.5-0.7$,TP 切了
  • TP-外 (LayerNorm / Dropout / residual):约 $A_\text{block} \times 0.3-0.5$,TP 不切(全副本)
  • 纯 TP 单卡 activation = $A_\text{TP-内}/T + A_\text{TP-外}$
  • TP+SP 单卡 activation = $A_\text{TP-内}/T + A_\text{TP-外}/T = A_\text{block}/T$
  • 省了 $A_\text{TP-外} \cdot (1 - 1/T)$,约总 activation 的 25-40%(取决于模型)

陷阱:说"SP 让 activation 减 $T$ 倍"——不准确,只对 TP-外那部分;或忘了 SP 通信量与纯 TP 一样。

Q24. DualPipe 相对 1F1B 的核心改进?
  • 双向流水线:一组 micro-batch 从 stage 0 → P,另一组从 P → 0
  • 两组在中间相遇时刚好填满 1F1B 的 warm-up / cool-down bubble
  • 理论 bubble = 0(理想下两侧互补)
  • 关键收益:all-to-all 通信完全 overlap(EP 重要)
  • 代价:activation memory $\times 2$(两个方向都要存);实现极复杂
  • DeepSeek-V3 December 2024 报告 (arXiv 2412.19437) + github.com/deepseek-ai/DualPipe

陷阱:说"DualPipe 是新的 PP 切法"——它是新的 schedule;或忘了 2× activation 的代价。

Q25. 现在 frontier 训练(如 DeepSeek-V3 / Llama 3)的 4D / 5D parallelism 怎么组合?
  • 5 个正交维度:DP / FSDP × TP × PP × CP × EP
  • World size $W = D \times T \times P \times C \times E$
  • 经验法则:

    • TP=8(一个节点 NVLink 域内必须塞下)
    • PP=8-32(跨节点 IB ok,bubble 控制 $M \geq 4P$)
    • CP 看 context 长度(Llama 3 在 128K 用 CP=16,1M 可能 CP=64+)
    • EP 看 MoE expert 数(DeepSeek-V3 用 EP=64)
    • FSDP / DP 用剩下的 world size
  • Llama 3 405B(保持总 16K GPU 不变,CP 上 DP 下):

    • 8K context: TP=8 × CP=1 × PP=16 × DP=128
    • 128K context: TP=8 × CP=16 × PP=16 × DP=8
  • DeepSeek-V3 671B: TP=1 + PP=16 × EP=64 × ZeRO-1 DP=2,共 2048 H800
  • 关键工程点

    • 通信原语按拓扑放(TP 走 NVLink,PP / DP 走 IB)
    • DualPipe / Interleaved 1F1B 压 PP bubble
    • FlashAttention v3 + SP / CP 压 activation
    • fp8 GEMM(H100 / B100)+ bf16 accumulation 提 throughput

陷阱:把维度顺序搞反;忘了 TP 必须节点内;以为 DeepSeek 用了 TP(V3 实际 TP=1)。

§A 附录:完整 4D wrap 代码骨架

下面是一个 4D parallelism 的 minimal 端到端 wrap 示例(FSDP2 + TP + SP + PP),按 TorchTitan 风格。

import torch
import torch.nn as nn
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import (
    parallelize_module, RowwiseParallel, ColwiseParallel,
    SequenceParallel, PrepareModuleInput, PrepareModuleOutput,
)
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed.pipelining import pipeline, SplitPoint, ScheduleInterleaved1F1B
from torch.utils.checkpoint import checkpoint

def build_4d_model(model, world_size):
    """ 4D parallelism wrap (PP × FSDP × CP × TP) """
    # Step 1. 建 4D device mesh
    # 例: 64 GPU = PP=2 × FSDP=4 × CP=1 × TP=8
    mesh = init_device_mesh(
        "cuda",
        mesh_shape=(2, 4, 1, 8),
        mesh_dim_names=("pp", "fsdp", "cp", "tp"),
    )

    # Step 2. TP + SP wrap (mesh["tp"])
    tp_plan = {
        # Attention
        "attn.wq":      ColwiseParallel(),
        "attn.wk":      ColwiseParallel(),
        "attn.wv":      ColwiseParallel(),
        "attn.wo":      RowwiseParallel(),
        # MLP
        "mlp.fc1":      ColwiseParallel(),
        "mlp.fc2":      RowwiseParallel(),
        # SP: norm / residual 沿 seq 维切
        "ln1":          SequenceParallel(),
        "ln2":          SequenceParallel(),
    }
    for block in model.blocks:
        parallelize_module(block, mesh["tp"], tp_plan)

    # Step 3. Activation checkpoint (每 block 一次)
    for i, block in enumerate(model.blocks):
        model.blocks[i] = _ckpt_wrap(block)

    # Step 4. PP split (mesh["pp"])
    # 把 model 切成 2 段(PP=2)
    split_spec = {"blocks.16": SplitPoint.BEGINNING}
    pipe = pipeline(
        model,
        mb_args=(torch.randn(1, 4096, device="cuda"),),
        split_spec=split_spec,
    )
    pp_stage = pipe.build_stage(
        stage_index=mesh["pp"].get_local_rank(),
        device=torch.device(f"cuda:{torch.cuda.current_device()}"),
        group=mesh["pp"].get_group(),
    )

    # Step 5. FSDP2 wrap (mesh["fsdp"])
    mp_policy = MixedPrecisionPolicy(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,
    )
    for module in pp_stage.submod.modules():
        if isinstance(module, TransformerBlock):
            fully_shard(module, mesh=mesh["fsdp"], mp_policy=mp_policy)
    fully_shard(pp_stage.submod, mesh=mesh["fsdp"], mp_policy=mp_policy)

    return pp_stage, mesh

def _ckpt_wrap(block):
    """ 把一个 TransformerBlock 包成 activation checkpoint 形式 """
    class Wrapped(nn.Module):
        def __init__(self, inner):
            super().__init__()
            self.inner = inner
        def forward(self, *args, **kw):
            return checkpoint(self.inner, *args, use_reentrant=False, **kw)
    return Wrapped(block)

# ============================================================
# Training loop with PP schedule
# ============================================================
def train_step_4d(pp_stage, schedule, optimizer, batch):
    """ 1 step of 4D parallel training """
    optimizer.zero_grad(set_to_none=True)
    # PP schedule runs forward + backward across all stages
    losses = []
    schedule.step(batch, losses=losses)  # 内部触发 FSDP all-gather / TP all-reduce / etc.
    optimizer.step()
    return torch.stack(losses).mean()

说明:上述代码是教学骨架,实际生产请直接用 TorchTitan repo(pytorch/torchtitan)——其包含完整的 Trainer、checkpointing、profiling、loss / lr schedule,本节只给概念示意。

§B 参考资料