Part 9 of How To Scale Your Model (第8部分:服务 LLaMA | 第10部分:JAX)
前面几章全是理论推导。理论能让你走很远,但真到了优化的时候,还得看实际情况:XLA 编译器干了什么?哪里慢了?这章教你用 Profiler 工具找出问题所在。
前面几章我们一直在做”纸上谈兵”:用 roofline 模型估算性能上限。
但实际优化时,你需要知道:
这就需要 Profiler(性能分析器)。
写 TPU 程序有几个层次,从高到低:
| 层次 | 是什么 | 谁用 |
|---|---|---|
| JAX | NumPy 风格的高级 API | 大多数程序员 |
| StableHLO | 平台无关的中间表示 | XLA 编译器 |
| HLO | 硬件相关的中间表示 | Profiler 显示的 |
| LLO | 低级优化器,直接操作 TPU | 内部 |
| 机器码 | TPU 执行的二进制 | 内部 |
我们写的是 JAX,看的是 HLO。
import jax
import jax.numpy as jnp
def multiply(x, y):
return jnp.einsum('bf,fd->db', x, y)
y = jax.jit(multiply)(jnp.ones((128, 256)), jnp.ones((256, 16), dtype=jnp.bfloat16))
jax.jit 告诉 JAX:追踪这个函数,编译成高效代码。
编译后的 HLO 大概长这样:
ENTRY %main.5 (Arg_0.1: f32[128,256], Arg_1.2: bf16[256,16]) -> f32[16,128] {
%Arg_1.2 = bf16[256,16]{1,0} parameter(1), metadata={op_name="y"}
%convert.3 = f32[256,16]{1,0} convert(bf16[256,16]{1,0} %Arg_1.2),
%Arg_0.1 = f32[128,256]{1,0} parameter(0), metadata={op_name="x"}
ROOT %dot.4 = f32[16,128]{1,0} dot(f32[256,16]{1,0} %convert.3, f32[128,256]{1,0} %Arg_0.1),
lhs_contracting_dims={0}, rhs_contracting_dims={1},
}
别慌,这其实很好读。dot.4 就是那个矩阵乘法,沿着维度 0 和 1 收缩。
当程序跑得慢时,我们用 Profiler 看 HLO 层面发生了什么。如果 HLO 层面都解决不了,就用 Pallas 写自定义内核。
import jax
with jax.profiler.trace("/tmp/tensorboard"):
key = jax.random.key(0)
x = jax.random.normal(key, (1024, 1024))
y = x @ x
y.block_until_ready() # 等待计算完成
# 然后在终端运行:
# tensorboard --logdir=/tmp/tensorboard
# 或者在 Colab 里:
# %load_ext tensorboard
# %tensorboard --logdir=/tmp/tensorboard
三个最有用的标签:
| 标签 | 看什么 |
|---|---|
| Trace Viewer | 时间线,看每个操作花了多久 |
| Graph Viewer | 计算图,看操作之间怎么连接 |
| Memory Profile | 内存使用随时间的变化 |
想先体验一下?这里有个在线 Perfetto 链接:简单 Transformer 的 Trace
或者用这个 Colab 生成完整 Profile 自己玩。
这是一个 Transformer 的 Trace。可以看到:
导航技巧:W/S 放大缩小,A/D 左右移动。像游戏一样!
看到这种东西不要怕:
%fusion.3 = bf16[32,32,4096]{2,1,0:T(8,128)(2,1)S(1)} fusion(bf16[32,32,8192]{...} %fusion.32), kind=kCustom, calls=%all-reduce-scatter.3
拆开来看:
| 部分 | 含义 |
|---|---|
fusion.3 | 操作名 |
bf16[32,32,4096] | 输出类型和形状 |
{2,1,0:T(8,128)(2,1)} | 内存布局和 tiling |
S(1) | 存储位置:S(0)=HBM, S(1)=VMEM |
fusion(...) | 输入参数 |
kind=kCustom | 操作类型 |
关于 Tiling
{1,0:T(2,2)} 是什么意思?
1,0:维度在内存中的顺序(从右往左读)T(2,2):以 2×2 块 tiling更复杂的例子:bf16[4,8]{1,0,T(2,4)(2,1)}
两层 tiling:外层 2×4,内层 2×1(bf16 需要 4 字节对齐)。
为什么 Tiling 重要?
有时候 XLA 会插入”重新布局”操作来调整 tiling,这会带来开销。如果你在 profile 里看到很多 copy 操作,可能就是这个问题。
HLO 操作太复杂?Graph Viewer 把它可视化了。
鼠标悬停在节点上,可以看到对应的代码行。
多盯着看几遍,试着把 HLO 操作和你的代码对应起来。
这是一个假 Transformer 的 Profile:
用 这个 Colab 自己生成一个来玩。
放大 FFN 块。up-projection 操作:
bf16[8, 1024, 8192] × bf16[8192, 16384] bf16[32, 1024, 16384] 这是 4 路 DP + 2 路 TP 分片后的本地视图。全局形状:
X: bf16[32, 1024, 8192] × W_in: bf16[8192, 32768] → Tmp: bf16[32, 1024, 32768]
验算一下时间:
每分片批次 = 8 × 1024 = 8192 token → 计算受限 ✓
理论时间 = 2 × 32 × 1024 × 8192 × 32768 / (23e12 × 8) = 95.6ms
实际时间 = 96ms
几乎完美命中 roofline!
第二个 matmul 末尾有个小操作:
%fusion.1 = bf16[8,1024,4096]{...} fusion(...), kind=kCustom, calls=%all-reduce-scatter.1
这是个 ReduceScatter。
验算:
数组大小 = 2 × 32 × 1024 × 8192 = 537MB(全局) 每分片 = 537 / 4 = 134MB
单跳 ICI 带宽 = 1.2e11 B/s
理论时间 = 134e6 / 1.2e11 = 1.1ms
实际时间 = 1.13ms
又命中了!
Q 投影用的矩阵:[d_model=8192, n_heads=32, d_qkv=256]
沿头维度做 Megatron 分片。
试试自己验算这些操作应该花多久?
这个视图显示内存使用随时间的变化。
例子里可以看到:
对调试 OOM 很有帮助:找到峰值是在哪里,是什么操作导致的。
任务(先不要看代码!只看 Profile):
这是两个矩阵乘法:
def matmul(w1, w2, x):
return jnp.einsum('wf,bf->bw', w2, jnp.einsum('fw,bw->bf', w1, x))
Profile 里可以看到:reduce → 两个大 fusion → all-reduce
第一个 fusion:
%fusion.1 = bf16[4096]{...} fusion(bf16[4096,8192]{...} %param.1, bf16[8192]{...} %reduce.6)
每个分片:bf16[8192] × bf16[4096, 8192] → bf16[4096]
AllReduce 的 replica_groups 显示 8 组 → 8 路张量并行
全局形状:bf16[8, 8192] × bf16[32768, 8192] → bf16[8, 32768]
用这个 Colab 里的简单 Transformer:
jax.lax.with_sharding_constraint 试着优化参考数据:
完成后,纯从 Profile 回答: