From 531f48a170179af4950be6d286ad7fc1cec6665a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 17 Oct 2023 19:17:08 +0000 Subject: [PATCH] [Unity] Handle duplicate outputs in LazyTransformParams A parameter transformation may output the same tensor more than once. When this occurs, the `set_item` function should be called for each output tensor. --- .../relax/transform/lazy_transform_params.py | 26 +++++----- .../test_transform_lazy_transform_params.py | 52 +++++++++++++++++++ 2 files changed, 66 insertions(+), 12 deletions(-) diff --git a/python/tvm/relax/transform/lazy_transform_params.py b/python/tvm/relax/transform/lazy_transform_params.py index 4774ae429139..6a8adcb64b83 100644 --- a/python/tvm/relax/transform/lazy_transform_params.py +++ b/python/tvm/relax/transform/lazy_transform_params.py @@ -56,7 +56,9 @@ def visit_var_binding_(self, binding: relax.VarBinding) -> None: if binding.var == self.out_tuple_var: assert isinstance(binding.value, relax.Tuple) for i, expr in enumerate(binding.value.fields): - self.out_tuple_map[expr] = relax.PrimValue(i) + if expr not in self.out_tuple_map: + self.out_tuple_map[expr] = [] + self.out_tuple_map[expr].append(relax.PrimValue(i)) else: self.is_tuple_get_item_input = False super().visit_var_binding_(binding) @@ -198,17 +200,17 @@ def visit_var_binding_(self, binding: relax.VarBinding) -> None: for var in self.memory_free_insertion[binding.var]: if var in self.out_tuple_map: self.killed_vars.add(var) - index = self.out_tuple_map[var] - # rewrite set item - self.builder_.emit( - relax.Call( - relax.ExternFunc("set_item"), - [index, super().visit_var_(var)], - None, - [relax.ObjectStructInfo()], - ), - name_hint="_", - ) + for index in self.out_tuple_map[var]: + # rewrite set item + self.builder_.emit( + relax.Call( + relax.ExternFunc("set_item"), + [index, super().visit_var_(var)], + None, + [relax.ObjectStructInfo()], + ), + name_hint="_", + ) if var in self.input_params_set: self.builder_.emit( diff --git a/tests/python/relax/test_transform_lazy_transform_params.py b/tests/python/relax/test_transform_lazy_transform_params.py index 6ce728ba95af..af7ed1956bbf 100644 --- a/tests/python/relax/test_transform_lazy_transform_params.py +++ b/tests/python/relax/test_transform_lazy_transform_params.py @@ -376,5 +376,57 @@ def set_item(i, value): tvm.testing.assert_allclose(expected_i, transformed_i) +def test_duplicate_outputs(): + """A tensor may be repeated in the output + + This is something that should be avoided upstream, but is a legal + parameter transformation, and should produce correct output. + """ + + @I.ir_module + class Before: + @R.function + def main_transform_params( + params: R.Tuple(R.Tensor([16], dtype="int32"), R.Tensor([16], dtype="int32")) + ): + R.func_attr({"relax.force_pure": True}) + param0 = params[0] + param1 = params[1] + transformed0 = R.add(param0, R.const(1, "int32")) + transformed1 = R.add(param1, R.const(2, "int32")) + output = (transformed0, transformed1, transformed0) + return output + + @I.ir_module + class Expected: + @R.function(pure=False) + def main_transform_params() -> R.Tuple: + gv: R.Object = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,)) + gv1: R.Tensor((16,), dtype="int32") = R.match_cast(gv, R.Tensor((16,), dtype="int32")) + param0: R.Tensor((16,), dtype="int32") = gv1 + + gv2: R.Object = R.call_packed("get_item", R.prim_value(1), sinfo_args=(R.Object,)) + gv3: R.Tensor((16,), dtype="int32") = R.match_cast(gv2, R.Tensor((16,), dtype="int32")) + param1: R.Tensor((16,), dtype="int32") = gv3 + + transformed0: R.Tensor((16,), dtype="int32") = R.add(param0, R.const(1, "int32")) + _: R.Tuple = R.vm.kill_object(param0) + _: R.Object = R.call_packed( + "set_item", R.prim_value(0), transformed0, sinfo_args=(R.Object,) + ) + _: R.Object = R.call_packed( + "set_item", R.prim_value(2), transformed0, sinfo_args=(R.Object,) + ) + + transformed1: R.Tensor((16,), dtype="int32") = R.add(param1, R.const(2, "int32")) + _ = R.vm.kill_object(param1) + _ = R.call_packed("set_item", R.prim_value(1), transformed1, sinfo_args=(R.Object,)) + output = R.tuple() + return output + + after = LazyTransformParams()(Before) + tvm.ir.assert_structural_equal(after, Expected) + + if __name__ == "__main__": tvm.testing.main()