Distributed Training 面试 Cheat Sheet
DDP / FSDP2 / ZeRO / TP / PP / SP / CP / EP — 公式推导 + From-Scratch 代码 + 25 高频题(L1 必会 · L2 进阶 · L3 顶级 lab)
§0 TL;DR Cheat Sheet
DP(数据切分) × TP(张量切分) × PP(层切分) × SP(序列切分) × CP(context 切分) × EP(专家切分) × 激活重计算。详见后文 §1–§11 推导。
- DDP(PyTorch):每张卡持完整模型,反向时用 NCCL all-reduce 同步梯度,bucket fusion + computation/communication overlap 是关键工程优化。
- ZeRO 1/2/3(Rajbhandari et al., SC 2020):把 optimizer state / gradient / parameter 分别切到 $N$ 张卡上,单卡显存从 $\Phi (2+2+12) = 16\Phi$ 字节降到 $16\Phi/N$(fp16 训练,Adam,参数量 $\Phi$)。
- FSDP / FSDP2(Zhao et al., VLDB 2023;PyTorch 2.4+, 2024):PyTorch 原生 ZeRO-3。FSDP2 用 per-parameter DTensor sharding 替换 FSDP1 的 flat-parameter 模式,与 TP / PP 复合更顺。
- TP(Megatron-LM, Shoeybi 2019;Narayanan SC 2021):列并行 → 行并行 配对,每层只需 2× all-reduce/forward;attention 按 head 切,FFN 第一层 col-parallel + 第二层 row-parallel。
- PP:GPipe(Huang NeurIPS 2019)→ 1F1B / PipeDream(Narayanan SOSP 2019)→ Megatron-LM interleaved 1F1B(Narayanan SC 2021)。Bubble ratio $\approx (P-1)/M$($P$ 个 stage,$M$ 个 micro-batch),interleaved 把它再压 $V$ 倍。
- SP / CP / EP:SP(Korthikanti et al., MLSys 2023)把 TP 外的 LayerNorm/Dropout activation 沿序列切;CP(Megatron-LM 2024 / Ring Attention Liu 2024)切长 context;EP 把 MoE expert 分到不同 GPU,前向用 all-to-all routing。
- 2024 前沿:DualPipe(DeepSeek-V3, Dec 2024)双向流水线把 forward/backward 计算与通信完全 overlap;Llama 3 405B(Meta 2024)用 16K H100、3.8e25 FLOPs、54 天 466 次中断(419 unexpected + 47 计划维护);TorchTitan(Liang et al., ICLR 2025)打通 FSDP2 + TP + PP + SP + Float8。
§1 直觉 — 为什么模型大了一张卡装不下
一道送分题但很多人答不全:训练一个模型,单卡显存里到底装了什么?
考虑 fp16 mixed-precision 训练 + Adam optimizer,参数量为 $\Phi$:
| 项 | 单卡占用 | 备注 |
|---|---|---|
| Parameter (fp16) | $2\Phi$ | forward / backward 用 |
| Gradient (fp16) | $2\Phi$ | backward 累积 |
| Optimizer state | $12\Phi$ | fp32 master copy ($4\Phi$) + Adam $m$, $v$ ($4\Phi + 4\Phi$) |
| 小计(model states) | $16\Phi$ | 与 batch size 无关 |
| Activation | $O(B \cdot L \cdot D \cdot \text{depth})$ | 与 batch size、seq len 线性,与 depth 线性 |
| Workspace / temp buffer | varies | NCCL / cuDNN workspace |
7B 参数模型仅 model states 就要 $\approx 112$ GB——一张 80GB H100 都装不下。所以分布式训练首要目的是切 model states + activation。
训练任务可以沿以下三轴切分,理论上彼此独立、实践中复合使用:
- 数据维度(DP / ZeRO / FSDP):切 batch、模型副本之间通信梯度
- 模型维度(深度方向)(PP):切 layer,stage 之间传 activation
- 模型维度(宽度方向)(TP / SP / CP / EP):切单层内部的张量,每步通信中间结果
3D / 4D / 5D parallelism 就是这几条轴的笛卡尔积。Llama 3 / DeepSeek-V3 都用 4D(DP × TP × PP × CP/SP/EP 子集)。
§2 NCCL 通信原语与拓扑
分布式训练 99% 的通信交给 NCCL,必须熟悉 5 个原语的语义和通信量。
2.1 五大集合通信原语
设有 $N$ 张 GPU,每张卡持有 size = $S$ 的 buffer。
| 原语 | 输入 → 输出 | 等价语义 | Ring 算法通信量 / GPU |
|---|---|---|---|
| all-reduce | $N$ 个 $S$ → $N$ 个相同 $S$ | sum + broadcast | $2(N-1)/N \cdot S \approx 2S$ |
| reduce-scatter | $N$ 个 $S$ → $N$ 个不同 $S/N$ | sum 后切片 | $(N-1)/N \cdot S \approx S$ |
| all-gather | $N$ 个 $S/N$ → $N$ 个 $S$ | 把各 rank 的片段拼成完整 | $(N-1)/N \cdot S \approx S$ |
| broadcast | rank0 的 $S$ → $N$ 个 $S$ | 单源 | $S$(树形) |
| all-to-all | $N \times N$ 块矩阵转置 | shuffle | $(N-1)/N \cdot S \approx S$ |
关键恒等式:
$$\boxed{\;\text{all-reduce} = \text{reduce-scatter} + \text{all-gather}\;}$$
每一步 ring 上各传 $S/N$,共 $2(N-1)$ 步 → 单 GPU 总流量 $2(N-1)S/N \approx 2S$ bytes(与 $N$ 几乎无关,这是 ring 算法的精髓)。
NCCL 在小 message 用 tree all-reduce(latency-bound,$O(\log N)$ 跳);大 message 用 ring all-reduce(bandwidth-bound,吞吐最优)。NVLink 拓扑下还有 NVLS (NVLink SHARP),硬件做 reduction(H100/H200 NVSwitch 支持)。
2.2 NVLink / IB / 拓扑
| 链路 | 单向带宽(H100 代) | 用途 |
|---|---|---|
| NVLink 4.0 | 900 GB/s (per GPU, 18 链路聚合) | 同节点 GPU↔GPU |
| PCIe 5.0 x16 | 64 GB/s | GPU↔CPU、慢路径 |
| InfiniBand NDR 400G | 50 GB/s (per port) | 跨节点 |
经验法则:节点内通信比跨节点快 10-20 倍。所以 TP 一定要塞进一个节点,DP / PP 才跨节点。Llama 3 训练拓扑:TP=8(节点内 NVLink)× PP=16(跨节点 IB)× DP=128。
2.3 PyTorch 中的 NCCL 调用
import torch
import torch.distributed as dist
dist.init_process_group(backend="nccl") # 后端固定 nccl
rank = dist.get_rank()
world_size = dist.get_world_size()
# all-reduce(默认 SUM;可选 AVG / MIN / MAX / PRODUCT)
buf = torch.ones(1024, device=f"cuda:{rank % 8}") * rank
dist.all_reduce(buf, op=dist.ReduceOp.SUM)
# 现在 buf == sum(0..world_size-1) * ones(1024)
device = torch.device(f"cuda:{rank % 8}")
# reduce-scatter
input_list = [torch.full((1024,), float(rank + i), device=device) for i in range(world_size)]
output = torch.empty(1024, device=device)
dist.reduce_scatter(output, input_list, op=dist.ReduceOp.SUM)
# all-gather
local = torch.full((1024,), float(rank), device=device)
gathered = [torch.empty(1024, device=device) for _ in range(world_size)]
dist.all_gather(gathered, local)
# all-to-all(MoE 必备)
in_split = list(torch.randn(world_size, 1024, device=device).unbind(0))
out_split = [torch.empty(1024, device=device) for _ in range(world_size)]
dist.all_to_all(out_split, in_split)
§3 DDP — DistributedDataParallel
3.1 算法骨架
DDP 是数据并行最朴素的实现:
- 复制:每个 rank 都持有完整模型副本(参数 / 梯度 / 优化器状态全量)
- 切 batch:global batch $B$ 切成 $N$ 份 micro-batch,每 rank 算自己那份的 forward + backward
- 同步梯度:backward 完毕后对所有 gradient 做 all-reduce(SUM 后除以 $N$ 取平均,等价于 AVG)
- 本地 optimizer step:每 rank 用同样的 gradient 跑同样的 optimizer,参数始终一致
数学形式(loss $\mathcal{L}$ 在 mini-batch 上的均值):
$$g_\text{global} = \frac{1}{N}\sum_{i=1}^N \nabla_\theta \mathcal{L}(\theta; \mathcal{B}_i) = \text{all-reduce-mean}(g_1, \dots, g_N)$$
每个 rank 拿到的 $g_\text{global}$ 完全相同,因此 $N$ 个 rank 上的 $\theta$ 永远保持一致(同初始化 + 同梯度 + 同 optimizer)。
3.2 Bucket Fusion + Overlap(DDP 工程精髓)
朴素实现:backward 完成 → 把所有梯度按张量拼起来 → 一次 all-reduce。这样 GPU 空等通信,浪费惨重。
PyTorch DDP 实际做法:
- Bucket:把多个 gradient tensor 按反向计算顺序打包成 bucket(默认 25 MB)
- Hook:在每个 parameter grad 计算完时触发 hook
- 重叠:当一个 bucket 填满(所有 grad 都已计算),立刻在后台 stream 上发起 all-reduce,同时主 stream 继续算更早的层的 backward
- 结果:通信被前面层的反向计算完全 / 部分掩盖
Backward time axis →
Layer N: [grad N]──┐
Layer N-1: [grad N-1]┼─bucket_N─[all-reduce N]
Layer N-2: [grad N-2]┘ ↓ (后台 stream)
Layer N-3: [grad N-3]──┐
Layer N-4: [grad N-4]──┼─bucket_N-1─[all-reduce N-1]
... ↓
Layer 1: [grad 1]──────────────[all-reduce 1]
主 stream (compute): ████████████████████████████████
后台 stream (NCCL): ▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓
↑ 与计算重叠
太小:communication launch overhead 主导,bandwidth 利用率低;太大:等齐时间长、首个 all-reduce 启动晚。PyTorch 默认 25 MB 是大模型常用 sweet spot。可通过 DDP(bucket_cap_mb=...) 调。
3.3 PyTorch 代码(含 overlap)
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_ddp(rank, world_size, local_rank):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(local_rank) # 多节点时 local_rank != global rank
def main(rank, world_size, local_rank):
setup_ddp(rank, world_size, local_rank)
device = torch.device(f"cuda:{local_rank}")
model = MyModel().to(device)
model = DDP(
model,
device_ids=[local_rank],
bucket_cap_mb=25, # bucket 大小
gradient_as_bucket_view=True, # 内存优化: grad 直接是 bucket view
static_graph=False, # 若图静态可开,启用更多 fusion
)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
loader = make_distributed_loader(rank, world_size)
for batch in loader:
x, y = batch[0].to(device), batch[1].to(device)
loss = nn.functional.cross_entropy(model(x), y)
optimizer.zero_grad(set_to_none=True)
loss.backward() # ← bucket hook 在这里触发 all-reduce
optimizer.step()
3.4 复杂度
后续凡用 $\Phi$ 表示 参数量(个数)。fp16 / bf16 训练下一个参数 2 bytes,故 fp16 weight buffer 大小 = $2\Phi$ bytes。下表"$2\Phi$ bytes"指字节数,"$\Phi$ params"指个数。
| 项 | 量 | 说明 |
|---|---|---|
| 单 GPU 显存 | $16\Phi$ bytes + activation | model states 不分摊(含 fp16 params/grads + fp32 master/Adam) |
| 单 step 通信量 | $\approx 2 \cdot 2\Phi = 4\Phi$ bytes(ring all-reduce on fp16 gradient) | 等价于 reduce-scatter + all-gather |
| 扩展性 | 计算线性 scale;通信 bandwidth $O(1)$(ring 不依赖 $N$) | 但 latency-bound 小模型仍受 $N$ 影响 |
DDP 的硬伤:model states 不切。Adam 训练 70B 模型,单卡光 model states 就 1.12 TB,DDP 完全装不下。所以 7B+ 模型必须上 ZeRO / FSDP。
§4 ZeRO 1/2/3 — Zero Redundancy Optimizer
Rajbhandari et al. "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models" (SC 2020) 是分布式训练第二个划时代工作(第一个是 Megatron-TP)。核心思想:DDP 把 model states 复制 $N$ 份是浪费,切成 $N$ 份每张卡只放 $1/N$ 即可,通信只多一点点。
4.1 三阶段显存数学
记 $\Phi$ = 参数量,$N$ = DP world size。fp16 mixed precision + Adam 下单卡 model states:
$$\text{DDP}: \quad 2\Phi + 2\Phi + 12\Phi = 16\Phi$$
ZeRO 三阶段按 「切什么」 区分:
| 阶段 | 切的内容 | 单卡 model states | 通信量 (per step) |
|---|---|---|---|
| ZeRO-1 | optimizer state | $2\Phi + 2\Phi + 12\Phi/N$ | $2\Phi$(all-reduce 等价于 reduce-scatter + all-gather;ZeRO-1 只 reduce-scatter grad) |
| ZeRO-2 | optimizer state + gradient | $2\Phi + 2\Phi/N + 12\Phi/N$ | $2\Phi$(同 ZeRO-1) |
| ZeRO-3 | opt state + grad + parameter | $2\Phi/N + 2\Phi/N + 12\Phi/N = 16\Phi/N$ | $3\Phi$(forward + backward 各一次 all-gather,backward 一次 reduce-scatter) |
$N$ 足够大时(如 1024 张 H100),ZeRO-3 单卡 model states $16\Phi / 1024 = 0.0156\Phi$。65B 模型即 $\approx 1$ GB / GPU,配合 activation checkpoint 单 H100 完全装得下。
4.2 ZeRO-3 工作流(最常用,前向 / 反向 / 优化)
ZeRO-3 把参数本身也切了,所以参数在用之前必须 all-gather 才能 forward。
前向流程(每层 / 每 module):
1. all-gather: 从 N 张卡 collect 完整 W^(ℓ) ──[通信: φ_ℓ bytes]
2. compute: y = f(x; W^(ℓ)) ──[计算]
3. release: 丢掉本地不属于自己的 shard ──[释放显存]
反向流程:
1. all-gather: 再次取回 W^(ℓ)(forward 时已释放)──[通信: φ_ℓ]
2. compute: grad_W^(ℓ), grad_x
3. reduce-scatter: 把 grad_W^(ℓ) reduce 并切回各自的 shard ──[通信: φ_ℓ]
4. release: 丢掉非己 shard
伪代码:
# ZeRO-3 forward (单层抽象, 简化版)
def zero3_forward(layer_idx, x, sharded_W):
# 1. all-gather full weight
full_W = all_gather(sharded_W) # [Φ_ℓ / N, ...] × N -> [Φ_ℓ, ...]
# 2. compute
y = layer_forward(x, full_W)
# 3. 释放 full_W (只留 sharded_W)
del full_W
return y
def zero3_backward(layer_idx, dy, sharded_W, cached_input):
# 1. all-gather (forward 时已释放)
full_W = all_gather(sharded_W)
# 2. compute local gradients
dW_local, dx = layer_backward(dy, full_W, cached_input)
del full_W
# 3. reduce-scatter gradient 到对应 shard
dW_sharded = reduce_scatter(dW_local) # [Φ_ℓ, ...] / N -> [Φ_ℓ/N, ...]
return dW_sharded, dx
4.3 ZeRO-1/2/3 vs DDP 通信对比(重要面试题)
总参数 $\Phi$(个数)。下表的单位是 "fp16 weight buffer 等价数"(即 $\Phi$ 在通信流量列代表 $2\Phi$ bytes 的 fp16 buffer 流量)。一次 forward+backward+update 的 per-GPU 通信流量(ring 假设):
| 模式 | Forward | Backward | Optim | 总(fp16 buffer equiv.) |
|---|---|---|---|---|
| DDP | 0 | $2\Phi$ (all-reduce grad) | 0 | $2\Phi$(即 $4\Phi$ bytes) |
| ZeRO-1 | 0 | $\Phi$ (reduce-scatter grad) | $\Phi$ (all-gather updated params) | $2\Phi$ |
| ZeRO-2 | 0 | $\Phi$ (reduce-scatter grad) | $\Phi$ (all-gather) | $2\Phi$ |
| ZeRO-3 | $\Phi$ (all-gather params, on-the-fly) | $\Phi$ (all-gather) + $\Phi$ (reduce-scatter grad) | 0 (已在 backward 中) | $3\Phi$ |
ZeRO-1/2 通信量与 DDP 相同($2\Phi$ buffer),但显存大幅下降;ZeRO-3 多 $1.5\times$ 通信,换 $N\times$ 显存下降。实际工程中 ZeRO-3 通信也能通过 prefetch + overlap 部分隐藏,所以是 70B+ 模型的主流选择。换算成字节:DDP $\approx 4\Phi$ bytes,ZeRO-3 $\approx 6\Phi$ bytes。
4.4 ZeRO-Offload / ZeRO-Infinity
ZeRO-Offload(Ren et al., USENIX ATC 2021):把 optimizer state + 部分 gradient 卸到 CPU,CPU 端跑 Adam update。代价:CPU↔GPU PCIe 通信 + CPU 计算慢。适合小集群 + 大模型场景(如单机 8 卡跑 13B)。
ZeRO-Infinity(Rajbhandari et al., SC 2021):再进一步,把参数 / 优化器卸到 NVMe。理论上单机能跑 1T 参数(实际 throughput 极低,主要用于推理 / 微调)。
卸 CPU 后单 step 时间通常变长 1.5-3 倍;NVMe 卸更慢 5-10 倍。只在「装不下 + 没钱加卡」时才用。生产训练优先扩 GPU 数。
§5 FSDP / FSDP2 — PyTorch 原生 ZeRO-3
Zhao et al. "PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel" (VLDB 2023) 把 ZeRO-3 思想做进 PyTorch 主干。FSDP2(PyTorch 2.4+, 2024)是大重写:per-parameter DTensor sharding 替换 FSDP1 的 flat-parameter。
5.1 FSDP1 vs FSDP2 核心区别
| 维度 | FSDP1 (2022-2023) | FSDP2 (2024+) |
|---|---|---|
| 数据结构 | FlatParameter:把同一 wrap unit 内所有 param 拼成一个 1D buffer,再 chunk | DTensor per-parameter:每个 param 独立按 dim-0 切 |
| 状态字典 | 需 all-gather 才能产出 unflattened state dict | 通信自由 sharded state dict |
| 冻结参数 | 同 unit 内必须全冻或全可训 | 各 param 独立,混合冻可训自然 |
| TP 复合 | 困难(flat-buffer 与 TP 沿不同维切冲突) | 天然兼容:DTensor 描述多轴 placement (Shard(0), Replicate, Shard(1) 组合) |
| API | FullyShardedDataParallel | fully_shard() 函数式 wrap |
5.2 FSDP wrap 策略(最核心的设计决策)
FSDP 不是把整个模型当一个 unit。它按自定义 unit boundary 切,每个 unit 内的参数一起 all-gather / reduce-scatter。
import torch
import torch.nn as nn
from torch.distributed.fsdp import fully_shard # FSDP2 API
from torch.distributed.fsdp import MixedPrecisionPolicy
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.attn = MultiHeadAttention(d_model, n_heads)
self.mlp = FeedForward(d_model)
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
# 把每个 TransformerBlock 当一个 FSDP unit
model = MyLLM(n_layers=32, d_model=4096, n_heads=32).cuda()
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, # forward 用 bf16
reduce_dtype=torch.float32, # gradient reduce 用 fp32 防误差累积
)
for block in model.blocks:
fully_shard(block, mp_policy=mp_policy)
fully_shard(model, mp_policy=mp_policy) # root 也要 wrap
Unit 越小(如每个 linear 都 wrap):每次 all-gather 只取一层,显存 peak 低,但通信次数多、prefetch 难做;Unit 越大(如整个 block 或多个 block):通信次数少、容易 overlap,但 peak memory 高。TransformerBlock 粒度是 LLaMA / GPT 类的标准答案。
5.3 FSDP2 + 混合精度 + 激活检查点
from torch.distributed.fsdp import CPUOffloadPolicy, OffloadPolicy
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper, CheckpointImpl
)
# 1. 激活检查点 (gradient checkpointing) — 重算换显存
def apply_ac(model):
for i, block in enumerate(model.blocks):
model.blocks[i] = checkpoint_wrapper(
block,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_ac(model)
# 2. FSDP2 wrap
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
)
offload_policy = OffloadPolicy() # 或 CPUOffloadPolicy() 卸到 CPU
for block in model.blocks:
fully_shard(block, mp_policy=mp_policy, offload_policy=offload_policy)
fully_shard(model, mp_policy=mp_policy)
# 3. 训练 (注意: 必须用 torch._foreach_ 友好的 optimizer)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, fused=True)
for batch in loader:
loss = compute_loss(model, batch)
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
5.4 FSDP vs ZeRO-3 通信量差异(L3 题)
结论:完全一样(在 wrap unit = layer 这个粒度上)。FSDP 是 ZeRO-3 的 PyTorch 原生实现,算法等价。差异在工程:
| 维度 | DeepSpeed ZeRO-3 | FSDP2 |
|---|---|---|
| 集成 | 外部库 + monkey patch | PyTorch 主干 |
| state_dict | 需要单独 API 拼装 | 原生 set_model_state_dict / DCP |
| TP 复合 | DeepSpeed-Megatron 桥接 | DTensor 原生 |
| 冻结 / LoRA | 复杂 (flat-param 冲突) | 自然 |
| 编译 | 受限 | torch.compile(fullgraph=False) 友好 |
实际选型:新项目用 FSDP2;DeepSpeed 仍在 70B+ MoE / 多框架场景用得多。
5.5 HSDP — Hybrid Sharded DP(混合分片)
问题:1024 张 GPU 全部 ZeRO-3,单次 all-gather 跨 1024 卡 → 单卡通信量 $(N-1)/N \cdot \phi \to \phi$ 不变,但 latency 大幅上升(IB 跨节点贵)。
HSDP:组内 ZeRO-3(如节点内 8 卡),组间 DDP。等价于把 1024 卡分成 128 个 ZeRO-3 group(每组 8 卡)。
| 模式 | Shard 范围 | 单卡 model states | 跨节点通信 |
|---|---|---|---|
| 纯 ZeRO-3 | 全 1024 卡 | $16\Phi / 1024$ | 多 |
| HSDP (8) | 8 卡组内 | $16\Phi / 8 = 2\Phi$ | 少(组间是 grad all-reduce) |
trade-off:显存换通信效率。Llama 3 在某些训练阶段用 HSDP。
§6 ZeRO++ — 通信优化(2023)
Wang et al. "ZeRO++: Extremely Efficient Collective Communication for Giant Model Training" (NeurIPS MLSys workshop 2023, arXiv 2306.10209)。在 ZeRO-3 之上做了 3 件事:
6.1 qwZ:Quantized Weight all-gather(量化前向通信)
ZeRO-3 forward 时每层都要 all-gather fp16 weight。qwZ 把 fp16 → int8 后再 all-gather,通信量减半。
forward (原始): weight_shard (fp16) ──all-gather──> full_weight (fp16) ──compute
forward (qwZ): weight_shard (fp16) ──quant──> shard (int8)
──all-gather──> full (int8)
──dequant──> full (fp16) ──compute
代价:每次 all-gather 前后做 block-wise quantization / dequantization(block size 一般 2048-4096 元素一组,保留 per-block scale)。block-quant 比 per-tensor quant 精度好得多,对训练 loss 影响 < 1%。
6.2 hpZ:Hierarchical Partition(分层切分)
观察:backward 时 all-gather 的代价比 forward 大(更深更靠后的层先反传)。hpZ 把 weight 在 节点内复制(全 NVLink),跨节点不切:
- 节点内:8 张 GPU 各持完整 weight(hpZ = 节点内 DDP)
- 跨节点:weight 切片(依旧 ZeRO-3 模式)
backward 的 weight all-gather 只在节点内做(NVLink,便宜),不跨 IB。代价:每节点显存翻 8 倍——但 model states 本来就被切到 1024 卡上,乘 8 还远小于 DDP。
6.3 qgZ:Quantized Gradient reduce(量化梯度通信)
backward 末段 reduce-scatter gradient 也走 int8(fp16 → int8 + 量化 reduce)。原始 reduce-scatter 用 SUM 操作不能简单量化(量化 + sum 会累积误差),ZeRO++ 改用 all-to-all + 反量化 + local sum 路径绕开。
6.4 综合效果
ZeRO++ 论文报告:384 GPU 规模 throughput 2.16× 提升,通信量降 4×(fp16→int8 节省 2× × 2 个通信原语)。代价:实现复杂、精度需小心 ablation。
这里只量化集合通信途中的 transient buffer,权重本身的存储和计算仍是 fp16 / bf16。所以 loss 影响通常可忽略;和 fp8 训练(如 Hopper 上的 fp8 GEMM)是两件事。
§7 Tensor Parallel — Megatron-LM
Shoeybi et al. "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism" (arXiv 2019) + Narayanan et al. "Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM" (SC 2021)。
7.1 核心思想:列并行 → 行并行 配对
考虑一个 MLP 层:$Y = \text{GELU}(X W_1) W_2$,$X \in \mathbb{R}^{B \times D}$,$W_1 \in \mathbb{R}^{D \times 4D}$,$W_2 \in \mathbb{R}^{4D \times D}$。
Column Parallel(切 $W_1$ 的输出维 $4D$):
$$W_1 = [W_1^{(1)} \mid W_1^{(2)} \mid \cdots \mid W_1^{(T)}], \quad W_1^{(i)} \in \mathbb{R}^{D \times 4D/T}$$
每个 rank $i$ 持有 $W_1^{(i)}$,独立算 $Y_1^{(i)} = X W_1^{(i)} \in \mathbb{R}^{B \times 4D/T}$。输入 $X$ 全副本,输出沿 column 切。
Row Parallel(切 $W_2$ 的输入维 $4D$):
$$W_2 = \begin{bmatrix} W_2^{(1)} \\ W_2^{(2)} \\ \vdots \\ W_2^{(T)} \end{bmatrix}, \quad W_2^{(i)} \in \mathbb{R}^{4D/T \times D}$$
输入沿 row 切:$Z^{(i)} = \text{GELU}(Y_1^{(i)}) W_2^{(i)} \in \mathbb{R}^{B \times D}$。每个 rank 算自己那部分 partial sum,最后 all-reduce 求和:
$$Y = \sum_{i=1}^T Z^{(i)} = \text{all-reduce}(Z^{(1)}, \dots, Z^{(T)})$$
7.2 通信量分析
列并行:输入 $X$ 是全副本,输出沿 column 切 → forward 无通信(输出已分散在各 rank)。
行并行:输入已沿 row 切,输出需 all-reduce → forward 1× all-reduce $(BD)$。
col + row 配对:列并行末端不通信(结果直接喂给行并行),行并行末端 1× all-reduce。整个 MLP block forward 只 1× all-reduce。
Backward 镜像:每个 MLP block backward 也是 1× all-reduce(gradient 流过 col-row 时镜像通信)。
总计:单个 transformer block(MLP + Attention,attention 也是 col-row)forward + backward = 4× all-reduce, 每次 $BLD$($L$ = seq len)。
7.3 Attention 的 TP 切法
Multi-head attention 的 $W_Q, W_K, W_V \in \mathbb{R}^{D \times D}$。直接按 head 维切(每个 head 独立、列并行天然):
H heads, T-way TP:
- 每 rank 持 H/T 个 head 的 W_Q^(i), W_K^(i), W_V^(i)
- 每 rank 独立算自己的 head_h(Q W_Q^h, ...)
- 输出 W_O 切行并行 (row-parallel) → 末端 all-reduce
head 维天然独立(不同 head 不交互),切 head 不需要跨 rank 通信中间结果;切 hidden dim 则需要在 softmax 前后做额外通信。代码也简洁:直接把 head 维分配给 ranks。
7.4 TP 代码骨架
下面是教学版 col-parallel / row-parallel,明确写出 forward 和 backward 的通信对(关键:col-parallel forward 不通信,backward all-reduce 输入梯度;row-parallel forward all-reduce 输出,backward 不通信):
import math
import torch
import torch.nn as nn
import torch.distributed as dist
class _CopyToTPRegion(torch.autograd.Function):
""" Identity in forward, all-reduce in backward (col-parallel 入口) """
@staticmethod
def forward(ctx, x, tp_group):
ctx.tp_group = tp_group
return x
@staticmethod
def backward(ctx, grad_out):
dist.all_reduce(grad_out, op=dist.ReduceOp.SUM, group=ctx.tp_group)
return grad_out, None
class _ReduceFromTPRegion(torch.autograd.Function):
""" All-reduce in forward, identity in backward (row-parallel 出口) """
@staticmethod
def forward(ctx, x, tp_group):
dist.all_reduce(x, op=dist.ReduceOp.SUM, group=tp_group)
return x
@staticmethod
def backward(ctx, grad_out):
return grad_out, None
class ColumnParallelLinear(nn.Module):
""" Y = X W, W ∈ R^{in×out}, 切 out 维成 T 份;
forward: input replicated -> 输出 [B, out/T] 沿 last dim 切
backward: input grad 要 all-reduce(多个 TP rank 各算了 partial dX)"""
def __init__(self, in_features, out_features, tp_group, tp_size):
super().__init__()
assert out_features % tp_size == 0
self.tp_group = tp_group
self.out_per_rank = out_features // tp_size
self.weight = nn.Parameter(torch.empty(self.out_per_rank, in_features))
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, x):
# x: [B, in] (replicated). 入口 _CopyToTPRegion 让 backward 自动 all-reduce dX
x = _CopyToTPRegion.apply(x, self.tp_group)
return torch.nn.functional.linear(x, self.weight) # [B, out/T]
class RowParallelLinear(nn.Module):
""" Y = X W, W ∈ R^{in×out}, 切 in 维成 T 份;
forward: input [B, in/T] (sharded), 输出 partial sum -> all-reduce
backward: dW/dX 各自本地算即可,不需通信 """
def __init__(self, in_features, out_features, tp_group, tp_size):
super().__init__()
assert in_features % tp_size == 0
self.tp_group = tp_group
self.in_per_rank = in_features // tp_size
self.weight = nn.Parameter(torch.empty(out_features, self.in_per_rank))
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, x):
# x: [B, in/T] (sharded along last dim)
local_out = torch.nn.functional.linear(x, self.weight) # [B, out] partial sum
return _ReduceFromTPRegion.apply(local_out, self.tp_group)
class TPTransformerMLP(nn.Module):
""" Megatron-LM 标准 col+row 配对 """
def __init__(self, d_model, tp_group, tp_size):
super().__init__()
d_ff = 4 * d_model
self.fc1 = ColumnParallelLinear(d_model, d_ff, tp_group, tp_size) # 出 [B, L, 4D/T]
self.fc2 = RowParallelLinear(d_ff, d_model, tp_group, tp_size) # 入 [B, L, 4D/T] -> [B, L, D]
def forward(self, x):
# x: [B, L, D] (replicated)
h = torch.nn.functional.gelu(self.fc1(x)) # [B, L, 4D/T]
return self.fc2(h) # [B, L, D] after all-reduce
col-parallel: forward = identity, backward = all-reduce(dX);row-parallel: forward = all-reduce(Y), backward = identity。一个 transformer block (col + row × 2 sets: MLP + attention) 共 4 次 all-reduce,2 在 forward、2 在 backward——这就是上文 §7.2 那 4 次的来源。
7.5 TP 必须塞进节点(NVLink)
每个 transformer block forward + backward = 4× all-reduce of activation tensor 大小 $\approx BLD$。对 7B 模型 $D = 4096$,$B \times L = 2048$ → 单次 $\approx 32$ MB,per-block 4 次 × 32 layer = 128 次 / step, $\approx 4$ GB / step。
NVLink 带宽 900 GB/s 下 $\approx 4.4$ ms;IB 50 GB/s 下 $\approx 80$ ms。所以 TP 必须塞进节点内(NVLink 域)。这就是为什么 LLaMA / Llama 3 / GPT-3 都用 TP=8(一个节点 8 卡)而不是 TP=16。
§8 Sequence Parallel — 省 activation memory
Korthikanti et al. "Reducing Activation Recomputation in Large Transformer Models" (MLSys 2023)。
8.1 动机:activation memory 主要是 LayerNorm / Dropout 这种 element-wise op
TP 把 attention / MLP 的中间 activation 切了(沿 head / hidden 维),但 TP 外的部分(LayerNorm 输入输出、Dropout、residual)仍是全副本——每张 TP rank 都存一份 $[B, L, D]$ 大小的 activation。
7B 模型 $B \times L \times D = 2048 \times 4096$, fp16 = 16 MB / activation。一层有约 4-6 个这种 full-replica activation。32 层 × 5 → 2.5 GB / GPU 浪费。
8.2 SP 解法:沿 sequence 维切
让 TP 外的 activation 也沿 $L$(sequence)维切到 TP rank 上。
TP-only: [B, L, D] [B, L, D]
全副本 全副本
↓ ↑
LayerNorm → split → Attention → all-gather → LayerNorm
(TP)
TP + SP: [B, L/T, D] [B, L/T, D]
沿 L 切 沿 L 切
↓ ↑
LayerNorm → all-gather → Attention → reduce-scatter → LayerNorm
(TP)
关键操作变化:
- TP-only: forward 内 attention/MLP 入口处需要 broadcast / 退化为 no-op,出口 all-reduce
- TP + SP: 入口处 all-gather(把 SP 切的 L 维拼回全长),出口 reduce-scatter(把全长 reduce 再切回 L 维)
通信量:每个 transformer block 在 TP+SP 下需要 1× all-gather + 1× reduce-scatter(替换纯 TP 的 1× all-reduce),二者通信量等价,总量与纯 TP 一致。$\boxed{\text{通信量与纯 TP 完全相同}}$——但 LayerNorm / Dropout 的 activation 从 $BLD$ 降到 $BLD/T$。
因为 $\text{all-reduce} = \text{reduce-scatter} + \text{all-gather}$,纯 TP 末端 1× all-reduce $= $ SP 模式下 1× reduce-scatter + 后续 1× all-gather(next block 入口)。通信被重新分布,总量不变,但 activation 显存省下来了。
8.3 SP activation memory 省了多少(L3 题)
设未切分的单 transformer block 总 activation 内存为 $A = A_\text{in} + A_\text{out}$,其中:
- $A_\text{in}$:TP 内部那部分(attention 的 Q/K/V/scores、MLP 中间 [B, L, 4D]),TP 把它沿 hidden 切
- $A_\text{out}$:TP 外部那部分(LayerNorm/Dropout/residual 的 [B, L, D] activation),纯 TP 时全副本
纯 TP 单卡 activation:$A^\text{TP}_\text{per-card} = A_\text{in}/T + A_\text{out}$。
TP + SP 单卡 activation:SP 把 TP-外那部分也沿 seq 维切(每卡持 $[B, L/T, D]$)→ $A^\text{TP+SP}_\text{per-card} = A_\text{in}/T + A_\text{out}/T = A/T$。
SP 比纯 TP 省的部分:
$$\boxed{\;A^\text{TP}_\text{per-card} - A^\text{TP+SP}_\text{per-card} = A_\text{out} \cdot \left(1 - \frac{1}{T}\right)\;}$$
对 LLaMA-class 模型,$A_\text{out}$ 约占未切分总 activation 的 30-50%。TP=8 下省 $7/8 \cdot A_\text{out}$ ≈ 总 activation 减少 25-40%(与 Korthikanti 2023 报告一致)。
8.4 Selective Activation Recompute
Korthikanti et al. 同篇还提了 selective recompute:只对部分 op(如 attention 的 $QK^\top$ 和 softmax 这种二次 memory 的)做 recompute,其余 op 不重算。比 full activation checkpoint 减 30-40% 计算开销,省同样多的 memory。
§9 Pipeline Parallel
9.1 朴素 PP 与 bubble
把 $L$ 层模型按深度切到 $P$ 个 stage(每 stage $L/P$ 层)。Naive PP 让每个 mini-batch 顺序流过所有 stage:
Stage 0 (GPU0): [F1] [B1]
Stage 1 (GPU1): [F1] [B1]
Stage 2 (GPU2): [F1] [B1]
Stage 3 (GPU3): [F1] [B1]
↑ 大量 GPU idle ↑ ↑大量 idle ↑
GPU 利用率 = $1/P$,完全不能用。
9.2 GPipe(Huang et al., NeurIPS 2019)
把 mini-batch 切成 $M$ 个 micro-batch,micro-batch 之间流水:
Stage 0: [F1][F2][F3][F4] [B4][B3][B2][B1]
Stage 1: [F1][F2][F3][F4] [B4][B3][B2][B1]
Stage 2: [F1][F2][F3][F4] [B4][B3][B2][B1]
Stage 3: [F1][F2][F3][F4] [B4][B3][B2][B1]
|←─ "warm-up" ─→| |← 全部 micro-batch backward ─→| "cool-down"
Bubble(GPU 空闲时间占比):
$$\boxed{\;\text{bubble ratio} = \frac{P - 1}{M + P - 1}\;}$$
推导:每个 micro-batch 走完 $P$ stage 要 $P$ 个 step;warm-up 阶段(stage $i$ 等前 $i$ 个 micro-batch)共 $P-1$ 个 step idle;cool-down 阶段对称 $P-1$ 个 step idle;总 step = $M + P - 1$(forward)+ $M + P - 1$(backward)= $2(M+P-1)$;其中 idle = $2(P-1)$。idle 占比 = $(P-1)/(M+P-1)$。
必须等所有 $M$ 个 micro-batch 全部 forward 完才能开始 backward——所有 micro-batch 的 activation 都要存住,activation memory 与 $M$ 线性增长。这是 1F1B 解决的问题。
9.3 1F1B / PipeDream(Narayanan SOSP 2019, Megatron-LM-2 SC 2021)
1F1B = 1 Forward 1 Backward:每个 stage 在 forward 一个 micro-batch 后立刻 backward(前一个已经完成的):
Stage 0: [F1][F2][F3][F4][B1][F5][B2][F6][B3][F7][B4]...
↑ 一旦 micro-batch 1 backward 准备好(其 forward 已到 stage P)
就立刻 B1,腾出 micro-batch 1 的 activation 内存
关键性质:
- Bubble ratio 仍然是 $(P-1)/(M+P-1)$(warm-up + cool-down 没省)
- 但 每 stage 同时存活的 activation 数 = $P$(不是 GPipe 的 $M$)——大幅省 activation memory
Stage i 在稳态时的 activation 在内存中:
- micro-batch i forward 后等 backward 的中间 activation
- 因为有 P-i 个 stage 在 i 之后,且每 stage 走 forward / backward 都一个 step
- 所以 stage i 同时存 P-i 个 forward 完但还没 backward 的 activation
- 第一个 stage (i=0) 存最多: P 个;最后 stage 存最少: 1 个
9.4 Interleaved 1F1B(Megatron-LM-2, SC 2021)
进一步把 bubble 再压。每张 GPU 不持有连续 $L/P$ 层,而是持有 $V$ 段不连续的层(virtual stages):
原 1F1B (P=4, L=8 layers):
GPU0: layers 0,1 GPU1: layers 2,3 GPU2: layers 4,5 GPU3: layers 6,7
Interleaved (P=4, V=2, L=8):
GPU0: layers 0,4 GPU1: layers 1,5 GPU2: layers 2,6 GPU3: layers 3,7
(每张 GPU 持 V=2 段, 共 L/(PV) = 1 层 / 段)
每个 micro-batch 在 stage 0 跑 layer 0 → stage 1 跑 layer 1 → ... → stage 3 跑 layer 3 → 再回 stage 0 跑 layer 4 → ... → stage 3 跑 layer 7。一个 micro-batch 通过 stage 列 $V$ 次。
Bubble ratio(Narayanan et al., SC 2021, Eq. 4):
$$\boxed{\;\text{interleaved bubble} = \frac{P-1}{V \cdot M + P - 1} \approx \frac{P-1}{V \cdot M}\;}$$
把分母里 $M$ 替换成 $V \cdot M$(流水管子里同时塞了 $V$ 倍 micro-batch 的虚拟 stage 通过次数),warm-up / cool-down 的 $P-1$ 不变。代价:单 micro-batch 跨 GPU 的 send/recv 次数变 $V$ 倍(每个 micro-batch 在 stage 列上走 $V$ 圈),所以 $V$ 不能太大(一般 $V = 2$ 到 $4$)。
普通 1F1B:bubble = $(P-1)/(M+P-1)$,分母是 micro-batch 数 $M$ 加上 warm/cool $P-1$;interleaved:每个 micro-batch 跨 GPU 走 $V$ 圈,等效 micro-batch 数变成 $V \cdot M$,bubble = $(P-1)/(V M + P - 1) \approx (P-1)/(VM)$。"$V$ 倍 micro-batch" 是核心直觉。
9.5 1F1B 调度伪代码
def one_f_one_b_schedule(P, M, stage_rank, num_warmup_microbatches):
"""
stage_rank: 当前 GPU 持有的 pipeline stage 编号 (0..P-1)
num_warmup_microbatches = P - 1 - stage_rank
"""
# ===== Warm-up: stage_rank 走 (P-1-stage_rank) 个 forward
for i in range(num_warmup_microbatches):
recv_activation_from_prev_stage()
out = forward(model, activation)
send_activation_to_next_stage(out)
# ===== Steady state: F1B1 alternating
num_microbatches_remaining = M - num_warmup_microbatches
for i in range(num_microbatches_remaining):
# forward
recv_activation_from_prev_stage()
out = forward(model, activation)
send_activation_to_next_stage(out)
# backward (前面 forward 过、现在到的)
recv_grad_from_next_stage()
grad_in = backward(model, grad_out)
send_grad_to_prev_stage(grad_in)
# ===== Cool-down: 剩余 backward
for i in range(num_warmup_microbatches):
recv_grad_from_next_stage()
grad_in = backward(model, grad_out)
send_grad_to_prev_stage(grad_in)
9.6 PP 通信特点
PP 只在 stage 边界传 activation / gradient,单次传一个 micro-batch 大小 $\approx B/M \cdot L \cdot D$ bytes。通信量极小(远小于 TP / DP),所以 PP 可以跨节点(IB 也够)。
9.7 PP 的硬伤:load imbalance
- 不同 stage 的层数 / 计算量需手动平衡(embedding + LM head 很重,常需特殊处理)
- 一个 stage 的 OOM / 慢都拖死整个 pipeline(短板效应)
- micro-batch 数 $M$ 太小 → bubble 大;$M$ 太大 → activation memory 涨
§10 Context Parallel — 长序列切分
随着上下文窗口从 4K → 128K → 1M,单卡装不下完整 attention 计算。Context Parallel (CP) 沿 sequence 维切。
10.1 Ring Attention(Liu et al., arXiv 2310.01889, 2024)
把 $Q, K, V$ 沿 seq 维切到 $C$ 个 rank:每 rank 持 $L/C$ 个 token 的 Q/K/V。
Rank 0: Q[0:L/C], K[0:L/C], V[0:L/C]
Rank 1: Q[L/C:2L/C], K[L/C:2L/C], V[L/C:2L/C]
...
但 attention 需要每个 query 看到所有 key(causal 下看到所有过去 key)。Ring attention 解法:
- 每个 rank 用本地 Q 与本地 K, V 算一块 partial attention
- 把本地 K, V 沿 ring 转发给下一个 rank
- 下一个 rank 用本地 Q × 上一 rank 的 K, V 算 partial attention,累积到 output(online softmax 风格)
- 转 $C$ 步后每个 Q 都看遍全部 K, V,attention 完整
通信量:每个 rank 持 $L/C$ 个 token 的 K, V,大小 $2 \cdot L/C \cdot D$(fp16 下乘 2 bytes)。Ring 转 $C-1$ 步才让每个 K, V shard 走遍所有 rank。每 rank 总传输量 $\approx (C-1) \cdot 2 L D / C \approx 2 L D$(与 $C$ 几乎无关——这正是 ring 的常规性质)。
关键:用 online softmax(FlashAttention 同款)让 partial attention 可累积,不需物化完整 $L \times L$ scores。
FlashAttention 是单卡内沿 block tiling 算 attention(block 在 SRAM 内)。Ring attention 是把这个 tiling 推到多卡 / 多节点 level:每个 rank 持有一块 K, V,按 ring 顺序计算 partial attention 并累积。两者数学上一脉相承——都是 online softmax + block-wise accumulation。
10.2 Llama 3 CP 实现
Llama 3 paper (arXiv 2407.21783) 报告在 128K 长上下文 阶段用 CP=16(短上下文 8K 阶段不需 CP,分配回 DP)。结合 FlashAttention v3,128K context 单 step 时间从不可训练 → 几秒级可控。
10.3 CP 与 TP 的正交性
- TP 沿 hidden / head 维切 → 节点内
- CP 沿 sequence 维切 → 可跨节点(通信量 $\propto L$ 而非 $L^2$)
- 复合:每张 GPU 持 $L/C$ 个 token 的 $D/T$ 维 sub-tensor
§11 Expert Parallel — MoE 路由
11.1 MoE 基本结构
Mixture-of-Experts 把单个 FFN 替换为 $E$ 个 expert FFN + 一个 gate / router:
$$y = \sum_{e=1}^E G_e(x) \cdot \text{Expert}_e(x), \quad G \in \mathbb{R}^E, \quad \sum_e G_e = 1$$
实际部署用 top-K routing:只选 $G$ 输出最大的 $K$ 个 expert(典型 $K = 1, 2$),其余 expert 不计算。计算量与 expert 数 $E$ 无关,只与 $K$ 有关——这是 MoE 能 scale 参数的关键。
11.2 Expert Parallel:expert 分到不同 GPU
| 模式 | 单卡 expert 数 | 通信 |
|---|---|---|
| 不并行 (replicate) | $E$ | 0 |
| TP-style expert split | 每 expert 切片,全 GPU 算每个 expert | many all-reduce |
| EP: 每 GPU 持 $E/N$ 个 expert | $E/N$ | all-to-all dispatch + combine |
EP 的 forward 流程:
1. 每 GPU 对自己的 token batch 算 gate → 得到每 token 的 routing decision (assign 给 e_1, e_2, ..., e_K)
2. all-to-all dispatch: 把 token 按 expert 归属发到对应 GPU
- 输入: 每 GPU 持 B/N 个 token, 每 token 带 K 个 (expert_id, token_data)
- 输出: 每 GPU 收到所有 GPU 发给本机 expert 的 token
3. 每 GPU 用本地 expert 算 FFN
4. all-to-all combine: 把 expert 输出送回 token 原属 GPU
5. 每 GPU 按 gate weight 合并 K 个 expert 输出
(1) load imbalance: gate 倾向于 favor 某几个 expert,导致部分 GPU 过载;解:load balancing loss(Switch Transformer / GShard)。(2) all-to-all 通信量大: 与 token 数 × hidden 成正比,跨节点 IB 是瓶颈。DeepSeek-V3 用 node-limited routing 限制每 token 至多 dispatch 到 $M$ 个节点,减少 IB 流量。
11.3 EP all-to-all 代码骨架(伪代码)
下面是 EP 前向流程的伪代码骨架——重点在 dispatch / combine 的两次 all-to-all,省去了 routing / bucketing 的工程细节(生产代码见 Megatron-Core MoE 或 DeepSpeed-MoE):
def expert_parallel_forward(
x, # [B, L, D]
gate, # nn.Module: tokens [BL, D] -> (top_k_ids, top_k_w) 各 [BL, K]
experts_local, # 本地持有的 E_local 个 expert (nn.ModuleList)
ep_group, ep_size, ep_rank,
E_total, K,
):
""" 教学版:每张 GPU 持 E_local = E_total / ep_size 个 expert """
B, L, D = x.shape
tokens = x.reshape(B * L, D)
# 1. routing: 每 token 选 K 个 expert
top_k_ids, top_k_w = gate(tokens) # 都是 [BL, K]
# 2. expand: top-K 把每 token 复制 K 份, 每份带一个 expert_id
expanded = tokens.unsqueeze(1).expand(-1, K, -1).reshape(B * L * K, D)
expand_ids = top_k_ids.reshape(B * L * K) # [BL*K]
expand_w = top_k_w.reshape(B * L * K)
# 3. 算每个复制 token 该送到哪个 EP rank
target_rank = expand_ids // (E_total // ep_size) # [BL*K]
# 4. 按 target_rank 排序 + 统计每 rank send_count
perm = torch.argsort(target_rank)
sorted_tokens = expanded[perm]
send_counts = torch.bincount(target_rank, minlength=ep_size).tolist()
# 5a. 交换 send_counts -> recv_counts (每 rank 告诉别人自己要收多少)
send_t = torch.tensor(send_counts, dtype=torch.int64, device=x.device)
recv_t = torch.empty_like(send_t)
dist.all_to_all_single(recv_t, send_t, group=ep_group) # 一次小 a2a 同步 counts
recv_counts = recv_t.tolist()
# 5b. 同步 token 的 expert_id (用于本机分配到对应 local expert)
sorted_ids = expand_ids[perm]
received_ids = torch.empty(sum(recv_counts), dtype=sorted_ids.dtype, device=x.device)
dist.all_to_all_single(received_ids, sorted_ids,
output_split_sizes=recv_counts,
input_split_sizes=send_counts,
group=ep_group)
# 5c. 同步 token 本体
received_tokens = torch.empty(sum(recv_counts), D, device=x.device, dtype=x.dtype)
dist.all_to_all_single(received_tokens, sorted_tokens,
output_split_sizes=recv_counts,
input_split_sizes=send_counts,
group=ep_group)
# 6. local expert compute
received_out = torch.zeros_like(received_tokens)
for local_eid in range(len(experts_local)):
global_eid = ep_rank * (E_total // ep_size) + local_eid
mask = (received_ids == global_eid)
if mask.any():
received_out[mask] = experts_local[local_eid](received_tokens[mask])
# 7. all-to-all combine: 反向, 用对调过的 split sizes
combined = torch.empty_like(sorted_tokens)
dist.all_to_all_single(combined, received_out,
output_split_sizes=send_counts, # 反向
input_split_sizes=recv_counts,
group=ep_group)
# 8. 反排序 + gate weight 合并
inv_perm = torch.empty_like(perm)
inv_perm[perm] = torch.arange(perm.numel(), device=perm.device)
out_expanded = combined[inv_perm] # [BL*K, D]
out_expanded = out_expanded * expand_w.unsqueeze(-1)
out = out_expanded.view(B * L, K, D).sum(dim=1) # [BL, D]
return out.view(B, L, D)
真实代码处理 (a) capacity factor 限制 (避免某 expert 过载);(b) drop tokens 当 expert 满了;(c) NVL/IB hierarchical all-to-all(DeepSeek 的 node-limited routing);(d) backward 的镜像 all-to-all + gate gradient。本骨架仅示意 dispatch / combine 双向流程。
§12 Activation Memory 优化
12.1 Gradient Checkpointing(Chen et al., arXiv:1604.06174, 2016)
不存中间 activation,反向时重算。空间 $O(\sqrt{L})$(最优分段),时间 + 33% (1 次 extra forward)。
from torch.utils.checkpoint import checkpoint
def block_forward(x):
return transformer_block(x)
# 反向时重算 block_forward,不存中间 activation
y = checkpoint(block_forward, x, use_reentrant=False)
12.2 Selective Recompute(Korthikanti 2023)
只对二次 memory 的 op(attention 的 $QK^\top$ 和 softmax)做 recompute,其余存。比 full recompute 减 30-40% 计算 overhead,省同样多 memory。Megatron-LM 默认开启。
12.3 Offload(ZeRO-Infinity)
把 activation / optimizer state 卸到 CPU RAM 或 NVMe。CPU 卸适合 13B-30B 量级单机训练;NVMe 卸吞吐极低,主要用于推理或 trillion-scale 大模型探索。
12.4 Activation Memory 公式(必背)
一个 transformer block 的 activation 大致(fp16, full save):
$$A_\text{block} \approx \underbrace{34 \cdot B \cdot L \cdot D}_{\text{各 LayerNorm/QKV/output/MLP residual}} + \underbrace{5 \cdot B \cdot L^2 \cdot H}_{\text{attention 中间矩阵}}$$
详见 Korthikanti et al. 2023 Table 2。$L \gg D$ 时第二项主导(FlashAttention 直接干掉这部分);$L \approx D$ 时第一项主导(SP 把它降 $T$ 倍)。
§13 综合:3D / 4D / 5D Parallelism
把前面几条轴拼起来。设 world size $W$,
$$W = D_{DP} \times T_{TP} \times P_{PP} \times C_{CP} \times E_{EP}$$
每条轴的角色:
| 轴 | 切什么 | 在哪 | 通信原语 |
|---|---|---|---|
| DP / FSDP | batch + model states | 跨节点 IB ok | all-reduce / reduce-scatter + all-gather |
| TP | 单层 hidden / head | 节点内 NVLink | all-reduce (per block × 4) |
| SP | TP 外的 LayerNorm / Dropout activation | 节点内(与 TP 共 group) | reduce-scatter + all-gather |
| PP | layer depth | 跨节点 IB ok | point-to-point send/recv |
| CP | sequence 维 | 节点内或跨节点都可(通信 $\propto L$) | ring K/V |
| EP | MoE expert | 跨节点 IB ok(all-to-all 量大) | all-to-all |
13.1 Llama 3 405B 训练拓扑(公开信息)
Meta 2024 报告 (2407.21783):
- 16K H100 GPUs(总计 16384)
Parallelism(保持总 GPU 数不变,CP 上去时 DP 下来):
- 短上下文阶段 (8K):TP=8 × CP=1 × PP=16 × DP=128 = 16384
- 长上下文阶段 (128K):TP=8 × CP=16 × PP=16 × DP=8 = 16384
- 训练精度:BF16(论文 Table 报告 BF16 MFU);FP8 用于 inference 量化,未用于 405B 训练
- 54 天共 466 次中断(419 unexpected + 47 planned/maintenance),其中 78% unexpected 是硬件原因
- 有效训练时间 > 90%
13.2 DeepSeek-V3 训练拓扑(公开信息)
DeepSeek 2024 报告 (2412.19437):
- 2048 H800 GPUs
- Parallelism: TP=1(无 TP!靠 ZeRO + EP + PP 弥补) × PP=16 × EP=64(跨 8 节点) × ZeRO-1 DP
- DualPipe 双向流水(详见下节)+ all-to-all overlap
- fp8 GEMM + bf16 accumulation
V3 用了 MLA (multi-head latent attention) + 大量 MoE expert,TP 切 head 收益小(latent attention head 维已经很小);而 EP 的 all-to-all overlap 配合 DualPipe 把通信全藏起来。整体 PP × EP × ZeRO 已够装下 671B 参数。
§14 DualPipe — 2024 流水线前沿
DeepSeek 2024 在 V3 paper 与独立 repo 发布的 DualPipe 算法(arXiv 2412.19437 + github.com/deepseek-ai/DualPipe)。
14.1 核心想法
1F1B 的 bubble 来自 warm-up + cool-down 阶段。DualPipe 让 两个方向同时跑流水线——一组 micro-batch 从 stage 0 → P,另一组从 P → 0,两组在中间相遇时刚好填满 warm-up / cool-down 空隙。
传统 1F1B (P=4):
Stage 0: [F1][F2][F3][F4][B1][F5][B2]...
Stage 1: [F1][F2][F3][F4][B1][F5][B2]...
(前后两端 bubble)
DualPipe (P=4):
Forward 方向 micro-batch: [F1][F2][F3][F4]...
Reverse 方向 micro-batch: ...[F4'][F3'][F2'][F1']
Stage 0: 同时 process Forward 的 F + Reverse 的 F' + 对应 B ← 计算 / 通信完全 overlap
更精确:DualPipe 设计了一个双向 schedule,让每张 GPU 在任意时刻都有两个 micro-batch 在 forward / backward 重叠执行;同时 expert parallel 的 all-to-all 通信也被 hide 在两组 micro-batch 的间隙里。
14.2 DualPipe vs 普通 1F1B 性质
| 维度 | 1F1B | Interleaved 1F1B | DualPipe |
|---|---|---|---|
| Bubble | $(P-1)/(M+P-1)$ | $(P-1)/(VM+P-1)$ | 理想 0 (warm-up / cool-down 互补) |
| Activation memory | $P$ × per stage | $V \cdot P$ × | $\approx 2 \times P$ |
| 通信 overlap | 部分 | 同 1F1B | 几乎 100% all-to-all overlap |
| 实现复杂度 | 中 | 高 | 极高(需 forward & reverse 调度) |
14.3 DualPipe 的代价
- 代码复杂度爆炸:一个 stage 同时跑 forward (方向 1) + forward (方向 2) + backward,cuda stream 管理极困难
- 2× activation memory:两个方向都要存 activation,比 1F1B 多 $\approx 2\times$
- stage 必须均衡:任何一个 stage 卡住都会让两个方向都堵
DeepSeek 用 DualPipe 是因为 EP all-to-all 通信极重,必须想办法藏起来。一般训练任务普通 1F1B 已足够。
§15 TorchTitan — PyTorch 原生 4D 平台
Liang et al. (ICLR 2025, arXiv 2410.06511) "TorchTitan: One-stop PyTorch Native Solution for Production-Ready LLM Pre-training"。
15.1 设计目标
- 不要再 monkey patch:把 FSDP2 / TP / PP / SP / CP / Float8 /
torch.compile全做进 PyTorch 主干 - DTensor 作为统一语言:所有切分都用 DTensor placement 描述(
Shard(d),Replicate,Partial) - Composable:FSDP2 与 TP 自然复合(FSDP1 与 TP 几乎不可复合)
15.2 代码风格(与 DeepSpeed monkey patch 对比)
# TorchTitan 风格:声明式 DTensor placement
import torch
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import (
parallelize_module, RowwiseParallel, ColwiseParallel,
SequenceParallel, PrepareModuleInput,
)
from torch.distributed.fsdp import fully_shard
# 1. 构造 4D 设备 mesh
mesh = init_device_mesh(
"cuda",
mesh_shape=(2, 8, 4, 8), # PP=2, FSDP=8, CP=4, TP=8
mesh_dim_names=("pp", "fsdp", "cp", "tp"),
)
# 2. 应用 TP + SP(声明每个 sub-module 的 sharding 策略)
parallelize_module(
model,
mesh["tp"],
{
"attn.wq": ColwiseParallel(),
"attn.wk": ColwiseParallel(),
"attn.wv": ColwiseParallel(),
"attn.wo": RowwiseParallel(),
"mlp.fc1": ColwiseParallel(),
"mlp.fc2": RowwiseParallel(),
"norm1": SequenceParallel(),
"norm2": SequenceParallel(),
},
)
# 3. FSDP2 wrap(自动认到 mesh["fsdp"])
for block in model.blocks:
fully_shard(block, mesh=mesh["fsdp"])
fully_shard(model, mesh=mesh["fsdp"])
# 4. PP(pipeline schedule,1F1B interleaved,伪代码示意)
# 真实 ScheduleInterleaved1F1B 需要 List[PipelineStage](每 rank 持 V 个 virtual stage)
from torch.distributed.pipelining import PipelineStage, ScheduleInterleaved1F1B
stages = [PipelineStage(submod_v0, ...), PipelineStage(submod_v1, ...)] # V=2 virtual stages
schedule = ScheduleInterleaved1F1B(stages, n_microbatches=32, loss_fn=loss_fn)
15.3 Float8 训练(Hopper / Blackwell)
TorchTitan 集成 Float8 训练(H100/B100 fp8 GEMM),权重 / activation 都用 fp8,accumulation 在 fp32。以下是 torchao 的典型用法(具体 API 请以 torchao 当前版本文档为准):
import torch.nn as nn
from torchao.float8 import convert_to_float8_training, Float8LinearConfig
convert_to_float8_training(
model,
config=Float8LinearConfig(), # 默认 dynamic scaling
module_filter_fn=lambda m, n: isinstance(m, nn.Linear) and "lm_head" not in n,
)
效果:H100 上 throughput +20-40%,loss 几乎无差异(block-wise scaling 关键)。
15.4 异步 Checkpointing
import torch.distributed.checkpoint as DCP
# 异步保存:返回 Future,不阻塞训练
future = DCP.async_save(model.state_dict(), checkpoint_id="step_10000")
# ... 继续训练
future.result() # 必要时等
DCP(Distributed Checkpoint)配合 FSDP2 的 sharded state dict,单 checkpoint 写盘只需几秒(每 rank 写自己那部分到分布式存储),不阻塞训练。
§16 通信原语总结表
为了在面试中快速 recall,一张表收尾。
| 操作 | 当原始数据全部加起来是 $S$(每 rank $S$) | 每 rank 通信量(ring) | 用在 |
|---|---|---|---|
all-reduce(buf, SUM) | $N \cdot S$ → 全 rank $S$ | $2(N-1)/N \cdot S$ | DDP gradient sync, TP block 末端 |
reduce-scatter(buf) | $N \cdot S$ → 各 rank $S/N$ | $(N-1)/N \cdot S$ | ZeRO grad reduce, SP 出口 |
all-gather(shard) | $N \cdot S/N$ → 全 rank $S$ | $(N-1)/N \cdot S$ | ZeRO-3 forward param, SP 入口 |
broadcast(buf, src) | rank src 的 $S$ → 全 rank $S$ | $S$(树) / $(N-1)/N \cdot S$(ring) | model 初始化广播 |
all-to-all(buf) | $N \cdot N$ 块 → 转置 | $(N-1)/N \cdot S$ | MoE EP routing, sequence sharding 变换 |
point-to-point send/recv | 1 → 1 | $S$ | PP stage 边界 |
§17 25 高频面试题
按 L1 / L2 / L3 排序,每题 collapsible 答案要点。
L1 必会题(任何 ML 工程岗都会问)
Q1. DDP 和 DP(DataParallel)的区别?
- DP(
nn.DataParallel):单进程多卡,主 rank scatter input + gather output,有 GIL + 主 rank 瓶颈,已 deprecated - DDP(
DistributedDataParallel):多进程多卡,每进程一张 GPU,NCCL all-reduce 同步梯度 - DDP 比 DP 快 1.5-3 倍且 scaling 远好
陷阱:说 DP 是"训练快"——错,DP 是历史遗留 API,生产用 DDP。
Q2. NCCL all-reduce 的通信量?
- Ring 算法:单 GPU 总流量 $2(N-1)/N \cdot S \approx 2S$ bytes
- 与 $N$ 几乎无关(这是 ring 的精髓)
- 等价于 reduce-scatter ($S$) + all-gather ($S$)
陷阱:说 $N \cdot S$ 或 $S/N$;忘了 ring 的 step 数和 per-step 流量。
Q3. Adam mixed-precision 训练单卡 model states 多少?
- 参数 fp16: $2\Phi$
- 梯度 fp16: $2\Phi$
- Optimizer (fp32 master + Adam m + v): $4\Phi + 4\Phi + 4\Phi = 12\Phi$
- 总计 $16\Phi$ bytes
陷阱:忘了 fp32 master copy;或把 Adam $m, v$ 当 fp16(实际 fp32)。
Q4. ZeRO 1/2/3 分别切什么?
- ZeRO-1: 切 optimizer state
- ZeRO-2: + gradient
- ZeRO-3: + parameter(最激进)
单卡 model states 从 $16\Phi$ 降到 $16\Phi/N$(ZeRO-3)。
陷阱:把顺序搞反;不知道 ZeRO-1/2 通信量与 DDP 完全一样(都是 $2\Phi$)。
Q5. FSDP 和 ZeRO-3 有什么区别?
- 算法层面:完全等价(FSDP 就是 ZeRO-3 的 PyTorch 实现)
- 工程层面:FSDP 集成在 PyTorch 主干;ZeRO-3 在 DeepSpeed 库
- FSDP2 改用 per-parameter DTensor sharding,与 TP 复合自然,state_dict 简洁
陷阱:说"FSDP 比 ZeRO 通信少"或反之——错,通信量同。
Q6. Tensor Parallel 切 attention 的哪一维?
- 切 head 维(每张卡持 $H/T$ 个 head)
- $W_Q, W_K, W_V$ column-parallel(按输出维切)
- $W_O$ row-parallel(按输入维切)
- 每个 attention block forward 1× all-reduce, backward 1× all-reduce
陷阱:说切 hidden_dim 维;或忘了 col + row 配对让中间不通信。
Q7. TP 必须放在哪里?
- 节点内 NVLink 域(带宽 900 GB/s)
- 跨节点 IB(50 GB/s)走 TP 性能会暴跌
- 这就是为什么 TP=8 是黄金尺寸(一个节点 8 卡)
陷阱:说 TP 可以任意跨节点;或把节点内 / 跨节点带宽搞反。
Q8. Pipeline Parallel 的 bubble 怎么算?
- 朴素 PP:$M = 1$,bubble = $(P-1)/P$(巨大)
- GPipe / 1F1B with $M$ micro-batch:$(P-1)/(M+P-1)$
- 通常 $M \geq 4P$ 让 bubble < 20%
陷阱:忘了 $M$ 在分母;只说 $1/P$ 不说 $M$。
Q9. 1F1B 比 GPipe 好在哪?
- bubble ratio 一样 $(P-1)/(M+P-1)$
- 但 1F1B 每 stage 同时存活的 activation 数 = $P$(不是 GPipe 的 $M$)
- 大幅省 activation memory
陷阱:说 1F1B 减小了 bubble——错,1F1B 不减 bubble,省 activation。
Q10. 激活检查点(gradient checkpointing)的代价?
- 显存:$O(L) \to O(\sqrt{L})$
- 时间:+ 33%(一次额外 forward)
- 实际生产几乎必开(70B+ 模型不开就 OOM)
陷阱:说"显存减半"——不准确,理论是 $\sqrt{L}$;或说"时间减半"。
L2 进阶题(research-oriented 岗位)
Q11. ZeRO-3 forward 为什么需要 all-gather?
- 参数被切到 $N$ 张卡 → 每张卡只持 $1/N$
- 算某层 forward 时需要完整 $W^{(\ell)}$ → 临时 all-gather 到所有 rank
- forward 完立刻 release,只保留 shard
- backward 时再次 all-gather(forward 中已释放)
陷阱:以为参数在所有 rank 上常驻;或忘了 backward 也要重新 all-gather。
Q12. 推导:DDP vs ZeRO-3 通信量差。
- DDP: backward 1× all-reduce = $2\Phi$,总 $2\Phi$
- ZeRO-3: forward 1× all-gather ($\Phi$) + backward 1× all-gather ($\Phi$) + 1× reduce-scatter ($\Phi$) = $3\Phi$
- ZeRO-3 比 DDP 多 50% 通信,换 $N\times$ 显存下降
陷阱:以为 ZeRO 减通信——错,ZeRO 减显存,可能增通信(视阶段)。
Q13. NCCL ring all-reduce 单步流量?
- 总 step 数:$2(N-1)$(reduce-scatter $N-1$ + all-gather $N-1$)
- 单 step:每 rank 发送 $S/N$ bytes
- 单 GPU 总流量:$2(N-1) \cdot S/N \approx 2S$
- Bandwidth 利用率 $\to 1$ 当 $N \to \infty$
陷阱:说 $S$ 而非 $2S$;或说 $N \cdot S$(忘了 ring 的精髓)。
Q14. SP(Sequence Parallel)省了什么?
- 不省通信(与纯 TP 通信量一样)
- 省 TP 外的 activation memory(LayerNorm / Dropout)
- 把全副本 $[B, L, D]$ 切成 $[B, L/T, D]$
- 总 activation memory 降 25-40%
陷阱:说"SP 减通信"——错,SP 只是把 all-reduce 重排成 reduce-scatter + all-gather(总量等同),但 activation 切了。
Q15. Interleaved 1F1B 怎么减 bubble?
- 每 stage 持 $V$ 个 virtual stage(不连续的 $V$ 段 layer)
- bubble: $(P-1)/(VM+P-1) \approx (P-1)/(VM)$
- 同 $M$ 下 bubble 降 $V$ 倍
- 代价:通信次数 × $V$
陷阱:说"interleaved 减计算量"——错,只重排时间轴;或忘了通信 × V 的代价。
Q16. MoE 的 EP all-to-all 通信量?
- 每 token 选 $K$ 个 expert
- Dispatch: 每 GPU 把本机 $B/N$ token 发到对应 expert 的 GPU
- Combine: 反向
- 总 per-rank 通信 $\approx 2 \cdot K \cdot B/N \cdot D$(双向 all-to-all)
陷阱:忘了 top-K(不是 $E$);忘了 dispatch + combine 两次。
Q17. HSDP 和 FSDP 区别?
- FSDP:所有 GPU 同一 sharding group
- HSDP:组内 FSDP / ZeRO-3,组间 DDP
- HSDP 单 group 内通信少(节点内 NVLink),组间用 grad all-reduce(不需 weight all-gather)
- trade-off:组内 model states 多(不切到全 world),换跨节点通信减
陷阱:以为 HSDP 减总通信——精度上是减跨节点通信,组内 / 组间是 trade-off。
Q18. Llama 3 405B 用了哪些并行?
- 16K H100(总 16384,保持不变;CP 上去时 DP 下来)
- 短 context (8K):TP=8 × CP=1 × PP=16 × DP=128
- 长 context (128K):TP=8 × CP=16 × PP=16 × DP=8
- 54 天共 466 次中断(419 unexpected + 47 planned),90%+ 有效训练时间
- 训练用 BF16(不是 FP8——FP8 是推理量化)
- Meta 2024, arXiv 2407.21783
陷阱:说用了 EP——错,Llama 3 是 dense 模型不用 EP;或忘了 CP 这一阶段的存在。
Q19. ZeRO++ 三个 trick?
- qwZ:forward all-gather 用 int8 量化(block-wise quant)
- hpZ:weight 在节点内复制(NVLink 域),跨节点仍切——backward all-gather 走 NVLink
- qgZ:backward gradient reduce-scatter 也走 int8
- 总通信量降 4×,throughput 384 GPU 上 +116% (Wang 2023)
陷阱:以为 ZeRO++ 改训练精度——只在通信途中量化 buffer,权重和计算仍是 fp16/bf16。
Q20. 一个 7B 模型在 8 张 A100-40G 上能跑 fp16 训练吗?
- Model states (Adam): $16 \times 7$B $= 112$ GB
- DDP: 单卡 112 GB,单卡放不下(A100 40G)
- ZeRO-3 / FSDP: 单卡 $112/8 = 14$ GB ✓
- 加 activation (with checkpoint): 几 GB ✓
- 加 workspace: 几 GB
- 结论:FSDP/ZeRO-3 + 激活检查点可以跑
陷阱:说"DDP 也能跑";或忘了 model states 而只算参数 $\Phi$;或忘了 fp32 master copy。
L3 高级变体(顶级 lab / 自研基建岗)
Q21. 推导 1F1B bubble ratio + interleaved 怎么进一步减小。
- 1F1B:warm-up $P-1$ step (forward fill) + cool-down $P-1$ step (backward drain) + steady $M-P+1$ step
- 总 step $= 2M$(forward + backward)需要 $2(M + P - 1)$ 时间槽(含 warm/cool 各 $P-1$ idle 槽)
- bubble = $2(P-1) / [2(M+P-1)] = (P-1)/(M+P-1)$
- Interleaved 1F1B(Narayanan SC 2021, Eq. 4):每 GPU 持 $V$ 段不连续 layer (virtual stages),每 micro-batch 在 stage 列上走 $V$ 圈 → 流水里等效 $V \cdot M$ 个 micro-batch 在排队
- bubble = $(P-1) / (V \cdot M + P - 1) \approx (P-1)/(V \cdot M)$
- 同 $M$ 下 bubble 降 $V$ 倍;代价:跨 GPU send/recv 次数 $\times V$
陷阱:忘了 interleaved 通信 × V 的代价;说"interleaved 减 forward 计算"——只重排时间。
Q22. FSDP vs ZeRO-3 通信量 + 各自适用场景。
- 通信量完全一样:forward all-gather $\Phi$ + backward all-gather $\Phi$ + reduce-scatter $\Phi$ = $3\Phi$
工程差异:
- FSDP2: 用 DTensor 描述 per-param sharding,与 TP / PP 复合自然,
torch.compile友好 - DeepSpeed ZeRO-3: 用 flat-parameter,要 monkey patch 拼装,但 ZeRO++ / Offload / Infinity 生态完整
- FSDP2: 用 DTensor 描述 per-param sharding,与 TP / PP 复合自然,
选型:
- 新项目(dense + 4D 并行)→ FSDP2 / TorchTitan
- MoE + 跨框架 + offload → DeepSpeed
陷阱:说"FSDP 比 ZeRO 通信少"——算法相同;或说"FSDP 不支持 offload"——FSDP2 OffloadPolicy 已支持。
Q23. TP + SP 相比纯 TP 省多少 activation memory?
- 设 transformer block activation $A_\text{block}$
- TP-内 (attention 中间 / MLP 中间):约 $A_\text{block} \times 0.5-0.7$,TP 切了
- TP-外 (LayerNorm / Dropout / residual):约 $A_\text{block} \times 0.3-0.5$,TP 不切(全副本)
- 纯 TP 单卡 activation = $A_\text{TP-内}/T + A_\text{TP-外}$
- TP+SP 单卡 activation = $A_\text{TP-内}/T + A_\text{TP-外}/T = A_\text{block}/T$
- 省了 $A_\text{TP-外} \cdot (1 - 1/T)$,约总 activation 的 25-40%(取决于模型)
陷阱:说"SP 让 activation 减 $T$ 倍"——不准确,只对 TP-外那部分;或忘了 SP 通信量与纯 TP 一样。
Q24. DualPipe 相对 1F1B 的核心改进?
- 双向流水线:一组 micro-batch 从 stage 0 → P,另一组从 P → 0
- 两组在中间相遇时刚好填满 1F1B 的 warm-up / cool-down bubble
- 理论 bubble = 0(理想下两侧互补)
- 关键收益:all-to-all 通信完全 overlap(EP 重要)
- 代价:activation memory $\times 2$(两个方向都要存);实现极复杂
- DeepSeek-V3 December 2024 报告 (arXiv 2412.19437) + github.com/deepseek-ai/DualPipe
陷阱:说"DualPipe 是新的 PP 切法"——它是新的 schedule;或忘了 2× activation 的代价。
Q25. 现在 frontier 训练(如 DeepSeek-V3 / Llama 3)的 4D / 5D parallelism 怎么组合?
- 5 个正交维度:DP / FSDP × TP × PP × CP × EP
- World size $W = D \times T \times P \times C \times E$
经验法则:
- TP=8(一个节点 NVLink 域内必须塞下)
- PP=8-32(跨节点 IB ok,bubble 控制 $M \geq 4P$)
- CP 看 context 长度(Llama 3 在 128K 用 CP=16,1M 可能 CP=64+)
- EP 看 MoE expert 数(DeepSeek-V3 用 EP=64)
- FSDP / DP 用剩下的 world size
Llama 3 405B(保持总 16K GPU 不变,CP 上 DP 下):
- 8K context: TP=8 × CP=1 × PP=16 × DP=128
- 128K context: TP=8 × CP=16 × PP=16 × DP=8
- DeepSeek-V3 671B: TP=1 + PP=16 × EP=64 × ZeRO-1 DP=2,共 2048 H800
关键工程点:
- 通信原语按拓扑放(TP 走 NVLink,PP / DP 走 IB)
- DualPipe / Interleaved 1F1B 压 PP bubble
- FlashAttention v3 + SP / CP 压 activation
- fp8 GEMM(H100 / B100)+ bf16 accumulation 提 throughput
陷阱:把维度顺序搞反;忘了 TP 必须节点内;以为 DeepSeek 用了 TP(V3 实际 TP=1)。
§A 附录:完整 4D wrap 代码骨架
下面是一个 4D parallelism 的 minimal 端到端 wrap 示例(FSDP2 + TP + SP + PP),按 TorchTitan 风格。
import torch
import torch.nn as nn
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import (
parallelize_module, RowwiseParallel, ColwiseParallel,
SequenceParallel, PrepareModuleInput, PrepareModuleOutput,
)
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed.pipelining import pipeline, SplitPoint, ScheduleInterleaved1F1B
from torch.utils.checkpoint import checkpoint
def build_4d_model(model, world_size):
""" 4D parallelism wrap (PP × FSDP × CP × TP) """
# Step 1. 建 4D device mesh
# 例: 64 GPU = PP=2 × FSDP=4 × CP=1 × TP=8
mesh = init_device_mesh(
"cuda",
mesh_shape=(2, 4, 1, 8),
mesh_dim_names=("pp", "fsdp", "cp", "tp"),
)
# Step 2. TP + SP wrap (mesh["tp"])
tp_plan = {
# Attention
"attn.wq": ColwiseParallel(),
"attn.wk": ColwiseParallel(),
"attn.wv": ColwiseParallel(),
"attn.wo": RowwiseParallel(),
# MLP
"mlp.fc1": ColwiseParallel(),
"mlp.fc2": RowwiseParallel(),
# SP: norm / residual 沿 seq 维切
"ln1": SequenceParallel(),
"ln2": SequenceParallel(),
}
for block in model.blocks:
parallelize_module(block, mesh["tp"], tp_plan)
# Step 3. Activation checkpoint (每 block 一次)
for i, block in enumerate(model.blocks):
model.blocks[i] = _ckpt_wrap(block)
# Step 4. PP split (mesh["pp"])
# 把 model 切成 2 段(PP=2)
split_spec = {"blocks.16": SplitPoint.BEGINNING}
pipe = pipeline(
model,
mb_args=(torch.randn(1, 4096, device="cuda"),),
split_spec=split_spec,
)
pp_stage = pipe.build_stage(
stage_index=mesh["pp"].get_local_rank(),
device=torch.device(f"cuda:{torch.cuda.current_device()}"),
group=mesh["pp"].get_group(),
)
# Step 5. FSDP2 wrap (mesh["fsdp"])
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
)
for module in pp_stage.submod.modules():
if isinstance(module, TransformerBlock):
fully_shard(module, mesh=mesh["fsdp"], mp_policy=mp_policy)
fully_shard(pp_stage.submod, mesh=mesh["fsdp"], mp_policy=mp_policy)
return pp_stage, mesh
def _ckpt_wrap(block):
""" 把一个 TransformerBlock 包成 activation checkpoint 形式 """
class Wrapped(nn.Module):
def __init__(self, inner):
super().__init__()
self.inner = inner
def forward(self, *args, **kw):
return checkpoint(self.inner, *args, use_reentrant=False, **kw)
return Wrapped(block)
# ============================================================
# Training loop with PP schedule
# ============================================================
def train_step_4d(pp_stage, schedule, optimizer, batch):
""" 1 step of 4D parallel training """
optimizer.zero_grad(set_to_none=True)
# PP schedule runs forward + backward across all stages
losses = []
schedule.step(batch, losses=losses) # 内部触发 FSDP all-gather / TP all-reduce / etc.
optimizer.step()
return torch.stack(losses).mean()
说明:上述代码是教学骨架,实际生产请直接用 TorchTitan repo(pytorch/torchtitan)——其包含完整的 Trainer、checkpointing、profiling、loss / lr schedule,本节只给概念示意。
§B 参考资料
- DDP: PyTorch documentation,
nn.parallel.DistributedDataParallel - ZeRO: Rajbhandari et al., "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models", SC 2020 (arXiv:1910.02054)
- ZeRO-Offload: Ren et al., USENIX ATC 2021
- ZeRO-Infinity: Rajbhandari et al., SC 2021
- ZeRO++: Wang et al., "ZeRO++: Extremely Efficient Collective Communication for Giant Model Training", arXiv:2306.10209, 2023
- FSDP: Zhao et al., "PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel", VLDB 2023
- FSDP2 / TorchTitan: Liang et al., "TorchTitan: One-stop PyTorch Native Solution", ICLR 2025 (arXiv:2410.06511)
- Megatron-LM TP: Shoeybi et al., "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism", arXiv:1909.08053, 2019
- Megatron-LM-2 (Interleaved 1F1B): Narayanan et al., "Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM", SC 2021
- GPipe: Huang et al., "GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism", NeurIPS 2019
- PipeDream / 1F1B: Narayanan et al., "PipeDream: Generalized Pipeline Parallelism for DNN Training", SOSP 2019
- Sequence Parallel + Selective Recompute: Korthikanti et al., "Reducing Activation Recomputation in Large Transformer Models", MLSys 2023
- Ring Attention: Liu et al., "Ring Attention with Blockwise Transformers for Near-Infinite Context", ICLR 2024 (arXiv:2310.01889)
- Gradient Checkpointing: Chen et al., "Training Deep Nets with Sublinear Memory Cost", arXiv:1604.06174, 2016
- DualPipe: DeepSeek-V3 Technical Report, arXiv:2412.19437, December 2024
- Llama 3: Meta AI, "The Llama 3 Herd of Models", arXiv:2407.21783, 2024
- GShard / Switch Transformer (MoE EP): Lepikhin et al. 2020 / Fedus et al. JMLR 2022