Lesson 0002:动手 debug torch.compile —— 亲眼看到编译产物

上一课你知道了概念上的流水线:Dynamo → Inductor → Triton。
这一课你要亲眼看:FX Graph 长什么样?Triton kernel 的代码是什么?Graph break 在哪?

我们会用 TORCH_LOGS 和几个简单的 Python 脚本来做这件事。

1. 环境准备

第一件事:确认你能在终端里访问 PyTorch 2.0+。以下所有代码你都要在自己的终端里跑(你的 conda/venv 环境)。

# 确认版本 ≥ 2.0
python -c "import torch; print(torch.__version__)"

如果输出是 2.x.y,就可以开始了。如果不是,升级到最新版:

pip install --upgrade torch

2. 第一眼看 FX Graph

最简单的起点:对一个极小的函数用 torch.compile,然后导出它产生的 FX Graph。

打开你的终端,创建并运行这个脚本:
# 复制这段代码到 Python 文件中,或直接在 Python REPL 里逐行跑
import torch

def my_fn(x, y):
    # 一个刻意多步骤的小函数,方便观察融合
    a = x + y
    b = a * 2
    c = torch.nn.functional.relu(b)
    return c

# 步骤 1: 只用 Dynamo 导出图(不跑 Inductor)
x = torch.randn(4, 4)
y = torch.randn(4, 4)

explanation = torch._dynamo.explain(my_fn)(x, y)
print("=" * 60)
print("图中操作数:", explanation.graph_break_count)
print("Graph breaks:", explanation.break_reasons)
print("=" * 60)
print("FX Graph:")
print(explanation.graphs[0].graph.print_tabular())

你会看到类似这样的输出:

opcode name target args kwargs ------------- ------ ------------------ ---------------- -------- placeholder l_x_ L_x_ () {} placeholder l_y_ L_y_ () {} call_function add <built-in add> (l_x_, l_y_) {} call_function mul <built-in mul> (add, 2) {} call_function relu <function relu> (mul,) {} output output output ((relu,),) {}
你在看什么?
这就是 FX Graph——Dynamo 把你的 4 行 Python 代码翻译成了 6 个图节点。每行的 opcode 告诉你这是什么操作:placeholder(输入)、call_function(调用一个函数)、output(输出)。

注意:这时候还没有编译。这只是 Dynamo 的"抓图"阶段的作用。

3. 第二眼看 Inductor 的融合效果

上一步只有图捕获,没有编译。现在让 torch.compile 真正运行,看 Inductor 把图编译成了什么。

# 在终端里设置环境变量,然后再跑 Python
# 方法 A:在 bash/zsh 里一行搞定
TORCH_LOGS="output_code" python -c "
import torch

def my_fn(x, y):
    a = x + y
    b = a * 2
    c = torch.nn.functional.relu(b)
    return c

fn = torch.compile(my_fn)
x = torch.randn(4, 4, device='cuda')
y = torch.randn(4, 4, device='cuda')

# 第一遍:触发编译
_ = fn(x, y)
# 第二遍:使用缓存
_ = fn(x, y)
print('Done — 查看上方终端输出中的 Triton kernel 代码')
"
注意:上面代码用了 device='cuda',需要 GPU。如果你没有 GPU,把 device='cuda' 改成 device='cpu'——Inductor 会生成 C++ 代码而不是 Triton 代码。底线是一样的:你能看到编译产物的源代码。

终端输出里你会看到一长串 Triton kernel 代码。找类似这样的片段:

@triton.jit def triton__(...) : # Triton 自动生成的 fused kernel # 把 add + mul + relu 合成了一个 kernel! x = tl.load(...) y = tl.load(...) z = x + y # add z = z * 2 # mul z = tl.where(z > 0, z, 0) # relu tl.store(...)
关键 insight:你写的 3 个独立的 PyTorch 操作(add、mul、relu)被 Inductor 融合成了 1 个 Triton kernel。三个操作的内存中间结果不再写回 HBM(显存),而是留在寄存器里。这就是 torh.compile 提速的核心机制——不是魔法,就是减少显存带宽消耗。
PyTorch 官方文档:torch.compile Debugging — TORCH_LOGS 的完整参数列表和排障流程。

4. 第三眼看 Graph Break

Graph break 是 torch.compile 最常见的性能杀手。来看一个故意的例子:

