diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index 962f76a376b6..2da72a4f5ac7 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -51,7 +51,7 @@ class BindingCanonicalizer : public ExprMutator { if (!CanCanonicalizeVar(v)) { return Downcast(v); } - return ExprMutator::VisitExpr_(LookupBinding(v).as()); + return ExprMutator::VisitExpr_(LookupBinding(v).as()); } void VisitBinding_(const VarBindingNode* binding) override { diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py index 086c316ae817..5e1d1b881e2c 100644 --- a/tests/python/relax/test_transform_canonicalize_bindings.py +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -89,6 +89,40 @@ def main(x: R.Tensor): assert_structural_equal(new_mod, Expected) +def test_assign_to_output_indataflow_block(): + @tvm.script.ir_module + class TestDataflowAssignments: + @R.function + def main(x: R.Tensor): + with R.dataflow(): + y = x # is not a dataflow var + z = y + o = z + p = o + m = p + n = m + R.output(n) + return n + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor): + with R.dataflow(): + y = x + z = x + o = x + p = x + m = x + # we can't get rid of n because it leaves the block + n = x + R.output(n) + return x + + new_mod = relax.transform.CanonicalizeBindings()(TestDataflowAssignments) + assert_structural_equal(new_mod, Expected) + + def test_ops(): @tvm.script.ir_module class TestOps: