From b5aeb83b15ae8ceb5cd17fce762d4da54262fbe3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 7 Aug 2024 09:11:11 -0500 Subject: [PATCH 1/2] [Relax][Transform] Handle tuple return in RemoveUnusedOutputs Prior to this commit, the `relax.transform.RemoveUnusedOutputs` pass only marked a tuple element as used if it occurred in a `TupleGetItem` node. This ignored use cases where a tuple is used as an aggregate object, such as returning a tuple from a function. This would collect incorrect results for a Relax function that calls a subroutine, receives a tuple as the return value of the subroutine, then returns that tuple. This commit updates `RemoveUnusedOutputs` to look for usage of a tuple object, not just for usage in `TupleGetItem`. Closes https://github.com/apache/tvm/issues/17247 --- src/relax/transform/remove_unused_outputs.cc | 59 ++++++++++++------- .../test_transform_remove_unused_outputs.py | 20 +++++++ 2 files changed, 59 insertions(+), 20 deletions(-) diff --git a/src/relax/transform/remove_unused_outputs.cc b/src/relax/transform/remove_unused_outputs.cc index e3bf12382c67..9a5c31e79ba0 100644 --- a/src/relax/transform/remove_unused_outputs.cc +++ b/src/relax/transform/remove_unused_outputs.cc @@ -92,29 +92,48 @@ class PartialTupleUsageCollector : ExprVisitor { } void VisitExpr_(const TupleGetItemNode* op) override { - Expr tuple = UnwrapBindings(op->tuple); - - if (auto call = tuple.as()) { - if (auto opt_callee = call->op.as()) { - auto callee = opt_callee.value(); - if (auto it = output_usage_mask_.find(callee); it != output_usage_mask_.end()) { - auto& used_indices = it->second; - - CHECK_GE(op->index, 0) << "IndexError: " - << "Indices for TupleGetItem must be non-negative, " - << "but expression " << GetRef(op) - << " uses a tuple index of " << op->index; - size_t index = op->index; - - CHECK_LT(index, used_indices.size()) - << "IndexError: " - << "Indices for TupleGetItem must be less than the size of the tuple, " - << "but expression " << GetRef(op) << " uses a tuple index of " << op->index - << " for a tuple of size " << used_indices.size(); - used_indices[index] = true; + if (auto* usage_mask_ptr = GetCalleeUsageMask(op->tuple)) { + auto& used_indices = *usage_mask_ptr; + + CHECK_GE(op->index, 0) << "IndexError: " + << "Indices for TupleGetItem must be non-negative, " + << "but expression " << GetRef(op) << " uses a tuple index of " + << op->index; + size_t index = op->index; + + CHECK_LT(index, used_indices.size()) + << "IndexError: " + << "Indices for TupleGetItem must be less than the size of the tuple, " + << "but expression " << GetRef(op) << " uses a tuple index of " << op->index + << " for a tuple of size " << used_indices.size(); + used_indices[index] = true; + } + } + + void VisitExpr_(const VarNode* op) override { + if (auto* usage_mask_ptr = GetCalleeUsageMask(GetRef(op))) { + auto& usage_mask = *usage_mask_ptr; + for (size_t i = 0; i < usage_mask.size(); i++) { + usage_mask[i] = true; + } + } + } + + std::vector* GetCalleeUsageMask(Expr expr) { + if (!expr->struct_info_.as()) { + return nullptr; + } + + expr = UnwrapBindings(expr); + if (auto call = expr.as()) { + if (auto callee = call->op.as()) { + if (auto it = output_usage_mask_.find(callee.value()); it != output_usage_mask_.end()) { + return &it->second; } } } + + return nullptr; } Expr UnwrapBindings(Expr expr) const { diff --git a/tests/python/relax/test_transform_remove_unused_outputs.py b/tests/python/relax/test_transform_remove_unused_outputs.py index c0405ca58d00..d07ecf578d89 100644 --- a/tests/python/relax/test_transform_remove_unused_outputs.py +++ b/tests/python/relax/test_transform_remove_unused_outputs.py @@ -119,5 +119,25 @@ def func() -> R.Tuple([R.Tensor([16, 16], "int32"), R.Tensor([32, 32], "int32")] return (A, C) +class TestReturnTuple(BaseCompare): + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([16, 16], "int32")): + B = R.add(A, A) + out_tuple = Before.func(B) + return out_tuple + + @R.function(private=True) + def func(B: R.Tensor([16, 16], "int32")) -> R.Tuple( + R.Tensor([16, 16], "int32"), R.Tensor([16, 16], "int32") + ): + C = R.multiply(B, B) + D = R.add(B, B) + return (C, D) + + Expected = Before + + if __name__ == "__main__": tvm.testing.main() From 06ed9822c9ae3d01bf13c76b7c7d137cc63b8be9 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 9 Aug 2024 15:29:46 -0500 Subject: [PATCH 2/2] lint fix --- tests/python/relax/test_transform_remove_unused_outputs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/relax/test_transform_remove_unused_outputs.py b/tests/python/relax/test_transform_remove_unused_outputs.py index d07ecf578d89..365ce1695d0e 100644 --- a/tests/python/relax/test_transform_remove_unused_outputs.py +++ b/tests/python/relax/test_transform_remove_unused_outputs.py @@ -129,9 +129,9 @@ def main(A: R.Tensor([16, 16], "int32")): return out_tuple @R.function(private=True) - def func(B: R.Tensor([16, 16], "int32")) -> R.Tuple( - R.Tensor([16, 16], "int32"), R.Tensor([16, 16], "int32") - ): + def func( + B: R.Tensor([16, 16], "int32") + ) -> R.Tuple(R.Tensor([16, 16], "int32"), R.Tensor([16, 16], "int32")): C = R.multiply(B, B) D = R.add(B, B) return (C, D)