上一课你知道了概念上的流水线:Dynamo → Inductor → Triton。
这一课你要亲眼看:FX Graph 长什么样?Triton kernel 的代码是什么?Graph break 在哪?
我们会用 TORCH_LOGS 和几个简单的 Python 脚本来做这件事。
第一件事:确认你能在终端里访问 PyTorch 2.0+。以下所有代码你都要在自己的终端里跑(你的 conda/venv 环境)。
# 确认版本 ≥ 2.0
python -c "import torch; print(torch.__version__)"
如果输出是 2.x.y,就可以开始了。如果不是,升级到最新版:
pip install --upgrade torch
最简单的起点:对一个极小的函数用 torch.compile,然后导出它产生的 FX Graph。
# 复制这段代码到 Python 文件中,或直接在 Python REPL 里逐行跑
import torch
def my_fn(x, y):
# 一个刻意多步骤的小函数,方便观察融合
a = x + y
b = a * 2
c = torch.nn.functional.relu(b)
return c
# 步骤 1: 只用 Dynamo 导出图(不跑 Inductor)
x = torch.randn(4, 4)
y = torch.randn(4, 4)
explanation = torch._dynamo.explain(my_fn)(x, y)
print("=" * 60)
print("图中操作数:", explanation.graph_break_count)
print("Graph breaks:", explanation.break_reasons)
print("=" * 60)
print("FX Graph:")
print(explanation.graphs[0].graph.print_tabular())
你会看到类似这样的输出:
opcode 告诉你这是什么操作:placeholder(输入)、call_function(调用一个函数)、output(输出)。上一步只有图捕获,没有编译。现在让 torch.compile 真正运行,看 Inductor 把图编译成了什么。
# 在终端里设置环境变量,然后再跑 Python
# 方法 A:在 bash/zsh 里一行搞定
TORCH_LOGS="output_code" python -c "
import torch
def my_fn(x, y):
a = x + y
b = a * 2
c = torch.nn.functional.relu(b)
return c
fn = torch.compile(my_fn)
x = torch.randn(4, 4, device='cuda')
y = torch.randn(4, 4, device='cuda')
# 第一遍:触发编译
_ = fn(x, y)
# 第二遍:使用缓存
_ = fn(x, y)
print('Done — 查看上方终端输出中的 Triton kernel 代码')
"
device='cuda',需要 GPU。如果你没有 GPU,把 device='cuda' 改成 device='cpu'——Inductor 会生成 C++ 代码而不是 Triton 代码。底线是一样的:你能看到编译产物的源代码。
终端输出里你会看到一长串 Triton kernel 代码。找类似这样的片段:
Graph break 是 torch.compile 最常见的性能杀手。来看一个故意的例子:
# 终端里运行:
TORCH_LOGS="graph_breaks" python -c "
import torch
def bad_fn(x):
# Step 1: 正常 tensor 操作 → 可以被编译
y = x * 2
y = y + 1
# Step 2: .item() 强制 CPU 同步 → Graph Break!
val = y.sum().item()
print(f'val = {val}') # Python print 也会断
# Step 3: 后面又是正常 tensor 操作
z = y / 3
return z
fn = torch.compile(bad_fn)
x = torch.randn(100)
_ = fn(x)
"
你会看到类似这样的输出:
这意味着你的函数被 切成 3 段:
每一次进出 eager mode,编译的收益就被削弱一次。这就是为什么 fullgraph=True 强制编译整个图——如果用了它报错,说明你的模型里存在 graph break。
现在你已经知道怎么单独看了。搞一个真实的调试命令:
# 一键看 graph breaks + 每次重编译原因 + Triton kernel 输出
TORCH_LOGS="graph_breaks,recompiles,output_code" python your_training_script.py 2>&1 | head -200
这条命令的含义:
| Flag | 回答什么 | 什么时候用 |
|---|---|---|
graph_breaks |
图在哪断了?为什么? | 每次调 torch.compile 都应该看一眼 |
recompiles |
为什么同一个函数被反复编译?(shape 变了?dtype 变了?) | 训练慢但 graph break 很少——可能是反复重编译 |
output_code |
Inductor 生成了什么 kernel?融合效果如何? | 想确认编译确实在做融合,不是在空转 |
guards |
编译缓存的守卫条件是什么?(哪些条件变化会触发重编译) | 排查重编译风暴的根因 |
一切同理,只是 backend 换成 C++:
# CPU 版本
TORCH_LOGS="output_code" python -c "
import torch
def my_fn(x, y):
a = x + y
b = a * 2
c = torch.nn.functional.relu(b)
return c
fn = torch.compile(my_fn)
x = torch.randn(4, 4) # 无 device=... → CPU
y = torch.randn(4, 4)
_ = fn(x, y)
print('Done')
"
你会看到 Inductor 生成的 C++/OpenMP 代码而不是 Triton。原理完全一样——图捕获 → 融合 → 生成代码,只是目标硬件不同。
你学会了三个调试技能:
torch._dynamo.explain() 导出图,验证 Dynamo 抓到了什么TORCH_LOGS="output_code",确认 Inductor 在做融合而不是空转TORCH_LOGS="graph_breaks" 定位断点,修复后加 fullgraph=TrueTORCH_LOGS="output_code" 会显示什么?torch._dynamo.explain() 的作用是?TORCH_LOGS="graph_breaks",把输出贴给我,我帮你分析。