Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
a4ab52f
feat(relax/frontend/torch): Add basic range constraint support from E…
demoncoder-crypto Apr 26, 2025
71f8a98
Fix: Insert test_dynamic_shape_with_constraints
demoncoder-crypto Apr 28, 2025
f8ad0d7
refactor(test): Refactor constraint test to use verify_model and add …
demoncoder-crypto May 2, 2025
e5710b7
fix(test): Define tir.Var for TVMScript parsing in constraint test
demoncoder-crypto May 3, 2025
5c7758c
style: Apply black formatting
demoncoder-crypto May 3, 2025
9ca05a6
fix(relax/torch): Handle ExportedProgram range constraints and add tests
demoncoder-crypto May 4, 2025
8ab98aa
Merge branch 'main' into fix/relax-pytorch-constraints-v2
demoncoder-crypto May 4, 2025
7201b72
style: Apply formatting fixes to test_frontend_from_exported_program.py
demoncoder-crypto May 4, 2025
f7e23f4
style: Fix trailing whitespace in test file
demoncoder-crypto May 4, 2025
bcab702
feat(relax): Enhance PyTorch ExportedProgram range constraints support
demoncoder-crypto May 4, 2025
70bff93
feat: Enhance PyTorch range constraints support
demoncoder-crypto May 4, 2025
54885dd
style: Fix lint errors reported by CI
demoncoder-crypto May 4, 2025
4717288
style: Apply final lint fixes for translator and test files
demoncoder-crypto May 4, 2025
073ec93
Apply Black code formatting to exported_program_translator.py
demoncoder-crypto May 4, 2025
6162a27
Add logging module for PyTorch frontend
demoncoder-crypto May 4, 2025
249c808
fix: coerce bounds to int and update R.relu to R.nn.relu
demoncoder-crypto May 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 119 additions & 18 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
Expand All @@ -20,15 +20,19 @@
"""PyTorch ExportedProgram of Relax."""
from collections import ChainMap, OrderedDict
from functools import partial
from typing import Callable, Dict, List, Tuple
from typing import Callable, Dict, List, Tuple, Optional

import torch
import tvm
from tvm import relax

import tvm.tir as tir # pylint: disable=unused-import, consider-using-from-import
from tvm.relax.frontend.torch.log import get_logger
from .base_fx_graph_translator import BaseFXGraphImporter


logger = get_logger(__name__)


class ExportedProgramImporter(BaseFXGraphImporter):
"""An importer from ExportedProgram to Relax."""

Expand Down Expand Up @@ -503,13 +507,16 @@ def create_convert_map(
"item.default": self._item,
}

def create_input_vars(
self, exported_program: torch.export.ExportedProgram
) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var]]:
def create_input_vars(self, exported_program: torch.export.ExportedProgram) -> Tuple[
Dict[str, relax.Var],
Dict[str, relax.Var],
Dict[tvm.tir.Var, Tuple[Optional[int], Optional[int]]],
]:
"""Create relax input vars."""
parameters_buffers_constants = OrderedDict()
user_inputs = OrderedDict()
torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] = {}
relax_range_constraints: Dict[tvm.tir.Var, Tuple[Optional[int], Optional[int]]] = {}

