大语言模型在 TPU 上的系统级视角 (第0部分:导论 | 第1部分:Roofline模型)
训练 LLM 常常被说成是『炼丹』,但搞懂模型性能优化其实没那么玄乎。本书想把 LLM 扩展这件事讲明白:TPU 和 GPU 到底怎么干活的、芯片之间怎么通信、LLM 在真实硬件上是怎么跑的,以及怎么把模型拆分到多个芯片上高效运行。如果你曾经想过『训练这个模型要花多少钱』、『部署需要多少显存』、『AllGather 是什么鬼』,希望这本书能帮到你。
深度学习里有很多”黑魔法”,但模型性能优化不用那么玄——哪怕是超大规模也一样!其实背后的原理挺简单的,而且从单卡到上万张卡都适用。搞明白这些,你就能:
需要什么基础? 你得大概知道 LLM 和 Transformer 是啥,但不需要了解它们怎么在大规模下运行。最好对 JAX 有点了解(不是必须)。想补课的话,可以看看这篇图解 Transformer 和原始论文。更多资料见延伸阅读。
读完能学到什么? 你应该能为任意模型和硬件组合选出合适的并行策略,并估算训练和推理的耗时。如果做不到,欢迎给我们留言反馈!
新增了 NVIDIA GPU 专题,见第12章!
三四年前,做机器学习研究可能不太需要懂这些底层的东西。但现在不一样了——即使是”小”模型,也已经逼近硬件极限了。
如果你在 benchmark 上提升了 20%,但效率损失了 20%,那这个提升就是假的。 很多看起来很酷的新架构最后没火起来,要么是因为跑不快,要么是没人愿意花功夫把它优化好。
模型扩展的终极目标:加芯片的同时,吞吐量也线性增长。 这叫”强扩展”(Strong Scaling)。加卡确实能加速计算,但也带来了通信开销。当通信时间超过计算时间,我们就”卡在通信上”了,再加卡也没用。
如果我们足够了解硬件,就能提前预判瓶颈在哪,从而调整模型避开它们。
本书的目标:解释 TPU(和 GPU)怎么工作,以及 Transformer 怎么演进来适配当前硬件。希望对设计新架构的研究者和优化现有模型的工程师都有用。
整体结构:
第1章讲 Roofline 分析——哪些因素在限制你的扩展(通信、计算、内存)。第2章和第3章深入讲 TPU:单芯片怎么工作,多芯片怎么连接,芯片间的带宽和延迟是多少。我们会回答这些问题:
五年前,机器学习的模型架构百花齐放——CNN、LSTM、MLP、Transformer 都有市场。但现在基本就剩 Transformer 了
第5章:训练和第7章:推理是本书的重头戏,讨论核心问题:给定模型大小和芯片数量,怎么并行化才能保持强扩展? 问题简单,答案却出乎意料地复杂。高层来看,有四种主要的并行策略来把模型拆分到多张卡上:
还有一些节省内存的技巧:重计算、ZeRO 优化器分片、卸载到主机内存、梯度累积。这些我们都会讲。
读完这些章节,你应该能为新架构或新场景自己选出合适的并行方案。第6章和第8章是实战教程,把这些概念应用到 LLaMA-3(一个流行的开源模型)上。
最后,第9章和第10章讲怎么用 JAX 实现这些想法,以及出问题时怎么调试。第12章是新增的 GPU 专题。
全书穿插了练习题,可以动手试试。不用从头读到尾,挑感兴趣的看就行。欢迎留下反馈!这本书还在持续更新中。
感谢 James Bradbury 和 Blake Hechtman,书中很多想法源自他们。
本书可能比需要的更长,但别被吓到。前三章是预备知识,熟悉的话可以跳过(但会引入后面用到的符号)。最后几章最实用,讲怎么处理真实模型。
第一部分:基础知识
第1章:Roofline 分析入门——算法受三个因素限制:计算、通信、内存。用这些可以估算运行速度。
第2章:TPU 是怎么工作的——TPU 的内部原理,以及这对我们能训练什么模型有什么影响。
第3章:分片矩阵与矩阵乘法——通过矩阵乘法这个最重要的操作来讲解模型分片和多卡并行。
第二部分:Transformer 详解
第4章:Transformer 数学全解——前向和反向各用多少 FLOPs?参数量怎么算?KV 缓存多大?这里详细推导。
第5章:Transformer 训练并行化——FSDP、Megatron 分片、流水线并行。给定芯片数量,怎么高效训练指定大小和批次的模型?
第6章:在 TPU 上训练 LLaMA 3——实战:怎么在 TPU 上训练 LLaMA 3?要多久?花多少钱?
第7章:Transformer 推理详解——训练完还得部署。推理多了一个要考虑的因素:延迟。我们会讲分离式服务和 KV 缓存。
第8章:在 TPU 上部署 LLaMA 3——用 TPU v5e 部署 LLaMA 3 要多少钱?延迟和吞吐量怎么权衡?
第三部分:动手实践
第9章:TPU 代码性能分析——真实的 LLM 没那么理想化。这里讲 JAX + XLA 技术栈,以及怎么用分析器找问题。
第10章:用 JAX 编程 TPU——JAX 提供了一系列 API 用于并行化,这里教你怎么用。含趣味示例和练习题。
第四部分:总结与扩展
第11章:总结与延伸阅读——收尾和更多参考资料。
第12章:GPU 是怎么工作的——GPU 专题:内部原理、网络连接、Roofline 与 TPU 的对比。