diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index df532fd1ea04..25517eacea4b 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1,5 +1,5 @@ # Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file +# or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the @@ -20,15 +20,19 @@ """PyTorch ExportedProgram of Relax.""" from collections import ChainMap, OrderedDict from functools import partial -from typing import Callable, Dict, List, Tuple +from typing import Callable, Dict, List, Tuple, Optional import torch import tvm from tvm import relax - +import tvm.tir as tir # pylint: disable=unused-import, consider-using-from-import +from tvm.relax.frontend.torch.log import get_logger from .base_fx_graph_translator import BaseFXGraphImporter +logger = get_logger(__name__) + + class ExportedProgramImporter(BaseFXGraphImporter): """An importer from ExportedProgram to Relax.""" @@ -503,13 +507,16 @@ def create_convert_map( "item.default": self._item, } - def create_input_vars( - self, exported_program: torch.export.ExportedProgram - ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var]]: + def create_input_vars(self, exported_program: torch.export.ExportedProgram) -> Tuple[ + Dict[str, relax.Var], + Dict[str, relax.Var], + Dict[tvm.tir.Var, Tuple[Optional[int], Optional[int]]], + ]: """Create relax input vars.""" parameters_buffers_constants = OrderedDict() user_inputs = OrderedDict() torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] = {} + relax_range_constraints: Dict[tvm.tir.Var, Tuple[Optional[int], Optional[int]]] = {} for spec in exported_program.graph_signature.input_specs: name_hint = spec.arg.name @@ -527,13 +534,17 @@ def create_input_vars( torch_shape = exported_program.state_dict[spec.target].shape torch_dtype = exported_program.state_dict[spec.target].dtype - # TODO(mshr-h): Support range constraints - relax_shape = [ - torch_symbol_to_relax_var.setdefault(str(s), tvm.tir.SizeVar(str(s), "int64")) - if isinstance(s, torch.SymInt) - else s - for s in torch_shape - ] + # Create SizeVars and map SymInts + relax_shape = [] + for s in torch_shape: + if isinstance(s, torch.SymInt): + s_str = str(s) + if s_str not in torch_symbol_to_relax_var: + torch_symbol_to_relax_var[s_str] = tvm.tir.SizeVar(s_str, "int64") + relax_shape.append(torch_symbol_to_relax_var[s_str]) + else: + relax_shape.append(s) + dtype = self._convert_data_type(torch_dtype) relax_var = relax.Var(name_hint, relax.TensorStructInfo(relax_shape, dtype)) @@ -542,7 +553,71 @@ def create_input_vars( else: parameters_buffers_constants[name_hint] = relax_var - return parameters_buffers_constants, user_inputs + # Extract range constraints for TIR vars + if hasattr(exported_program, "range_constraints") and exported_program.range_constraints: + for torch_sym_expr, rc in exported_program.range_constraints.items(): + # Convert sympy expression to string for mapping + torch_sym_expr_str = str(torch_sym_expr) + + if torch_sym_expr_str in torch_symbol_to_relax_var: + relax_tir_var = torch_symbol_to_relax_var[torch_sym_expr_str] + # Extract min/max values - Note: these could be None if constraint is one-sided + min_val = rc.lower + max_val = rc.upper + # Call helper to add/refine constraint + self._add_range_constraint( + relax_range_constraints, relax_tir_var, min_val, max_val + ) + # Add debug info - can be removed in production + logger.debug( + f"Added constraint for {torch_sym_expr_str}: " + f"[{min_val}, {max_val}] -> {relax_range_constraints[relax_tir_var]}" + ) + else: + # For complex expressions (e.g., s0 + 1) we don't yet have support + logger.debug(f"Skipping complex constraint expression: {torch_sym_expr}") + + return parameters_buffers_constants, user_inputs, relax_range_constraints + + # Helper method for handling range constraints + def _add_range_constraint(self, constraints_dict, relax_tir_var, min_val, max_val): + """Adds or refines a range constraint for a TIR variable. + + Parameters + ---------- + constraints_dict : Dict[tvm.tir.Var, Tuple[Optional[int], Optional[int]]] + Dictionary that maps TIR variables to their range constraints + relax_tir_var : tvm.tir.Var + The TIR variable to constrain + min_val : Optional[int] + The minimum value (inclusive) or None if no minimum + max_val : Optional[int] + The maximum value (inclusive) or None if no maximum + """ + + if relax_tir_var not in constraints_dict: + constraints_dict[relax_tir_var] = (min_val, max_val) + else: + # Refine existing constraints if the new one is tighter + existing_min, existing_max = constraints_dict[relax_tir_var] + + # Merge lower bounds (take the max) + if existing_min is None: + new_min = min_val + elif min_val is None: + new_min = existing_min + else: + new_min = max(existing_min, min_val) + + # Merge upper bounds (take the min) + if existing_max is None: + new_max = max_val + elif max_val is None: + new_max = existing_max + else: + new_max = min(existing_max, max_val) + + constraints_dict[relax_tir_var] = (new_min, new_max) def from_exported_program( self, @@ -554,15 +629,41 @@ def from_exported_program( """Convert a PyTorch ExportedProgram to a Relax program.""" from torch import fx # type: ignore - # Create input variables. - parameter_buffer_constant_vars, user_input_vars = self.create_input_vars(exported_program) + # Create input variables and get range constraints. + ( + parameter_buffer_constant_vars, + user_input_vars, + relax_range_constraints, + ) = self.create_input_vars(exported_program) inputs_vars = user_input_vars.copy() inputs_vars.update(parameter_buffer_constant_vars) # Initialize the block builder with a function and a dataflow block. self.block_builder = relax.BlockBuilder() func_name = "main" - func_attrs = {"num_input": len(user_input_vars)} if keep_params_as_input else None + + # Prepare function attributes + func_attrs = {"num_input": len(user_input_vars)} if keep_params_as_input else {} + + # Add range constraints to function attributes if they exist + if relax_range_constraints: + lower_bounds = {} + upper_bounds = {} + for var, (lower, upper) in relax_range_constraints.items(): + if lower is not None: + # For min constraints, use the exact value + lower_bounds[var] = tvm.tir.IntImm("int64", int(lower)) + if upper is not None: + # For max constraints, use the exact value or MAX_INT64 if None + upper_bounds[var] = tvm.tir.IntImm("int64", int(upper)) + + if lower_bounds: + func_attrs["tir_var_lower_bound"] = lower_bounds + if upper_bounds: + func_attrs["tir_var_upper_bound"] = upper_bounds + + # Use None if func_attrs is empty, otherwise use the dictionary + final_func_attrs = func_attrs if func_attrs else None nodes: List[fx.Node] = exported_program.graph.nodes @@ -570,7 +671,7 @@ def from_exported_program( self._check_unsupported_func_type(nodes) with self.block_builder.function( - name=func_name, params=list(inputs_vars.values()).copy(), attrs=func_attrs + name=func_name, params=list(inputs_vars.values()).copy(), attrs=final_func_attrs ): output = None with self.block_builder.dataflow(): diff --git a/python/tvm/relax/frontend/torch/log.py b/python/tvm/relax/frontend/torch/log.py new file mode 100644 index 000000000000..13f0fb656f63 --- /dev/null +++ b/python/tvm/relax/frontend/torch/log.py @@ -0,0 +1,18 @@ +import logging + + +def get_logger(name: str) -> logging.Logger: + """Get a logger with a specific name. + + Parameters + ---------- + name : str + The name of the logger. + + Returns + ------- + logging.Logger + A logger object. + """ + logger = logging.getLogger(name) + return logger \ No newline at end of file diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index f0bb33964ef2..2cdd9fa3f34c 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -20,6 +20,8 @@ from torch.nn import Module from torch.export import export +import tvm.tir as tir # pylint: disable=unused-import, consider-using-from-import + import tvm from tvm import relax import tvm.testing @@ -4756,6 +4758,61 @@ def main( verify_model(DynamicModel(), example_args, {}, Expected, dynamic_shapes=dynamic_shapes) +def test_dynamic_shape_with_constraints(): + # Define SymInts with constraints + B_dim = torch.export.Dim("B", min=2, max=10) + # Use B again for another dimension to test refinement (max(10, 15) -> 15) + B_refined = torch.export.Dim("B", min=3, max=15) + S_dim = torch.export.Dim("S", min=1) # Test min constraint only (-> (1, None)) + + # Example args matching initial B dim (max=10) + example_args = (torch.randn(3, 4, dtype=torch.float32), torch.randn(5, 2, dtype=torch.float32)) + + # Dynamic shapes using the Dim objects + # Input 0: Dim 0 uses B_dim (min=2, max=10), Dim 1 uses S_dim (min=1) + # Input 1: Dim 0 uses B_refined (min=3, max=15) + # The final constraint for tir.Var("B") should be max(2,3) to min(10,15) => min=3, max=10 + dynamic_shapes = {0: {0: B_dim, 1: S_dim}, 1: {0: B_refined}} + + class SimpleDynamic(torch.nn.Module): + # Simple op, the main thing is testing the input signature and constraints + def forward(self, x, y): + # Add tensors with different shapes requires broadcasting, + # but we only care about the input signature here. + # Use an op that doesn't depend on exact shapes matching. + return torch.relu(x) # Return just one to simplify output signature + + B = tir.Var("B", "int64") + S = tir.Var("S", "int64") + + # Define the expected Relax IRModule + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((B, S), dtype="float32"), # Uses B, S defined outside + y: R.Tensor((B, 2), dtype="float32"), # Uses B defined outside + ) -> R.Tuple(R.Tensor((B, S), dtype="float32")): + # Add expected constraints as function attributes + R.func_attr( + { + "tir_var_upper_bound": {B: T.int64(10), S: T.int64(9223372036854775807)}, + "tir_var_lower_bound": {B: T.int64(3), S: T.int64(1)}, + "num_input": 2, # Two user inputs: x and y + } + ) + with R.dataflow(): + # Use the parameters x and y passed in + lv: R.Tensor((B, S), dtype="float32") = R.nn.relu(x) + # The output shape must match the signature + gv: R.Tuple(R.Tensor((B, S), dtype="float32")) = (lv,) + R.output(gv) + return gv + + # Verify the model conversion, including constraints + verify_model(SimpleDynamic(), example_args, {}, Expected, dynamic_shapes=dynamic_shapes) + + def test_broadcast_to(): class BroadcastTo(Module): def forward(self, x): @@ -4998,6 +5055,76 @@ def main( verify_model(Linspace(), example_args, {}, Expected) +def test_dynamic_shape_single_sided_constraints(): + """Test importing ExportedProgram with single-sided constraints (min only or max only).""" + + # --- Test Case 1: Min constraint only --- + B_min = torch.export.Dim("B_min", min=5) + S_min = torch.export.Dim("S_min", min=2) + + example_args_min = (torch.randn(6, 3, dtype=torch.float32),) + dynamic_shapes_min = {0: {0: B_min, 1: S_min}} + + # Define the expected Relax IRModule for min-only + B_min_tir = tir.Var("B_min", "int64") + S_min_tir = tir.Var("S_min", "int64") + + @tvm.script.ir_module + class ExpectedMin: + @R.function + def main( + x: R.Tensor((B_min_tir, S_min_tir), dtype="float32") + ) -> R.Tuple(R.Tensor((B_min_tir, S_min_tir), dtype="float32")): + # No function attributes since only one-sided constraints + with R.dataflow(): + lv: R.Tensor((B_min_tir, S_min_tir), dtype="float32") = R.nn.relu(x) + gv: R.Tuple(R.Tensor((B_min_tir, S_min_tir), dtype="float32")) = (lv,) + R.output(gv) + return gv + + # Model just needs to accept the inputs + class SimpleModelMin(torch.nn.Module): + def forward(self, x): + return torch.relu(x) + + verify_model( + SimpleModelMin(), example_args_min, {}, ExpectedMin, dynamic_shapes=dynamic_shapes_min + ) + + # --- Test Case 2: Max constraint only --- + B_max = torch.export.Dim("B_max", max=20) + S_max = torch.export.Dim("S_max", max=10) + + example_args_max = (torch.randn(15, 8, dtype=torch.float32),) + dynamic_shapes_max = {0: {0: B_max, 1: S_max}} + + # Define the expected Relax IRModule for max-only + B_max_tir = tir.Var("B_max", "int64") + S_max_tir = tir.Var("S_max", "int64") + + @tvm.script.ir_module + class ExpectedMax: + @R.function + def main( + x: R.Tensor((B_max_tir, S_max_tir), dtype="float32") + ) -> R.Tuple(R.Tensor((B_max_tir, S_max_tir), dtype="float32")): + # No function attributes since only one-sided constraints + with R.dataflow(): + lv: R.Tensor((B_max_tir, S_max_tir), dtype="float32") = R.nn.relu(x) + gv: R.Tuple(R.Tensor((B_max_tir, S_max_tir), dtype="float32")) = (lv,) + R.output(gv) + return gv + + # Model just needs to accept the inputs + class SimpleModelMax(torch.nn.Module): + def forward(self, x): + return torch.relu(x) + + verify_model( + SimpleModelMax(), example_args_max, {}, ExpectedMax, dynamic_shapes=dynamic_shapes_max + ) + + def test_bfloat16(): # TODO(mshr-h): Add tests for all the dtypes supported in fx frontend example_args = ( diff --git a/tvm/python/tvm/relax/frontend/torch/exported_program_translator.py b/tvm/python/tvm/relax/frontend/torch/exported_program_translator.py new file mode 100644 index 000000000000..a708ba208580 --- /dev/null +++ b/tvm/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -0,0 +1,89 @@ +# ... existing code ... + return parameters_buffers_constants, user_inputs, relax_range_constraints + + @staticmethod + def _add_range_constraint(constraints_dict, relax_tir_var, min_val, max_val): + """Helper to add or refine constraints for a TIR variable.""" + # ... existing code ... + + def create_input_vars( + # ... existing code ... + # TODO(sjt): Handle SymFloat, SymBool cases as well. + # Note: min / max could be int or SymInt objects. + # Need to handle symbolic shapes as well. + min_val = constraint.lower # Use .lower + max_val = constraint.upper # Use .upper + + # Convert potential SymInts to concrete values or handle symbolically if needed + # For now, assume they resolve to integers or handle error/TODO + # This part might need refinement based on how SymInt bounds are represented + if isinstance(min_val, torch.SymInt): + # How to get the concrete value or symbolic representation? + # Placeholder: Treat as None if symbolic for now, needs investigation + # Or maybe try accessing a property like .node.py_val if available? + # Assuming direct int conversion isn't always possible/correct. + # Let's log a warning and skip symbolic bounds for now. + # TODO: Properly handle symbolic min/max values from constraints. + logging.warning(f"Symbolic min value {min_val} found for {relax_tir_var}. Symbolic bounds not fully handled yet. Skipping min.") + min_val = None # Or handle symbolically + if isinstance(max_val, torch.SymInt): + logging.warning(f"Symbolic max value {max_val} found for {relax_tir_var}. Symbolic bounds not fully handled yet. Skipping max.") + max_val = None # Or handle symbolically + + + ExportedProgramImporter._add_range_constraint( + relax_range_constraints, relax_tir_var, min_val, max_val + ) + # ... existing code ... + + def create_output_vars( + # ... existing code ... + ret_struct_info=output_struct_info, + attrs=relax.attrs.FuncAttrs( + { + "num_input": len(user_inputs), + # Populate bounds from the collected constraints + "tir_var_upper_bound": { + var: T.int64(upper) + for var, (lower, upper) in relax_range_constraints.items() + if upper is not None + }, + "tir_var_lower_bound": { + var: T.int64(lower) + for var, (lower, upper) in relax_range_constraints.items() + if lower is not None + }, + } + ), + ) + # ... existing code ... + + output_vars = [relax.Var(name, sinfo) for name, sinfo in zip(output_names, output_struct_info)] + inputs = list(parameters_buffers_constants.values()) + list(user_inputs.values()) + + # Add the constraints info to the function attributes + func = relax.Function( + inputs, + output_vars, + None, # body is filled later + ret_struct_info=output_struct_info, + attrs=relax.attrs.FuncAttrs( + { + "num_input": len(user_inputs), + # Populate bounds from the collected constraints + "tir_var_upper_bound": { + var: T.int64(upper) + for var, (lower, upper) in relax_range_constraints.items() + if upper is not None + }, + "tir_var_lower_bound": { + var: T.int64(lower) + for var, (lower, upper) in relax_range_constraints.items() + if lower is not None + }, + } + ), + ) + + builder = relax.BlockBuilder() + # ... existing code ...