diff --git a/nnvm/tests/python/compiler/test_simplify_inference.py b/nnvm/tests/python/compiler/test_simplify_inference.py index e2826765995e..fd0e1e3c182e 100644 --- a/nnvm/tests/python/compiler/test_simplify_inference.py +++ b/nnvm/tests/python/compiler/test_simplify_inference.py @@ -10,7 +10,6 @@ def simple_bn(x, gamma, beta, moving_mean, moving_var, scale = sym.elemwise_mul(1 / sym.sqrt(moving_var + epsilon), gamma) shift = sym.elemwise_add( sym.elemwise_mul(sym.negative(moving_mean), scale), beta) - shape = [-1 if i == axis else 1 for i in range(len(shape))] # for 2D num_newaxis=len(shape) - axis - 1 if num_newaxis: diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 0650a493d9a6..43ec46d35a82 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -1,6 +1,7 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """The expression nodes of Relay.""" from __future__ import absolute_import +from numbers import Number as _Number import numpy as _np from .base import RelayNode, register_relay_node @@ -11,6 +12,8 @@ from .. import nd as _nd from .. import convert +# will be registered afterwards +_op_make = None class Expr(RelayNode): """The base type for all Relay expressions.""" @@ -48,6 +51,62 @@ def astype(self, dtype): """ return _make.dtype_cast(self, dtype) + def __add__(self, other): + if isinstance(other, Expr): + return _op_make.add(self, other) + elif isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + + def __radd__(self, other): + return self.__add__(other) + + def __sub__(self, other): + if isinstance(other, Expr): + return _op_make.subtract(self, other) + elif isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + + def __rsub__(self, other): + if isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + + def __mul__(self, other): + if isinstance(other, Expr): + return _op_make.multiply(self, other) + elif isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + + def __rmul__(self, other): + return self.__mul__(other) + + def __div__(self, other): + if isinstance(other, Expr): + return _op_make.divide(self, other) + elif isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + + def __rdiv__(self, other): + if isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + + def __truediv__(self, other): + return self.__div__(other) + + def __rtruediv__(self, other): + return self.__rdiv__(other) + @register_relay_node class Constant(Expr): @@ -305,7 +364,7 @@ def __len__(self): def __repr__(self): return ("TupleWrapper(" + self.tuple_value.__repr__() + - ", " + self.size + ")") + ", " + str(self.size) + ")") def astype(self, _): raise TypeError("astype cannot be used on tuple") diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 68a07f190d42..f3950fffc45f 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -160,6 +160,21 @@ def free_type_vars(expr): """ return _ir_pass.free_type_vars(expr) +def simplify_inference(expr): + """ Simplify the data-flow graph for inference phase. + + Parameters + ---------- + e: tvm.relay.Expr + The input Expression + + Returns + ------- + result: tvm.relay.Expr + An expression which is semantically equal to the input expression, + but with some simplification + """ + return _ir_pass.simplify_inference(expr) def dead_code_elimination(expr): """ Remove expressions which does not effect the program result (dead code). diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index c0af986be4f7..7b61fd10f5b0 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -15,3 +15,11 @@ from . import _tensor from ..expr import Expr from ..base import register_relay_node + + +def _register_op_make(): + from . import _make + from .. import expr + expr._op_make = _make + +_register_op_make() diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index a41e6c35b93a..f8e67bac33c5 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -120,6 +120,40 @@ inline bool IsDepthwiseConv2D(const Call& call, } +/*! + * \brief Create a Constant with a scalar + * + * \param dtype The data type. + * \param value The value of the scalar. + * \return A Constant. + */ +template +inline Constant MakeConstantScalar(DataType dtype, T value) { + CHECK_EQ(sizeof(T) * 8, dtype.bits()) << "data type mismatch"; + runtime::NDArray arr = runtime::NDArray::Empty({}, Type2TVMType(dtype), {kDLCPU, 0}); + *static_cast(arr->data) = value; + return ConstantNode::make(arr); +} + + +inline Expr Negative(Expr x) { + static const Op& op = Op::Get("negative"); + return CallNode::make(op, {x}, Attrs(), {}); +} + + +inline Expr Sqrt(Expr x) { + static const Op& op = Op::Get("sqrt"); + return CallNode::make(op, {x}, Attrs(), {}); +} + + +inline Expr Add(Expr lhs, Expr rhs) { + static const Op& op = Op::Get("add"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); +} + + inline Expr Multiply(Expr lhs, Expr rhs) { static const Op& op = Op::Get("multiply"); return CallNode::make(op, {lhs, rhs}, Attrs(), {}); diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc new file mode 100644 index 000000000000..785b486ddc06 --- /dev/null +++ b/src/relay/pass/simplify_inference.cc @@ -0,0 +1,77 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file simplify_inference.cc + */ +#include +#include +#include +#include "./pattern_util.h" + +namespace tvm { +namespace relay { + +Expr BatchNormToInferUnpack(const Attrs attrs, + Expr data, + Expr gamma, + Expr beta, + Expr moving_mean, + Expr moving_var) { + const auto param = attrs.as(); + Expr epsilon = MakeConstantScalar(Float(32), static_cast(param->epsilon)); + Expr var_add_eps = Add(moving_var, epsilon); + Expr sqrt_var = Sqrt(var_add_eps); + Expr scale = Divide(MakeConstantScalar(Float(32), 1.0f), sqrt_var); + + if (param->scale) { + scale = Multiply(scale, gamma); + } + Expr neg_mean = Negative(moving_mean); + Expr shift = Multiply(neg_mean, scale); + if (param->center) { + shift = Add(shift, beta); + } + + int axis = param->axis; + const auto* tdata = data->type_as(); + scale = ExpandBiasToMatchAxis(scale, tdata->shape.size(), {axis}); + shift = ExpandBiasToMatchAxis(shift, tdata->shape.size(), {axis}); + + Expr out = Multiply(data, scale); + out = Add(out, shift); + return out; +} + +class InferenceSimplifier : public ExprMutator { + public: + Expr VisitExpr_(const TupleGetItemNode* n) final { + static const Op& batch_norm = Op::Get("nn.batch_norm"); + static const Op& dropout = Op::Get("nn.dropout"); + + Expr new_e = ExprMutator::VisitExpr_(n); + const auto* new_n = new_e.as(); + if (new_n->index != 0) { + return new_e; + } + if (const auto* call = new_n->tuple.as()) { + if (call->op.same_as(batch_norm)) { + return BatchNormToInferUnpack(call->attrs, + call->args[0], call->args[1], call->args[2], call->args[3], call->args[4]); + } else if (call->op.same_as(dropout)) { + return call->args[0]; + } + } + return new_e; + } +}; + +Expr SimplifyInference(const Expr& e) { + return InferenceSimplifier().Mutate(e); +} + +TVM_REGISTER_API("relay._ir_pass.simplify_inference") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = SimplifyInference(args[0]); + }); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_pass_simplify_inference.py b/tests/python/relay/test_pass_simplify_inference.py new file mode 100644 index 000000000000..9830b83dc6e5 --- /dev/null +++ b/tests/python/relay/test_pass_simplify_inference.py @@ -0,0 +1,47 @@ +from tvm import relay as rly +from tvm.relay.ir_pass import simplify_inference, alpha_equal + +def test_simplify_batchnorm(): + def simple_bn(x, gamma, beta, moving_mean, moving_var, + axis=1, epsilon=1e-5, shape=None): + # expect = (x - moving_mean) / sqrt(moving_var + eps) * gamma + beta + scale = rly.multiply(rly.const(1, 'float32') / + rly.sqrt(moving_var + rly.const(epsilon, 'float32')), gamma) + shift = rly.add( + rly.multiply(rly.negative(moving_mean), scale), beta) + num_newaxis = len(shape) - (axis + 1) + if num_newaxis: + scale = rly.expand_dims(scale, axis=1, num_newaxis=num_newaxis) + shift = rly.expand_dims(shift, axis=1, num_newaxis=num_newaxis) + return x * scale + shift + + def check(dim, axis, nstep): + eps = 0.01 + ttype1 = rly.TensorType(tuple(10 for i in range(dim)), 'float32') + ttype2 = rly.TensorType((10,), 'float32') + x = rly.var("x", ttype1) + beta = rly.var("beta", ttype2) + gamma = rly.var("gamma", ttype2) + moving_var = rly.var("moving_var", ttype2) + moving_mean = rly.var("moving_mean", ttype2) + y1, y2 = x, x + + for _ in range(nstep): + y1, _, _ = rly.nn.batch_norm(y1 + rly.const(1, 'float32'), + gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis) + y1 = rly.nn.dropout(y1) + y1 = rly.ir_pass.infer_type(y1) + y1 = simplify_inference(y1) + + y2 = simple_bn(y2 + rly.const(1, 'float32'), + gamma, beta, moving_mean, moving_var, + epsilon=eps, axis=axis, shape=ttype1.shape) + assert rly.ir_pass.graph_equal(y1, y2) + + check(2, 1, 1) + check(4, 1, 1) + check(4, 0, 3) + + +if __name__ == "__main__": + test_simplify_batchnorm()