Static FX Graph Generation Pass#
Project Overview#
This document provides a comprehensive specification for building a Static FX Graph Generation Pass in the Jaseci compiler. Unlike PyTorch 2's dynamic symbolic tracing approach (which uses runtime execution with Proxy objects), this pass will perform static analysis on the UniIR representation of PyTorch models to construct an FX-like graph without symbolic execution. The key innovation is that instead of breaking the graph when encountering unsupported code (like data-dependent control flow), this pass will tag and annotate those regions while maintaining a complete representation of the model.
Motivation#
PyTorch 2's torch.fx.symbolic_trace
has several limitations:
- Graph Breaks: When encountering data-dependent control flow or unsupported operations, PyTorch creates multiple disconnected FX graphs
- Runtime Requirement: Symbolic tracing requires executing the model with Proxy objects
- Limited Analysis: Graph breaks prevent whole-program optimization and analysis
- No Static Guarantees: Cannot analyze control flow without runtime execution
This project aims to create a static analysis pass that:
- Works entirely on the UniIR AST without execution
- Creates a complete graph representation even for "untraceble" code
- Tags problematic regions (graph break causes) inline in the representation
- Enables static optimizations and analysis not possible with dynamic tracing
Background: PyTorch FX Graph System#
How PyTorch FX Works#
graph TB
A[PyTorch Module] -->|symbolic_trace| B[Tracer]
B -->|Creates Proxy Objects| C[Proxy Execution]
C -->|Records Operations| D[FX Graph]
D -->|Nodes & Operations| E[GraphModule]
style C fill:#e74c3c,stroke:#ecf0f1,stroke-width:2px,color:#fff
style D fill:#27ae60,stroke:#ecf0f1,stroke-width:2px,color:#fff
PyTorch FX uses a runtime tracing approach:
- Proxy Objects: Special wrapper objects that intercept operations via
__torch_function__
- Symbolic Execution: The model's forward method is executed with Proxy inputs
- Operation Recording: Each PyTorch operation on Proxies is recorded as a Node in the Graph
- IR Construction: The recorded operations form an FX Graph (intermediate representation)
FX Graph Structure#
An FX Graph consists of:
# FX Graph Node Types
- placeholder: Input parameters
- get_attr: Accessing module parameters
- call_function: Function calls (torch.relu, etc.)
- call_method: Method calls on tensors
- call_module: Submodule invocations
- output: Return values
Graph Breaks in PyTorch 2#
Graph breaks occur when TorchDynamo encounters code it cannot trace:
def forward(self, x):
# Traceable
x = self.linear1(x)
# GRAPH BREAK: Data-dependent control flow
if x.sum() > 0: # Cannot evaluate Proxy to bool
x = self.linear2(x)
else:
x = self.linear3(x)
# New graph starts here
return self.output(x)
Common Graph Break Causes:
- Data-dependent conditionals (
if tensor.sum() > 0
) - Data-dependent loops (
for i in range(tensor.item())
) - Python I/O operations (
print
, file operations) - Unsupported Python builtins (
getattr
,inspect
module) - Dynamic attribute access
- Certain tensor operations that create dynamic shapes
PyTorch Codebase References#
To understand FX graph generation, examine these key files in the PyTorch repository:
torch/fx/
├── symbolic_trace.py # Main tracing entry point
├── proxy.py # Proxy object implementation
├── graph.py # Graph and Node classes
├── interpreter.py # Graph execution
└── node.py # Node representation
torch/_dynamo/
├── symbolic_convert.py # TorchDynamo bytecode analysis
├── guards.py # Dynamic shape guards
└── eval_frame.py # Frame evaluation hooks
Key Functions to Study:
- torch.fx.symbolic_trace()
- Main entry point
- torch.fx.Tracer.trace()
- Core tracing logic
- torch.fx.Proxy.__torch_function__()
- Operation interception
- torch._dynamo.symbolic_convert.InstructionTranslator
- Bytecode to FX conversion
The Jaseci Compiler Architecture#
High-Level Compiler Pipeline#
graph TB
subgraph "Frontend"
A[Source Code] --> B[Lexer/Parser]
B --> C[UniIR AST]
end
subgraph "Analysis Passes"
C --> D[SymTabBuildPass]
D --> E[DeclImplMatchPass]
E --> F[DefUsePass]
F --> G[SemDefMatchPass]
G --> H[CFGBuildPass]
end
subgraph "Type Checking"
H --> I[TypeCheckPass]
end
subgraph "Code Generation"
I --> J[PyastGenPass]
J --> K[PyJacAstLinkPass]
K --> L[PyBytecodeGenPass]
end
L --> M[Python Bytecode]
style C fill:#3498db,stroke:#ecf0f1,stroke-width:2px,color:#fff
style H fill:#e67e22,stroke:#ecf0f1,stroke-width:2px,color:#fff
style M fill:#27ae60,stroke:#ecf0f1,stroke-width:2px,color:#fff
Key Compiler Components#
1. UniIR (Unified Intermediate Representation)#
Located in jac/jaclang/compiler/unitree.py
, UniIR is the AST representation used throughout the compiler.
Key Characteristics: - Universal: Represents both Jac and Python code - Hierarchical: Tree structure with parent/child relationships - Annotated: Nodes carry symbol tables, type information, and source locations - Traversable: Built for pass-based transformations
Core UniNode Base Class:
class UniNode:
"""Base class for all IR nodes"""
def __init__(self, kid: Sequence[UniNode]) -> None:
self.parent: Optional[UniNode] = None
self.kid: list[UniNode] = [] # Child nodes
self._sub_node_tab: dict[type, list[UniNode]] = {} # Fast subnode lookup
self.gen: CodeGenTarget = CodeGenTarget() # Code generation target
self.loc: CodeLocInfo = CodeLocInfo() # Source location
Important UniNode Types:
classDiagram
UniNode <|-- UniScopeNode
UniNode <|-- Expr
UniNode <|-- UniCFGNode
UniScopeNode <|-- Module
UniScopeNode <|-- Archetype
UniScopeNode <|-- Ability
Expr <|-- FuncCall
Expr <|-- AtomTrailer
Expr <|-- BinaryExpr
UniCFGNode <|-- IfStmt
UniCFGNode <|-- WhileStmt
UniCFGNode <|-- ForStmt
UniCFGNode <|-- Assignment
class UniNode {
+parent: Optional[UniNode]
+kid: list[UniNode]
+loc: CodeLocInfo
+sym_tab: UniScopeNode
}
class UniScopeNode {
+sym_tab: SymbolTable
+scope_type: str
}
class UniCFGNode {
+bb_in: list[UniCFGNode]
+bb_out: list[UniCFGNode]
}
2. Pass Infrastructure#
Located in jac/jaclang/compiler/passes/
, the compiler uses a pass-based architecture.
Base Transform Class (transform.py
):
class Transform(ABC, Generic[T, R]):
"""Base class for all compiler transformations"""
def __init__(self, ir_in: T, prog: JacProgram):
self.ir_in: T = ir_in
self.ir_out: R = self.transform(ir_in=ir_in)
self.errors_had: list[Alert] = []
self.warnings_had: list[Alert] = []
@abstractmethod
def transform(self, ir_in: T) -> R:
"""Transform IR input to IR output"""
pass
UniPass Class (uni_pass.py
):
class UniPass(Transform[uni.Module, uni.Module]):
"""Base class for AST traversal passes"""
def enter_node(self, node: uni.UniNode) -> None:
"""Called when entering a node during traversal"""
# Dispatches to enter_{node_type} methods
def exit_node(self, node: uni.UniNode) -> None:
"""Called when exiting a node during traversal"""
# Dispatches to exit_{node_type} methods
def traverse(self, node: uni.UniNode) -> uni.UniNode:
"""Depth-first traversal of the AST"""
self.enter_node(node)
for child in node.kid:
self.traverse(child)
self.exit_node(node)
Pass Dispatch Mechanism:
Passes use dynamic method dispatch based on node type:
# In UniPass
def enter_node(self, node: uni.UniNode) -> None:
method_name = f"enter_{pascal_to_snake(type(node).__name__)}"
if hasattr(self, method_name):
getattr(self, method_name)(node)
# In your pass
def enter_func_call(self, node: uni.FuncCall) -> None:
# Handle function call nodes
pass
def exit_if_stmt(self, node: uni.IfStmt) -> None:
# Handle if statement nodes
pass
3. Existing Passes#
SymTabBuildPass#
Purpose: Constructs symbol tables for name resolution
Location: jac/jaclang/compiler/passes/main/sym_tab_build_pass.py
Key Operations:
- Creates scope hierarchies
- Registers symbols (variables, functions, classes)
- Links parent-child scopes
- Adds special symbols (self
, super
)
Example:
def enter_archetype(self, node: uni.Archetype) -> None:
self.push_scope_and_link(node)
node.parent_scope.def_insert(node, access_spec=node)
def enter_ability(self, node: uni.Ability) -> None:
self.push_scope_and_link(node)
if node.is_method:
# Add 'self' symbol
node.sym_tab.def_insert(uni.Name.gen_stub_from_node(node, "self"))
CFGBuildPass#
Purpose: Builds control flow graphs for program analysis
Location: jac/jaclang/compiler/passes/main/cfg_build_pass.py
Key Operations: - Identifies basic blocks (straight-line code sequences) - Links basic blocks with control flow edges - Tracks loop structures - Handles conditional branches
CFG Structure:
graph TD
subgraph "CFG Construction"
A[Entry Block] --> B[Basic Block 1]
B --> C{If Statement}
C -->|True| D[Then Block]
C -->|False| E[Else Block]
D --> F[Merge Block]
E --> F
F --> G[Exit Block]
end
style C fill:#e74c3c,stroke:#ecf0f1,stroke-width:2px,color:#fff
Basic Block Connections:
class UniCFGNode(UniNode):
"""Node participating in control flow"""
bb_in: list[UniCFGNode] # Incoming edges
bb_out: list[UniCFGNode] # Outgoing edges
PreDynamoPass#
Purpose: Prepares PyTorch code for dynamic compilation by transforming certain patterns
Location: jac/jaclang/compiler/passes/main/predynamo_pass.py
Example Transformation:
# Before
if condition:
x = tensor_a
else:
x = tensor_b
# After (transformed to)
x = torch.where(condition, tensor_a, tensor_b)
This pass is highly relevant to our project as it already performs PyTorch-specific transformations!
Project Architecture#
Overall Design#
graph TB
subgraph "Existing Passes"
A[SymTabBuildPass] --> B[DefUsePass]
B --> C[CFGBuildPass]
end
subgraph "New Static FX Pass"
C --> D[FunctionInlinePass]
D --> E[StaticFxGraphBuildPass]
E --> F[GraphBreakAnnotationPass]
end
subgraph "Output"
F --> G[StaticFxGraph IR]
G --> H[Graph Visualization]
G --> I[Optimization Passes]
end
style D fill:#f39c12,stroke:#ecf0f1,stroke-width:2px,color:#fff
style E fill:#27ae60,stroke:#ecf0f1,stroke-width:2px,color:#fff
style F fill:#3498db,stroke:#ecf0f1,stroke-width:2px,color:#fff
style G fill:#9b59b6,stroke:#ecf0f1,stroke-width:2px,color:#fff
Phase 1: Function Inlining Pass#
Purpose: Inline method calls in PyTorch model forward methods to create a flat representation for analysis.
Why Inlining is Necessary#
PyTorch models typically have this structure:
class MyModel(nn.Module):
def forward(self, x):
x = self.layer1(x) # Calls another module
x = self._helper(x) # Calls private method
return self.layer2(x)
def _helper(self, x):
return F.relu(x)
Without inlining, we'd have:
- Incomplete static analysis (can't see into _helper
)
- Missing optimization opportunities
- Fragmented graph representation
With inlining, we get:
def forward_inlined(self, x):
# Inlined self.layer1(x)
x = self.layer1.linear(x)
# Inlined self._helper(x)
x = F.relu(x)
# Inlined self.layer2(x)
x = self.layer2.linear(x)
return x
Implementation Strategy#
Location: Create jac/jaclang/compiler/passes/main/function_inline_pass.py
Data Structure for Inlining:
@dataclass
class InlineCandidate:
"""Represents a function that can be inlined"""
func_node: uni.Ability
call_site: uni.FuncCall
caller_node: uni.Ability
inline_priority: int # Higher = inline first
@dataclass
class InlineContext:
"""Context for performing inlining"""
variable_mapping: dict[str, str] # Old name -> New name
parent_scope: uni.UniScopeNode
depth: int # Current inline depth
Algorithm:
flowchart TD
A[Start: Identify Model Class] --> B[Find forward method]
B --> C[Collect function calls in forward]
C --> D{Is call to:<br/>- Private method?<br/>- nn.Module subclass?<br/>- Simple function?}
D -->|Yes| E[Mark as inline candidate]
D -->|No| F[Skip inlining]
E --> G[Retrieve function body]
G --> H[Rename variables to avoid conflicts]
H --> I[Replace call with inlined body]
I --> J[Update symbol table]
J --> K{More candidates?}
K -->|Yes| C
K -->|No| L[End]
F --> K
Key Methods:
class FunctionInlinePass(UniPass):
def before_pass(self) -> None:
self.inline_candidates: list[InlineCandidate] = []
self.inline_depth_limit = 3 # Prevent infinite recursion
self.inlined_functions: set[str] = set()
def enter_archetype(self, node: uni.Archetype) -> None:
"""Look for PyTorch model classes"""
# Check if inherits from nn.Module
if self.is_pytorch_module(node):
self.current_model = node
def enter_ability(self, node: uni.Ability) -> None:
"""Find forward method"""
if node.name_ref.value == "forward":
self.analyze_for_inlining(node)
def analyze_for_inlining(self, forward_method: uni.Ability) -> None:
"""Find all function calls that should be inlined"""
calls = self.get_all_sub_nodes(forward_method, uni.FuncCall)
for call in calls:
if self.should_inline(call):
candidate = self.create_inline_candidate(call)
self.inline_candidates.append(candidate)
def should_inline(self, call: uni.FuncCall) -> bool:
"""Determine if a call should be inlined"""
# Inline if:
# 1. Private method call (self._method)
# 2. Call to simple nn.Module (Linear, Conv2d, etc.)
# 3. Utility function (not recursive, not too large)
pass
def perform_inline(self, candidate: InlineCandidate) -> None:
"""Replace call with inlined function body"""
# 1. Clone function body
# 2. Rename variables (prevent conflicts)
# 3. Replace parameters with arguments
# 4. Update symbol table
# 5. Replace call node with inlined body
pass
Variable Renaming Example:
# Original helper function
def _helper(self, x, scale):
result = x * scale
return result
# Called as: y = self._helper(a, 2.0)
# After inlining (renamed variables)
__helper_x_0 = a
__helper_scale_0 = 2.0
__helper_result_0 = __helper_x_0 * __helper_scale_0
y = __helper_result_0
Handling Special Cases#
- Recursive calls: Don't inline (mark with annotation)
- External functions: Only inline if definition is available in IR
- Lambda functions: Can inline directly
- Generator functions: Skip (cannot inline)
Phase 2: Static FX Graph Builder#
Purpose: Construct an FX-like graph representation from the inlined UniIR.
Graph IR Design#
Location: Create jac/jaclang/compiler/passes/main/static_fx_ir.py
New IR Nodes:
@dataclass
class StaticFxNode:
"""Node in the static FX graph"""
op: str # "placeholder", "call_function", "call_method", "get_attr", "output"
name: str # Unique identifier
target: Any # Function/method/attribute being called
args: tuple[Any, ...] # Positional arguments
kwargs: dict[str, Any] # Keyword arguments
meta: dict[str, Any] # Metadata (types, shapes, source location)
users: list[StaticFxNode] # Nodes that use this node's output
graph_break_reason: Optional[str] = None # Why this causes a graph break
class StaticFxGraph:
"""Complete FX-like graph representation"""
nodes: list[StaticFxNode]
input_nodes: list[StaticFxNode]
output_nodes: list[StaticFxNode]
graph_break_regions: list[GraphBreakRegion]
@dataclass
class GraphBreakRegion:
"""Represents a region that would cause graph breaks in PyTorch"""
reason: str # "data_dependent_control_flow", "dynamic_loop", etc.
nodes: list[StaticFxNode] # Nodes involved in the break
source_loc: CodeLocInfo
workaround: Optional[str] # Suggested fix
Graph Construction Algorithm#
flowchart TD
A[Start: Get inlined forward method] --> B[Create StaticFxGraph]
B --> C[Process function signature]
C --> D[Create placeholder nodes for inputs]
D --> E[Traverse function body in execution order]
E --> F{Node type?}
F -->|Assignment| G[Create get_attr or call_*]
F -->|FuncCall| H[Analyze call target]
F -->|IfStmt| I[Handle conditional]
F -->|Loop| J[Handle loop]
F -->|Return| K[Create output node]
G --> L[Add to graph]
H --> M{Is PyTorch op?}
M -->|Yes| N[Create call_function node]
M -->|No| O[Mark as potential graph break]
I --> P{Can statically analyze?}
P -->|Yes| Q[Create conditional node]
P -->|No| R[Mark as graph break: data-dependent]
J --> S{Static loop bounds?}
S -->|Yes| T[Unroll loop]
S -->|No| U[Mark as graph break: dynamic loop]
L --> V{More nodes?}
N --> V
O --> V
Q --> V
R --> V
T --> V
U --> V
K --> W[End]
V -->|Yes| E
V -->|No| W
Implementation:
class StaticFxGraphBuildPass(UniPass):
def before_pass(self) -> None:
self.current_graph: Optional[StaticFxGraph] = None
self.node_map: dict[str, StaticFxNode] = {} # Variable name -> FX node
self.graph_break_regions: list[GraphBreakRegion] = []
def enter_ability(self, node: uni.Ability) -> None:
"""Build graph for forward method"""
if node.name_ref.value == "forward" and self.is_pytorch_module(node.parent):
self.current_graph = StaticFxGraph(nodes=[], ...)
self.build_graph(node)
def build_graph(self, forward_method: uni.Ability) -> None:
"""Main graph construction"""
# 1. Create placeholder nodes for parameters
self.create_input_placeholders(forward_method)
# 2. Process function body
for stmt in forward_method.body:
self.process_statement(stmt)
# 3. Create output node
self.create_output_node(forward_method)
def process_statement(self, stmt: uni.UniNode) -> None:
"""Process a single statement"""
if isinstance(stmt, uni.Assignment):
self.process_assignment(stmt)
elif isinstance(stmt, uni.IfStmt):
self.process_conditional(stmt)
elif isinstance(stmt, (uni.InForStmt, uni.WhileStmt)):
self.process_loop(stmt)
elif isinstance(stmt, uni.ReturnStmt):
self.process_return(stmt)
def process_assignment(self, node: uni.Assignment) -> None:
"""Convert assignment to FX node"""
if isinstance(node.value, uni.FuncCall):
fx_node = self.create_fx_call_node(node.value)
# Map variable to FX node
target_name = node.target[0].value
self.node_map[target_name] = fx_node
def create_fx_call_node(self, call: uni.FuncCall) -> StaticFxNode:
"""Create FX node for function call"""
# Determine operation type
if self.is_torch_function(call):
op = "call_function"
target = self.resolve_torch_function(call)
elif self.is_tensor_method(call):
op = "call_method"
target = self.extract_method_name(call)
elif self.is_module_call(call):
op = "call_module"
target = self.extract_module_path(call)
else:
# Unknown call type - potential graph break
op = "call_function"
target = "unknown"
# Mark as graph break
# Extract arguments
args = self.extract_args(call)
kwargs = self.extract_kwargs(call)
fx_node = StaticFxNode(
op=op,
name=self.generate_node_name(),
target=target,
args=args,
kwargs=kwargs,
meta={"source_loc": call.loc}
)
self.current_graph.nodes.append(fx_node)
return fx_node
def process_conditional(self, node: uni.IfStmt) -> None:
"""Handle if statements"""
condition_expr = node.condition
# Check if condition is data-dependent
if self.is_data_dependent(condition_expr):
# Mark as graph break
break_region = GraphBreakRegion(
reason="data_dependent_control_flow",
nodes=[], # Will be populated
source_loc=node.loc,
workaround="Consider using torch.where() or masking"
)
self.graph_break_regions.append(break_region)
# Still process both branches (record them)
self.process_branch_with_tag(node.body, "then_branch", break_region)
if node.else_body:
self.process_branch_with_tag(node.else_body.body, "else_branch", break_region)
else:
# Statically resolvable - can safely include
# (e.g., if check based on model config, not tensor values)
pass
def is_data_dependent(self, expr: uni.Expr) -> bool:
"""Check if expression depends on tensor runtime values"""
# Data-dependent if:
# - Tensor item() call
# - Tensor comparison (tensor > 0)
# - Tensor shape/size accessed as value
# - Any operation on tensor producing scalar bool
# Find all tensor operations in expression
tensor_ops = self.find_tensor_operations(expr)
return any(self.produces_runtime_value(op) for op in tensor_ops)
def process_loop(self, node: uni.UniNode) -> None:
"""Handle loops"""
if isinstance(node, uni.InForStmt):
# Check if loop bounds are static
if self.has_static_bounds(node):
# Unroll the loop
self.unroll_loop(node)
else:
# Dynamic loop - graph break
break_region = GraphBreakRegion(
reason="dynamic_loop",
nodes=[],
source_loc=node.loc,
workaround="Use static loop bounds or vectorize"
)
self.graph_break_regions.append(break_region)
# Still record loop body with tag
self.process_loop_body_with_tag(node.body, break_region)
Recognizing PyTorch Operations#
PyTorch Operation Categories:
TORCH_FUNCTIONS = {
# Tensor creation
"torch.zeros", "torch.ones", "torch.tensor", "torch.randn",
# Operations
"torch.matmul", "torch.add", "torch.mul", "torch.div",
"torch.cat", "torch.stack", "torch.split",
# Activations
"torch.relu", "torch.sigmoid", "torch.tanh", "torch.softmax",
# NN functions
"F.linear", "F.conv2d", "F.max_pool2d", "F.dropout",
}
TENSOR_METHODS = {
"view", "reshape", "permute", "transpose",
"sum", "mean", "std", "max", "min",
"unsqueeze", "squeeze", "expand",
}
NN_MODULES = {
"torch.nn.Linear", "torch.nn.Conv2d", "torch.nn.BatchNorm2d",
"torch.nn.ReLU", "torch.nn.Dropout", "torch.nn.MaxPool2d",
}
Recognition Strategy:
def is_torch_function(self, call: uni.FuncCall) -> bool:
"""Check if call is a torch.* function"""
if isinstance(call.target, uni.AtomTrailer):
# Resolve full qualified name
full_name = self.resolve_qualified_name(call.target)
return full_name in TORCH_FUNCTIONS
return False
def resolve_qualified_name(self, node: uni.AtomTrailer) -> str:
"""Resolve torch.nn.functional.relu -> 'F.relu' or 'torch.relu'"""
parts = []
current = node
while isinstance(current, uni.AtomTrailer):
if isinstance(current.right, uni.Name):
parts.append(current.right.value)
current = current.target
if isinstance(current, uni.Name):
parts.insert(0, current.value)
return ".".join(parts)
Phase 3: Graph Break Annotation#
Purpose: Enhance graph break regions with detailed analysis and suggested workarounds.
class GraphBreakAnnotationPass(UniPass):
"""Analyzes graph breaks and suggests fixes"""
def annotate_graph_breaks(self, graph: StaticFxGraph) -> None:
"""Add detailed annotations to graph break regions"""
for region in graph.graph_break_regions:
# Analyze the cause
analysis = self.analyze_break_cause(region)
# Suggest workaround
workaround = self.suggest_workaround(region, analysis)
# Estimate performance impact
impact = self.estimate_break_impact(region, graph)
region.meta.update({
"analysis": analysis,
"workaround": workaround,
"impact": impact,
})
def suggest_workaround(self, region: GraphBreakRegion, analysis: dict) -> str:
"""Generate workaround suggestion"""
if region.reason == "data_dependent_control_flow":
return self.suggest_control_flow_fix(region, analysis)
elif region.reason == "dynamic_loop":
return self.suggest_loop_fix(region, analysis)
# ... more cases
Implementation Guide#
Step 1: Setup and Prerequisites#
1. Create new files in the codebase:
# Navigate to passes directory
cd jac/jaclang/compiler/passes/main
# Create new pass files
touch function_inline_pass.py
touch static_fx_graph_pass.py
touch graph_break_annotation_pass.py
# Create IR definition file
touch static_fx_ir.py
2. Update pass registry:
Edit jac/jaclang/compiler/passes/main/__init__.py
:
from .function_inline_pass import FunctionInlinePass
from .static_fx_graph_pass import StaticFxGraphBuildPass
from .graph_break_annotation_pass import GraphBreakAnnotationPass
__all__ = [
# ... existing passes
"FunctionInlinePass",
"StaticFxGraphBuildPass",
"GraphBreakAnnotationPass",
]
3. Study existing passes:
Read and understand these files thoroughly:
- unitree.py
(lines 44-650): UniNode class hierarchy
- uni_pass.py
: Pass infrastructure
- cfg_build_pass.py
: Example of complex pass with graph building
- predynamo_pass.py
: PyTorch-specific transformations
Step 2: Implement Function Inlining Pass#
File: jac/jaclang/compiler/passes/main/function_inline_pass.py
Implementation checklist:
graph TD
A[Step 2.1: Identify PyTorch modules] --> B[Step 2.2: Find forward methods]
B --> C[Step 2.3: Collect inline candidates]
C --> D[Step 2.4: Implement variable renaming]
D --> E[Step 2.5: Perform inlining]
E --> F[Step 2.6: Update symbol tables]
F --> G[Step 2.7: Test on simple model]
2.1: Start with detecting PyTorch model classes:
def is_pytorch_module(self, node: uni.Archetype) -> bool:
"""Check if archetype inherits from nn.Module"""
# Look at base classes
if hasattr(node, 'base_classes'):
for base in node.base_classes:
# Check if any base is nn.Module or torch.nn.Module
base_name = self.get_qualified_name(base)
if 'nn.Module' in base_name or 'Module' in base_name:
return True
return False
2.2: Find forward methods and helper methods:
def collect_methods(self, model_class: uni.Archetype) -> dict[str, uni.Ability]:
"""Collect all methods in the model class"""
methods = {}
for node in model_class.body:
if isinstance(node, uni.Ability):
methods[node.name_ref.value] = node
return methods
2.3: Determine what to inline:
def should_inline(self, call: uni.FuncCall, context: InlineContext) -> bool:
"""Decide if this call should be inlined"""
# Don't inline if too deep
if context.depth >= self.inline_depth_limit:
return False
# Check if it's a method call on self
if isinstance(call.target, uni.AtomTrailer):
if self.is_self_method_call(call.target):
method_name = self.extract_method_name(call.target)
# Inline private methods
if method_name.startswith('_'):
return True
# Check for simple nn.Module calls
if self.is_simple_nn_module(call):
return True
return False
2.4: Rename variables to avoid conflicts:
def rename_variables_in_body(
self,
function_body: list[uni.UniNode],
suffix: str
) -> list[uni.UniNode]:
"""Clone function body and rename all variables"""
# Use deep copy to clone nodes
cloned_body = self.deep_clone_nodes(function_body)
# Find all variable assignments and references
assignments = self.get_all_sub_nodes(cloned_body, uni.Assignment)
names = self.get_all_sub_nodes(cloned_body, uni.Name)
# Build rename map
rename_map = {}
for assign in assignments:
for target in assign.target:
if isinstance(target, uni.Name):
old_name = target.value
new_name = f"__{old_name}_{suffix}"
rename_map[old_name] = new_name
# Apply renames
for name_node in names:
if name_node.value in rename_map:
name_node.value = rename_map[name_node.value]
return cloned_body, rename_map
2.5: Perform the actual inlining:
def inline_function_call(
self,
call: uni.FuncCall,
func_def: uni.Ability,
context: InlineContext
) -> list[uni.UniNode]:
"""Replace function call with inlined body"""
# 1. Clone and rename function body
suffix = f"inline_{context.depth}_{self.inline_counter}"
self.inline_counter += 1
inlined_body, rename_map = self.rename_variables_in_body(
func_def.body,
suffix
)
# 2. Create parameter assignments
param_assignments = self.create_parameter_bindings(
func_def.signature.params,
call.params,
rename_map
)
# 3. Combine: parameter bindings + function body
result = param_assignments + inlined_body
# 4. Handle return statement
return_stmts = self.get_all_sub_nodes(result, uni.ReturnStmt)
if return_stmts:
# Replace return with assignment to call result
self.replace_returns_with_assignments(return_stmts)
return result
2.6: Update symbol tables after inlining:
def update_symbol_table_after_inline(
self,
parent_scope: uni.UniScopeNode,
inlined_nodes: list[uni.UniNode],
rename_map: dict[str, str]
) -> None:
"""Register new symbols from inlined code"""
# Add renamed variables to parent scope's symbol table
for old_name, new_name in rename_map.items():
# Create stub symbol
stub = uni.Name.gen_stub_from_node(
parent_scope,
new_name
)
parent_scope.sym_tab.def_insert(stub)
2.7: Test with a simple PyTorch model:
# Test case: test_function_inline_pass.py
def test_inline_simple_method():
source = """
import torch.nn as nn
class SimpleModel(nn.Module):
def forward(self, x):
return self._helper(x)
def _helper(self, x):
return x * 2
"""
# Run passes up to and including inline pass
# Verify that _helper is inlined into forward
pass
Step 3: Define Static FX IR#
File: jac/jaclang/compiler/passes/main/static_fx_ir.py
Key data structures:
from dataclasses import dataclass, field
from typing import Any, Optional
from enum import Enum
class FxOpType(Enum):
"""FX operation types (matching torch.fx)"""
PLACEHOLDER = "placeholder"
GET_ATTR = "get_attr"
CALL_FUNCTION = "call_function"
CALL_METHOD = "call_method"
CALL_MODULE = "call_module"
OUTPUT = "output"
@dataclass
class StaticFxNode:
"""Node in static FX graph"""
op: FxOpType
name: str
target: Any
args: tuple[Any, ...] = field(default_factory=tuple)
kwargs: dict[str, Any] = field(default_factory=dict)
# Additional static analysis info
meta: dict[str, Any] = field(default_factory=dict)
# Graph break annotation
is_graph_break: bool = False
graph_break_reason: Optional[str] = None
graph_break_workaround: Optional[str] = None
# Connections
users: list['StaticFxNode'] = field(default_factory=list)
def __repr__(self) -> str:
args_repr = ', '.join(str(a) for a in self.args)
kwargs_repr = ', '.join(f'{k}={v}' for k, v in self.kwargs.items())
all_args = ', '.join(filter(None, [args_repr, kwargs_repr]))
break_marker = " [GRAPH_BREAK]" if self.is_graph_break else ""
return f"{self.name}: {self.op.value}[{self.target}]({all_args}){break_marker}"
@dataclass
class GraphBreakRegion:
"""Represents a region causing graph breaks"""
reason: str
nodes: list[StaticFxNode] = field(default_factory=list)
source_loc: Any = None
workaround: Optional[str] = None
severity: str = "warning" # "info", "warning", "error"
# Detailed analysis
analysis: dict[str, Any] = field(default_factory=dict)
fixable: bool = False
class StaticFxGraph:
"""Complete FX graph with graph break annotations"""
def __init__(self, name: str):
self.name = name
self.nodes: list[StaticFxNode] = []
self._node_map: dict[str, StaticFxNode] = {}
self.input_nodes: list[StaticFxNode] = []
self.output_nodes: list[StaticFxNode] = []
self.graph_break_regions: list[GraphBreakRegion] = []
def add_node(self, node: StaticFxNode) -> StaticFxNode:
"""Add node to graph"""
self.nodes.append(node)
self._node_map[node.name] = node
return node
def get_node(self, name: str) -> Optional[StaticFxNode]:
"""Retrieve node by name"""
return self._node_map.get(name)
def add_edge(self, from_node: StaticFxNode, to_node: StaticFxNode) -> None:
"""Add edge between nodes"""
if to_node not in from_node.users:
from_node.users.append(to_node)
def print_graph(self) -> str:
"""Pretty print the graph"""
lines = [f"Graph: {self.name}", "=" * 50]
for node in self.nodes:
lines.append(str(node))
if node.graph_break_reason:
lines.append(f" └─ Break Reason: {node.graph_break_reason}")
if node.graph_break_workaround:
lines.append(f" └─ Workaround: {node.graph_break_workaround}")
if self.graph_break_regions:
lines.append("\nGraph Break Summary:")
lines.append("-" * 50)
for i, region in enumerate(self.graph_break_regions):
lines.append(f"{i+1}. {region.reason}")
lines.append(f" Severity: {region.severity}")
if region.workaround:
lines.append(f" Fix: {region.workaround}")
return "\n".join(lines)
def to_dot(self) -> str:
"""Generate Graphviz DOT representation"""
lines = ["digraph StaticFxGraph {"]
lines.append(' rankdir=TB;')
# Nodes
for node in self.nodes:
color = "red" if node.is_graph_break else "black"
label = f"{node.name}\\n{node.op.value}\\n{node.target}"
lines.append(f' "{node.name}" [label="{label}", color="{color}"];')
# Edges
for node in self.nodes:
for user in node.users:
lines.append(f' "{node.name}" -> "{user.name}";')
lines.append("}")
return "\n".join(lines)
Step 4: Implement Static FX Graph Builder#
File: jac/jaclang/compiler/passes/main/static_fx_graph_pass.py
Main pass structure:
import jaclang.compiler.unitree as uni
from jaclang.compiler.passes import UniPass
from .static_fx_ir import (
StaticFxGraph, StaticFxNode, FxOpType, GraphBreakRegion
)
class StaticFxGraphBuildPass(UniPass):
"""Build static FX graph from inlined forward method"""
def before_pass(self) -> None:
"""Initialize pass state"""
self.graphs: dict[str, StaticFxGraph] = {}
self.current_graph: Optional[StaticFxGraph] = None
# Maps variable names to FX nodes
self.value_map: dict[str, StaticFxNode] = {}
# Counter for unique node names
self.node_counter = 0
# Track potential graph breaks
self.graph_break_candidates: list[tuple[uni.UniNode, str]] = []
def enter_archetype(self, node: uni.Archetype) -> None:
"""Process PyTorch model classes"""
if not self.is_pytorch_module(node):
return
self.current_model = node
self.model_name = node.name.value
def enter_ability(self, node: uni.Ability) -> None:
"""Process forward method"""
if node.name_ref.value != "forward":
return
if not hasattr(self, 'current_model'):
return
# Create new graph
graph_name = f"{self.model_name}.forward"
self.current_graph = StaticFxGraph(name=graph_name)
# Build the graph
self.build_graph_from_forward(node)
# Store graph
self.graphs[graph_name] = self.current_graph
def build_graph_from_forward(self, forward: uni.Ability) -> None:
"""Main graph construction from forward method"""
# Step 1: Create placeholders for inputs
self.create_input_placeholders(forward)
# Step 2: Process function body
for stmt in forward.body:
self.process_statement(stmt)
# Step 3: Handle return/output
return_stmts = self.get_all_sub_nodes(forward, uni.ReturnStmt)
if return_stmts:
self.create_output_node(return_stmts[-1])
def create_input_placeholders(self, forward: uni.Ability) -> None:
"""Create placeholder nodes for function parameters"""
if not forward.signature or not forward.signature.params:
return
for param in forward.signature.params:
if isinstance(param, uni.Param):
# Skip 'self' parameter
if param.name.value == 'self':
continue
node = StaticFxNode(
op=FxOpType.PLACEHOLDER,
name=param.name.value,
target=param.name.value,
meta={
"type": param.type_tag if param.type_tag else "unknown"
}
)
self.current_graph.add_node(node)
self.current_graph.input_nodes.append(node)
self.value_map[param.name.value] = node
def process_statement(self, stmt: uni.UniNode) -> None:
"""Process different statement types"""
if isinstance(stmt, uni.Assignment):
self.process_assignment(stmt)
elif isinstance(stmt, uni.ExprStmt):
self.process_expr_stmt(stmt)
elif isinstance(stmt, uni.IfStmt):
self.process_if_statement(stmt)
elif isinstance(stmt, (uni.InForStmt, uni.IterForStmt)):
self.process_for_loop(stmt)
elif isinstance(stmt, uni.WhileStmt):
self.process_while_loop(stmt)
elif isinstance(stmt, uni.ReturnStmt):
# Handled separately
pass
else:
# Unknown statement type - potential issue
self.log_warning(f"Unhandled statement type: {type(stmt).__name__}", stmt)
def process_assignment(self, stmt: uni.Assignment) -> None:
"""Process assignment: x = f(y)"""
# Process right-hand side
if isinstance(stmt.value, uni.FuncCall):
fx_node = self.process_func_call(stmt.value)
elif isinstance(stmt.value, uni.BinaryExpr):
fx_node = self.process_binary_expr(stmt.value)
elif isinstance(stmt.value, uni.Name):
# Simple variable reference
fx_node = self.value_map.get(stmt.value.value)
else:
# Other expression types
fx_node = self.process_expression(stmt.value)
# Map target variable(s) to result node
for target in stmt.target:
if isinstance(target, uni.Name):
if fx_node:
self.value_map[target.value] = fx_node
def process_func_call(self, call: uni.FuncCall) -> StaticFxNode:
"""Process function call and create FX node"""
# Determine call type
if self.is_torch_function(call):
return self.create_torch_function_node(call)
elif self.is_torch_nn_functional(call):
return self.create_torch_function_node(call)
elif self.is_module_call(call):
return self.create_module_call_node(call)
elif self.is_tensor_method(call):
return self.create_tensor_method_node(call)
else:
# Unknown call - potential graph break
return self.create_unknown_call_node(call)
def create_torch_function_node(self, call: uni.FuncCall) -> StaticFxNode:
"""Create node for torch.* function call"""
# Extract function name
func_name = self.get_qualified_name(call.target)
# Extract arguments
args = self.extract_call_args(call)
kwargs = self.extract_call_kwargs(call)
# Create node
node = StaticFxNode(
op=FxOpType.CALL_FUNCTION,
name=self.generate_unique_name(func_name),
target=func_name,
args=args,
kwargs=kwargs,
meta={
"source_loc": call.loc,
}
)
self.current_graph.add_node(node)
# Add edges from arguments
self.add_edges_from_args(args, node)
return node
def create_module_call_node(self, call: uni.FuncCall) -> StaticFxNode:
"""Create node for nn.Module forward call"""
# Extract module path (e.g., "self.layer1")
module_path = self.get_module_path(call.target)
args = self.extract_call_args(call)
kwargs = self.extract_call_kwargs(call)
node = StaticFxNode(
op=FxOpType.CALL_MODULE,
name=self.generate_unique_name(module_path),
target=module_path,
args=args,
kwargs=kwargs,
meta={
"source_loc": call.loc,
}
)
self.current_graph.add_node(node)
self.add_edges_from_args(args, node)
return node
def process_if_statement(self, stmt: uni.IfStmt) -> None:
"""Process if statement - check for data dependency"""
condition = stmt.condition
# Check if condition depends on tensor values
if self.is_data_dependent_condition(condition):
# This is a graph break!
reason = "Data-dependent control flow (if statement)"
# Create graph break region
region = GraphBreakRegion(
reason=reason,
source_loc=stmt.loc,
severity="warning",
)
# Still process both branches to record operations
self.process_branch_with_annotation(
stmt.body,
region,
branch_type="then"
)
if stmt.else_body:
self.process_branch_with_annotation(
stmt.else_body.body,
region,
branch_type="else"
)
self.current_graph.graph_break_regions.append(region)
else:
# Static condition - safe to trace
# (Could evaluate at compile time)
for s in stmt.body:
self.process_statement(s)
if stmt.else_body:
for s in stmt.else_body.body:
self.process_statement(s)
def is_data_dependent_condition(self, condition: uni.Expr) -> bool:
"""Check if condition depends on runtime tensor values"""
# Look for patterns that indicate data dependency:
# 1. Tensor comparisons: x > 0, x.sum() > threshold
# 2. Tensor.item() calls
# 3. Tensor bool evaluation
# Find all function calls in condition
calls = self.get_all_sub_nodes(condition, uni.FuncCall)
for call in calls:
# Check for .item()
if self.is_tensor_item_call(call):
return True
# Check for tensor comparisons
comparisons = self.get_all_sub_nodes(condition, uni.BinaryExpr)
for comp in comparisons:
if self.involves_tensor_comparison(comp):
return True
return False
def process_for_loop(self, stmt: uni.InForStmt) -> None:
"""Process for loop - check for dynamic bounds"""
# Check if loop has static bounds
if self.has_static_loop_bounds(stmt):
# Can unroll the loop
self.unroll_static_loop(stmt)
else:
# Dynamic loop bounds - graph break
region = GraphBreakRegion(
reason="Dynamic loop bounds",
source_loc=stmt.loc,
severity="warning",
workaround="Use static loop bounds or vectorize operations"
)
# Process loop body with annotation
self.process_loop_body_with_annotation(stmt.body, region)
self.current_graph.graph_break_regions.append(region)
def has_static_loop_bounds(self, stmt: uni.InForStmt) -> bool:
"""Check if for loop has compile-time known bounds"""
# Check if iterating over:
# - range(constant)
# - List literal
# - Other static iterables
collection = stmt.collection
if isinstance(collection, uni.FuncCall):
# Check for range(N) where N is constant
if self.is_range_with_constant(collection):
return True
elif isinstance(collection, uni.ListVal):
# List literal - static
return True
return False
def unroll_static_loop(self, stmt: uni.InForStmt) -> None:
"""Unroll loop with static bounds"""
# Get loop iterations
iterations = self.get_loop_iterations(stmt)
# Process body for each iteration
for iter_value in iterations:
# Bind loop variable
loop_var = stmt.target.value
# Create node for iteration value
# Process body statements
for body_stmt in stmt.body:
self.process_statement(body_stmt)
def create_output_node(self, return_stmt: uni.ReturnStmt) -> None:
"""Create output node for return value"""
if not return_stmt.expr:
return
# Get the FX node representing return value
return_value = self.get_value_node(return_stmt.expr)
node = StaticFxNode(
op=FxOpType.OUTPUT,
name="output",
target="output",
args=(return_value,) if return_value else (),
meta={
"source_loc": return_stmt.loc,
}
)
self.current_graph.add_node(node)
self.current_graph.output_nodes.append(node)
if return_value:
self.current_graph.add_edge(return_value, node)
# Helper methods
def generate_unique_name(self, base: str) -> str:
"""Generate unique node name"""
name = f"{base}_{self.node_counter}"
self.node_counter += 1
return name
def extract_call_args(self, call: uni.FuncCall) -> tuple:
"""Extract positional arguments as FX nodes"""
args = []
for param in call.params:
if isinstance(param, uni.KWPair):
continue # Skip kwargs
arg_node = self.get_value_node(param)
args.append(arg_node)
return tuple(args)
def extract_call_kwargs(self, call: uni.FuncCall) -> dict:
"""Extract keyword arguments as FX nodes"""
kwargs = {}
for param in call.params:
if isinstance(param, uni.KWPair):
key = param.key.value
value_node = self.get_value_node(param.value)
kwargs[key] = value_node
return kwargs
def get_value_node(self, expr: uni.Expr) -> StaticFxNode:
"""Get or create FX node for expression"""
if isinstance(expr, uni.Name):
# Variable reference
return self.value_map.get(expr.value)
elif isinstance(expr, uni.FuncCall):
return self.process_func_call(expr)
else:
# Other expressions - may need special handling
return self.process_expression(expr)
def add_edges_from_args(self, args: tuple, target_node: StaticFxNode) -> None:
"""Add edges from argument nodes to target node"""
for arg in args:
if isinstance(arg, StaticFxNode):
self.current_graph.add_edge(arg, target_node)
Step 5: Implement Graph Break Annotation Pass#
File: jac/jaclang/compiler/passes/main/graph_break_annotation_pass.py
from jaclang.compiler.passes import UniPass
from .static_fx_ir import StaticFxGraph, GraphBreakRegion
class GraphBreakAnnotationPass(UniPass):
"""Analyze and annotate graph break regions"""
def after_pass(self) -> None:
"""Process all graphs after traversal"""
# Access graphs from previous pass
if hasattr(self.prog, 'static_fx_graphs'):
for graph_name, graph in self.prog.static_fx_graphs.items():
self.annotate_graph(graph)
def annotate_graph(self, graph: StaticFxGraph) -> None:
"""Add detailed annotations to graph breaks"""
for region in graph.graph_break_regions:
# Perform detailed analysis
analysis = self.analyze_break_region(region)
region.analysis = analysis
# Generate workaround suggestion
workaround = self.suggest_workaround(region, analysis)
region.workaround = workaround
# Determine if fixable
region.fixable = self.is_fixable(region, analysis)
# Estimate impact
impact = self.estimate_impact(region, graph)
region.analysis['impact'] = impact
def analyze_break_region(self, region: GraphBreakRegion) -> dict:
"""Perform detailed analysis of graph break"""
analysis = {
"reason": region.reason,
"location": str(region.source_loc),
"affected_nodes": len(region.nodes),
}
if "control_flow" in region.reason.lower():
analysis.update(self.analyze_control_flow_break(region))
elif "loop" in region.reason.lower():
analysis.update(self.analyze_loop_break(region))
return analysis
def suggest_workaround(self, region: GraphBreakRegion, analysis: dict) -> str:
"""Generate specific workaround suggestion"""
if "data-dependent control" in region.reason.lower():
return (
"Consider refactoring to use tensor operations instead:\n"
" • Replace if-else with torch.where(condition, x, y)\n"
" • Use masking: result = condition * x + (~condition) * y\n"
" • Apply conditional operations: x.masked_fill(condition, value)"
)
elif "dynamic loop" in region.reason.lower():
return (
"Consider these alternatives:\n"
" • Use static loop bounds if possible\n"
" • Vectorize the loop using tensor operations\n"
" • Use torch.vmap for batch operations"
)
return "No specific workaround available"
def is_fixable(self, region: GraphBreakRegion, analysis: dict) -> bool:
"""Determine if break can be automatically fixed"""
# Simple cases that can be auto-fixed:
# - if-else with tensor operations -> torch.where
# - Static loops -> unrolling
if "control_flow" in region.reason.lower():
# Check if both branches are simple tensor ops
return analysis.get('branches_are_simple', False)
return False
Step 6: Integration and Testing#
6.1 Update JacProgram:
Edit jac/jaclang/compiler/program.py
:
# Add to ir_gen_sched
ir_gen_sched = [
SymTabBuildPass,
DeclImplMatchPass,
DefUsePass,
SemDefMatchPass,
CFGBuildPass,
FunctionInlinePass, # NEW
StaticFxGraphBuildPass, # NEW
GraphBreakAnnotationPass, # NEW
]
6.2 Create test suite:
Create jac/jaclang/compiler/passes/main/tests/test_static_fx_pass.py
:
import unittest
from jaclang.compiler.program import JacProgram
class TestStaticFxPass(unittest.TestCase):
def test_simple_linear_model(self):
"""Test static FX graph for simple linear model"""
source = """
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 5)
def forward(self, x):
x = self.linear(x)
x = torch.relu(x)
return x
"""
prog = JacProgram()
mod = prog.compile(use_str=source, file_path="test.py")
# Check that graph was built
self.assertIsNotNone(prog.static_fx_graphs)
self.assertIn("SimpleModel.forward", prog.static_fx_graphs)
graph = prog.static_fx_graphs["SimpleModel.forward"]
# Verify graph structure
self.assertEqual(len(graph.input_nodes), 1) # x
self.assertEqual(len(graph.output_nodes), 1)
# Should have: placeholder, call_module (linear), call_function (relu), output
self.assertEqual(len(graph.nodes), 4)
def test_data_dependent_control_flow(self):
"""Test detection of graph breaks"""
source = """
import torch
import torch.nn as nn
class ConditionalModel(nn.Module):
def forward(self, x):
if x.sum() > 0: # GRAPH BREAK
x = torch.relu(x)
else:
x = torch.tanh(x)
return x
"""
prog = JacProgram()
mod = prog.compile(use_str=source, file_path="test.py")
graph = prog.static_fx_graphs["ConditionalModel.forward"]
# Should have detected graph break
self.assertEqual(len(graph.graph_break_regions), 1)
region = graph.graph_break_regions[0]
self.assertIn("data-dependent", region.reason.lower())
self.assertIsNotNone(region.workaround)
6.3 CLI tool for visualization:
Create jac/jaclang/cli/static_fx_tool.py
:
#!/usr/bin/env python3
"""CLI tool for static FX graph analysis"""
import argparse
from jaclang.compiler.program import JacProgram
def main():
parser = argparse.ArgumentParser(description="Static FX Graph Analyzer")
parser.add_argument("file", help="Python file with PyTorch model")
parser.add_argument("--output", "-o", help="Output file for graph")
parser.add_argument("--format", choices=["text", "dot"], default="text")
args = parser.parse_args()
# Compile and generate graph
prog = JacProgram()
mod = prog.compile(file_path=args.file)
# Output graphs
if hasattr(prog, 'static_fx_graphs'):
for name, graph in prog.static_fx_graphs.items():
print(f"\n{'='*60}")
print(f"Graph: {name}")
print('='*60)
if args.format == "text":
print(graph.print_graph())
elif args.format == "dot":
print(graph.to_dot())
# Output to file
if args.output:
with open(args.output, 'w') as f:
if args.format == "text":
f.write(graph.print_graph())
else:
f.write(graph.to_dot())
if __name__ == "__main__":
main()
Usage Examples#
Example 1: Simple Model#
Input (simple_model.py
):
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return x
Static FX Graph Output:
Graph: SimpleNet.forward
==================================================
x: placeholder[x]()
fc1_0: call_module[self.fc1](x)
relu_1: call_function[F.relu](fc1_0)
fc2_2: call_module[self.fc2](relu_1)
output: output[output](fc2_2)
Graph Break Summary:
--------------------------------------------------
No graph breaks detected!
Example 2: Model with Graph Breaks#
Input (conditional_model.py
):
import torch
import torch.nn as nn
class ConditionalNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 10)
self.fc2 = nn.Linear(10, 10)
def forward(self, x):
x = self.fc1(x)
# GRAPH BREAK: Data-dependent control flow
if x.sum() > 0:
x = torch.relu(x)
else:
x = torch.tanh(x)
x = self.fc2(x)
return x
Static FX Graph Output:
Graph: ConditionalNet.forward
==================================================
x: placeholder[x]()
fc1_0: call_module[self.fc1](x)
sum_1: call_method[sum](fc1_0)
compare_2: call_function[operator.gt](sum_1, 0) [GRAPH_BREAK]
└─ Break Reason: Data-dependent control flow (if statement)
└─ Workaround: Consider refactoring to use tensor operations:
• Replace if-else with torch.where(condition, x, y)
• Use masking: result = condition * x + (~condition) * y
relu_3: call_function[torch.relu](fc1_0) [GRAPH_BREAK: then_branch]
tanh_4: call_function[torch.tanh](fc1_0) [GRAPH_BREAK: else_branch]
fc2_5: call_module[self.fc2](...)
output: output[output](fc2_5)
Graph Break Summary:
--------------------------------------------------
1. Data-dependent control flow (if statement)
Severity: warning
Fixable: yes
Fix: Consider refactoring to use tensor operations:
• Replace if-else with torch.where(condition, x, y)
• Use masking
Future Extensions#
1. Automatic Graph Break Fixing#
Implement a pass that automatically refactors code to eliminate graph breaks:
class GraphBreakFixerPass(UniPass):
"""Automatically fix graph breaks where possible"""
def fix_data_dependent_conditional(self, if_stmt: uni.IfStmt) -> uni.Assignment:
"""Transform if-else to torch.where"""
# Convert:
# if cond: x = a else: x = b
# To:
# x = torch.where(cond, a, b)
pass
2. Shape Analysis Integration#
Integrate with shape inference to detect dynamic shapes:
class ShapeAnalysisPass(UniPass):
"""Infer tensor shapes through the graph"""
def infer_shapes(self, graph: StaticFxGraph) -> None:
"""Propagate shape information through graph"""
# Start from input placeholders
# Propagate through operations
# Detect dynamic shapes
pass
3. Optimization Passes#
Use the static FX graph for optimizations:
class FxOptimizationPass(UniPass):
"""Optimize FX graph"""
def fuse_operations(self, graph: StaticFxGraph) -> StaticFxGraph:
"""Fuse consecutive operations (e.g., conv + relu)"""
pass
def eliminate_dead_code(self, graph: StaticFxGraph) -> StaticFxGraph:
"""Remove unused nodes"""
pass
4. Visualization Dashboard#
Create an interactive web dashboard:
# Flask app for visualizing graphs
@app.route('/graph/<model_name>')
def show_graph(model_name):
# Render interactive graph with:
# - Node highlighting
# - Graph break annotations
# - Suggested fixes
pass
Key Differences from PyTorch FX#
Aspect | PyTorch FX | This Project (Static FX) |
---|---|---|
Analysis Type | Dynamic (runtime tracing) | Static (compile-time analysis) |
Execution | Requires running model | No execution needed |
Graph Breaks | Creates multiple disconnected graphs | Single annotated graph |
Control Flow | Cannot handle data-dependent | Tags but includes in graph |
Completeness | Only traceable code | Complete program representation |
Use Case | Runtime optimization | Static analysis & optimization |
Conclusion#
This project combines:
- Compiler Engineering: Leveraging Jaseci's pass-based architecture
- Static Analysis: Analyzing code without execution
- ML Systems: Understanding PyTorch's execution model
- Graph Theory: Building and manipulating computation graphs
The resulting system will enable:
- Better Analysis: See the full model structure without runtime execution
- Optimization Opportunities: Identify patterns for compiler optimizations
- Developer Tools: Help developers write more traceable PyTorch code
- Research Platform: Experiment with novel ML compiler techniques
Next Steps: Start with implementing the FunctionInlinePass
, then move to basic graph construction for simple models, and gradually add support for more complex patterns and graph break detection.
Good luck with the implementation! This is an exciting project at the intersection of compilers and ML systems.