From 31d6dd0f086df4b9b64c37165988a1642cfaf1ea Mon Sep 17 00:00:00 2001 From: Qiang Zhang Date: Mon, 20 Feb 2023 15:54:31 +0800 Subject: [PATCH] [Relay] ExprMutator Return Origin Expr When All Fields Isn't Changed --- python/tvm/relay/expr_functor.py | 56 ++++++++++++++++++------- tests/python/relay/test_expr_functor.py | 2 +- 2 files changed, 43 insertions(+), 15 deletions(-) diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index 95a8c79dc2d8..48941b2b23b9 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -204,17 +204,23 @@ class ExprMutator(ExprFunctor): def visit_function(self, fn): new_params = [self.visit(x) for x in fn.params] new_body = self.visit(fn.body) + if new_params == list(fn.params) and new_body == fn.body: + return fn return FunctionWithFields(fn, list(new_params), new_body) def visit_let(self, let): new_var = self.visit(let.var) new_val = self.visit(let.value) new_body = self.visit(let.body) + if new_var == let.var and new_val == let.value and new_body == let.body: + return let return Let(new_var, new_val, new_body) def visit_call(self, call): new_fn = self.visit(call.op) new_args = [self.visit(arg) for arg in call.args] + if new_fn == call.op and new_args == list(call.args): + return call return Call(new_fn, new_args, call.attrs, call.type_args, call.span) def visit_var(self, var): @@ -224,16 +230,28 @@ def visit_global_id(self, global_var): return global_var def visit_if(self, ite): - return If(self.visit(ite.cond), self.visit(ite.true_branch), self.visit(ite.false_branch)) + new_cond = self.visit(ite.cond) + new_true_branch = self.visit(ite.true_branch) + new_false_branch = self.visit(ite.false_branch) + if ( + new_cond == ite.cond + and new_true_branch == ite.true_branch + and new_false_branch == ite.false_branch + ): + return ite + return If(new_cond, new_true_branch, new_false_branch) def visit_tuple(self, tup): - return Tuple([self.visit(field) for field in tup.fields], tup.span) + new_fields = [self.visit(field) for field in tup.fields] + if new_fields == list(tup.fields): + return tup + return Tuple(new_fields, tup.span) def visit_tuple_getitem(self, op): - tuple_value = self.visit(op.tuple_value) - if not tuple_value.same_as(op.tuple_value): - return TupleGetItem(tuple_value, op.index) - return op + new_tuple_value = self.visit(op.tuple_value) + if new_tuple_value == op.tuple_value: + return op + return TupleGetItem(new_tuple_value, op.index) def visit_global_var(self, gvar): return gvar @@ -248,17 +266,27 @@ def visit_constructor(self, con): return con def visit_match(self, m): - return Match( - self.visit(m.data), - [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses], - complete=m.complete, - ) + new_data = self.visit(m.data) + new_clauses = [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses] + if new_data == m.data and all(x.rhs == y.rhs for x, y in zip(new_clauses, m.clauses)): + return m + return Match(new_data, new_clauses, complete=m.complete) def visit_ref_create(self, r): - return RefCreate(self.visit(r.value)) + new_value = self.visit(r.value) + if new_value == r.value: + return r + return RefCreate(new_value) def visit_ref_write(self, r): - return RefWrite(self.visit(r.ref), self.visit(r.value)) + new_ref = self.visit(r.ref) + new_value = self.visit(r.value) + if new_ref == r.ref and new_value == r.value: + return r + return RefWrite(new_ref, new_value) def visit_ref_read(self, r): - return RefRead(self.visit(r.ref)) + new_ref = self.visit(r.ref) + if new_ref == r.ref: + return r + return RefRead(new_ref) diff --git a/tests/python/relay/test_expr_functor.py b/tests/python/relay/test_expr_functor.py index 45317836faf0..930cbd926080 100644 --- a/tests/python/relay/test_expr_functor.py +++ b/tests/python/relay/test_expr_functor.py @@ -32,7 +32,7 @@ def check_visit(expr): ev.visit(expr) em = ExprMutator() - assert em.visit(expr) + assert expr == em.visit(expr) def test_constant():