From 5905d357c7ed991661bdd72d8363a85099d8b4e8 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Mon, 29 Oct 2018 23:48:23 -0700 Subject: [PATCH 1/6] [PASS] Simplify inference. --- python/tvm/relay/expr.py | 62 ++++++++++++++++++- python/tvm/relay/ir_pass.py | 2 + python/tvm/relay/op/__init__.py | 8 +++ src/relay/pass/simplify_inference.cc | 90 ++++++++++++++++++++++++++++ 4 files changed, 161 insertions(+), 1 deletion(-) create mode 100644 src/relay/pass/simplify_inference.cc diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 0650a493d9a6..6e9ff6357854 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -1,6 +1,7 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """The expression nodes of Relay.""" from __future__ import absolute_import +from numbers import Number as _Number import numpy as _np from .base import RelayNode, register_relay_node @@ -11,6 +12,8 @@ from .. import nd as _nd from .. import convert +# will be registered afterwards +_op_make = None class Expr(RelayNode): """The base type for all Relay expressions.""" @@ -48,6 +51,63 @@ def astype(self, dtype): """ return _make.dtype_cast(self, dtype) + def __add__(self, other): + if isinstance(other, Expr): + return _op_make.add(self, other) + elif isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + + def __radd__(self, other): + return self.__add__(other) + + def __sub__(self, other): + if isinstance(other, Expr): + return _op_make.subtract(self, other) + elif isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + + def __rsub__(self, other): + if isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + + def __mul__(self, other): + if isinstance(other, Expr): + return _op_make.multiply(self, other) + elif isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + + def __rmul__(self, other): + return self.__mul__(other) + + def __div__(self, other): + print('divide') + if isinstance(other, Expr): + return _op_make.divide(self, other) + elif isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + + def __rdiv__(self, other): + if isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + + def __truediv__(self, other): + return self.__div__(other) + + def __rtruediv__(self, other): + return self.__rdiv__(other) + @register_relay_node class Constant(Expr): @@ -305,7 +365,7 @@ def __len__(self): def __repr__(self): return ("TupleWrapper(" + self.tuple_value.__repr__() + - ", " + self.size + ")") + ", " + str(self.size) + ")") def astype(self, _): raise TypeError("astype cannot be used on tuple") diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 6adfaacdc86d..742e8f1248ed 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -131,6 +131,8 @@ def free_type_vars(expr): """ return _ir_pass.free_type_vars(expr) +def simplify_inference(expr): + return _ir_pass.simplify_inference(expr) def dead_code_elimination(expr): """ Remove expressions which does not effect the program result (dead code). diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index c0af986be4f7..6e0e6937310c 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -15,3 +15,11 @@ from . import _tensor from ..expr import Expr from ..base import register_relay_node + + +def _register_op_make(): + import sys + from .. import expr + expr._op_make = sys.modules['tvm.relay.op._make'] + +_register_op_make() diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc new file mode 100644 index 000000000000..23b21b87355c --- /dev/null +++ b/src/relay/pass/simplify_inference.cc @@ -0,0 +1,90 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file simplify_inference.cc + */ +#include +#include +#include +#include "./pattern_util.h" + +namespace tvm { +namespace relay { + +// TODO: make type generic +Constant make_const(float x) { + DLDataType dtype{kDLFloat, 32, 1}; + runtime::NDArray data = runtime::NDArray::Empty({}, dtype, {kDLCPU, 0}); + float* pdata = static_cast(data->data); + *pdata = x; + Constant n = ConstantNode::make(data); + return n; +} + +Expr +BatchNormToInferUnpack(const Attrs attrs, + Expr data, + Expr gamma, + Expr beta, + Expr moving_mean, + Expr moving_var) { + const auto param = attrs.as(); + Expr epsilon = make_const(param->epsilon); + Expr var_add_eps = CallNode::make(Op::Get("add"), {moving_var, epsilon}); + Expr sqrt = CallNode::make(Op::Get("sqrt"), {var_add_eps}); + Expr scale = CallNode::make(Op::Get("divide"), {make_const(1.0f), sqrt}); + + if (param->scale) { + scale = CallNode::make( + Op::Get("multiply"), {scale, gamma}); + } + Expr neg_mean = CallNode::make(Op::Get("negative"), {moving_mean}); + Expr shift = CallNode::make(Op::Get("multiply"), {neg_mean, scale}); + if (param->center) { + shift = CallNode::make(Op::Get("add"), {shift, beta}); + } + + int axis = param->axis; + const auto* tdata = data->type_as(); + CHECK(tdata) << "require checked type"; + Array dshape; + for (auto e : tdata->shape) { + CHECK(is_const(e)); + const IntImm* imm = e.as(); + CHECK(imm); + dshape.push_back(Integer(imm->value)); + } + scale = ExpandBiasToMatchAxis(scale, axis, dshape); + shift = ExpandBiasToMatchAxis(shift, axis, dshape); + + Expr out = CallNode::make(Op::Get("multiply"), {data, scale}); + out = CallNode::make(Op::Get("add"), {out, shift}); + return out; +} + +class Simplifier : public ExprMutator { + public: + Expr VisitExpr_(const CallNode* n) final { + if (const OpNode* op = n->op.as()) { + LOG(INFO) << "op: " << op->name; + if (op->name == "nn.batch_norm") { + LOG(INFO) << n->args; + return BatchNormToInferUnpack(n->attrs, n->args[0], n->args[1], n->args[2], n->args[3], n->args[4]); + } else if (op->name == "nn.dropout") { + return n->args[0]; + } + } + return GetRef(n); + } +}; + +Expr SimplifyInference(const Expr& e) { + return Simplifier().Mutate(e); +} + +TVM_REGISTER_API("relay._ir_pass.simplify_inference") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = SimplifyInference(args[0]); + }); + +} // namespace relay +} // namespace tvm From c2e217cb642191e39c999b9821d60874cad74699 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 30 Oct 2018 16:36:06 -0700 Subject: [PATCH 2/6] [PASS] Update. --- .../compiler/test_simplify_inference.py | 1 - python/tvm/relay/expr.py | 1 - python/tvm/relay/op/__init__.py | 4 +- src/relay/pass/pattern_util.h | 7 +++ src/relay/pass/simplify_inference.cc | 62 +++++++------------ .../relay/test_pass_simplify_inference.py | 47 ++++++++++++++ 6 files changed, 80 insertions(+), 42 deletions(-) create mode 100644 tests/python/relay/test_pass_simplify_inference.py diff --git a/nnvm/tests/python/compiler/test_simplify_inference.py b/nnvm/tests/python/compiler/test_simplify_inference.py index e2826765995e..fd0e1e3c182e 100644 --- a/nnvm/tests/python/compiler/test_simplify_inference.py +++ b/nnvm/tests/python/compiler/test_simplify_inference.py @@ -10,7 +10,6 @@ def simple_bn(x, gamma, beta, moving_mean, moving_var, scale = sym.elemwise_mul(1 / sym.sqrt(moving_var + epsilon), gamma) shift = sym.elemwise_add( sym.elemwise_mul(sym.negative(moving_mean), scale), beta) - shape = [-1 if i == axis else 1 for i in range(len(shape))] # for 2D num_newaxis=len(shape) - axis - 1 if num_newaxis: diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 6e9ff6357854..43ec46d35a82 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -88,7 +88,6 @@ def __rmul__(self, other): return self.__mul__(other) def __div__(self, other): - print('divide') if isinstance(other, Expr): return _op_make.divide(self, other) elif isinstance(other, _Number): diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 6e0e6937310c..7b61fd10f5b0 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -18,8 +18,8 @@ def _register_op_make(): - import sys + from . import _make from .. import expr - expr._op_make = sys.modules['tvm.relay.op._make'] + expr._op_make = _make _register_op_make() diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index a395e74cdf0b..574ae9ea8334 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -116,6 +116,13 @@ inline Expr ReshapeLike(Expr lhs, Expr rhs) { return CallNode::make(op, {lhs, rhs}, Attrs(), {}); } +template +inline Constant MakeConstantScalar(DataType dtype, T value) { + CHECK_EQ(sizeof(T) * 8, dtype.bits()) << "data type mismatch"; + runtime::NDArray arr = runtime::NDArray::Empty({}, Type2TVMType(dtype), {kDLCPU, 0}); + *static_cast(arr->data) = value; + return ConstantNode::make(arr); +} } // namespace relay diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc index 23b21b87355c..bc8de07d2c15 100644 --- a/src/relay/pass/simplify_inference.cc +++ b/src/relay/pass/simplify_inference.cc @@ -10,28 +10,17 @@ namespace tvm { namespace relay { -// TODO: make type generic -Constant make_const(float x) { - DLDataType dtype{kDLFloat, 32, 1}; - runtime::NDArray data = runtime::NDArray::Empty({}, dtype, {kDLCPU, 0}); - float* pdata = static_cast(data->data); - *pdata = x; - Constant n = ConstantNode::make(data); - return n; -} - -Expr -BatchNormToInferUnpack(const Attrs attrs, - Expr data, - Expr gamma, - Expr beta, - Expr moving_mean, - Expr moving_var) { +Expr BatchNormToInferUnpack(const Attrs attrs, + Expr data, + Expr gamma, + Expr beta, + Expr moving_mean, + Expr moving_var) { const auto param = attrs.as(); - Expr epsilon = make_const(param->epsilon); + Expr epsilon = MakeConstantScalar(Float(32), static_cast(param->epsilon)); Expr var_add_eps = CallNode::make(Op::Get("add"), {moving_var, epsilon}); - Expr sqrt = CallNode::make(Op::Get("sqrt"), {var_add_eps}); - Expr scale = CallNode::make(Op::Get("divide"), {make_const(1.0f), sqrt}); + Expr sqrt_var = CallNode::make(Op::Get("sqrt"), {var_add_eps}); + Expr scale = CallNode::make(Op::Get("divide"), {MakeConstantScalar(Float(32), 1.0f), sqrt_var}); if (param->scale) { scale = CallNode::make( @@ -46,15 +35,8 @@ BatchNormToInferUnpack(const Attrs attrs, int axis = param->axis; const auto* tdata = data->type_as(); CHECK(tdata) << "require checked type"; - Array dshape; - for (auto e : tdata->shape) { - CHECK(is_const(e)); - const IntImm* imm = e.as(); - CHECK(imm); - dshape.push_back(Integer(imm->value)); - } - scale = ExpandBiasToMatchAxis(scale, axis, dshape); - shift = ExpandBiasToMatchAxis(shift, axis, dshape); + scale = ExpandBiasToMatchAxis(scale, tdata->shape.size(), {axis}); + shift = ExpandBiasToMatchAxis(shift, tdata->shape.size(), {axis}); Expr out = CallNode::make(Op::Get("multiply"), {data, scale}); out = CallNode::make(Op::Get("add"), {out, shift}); @@ -63,17 +45,21 @@ BatchNormToInferUnpack(const Attrs attrs, class Simplifier : public ExprMutator { public: - Expr VisitExpr_(const CallNode* n) final { - if (const OpNode* op = n->op.as()) { - LOG(INFO) << "op: " << op->name; - if (op->name == "nn.batch_norm") { - LOG(INFO) << n->args; - return BatchNormToInferUnpack(n->attrs, n->args[0], n->args[1], n->args[2], n->args[3], n->args[4]); - } else if (op->name == "nn.dropout") { - return n->args[0]; + Expr VisitExpr_(const TupleGetItemNode* n) final { + static const Op& batch_norm = Op::Get("nn.batch_norm"); + static const Op& dropout = Op::Get("nn.dropout"); + + Expr new_e = ExprMutator::VisitExpr_(n); + const auto* new_n = new_e.as(); + if (const auto* call = new_n->tuple.as()) { + if (call->op.same_as(batch_norm)) { + return BatchNormToInferUnpack(call->attrs, + call->args[0], call->args[1], call->args[2], call->args[3], call->args[4]); + } else if (call->op.same_as(dropout)) { + return call->args[0]; } } - return GetRef(n); + return new_e; } }; diff --git a/tests/python/relay/test_pass_simplify_inference.py b/tests/python/relay/test_pass_simplify_inference.py new file mode 100644 index 000000000000..9830b83dc6e5 --- /dev/null +++ b/tests/python/relay/test_pass_simplify_inference.py @@ -0,0 +1,47 @@ +from tvm import relay as rly +from tvm.relay.ir_pass import simplify_inference, alpha_equal + +def test_simplify_batchnorm(): + def simple_bn(x, gamma, beta, moving_mean, moving_var, + axis=1, epsilon=1e-5, shape=None): + # expect = (x - moving_mean) / sqrt(moving_var + eps) * gamma + beta + scale = rly.multiply(rly.const(1, 'float32') / + rly.sqrt(moving_var + rly.const(epsilon, 'float32')), gamma) + shift = rly.add( + rly.multiply(rly.negative(moving_mean), scale), beta) + num_newaxis = len(shape) - (axis + 1) + if num_newaxis: + scale = rly.expand_dims(scale, axis=1, num_newaxis=num_newaxis) + shift = rly.expand_dims(shift, axis=1, num_newaxis=num_newaxis) + return x * scale + shift + + def check(dim, axis, nstep): + eps = 0.01 + ttype1 = rly.TensorType(tuple(10 for i in range(dim)), 'float32') + ttype2 = rly.TensorType((10,), 'float32') + x = rly.var("x", ttype1) + beta = rly.var("beta", ttype2) + gamma = rly.var("gamma", ttype2) + moving_var = rly.var("moving_var", ttype2) + moving_mean = rly.var("moving_mean", ttype2) + y1, y2 = x, x + + for _ in range(nstep): + y1, _, _ = rly.nn.batch_norm(y1 + rly.const(1, 'float32'), + gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis) + y1 = rly.nn.dropout(y1) + y1 = rly.ir_pass.infer_type(y1) + y1 = simplify_inference(y1) + + y2 = simple_bn(y2 + rly.const(1, 'float32'), + gamma, beta, moving_mean, moving_var, + epsilon=eps, axis=axis, shape=ttype1.shape) + assert rly.ir_pass.graph_equal(y1, y2) + + check(2, 1, 1) + check(4, 1, 1) + check(4, 0, 3) + + +if __name__ == "__main__": + test_simplify_batchnorm() From c7f0d9977c244a33f51c368f5aa7bb027489ea9e Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 30 Oct 2018 16:54:26 -0700 Subject: [PATCH 3/6] [PASS] Fix lint. --- src/relay/pass/pattern_util.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index effe18c313e8..9b5fc1c85350 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -137,7 +137,7 @@ inline Expr ReshapeLike(Expr lhs, Expr rhs) { return CallNode::make(op, {lhs, rhs}, Attrs(), {}); } - + template inline Constant MakeConstantScalar(DataType dtype, T value) { CHECK_EQ(sizeof(T) * 8, dtype.bits()) << "data type mismatch"; From fd8df1942f27efc9326a0440e316ee1d6be5acad Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 30 Oct 2018 20:41:25 -0700 Subject: [PATCH 4/6] [PASS] Update. --- python/tvm/relay/ir_pass.py | 13 +++++++++ src/relay/pass/pattern_util.h | 43 ++++++++++++++++++++++------ src/relay/pass/simplify_inference.cc | 19 ++++++------ 3 files changed, 56 insertions(+), 19 deletions(-) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 910a7dd72f36..f3950fffc45f 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -161,6 +161,19 @@ def free_type_vars(expr): return _ir_pass.free_type_vars(expr) def simplify_inference(expr): + """ Simplify the data-flow graph for inference phase. + + Parameters + ---------- + e: tvm.relay.Expr + The input Expression + + Returns + ------- + result: tvm.relay.Expr + An expression which is semantically equal to the input expression, + but with some simplification + """ return _ir_pass.simplify_inference(expr) def dead_code_elimination(expr): diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 9b5fc1c85350..f8e67bac33c5 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -120,6 +120,40 @@ inline bool IsDepthwiseConv2D(const Call& call, } +/*! + * \brief Create a Constant with a scalar + * + * \param dtype The data type. + * \param value The value of the scalar. + * \return A Constant. + */ +template +inline Constant MakeConstantScalar(DataType dtype, T value) { + CHECK_EQ(sizeof(T) * 8, dtype.bits()) << "data type mismatch"; + runtime::NDArray arr = runtime::NDArray::Empty({}, Type2TVMType(dtype), {kDLCPU, 0}); + *static_cast(arr->data) = value; + return ConstantNode::make(arr); +} + + +inline Expr Negative(Expr x) { + static const Op& op = Op::Get("negative"); + return CallNode::make(op, {x}, Attrs(), {}); +} + + +inline Expr Sqrt(Expr x) { + static const Op& op = Op::Get("sqrt"); + return CallNode::make(op, {x}, Attrs(), {}); +} + + +inline Expr Add(Expr lhs, Expr rhs) { + static const Op& op = Op::Get("add"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); +} + + inline Expr Multiply(Expr lhs, Expr rhs) { static const Op& op = Op::Get("multiply"); return CallNode::make(op, {lhs, rhs}, Attrs(), {}); @@ -137,15 +171,6 @@ inline Expr ReshapeLike(Expr lhs, Expr rhs) { return CallNode::make(op, {lhs, rhs}, Attrs(), {}); } - -template -inline Constant MakeConstantScalar(DataType dtype, T value) { - CHECK_EQ(sizeof(T) * 8, dtype.bits()) << "data type mismatch"; - runtime::NDArray arr = runtime::NDArray::Empty({}, Type2TVMType(dtype), {kDLCPU, 0}); - *static_cast(arr->data) = value; - return ConstantNode::make(arr); -} - } // namespace relay } // namespace tvm #endif // TVM_RELAY_PASS_PATTERN_UTIL_H_ diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc index bc8de07d2c15..2c6493f0a31b 100644 --- a/src/relay/pass/simplify_inference.cc +++ b/src/relay/pass/simplify_inference.cc @@ -18,18 +18,17 @@ Expr BatchNormToInferUnpack(const Attrs attrs, Expr moving_var) { const auto param = attrs.as(); Expr epsilon = MakeConstantScalar(Float(32), static_cast(param->epsilon)); - Expr var_add_eps = CallNode::make(Op::Get("add"), {moving_var, epsilon}); - Expr sqrt_var = CallNode::make(Op::Get("sqrt"), {var_add_eps}); - Expr scale = CallNode::make(Op::Get("divide"), {MakeConstantScalar(Float(32), 1.0f), sqrt_var}); + Expr var_add_eps = Add(moving_var, epsilon); + Expr sqrt_var = Sqrt(var_add_eps); + Expr scale = Divide(MakeConstantScalar(Float(32), 1.0f), sqrt_var); if (param->scale) { - scale = CallNode::make( - Op::Get("multiply"), {scale, gamma}); + scale = Multiply(scale, gamma); } - Expr neg_mean = CallNode::make(Op::Get("negative"), {moving_mean}); - Expr shift = CallNode::make(Op::Get("multiply"), {neg_mean, scale}); + Expr neg_mean = Negative(moving_mean); + Expr shift = Multiply(neg_mean, scale); if (param->center) { - shift = CallNode::make(Op::Get("add"), {shift, beta}); + shift = Add(shift, beta); } int axis = param->axis; @@ -38,8 +37,8 @@ Expr BatchNormToInferUnpack(const Attrs attrs, scale = ExpandBiasToMatchAxis(scale, tdata->shape.size(), {axis}); shift = ExpandBiasToMatchAxis(shift, tdata->shape.size(), {axis}); - Expr out = CallNode::make(Op::Get("multiply"), {data, scale}); - out = CallNode::make(Op::Get("add"), {out, shift}); + Expr out = Multiply(data, scale); + out = Add(out, shift); return out; } From 28cda838653db20fde2565239d5fd3a1fe4333c1 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 30 Oct 2018 20:58:19 -0700 Subject: [PATCH 5/6] [PASS] Update. --- src/relay/pass/simplify_inference.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc index 2c6493f0a31b..4d1a7973e3cf 100644 --- a/src/relay/pass/simplify_inference.cc +++ b/src/relay/pass/simplify_inference.cc @@ -33,7 +33,6 @@ Expr BatchNormToInferUnpack(const Attrs attrs, int axis = param->axis; const auto* tdata = data->type_as(); - CHECK(tdata) << "require checked type"; scale = ExpandBiasToMatchAxis(scale, tdata->shape.size(), {axis}); shift = ExpandBiasToMatchAxis(shift, tdata->shape.size(), {axis}); @@ -50,6 +49,9 @@ class Simplifier : public ExprMutator { Expr new_e = ExprMutator::VisitExpr_(n); const auto* new_n = new_e.as(); + if (new_n->index != 0) { + return new_e; + } if (const auto* call = new_n->tuple.as()) { if (call->op.same_as(batch_norm)) { return BatchNormToInferUnpack(call->attrs, From f75651585056164852c8b20a378679bd8f7c0437 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 30 Oct 2018 22:44:17 -0700 Subject: [PATCH 6/6] [PASS] Update. --- src/relay/pass/simplify_inference.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc index 4d1a7973e3cf..785b486ddc06 100644 --- a/src/relay/pass/simplify_inference.cc +++ b/src/relay/pass/simplify_inference.cc @@ -41,7 +41,7 @@ Expr BatchNormToInferUnpack(const Attrs attrs, return out; } -class Simplifier : public ExprMutator { +class InferenceSimplifier : public ExprMutator { public: Expr VisitExpr_(const TupleGetItemNode* n) final { static const Op& batch_norm = Op::Get("nn.batch_norm"); @@ -65,7 +65,7 @@ class Simplifier : public ExprMutator { }; Expr SimplifyInference(const Expr& e) { - return Simplifier().Mutate(e); + return InferenceSimplifier().Mutate(e); } TVM_REGISTER_API("relay._ir_pass.simplify_inference")