diff --git a/src/relax/transform/remove_unused_outputs.cc b/src/relax/transform/remove_unused_outputs.cc index e3bf12382c67..9a5c31e79ba0 100644 --- a/src/relax/transform/remove_unused_outputs.cc +++ b/src/relax/transform/remove_unused_outputs.cc @@ -92,29 +92,48 @@ class PartialTupleUsageCollector : ExprVisitor { } void VisitExpr_(const TupleGetItemNode* op) override { - Expr tuple = UnwrapBindings(op->tuple); - - if (auto call = tuple.as()) { - if (auto opt_callee = call->op.as()) { - auto callee = opt_callee.value(); - if (auto it = output_usage_mask_.find(callee); it != output_usage_mask_.end()) { - auto& used_indices = it->second; - - CHECK_GE(op->index, 0) << "IndexError: " - << "Indices for TupleGetItem must be non-negative, " - << "but expression " << GetRef(op) - << " uses a tuple index of " << op->index; - size_t index = op->index; - - CHECK_LT(index, used_indices.size()) - << "IndexError: " - << "Indices for TupleGetItem must be less than the size of the tuple, " - << "but expression " << GetRef(op) << " uses a tuple index of " << op->index - << " for a tuple of size " << used_indices.size(); - used_indices[index] = true; + if (auto* usage_mask_ptr = GetCalleeUsageMask(op->tuple)) { + auto& used_indices = *usage_mask_ptr; + + CHECK_GE(op->index, 0) << "IndexError: " + << "Indices for TupleGetItem must be non-negative, " + << "but expression " << GetRef(op) << " uses a tuple index of " + << op->index; + size_t index = op->index; + + CHECK_LT(index, used_indices.size()) + << "IndexError: " + << "Indices for TupleGetItem must be less than the size of the tuple, " + << "but expression " << GetRef(op) << " uses a tuple index of " << op->index + << " for a tuple of size " << used_indices.size(); + used_indices[index] = true; + } + } + + void VisitExpr_(const VarNode* op) override { + if (auto* usage_mask_ptr = GetCalleeUsageMask(GetRef(op))) { + auto& usage_mask = *usage_mask_ptr; + for (size_t i = 0; i < usage_mask.size(); i++) { + usage_mask[i] = true; + } + } + } + + std::vector* GetCalleeUsageMask(Expr expr) { + if (!expr->struct_info_.as()) { + return nullptr; + } + + expr = UnwrapBindings(expr); + if (auto call = expr.as()) { + if (auto callee = call->op.as()) { + if (auto it = output_usage_mask_.find(callee.value()); it != output_usage_mask_.end()) { + return &it->second; } } } + + return nullptr; } Expr UnwrapBindings(Expr expr) const { diff --git a/tests/python/relax/test_transform_remove_unused_outputs.py b/tests/python/relax/test_transform_remove_unused_outputs.py index c0405ca58d00..365ce1695d0e 100644 --- a/tests/python/relax/test_transform_remove_unused_outputs.py +++ b/tests/python/relax/test_transform_remove_unused_outputs.py @@ -119,5 +119,25 @@ def func() -> R.Tuple([R.Tensor([16, 16], "int32"), R.Tensor([32, 32], "int32")] return (A, C) +class TestReturnTuple(BaseCompare): + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([16, 16], "int32")): + B = R.add(A, A) + out_tuple = Before.func(B) + return out_tuple + + @R.function(private=True) + def func( + B: R.Tensor([16, 16], "int32") + ) -> R.Tuple(R.Tensor([16, 16], "int32"), R.Tensor([16, 16], "int32")): + C = R.multiply(B, B) + D = R.add(B, B) + return (C, D) + + Expected = Before + + if __name__ == "__main__": tvm.testing.main()