Lesson 0001:全景图 — 从 PyTorch 代码到 GPU 执行

你写了 model = torch.compile(model),然后代码就跑得快了一点。
它到底做了什么?TVM、Triton、IR 各自扮演什么角色?

本课给你一个完整的鸟瞰图。不需要记住每个缩写——先建立方位感,后面的课再逐一深入。

0. 先回答你最重要的问题

TVM 能加速 PyTorch 训练吗?

基本不能。TVM 的现代 IR(Relax)是一个 推理编译器——它接收训练好的模型,针对目标硬件做部署优化。它不包含 autograd 引擎,无法替代 PyTorch 的反向传播。

结论:训练加速 → torch.compile + Triton(你已经在用);推理部署 → TVM 或 TensorRT。这两条线分工明确,换 TVM 做训练不会有收益。

1. Eager Mode:为什么需要编译?

PyTorch 默认以 Eager Mode 运行:

x = layer1(x) # → 立即启动 CUDA kernel A x = F.relu(x) # → 立即启动 CUDA kernel B x = x + residual # → 立即启动 CUDA kernel C # 三个 kernel,三次 launch,两次中间结果写回显存

每次你调用一个操作,PyTorch 就立刻启动一个 GPU kernel。问题:

编译的目标就是:抓住计算图 → 合并小操作为大 kernel → 减少 launch 和显存往返

2. torch.compile 的三段流水线

PyTorch 代码 │ ▼ ┌─────────────┐ │ Dynamo① 图捕获:拦截 Python 字节码, │ (Graph │ 记录所有操作到 FX Graph │ Capture) │ └──────┬───────┘ │ FX Graph (IR 第一层:算子级 DAG) ▼ ┌─────────────┐ │ Inductor② 图优化 + 降低:算子融合、 │ (Graph │ 布局选择 → Loop-level IR │ Compiler) │ └──────┬───────┘ │ Loop-level IR (IR 第二层:循环级) ▼ ┌─────────────┐ │ Triton/C++③ Kernel 生成:产出 GPU 或 CPU 代码 │ (Kernel │ │ Generation)│ └──────┬───────┘ │ ▼ GPU (PTX → SASS) 或 CPU (native)

2.1 Dynamo — 图捕获(Graph Capture)

Dynamo 做的事情很直观:它 劫持 Python 解释器的帧评估钩子(frame evaluation hook),在代码执行前拦截字节码,把 PyTorch 操作记录成一张 FX 图。

用户代码: x = a + b y = torch.nn.functional.relu(x) z = y * 2 Dynamo 抓到的 FX Graph: %x = add(%a, %b) %y = relu(%x) %z = mul(%y, 2) # 每个 node 都是一个 PyTorch 算子

关键概念:Graph Break。Dynamo 不是万能的——遇到它不认识的代码(data-dependent control flow、不支持的操作),图会断开。断点太多,编译的好处就少了。用 torch.compile(fullgraph=True) 可以检测断点。

PyTorch Dynamo 文档 — Dynamo 的完整工作原理和 graph break 的处理方式。

2.2 Inductor — 图编译器(Graph Compiler)

Inductor 拿到 FX Graph 后做两件事:

第一步:算子融合(Operator Fusion)

融合前: add → relu → mul # 3 个 kernel,中间结果读写显存 2 次 融合后: fused_add_relu_mul # 1 个 kernel,中间结果留在寄存器

这是编译最大的性能收益来源。对于 transformer 里大量的 element-wise 操作(LayerNorm 周围的 add、mul、div),融合能减少 60-80% 的显存带宽消耗。

第二步:降低到 Loop-level IR

FX Graph 是算子级别的("这里有一个 add,这里有一个 relu")。Loop-level IR 是更底层的("这里有一个循环,循环体里做加法,再做 relu")。这个降低让 Inductor 可以决定:

2.3 Triton — Kernel 生成(Kernel Generation)

Inductor 把 Loop-level IR 翻译成 Triton 代码(一种 Python-like 的 GPU kernel 语言),然后 Triton 编译器把 Triton 代码编译成 PTX(NVIDIA GPU 的汇编中间表示),最终由 NVIDIA 驱动编译成 GPU 可执行的 SASS

Triton 的核心创新:block-level 编程。手写 CUDA 需要分别管理 thread 和 block,Triton 只需要你在 block 层面描述计算,它自动处理:

这就是为什么 torch.compile 能"零成本"加速——你不需要写任何 kernel 代码。

PyTorch 2.0: The Journey to 2x Training Speedups (arXiv:2304.04487) — Meta 团队对 torch.compile 架构最权威的阐述。

3. IR 到底是什么意思?你在哪一层?

IR = Intermediate Representation = 中间表示

编译器的本质是对信息做 representation(表示)→ transformation(变换)。每一步你换一种表示,在那个表示上做某种变换,然后传到下一层。