# 终端里运行:
TORCH_LOGS="graph_breaks" python -c "
import torch

def bad_fn(x):
    # Step 1: 正常 tensor 操作 → 可以被编译
    y = x * 2
    y = y + 1

    # Step 2: .item() 强制 CPU 同步 → Graph Break!
    val = y.sum().item()
    print(f'val = {val}')  # Python print 也会断

    # Step 3: 后面又是正常 tensor 操作
    z = y / 3
    return z

fn = torch.compile(bad_fn)
x = torch.randn(100)
_ = fn(x)
"

你会看到类似这样的输出:

[INFO] Graph break: .item() called on a tensor at bad_fn:10 in <stdin> Reason: call to method Tensor.item

这意味着你的函数被 切成 3 段

每一次进出 eager mode,编译的收益就被削弱一次。这就是为什么 fullgraph=True 强制编译整个图——如果用了它报错,说明你的模型里存在 graph break。

5. 实战:一键看全貌

现在你已经知道怎么单独看了。搞一个真实的调试命令:

# 一键看 graph breaks + 每次重编译原因 + Triton kernel 输出
TORCH_LOGS="graph_breaks,recompiles,output_code" python your_training_script.py 2>&1 | head -200

这条命令的含义:

Flag回答什么什么时候用
graph_breaks 图在哪断了?为什么? 每次调 torch.compile 都应该看一眼
recompiles 为什么同一个函数被反复编译?(shape 变了?dtype 变了?) 训练慢但 graph break 很少——可能是反复重编译
output_code Inductor 生成了什么 kernel?融合效果如何? 想确认编译确实在做融合,不是在空转
guards 编译缓存的守卫条件是什么?(哪些条件变化会触发重编译) 排查重编译风暴的根因

6. 如果你没有 GPU

一切同理,只是 backend 换成 C++:

# CPU 版本
TORCH_LOGS="output_code" python -c "
import torch

def my_fn(x, y):
    a = x + y
    b = a * 2
    c = torch.nn.functional.relu(b)
    return c

fn = torch.compile(my_fn)
x = torch.randn(4, 4)  # 无 device=... → CPU
y = torch.randn(4, 4)
_ = fn(x, y)
print('Done')
"

你会看到 Inductor 生成的 C++/OpenMP 代码而不是 Triton。原理完全一样——图捕获 → 融合 → 生成代码,只是目标硬件不同。

7. 本节核心技能

你学会了三个调试技能

  1. 看 FX Graph:用 torch._dynamo.explain() 导出图,验证 Dynamo 抓到了什么
  2. 看 kernel 代码:用 TORCH_LOGS="output_code",确认 Inductor 在做融合而不是空转
  3. 排查 graph break:用 TORCH_LOGS="graph_breaks" 定位断点,修复后加 fullgraph=True

8. Quiz

动手之后,检验一下

1. TORCH_LOGS="output_code" 会显示什么?
PyTorch 的版本号和编译时间
Inductor 生成的 Triton/C++ kernel 代码
模型的 loss 曲线
GPU 的利用率
2. 哪些操作会导致 Graph Break?
torch.add 和 torch.mul
torch.compile 本身
.item() 和 data-dependent 控制流
torch.nn.functional.relu
3. 一个函数被切成 3 个子图意味着?
Inductor 做了 3 层优化
有 2 处 graph break,部分代码在 eager mode 执行
模型有 3 个参数
Triton 生成了 3 个 kernel
4. torch._dynamo.explain() 的作用是?
加速模型训练
导出 FX Graph 并报告 graph break,不实际编译运行
切换 Triton backend 到 TVM
生成 ONNX 模型
5. 下面哪项是 Inductor 算子融合的效果?
把 10 层的模型压缩为 5 层
add + mul + relu 合为 1 个 kernel,中间结果不留 HBM
把 FP32 的参数转换成 FP16
把 batch size 从 16 变成 32

📖 推荐阅读: PyTorch 官方 torch.compile 排障指南 — TORCH_LOGS 的完整参数表、常见 graph break 原因及修复方案。 以及 torch.compile Troubleshooting
💬 有问题? 在终端跑命令时遇到奇怪的输出?graph break 的原因看不懂?直接问我。

推荐:拿你自己的训练代码跑一次 TORCH_LOGS="graph_breaks",把输出贴给我,我帮你分析。