From cc90ca6955b3c88ba036e97842b3ad812614afc2 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 28 Aug 2024 10:46:54 -0500 Subject: [PATCH] [Relax][Bugfix] Infer TIR values from shapes inside a tuple If a Relax function contains an `R.match_cast` that defines a symbolic shape, and the value provided to the `R.match_cast` has a known static shape, the `relax.transform.CanoncalizeBindings()` pass can in-line the known static shape. However, while these known TIR values were only collected if the expression used in `R.match_cast` was a `R.Tensor`, `R.Shape`, and `R.Prim` (Relax types which may contain symbolic TIR values), they were not collected if the `R.match_cast` expression was a `R.Tuple`. For example, while using `R.match_cast` to convert from `R.Tensor([16])` to `R.Tensor([batch_size])` would identify that `batch_size` must be `16`, using `R.match_cast` to convert from `R.Tuple(R.Tensor([16]))` to `R.Tuple(R.Tensor([batch_size]))` would not. This commit updates the `InferSymbolicVarMap` to collect all symbolic shapes, even if they occur within a `R.Tuple`. --- src/relax/utils.cc | 27 ++++++++++++--- .../test_transform_canonicalize_bindings.py | 34 +++++++++++++++++++ 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 77416dc92b1d..96fd5578e40a 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -159,13 +159,32 @@ tvm::Map InferSymbolicVarMap( GetStructInfo(expr_tensor->shape.value())); }; + std::function bind_from_struct_info = nullptr; + auto bind_from_tuple = [&bind_from_struct_info](const StructInfo& var, const StructInfo& expr) { + auto var_tuple = var.as(); + if (!var_tuple) return; + + auto expr_tuple = expr.as(); + if (!expr_tuple) return; + + if (var_tuple->fields.size() != expr_tuple->fields.size()) return; + + for (size_t i = 0; i < var_tuple->fields.size(); i++) { + bind_from_struct_info(var_tuple->fields[i], expr_tuple->fields[i]); + } + }; + + bind_from_struct_info = [&](const StructInfo& var, const StructInfo& expr) { + bind_from_tensor(var, expr); + bind_from_shape(var, expr); + bind_from_prim_value(var, expr); + bind_from_tuple(var, expr); + }; + for (const auto& [relax_var, relax_expr] : relax_var_remap) { auto var_sinfo = GetStructInfo(relax_var); auto expr_sinfo = GetStructInfo(relax_expr); - - bind_from_tensor(var_sinfo, expr_sinfo); - bind_from_shape(var_sinfo, expr_sinfo); - bind_from_prim_value(var_sinfo, expr_sinfo); + bind_from_struct_info(var_sinfo, expr_sinfo); } return tir_var_remap; diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py index ea3b1c249b8b..a7ff8cdc3202 100644 --- a/tests/python/relax/test_transform_canonicalize_bindings.py +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -253,6 +253,40 @@ def main(x: R.Tensor(("m", "n"))): verify(TestChangeShape, Expected) +def test_replace_symbolic_variable_and_remove_match_cast_of_tuple(): + """Symbolic variables may be defined in R.match_cast of tuple + + This test is similar to + `test_replace_symbolic_variable_and_remove_match_cast`, except + that the MatchCast is performed on a Relax tuple. + + This is a regression test. Earlier implementations only inferred + TIR variables from `R.match_cast` of tensors, shapes, and prim + values, but omitted tuples. + + """ + + @I.ir_module + class Before: + @R.function + def main(x: R.Tuple(R.Tensor(("m", "n")))): + y = x + o, p = T.int64(), T.int64() + z = R.match_cast(x, R.Tuple(R.Tensor((o, p)))) + w = z + q = R.add(w[0], y[0]) + return R.add(q, w[0]) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tuple(R.Tensor(("m", "n")))): + q = R.add(x[0], x[0]) + return R.add(q, x[0]) + + verify(Before, Expected) + + def test_unwrap_tuple(): @I.ir_module class Before: