如何分析 TPU 程序性能

Part 9 of How To Scale Your Model (第8部分:服务 LLaMA | 第10部分:JAX)

前面几章全是理论推导。理论能让你走很远,但真到了优化的时候,还得看实际情况:XLA 编译器干了什么?哪里慢了?这章教你用 Profiler 工具找出问题所在。

为什么需要 Profiling?

前面几章我们一直在做”纸上谈兵”:用 roofline 模型估算性能上限。

但实际优化时,你需要知道:

这就需要 Profiler(性能分析器)。


TPU 软件栈全景

写 TPU 程序有几个层次,从高到低:

层次 是什么 谁用
JAX NumPy 风格的高级 API 大多数程序员
StableHLO 平台无关的中间表示 XLA 编译器
HLO 硬件相关的中间表示 Profiler 显示的
LLO 低级优化器,直接操作 TPU 内部
机器码 TPU 执行的二进制 内部

我们写的是 JAX,看的是 HLO。

一个简单例子

import jax
import jax.numpy as jnp

def multiply(x, y):
    return jnp.einsum('bf,fd->db', x, y)

y = jax.jit(multiply)(jnp.ones((128, 256)), jnp.ones((256, 16), dtype=jnp.bfloat16))

jax.jit 告诉 JAX:追踪这个函数,编译成高效代码。

编译后的 HLO 大概长这样:

ENTRY %main.5 (Arg_0.1: f32[128,256], Arg_1.2: bf16[256,16]) -> f32[16,128] {
  %Arg_1.2 = bf16[256,16]{1,0} parameter(1), metadata={op_name="y"}
  %convert.3 = f32[256,16]{1,0} convert(bf16[256,16]{1,0} %Arg_1.2),
  %Arg_0.1 = f32[128,256]{1,0} parameter(0), metadata={op_name="x"}
  ROOT %dot.4 = f32[16,128]{1,0} dot(f32[256,16]{1,0} %convert.3, f32[128,256]{1,0} %Arg_0.1),
       lhs_contracting_dims={0}, rhs_contracting_dims={1},
}

别慌,这其实很好读。dot.4 就是那个矩阵乘法,沿着维度 0 和 1 收缩。

当程序跑得慢时,我们用 Profiler 看 HLO 层面发生了什么。如果 HLO 层面都解决不了,就用 Pallas 写自定义内核。


TensorBoard Profiler 使用指南

怎么生成 Profile

import jax

with jax.profiler.trace("/tmp/tensorboard"):
    key = jax.random.key(0)
    x = jax.random.normal(key, (1024, 1024))
    y = x @ x
    y.block_until_ready()  # 等待计算完成

# 然后在终端运行:
# tensorboard --logdir=/tmp/tensorboard

# 或者在 Colab 里:
# %load_ext tensorboard
# %tensorboard --logdir=/tmp/tensorboard

Profiler 能看什么?

三个最有用的标签:

标签 看什么
Trace Viewer 时间线,看每个操作花了多久
Graph Viewer 计算图,看操作之间怎么连接
Memory Profile 内存使用随时间的变化

想先体验一下?这里有个在线 Perfetto 链接:简单 Transformer 的 Trace

或者用这个 Colab 生成完整 Profile 自己玩。


Trace Viewer:时间线视图

这是一个 Transformer 的 Trace。可以看到:

  1. 顶行(XLA Ops):实际的 TPU 操作,名字是 HLO 名字
  2. 重复的块:每个重复就是一层
  3. 点击操作:可以看到对应的代码位置

导航技巧:W/S 放大缩小,A/D 左右移动。像游戏一样!


怎么读 HLO 代码

看到这种东西不要怕:

%fusion.3 = bf16[32,32,4096]{2,1,0:T(8,128)(2,1)S(1)} fusion(bf16[32,32,8192]{...} %fusion.32), kind=kCustom, calls=%all-reduce-scatter.3

拆开来看:

部分 含义
fusion.3 操作名
bf16[32,32,4096] 输出类型和形状
{2,1,0:T(8,128)(2,1)} 内存布局和 tiling
S(1) 存储位置:S(0)=HBM, S(1)=VMEM
fusion(...) 输入参数
kind=kCustom 操作类型