for spec in exported_program.graph_signature.input_specs:
name_hint = spec.arg.name
Expand All @@ -527,13 +534,17 @@ def create_input_vars(
torch_shape = exported_program.state_dict[spec.target].shape
torch_dtype = exported_program.state_dict[spec.target].dtype

# TODO(mshr-h): Support range constraints
relax_shape = [
torch_symbol_to_relax_var.setdefault(str(s), tvm.tir.SizeVar(str(s), "int64"))
if isinstance(s, torch.SymInt)
else s
for s in torch_shape
]
# Create SizeVars and map SymInts
relax_shape = []
for s in torch_shape:
if isinstance(s, torch.SymInt):
s_str = str(s)
if s_str not in torch_symbol_to_relax_var:
torch_symbol_to_relax_var[s_str] = tvm.tir.SizeVar(s_str, "int64")
relax_shape.append(torch_symbol_to_relax_var[s_str])
else:
relax_shape.append(s)

dtype = self._convert_data_type(torch_dtype)

relax_var = relax.Var(name_hint, relax.TensorStructInfo(relax_shape, dtype))
Expand All @@ -542,7 +553,71 @@ def create_input_vars(
else:
parameters_buffers_constants[name_hint] = relax_var

return parameters_buffers_constants, user_inputs
# Extract range constraints for TIR vars
if hasattr(exported_program, "range_constraints") and exported_program.range_constraints:
for torch_sym_expr, rc in exported_program.range_constraints.items():
# Convert sympy expression to string for mapping
torch_sym_expr_str = str(torch_sym_expr)

if torch_sym_expr_str in torch_symbol_to_relax_var:
relax_tir_var = torch_symbol_to_relax_var[torch_sym_expr_str]
# Extract min/max values - Note: these could be None if constraint is one-sided
min_val = rc.lower
max_val = rc.upper
# Call helper to add/refine constraint
self._add_range_constraint(
relax_range_constraints, relax_tir_var, min_val, max_val
)
# Add debug info - can be removed in production
logger.debug(
f"Added constraint for {torch_sym_expr_str}: "
f"[{min_val}, {max_val}] -> {relax_range_constraints[relax_tir_var]}"
)
else:
# For complex expressions (e.g., s0 + 1) we don't yet have support
logger.debug(f"Skipping complex constraint expression: {torch_sym_expr}")

return parameters_buffers_constants, user_inputs, relax_range_constraints

# Helper method for handling range constraints
def _add_range_constraint(self, constraints_dict, relax_tir_var, min_val, max_val):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@staticmethod would be better since it doesn't access the instance variable or method.

"""Adds or refines a range constraint for a TIR variable.

Parameters
----------
constraints_dict : Dict[tvm.tir.Var, Tuple[Optional[int], Optional[int]]]
Dictionary that maps TIR variables to their range constraints
relax_tir_var : tvm.tir.Var
The TIR variable to constrain
min_val : Optional[int]
The minimum value (inclusive) or None if no minimum
max_val : Optional[int]
The maximum value (inclusive) or None if no maximum
"""

if relax_tir_var not in constraints_dict:
constraints_dict[relax_tir_var] = (min_val, max_val)
else:
# Refine existing constraints if the new one is tighter
existing_min, existing_max = constraints_dict[relax_tir_var]

# Merge lower bounds (take the max)
if existing_min is None:
new_min = min_val
elif min_val is None:
new_min = existing_min
else:
new_min = max(existing_min, min_val)

# Merge upper bounds (take the min)
if existing_max is None:
new_max = max_val
elif max_val is None:
new_max = existing_max
else:
new_max = min(existing_max, max_val)

constraints_dict[relax_tir_var] = (new_min, new_max)

def from_exported_program(
self,
Expand All @@ -554,23 +629,49 @@ def from_exported_program(
"""Convert a PyTorch ExportedProgram to a Relax program."""
from torch import fx # type: ignore

# Create input variables.
parameter_buffer_constant_vars, user_input_vars = self.create_input_vars(exported_program)
# Create input variables and get range constraints.
(
parameter_buffer_constant_vars,
user_input_vars,
relax_range_constraints,
) = self.create_input_vars(exported_program)
inputs_vars = user_input_vars.copy()
inputs_vars.update(parameter_buffer_constant_vars)

# Initialize the block builder with a function and a dataflow block.
self.block_builder = relax.BlockBuilder()
func_name = "main"
func_attrs = {"num_input": len(user_input_vars)} if keep_params_as_input else None

# Prepare function attributes
func_attrs = {"num_input": len(user_input_vars)} if keep_params_as_input else {}

# Add range constraints to function attributes if they exist
if relax_range_constraints:
lower_bounds = {}
upper_bounds = {}
for var, (lower, upper) in relax_range_constraints.items():
if lower is not None:
# For min constraints, use the exact value
lower_bounds[var] = tvm.tir.IntImm("int64", int(lower))
if upper is not None:
# For max constraints, use the exact value or MAX_INT64 if None
upper_bounds[var] = tvm.tir.IntImm("int64", int(upper))

if lower_bounds:
func_attrs["tir_var_lower_bound"] = lower_bounds
if upper_bounds:
func_attrs["tir_var_upper_bound"] = upper_bounds

# Use None if func_attrs is empty, otherwise use the dictionary
final_func_attrs = func_attrs if func_attrs else None

nodes: List[fx.Node] = exported_program.graph.nodes

# Find all the missing function types
self._check_unsupported_func_type(nodes)

with self.block_builder.function(
name=func_name, params=list(inputs_vars.values()).copy(), attrs=func_attrs
name=func_name, params=list(inputs_vars.values()).copy(), attrs=final_func_attrs
):
output = None
with self.block_builder.dataflow():
Expand Down
18 changes: 18 additions & 0 deletions python/tvm/relax/frontend/torch/log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import logging


def get_logger(name: str) -> logging.Logger:
"""Get a logger with a specific name.

Parameters
----------
name : str
The name of the logger.

Returns
-------
logging.Logger
A logger object.
"""
logger = logging.getLogger(name)
return logger
127 changes: 127 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from torch.nn import Module
from torch.export import export

import tvm.tir as tir # pylint: disable=unused-import, consider-using-from-import

import tvm
from tvm import relax
import tvm.testing
Expand Down Expand Up @@ -4756,6 +4758,61 @@ def main(
verify_model(DynamicModel(), example_args, {}, Expected, dynamic_shapes=dynamic_shapes)


def test_dynamic_shape_with_constraints():
# Define SymInts with constraints
B_dim = torch.export.Dim("B", min=2, max=10)
# Use B again for another dimension to test refinement (max(10, 15) -> 15)
B_refined = torch.export.Dim("B", min=3, max=15)
S_dim = torch.export.Dim("S", min=1) # Test min constraint only (-> (1, None))

# Example args matching initial B dim (max=10)
example_args = (torch.randn(3, 4, dtype=torch.float32), torch.randn(5, 2, dtype=torch.float32))

# Dynamic shapes using the Dim objects
# Input 0: Dim 0 uses B_dim (min=2, max=10), Dim 1 uses S_dim (min=1)
# Input 1: Dim 0 uses B_refined (min=3, max=15)
# The final constraint for tir.Var("B") should be max(2,3) to min(10,15) => min=3, max=10
dynamic_shapes = {0: {0: B_dim, 1: S_dim}, 1: {0: B_refined}}

class SimpleDynamic(torch.nn.Module):
# Simple op, the main thing is testing the input signature and constraints
def forward(self, x, y):
# Add tensors with different shapes requires broadcasting,
# but we only care about the input signature here.
# Use an op that doesn't depend on exact shapes matching.
return torch.relu(x) # Return just one to simplify output signature

B = tir.Var("B", "int64")
S = tir.Var("S", "int64")

# Define the expected Relax IRModule
@tvm.script.ir_module
class Expected:
@R.function
def main(
x: R.Tensor((B, S), dtype="float32"), # Uses B, S defined outside
y: R.Tensor((B, 2), dtype="float32"), # Uses B defined outside
) -> R.Tuple(R.Tensor((B, S), dtype="float32")):
# Add expected constraints as function attributes
R.func_attr(
{
"tir_var_upper_bound": {B: T.int64(10), S: T.int64(9223372036854775807)},
"tir_var_lower_bound": {B: T.int64(3), S: T.int64(1)},
"num_input": 2, # Two user inputs: x and y
}
)
with R.dataflow():
# Use the parameters x and y passed in
lv: R.Tensor((B, S), dtype="float32") = R.nn.relu(x)
# The output shape must match the signature
gv: R.Tuple(R.Tensor((B, S), dtype="float32")) = (lv,)
R.output(gv)
return gv

# Verify the model conversion, including constraints
verify_model(SimpleDynamic(), example_args, {}, Expected, dynamic_shapes=dynamic_shapes)


def test_broadcast_to():
class BroadcastTo(Module):
def forward(self, x):
Expand Down Expand Up @@ -4998,6 +5055,76 @@ def main(
verify_model(Linspace(), example_args, {}, Expected)


def test_dynamic_shape_single_sided_constraints():
"""Test importing ExportedProgram with single-sided constraints (min only or max only)."""

# --- Test Case 1: Min constraint only ---
B_min = torch.export.Dim("B_min", min=5)
S_min = torch.export.Dim("S_min", min=2)

example_args_min = (torch.randn(6, 3, dtype=torch.float32),)
dynamic_shapes_min = {0: {0: B_min, 1: S_min}}

# Define the expected Relax IRModule for min-only
B_min_tir = tir.Var("B_min", "int64")
S_min_tir = tir.Var("S_min", "int64")

@tvm.script.ir_module
class ExpectedMin:
@R.function
def main(
x: R.Tensor((B_min_tir, S_min_tir), dtype="float32")
) -> R.Tuple(R.Tensor((B_min_tir, S_min_tir), dtype="float32")):
# No function attributes since only one-sided constraints
with R.dataflow():
lv: R.Tensor((B_min_tir, S_min_tir), dtype="float32") = R.nn.relu(x)
gv: R.Tuple(R.Tensor((B_min_tir, S_min_tir), dtype="float32")) = (lv,)
R.output(gv)
return gv

# Model just needs to accept the inputs
class SimpleModelMin(torch.nn.Module):
def forward(self, x):
return torch.relu(x)

verify_model(
SimpleModelMin(), example_args_min, {}, ExpectedMin, dynamic_shapes=dynamic_shapes_min
)

# --- Test Case 2: Max constraint only ---
B_max = torch.export.Dim("B_max", max=20)
S_max = torch.export.Dim("S_max", max=10)

example_args_max = (torch.randn(15, 8, dtype=torch.float32),)
dynamic_shapes_max = {0: {0: B_max, 1: S_max}}

# Define the expected Relax IRModule for max-only
B_max_tir = tir.Var("B_max", "int64")
S_max_tir = tir.Var("S_max", "int64")

@tvm.script.ir_module
class ExpectedMax:
@R.function
def main(
x: R.Tensor((B_max_tir, S_max_tir), dtype="float32")
) -> R.Tuple(R.Tensor((B_max_tir, S_max_tir), dtype="float32")):
# No function attributes since only one-sided constraints
with R.dataflow():
lv: R.Tensor((B_max_tir, S_max_tir), dtype="float32") = R.nn.relu(x)
gv: R.Tuple(R.Tensor((B_max_tir, S_max_tir), dtype="float32")) = (lv,)
R.output(gv)
return gv

# Model just needs to accept the inputs
class SimpleModelMax(torch.nn.Module):
def forward(self, x):
return torch.relu(x)

verify_model(
SimpleModelMax(), example_args_max, {}, ExpectedMax, dynamic_shapes=dynamic_shapes_max
)


def test_bfloat16():
# TODO(mshr-h): Add tests for all the dtypes supported in fx frontend
example_args = (
Expand Down
Loading