based on pytorch 2.6
torch.compile()
and torch.export
both leverage the PT2 stack, but serve different purposes.
torch.compile
torch.compile()
is a JIT compiler that falls back to the Python runtime for untraceable parts, offering flexibility.
torch.export
In contrast, torch.export
is an AOT compiler that captures the full graph, errors on untraceable code, and produces portable, low-level (ATen
) graphs ideal for deployment.
Internally
- torch.compile(): Uses TorchDynamo for tracing and AOT Autograd for optimization, with torch.fx as the graph backbone, enabling runtime flexibility.
- torch.export(): Relies on TorchDynamo for broad bytecode tracing, AOT Autograd for ATen lowering, and torch.fx for graph representation, ensuring a standalone graph.
Comparison Table
Feature | torch.compile() |
torch.export() |
---|---|---|
Compilation | JIT | AOT |
Graph Capture | Partial (w/ breaks) | Full (errors if untraceable) |
Untraceable Handling | Python runtime fallback | Requires rewrite/info |
Graph Output | Flexible, runtime-tied | Portable, ATen-level |
Python Support | High (via fallback) | Broad (bytecode tracing) |