如何扩展你的模型

大语言模型在 TPU 上的系统级视角 (第0部分:导论 | 第1部分:Roofline模型)

训练 LLM 常常被说成是『炼丹』,但搞懂模型性能优化其实没那么玄乎。本书想把 LLM 扩展这件事讲明白:TPU 和 GPU 到底怎么干活的、芯片之间怎么通信、LLM 在真实硬件上是怎么跑的,以及怎么把模型拆分到多个芯片上高效运行。如果你曾经想过『训练这个模型要花多少钱』、『部署需要多少显存』、『AllGather 是什么鬼』,希望这本书能帮到你。

深度学习里有很多”黑魔法”,但模型性能优化不用那么玄——哪怕是超大规模也一样!其实背后的原理挺简单的,而且从单卡到上万张卡都适用。搞明白这些,你就能:

需要什么基础? 你得大概知道 LLM 和 Transformer 是啥,但不需要了解它们怎么在大规模下运行。最好对 JAX 有点了解(不是必须)。想补课的话,可以看看这篇图解 Transformer原始论文。更多资料见延伸阅读

读完能学到什么? 你应该能为任意模型和硬件组合选出合适的并行策略,并估算训练和推理的耗时。如果做不到,欢迎给我们留言反馈!

新增了 NVIDIA GPU 专题,见第12章

为什么要学这些?

三四年前,做机器学习研究可能不太需要懂这些底层的东西。但现在不一样了——即使是”小”模型,也已经逼近硬件极限了历史上,机器学习研究在"硬核系统优化"和"易用框架封装"之间来回摆。当年 Alex Krizhevsky 得手写 CUDA 才能把 CNN 跑快,但几年后 TensorFlow、PyTorch 这些框架出来后,大家就不用管底层了。也许将来这些知识也会被封装掉。但 Scaling Law 让模型越来越大,前沿研究和高效扩展已经分不开了。

如果你在 benchmark 上提升了 20%,但效率损失了 20%,那这个提升就是假的。 很多看起来很酷的新架构最后没火起来,要么是因为跑不快,要么是没人愿意花功夫把它优化好。

模型扩展的终极目标:加芯片的同时,吞吐量也线性增长。 这叫”强扩展”(Strong Scaling)。加卡确实能加速计算,但也带来了通信开销。当通信时间超过计算时间,我们就”卡在通信上”了,再加卡也没用。随着计算时间变短,单卡上也会出现新瓶颈。你的新 TPU/GPU 标称 500 TFLOPS,但如果整天在搬运数据,实际可能只有十分之一。单卡的计算能力、内存带宽、总内存三者的配合,对扩展至关重要。

如果我们足够了解硬件,就能提前预判瓶颈在哪,从而调整模型避开它们。硬件设计师面临相反的问题:要在算力、带宽、内存之间找平衡,同时控制成本。这是个"协同设计"的博弈:你得赌两三年后算法会是什么样。TPU 就是这场博弈的一个成功案例。矩阵乘法每搬运一字节数据就能做 N 次运算,非常划算。早期 TPU 的脉动阵列架构比同期 GPU 性价比更高。GPU 后来也加了 TensorCore 来追赶。但你可以想象,如果神经网络没火起来,或者往 TPU 不擅长的方向发展,那赌输了代价有多大。

本书的目标:解释 TPU(和 GPU)怎么工作,以及 Transformer 怎么演进来适配当前硬件。希望对设计新架构的研究者和优化现有模型的工程师都有用。

这本书讲什么

整体结构:

第1章Roofline 分析——哪些因素在限制你的扩展(通信、计算、内存)。第2章第3章深入讲 TPU:单芯片怎么工作,多芯片怎么连接,芯片间的带宽和延迟是多少。我们会回答这些问题:

动图演示: 第2章中会讲 TPU 如何做逐元素乘法。根据数组大小和带宽,我们可能是计算受限(充分利用算力)或通信受限(被数据搬运拖慢)。

五年前,机器学习的模型架构百花齐放——CNN、LSTM、MLP、Transformer 都有市场。但现在基本就剩 Transformer 了。我们觉得有必要彻底搞懂 Transformer 的每个细节:每个矩阵多大、归一化在哪里、参数和 FLOPs浮点运算数(Floating point OPs),就是加法和乘法的总数。很多资料把 FLOPs 说成"每秒运算数",我们用 FLOPs/s 来明确表示后者。怎么算。第4章会细细拆解这些”Transformer 数学”,教你算训练和推理的参数量、FLOPs。这能告诉你模型吃多少内存、计算和通信各花多少时间、注意力和 FFN 哪个更重要。

图示: 标准 Transformer 层。圆圈里的点表示矩阵乘法(matmul),紫色是参数(不含归一化)。第4章会详细解释。

第5章:训练第7章:推理是本书的重头戏,讨论核心问题:给定模型大小和芯片数量,怎么并行化才能保持强扩展? 问题简单,答案却出乎意料地复杂。高层来看,有四种主要的并行策略来把模型拆分到多张卡上:

还有一些节省内存的技巧:重计算ZeRO 优化器分片卸载到主机内存梯度累积。这些我们都会讲。

读完这些章节,你应该能为新架构或新场景自己选出合适的并行方案。第6章第8章是实战教程,把这些概念应用到 LLaMA-3(一个流行的开源模型)上。

最后,第9章第10章讲怎么用 JAX 实现这些想法,以及出问题时怎么调试。第12章是新增的 GPU 专题。

全书穿插了练习题,可以动手试试。不用从头读到尾,挑感兴趣的看就行。欢迎留下反馈!这本书还在持续更新中。

感谢 James Bradbury 和 Blake Hechtman,书中很多想法源自他们。

话不多说,开始看第1章:Roofline 模型吧。

章节导航

本书可能比需要的更长,但别被吓到。前三章是预备知识,熟悉的话可以跳过(但会引入后面用到的符号)。最后几章最实用,讲怎么处理真实模型。

第一部分:基础知识

第二部分:Transformer 详解

第三部分:动手实践

第四部分:总结与扩展

Miscellaneous

*Work done at Google DeepMind, now at MatX.

Citation

For attribution in academic contexts, please cite this work as:

    Austin et al., "How to Scale Your Model", Google DeepMind, online, 2025.

or as a BibTeX entry:

    @article{scaling-book,
      title = {How to Scale Your Model},
      author = {Austin, Jacob and Douglas, Sholto and Frostig, Roy and Levskaya, Anselm and Chen, Charlie and Vikram, Sharad
      and Lebron, Federico and Choy, Peter and Ramasesh, Vinay and Webson, Albert and Pope, Reiner},
      publisher = {Google DeepMind},
      howpublished = {Online},
      note = {Retrieved from https://jax-ml.github.io/scaling-book/},
      year = {2025}
    }