IR 层次谁产生表示什么适合做什么优化
FX Graph Dynamo 算子级 DAG(add、matmul、softmax…) 算子替换、死代码消除、常量折叠
Loop-level IR Inductor 循环 + 访存模式 算子融合、分块(tiling)、向量化、布局选择
Triton IR Inductor 输出 Block-level GPU 语义 线程映射、共享内存分配、自动调优
PTX Triton 编译器 NVIDIA GPU 虚拟 ISA 寄存器分配、指令调度
SASS NVIDIA 驱动 具体 GPU 架构的机器码 —(最终执行)
TVM Relay/Relax TVM(推理) 模型级计算图(替代 FX Graph 角色) 跨算子融合、layout rewrite、量化、异构调度
关键 insight:IR 不是一种东西——它是多层堆叠的抽象塔。每往下一层,表示更接近硬件,优化空间变小但精度变高。

TVM 的 Relay/Relax 和 torch.compile 的 FX Graph 处于同一抽象层次,只是 TVM 覆盖了更多硬件后端(ARM、Mobile GPU、FPGA 等),而 torch.compile 专注于 NVIDIA GPU + x86 CPU。

4. 三大编译器的定位:一张图讲清楚

torch.compile + TritonApache TVMNVIDIA TensorRT
主要场景 训练 + 推理 推理(多硬件) 推理(NVIDIA only)
IR 层数 FX → Loop → Triton 三层 Relax → TIR 两层 单一计算图 IR
Kernel 生成 Triton(自动, 运行时 JIT) AutoTVM / MetaSchedule(需 auto-tuning, 小时级) 手写 CUDA 库(cuDNN、cuBLAS)
硬件覆盖 NVIDIA GPU + x86 CPU GPU + CPU + Mobile + FPGA + … NVIDIA GPU only
训练支持 ✅ 原生支持 ❌ 无 autograd 引擎 ❌ 推理 only
编译开销 秒-分钟(首次 JIT) 分钟-小时(auto-tuning) 分钟(builder 构建 engine)
NVIDIA GPU 推理峰值 接近 TensorRT(90-95%) 接近 TensorRT(85-95%) 🥇 最强(手写库 + 硬件专调)

所以 TensorRT 是推理场景的"终极答案"——它用的是 NVIDIA 手写的 cuDNN/cuBLAS kernel,不是自动生成的。代价是:只支持 NVIDIA GPU、有限的算子覆盖、不碰训练。

5. 回到你的问题:训练还能怎么加速?

你已经有了最好的起点:torch.compile + Triton 是目前 PyTorch 训练加速的 唯一生产级方案。TVM 和 TensorRT 都不处理训练。

如果你还没榨干 torch.compile 的潜力,可以从这里开始:
# 当前你可能这样用:
model = torch.compile(model)

# 试试这样:
model = torch.compile(
    model,
    mode="max-autotune",     # Triton 尝试多个 kernel 配置,选最快的
    fullgraph=True,           # 强制全图编译,报错说明有 graph break
)

Mode 对比:

Mode编译时间运行时速度适用
"default"最快good开发调试
"reduce-overhead"better小 batch、多用 CUDA Graphs 消除 launch overhead
"max-autotune"慢(首次几分钟)best大规模训练、静态 shape、跑得久值得等编译
PyTorch Blog: torch.compile Tuning Guide — mode 选择和 autotune 的实战建议。

6. 小结

这一课你只需要记住三个东西:

  1. 编译三段式:Dynamo(抓图)→ Inductor(优化图)→ Triton(生成 kernel)
  2. IR 是分层堆叠的:FX Graph…Loop IR…Triton IR…PTX…SASS,越往下越接近硬件
  3. TVM 不是训练方案:它的主战场是 推理部署,尤其是非 NVIDIA 硬件。你的训练加速卡在 torch.compile 怎么调,不在要不要换 backend

7. Quiz

检验一下

1. torch.compile 的流水线中,图捕获由哪个组件完成?
Inductor
Dynamo
Triton
FX
2. IR(中间表示)在编译栈中处于哪个位置?
只存在于 Dynamo 输出处
只存在于最终的 PTX/SASS 层
存在于多个层次,每层表示不同的抽象
IR 是 TVM 特有的,torch.compile 里没有
3. 为什么 TVM 不能直接加速 PyTorch 训练?
TVM 只支持 ONNX,不支持 PyTorch
TVM 没有 autograd 引擎,无法做反向传播
TVM 的速度比 Triton 慢
TVM 只能跑在 CPU 上
4. TensorRT 是干什么的?
PyTorch 训练的默认 backend
开源的 GPU kernel 语言
NVIDIA 的推理优化引擎,用 cuDNN 等手写库做极致推理加速
跨硬件的深度学习编译器
5. Inductor 做的"算子融合"(operator fusion)主要解决什么问题?
减少模型参数量
减少显存带宽消耗和 kernel launch overhead
提高模型精度
支持更多硬件平台

📖 推荐阅读: PyTorch 2.0: The Journey to 2x Training Speedups — 本文是 torch.compile 设计团队的官方论文。读完第一课再去看 paper,你会发现自己能看懂每个组件在干什么。
💬 有疑问? 直接问我。"Inductor 的 Loop-level IR 长什么样?" "Graph break 具体怎么排查?" "Triton 和 CUDA kernel 在显存使用上有什么区别?"——我都可以展开讲。 后续课程预告:Lesson 0002 会用实际代码带你 debug 编译流程(看 FX Graph、看 Triton kernel 输出、看 graph break 在哪)。