Part 5 of How To Scale Your Model (第4部分:Transformer | 第6部分:训练 LLaMA)
训练大模型,一张卡肯定不够。这章我们聊聊四种『分而治之』的方法:数据并行、FSDP、张量并行、流水线并行。每种方法各有利弊,关键是搞清楚什么时候通信会拖后腿。
核心问题:我有一堆芯片,怎么让它们高效协作?
理想情况下,芯片数量翻倍,训练速度也翻倍——这叫强扩展。但现实没这么美好。芯片越多,它们之间的”沟通成本”也越高。如果沟通时间超过了干活时间,加再多芯片也是浪费。
打个比方:一个人搬砖很慢,两个人可以快一倍。但如果 100 个人挤在一起搬,光是”你往左我往右”的协调就够呛,效率反而可能下降。
本章目标:搞清楚四种”分活儿”的方法,以及每种方法什么时候会被”沟通成本”拖累。
| 策略 | 一句话解释 |
|---|---|
| 数据并行 | 每张卡都有完整模型,各自算不同的数据,最后把梯度汇总 |
| FSDP(ZeRO) | 模型参数切成碎片分给各卡,用的时候再拼起来 |
| 张量并行 | 每个矩阵乘法都分给多张卡一起算,算完再合并 |
| 流水线并行 | 模型按层切开,数据像流水线一样一层层往下传 |
为了后面计算方便,我们统一用这些符号:
模型参数:
| 符号 | 含义 |
|---|---|
| D | 隐藏维度(d_model) |
| F | FFN 中间维度(d_ff,通常是 4D 或 8D) |
| B | 批次大小(总 token 数,不是每卡的) |
| T | 序列长度 |
| L | 层数 |
硬件参数:
| 符号 | 含义 |
|---|---|
| C | 每芯片的 FLOPs/s |
| W | 网络带宽(双向),比如 $W_{ici}$ 表示 ICI 带宽 |
| X, Y, Z | 网格各轴的芯片数 |
为了聚焦核心问题,我们做两个简化:
前向传播:
反向传播:
一句话:每张卡都有完整模型,各自算不同的数据批次,最后平均梯度。
分片公式:\(\text{In}[B_X, D] \cdot W_\text{in}[D, F] \cdot W_\text{out}[F, D] \rightarrow \text{Out}[B_X, D]\)
B_X 表示把 B 切成 X 份,每卡只处理 B/X 的数据。
工作流程:
前向传播(无通信):
反向传播:
优点:
缺点:
要点:数据并行最大能训练的模型 ≈ HBM 容量 ÷ 10。对于 TPU v5p,约 90 亿参数。
计算时间:\(T_{计算} = \frac{8 \cdot B \cdot D \cdot F}{X \cdot C}\)
通信时间:\(T_{通信} = \frac{8 \cdot D \cdot F}{W_{ici}}\)
(8 = 2 个矩阵 × 2 次 AllReduce × 2 字节)
要想计算受限(通信能被计算掩盖),需要:
\[\frac{B}{X} > \frac{C}{W_{ici}}\]翻译成人话:每卡的批次大小,要超过”ICI 算术强度”。
对于 TPU v5p:
也就是说,每卡至少要处理 2550 个 token,否则就会被通信拖累。
如果用三个轴都做数据并行,带宽变成 3 倍,阈值降到 850。但即使这样,一个 pod(8960 芯片)也需要 760 万 token 的批次才能跑满。
结论:纯数据并行被通信卡住的情况其实不多见!
上下文并行:这里的 B 是”总 token 数”。MLP 不在乎 token 是来自同一个序列还是不同序列,所以可以沿序列维度做数据并行(叫”上下文并行”)。注意力需要特殊处理(环形注意力),但 MLP 完全不用管。
一句话:不光切数据,连模型参数和优化器状态也切了。用的时候再临时拼起来。
分片公式:\(\text{In}[B_X, D] \cdot W_\text{in}[D_X, F] \cdot W_\text{out}[F, D_X] \rightarrow \text{Out}[B_X, D]\)
注意:权重的收缩维度(D)也被切了!
核心思想:
还记得吗?AllReduce = AllGather + ReduceScatter。
既然反向传播要做 AllReduce,不如这样:
通信量完全一样,但内存省了 X 倍!这就是为什么叫”ZeRO”(Zero Redundancy Optimizer)。
前向传播:
反向传播:
ZeRO-1/2/3 是什么?
通信量都一样,所以一般直接用 ZeRO-3。
和数据并行完全一样!因为 AllReduce = AllGather + ReduceScatter,通信总量没变。
\[\frac{B}{X} > \frac{C}{W_{ici}} = 2550\]要点:FSDP 和数据并行的通信门槛一样,但 FSDP 省内存。如果你的数据并行能跑,换成 FSDP 只有好处没坏处!
实际例子:
DeepSeek-V2 用了 4000 万 token 的批次。这意味着可以扩展到约 47000 芯片(~5 个 TPU v5p pod)而不被通信限制。
LLaMA-3 70B 用 1600 万 token 批次,可以分到约 18000 芯片(~2 个 pod)。
临界批次大小:有个反直觉的事实——批次越小,越容易被通信卡住。因为通信量是固定的(和模型大小相关),但计算量随批次变小。这就是为什么 DeepSeek 等模型用超大批次训练。
一句话:不切数据,切模型。每个矩阵乘法都让多张卡一起算。
分片公式:\(\text{In}[B, D_Y] \cdot W_\text{in}[D, F_Y] \cdot W_\text{out}[F_Y, D] \rightarrow \text{Out}[B, D_Y]\)
用 Y 表示张量并行轴(后面会和 FSDP 的 X 轴组合)。
工作流程:
关键区别:FSDP 移动的是权重,张量并行移动的是激活。
前向传播:
反向传播:同理,也需要 AllGather 和 ReduceScatter
巧妙之处:
两个矩阵配合得刚刚好!
这样,一进一出正好配对:进的时候 AllGather D,出的时候 ReduceScatter D。
要计算受限:\(F > Y \cdot \frac{C}{W_{ici}} = Y \times 2550\)
也就是说,张量并行的路数不能超过 F / 2550。
要点:张量并行最多做到 F / 2550 路。对于大多数模型(F≈30000),就是 8-16 路。再多就会被通信卡住。
实际例子:
有趣的是:这个门槛和批次大小无关!因为通信量和计算量都与 B 成正比,抵消了。
一句话:两个维度一起切,既省内存又能用小批次。
分片公式:\(\text{In}[B_X, D_Y] \cdot W_\text{in}[D_X, F_Y] \cdot W_\text{out}[F_Y, D_X] \rightarrow \text{Out}[B_X, D_Y]\)
X 轴做 FSDP,Y 轴做张量并行。
为什么要混合?
当 B 变小时:
所以,批次小的时候多用张量并行,批次大的时候多用 FSDP。
前向传播:
设 N = X × Y 是总芯片数,最优 FSDP 分片数是:
\[X_{opt} = \sqrt{\frac{B}{F} \cdot \frac{M_X}{M_Y} \cdot N}\]其中 M_X、M_Y 是各方向的网格轴数(大约各占一半,乘积约为 2)。
实际例子:
代入公式:X ≈ 14,所以用 X=16 做 FSDP,Y=4 做张量并行。
其中 α = C/W ≈ 2550。
代入 F=32000, M_X M_Y=2:
\[\frac{B}{N} > \frac{2550^2}{2 \times 32000} \approx 100\]要点:混合 FSDP+TP 可以把每芯片批次降到约 100 token!这比纯 FSDP 的 850 小了 8 倍多。
下面是交互式演示,可以拖动滑块调整批次大小:
一句话:按层切模型,数据像流水线一样一层层传下去。
GPU 世界用得很多,TPU 上不太必要(因为 ICI 带宽够大)。
基本流程:
batch_size = 32
d_model = 128
num_layers = len(jax.devices())
x = jax.random.normal(key, (batch_size, d_model))
weights = jax.random.normal(key, (num_layers, d_model, d_model))
# 前向传播
for i in range(num_layers):
x = x @ weights[i]
if i != num_layers - 1:
x = jax.device_put(x, jax.devices()[i+1])
# 反向传播
loss, dx = jax.value_and_grad(loss_fn)(x)
for i in range(num_layers-1, -1, -1):
_, f_vjp = jax.vjp(layer_fn, intermediates[i], weights[i])
dx, dw = f_vjp(dx)
if i != 0:
dx = jax.device_put(dx, jax.devices()[i-1])
优点:
缺点:
解决气泡的方法:
因为 TPU 有很强的 ICI,流水线并行不是必需品。我们一般用 FSDP + TP 就够了。
一个 TPU v5p Pod 最大 8960 芯片。想要更多?得走 DCN(数据中心网络)。
DCN 带宽:每 TPU 约 6.25GB/s(比 ICI 慢 30 倍)
常见策略:
也就是说,每个 Pod 至少要处理 7.3 万多 token,否则 DCN 带宽不够用。
要点:跨 Pod 数据并行,每 Pod 需要至少 7.3 万 token 的批次。
实际例子:
训练 LLaMA-3 70B,批次 200 万 token:
| 策略 | 分片公式 |
|---|---|
| 数据并行 | In[B_X, D] · W_in[D, F] · W_out[F, D] → Out[B_X, D] |
| FSDP | In[B_X, D] · W_in[D_X, F] · W_out[F, D_X] → Out[B_X, D] |
| 张量并行 | In[B, D_Y] · W_in[D, F_Y] · W_out[F_Y, D] → Out[B, D_Y] |
| FSDP + TP | In[B_X, D_Y] · W_in[D_X, F_Y] · W_out[F_Y, D_X] → Out[B_X, D_Y] |
| 策略 | 每层计算 | 每层通信(字节,前向+反向) |
|---|---|---|
| DP | 12BDF/X | 0 + 8DF |
| FSDP | 12BDF/X | 4DF + 8DF |
| TP | 12BDF/Y | 4BD + 4BD |
| FSDP+TP | 12BDF/(XY) | (4BD/X + 4DF/Y) + (8BD/X + 8DF/Y) |
| 策略 | 通信受限条件 | TPU v5p 数值 |
|---|---|---|
| DP/FSDP | B/X < C/W_ici | 每卡 < 2550 token(单轴) 每卡 < 850 token(三轴) |
| 张量并行 | Y > F/2550 | 超过 8-16 路 |
| FSDP+TP | B/N < α²/(2F) | 每卡 < 100 token |
| 跨 Pod | B/Pod < C/W_dcn | 每 Pod < 73440 token |
TPU v5p(96GB)最多放 90 亿参数(纯数据并行)
用 LLaMA-2 13B 作为例子:
| 参数 | 值 |
|---|---|
| L(层数) | 40 |
| D(隐藏维度) | 5120 |
| F(FFN 维度) | 13824 |
| H(头维度) | 128 |
| V(词表大小) | 32000 |
问题 1:验证一下参数量确实是 130 亿。
问题 2:用 BS=1600 万 token 和 Adam 训练,总内存需求是多少?
参数 + 优化器:(2 + 4 + 4) × 13×10⁹ = 130GB
激活(每层存 3 个检查点):
总计:约 42TB
问题 3:在 TPU v5p 16×16×16(4096 芯片)上,用 300 万 token 批次训练:
a) 能用纯数据并行吗? b) 能用纯 FSDP 吗? c) 应该怎么配置 FSDP + TP?
a) 不能。纯数据并行需要每卡存完整模型(130GB),但 TPU v5p 只有 96GB。
b) 勉强不行。内存没问题(300万 token 只需要 ~8TB 激活),但:
c) 可以用混合策略:
前向传播:Out = In × W_in × W_out
反向传播需要计算四个量:
如何确定通信:
关键洞察:dOut 的分片方式和 Out 相同(都是输出),所以反向传播的通信模式和前向传播对称。