Part 2 of How To Scale Your Model (第1部分:Roofline模型 | 第3部分:分片)
TPU 的内部结构其实很简单:一个超强的矩阵乘法引擎 + 一大块显存 + 高速互连网络。搞懂这些,你就知道为什么有些模型跑得快、有些跑得慢了。
想了解 GPU 的可以看新增的第12章!
TPU 本质上就是一个专门算矩阵乘法的计算单元(TensorCore)+ 一大块高速内存(HBM)。
TensorCore 里面有三个关键部件:
MXU 是 TPU 的灵魂。它用一种叫”脉动阵列”的结构(见附录 B),每 8 个时钟周期就能完成一次 bf16[8,128] × bf16[128,128] → f32[8,128] 的矩阵乘法。
除了矩阵乘法,还有很多”杂活”:ReLU 激活函数、向量加法、归约求和等。这些都是 VPU 干的。详见附录 A。
VMEM 是 TensorCore 内部的片上高速缓存,容量小(TPU v5e 只有 128MB)但到 MXU 的带宽极高。可以类比 CPU 的 L1/L2 缓存,但更大、由程序员控制。
重点:数据必须先从 HBM 搬到 VMEM,TensorCore 才能用它算东西。
HBM(高带宽内存) 就好理解了——就是我们平时说的”显存”:
计算时,数据流是这样的:HBM → VMEM → MXU 计算 → VMEM → HBM
TPU 做矩阵乘法时,会把整个过程流水线化:
边搬边算——搬运和计算重叠进行,这样 MXU 就不用干等数据了。
下面这个动画展示了逐元素乘法的过程(矩阵乘法类似):
一句话总结 TPU: 把权重从 HBM 搬到 VMEM,再喂给脉动阵列,每秒能做约 200 万亿次乘加。性能瓶颈通常在数据搬运(HBM↔VMEM 和 VMEM↔MXU 的带宽),而不是计算本身。
VMEM 虽然小,但带宽比 HBM 高 20 多倍。这意味着:
问题是 VMEM 太小,这通常是个挑战。
一个 TPU 芯片通常有两个核心,共享内存(叫”megacore”模式,从 v4 开始)。老款 TPU(v3 及以前)的两个核心是独立的。推理芯片(如 v5e)每个芯片只有一个核心。
4 个芯片组成一个”托盘”(tray),通过 PCIe 连到一个 CPU 主机。这就是你在 Colab 或 TPU-VM 里看到的 4 芯片/8 核心配置(通常当作 4 个逻辑核心用)。
PCIe 带宽有限——大约只有 HBM 带宽的 1/100。可以把数据卸载到主机 RAM,但很慢。
单机不够用怎么办?把多张 TPU 连起来!
在同一个 Pod 内,TPU 通过 ICI(芯片间互连)直接相连——不经过主机!
环面的好处:任意两个节点的最大距离从 N 减到 N/2。还有”扭曲环面”(像莫比乌斯带一样缠绕)可以进一步缩短距离。
SuperPod(最大 Pod):
16×16×16(4096 芯片)16×20×28(8960 芯片)这些大 Pod 由 4×4×4 的小立方体通过光学交换机连接。
也可以申请小规模配置(如 2×2×1、2×2×2),但没有环绕链路,通信时间会翻倍。完整立方体倍数(如 4×4×4、4×4×8)才有环绕链路。
TPU v5e 和 v6e 的 Pod 是单个 16×16 的 2D 环面,长边(16)才有环绕链路。
这是 TPU 和 GPU 的核心区别:
TPU 的方式更便宜、更简单、能扩展到更大规模。GPU 的方式延迟更低、任意两点通信更快。各有利弊。详见 GPU 章节。
以 TPU v5p 为例:
| 连接类型 | 带宽(每芯片) | 备注 |
|---|---|---|
| HBM ↔ VMEM | 2.5 TB/s | 最快 |
| ICI(每轴) | 90 GB/s(双向) | v5p 有 3 轴 |
| PCIe | ~16 GB/s | 比 HBM 慢 100 倍 |
| DCN(出口) | ~6 GB/s | 最慢 |
结论:把模型拆到多卡上时,要小心别让通信拖慢计算。
一组通过 ICI 连接的 TPU 叫一个”切片(Slice)“。不同切片可以通过 DCN(数据中心网络)连接——比如连接不同 Pod 上的切片。
DCN 比 ICI 慢很多,数据还得绕道:TPU → PCIe → 主机 → 网络 → 目标主机 → PCIe → TPU。尽量减少等 DCN 的时间。
TPU 结构很简单:矩阵乘法单元 + 显存 + ICI(超快)+ DCN(较慢)
带宽速度排行:HBM > ICI > PCIe > DCN
TPU 只连最近邻居:远距离通信需要跳转多个芯片
权重矩阵要填充到 128×128(v6 是 256×256)才能喂饱 MXU
低精度更快:int8 是 bf16 的 2 倍,int4 是 4 倍(VPU 操作仍是 fp32)
避免让 MXU 等数据:通信量要和各链路速度匹配
| 型号 | Pod 大小 | 单主机 | HBM/芯片 | HBM 带宽 | bf16 算力 | int8 算力 |
|---|---|---|---|---|---|---|
| v3 | 32×32 | 4×2 | 32GB | 0.9 TB/s | 140 TF/s | 140 TF/s |
| v4p | 16³ | 2×2×1 | 32GB | 1.2 TB/s | 275 TF/s | 275 TF/s |
| v5p | 16×20×28 | 2×2×1 | 96GB | 2.8 TB/s | 459 TF/s | 918 TF/s |
| v5e | 16×16 | 4×2 | 16GB | 0.8 TB/s | 197 TF/s | 394 TF/s |
| v6e | 16×16 | 4×2 | 32GB | 1.6 TB/s | 920 TF/s | 1840 TF/s |
TF/s = 10¹² FLOPs/s
ICI 带宽(每链路):
| 型号 | 单向 | 双向 |
|---|---|---|
| v3 | 100 GB/s | 200 GB/s |
| v4p | 45 GB/s | 90 GB/s |
| v5p | 90 GB/s | 180 GB/s |
| v5e | 45 GB/s | 90 GB/s |
| v6e | 90 GB/s | 180 GB/s |
PCIe 约 16 GB/s/芯片(v6e 是 32),DCN 约 6 GB/s/芯片(v6e 是 12.5,v5e 是 3.125)。
这些数字看着枯燥,但用处很大——可以让你快速估算模型性能。
题 1:推理延迟下界
假设你要从一个 2000 亿参数的 bf16 模型采样,模型分布在 32 张 TPU v4p 上。把所有参数从 HBM 加载到 MXU 要多久?
参数量:2×200×10⁹ = 400×10⁹ 字节(bf16 每参数 2 字节) 每芯片:400×10⁹ / 32 = 12.5×10⁹ 字节 HBM 带宽:1.2×10¹² 字节/s 加载时间:12.5×10⁹ / 1.2×10¹² ≈ 10ms
这就是采样延迟的理论下界——每次采样都要加载所有参数,不可能比 10ms 更快。实际上,小 batch 时接近这个值。
题 2:数一数
一个完整的 TPU v5e Pod 有:
对 v5p Pod 也算一下。
v5e:
v5p:
题 3:从主机内存算矩阵乘法
假设权重 bf16[D, 4D] 和激活 bf16[B, D] 都存在主机内存(不在 TPU 显存),你想用一张 TPU v6e 算矩阵乘法。假设 $B \ll D$。需要多大 batch 才能计算受限?(PCIe 带宽 1.5×10¹⁰ 字节/s)
计算受限条件: \(\frac{8BD^2}{9.2×10^{14}} > \frac{8D^2}{1.5×10^{10}}\)
\[B > \frac{9.2×10^{14}}{1.5×10^{10}} ≈ 61000\]需要 6 万以上的 batch 才能计算受限!PCIe 太慢了。
题 4:矩阵乘法需要多久
在 1 张 TPU v5e 上,用 int8[16384, 4096] 的权重乘以 int8[B, 4096] 的激活:
(1) 从 HBM 读:
计算受限条件:$B > 271$
(2) 从 VMEM 读:
VMEM 带宽约是 HBM 的 22 倍,临界点变成 $B > 11$。
题 5:ICI 传输
4×4 的 TPU v5e 切片,把 bf16[8, 128, 8192] 从 (0,0) 发到 (3,3)。假设每跳延迟 1μs。
总共约 188μs(带宽受限)。
题 6:综合挑战
一个 int8[128K, 128K] 的大矩阵均匀分布在 TPU v5e 4×4 切片上,但卸载到了各芯片的主机内存。你想把它全部收集到 TPU(0,0) 然后乘以 bf16[8, 128K]。要多久?
矩阵约 16GB。4×4 切片有 2 个主机(每主机 8 芯片),每主机存 8GB。
方案:通过 ICI 收集比通过 DCN 更快。
分步计算:
PCIe 加载:16GB / 16 芯片 = 1GB/芯片,带宽 1.5×10¹⁰ → 约 66ms
ICI 收集:TPU(0,0) 要收 15GB,2 个方向各 45 GB/s → 下界 167ms(实际可能更长)
HBM → MXU:16GB / 8.1×10¹¹ ≈ 19ms
计算:$2×8×128K×128K = 2.7×10^{11}$ FLOPs / 1.97×10¹⁴ ≈ 1.3ms
瓶颈在 ICI 收集。假设能部分重叠,总时间约 170-200ms。
这里更深入地介绍 TPU 内部。以 TPU v5p 为例。
VPU 是做”杂活”的向量单元:逐元素加法、ReLU、归约等。
结构:(8, 128) 的 2D SIMD 阵列
速度:大多数指令 1 周期完成,2 周期延迟。
VREGs(向量寄存器):v5p 每核 64 个,总共约 256KB。每周期可以从 VMEM 读 3 个、写 1 个。
归约:sublane 内归约很快(shuffle 几下就行),跨 lane 归约要用 XLU(慢)。
小测验:算一下 TPU v5p VPU 能做多少 FLOPs/s?(时钟 1.75GHz)
每周期:$8 × 128 × 4 × 2$(2 核)= 8192 FLOPs 总算力:$8192 × 1.75×10^9 = 1.4×10^{13}$ FLOPs/s
比 MXU 的 2×10¹⁴ 小 10 倍左右。
对比 GPU:VPU 的每个 ALU 类似 CUDA 核心,每个 lane 类似一个 Warp 调度器。
标量核心是 TPU 的”大脑”——取指令、控制 DMA、做标量运算。
注意:标量核心是单线程的,每周期只能发起一个 DMA 请求。
一个标量核心管着:1 个 VPU(4096 个 ALU)、4 个 MXU、2 个 XLU、多个 DMA 引擎。这种高度集中的控制是效率来源,但也限制了灵活性。
MXU 的核心是 128×128 脉动阵列(v6e 是 256×256)。
完全饱和时,每 8 周期完成一次 bf16[8,128] @ bf16[128,128] → f32[8,128]
原理:
看这个动画:
权重(蓝)先对角加载,输入(绿)再对角喂入。每帧里,重叠的蓝绿单元相乘,加上从上方传来的累积结果,然后向下传一格。
输出流出的过程:
多组输入/权重的流水线:
一开始有流水线气泡(等数据填满),之后就是无缝连续计算。
一个 2×3 矩阵乘法的简化动画:
要点:矩阵形状要大于 MXU 边长(128),否则会有大量浪费。多 MXU 时(v4/v5 有 4 个),需要更大的分块。
v6e 的 256×256 MXU 每周期 4 倍 FLOPs,但也需要更大的张量才能喂饱。