From 40e4f0afce37e1531670e51351f2462be42d89a9 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 2 Apr 2024 20:57:05 +0000 Subject: [PATCH] [Relax] Provide well-formed output in `transform.LazyGetInput` Prior to this commit, symbolic variables inferred from the parameters were retained in the output function's `ret_struct_info`. This is ill-formed, as the parameters from which these symbolic variables are inferred are no longer part of the function signature. This commit updates `LazyGetInput` to use `EraseToWellDefined` to remove any symbolic variables from `ret_struct_info` that cannot be inferred from the remaining arguments. --- src/relax/transform/lazy_transform_params.cc | 14 ++++++++ .../test_transform_lazy_transform_params.py | 34 +++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/src/relax/transform/lazy_transform_params.cc b/src/relax/transform/lazy_transform_params.cc index 21608af7dba0..37827fbe0e6c 100644 --- a/src/relax/transform/lazy_transform_params.cc +++ b/src/relax/transform/lazy_transform_params.cc @@ -71,8 +71,22 @@ class LazyInputMutator : public ExprMutator { Array new_params(func->params.begin(), func->params.begin() + num_input_params); new_params.push_back(fget_param); + auto array_externally_visible_vars = + DefinableTIRVarsInStructInfo(TupleStructInfo(new_params.Map(GetStructInfo))); + std::unordered_set externally_visible_vars( + array_externally_visible_vars.begin(), array_externally_visible_vars.end()); + StructInfo new_ret_struct_info = + EraseToWellDefined(func->ret_struct_info, [&](const tir::Var& var) -> Optional { + if (externally_visible_vars.count(var)) { + return var; + } else { + return NullOpt; + } + }); + auto node = GetRef(func); node.CopyOnWrite()->params = new_params; + node.CopyOnWrite()->ret_struct_info = new_ret_struct_info; node = WithAttr(node, attr::kNumInput, Integer(num_input_params + 1)); plan_ = FunctionPlan{std::move(param_lookup), fget_param}; diff --git a/tests/python/relax/test_transform_lazy_transform_params.py b/tests/python/relax/test_transform_lazy_transform_params.py index 833cbd460c0f..040aea28909d 100644 --- a/tests/python/relax/test_transform_lazy_transform_params.py +++ b/tests/python/relax/test_transform_lazy_transform_params.py @@ -951,6 +951,40 @@ def transform_params( tvm.ir.assert_structural_equal(After, Expected) +def test_get_item_callback_dynamic_shape(): + @I.ir_module + class Before: + @R.function + def transform_params( + A: R.Tensor(["m", "n"], "float32"), B: R.Tensor(["m", "n"], "float32") + ) -> R.Tuple(R.Tensor(["m", "n"], "float32"), R.Tensor(["m", "n"], "float32")): + C = R.multiply(A, R.const(2, "float32")) + D = R.add(C, B) + return (D, B) + + @I.ir_module + class Expected: + @R.function + def transform_params( + fget_param: R.Callable([R.Prim("int64"), R.Object], R.Object) + ) -> R.Tuple(R.Tensor(ndim=2, dtype="float32"), R.Tensor(ndim=2, dtype="float32")): + R.func_attr({"num_input": 1}) + m = T.int64() + n = T.int64() + + A = fget_param(R.prim_value(0), R.str("A")) + A = R.match_cast(A, R.Tensor([m, n], "float32")) + C = R.multiply(A, R.const(2, "float32")) + + B = fget_param(R.prim_value(1), R.str("B")) + B = R.match_cast(B, R.Tensor([m, n], "float32")) + D = R.add(C, B) + return (D, B) + + After = relax.transform.LazyGetInput()(Before) + tvm.ir.assert_structural_equal(After, Expected) + + def test_set_output_callback(): """fset_output is called for each element of the output tuple