diff --git a/src/relax/transform/lazy_transform_params.cc b/src/relax/transform/lazy_transform_params.cc index 21608af7dba0..37827fbe0e6c 100644 --- a/src/relax/transform/lazy_transform_params.cc +++ b/src/relax/transform/lazy_transform_params.cc @@ -71,8 +71,22 @@ class LazyInputMutator : public ExprMutator { Array new_params(func->params.begin(), func->params.begin() + num_input_params); new_params.push_back(fget_param); + auto array_externally_visible_vars = + DefinableTIRVarsInStructInfo(TupleStructInfo(new_params.Map(GetStructInfo))); + std::unordered_set externally_visible_vars( + array_externally_visible_vars.begin(), array_externally_visible_vars.end()); + StructInfo new_ret_struct_info = + EraseToWellDefined(func->ret_struct_info, [&](const tir::Var& var) -> Optional { + if (externally_visible_vars.count(var)) { + return var; + } else { + return NullOpt; + } + }); + auto node = GetRef(func); node.CopyOnWrite()->params = new_params; + node.CopyOnWrite()->ret_struct_info = new_ret_struct_info; node = WithAttr(node, attr::kNumInput, Integer(num_input_params + 1)); plan_ = FunctionPlan{std::move(param_lookup), fget_param}; diff --git a/tests/python/relax/test_transform_lazy_transform_params.py b/tests/python/relax/test_transform_lazy_transform_params.py index 833cbd460c0f..040aea28909d 100644 --- a/tests/python/relax/test_transform_lazy_transform_params.py +++ b/tests/python/relax/test_transform_lazy_transform_params.py @@ -951,6 +951,40 @@ def transform_params( tvm.ir.assert_structural_equal(After, Expected) +def test_get_item_callback_dynamic_shape(): + @I.ir_module + class Before: + @R.function + def transform_params( + A: R.Tensor(["m", "n"], "float32"), B: R.Tensor(["m", "n"], "float32") + ) -> R.Tuple(R.Tensor(["m", "n"], "float32"), R.Tensor(["m", "n"], "float32")): + C = R.multiply(A, R.const(2, "float32")) + D = R.add(C, B) + return (D, B) + + @I.ir_module + class Expected: + @R.function + def transform_params( + fget_param: R.Callable([R.Prim("int64"), R.Object], R.Object) + ) -> R.Tuple(R.Tensor(ndim=2, dtype="float32"), R.Tensor(ndim=2, dtype="float32")): + R.func_attr({"num_input": 1}) + m = T.int64() + n = T.int64() + + A = fget_param(R.prim_value(0), R.str("A")) + A = R.match_cast(A, R.Tensor([m, n], "float32")) + C = R.multiply(A, R.const(2, "float32")) + + B = fget_param(R.prim_value(1), R.str("B")) + B = R.match_cast(B, R.Tensor([m, n], "float32")) + D = R.add(C, B) + return (D, B) + + After = relax.transform.LazyGetInput()(Before) + tvm.ir.assert_structural_equal(After, Expected) + + def test_set_output_callback(): """fset_output is called for each element of the output tuple