From 7c43afcd1fc01493081256f1c383722f1c1dc470 Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Fri, 14 Nov 2025 15:46:08 +0800 Subject: [PATCH] Add lower bound support for range constraints --- include/tvm/relax/transform.h | 19 +++++----- .../torch/exported_program_translator.py | 6 ++- src/relax/transform/adjust_matmul_order.cc | 32 ++++++++++++---- .../transform/static_plan_block_memory.cc | 38 ++++++++++++------- .../test_frontend_from_exported_program.py | 2 +- ...test_transform_static_plan_block_memory.py | 12 ++++++ 6 files changed, 77 insertions(+), 32 deletions(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index a8ccc4076bb3..58cf7421b5a7 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -125,18 +125,19 @@ TVM_DLL Pass RewriteDataflowReshape(); * The pass will reuse allocated memory to its best effort, in order to * reduce the total amount of allocated memory size. * - * The pass "supports" dynamic shape in the way of TIR variable upper bound - * annotation. We can optionally annotate the attribute "tir_var_upper_bound" - * to Relax functions. The attribute value is a dict from strings to integers, - * denoting the name of TIR variables to the upper bound values of the TIR vars. - * Note: The annotated upper bound attribute only applies to TIR vars in the + * The pass "supports" dynamic shape in the way of TIR variable bound + * annotations. We can optionally annotate the attributes "tir_var_upper_bound" + * and "tir_var_lower_bound" to Relax functions. The attribute values are dicts + * from strings to integers, denoting the name of TIR variables to the bound + * values of the TIR vars. + * Note: The annotated bound attributes only apply to TIR vars in the * function signature for clarity. * * For example, we can annotate a Relax function with - * `R.func_attr({"tir_var_upper_bound": {"n": 1024}})`. - * It means the maximum value of variable that names "n" in the function - * signature will have upper bound 1024. And we will use 1024 as its value - * during memory planning. + * `R.func_attr({"tir_var_lower_bound": {"n": 1}, "tir_var_upper_bound": {"n": 1024}})`. + * It means the variable that names "n" in the function signature will have + * range [1, 1024]. And we will use these bounds during memory planning. + * If lower bound is not specified, it defaults to 0. * * \return The pass. */ diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 5cddf24a89dc..431a1444d172 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1181,10 +1181,12 @@ def from_exported_program( if range_constraints: if func_attrs is None: func_attrs = {} - tir_var_upper_bound = { + func_attrs["tir_var_lower_bound"] = { + var_name: lower for var_name, (lower, _) in range_constraints.items() + } + func_attrs["tir_var_upper_bound"] = { var_name: upper for var_name, (_, upper) in range_constraints.items() } - func_attrs["tir_var_upper_bound"] = tir_var_upper_bound nodes: List[fx.Node] = exported_program.graph.nodes diff --git a/src/relax/transform/adjust_matmul_order.cc b/src/relax/transform/adjust_matmul_order.cc index 98fe57e11c2a..889272019174 100644 --- a/src/relax/transform/adjust_matmul_order.cc +++ b/src/relax/transform/adjust_matmul_order.cc @@ -73,19 +73,37 @@ std::tuple)>> pat_permuted_matmul_on_rhs; PrimExpr symbolic_var_constraints = Bool(true); - if (auto upper_bounds = func->GetAttr>("tir_var_upper_bound")) { + auto upper_bounds = func->GetAttr>("tir_var_upper_bound"); + auto lower_bounds = func->GetAttr>("tir_var_lower_bound"); + + if (upper_bounds || lower_bounds) { ffi::Map name_lookup; for (const auto& tir_var : TIRVarsInStructInfo(GetStructInfo(func))) { name_lookup.Set(tir_var->name_hint, tir_var); symbolic_var_constraints = symbolic_var_constraints && (0 <= tir_var); } - for (const auto& [key, obj_bound] : upper_bounds.value()) { - auto tir_var_name = Downcast(key); - if (auto opt_var = name_lookup.Get(tir_var_name)) { - auto var = opt_var.value(); - auto expr_bound = Downcast(obj_bound); - symbolic_var_constraints = symbolic_var_constraints && (var < expr_bound); + // Add lower bound constraints + if (lower_bounds) { + for (const auto& [key, obj_bound] : lower_bounds.value()) { + auto tir_var_name = Downcast(key); + if (auto opt_var = name_lookup.Get(tir_var_name)) { + auto var = opt_var.value(); + auto expr_bound = Downcast(obj_bound); + symbolic_var_constraints = symbolic_var_constraints && (expr_bound <= var); + } + } + } + + // Add upper bound constraints + if (upper_bounds) { + for (const auto& [key, obj_bound] : upper_bounds.value()) { + auto tir_var_name = Downcast(key); + if (auto opt_var = name_lookup.Get(tir_var_name)) { + auto var = opt_var.value(); + auto expr_bound = Downcast(obj_bound); + symbolic_var_constraints = symbolic_var_constraints && (var < expr_bound); + } } } } diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index 85076206ae53..fc3c2259ff9a 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -365,40 +365,52 @@ class StorageAllocatorBaseVisitor : public ExprVisitor { }; /*! - * \brief Set the upper bound of the TIR variables that appear in + * \brief Set the range constraints of the TIR variables that appear in * the input function signature in the analyzer. * \param func The function to be analyzed. * \param ana The analyzer which contains the TIR var upper bounds. * \param dom_map The domain map of the TIR variables. */ -void SetTIRVarUpperBound(Function func, arith::Analyzer* ana, - ffi::Map* dom_map) { - // Use the attribute-annotated TIR var upper bounds as the TIR var values for +void SetTIRVarRangeConstraints(Function func, arith::Analyzer* ana, + ffi::Map* dom_map) { + // Use the attribute-annotated TIR var bounds as the TIR var values for // memory planning. - // NOTE: we only apply the annotated upper bounds to the TIR variables that + // NOTE: we only apply the annotated bounds to the TIR variables that // appear in the **function signature**. ffi::Map var_upper_bound_attr_raw = func->GetAttr>("tir_var_upper_bound") .value_or(ffi::Map()); + ffi::Map var_lower_bound_attr_raw = + func->GetAttr>("tir_var_lower_bound") + .value_or(ffi::Map()); ffi::Array non_negative_var_attr_raw = func->GetAttr>("tir_non_negative_var") .value_or(ffi::Array()); std::unordered_map var_upper_bound_attr; + std::unordered_map var_lower_bound_attr; std::unordered_set non_negative_var_attr; // We manually check the value type to ensure the values are all positive IntImm. for (auto [key, value] : var_upper_bound_attr_raw) { var_upper_bound_attr[key] = value; } + for (auto [key, value] : var_lower_bound_attr_raw) { + var_lower_bound_attr[key] = value; + } for (const ffi::String& var_name : non_negative_var_attr_raw) { non_negative_var_attr.insert(var_name); } ffi::Array var_in_signature = TIRVarsInStructInfo(GetStructInfo(func)); for (const tir::Var& tir_var : var_in_signature) { - auto it = var_upper_bound_attr.find(tir_var->name_hint); - if (it != var_upper_bound_attr.end()) { - tvm::Range range = - tvm::Range::FromMinExtent(tvm::IntImm(DataType::Int(64), 0), - tvm::IntImm(DataType::Int(64), (*it).second->value + 1)); + auto it_upper = var_upper_bound_attr.find(tir_var->name_hint); + auto it_lower = var_lower_bound_attr.find(tir_var->name_hint); + + if (it_upper != var_upper_bound_attr.end() || it_lower != var_lower_bound_attr.end()) { + int64_t lower = (it_lower != var_lower_bound_attr.end()) ? it_lower->second->value : 0; + int64_t upper = (it_upper != var_upper_bound_attr.end()) + ? it_upper->second->value + : std::numeric_limits::max(); + tvm::Range range = tvm::Range::FromMinExtent( + tvm::IntImm(DataType::Int(64), lower), tvm::IntImm(DataType::Int(64), upper - lower + 1)); ana->Bind(tir_var, range); dom_map->Set(tir_var, arith::IntSet::FromRange(range)); } else if (non_negative_var_attr.count(tir_var->name_hint)) { @@ -485,8 +497,8 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { : ctx_mod_(ctx_mod), analyzer_(analyzer) {} void VisitExpr_(const FunctionNode* func) final { - // Set the upper bound of TIR variables in the analyzer. - SetTIRVarUpperBound(ffi::GetRef(func), analyzer_, &dom_map_); + // Set the range constraints of TIR variables in the analyzer. + SetTIRVarRangeConstraints(ffi::GetRef(func), analyzer_, &dom_map_); // Recurse into the function to get its tokens. Tokens body_tokens = GetTokens(func->body); // Discard the tokens used by the function return value, as they are external referenced. @@ -843,7 +855,7 @@ class StorageAllocationRewriter : public ExprMutator { plan_dynamic_output_ = static_cast( func_->GetAttr(plan_dyn_attr_).value_or(IntImm(DataType::Int(32), 0))->value); if (plan_dynamic_output_) { - SetTIRVarUpperBound(ffi::GetRef(func_), &ana_, &dom_map_); + SetTIRVarRangeConstraints(ffi::GetRef(func_), &ana_, &dom_map_); } token2storage_var_.clear(); Function func = Downcast(this->VisitExpr_(func_)); diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 71e400a6a8b1..157af43facbf 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -6747,7 +6747,7 @@ def main( x1: R.Tensor(("s0", 4), dtype="float32"), x2: R.Tensor(("s0", 4), dtype="float32") ) -> R.Tuple(R.Tensor(("s0", 4), dtype="float32")): s0 = T.int64(is_size_var=True) - R.func_attr({"tir_var_upper_bound": {"s0": 64}}) + R.func_attr({"tir_var_lower_bound": {"s0": 1}, "tir_var_upper_bound": {"s0": 64}}) with R.dataflow(): lv: R.Tensor((s0, 4), dtype="float32") = R.add(x1, x2) gv: R.Tuple(R.Tensor((s0, 4), dtype="float32")) = (lv,) diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index 83e4d264c6a3..06e4ea142e95 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -1347,6 +1347,18 @@ def main(x: R.Tensor((2, "n"), dtype="float32")): relax.transform.StaticPlanBlockMemory()(Module) +def test_invalid_tir_var_lower_bound(): + @tvm.script.ir_module + class Module: + @R.function + def main(x: R.Tensor((2, "n"), dtype="float32")): + R.func_attr({"tir_var_lower_bound": {"n": [4]}, "relax.force_pure": True}) + return x + + with pytest.raises((TVMError, TypeError)): + relax.transform.StaticPlanBlockMemory()(Module) + + def test_add(): @I.ir_module class Module: