Part 7 of How To Scale Your Model (第6部分:训练 LLaMA | 第8部分:服务 LLaMA)
推理和训练完全是两码事。训练只看吞吐量,推理还得管延迟。这章我们从『怎么生成一个 token』讲起,一直讲到『怎么搭建一个高效的推理引擎』。
模型训练好了,终于可以用了!
说实话,损失曲线降下来、benchmark 分数升上去,这些都是代理指标。真正有意思的时刻,是你看着模型一个字一个字往外蹦的时候。
采样的原理很简单:
问题是:每生成一个 token,都要把前面所有 token 重新算一遍。
生成 n 个 token:
这显然不靠谱。
聪明的做法是:把中间结果存起来。
具体来说,注意力机制里每个 token 的 Key 和 Value 投影是可以复用的。只要我们把它们存在一个叫 KV 缓存 的数据结构里,后续 token 就不用重新计算前面的 K 和 V 了。
有了 KV 缓存,推理分成两个阶段:
| 阶段 | 做什么 | 特点 |
|---|---|---|
| 预填充 | 一次性处理整个提示,生成 KV 缓存 | 可以并行,像训练一样 |
| 生成 | 一个一个吐 token,每次更新 KV 缓存 | 必须串行,一次一个 |
现在复杂度变成:
当你在 ChatGPT 里看到回答一个字一个字蹦出来,每个字(通常)就是一次单独的模型调用。
关键洞察:预填充和生成是完全不同的任务!预填充像训练(批量并行),生成像一个超慢的循环(必须串行)。KV 缓存是推理特有的复杂性来源。
训练只关心一个指标:吞吐量(每秒处理多少 token)。
推理要复杂得多,因为多了一个维度:延迟。
| 场景 | 关心什么 |
|---|---|
| 批量推理(评估、数据生成) | 只看成本,不管延迟 |
| 聊天/流式 | 首 token 要快(TTFT),生成速度要跟上阅读速度 |
| 边缘推理(本地 llama.cpp) | 单用户,拼命压延迟 |
最大化硬件利用率仍然重要(省钱、降低 TTFT),但高利用率不一定等于好体验。
很多优化要在延迟、吞吐量、上下文长度、甚至模型质量之间做权衡。
训练时我们把 Transformer 简化成”一堆矩阵乘法”。推理要更精细地分析。
Transformer 前向传播的主要组件:
接下来我们分别分析:在预填充和生成中,什么是瓶颈?
所有线性操作本质上都一样:bf16[B, D] × bf16[D, F]。
回顾第1章的公式:
\[T_{计算} = \frac{2BDF}{C}\] \[T_{通信} = \frac{2BD + 2DF + 2BF}{W_{hbm}}\]当 B « D, F 时(批次小,模型大),分母约等于 2DF:
\[\frac{T_{计算}}{T_{通信}} \approx \frac{B \cdot W_{hbm}}{C}\]要计算受限(FLOPs 是瓶颈),需要:
\[B > \frac{C}{W_{hbm}} = B_{crit}\]| 硬件 | C/W_hbm | B_crit (bf16) |
|---|---|---|
| TPU v5e | 1.97e14 / 8.2e11 | 240 |
| H100 | ~3.9e15 / 3.35e12 | ~280 |
要点:矩阵乘法要计算受限,每副本的 token 批次大小必须超过 B_crit(TPU v5e 约 240)。
预填充:提示通常有几百甚至几千个 token,轻松超过 240。基本总是计算受限。
生成:每个请求一次只能生成一个 token!要达到 240,必须把 240 个请求批在一起。这意味着 240 个独立的 KV 缓存,实际上很难做到。
要点:预填充基本总是计算受限。生成要达到计算受限,必须把很多请求批在一起,这很难!
如果权重量化到 int8(激活仍是 bf16):
如果 FLOPs 也用 int8:
所以:B_crit = β × α_hbm,其中 β = 参数位数 / 激活位数。
这里事情变得有趣了,因为 KV 缓存来搅局。
假设用 Flash Attention(不实体化注意力矩阵):
读取:
计算:
算术强度:
\[\text{强度} = \frac{4BSTD}{4BSD + 4BTD} = \frac{ST}{S+T}\]自注意力,S = T:
\[\text{强度} = \frac{T^2}{2T} = \frac{T}{2}\]只要序列长度超过 480(TPU v5e),就能计算受限。一般没问题。
每次只处理一个新 token:
\[S \gg T=1 \implies \text{强度} \approx \frac{S \cdot 1}{S+1} \approx 1\]强度恒定为 1!不管批次多大、序列多长,都改变不了。
每次都要把整个 KV 缓存从 HBM 读一遍,却只做很少的计算。
要点:预填充的注意力可以计算受限(序列够长就行)。生成的注意力永远是内存带宽受限的。
为什么?因为每个请求都有自己的 KV 缓存。批次变大 → KV 缓存变多 → 内存读取同比例增加。没有复用,就没有收益。
这是全章最重要的公式,请务必记住。
小测验:在 TPU v5e 4×4 上服务 30B 模型(int8),8192 上下文,100kB/token 的 KV 缓存,批次大小 4。最小步骤延迟是多少?批次 256 呢?
int8 参数 = 30GB 每序列 KV 缓存 = 100kB × 8192 = 819MB 16 芯片总带宽 = 16 × 8.1e11 = 1.3e13 B/s
批次 4(带宽受限): \(T = \frac{4 \times 819e6 + 30e9}{1.3e13} = 2.5ms\)
批次 256(MLP 计算受限): \(T = \frac{256 \times 819e6}{1.3e13} + \frac{2 \times 256 \times 30e9}{16 \times 1.97e14} = 16ms + 5ms = 21ms\)
要点:关心吞吐量就用大批次(超过 B_crit ≈ 240)。关心延迟就用小批次。可能需要更大拓扑来支撑大批次。
拿 LLaMA-2 13B 做例子:
| 参数 | 值 |
|---|---|
| L | 40 |
| D | 5,120 |
| F | 13,824 |
| N (Q 头数) | 40 |
| K (KV 头数) | 40 |
| H | 128 |
参数内存:
bf16 = 26GB。量化可以更小。没有优化器、没有梯度。激活可以忽略(Flash Attention)。
KV 缓存(重点!):
\[\text{KV 大小} = 2 \times \text{bytes} \times H \times K \times L \times T\]LLaMA-2 13B,8192 序列,bf16:
\[8192 \times 40 \times 128 \times 40 \times 2 \times 2 = 6.7\text{GB}\]一个 KV 缓存就 6.7GB!4 个就超过参数了!
这就是为什么 KV 缓存是推理的大麻烦。
在 8×TPU v5e(128GB HBM,6.5TB/s 带宽,1600TF/s)上:
| 批次 | KV 缓存 (GB) | 总内存 (GB) | 步骤时间 (ms) | 吞吐量 (tok/s) |
|---|---|---|---|---|
| 1 | 6.7 | 32.7 | 5.0 | 200 |
| 8 | 53.6 | 79.6 | 12.1 | 659 |
| 16 | 107.2 | 133.2 | 20.3 | 788 |
| 32 | 214.4 | 240.4 | 36.7 | 873 |
| 64 | 428.8 | 454.8 | 69.3 | 923 |
| 240 | 1608 | 1634 | 249 | 964 |
问题:
如果 KV 缓存小 5 倍(比如用 8 个 KV 头配 40 个 Q 头):
| 批次 | KV 缓存 (GB) | 总内存 (GB) | 步骤时间 (ms) | 吞吐量 (tok/s) |
|---|---|---|---|---|
| 1 | 1.3 | 27.3 | 4.2 | 240 |
| 8 | 10.7 | 36.7 | 5.6 | 1429 |
| 16 | 21.4 | 47.4 | 7.2 | 2212 |
| 32 | 42.9 | 68.9 | 10.5 | 3048 |
| 64 | 85.8 | 111.8 | 17.0 | 3757 |
| 240 | 321.6 | 347.6 | 53.0 | 4529 |
延迟更好,吞吐量更高,批次能开更大。LLaMA-3 正是这么做的(32 个 Q 头,8 个 KV 头)。
要点:KV 缓存大小对推理性能影响巨大。小 KV = 更大批次 + 更低延迟 + 更高吞吐量。
既然 KV 缓存是罪魁祸首,大家想了很多办法来压缩它:
效果:KV 缓存减少 Q:KV 倍数。模型质量对此相对不敏感。
要点:这些优化叠加起来,可以把 KV 缓存压缩一个数量级,推理成本也能降一个数量级。
从 roofline 角度,预填充几乎和训练一样。
可以用的技术:
分片策略:
要点:预填充的分片和训练几乎一样。张量并行到 ICI 瓶颈,然后序列并行。
生成就难办多了:
不能用的策略:
| 策略 | 为什么不行 |
|---|---|
| FSDP | 我们是带宽受限的,不能通过 ICI 移动权重(太慢) |
| 数据并行 | 复制权重没意义,不如直接开多个副本 |
| 序列并行 | 每次只有一个 token,没序列可切 |
只剩下张量并行。
好消息是:因为我们是带宽受限的,可以做更激进的张量并行来改善延迟!
在训练中,ICI 瓶颈是 FLOPs 和 ICI 通信的比较。 在生成中,瓶颈是 HBM 带宽和 ICI 通信的比较。
\[T_{HBM} = \frac{2DF}{Y \cdot W_{hbm}}\] \[T_{ICI} = \frac{2BD}{W_{ici}}\]要 ICI 不成瓶颈:
\[Y < \frac{F}{B \cdot \beta}\]其中 β = W_hbm / W_ici ≈ 8(TPU v5e/v6e)。
例如:F=16384,B=32 → 可以做到 64 路张量并行!
要点:生成只能用张量并行的变体。目标是移动激活而不是 KV/参数。带宽受限时可以比训练做更多路张量并行。
KV 缓存也需要分片,而且尽量不要复制(太大了)。
分片策略:
代价:每层两次 AllToAll(Q 从张量分片转批次分片,输出再转回来)。
如果批次太小或上下文太长,还可以沿序列维度切 KV。
知道了怎么高效执行单次预填充和生成,还需要设计一个推理引擎来把它们串起来。
聚集一批请求 → 预填充 → 生成直到全部完成 → 下一批
问题:
这种方案只适合:边缘设备(单用户)或早期原型。
预填充批次大小 1(立即返回),生成批多个请求。
优点:
缺点:
预填充和生成跑在不同的 TPU/GPU 上。
工作流程:
优点:
缺点:
要点:高吞吐量、低延迟的服务,通常要把预填充和生成分离到不同服务器。预填充批次 1,生成批多个请求。
核心思想:
这样可以保持生成批次始终饱满。
预填充很贵。能不能少做一点?
观察:相同前缀的请求,KV 缓存是一样的!
例如:
应用场景:
实现要点:
Google 开源的推理引擎 JetStream:
核心组件:
Engine 接口:
prefill(tokens) → 返回 KV 缓存insert(kv_cache) → 插入到生成批次generate(batch) → 为每个请求生成一个 token还有 PyTorch 版本。
用这个虚构的模型练习:
| 参数 | 值 |
|---|---|
| L | 64 |
| D | 4,096 |
| F | 16,384 |
| N (Q 头) | 32 |
| K (KV 头) | 8 |
| H | 256 |
| V | 32,128 |
问题 1:参数量和 KV 缓存大小
参数:
KV 缓存(int8): 2 × L × K × H = 2 × 64 × 8 × 256 = 262KB/token
问题 2:在 TPU v5e 4×4 上能开多大批次?(int8,128k 上下文)
每序列 KV = 262KB × 128K = 33.5GB 16 TPU × 16GB = 256GB 总 HBM 可用 = 256 - 18.4 (参数) = 237.6GB 最大批次 = 237.6 / 33.5 ≈ 7
如果 K=1:最大批次 ≈ 56
问题 3:加载参数的理论最小时间
18.4B 字节 ÷ (16 × 8.1e11 B/s) = 1.4ms
这是步骤延迟的下限。
问题 4:预填充和生成怎么分片?
问题 5:改成 MoE(E=16 专家,k=2 激活)
(1) 总参数 = 64 × 4096 × (3×16×16384 + 2×256×40) + 131K = 212B 激活参数 = 64 × 4096 × (3×2×16384 + 2×256×40) + 131K = 31.2B
(2) B_crit = 240 × (16/2) = 1920 tokens
(3) KV 缓存不变(注意力没变)
(4) FLOPs = 2 × 激活参数 × T = 2 × 31.2B × T
实测确实在批次 240 左右看到拐点。
当拓扑很大时,可以同时沿 D 和 F 分片权重,让每块接近正方形。
通信量随 √N 下降,比 1D Megatron 更好。当 N > 81 时值得考虑。
当数据量很小时,通信时间被延迟(而非带宽)主导。
临界点:buffer < W_ici × 1μs ≈ 45KB
对于 BS=16, D=8192 的 int8 激活:16×8192=131KB,已经延迟受限了。
核心思想:用小模型快速生成草稿,大模型并行验证。
为什么快?
要点:推测采样用吞吐量换延迟。在批次受限时(KV 缓存大、硬件小),可能两者都赢。