From 0363f0b8f8af39b07a37aca8d35e4cc5923350ce Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 28 Sep 2023 14:38:34 -0500 Subject: [PATCH] [Unity] Avoid trivial var-to-var bindings in CanonicalizeBindings Prior to this commit, the `relax.transform.CanonicalizeBindings` transform would detect trivial bindings `var_y = var_x`, and replace later usage of `var_y` with `var_x`. However, the trivial binding `var_y = var_x` would be left in the canonicalized function. This commit updates the `CanonicalizeBindings` transform to remove trivial bindings. This is not intended as a full dead-code elimination, as that is better handled as a separate pass, but is instead intended to avoid introduction of dead code during canonicalization. --- src/relax/transform/canonicalize_bindings.cc | 49 +++++++++---------- .../test_transform_canonicalize_bindings.py | 41 ++-------------- 2 files changed, 26 insertions(+), 64 deletions(-) diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index ea5a612e1ac0..d8e3a9ba982c 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -38,16 +38,6 @@ class BindingCanonicalizer : public ExprMutator { using ExprMutator::VisitExpr_; - Expr VisitExpr_(const VarNode* op) override { - // remap first - Var v = Downcast(ExprMutator::VisitExpr_(op)); - if (!CanCanonicalizeVar(v)) { - return Downcast(v); - } - // visit again in case we need to do a substitution in the value - return ExprMutator::VisitExpr_(LookupBinding(v).as()); - } - Expr VisitExpr_(const TupleGetItemNode* tuple_get_item) override { if (auto tuple_var = tuple_get_item->tuple.as()) { if (auto tuple_value = LookupBinding(tuple_var.value())) { @@ -71,12 +61,14 @@ class BindingCanonicalizer : public ExprMutator { Expr new_value = this->VisitExpr(binding->value); Var new_var = this->VisitVarDef(binding->var); - if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { + if (auto opt_var = new_value.as(); + opt_var && CanCanonicalizeVar(new_var, opt_var.value())) { + var_remap_[new_var->vid] = opt_var.value(); + } else if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { this->builder_->EmitNormalized(GetRef(binding)); - return; + } else { + this->builder_->EmitNormalized(VarBinding(new_var, new_value)); } - - this->builder_->EmitNormalized(VarBinding(new_var, new_value)); } void VisitBinding_(const MatchCastNode* binding) override { @@ -84,9 +76,19 @@ class BindingCanonicalizer : public ExprMutator { // we can canonicalize to a var binding Expr new_value = this->VisitExpr(binding->value); - // if the LHS and RHS have the same struct info, we canonicalize to a var binding instead - if (StructuralEqual()(binding->struct_info, GetStructInfo(new_value))) { - builder_->EmitNormalized(VarBinding(binding->var, new_value)); + bool has_same_struct_info = StructuralEqual()(binding->struct_info, GetStructInfo(new_value)); + + if (has_same_struct_info) { + if (auto parent = new_value.as(); + parent && CanCanonicalizeVar(binding->var, parent.value())) { + // LHS and RHS have the same struct info, and occur in a + // context where the RHS can replace the LHS. + var_remap_[binding->var->vid] = parent.value(); + } else { + // LHS and RHS have the same struct info, but the RHS is not a + // drop-in replacement for the LHS. + builder_->EmitNormalized(VarBinding(binding->var, new_value)); + } } else if (new_value.same_as(binding->value)) { builder_->EmitNormalized(GetRef(binding)); } else { @@ -104,24 +106,17 @@ class BindingCanonicalizer : public ExprMutator { return !(both_present || neither_present) || (both_present && !check_eq(obj1, obj2)); } - bool CanCanonicalizeVar(Var v) { - Optional value = LookupBinding(v); - // can replace only if the value is also a var - if (!value || !value.as()) { - return false; - } - Var parent_var = Downcast(value); - + bool CanCanonicalizeVar(Var var, Var parent_var) { // Cases when we conservatively do not unify: // 1. checked_type_ or shape_ of the child differs from that of the parent // In this case, we could be overriding user annotations. // 2. If the child is a Var and the parent is a DataflowVar. // That could result in a DataflowVar leaving the current DataflowBlock. - bool annotations_differ = AnnotationsDiffer(v->struct_info_, parent_var->struct_info_, + bool annotations_differ = AnnotationsDiffer(var->struct_info_, parent_var->struct_info_, [&](const ObjectRef& lhs, const ObjectRef& rhs) { return tvm::StructuralEqual()(lhs, rhs); }); - bool var_to_dataflow = (!v.as() && parent_var.as()); + bool var_to_dataflow = (!var.as() && parent_var.as()); return !annotations_differ && !var_to_dataflow; } }; diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py index 91396ccb13e2..52bf5a6e43c3 100644 --- a/tests/python/relax/test_transform_canonicalize_bindings.py +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -36,17 +36,10 @@ def main(x: R.Tensor): o = p return o - # a little annoying to have these unused bindings around - # but they can be eliminated in a separate pass @tvm.script.ir_module class Expected: @R.function def main(x: R.Tensor): - y = x - z = x - q = x - p = x - o = x return x new_mod = relax.transform.CanonicalizeBindings()(TestChainAssignments) @@ -68,19 +61,16 @@ def main(x: R.Tensor): R.output(n) return n - # a little annoying to have these unused bindings around - # but they can be eliminated in a separate pass @tvm.script.ir_module class Expected: @R.function def main(x: R.Tensor): with R.dataflow(): y = R.const(1) - z = y - o = y - p = y - m = y - # we can't get rid of n because it leaves the block + # We can't get rid of n because it leaves the block. + # CanonicalizeBindings does not do a full dead-code + # elimination, and only does local analysis of trivial + # bindings that it may produce. n = y R.output(n) return n @@ -108,15 +98,6 @@ def main(x: R.Tensor): class Expected: @R.function def main(x: R.Tensor): - with R.dataflow(): - y = x - z = x - o = x - p = x - m = x - # we can't get rid of n because it leaves the block - n = x - R.output(n) return x new_mod = relax.transform.CanonicalizeBindings()(TestDataflowAssignments) @@ -137,8 +118,6 @@ def main(x: R.Tensor, y: R.Tensor): class Expected: @R.function def main(x: R.Tensor, y: R.Tensor): - w = y - q = x z = R.add(y, x) return R.add(x, z) @@ -161,7 +140,6 @@ def main(x: R.Tensor) -> R.Object: class Expected: @R.function def main(x: R.Tensor) -> R.Object: - y = x # Cannot unify because the cast indicates user intent z: R.Object = x return z @@ -185,11 +163,9 @@ def main(x: R.Tensor): class Expected: @R.function def main(x: R.Tensor): - q = x # can't get rid of z because its shape_ is different from x's m, n = T.int64(), T.int64() z = R.match_cast(x, R.Tensor((m, n))) - w = z return z new_mod = relax.transform.CanonicalizeBindings()(TestMatchCast) @@ -213,11 +189,6 @@ def main(x: R.Tensor(("m", "n"), "float32")): class Expected: @R.function def main(x: R.Tensor(("m", "n"), "float32")): - m, n = T.int64(), T.int64() - y = x - # canonicalized into a var binding - z = x - w = x q = R.add(x, x) return R.add(q, x) @@ -242,10 +213,8 @@ def main(x: R.Tensor(("m", "n"))): class Expected: @R.function def main(x: R.Tensor(("m", "n"))): - y = x o, p = T.int64(), T.int64() z = R.match_cast(x, R.Tensor((o, p))) - w = z # the shape_ field on q will need to be updated q = R.add(z, x) return R.add(q, z) @@ -270,8 +239,6 @@ class Expected: @R.function def main(x: R.Tensor, y: R.Tensor): tuple_var = (x, y) - w = x - q = y z = R.add(x, y) return R.add(y, z)