关于 Tiling

{1,0:T(2,2)} 是什么意思?

更复杂的例子:bf16[4,8]{1,0,T(2,4)(2,1)}

两层 tiling:外层 2×4,内层 2×1(bf16 需要 4 字节对齐)。

为什么 Tiling 重要?

有时候 XLA 会插入”重新布局”操作来调整 tiling,这会带来开销。如果你在 profile 里看到很多 copy 操作,可能就是这个问题。


Graph Viewer:计算图视图

HLO 操作太复杂?Graph Viewer 把它可视化了。

鼠标悬停在节点上,可以看到对应的代码行。

多盯着看几遍,试着把 HLO 操作和你的代码对应起来。


实战:看一个真实 Profile

这是一个假 Transformer 的 Profile:

这个 Colab 自己生成一个来玩。

FFN 块分析

放大 FFN 块。up-projection 操作:

这是 4 路 DP + 2 路 TP 分片后的本地视图。全局形状:

X: bf16[32, 1024, 8192] × W_in: bf16[8192, 32768] → Tmp: bf16[32, 1024, 32768]

验算一下时间

每分片批次 = 8 × 1024 = 8192 token → 计算受限 ✓

理论时间 = 2 × 32 × 1024 × 8192 × 32768 / (23e12 × 8) = 95.6ms

实际时间 = 96ms

几乎完美命中 roofline!

通信分析

第二个 matmul 末尾有个小操作:

%fusion.1 = bf16[8,1024,4096]{...} fusion(...), kind=kCustom, calls=%all-reduce-scatter.1

这是个 ReduceScatter。

验算

数组大小 = 2 × 32 × 1024 × 8192 = 537MB(全局) 每分片 = 537 / 4 = 134MB

单跳 ICI 带宽 = 1.2e11 B/s

理论时间 = 134e6 / 1.2e11 = 1.1ms

实际时间 = 1.13ms

又命中了!

注意力块分析

Q 投影用的矩阵:[d_model=8192, n_heads=32, d_qkv=256]

沿头维度做 Megatron 分片。

试试自己验算这些操作应该花多久?


Memory Profile:内存视图

这个视图显示内存使用随时间的变化。

例子里可以看到:

对调试 OOM 很有帮助:找到峰值是在哪里,是什么操作导致的。


练习题

问题 1:找 Bug

看看这个 Colab/Profile

任务(先不要看代码!只看 Profile):

答案

这是两个矩阵乘法:

def matmul(w1, w2, x):
    return jnp.einsum('wf,bf->bw', w2, jnp.einsum('fw,bw->bf', w1, x))

Profile 里可以看到:reduce → 两个大 fusion → all-reduce

第一个 fusion:

%fusion.1 = bf16[4096]{...} fusion(bf16[4096,8192]{...} %param.1, bf16[8192]{...} %reduce.6)

每个分片:bf16[8192] × bf16[4096, 8192] → bf16[4096]

AllReduce 的 replica_groups 显示 8 组 → 8 路张量并行

全局形状:bf16[8, 8192] × bf16[32768, 8192] → bf16[8, 32768]


问题 2:优化 Transformer

这个 Colab 里的简单 Transformer:

  1. 生成基准 Profile
  2. 每个部分花了多久?应该花多久?
  3. 用了什么分片策略?
  4. jax.lax.with_sharding_constraint 试着优化

参考数据

完成后,纯从 Profile 回答:


下一章我们深入看 JAX 并行化

Miscellaneous

*Work done at Google DeepMind, now at MatX.

Citation

For attribution in academic contexts, please cite this work as:

    Austin et al., "How to Scale Your Model", Google DeepMind, online, 2025.

or as a BibTeX entry:

    @article{scaling-book,
      title = {How to Scale Your Model},
      author = {Austin, Jacob and Douglas, Sholto and Frostig, Roy and Levskaya, Anselm and Chen, Charlie and Vikram, Sharad
      and Lebron, Federico and Choy, Peter and Ramasesh, Vinay and Webson, Albert and Pope, Reiner},
      publisher = {Google DeepMind},
      howpublished = {Online},
      note = {Retrieved from https://jax-ml.github.io/scaling-book/},
      year = {2025}
    }