Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 22 additions & 27 deletions src/relax/transform/canonicalize_bindings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,6 @@ class BindingCanonicalizer : public ExprMutator {

using ExprMutator::VisitExpr_;

Expr VisitExpr_(const VarNode* op) override {
// remap first
Var v = Downcast<Var>(ExprMutator::VisitExpr_(op));
if (!CanCanonicalizeVar(v)) {
return Downcast<Expr>(v);
}
// visit again in case we need to do a substitution in the value
return ExprMutator::VisitExpr_(LookupBinding(v).as<VarNode>());
}

Expr VisitExpr_(const TupleGetItemNode* tuple_get_item) override {
if (auto tuple_var = tuple_get_item->tuple.as<Var>()) {
if (auto tuple_value = LookupBinding(tuple_var.value())) {
Expand All @@ -71,22 +61,34 @@ class BindingCanonicalizer : public ExprMutator {
Expr new_value = this->VisitExpr(binding->value);
Var new_var = this->VisitVarDef(binding->var);

if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) {
if (auto opt_var = new_value.as<Var>();
opt_var && CanCanonicalizeVar(new_var, opt_var.value())) {
var_remap_[new_var->vid] = opt_var.value();
} else if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) {
this->builder_->EmitNormalized(GetRef<VarBinding>(binding));
return;
} else {
this->builder_->EmitNormalized(VarBinding(new_var, new_value));
}

this->builder_->EmitNormalized(VarBinding(new_var, new_value));
}

void VisitBinding_(const MatchCastNode* binding) override {
// If we have a trivial shape check (the shape_ of LHS and RHS is the same),
// we can canonicalize to a var binding
Expr new_value = this->VisitExpr(binding->value);

// if the LHS and RHS have the same struct info, we canonicalize to a var binding instead
if (StructuralEqual()(binding->struct_info, GetStructInfo(new_value))) {
builder_->EmitNormalized(VarBinding(binding->var, new_value));
bool has_same_struct_info = StructuralEqual()(binding->struct_info, GetStructInfo(new_value));

if (has_same_struct_info) {
if (auto parent = new_value.as<Var>();
parent && CanCanonicalizeVar(binding->var, parent.value())) {
// LHS and RHS have the same struct info, and occur in a
// context where the RHS can replace the LHS.
var_remap_[binding->var->vid] = parent.value();
} else {
// LHS and RHS have the same struct info, but the RHS is not a
// drop-in replacement for the LHS.
builder_->EmitNormalized(VarBinding(binding->var, new_value));
}
} else if (new_value.same_as(binding->value)) {
builder_->EmitNormalized(GetRef<MatchCast>(binding));
} else {
Expand All @@ -104,24 +106,17 @@ class BindingCanonicalizer : public ExprMutator {
return !(both_present || neither_present) || (both_present && !check_eq(obj1, obj2));
}

bool CanCanonicalizeVar(Var v) {
Optional<Expr> value = LookupBinding(v);
// can replace only if the value is also a var
if (!value || !value.as<VarNode>()) {
return false;
}
Var parent_var = Downcast<Var>(value);

bool CanCanonicalizeVar(Var var, Var parent_var) {
// Cases when we conservatively do not unify:
// 1. checked_type_ or shape_ of the child differs from that of the parent
// In this case, we could be overriding user annotations.
// 2. If the child is a Var and the parent is a DataflowVar.
// That could result in a DataflowVar leaving the current DataflowBlock.
bool annotations_differ = AnnotationsDiffer(v->struct_info_, parent_var->struct_info_,
bool annotations_differ = AnnotationsDiffer(var->struct_info_, parent_var->struct_info_,
[&](const ObjectRef& lhs, const ObjectRef& rhs) {
return tvm::StructuralEqual()(lhs, rhs);
});
bool var_to_dataflow = (!v.as<DataflowVarNode>() && parent_var.as<DataflowVarNode>());
bool var_to_dataflow = (!var.as<DataflowVarNode>() && parent_var.as<DataflowVarNode>());
return !annotations_differ && !var_to_dataflow;
}
};
Expand Down
41 changes: 4 additions & 37 deletions tests/python/relax/test_transform_canonicalize_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,10 @@ def main(x: R.Tensor):
o = p
return o

# a little annoying to have these unused bindings around
# but they can be eliminated in a separate pass
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
y = x
z = x
q = x
p = x
o = x
return x

new_mod = relax.transform.CanonicalizeBindings()(TestChainAssignments)
Expand All @@ -68,19 +61,16 @@ def main(x: R.Tensor):
R.output(n)
return n

# a little annoying to have these unused bindings around
# but they can be eliminated in a separate pass
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
with R.dataflow():
y = R.const(1)
z = y
o = y
p = y
m = y
# we can't get rid of n because it leaves the block
# We can't get rid of n because it leaves the block.
# CanonicalizeBindings does not do a full dead-code
# elimination, and only does local analysis of trivial
# bindings that it may produce.
n = y
R.output(n)
return n
Expand Down Expand Up @@ -108,15 +98,6 @@ def main(x: R.Tensor):
class Expected:
@R.function
def main(x: R.Tensor):
with R.dataflow():
y = x
z = x
o = x
p = x
m = x
# we can't get rid of n because it leaves the block
n = x
R.output(n)
return x

new_mod = relax.transform.CanonicalizeBindings()(TestDataflowAssignments)
Expand All @@ -137,8 +118,6 @@ def main(x: R.Tensor, y: R.Tensor):
class Expected:
@R.function
def main(x: R.Tensor, y: R.Tensor):
w = y
q = x
z = R.add(y, x)
return R.add(x, z)

Expand All @@ -161,7 +140,6 @@ def main(x: R.Tensor) -> R.Object:
class Expected:
@R.function
def main(x: R.Tensor) -> R.Object:
y = x
# Cannot unify because the cast indicates user intent
z: R.Object = x
return z
Expand All @@ -185,11 +163,9 @@ def main(x: R.Tensor):
class Expected:
@R.function
def main(x: R.Tensor):
q = x
# can't get rid of z because its shape_ is different from x's
m, n = T.int64(), T.int64()
z = R.match_cast(x, R.Tensor((m, n)))
w = z
return z

new_mod = relax.transform.CanonicalizeBindings()(TestMatchCast)
Expand All @@ -213,11 +189,6 @@ def main(x: R.Tensor(("m", "n"), "float32")):
class Expected:
@R.function
def main(x: R.Tensor(("m", "n"), "float32")):
m, n = T.int64(), T.int64()
y = x
# canonicalized into a var binding
z = x
w = x
q = R.add(x, x)
return R.add(q, x)

Expand All @@ -242,10 +213,8 @@ def main(x: R.Tensor(("m", "n"))):
class Expected:
@R.function
def main(x: R.Tensor(("m", "n"))):
y = x
o, p = T.int64(), T.int64()
z = R.match_cast(x, R.Tensor((o, p)))
w = z
# the shape_ field on q will need to be updated
q = R.add(z, x)
return R.add(q, z)
Expand All @@ -270,8 +239,6 @@ class Expected:
@R.function
def main(x: R.Tensor, y: R.Tensor):
tuple_var = (x, y)
w = x
q = y
z = R.add(x, y)
return R.add(y, z)

Expand Down