From 3f80ed3c426d4d077fcbea9d43839014cd6a2fbf Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 24 Nov 2018 14:16:36 -0800 Subject: [PATCH 1/2] [RELAY][FRONTEND] Initial MXNet frontend support. --- docs/langref/relay_op.rst | 6 +- include/tvm/relay/attrs/nn.h | 2 +- include/tvm/relay/attrs/transform.h | 2 +- nnvm/src/top/tensor/transform.cc | 4 +- python/tvm/relay/__init__.py | 4 +- python/tvm/relay/backend/compile_engine.py | 14 +- .../relay/backend/graph_runtime_codegen.py | 2 +- python/tvm/relay/build_module.py | 3 +- python/tvm/relay/frontend/__init__.py | 4 + python/tvm/relay/frontend/common.py | 129 ++++ python/tvm/relay/frontend/mxnet.py | 606 ++++++++++++++++++ python/tvm/relay/op/__init__.py | 1 + python/tvm/relay/op/_reduce.py | 19 + python/tvm/relay/op/_tensor.py | 3 +- python/tvm/relay/op/_transform.py | 62 +- python/tvm/relay/op/nn/_nn.py | 2 + python/tvm/relay/op/nn/nn.py | 4 +- python/tvm/relay/testing/inception_v3.py | 10 +- python/tvm/relay/testing/squeezenet.py | 69 +- python/tvm/relay/testing/vgg.py | 39 +- src/lang/attr_functor.h | 1 + src/relay/backend/graph_plan_memory.cc | 5 +- src/relay/op/op_common.h | 2 +- src/relay/op/tensor/reduce.cc | 172 ++++- src/relay/op/tensor/transform.cc | 178 ++++- src/relay/pass/fold_scale_axis.cc | 9 +- .../frontend/mxnet/model_zoo/__init__.py | 59 ++ .../python/frontend/mxnet/model_zoo/dcgan.py | 66 ++ tests/python/frontend/mxnet/model_zoo/dqn.py | 27 + .../frontend/mxnet/model_zoo/inception_v3.py | 170 +++++ tests/python/frontend/mxnet/model_zoo/mlp.py | 40 ++ .../python/frontend/mxnet/model_zoo/resnet.py | 199 ++++++ .../frontend/mxnet/model_zoo/squeezenet.py | 76 +++ tests/python/frontend/mxnet/model_zoo/vgg.py | 85 +++ tests/python/frontend/mxnet/test_forward.py | 214 +++++++ tests/python/frontend/mxnet/test_graph.py | 101 +++ tests/python/relay/test_op_level3.py | 2 +- topi/include/topi/transform.h | 25 +- topi/include/topi/vision/yolo/region.h | 2 +- 39 files changed, 2228 insertions(+), 190 deletions(-) create mode 100644 python/tvm/relay/frontend/__init__.py create mode 100644 python/tvm/relay/frontend/common.py create mode 100644 python/tvm/relay/frontend/mxnet.py create mode 100644 python/tvm/relay/op/_reduce.py create mode 100644 tests/python/frontend/mxnet/model_zoo/__init__.py create mode 100644 tests/python/frontend/mxnet/model_zoo/dcgan.py create mode 100644 tests/python/frontend/mxnet/model_zoo/dqn.py create mode 100644 tests/python/frontend/mxnet/model_zoo/inception_v3.py create mode 100644 tests/python/frontend/mxnet/model_zoo/mlp.py create mode 100644 tests/python/frontend/mxnet/model_zoo/resnet.py create mode 100644 tests/python/frontend/mxnet/model_zoo/squeezenet.py create mode 100644 tests/python/frontend/mxnet/model_zoo/vgg.py create mode 100644 tests/python/frontend/mxnet/test_forward.py create mode 100644 tests/python/frontend/mxnet/test_graph.py diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 95581a54e5a1..e7fda319cb9c 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -82,6 +82,7 @@ This level enables additional math and transform operators. tvm.relay.reshape_like tvm.relay.copy tvm.relay.transpose + tvm.relay.squeeze tvm.relay.floor tvm.relay.ceil tvm.relay.trunc @@ -114,7 +115,7 @@ This level enables additional math and transform operators. tvm.relay.less_equal tvm.relay.maximum tvm.relay.minimum - tvm.relay.pow + tvm.relay.power tvm.relay.where tvm.relay.argmax tvm.relay.argmin @@ -196,6 +197,7 @@ Level 3 Definitions .. autofunction:: tvm.relay.reshape .. autofunction:: tvm.relay.reshape_like .. autofunction:: tvm.relay.copy +.. autofunction:: tvm.relay.squeeze .. autofunction:: tvm.relay.transpose .. autofunction:: tvm.relay.take .. autofunction:: tvm.relay.zeros @@ -220,7 +222,7 @@ Level 4 Definitions .. autofunction:: tvm.relay.less_equal .. autofunction:: tvm.relay.maximum .. autofunction:: tvm.relay.minimum -.. autofunction:: tvm.relay.pow +.. autofunction:: tvm.relay.power .. autofunction:: tvm.relay.where .. autofunction:: tvm.relay.argmax .. autofunction:: tvm.relay.argmin diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 33f18a89e3e8..817ee04bd844 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -89,7 +89,7 @@ struct SoftmaxAttrs : public tvm::AttrsNode { int axis; TVM_DECLARE_ATTRS(SoftmaxAttrs, "relay.attrs.SoftmaxAttrs") { - TVM_ATTR_FIELD(axis).set_default(1) + TVM_ATTR_FIELD(axis).set_default(-1) .describe("The axis to sum over when computing softmax."); } }; diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 7a8129180c4d..39cd82de83e2 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -62,7 +62,7 @@ struct TransposeAttrs : public tvm::AttrsNode { /*! \brief Attributes used in reshape operators */ struct ReshapeAttrs : public tvm::AttrsNode { - Array newshape; + Array newshape; TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") { TVM_ATTR_FIELD(newshape) .describe("The new shape. Should be compatible with the original shape."); diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index 492208ed7a7c..6d8b75118a77 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -420,9 +420,9 @@ along which to split the array. return Array{ topi::split_sections(inputs[0], param.indices_or_sections[0], param.axis) }; } else { - Array indices; + Array indices; for (auto i : param.indices_or_sections) { - indices.push_back(tvm::make_const(tvm::Int(32), i)); + indices.push_back(static_cast(i)); } return Array{ topi::split(inputs[0], indices, param.axis) }; } diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 92e1e72fdac2..6b071f65a794 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -7,7 +7,7 @@ from . import expr from . import module from . import ir_pass -from .build_module import build, create_executor +from .build_module import build, build_config, create_executor # Root operators from .op import Op @@ -17,6 +17,7 @@ from . import nn from . import vision from . import image +from . import frontend from . import backend from .scope_builder import ScopeBuilder @@ -40,6 +41,7 @@ scalar_type = ty.scalar_type # Expr +Expr = expr.Expr Constant = expr.Constant Tuple = expr.Tuple Var = expr.Var diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index a02579e2ac7a..1f7ab18677c4 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -72,8 +72,18 @@ def lower(self, source_func, target=None): cached_func: CachedFunc The result of lowering. """ - key = _get_cache_key(source_func, target) - return _backend._CompileEngineLower(self, key) + # pylint: disable=broad-except + try: + key = _get_cache_key(source_func, target) + return _backend._CompileEngineLower(self, key) + except Exception: + import traceback + msg = traceback.format_exc() + msg += "Error during compile func\n" + msg += "--------------------------\n" + msg += source_func.astext(show_meta_data=False) + msg += "--------------------------\n" + raise RuntimeError(msg) def jit(self, source_func, target=None): """JIT a source_func to a tvm.Function. diff --git a/python/tvm/relay/backend/graph_runtime_codegen.py b/python/tvm/relay/backend/graph_runtime_codegen.py index 50568b58607b..4351fea6b459 100644 --- a/python/tvm/relay/backend/graph_runtime_codegen.py +++ b/python/tvm/relay/backend/graph_runtime_codegen.py @@ -357,4 +357,4 @@ def _get_unique_name(self, name): return name index = self._name_map[name] self._name_map[name] += 1 - return self.get_unique_name(name + str(index)) + return self._get_unique_name(name + str(index)) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 5a45ac276de9..d67bc89702d3 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -13,7 +13,7 @@ # List of optimization pass and level when switch on OPT_PASS_LEVEL = { "SimplifyInference": 0, - "CombineParallelConv2D": 1, + "CombineParallelConv2D": 4, "OpFusion": 1, "FoldConstant": 2, "FoldScaleAxis": 3, @@ -157,7 +157,6 @@ def optimize(func, params=None): if cfg.pass_enabled("FoldConstant"): func = ir_pass.fold_constant(func) - return func diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py new file mode 100644 index 000000000000..28766b9ae3be --- /dev/null +++ b/python/tvm/relay/frontend/__init__.py @@ -0,0 +1,4 @@ +"""Relay frontends.""" +from __future__ import absolute_import + +from .mxnet import from_mxnet diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py new file mode 100644 index 000000000000..8e037d4bc554 --- /dev/null +++ b/python/tvm/relay/frontend/common.py @@ -0,0 +1,129 @@ +"""Common utilities""" +from __future__ import absolute_import as _abs + + +class RequiredAttr(object): + """Dummpy class to represent required attr""" + pass + + +class StrAttrsDict(object): + """Helper class to parse attrs stored as Dict[str, str]. + + Parameters + ---------- + attrs : Dict[str, str] + The attributes to be used. + """ + def __init__(self, attrs): + self.attrs = attrs + + def get_float(self, key, default=RequiredAttr()): + """Get float attribute + + Parameters + ---------- + key : str + The attribute key + + default : float + The default value. + + Returns + ------- + value : The result + """ + if key in self.attrs: + return float(self.attrs[key]) + if isinstance(default, RequiredAttr): + raise AttributeError("Required attribute {} not found.".format(key)) + return default + + def get_int(self, key, default=RequiredAttr()): + """Get int attribute + + Parameters + ---------- + key : str + The attribute key + + default : float + The default value. + + Returns + ------- + value : The result + """ + if key in self.attrs: + val = self.attrs[key] + if val == "None": + return None + return int(val) + if isinstance(default, RequiredAttr): + raise AttributeError("Required attribute {} not found.".format(key)) + return default + + def get_str(self, key, default=RequiredAttr()): + """Get str attribute + + Parameters + ---------- + key : str + The attribute key + + default : float + The default value. + + Returns + ------- + value : The result + """ + if key in self.attrs: + return self.attrs[key] + if isinstance(default, RequiredAttr): + raise AttributeError("Required attribute {} not found.".format(key)) + return default + + def get_int_tuple(self, key, default=RequiredAttr()): + """Get int tuple attribute + + Parameters + ---------- + key : str + The attribute key + + default : float + The default value. + + Returns + ------- + value : The result + """ + if key in self.attrs: + tshape = self.attrs[key] + return tuple(int(x.strip()) for x in tshape.strip('()').split(',')) + if isinstance(default, RequiredAttr): + raise AttributeError("Required attribute {} not found.".format(key)) + return default + + def get_bool(self, key, default=RequiredAttr()): + """Get bool tuple attribute + + Parameters + ---------- + key : str + The attribute key + + default : float + The default value. + + Returns + ------- + value : The result + """ + if key in self.attrs: + val = self.attrs[key] + return val.strip().lower() in ['true', '1', 't', 'y', 'yes'] + if isinstance(default, RequiredAttr): + raise AttributeError("Required attribute {} not found.".format(key)) + return default diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py new file mode 100644 index 000000000000..c5f200ccf835 --- /dev/null +++ b/python/tvm/relay/frontend/mxnet.py @@ -0,0 +1,606 @@ +# pylint: disable=invalid-name, import-self, len-as-condition +"""MXNet symbol frontend.""" +from __future__ import absolute_import as _abs + +import json +from .. import ir_pass +from .. import expr as _expr +from .. import op as _op +from ... import nd as _nd +from .common import StrAttrsDict + +__all__ = ['from_mxnet'] + + +def _get_relay_op(op_name): + op = getattr(_op, op_name) + if not op: + raise RuntimeError("Unable to map op_name {} to relay".format(op_name)) + return op + + +def _warn_not_used(attr, op='nnvm'): + import warnings + err = "{} is ignored in {}.".format(attr, op) + warnings.warn(err) + + +def _rename(new_op): + if isinstance(new_op, str): + new_op = _get_relay_op(new_op) + # attrs are ignored. + def impl(inputs, _): + return new_op(*inputs) + return impl + + +def _reshape(inputs, attrs): + if attrs.get_bool("reverse", False): + raise RuntimeError("reshape do not support option reverse") + shape = attrs.get_int_tuple("shape") + return _op.reshape(inputs[0], newshape=shape) + + +def _init_op(new_op): + """Init ops like zeros/ones""" + def _impl(inputs, attrs): + assert len(inputs) == 0 + shape = attrs.get_int_tuple("shape") + dtype = attrs.get_str("dtype", "float32") + return new_op(shape=shape, dtype=dtype) + return _impl + + +def _softmax_op(new_op): + """softmax/log_softmax""" + def _impl(inputs, attrs): + assert len(inputs) == 1 + axis = attrs.get_int("axis", -1) + return new_op(inputs[0], axis=axis) + return _impl + + +def _reduce(new_op): + """Reduction ops like sum/min/max""" + def _impl(inputs, attrs): + assert len(inputs) == 1 + axis = attrs.get_int_tuple("axis", []) + keepdims = attrs.get_bool("keepdims", False) + # use None for reduce over all axis. + axis = None if len(axis) == 0 else axis + return new_op(inputs[0], axis=axis, keepdims=keepdims) + return _impl + + +def _arg_reduce(new_op): + """Arg Reduction ops like argmin/argmax""" + def _impl(inputs, attrs): + assert len(inputs) == 1 + axis = attrs.get_int("axis", None) + keepdims = attrs.get_bool("keepdims", False) + res = new_op(inputs[0], axis=[axis], keepdims=keepdims) + # cast to dtype. + res = res.astype("float32") + return res + return _impl + + +def _cast(inputs, attrs): + """Type cast""" + dtype = attrs.get_str("dtype") + return _op.cast(inputs[0], dtype=dtype) + + +def _clip(inputs, attrs): + a_min = attrs.get_float("a_min") + a_max = attrs.get_float("a_max") + return _op.clip(inputs[0], a_min=a_min, a_max=a_max) + + +def _transpose(inputs, attrs): + axes = attrs.get_int_tuple("axes", None) + # translate default case + axes = None if len(axes) == 0 else axes + return _op.transpose(inputs[0], axes=axes) + + +def _upsampling(inputs, attrs): + scale = attrs.get_int("scale") + return _op.nn.upsampling(inputs[0], scale=scale) + + +def _elemwise_sum(inputs, _): + assert len(inputs) > 0 + res = inputs[0] + for x in inputs[1:]: + res = _op.add(res, x) + return res + + +def _binop_scalar(new_op): + def _impl(inputs, attrs): + assert len(inputs) == 1 + scalar = attrs.get_float("scalar") + # Note: binary scalar only works for float op for now + scalar = _expr.const(scalar, dtype="float32") + return new_op(inputs[0], scalar) + return _impl + + +def _rbinop_scalar(new_op): + def _impl(inputs, attrs): + assert len(inputs) == 1 + scalar = attrs.get_float("scalar") + # Note: binary scalar only works for float op for now + scalar = _expr.const(scalar, dtype="float32") + return new_op(scalar, inputs[0]) + return _impl + +# All the functions with _mx prefix specific to MXNet. +# The functions without _mx prefix can be reused for +# NNVMv1 conversion to _op. + +def _mx_fully_connected(inputs, attrs): + import mxnet as mx + units = attrs.get_int("num_hidden") + use_bias = not attrs.get_bool("no_bias", False) + try: + _ = mx.sym.FullyConnected(mx.sym.var("x"), num_hidden=1, flatten=True) + has_flatten = True + except mx.base.MXNetError: + # no flatten attribute in old mxnet + has_flatten = False + use_flatten = attrs.get_bool("flatten", True) + if has_flatten and use_flatten: + inputs[0] = _op.nn.batch_flatten(inputs[0]) + res = _op.nn.dense(inputs[0], inputs[1], units=units) + if use_bias: + assert len(inputs) == 3 + res = _op.nn.bias_add(res, inputs[2]) + return res + + +def _get_channel_axis(layout, op_name): + if layout == "NCHW": + return 1 + elif layout == "NHWC": + return 3 + raise RuntimeError("layout: {} is not supported in {}".format(layout, op_name)) + + +def _mx_activations(inputs, attrs): + act_type = attrs.get_str("act_type") + assert len(inputs) == 1 + if act_type == "sigmoid": + return _op.sigmoid(inputs[0]) + elif act_type == "tanh": + return _op.tanh(inputs[0]) + elif act_type == "relu": + return _op.nn.relu(inputs[0]) + elif act_type == "softrelu": + def _stable_softrelu(x): + # log(1 + exp(-abs(x))) + relu(x) + one = _expr.const(1, dtype="float32") + exp_neg_abs_x = _op.exp(_op.negative(_op.abs(x))) + return _op.add(_op.log(_op.add(one, exp_neg_abs_x)), + _op.nn.relu(x)) + return _stable_softrelu(inputs[0]) + raise RuntimeError("Do not support act_type: {}".format(act_type)) + + +def _mx_conv2d(inputs, attrs): + kernel_size = attrs.get_int_tuple("kernel") + if len(kernel_size) != 2: + raise RuntimeError("non-2d kernel is not supported in conv2d") + data_layout = attrs.get_str("layout", "NCHW") + channel_axis = _get_channel_axis(data_layout, "conv2d") + + if "kernel_layout" in attrs.attrs: + weight_layout = attrs.get_str("kernel_layout") + else: + weight_layout = "HWIO" if data_layout == "NHWC" else "OIHW" + + new_attrs = {} + new_attrs["channels"] = attrs.get_int("num_filter") + new_attrs["kernel_size"] = kernel_size + new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1)) + new_attrs["padding"] = attrs.get_int_tuple("pad", (0, 0)) + new_attrs["dilation"] = attrs.get_int_tuple("dilate", (1, 1)) + new_attrs["groups"] = attrs.get_int("num_group", 1) + new_attrs["data_layout"] = data_layout + new_attrs["weight_layout"] = weight_layout + use_bias = not attrs.get_bool("no_bias", False) + res = _op.nn.conv2d(inputs[0], inputs[1], **new_attrs) + if use_bias: + assert len(inputs) == 3 + res = _op.nn.bias_add(res, inputs[2], axis=channel_axis) + return res + + +def _mx_conv2d_transpose(inputs, attrs): + if "target_shape" in attrs.attrs: + raise RuntimeError("target_shape is not supported in conv2d_transpose") + kernel_size = attrs.get_int_tuple("kernel") + if len(kernel_size) != 2: + raise RuntimeError("non-2d kernel is not supported in conv2d") + data_layout = attrs.get_str("layout", "NCHW") + channel_axis = _get_channel_axis(data_layout, "conv2d_transpose") + + if "kernel_layout" in attrs.attrs: + weight_layout = attrs.get_str("kernel_layout") + else: + weight_layout = "HWIO" if data_layout == "NHWC" else "OIHW" + + new_attrs = {} + new_attrs["channels"] = attrs.get_int("num_filter") + new_attrs["kernel_size"] = kernel_size + new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1)) + new_attrs["output_padding"] = attrs.get_int_tuple("adj", (0, 0)) + new_attrs["padding"] = attrs.get_int_tuple("pad", (0, 0)) + new_attrs["dilation"] = attrs.get_int_tuple("dilate", (1, 1)) + new_attrs["groups"] = attrs.get_int("num_group", 1) + new_attrs["data_layout"] = data_layout + new_attrs["weight_layout"] = weight_layout + use_bias = not attrs.get_bool("no_bias", False) + res = _op.nn.conv2d_transpose(inputs[0], inputs[1], **new_attrs) + + if use_bias: + assert len(inputs) == 3 + res = _op.nn.bias_add(res, inputs[2], axis=channel_axis) + return res + + +def _mx_pooling(inputs, attrs): + global_pool = attrs.get_bool("global_pool", False) + pool_type = attrs.get_str("pool_type") + + def _pool2d(new_op, is_avg): + kernel_size = attrs.get_int_tuple("kernel") + if len(kernel_size) != 2: + raise RuntimeError("non-2d kernel is not supported in pool2d") + new_attrs = {} + new_attrs["pool_size"] = kernel_size + new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1)) + new_attrs["padding"] = attrs.get_int_tuple("pad", (0, 0)) + new_attrs["ceil_mode"] = (attrs.get_str("pooling_convention", "valid") == "full") + if is_avg: + new_attrs["count_include_pad"] = attrs.get_bool("count_include_pad", True) + return new_op(inputs[0], **new_attrs) + + if pool_type == "max": + if global_pool: + return _op.nn.global_max_pool2d(inputs[0]) + return _pool2d(_op.nn.max_pool2d, False) + elif pool_type == "avg": + if global_pool: + return _op.nn.global_avg_pool2d(inputs[0]) + return _pool2d(_op.nn.avg_pool2d, True) + raise RuntimeError("Do not support pool_type:{}".format(pool_type)) + + +def _mx_dropout(inputs, attrs): + rate = attrs.get_float("p", 0.5) + return _op.nn.dropout(inputs[0], rate=rate) + + +def _mx_batch_norm(inputs, attrs): + if attrs.get_bool("output_mean_var", False): + raise RuntimeError("batch_norm do not support output_mean_var") + if attrs.get_bool("use_global_stats", False): + _warn_not_used("use_global_stats", "batch_norm") + new_attrs = {} + new_attrs["axis"] = attrs.get_int("axis", 1) + new_attrs["epsilon"] = attrs.get_float("eps", 0.001) + new_attrs["center"] = True + new_attrs["scale"] = not attrs.get_bool("fix_gamma", False) + return _op.nn.batch_norm(*inputs, **new_attrs) + + +def _mx_split(inputs, attrs): + axis = attrs.get_int("axis", 1) + new_attrs = {} + new_attrs["indices_or_sections"] = attrs.get_int("num_outputs") + new_attrs["axis"] = axis + res = _op.split(inputs[0], **new_attrs) + if attrs.get_bool("squeeze_axis", False): + return tuple([_op.squeeze(x, axis=[axis]) for x in res]) + return res + + +def _mx_softmax_activation(inputs, attrs): + mode = attrs.get_str("mode", "instance") + axis = 0 if mode == "instance" else 1 + return _op.nn.softmax(inputs[0], axis=axis) + + +def _mx_softmax_output(inputs, attrs): + if attrs.get_bool("multi_output", False): + return _op.nn.softmax(inputs[0], axis=1) + return _op.nn.softmax(inputs[0]) + + +def _mx_concat(inputs, attrs): + axis = attrs.get_int("dim", 1) + return _op.concatenate(tuple(inputs), axis=axis) + + +def _mx_expand_dims(inputs, attrs): + axis = attrs.get_int("axis") + return _op.expand_dims(inputs[0], axis=axis) + + +def _mx_leaky_relu(inputs, attrs): + act_type = attrs.get_str("act_type") + if act_type == "leaky": + return _op.nn.leaky_relu(inputs[0], alpha=attrs.get_float("slope", 0.25)) + elif act_type == "prelu": + assert len(inputs) == 2 + return _op.nn.prelu(*inputs) + elif act_type == "elu": + # -slope * relu(1-exp(x)) + relu(x) + slope = attrs.get_float("slope", 0.25) + one = _expr.const(1, dtype="float32") + x = inputs[0] + mslope = _op.nn.relu(_op.subtract(one, _op.exp(x))) + mslope = _op.multiply(mslope, _expr.const(-slope, dtype="float32")) + return _op.add(mslope, _op.nn.relu(x)) + elif act_type == "rrelu": + # NOTE this is only converted for inference. + lower_bound = attrs.get_float("lower_bound") + upper_bound = attrs.get_float("upper_bound") + alpha = (lower_bound + upper_bound) / 2.0 + return _op.nn.leaky_relu(inputs[0], alpha=alpha) + raise RuntimeError("act_type: {} is not supported".format(act_type)) + + +def _mx_lrn(inputs, attrs): + new_attrs = {} + new_attrs["alpha"] = attrs.get_float("alpha", 0.0001) + new_attrs["beta"] = attrs.get_float("beta", 0.75) + new_attrs["bias"] = attrs.get_float("knorm", 2) + # NCHW format and normalization along channel axis + new_attrs["axis"] = 1 + new_attrs["size"] = attrs.get_int("nsize") + assert len(inputs) == 1 + return _op.nn.lrn(inputs[0], **new_attrs) + + +# Note: due to attribute conversion constraint +# ops in the identity set must be attribute free +_identity_list = [ + "log", + "exp", + "sigmoid", + "tanh", + "exp", + "negative", + "reshape_like", + "slice_like", + "zeros_like", + "ones_like", +] + +_convert_map = { + "_copy" : _rename(_op.copy), + "relu" : _rename(_op.nn.relu), + "broadcast_add" : _rename(_op.add), + "broadcast_sub" : _rename(_op.subtract), + "broadcast_mul" : _rename(_op.multiply), + "broadcast_div" : _rename(_op.divide), + "elemwise_add" : _rename(_op.add), + "elemwise_sub" : _rename(_op.subtract), + "elemwise_mul" : _rename(_op.multiply), + "elemwise_div" : _rename(_op.divide), + "flatten" : _rename(_op.nn.batch_flatten), + "Flatten" : _rename(_op.nn.batch_flatten), + "_plus_scalar" : _binop_scalar(_op.add), + "__add_scalar__": _binop_scalar(_op.add), + "__sub_scalar__": _binop_scalar(_op.subtract), + "_minus_scalar" : _binop_scalar(_op.subtract), + "__mul_scalar__": _binop_scalar(_op.multiply), + "_mul_scalar" : _binop_scalar(_op.multiply), + "__div_scalar__": _binop_scalar(_op.divide), + "_div_scalar" : _binop_scalar(_op.divide), + "__pow_scalar__": _binop_scalar(_op.power), + "_rminus_scalar": _rbinop_scalar(_op.subtract), + "__rsub_scalar__": _rbinop_scalar(_op.subtract), + "_rdiv_scalar" : _rbinop_scalar(_op.divide), + "__rdiv_scalar__" : _rbinop_scalar(_op.divide), + "__rpow_scalar__": _rbinop_scalar(_op.power), + # reduction ops + "max" : _reduce(_op.max), + "min" : _reduce(_op.min), + "sum" : _reduce(_op.sum), + "max_axis" : _reduce(_op.max), + "min_axis" : _reduce(_op.min), + "sum_axis" : _reduce(_op.sum), + "argmax" : _arg_reduce(_op.argmax), + "argmin" : _arg_reduce(_op.argmin), + # init ops + "_ones" : _init_op(_op.ones), + "_zeros" : _init_op(_op.zeros), + # softmax + "softmax" : _softmax_op(_op.nn.softmax), + "log_softmax" : _softmax_op(_op.nn.log_softmax), + "Softmax" : _softmax_op(_op.nn.softmax), + # per op specialization + "Reshape" : _reshape, + "reshape" : _reshape, + "Cast" : _cast, + "clip" : _clip, + "transpose" : _transpose, + "UpSampling" : _upsampling, + "add_n" : _elemwise_sum, + # MXNet specific implementations + "FullyConnected": _mx_fully_connected, + "Activation" : _mx_activations, + "Convolution" : _mx_conv2d, + "Convolution_v1": _mx_conv2d, + "Deconvolution" : _mx_conv2d_transpose, + "Pooling" : _mx_pooling, + "Pooling_v1" : _mx_pooling, + "Dropout" : _mx_dropout, + "BatchNorm" : _mx_batch_norm, + "BatchNorm_v1" : _mx_batch_norm, + "LRN" : _mx_lrn, + "SliceChannel" : _mx_split, + "split" : _mx_split, + "expand_dims" : _mx_expand_dims, + "Concat" : _mx_concat, + "concat" : _mx_concat, + "LeakyReLU" : _mx_leaky_relu, + "SoftmaxOutput" : _mx_softmax_output, + "SoftmaxActivation" : _mx_softmax_activation, + # List of missing operators that are present in NNVMv1 + # TODO(tvm-tvm): support all operators. + # + # "broadcast_to", + # "gather_nd", + # "_contrib_MultiBoxPrior" : _rename("multibox_prior"), + # "_contrib_MultiBoxDetection" : _contrib_multibox_detection, + # "Crop" : _crop_like, + +} + +# set identity list +_convert_map.update({k : _rename(k) for k in _identity_list}) + + +def _from_mxnet_impl(symbol, shape_dict, dtype_info): + """Convert mxnet symbol to nnvm implementation. + + Reconstruct a relay Function by traversing the mxnet symbol. + + Parameters + ---------- + symbol : mxnet.sym.Symbol + Incompatible symbol from mxnet. + The op_name and attrs inside are not always compatible. + + shape_dict : dict + Known parameter shapes + + dtype_info : dict or str. + Known parameter dtypes + + Returns: + ------- + nnvm.sym.Symbol + Converted symbol + """ + assert symbol is not None + jgraph = json.loads(symbol.tojson()) + jnodes = jgraph["nodes"] + node_map = {} + + for nid, node in enumerate(jnodes): + children = [node_map[e[0]][e[1]] for e in node["inputs"]] + attrs = StrAttrsDict(node.get("attrs", {})) + node_name = node["name"] + op_name = node["op"] + if op_name == "null": + shape = shape_dict[node_name] if node_name in shape_dict else None + if isinstance(dtype_info, dict): + dtype = dtype_info[node_name] if node_name in dtype_dict else "float32" + else: + dtype = dtype_info + node_map[nid] = [_expr.var(node_name, shape=shape, dtype=dtype)] + elif op_name in _convert_map: + res = _convert_map[op_name](children, attrs) + if isinstance(res, (_expr.TupleWrapper, tuple, list)): + pass + elif isinstance(res, _expr.Expr): + res = [res] + else: + raise RuntimeError("unexpected type %s" % type(res)) + node_map[nid] = res + else: + raise RuntimeError("{} is not supported in relay frontend".format(op_name)) + + outputs = [node_map[e[0]][e[1]] for e in jgraph["heads"]] + outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) + func = _expr.Function(ir_pass.free_vars(outputs), outputs) + return func + + +def _update_shape_dtype(shape, dtype, params): + """Update shape dtype given params information""" + shape = {} if shape is None else shape + if not params: + return shape, dtype + shape = shape.copy() + shape.update({k : v.shape for k, v in params.items()}) + if isinstance(dtype, str): + for k, v in params.items(): + if v.dtype != dtype: + raise ValueError( + "%s: dtype not expected %s vs %s" % (k, dtype, v.dtype)) + else: + dtype = dtype.copy() + dtype.update({k : str(v.dtype) for k, v in params.items()}) + return shape, dtype + + +def from_mxnet(symbol, + shape=None, + dtype="float32", + arg_params=None, + aux_params=None): + """Convert from MXNet"s model into compatible relay Function. + + Parameters + ---------- + symbol : mxnet.Symbol or mxnet.gluon.HybridBlock + MXNet symbol. + + shape : dict of str to tuple, optional + The input shape to the graph + + dtype : str or dict of str to str + The input types to the graph + + arg_params : dict of str to mx.NDArray + The argument parameters in mxnet + + aux_params : dict of str to mx.NDArray + The auxiliary parameters in mxnet + + Returns + ------- + sym : nnvm.Symbol + Compatible nnvm symbol + + params : dict of str to tvm.NDArray + The parameter dict to be used by nnvm + """ + try: + import mxnet as mx + except ImportError as e: + raise ImportError("{}. MXNet is required to parse symbols.".format(e)) + + if isinstance(symbol, mx.sym.Symbol): + params = {} + arg_params = arg_params if arg_params else {} + aux_params = aux_params if aux_params else {} + for k, v in arg_params.items(): + params[k] = _nd.array(v.asnumpy()) + for k, v in aux_params.items(): + params[k] = _nd.array(v.asnumpy()) + shape, dtype = _update_shape_dtype(shape, dtype, params) + sym = _from_mxnet_impl(symbol, shape, dtype) + elif isinstance(symbol, mx.gluon.HybridBlock): + if args_params is not None or aux_params is not None: + raise ValueError("arg_params and aux_params ae not used when importing HybridBlock") + params = {} + for k, v in symbol.collect_params().items(): + params[k] = tvm.nd.array(v.data().asnumpy()) + data = mx.sym.Variable("data") + sym = symbol(data) + shape, dtype = _update_shape_dtype(shape, dtype, params) + sym = _from_mxnet_impl(sym, shape, dtype) + elif isinstance(symbol, mx.gluon.Block): + raise NotImplementedError("Only Hybrid Blocks are supported now.") + else: + msg = "mxnet.Symbol or gluon.HybridBlock expected, got {}".format(type(symbol)) + raise ValueError(msg) + return sym, params diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 30aef433d7c6..b32db4c23f3e 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -14,6 +14,7 @@ # operator registry from . import _tensor from . import _transform +from . import _reduce from ..expr import Expr from ..base import register_relay_node diff --git a/python/tvm/relay/op/_reduce.py b/python/tvm/relay/op/_reduce.py new file mode 100644 index 000000000000..fd18c0e71d53 --- /dev/null +++ b/python/tvm/relay/op/_reduce.py @@ -0,0 +1,19 @@ +"""Backend compiler related feature registration""" +from __future__ import absolute_import + +import topi +from . import op as _reg + + +def _schedule_reduce(_, outs, target): + """Generic schedule for reduce""" + with target: + return topi.generic.schedule_reduce(outs) + + +_reg.register_schedule("argmax", _schedule_reduce) +_reg.register_schedule("argmin", _schedule_reduce) +_reg.register_schedule("sum", _schedule_reduce) +_reg.register_schedule("max", _schedule_reduce) +_reg.register_schedule("prod", _schedule_reduce) +_reg.register_schedule("mean", _schedule_reduce) diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 7aef4d4377af..4832a195f9e8 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -273,4 +273,5 @@ def concatenate_compute(attrs, inputs, output_type, target): return [topi.concatenate(inputs, axis=attrs.axis)] register_schedule("concatenate", schedule_injective) -register_pattern("concatenate", OpPattern.INJECTIVE) +# TODO(tqchen): renable concat as injective +register_pattern("concatenate", OpPattern.OPAQUE) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index c05fbe8ec61e..d087526b7b88 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1,63 +1,17 @@ -#pylint: disable=invalid-name, unused-argument """Backend compiler related feature registration""" +# pylint: disable=invalid-name from __future__ import absolute_import -import topi -import topi.cuda -from tvm import container from . import op as _reg -from .op import (schedule_injective, register_compute, register_schedule, - register_pattern, OpPattern) -schedule_broadcast = schedule_injective +schedule_injective = _reg.schedule_injective +schedule_broadcast = _reg.schedule_injective -# squeeze -@register_compute("squeeze") -def squeeze_compiler(attrs, inputs, output_type, target): - """Compiler for squeeze dims.""" - assert len(inputs) == 1 - - if attrs.axis is None: - axis = None - elif isinstance(attrs.axis, container.Array): - axis = tuple(attrs.axis) - else: - axis = int(attrs.axis) - - return [topi.squeeze(inputs[0], axis)] - -register_pattern("squeeze", OpPattern.INJECTIVE) -register_schedule("squeeze", schedule_injective) - -# expand_dims -@register_compute("expand_dims") -def expand_dims_compiler(attrs, inputs, output_type, target): - """Compiler for expand_dims.""" - assert len(inputs) == 1 - - new_axis = int(attrs.num_newaxis) - assert new_axis >= 0 - - # axis should be in range [-data.ndim - 1, data.ndim] - axis = int(attrs.axis) - assert axis >= -len(inputs[0].shape) - 1 - assert axis <= len(inputs[0].shape) - - return [topi.expand_dims(inputs[0], axis, new_axis)] +_reg.register_schedule("squeeze", schedule_injective) _reg.register_schedule("expand_dims", schedule_broadcast) -_reg.register_pattern("expand_dims", OpPattern.BROADCAST) - -# strided_slice -_reg.register_schedule("strided_slice", schedule_injective) - -# slice_like -_reg.register_schedule("slice_like", schedule_injective) -_reg.register_pattern("slice_like", OpPattern.INJECTIVE) - -# reshape _reg.register_schedule("reshape", schedule_injective) -_reg.register_pattern("reshape", OpPattern.INJECTIVE) - -# reshape_like _reg.register_schedule("reshape_like", schedule_injective) -_reg.register_pattern("reshape_like", OpPattern.INJECTIVE) +_reg.register_schedule("cast", schedule_broadcast) +_reg.register_schedule("strided_slice", schedule_injective) +_reg.register_schedule("slice_like", schedule_injective) +_reg.register_schedule("split", schedule_injective) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index b48bfde97f33..9c988b86e8bc 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -1,5 +1,7 @@ #pylint: disable=invalid-name, unused-argument """Backend compiler related feature registration""" +from __future__ import absolute_import + import topi from topi.util import get_const_int, get_const_tuple from .. import op as reg diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 61c930436167..63b1e206e72c 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -145,7 +145,7 @@ def conv2d_transpose(data, weight_layout, output_padding, out_dtype) -def softmax(data, axis=1): +def softmax(data, axis=-1): r"""Computes softmax. .. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)} @@ -169,7 +169,7 @@ def softmax(data, axis=1): return _make.softmax(data, axis) -def log_softmax(data, axis): +def log_softmax(data, axis=-1): r"""Computes log softmax. .. math:: diff --git a/python/tvm/relay/testing/inception_v3.py b/python/tvm/relay/testing/inception_v3.py index 96684c5d6e1d..491b221fbe0a 100644 --- a/python/tvm/relay/testing/inception_v3.py +++ b/python/tvm/relay/testing/inception_v3.py @@ -54,7 +54,7 @@ def Inception7A(data, name=('%s_pool_%s_pool' % (pool, name))) cproj = Conv(pooling, proj, name=('%s_tower_2' % name), suffix='_conv') - concat = relay.concatenate((tower_1x1, tower_5x5, tower_3x3, cproj), axis=0) + concat = relay.concatenate((tower_1x1, tower_5x5, tower_3x3, cproj), axis=1) return concat # First Downsample @@ -72,7 +72,7 @@ def Inception7B(data, name=('%s_tower' % name), suffix='_conv_2') pooling = Pooling(data=data, kernel=(3, 3), stride=(2, 2), pad=(0, 0), pool_type="max", name=('max_pool_%s_pool' % name)) - concat = relay.concatenate((tower_3x3, tower_d3x3, pooling), axis=0) + concat = relay.concatenate((tower_3x3, tower_d3x3, pooling), axis=1) return concat def Inception7C(data, @@ -101,7 +101,7 @@ def Inception7C(data, cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' % name), suffix='_conv') # concat - concat = relay.concatenate((tower_1x1, tower_d7, tower_q7, cproj), axis=0) + concat = relay.concatenate((tower_1x1, tower_d7, tower_q7, cproj), axis=1) return concat def Inception7D(data, @@ -124,7 +124,7 @@ def Inception7D(data, pooling = Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type=pool, pad=(0, 0), name=('%s_pool_%s_pool' % (pool, name))) # concat - concat = relay.concatenate((tower_3x3, tower_d7_3x3, pooling), axis=0) + concat = relay.concatenate((tower_3x3, tower_d7_3x3, pooling), axis=1) return concat def Inception7E(data, @@ -153,7 +153,7 @@ def Inception7E(data, suffix='_conv') # concat concat = relay.concatenate( - (tower_1x1, tower_d3_a, tower_d3_b, tower_3x3_d3_a, tower_3x3_d3_b, cproj), axis=0) + (tower_1x1, tower_d3_a, tower_d3_b, tower_3x3_d3_a, tower_3x3_d3_b, cproj), axis=1) return concat def get_net(batch_size, diff --git a/python/tvm/relay/testing/squeezenet.py b/python/tvm/relay/testing/squeezenet.py index fa55cafbf2b4..c7b8e8db166b 100644 --- a/python/tvm/relay/testing/squeezenet.py +++ b/python/tvm/relay/testing/squeezenet.py @@ -31,19 +31,21 @@ from . import layers # Helpers -def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels): - net = _make_fire_conv(net, squeeze_channels, 1, 0) +def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels, prefix): + net = _make_fire_conv(net, squeeze_channels, 1, 0, "%s_input" % prefix) - left = _make_fire_conv(net, expand1x1_channels, 1, 0) - right = _make_fire_conv(net, expand3x3_channels, 3, 1) + left = _make_fire_conv(net, expand1x1_channels, 1, 0, "%s_left" % prefix) + right = _make_fire_conv(net, expand3x3_channels, 3, 1, "%s_right" % prefix) # NOTE : Assume NCHW layout here net = relay.concatenate((left, right), axis=1) - return net -def _make_fire_conv(net, channels, kernel_size, padding=0): - net = layers.conv2d(net, channels=channels, kernel_size=(kernel_size, kernel_size), - padding=(padding, padding), name="conv2d") +def _make_fire_conv(net, channels, kernel_size, padding=0, prefix=""): + net = layers.conv2d(net, + channels=channels, + kernel_size=(kernel_size, kernel_size), + padding=(padding, padding), name="%s_conv" % prefix) + net = relay.nn.bias_add(net, relay.var("%s_conv_bias" % prefix)) net = relay.nn.relu(net) return net @@ -75,41 +77,44 @@ def get_net(batch_size, image_shape, num_classes, version, dtype): kernel_size=(7, 7), strides=(2, 2), padding=(3, 3), - name="conv2d") - net = relay.nn.bias_add(net, relay.var("dense1_bias")) + name="conv1") + net = relay.nn.bias_add(net, relay.var("conv1_bias")) net = relay.nn.relu(net) net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2)) - net = _make_fire(net, 16, 64, 64) - net = _make_fire(net, 16, 64, 64) - net = _make_fire(net, 32, 128, 128) + net = _make_fire(net, 16, 64, 64, "fire1") + net = _make_fire(net, 16, 64, 64, "fire2") + net = _make_fire(net, 32, 128, 128, "fire3") net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2)) - net = _make_fire(net, 32, 128, 128) - net = _make_fire(net, 48, 192, 192) - net = _make_fire(net, 48, 192, 192) - net = _make_fire(net, 64, 256, 256) + net = _make_fire(net, 32, 128, 128, "fire4") + net = _make_fire(net, 48, 192, 192, "fire5") + net = _make_fire(net, 48, 192, 192, "fire6") + net = _make_fire(net, 64, 256, 256, "fire7") net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2)) - net = _make_fire(net, 64, 256, 256) + net = _make_fire(net, 64, 256, 256, "fire8") else: net = layers.conv2d(net, channels=64, kernel_size=(3, 3), strides=(2, 2), padding=(1, 1), - name="conv2d") + name="conv1") + net = relay.nn.bias_add(net, relay.var("conv1_bias")) net = relay.nn.relu(net) net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2)) - net = _make_fire(net, 16, 64, 64) - net = _make_fire(net, 16, 64, 64) + net = _make_fire(net, 16, 64, 64, "fire1") + net = _make_fire(net, 16, 64, 64, "fire2") net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2)) - net = _make_fire(net, 32, 128, 128) - net = _make_fire(net, 32, 128, 128) + net = _make_fire(net, 32, 128, 128, "fire3") + net = _make_fire(net, 32, 128, 128, "fire4") net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2)) - net = _make_fire(net, 48, 192, 192) - net = _make_fire(net, 48, 192, 192) - net = _make_fire(net, 64, 256, 256) - net = _make_fire(net, 64, 256, 256) + net = _make_fire(net, 48, 192, 192, "fire5") + net = _make_fire(net, 48, 192, 192, "fire6") + net = _make_fire(net, 64, 256, 256, "fire7") + net = _make_fire(net, 64, 256, 256, "fire8") net = relay.nn.dropout(net, rate=0.5) - net = layers.conv2d(net, channels=num_classes, kernel_size=(1, 1), name="conv2d") + net = layers.conv2d( + net, channels=num_classes, kernel_size=(1, 1), name="conv_final") + net = relay.nn.bias_add(net, relay.var("conv_final_bias")) net = relay.nn.relu(net) net = relay.nn.global_avg_pool2d(net) net = relay.nn.batch_flatten(net) @@ -117,8 +122,12 @@ def get_net(batch_size, image_shape, num_classes, version, dtype): args = relay.ir_pass.free_vars(net) return relay.Function(args, net) -def get_workload(batch_size=1, num_classes=1000, version='1.0', - image_shape=(3, 224, 224), dtype="float32"): + +def get_workload(batch_size=1, + num_classes=1000, + version='1.0', + image_shape=(3, 224, 224), + dtype="float32"): """Get benchmark workload for SqueezeNet Parameters diff --git a/python/tvm/relay/testing/vgg.py b/python/tvm/relay/testing/vgg.py index 7ec6669f6346..811de33c579a 100644 --- a/python/tvm/relay/testing/vgg.py +++ b/python/tvm/relay/testing/vgg.py @@ -24,20 +24,24 @@ from .init import create_workload from . import layers as wrapper -def get_feature(internel_layer, layers, filters, batch_norm=False): + +def get_feature(internal_layer, layers, filters, batch_norm=False): """Get VGG feature body as stacks of convoltions.""" for i, num in enumerate(layers): for j in range(num): - internel_layer = wrapper.conv2d( - data=internel_layer, kernel_size=(3, 3), padding=(1, 1), - channels=filters[i], name="conv%s_%s"%(i + 1, j + 1)) + internal_layer = wrapper.conv2d( + data=internal_layer, kernel_size=(3, 3), padding=(1, 1), + channels=filters[i], name="conv%s_%s" % (i + 1, j + 1)) + internal_layer = relay.nn.bias_add( + internal_layer, relay.var("conv%s_%s_bias" % (i + 1, j + 1))) if batch_norm: - internel_layer = wrapper.batch_norm_infer( - data=internel_layer, name="bn%s_%s" %(i + 1, j + 1)) - internel_layer = relay.nn.relu(data=internel_layer) - internel_layer = relay.nn.max_pool2d( - data=internel_layer, pool_size=(2, 2), strides=(2, 2)) - return internel_layer + internal_layer = wrapper.batch_norm_infer( + data=internal_layer, name="bn%s_%s" %(i + 1, j + 1)) + internal_layer = relay.nn.relu(data=internal_layer) + internal_layer = relay.nn.max_pool2d( + data=internal_layer, pool_size=(2, 2), strides=(2, 2)) + return internal_layer + def get_classifier(input_data, num_classes): """Get VGG classifier layers as fc layers.""" @@ -51,6 +55,7 @@ def get_classifier(input_data, num_classes): fc8 = wrapper.dense_add_bias(data=drop7, units=num_classes, name="fc8") return fc8 + def get_net(batch_size, image_shape, num_classes, dtype, num_layers=11, batch_norm=False): """ Parameters @@ -68,7 +73,7 @@ def get_net(batch_size, image_shape, num_classes, dtype, num_layers=11, batch_no The data type num_layers : int - Number of layers for the variant of densenet. Options are 11, 13, 16, 19. + Number of layers for the variant of vgg. Options are 11, 13, 16, 19. batch_norm : bool, default False Use batch normalization. @@ -88,7 +93,12 @@ def get_net(batch_size, image_shape, num_classes, dtype, num_layers=11, batch_no args = relay.ir_pass.free_vars(symbol) return relay.Function(args, symbol) -def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224), dtype="float32"): + +def get_workload(batch_size, + num_classes=1000, + image_shape=(3, 224, 224), + dtype="float32", + num_layers=11): """Get benchmark workload for VGG nets. Parameters @@ -105,6 +115,9 @@ def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224), dtype= dtype : str, optional The data type + num_layers : int + Number of layers for the variant of vgg. Options are 11, 13, 16, 19. + Returns ------- net : nnvm.Symbol @@ -113,5 +126,5 @@ def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224), dtype= params : dict of str to NDArray The parameters. """ - net = get_net(batch_size, image_shape, num_classes, dtype) + net = get_net(batch_size, image_shape, num_classes, dtype, num_layers) return create_workload(net) diff --git a/src/lang/attr_functor.h b/src/lang/attr_functor.h index 9257ad3b5490..69b7ec1f6e60 100644 --- a/src/lang/attr_functor.h +++ b/src/lang/attr_functor.h @@ -163,6 +163,7 @@ class AttrsHashHandler : * \param node The node to be hashed. */ size_t Hash(const NodeRef& node) { + if (!node.defined()) return 0; return this->VisitAttr(node); } diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index f3c3e2935d22..5001e2cd4fea 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -31,7 +31,10 @@ class StorageAllocaBaseVisitor : public ExprVisitor { for (Var param : func->params) { CreateToken(param.operator->(), false); } - this->VisitExpr(func->body); + // must always keep output alive. + for (StorageToken* tok : GetToken(func->body)) { + tok->ref_counter += 1; + } } void VisitExpr_(const ConstantNode* op) final { diff --git a/src/relay/op/op_common.h b/src/relay/op/op_common.h index 4c814bc1614f..5bb2f24cae81 100644 --- a/src/relay/op/op_common.h +++ b/src/relay/op/op_common.h @@ -16,7 +16,7 @@ namespace tvm { namespace relay { template -std::vector AsVector(const Array &array) { +inline std::vector AsVector(const Array &array) { std::vector result; result.reserve(array.size()); for (const T& ele : array) { diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 0a955fad631b..95c26c3ab7e4 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -5,6 +5,8 @@ */ #include #include +#include +#include #include #include #include "../op_common.h" @@ -15,12 +17,12 @@ namespace relay { /*! \brief Attributes for Reduce operators */ struct ReduceAttrs : public tvm::AttrsNode { - Array axis; + Array axis; bool keepdims; bool exclude; TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") { - TVM_ATTR_FIELD(axis).set_default(NullValue>()) + TVM_ATTR_FIELD(axis).set_default(NullValue>()) .describe(R"code(The axis or axes along which to perform the reduction. The default, `axis=()`, will compute over all elements into a @@ -50,7 +52,7 @@ struct ReduceAttrs : public tvm::AttrsNode { * \return r_axes The new reduced axes of the output. */ inline std::vector GetReduceAxes(const uint32_t indim, - const Array& inaxis, + const Array& inaxis, bool exclude) { if (!inaxis.defined()) { std::vector r_axes(indim); @@ -60,9 +62,7 @@ inline std::vector GetReduceAxes(const uint32_t indim, std::vector in_axes; for (auto i : inaxis) { - const int64_t* k = as_const_int(i); - CHECK(k != nullptr) << "Reduce axis need to be constant, cannot be symbolic"; - int64_t axis = k[0]; + int64_t axis = i->value; if (axis < 0) { axis = axis + indim; } @@ -97,6 +97,53 @@ inline std::vector GetReduceAxes(const uint32_t indim, return r_axes; } + +// Get axis under exclude condition. +Array GetExcludeAxes(size_t indim, + const Array& inaxis) { + std::vector axis_flag(indim, true); + for (auto i : inaxis) { + int64_t axis = i->value; + if (axis < 0) { + axis = axis + static_cast(indim); + } + // Check out of bounds error + CHECK_GE(axis, 0) + << "Axis out of bounds in reduce operator."; + CHECK_LT(axis, static_cast(indim)) + << "Axis out of bounds in reduce operator."; + axis_flag[axis] = false; + } + + Array r_axes; + + for (size_t i = 0; i < axis_flag.size(); ++i) { + if (axis_flag[i]) { + r_axes.push_back(static_cast(i)); + } + } + return r_axes; +} + + +template +Array ReduceCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target, + F f) { + const ReduceAttrs* param = attrs.as(); + CHECK(param != nullptr); + auto axes = param->axis; + if (param->exclude) { + axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis); + } + if (axes.size() == 0) { + return { topi::identity(inputs[0]) }; + } + return { f(inputs[0], axes, param->keepdims, false) }; +} + /*! * \brief ReduceShapeImpl get the outshape for the reduction operator * \param in_shape Shape of input data. @@ -200,7 +247,7 @@ bool ReduceRel(const Array& types, TVM_REGISTER_API("relay.op._make." OpName) \ .set_body([](const TVMArgs& args, TVMRetValue* rv) { \ auto make_func = [](Expr data, \ - Array axis, \ + Array axis, \ bool keepdims, \ bool exclude) { \ auto attrs = make_node(); \ @@ -217,6 +264,14 @@ bool ReduceRel(const Array& types, .add_argument("data", "Tensor", "The input tensor.") +Array ArgMaxCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + return ReduceCompute(attrs, inputs, out_type, target, topi::argmax); +} + + RELAY_REGISTER_REDUCE_OP("argmax") .describe(R"code(Creates an operation that finds the indices of the maximum values over a given axis. @@ -224,8 +279,17 @@ values over a given axis. )code" TVM_ADD_FILELINE) .set_attrs_type_key("relay.attrs.ReduceAttrs") .set_support_level(4) -.add_type_rel("ArgReduce", ArgReduceRel); +.add_type_rel("ArgReduce", ArgReduceRel) +.set_attr("FTVMCompute", ArgMaxCompute) +.set_attr("TOpPattern", kCommReduce); + +Array ArgMinCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + return ReduceCompute(attrs, inputs, out_type, target, topi::argmin); +} RELAY_REGISTER_REDUCE_OP("argmin") .describe(R"code(Creates an operation that finds the indices of the minimum @@ -234,7 +298,16 @@ values over a given axis. )code" TVM_ADD_FILELINE) .set_attrs_type_key("relay.attrs.ReduceAttrs") .set_support_level(4) -.add_type_rel("ArgReduce", ArgReduceRel); +.add_type_rel("ArgReduce", ArgReduceRel) +.set_attr("FTVMCompute", ArgMinCompute) +.set_attr("TOpPattern", kCommReduce); + +Array SumCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + return ReduceCompute(attrs, inputs, out_type, target, topi::sum); +} RELAY_REGISTER_REDUCE_OP("sum") @@ -257,16 +330,35 @@ Example:: )code" TVM_ADD_FILELINE) .set_attrs_type_key("relay.attrs.ReduceAttrs") .set_support_level(4) -.add_type_rel("Reduce", ReduceRel); +.add_type_rel("Reduce", ReduceRel) +.set_attr("FTVMCompute", SumCompute) +.set_attr("TOpPattern", kCommReduce); +Array MaxCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + return ReduceCompute(attrs, inputs, out_type, target, topi::max); +} + RELAY_REGISTER_REDUCE_OP("max") .describe(R"code(Computes the max of array elements over given axes. )code" TVM_ADD_FILELINE) .set_attrs_type_key("relay.attrs.ReduceAttrs") .set_support_level(4) -.add_type_rel("Reduce", ReduceRel); +.add_type_rel("Reduce", ReduceRel) +.set_attr("FTVMCompute", MaxCompute) +.set_attr("TOpPattern", kCommReduce); + + +Array MinCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + return ReduceCompute(attrs, inputs, out_type, target, topi::min); +} RELAY_REGISTER_REDUCE_OP("min") @@ -275,11 +367,20 @@ RELAY_REGISTER_REDUCE_OP("min") )code" TVM_ADD_FILELINE) .set_attrs_type_key("relay.attrs.ReduceAttrs") .set_support_level(4) -.add_type_rel("Reduce", ReduceRel); +.add_type_rel("Reduce", ReduceRel) +.set_attr("FTVMCompute", MinCompute) +.set_attr("TOpPattern", kCommReduce); -RELAY_REGISTER_REDUCE_OP("mean") -.describe(R"code(Computes the mean of array elements over given axes. +Array ProdCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + return ReduceCompute(attrs, inputs, out_type, target, topi::prod); +} + +RELAY_REGISTER_REDUCE_OP("prod") +.describe(R"code(Computes the products of array elements over given axes. Example:: @@ -287,20 +388,40 @@ Example:: [[1,4],[4,3],[5,2]], [[7,1],[7,2],[7,3]]] - mean(data) - [3.22] + mean(data, axis=1) + [35562240] mean(data, axis=[1,2]) - [ 2. 3.16666667 4.5] + [ 36 480 2058] )code" TVM_ADD_FILELINE) .set_attrs_type_key("relay.attrs.ReduceAttrs") .set_support_level(4) -.add_type_rel("Reduce", ReduceRel); +.add_type_rel("Reduce", ReduceRel) +.set_attr("FTVMCompute", ProdCompute) +.set_attr("TOpPattern", kCommReduce); -RELAY_REGISTER_REDUCE_OP("prod") -.describe(R"code(Computes the products of array elements over given axes. +Array MeanCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + IndexExpr count = make_const(inputs[0]->dtype, 1); + const ReduceAttrs* param = attrs.as(); + CHECK(param != nullptr); + auto axes = param->axis; + for (int64_t i : GetReduceAxes(inputs[0]->shape.size(), + param->axis, + param->exclude)) { + count *= inputs[0]->shape[i]; + } + auto res = ReduceCompute(attrs, inputs, out_type, target, topi::sum); + return {topi::divide(res[0], count)}; +} + + +RELAY_REGISTER_REDUCE_OP("mean") +.describe(R"code(Computes the mean of array elements over given axes. Example:: @@ -308,16 +429,17 @@ Example:: [[1,4],[4,3],[5,2]], [[7,1],[7,2],[7,3]]] - mean(data, axis=1) - [35562240] + mean(data) + [3.22] mean(data, axis=[1,2]) - [ 36 480 2058] + [ 2. 3.16666667 4.5] )code" TVM_ADD_FILELINE) .set_attrs_type_key("relay.attrs.ReduceAttrs") .set_support_level(4) -.add_type_rel("Reduce", ReduceRel); - +.add_type_rel("Reduce", ReduceRel) +.set_attr("FTVMCompute", MeanCompute) +.set_attr("TOpPattern", kCommReduce); } // namespace relay } // namespace tvm diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 52363e8af92a..83a4c9067f43 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -8,9 +8,10 @@ #include #include #include +#include #include #include "../op_common.h" - +#include "../../../arithmetic/compute_expr.h" namespace tvm { namespace relay { @@ -37,6 +38,16 @@ bool CastRel(const Array& types, return true; } +Array CastCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + const CastAttrs *param = attrs.as(); + CHECK(param != nullptr); + DataType dtype = param->dtype; + return { topi::cast(inputs[0], dtype) }; +} + Expr MakeCast(Expr data, DataType dtype) { auto attrs = make_node(); @@ -58,8 +69,9 @@ RELAY_REGISTER_OP("cast") .set_attrs_type_key("relay.attrs.CastAttrs") .add_argument("data", "Tensor", "The input tensor.") .set_support_level(3) -.add_type_rel("Cast", CastRel); - +.add_type_rel("Cast", CastRel) +.set_attr("FTVMCompute", CastCompute) +.set_attr("TOpPattern", kElemWise); // relay.expand_dims TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs); @@ -104,6 +116,15 @@ bool ExpandDimsRel(const Array& types, return true; } +Array ExpandDimsCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + const ExpandDimsAttrs *param = attrs.as(); + CHECK(param != nullptr); + return { topi::expand_dims(inputs[0], param->axis, param->num_newaxis) }; +} + Expr MakeExpandDims(Expr data, int axis, int num_newaxis) { @@ -129,7 +150,9 @@ RELAY_REGISTER_OP("expand_dims") .set_attrs_type_key("relay.attrs.ExpandDimsAttrs") .add_argument("data", "Tensor", "The input tensor.") .set_support_level(1) -.add_type_rel("ExpandDims", ExpandDimsRel); +.add_type_rel("ExpandDims", ExpandDimsRel) +.set_attr("FTVMCompute", ExpandDimsCompute) +.set_attr("TOpPattern", kBroadcast); TVM_REGISTER_NODE_TYPE(ConcatenateAttrs); @@ -303,13 +326,81 @@ bool ReshapeRel(const Array& types, << types[0]; return false; } + const auto* param = attrs.as(); - reporter->Assign(types[1], TensorTypeNode::make(param->newshape, data->dtype)); + Array oshape; + size_t src_idx = 0; + int infer_idx = -1; + + for (size_t i = 0; i < param->newshape.size(); ++i) { + int svalue = param->newshape[i]->value; + // special flag handling for shape inference. + if (svalue > 0) { + oshape.push_back(param->newshape[i]); + ++src_idx; + } else if (svalue == 0) { + // keep same + CHECK_LT(src_idx, data->shape.size()); + oshape.push_back(data->shape[src_idx++]); + } else if (svalue == -1) { + // inference based on rest + CHECK_LT(infer_idx, 0) + << "One and only one dim can be inferred"; + infer_idx = i; + oshape.push_back(1); + ++src_idx; + } else if (svalue == -2) { + // copy all remaining dims from source + while (src_idx < data->shape.size()) { + oshape.push_back(data->shape[src_idx++]); + } + } else if (svalue == -3) { + // merge two dims from source + CHECK_LT(src_idx + 1, data->shape.size()); + IndexExpr d1 = data->shape[src_idx++]; + IndexExpr d2 = data->shape[src_idx++]; + oshape.push_back(d1 * d2); + } else if (svalue == -4) { + // split the source dim s into two dims + // read the left dim and then the right dim (either can be -1) + CHECK_LT(i + 2, param->newshape.size()); + CHECK_LT(src_idx, data->shape.size()); + IndexExpr d0 = data->shape[src_idx++]; + Integer d1 = param->newshape[++i]; + Integer d2 = param->newshape[++i]; + if (d1->value == -1) { + CHECK(d2->value != -1) + << "Split dims cannot both be -1."; + oshape.push_back(d0 / d2); + oshape.push_back(d2); + } else { + CHECK_EQ(d2->value, -1); + oshape.push_back(d1); + oshape.push_back(d0 / d1); + } + } + } + + if (infer_idx >= 0) { + IndexExpr new_size = arith::ComputeReduce(oshape, 1); + IndexExpr old_size = arith::ComputeReduce(data->shape, 1); + oshape.Set(infer_idx, old_size / new_size); + } + reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); return true; } +Array ReshapeCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + const auto* out_ttype = out_type.as(); + CHECK(out_ttype != nullptr); + return { topi::reshape(inputs[0], out_ttype->shape) }; +} + Expr MakeReshape(Expr data, - Array newshape) { + Array newshape) { auto attrs = make_node(); attrs->newshape = std::move(newshape); static const Op& op = Op::Get("reshape"); @@ -377,14 +468,8 @@ Example:: .add_argument("data", "Tensor", "The input tensor.") .set_support_level(3) .add_type_rel("Reshape", ReshapeRel) -.set_attr("FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type, - const Target& target) { - const auto* param = attrs.as(); - CHECK(param != nullptr); - return Array{ topi::reshape(inputs[0], param->newshape) }; -}); +.set_attr("FTVMCompute", ReshapeCompute) +.set_attr("TOpPattern", kInjective); /*! @@ -440,12 +525,8 @@ the input array into an output array with the same shape as the second input arr .add_argument("shape_like", "Tensor", "Shape tensor.") .set_support_level(3) .add_type_rel("ReshapeLike", ReshapeLikeRel) -.set_attr("FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type, - const Target& target) { - return Array{ topi::reshape(inputs[0], inputs[1]->shape) }; -}); +.set_attr("FTVMCompute", ReshapeCompute) +.set_attr("TOpPattern", kInjective); // Take @@ -788,6 +869,7 @@ TVM_REGISTER_API("relay.op._make.squeeze") runtime::detail::unpack_call(MakeSqueeze, args, rv); }); + bool SqueezeRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -816,7 +898,13 @@ bool SqueezeRel(const Array& types, original_shape.push_back(std::pair(e, true)); } for (const auto& e : param->axis) { - original_shape.at(e->value).second = false; + int64_t axis_val = e->value; + if (axis_val < 0) { + axis_val += static_cast(original_shape.size()); + } + CHECK_GE(axis_val, 0); + CHECK_LT(axis_val, original_shape.size()); + original_shape.at(axis_val).second = false; } for (const auto p : original_shape) { if (p.second) { @@ -832,6 +920,16 @@ bool SqueezeRel(const Array& types, return true; } +Array SqueezeCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + const SqueezeAttrs *param = attrs.as(); + CHECK(param != nullptr); + return { topi::squeeze(inputs[0], param->axis) }; +} + + RELAY_REGISTER_OP("squeeze") .describe(R"code(Squeeze the input tensor at the dimensions given by axes @@ -842,7 +940,10 @@ RELAY_REGISTER_OP("squeeze") .set_attrs_type_key("relay.attrs.SqueezeAttrs") .add_argument("data", "Tensor", "The input tensor.") .set_support_level(3) -.add_type_rel("Squeeze", SqueezeRel); +.add_type_rel("Squeeze", SqueezeRel) +.set_attr("FTVMCompute", SqueezeCompute) +.set_attr("TOpPattern", kInjective); + // Have no idea how to assert the constraint. // CollapseSumLike: -> B where BroadCast(A, B) = A @@ -1034,8 +1135,8 @@ Array StridedSliceCompute(const Attrs& attrs, TVM_REGISTER_API("relay.op._make.strided_slice") - .set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeStridedSlice, args, rv); +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeStridedSlice, args, rv); }); @@ -1082,7 +1183,7 @@ bool SplitRel(const Array& types, // `types` contains: [data, result] CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); - CHECK(data != nullptr); + if (data == nullptr) return false; CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty"; const auto param = attrs.as(); CHECK(param != nullptr); @@ -1131,6 +1232,23 @@ bool SplitRel(const Array& types, return true; } +Array SplitCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + const auto param = attrs.as(); + CHECK(param != nullptr); + + if (const IntImm* sections = param->indices_or_sections.as()) { + int64_t num_sections = sections->value; + return Array{ + topi::split_sections(inputs[0], num_sections, param->axis) }; + } else { + auto indices = Downcast >(param->indices_or_sections); + return Array{ topi::split(inputs[0], indices, param->axis) }; + } +} + Expr MakeSplit(Expr data, NodeRef indices_or_sections, int axis) { @@ -1165,7 +1283,9 @@ the entries indicate where along axis the array is split. .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(3) -.add_type_rel("Split", SplitRel); +.add_type_rel("Split", SplitRel) +.set_attr("FTVMCompute", SplitCompute) +.set_attr("TOpPattern", kInjective); TVM_REGISTER_NODE_TYPE(SliceLikeAttrs); @@ -1249,12 +1369,11 @@ Array GetIntArray(Array arr) { return Array(arr.node_); } -template Array SliceLikeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, const Target& target) { - const auto* param = attrs.as(); + const auto* param = attrs.as(); CHECK(param != nullptr); Array src_shape = inputs[0]->shape; Array target_shape = inputs[1]->shape; @@ -1312,7 +1431,8 @@ RELAY_REGISTER_OP("slice_like") .add_argument("shape_like", "Tensor", "Shape tensor.") .set_support_level(10) .add_type_rel("SliceLike", SliceLikeRel) -.set_attr("FTVMCompute", SliceLikeCompute); +.set_attr("FTVMCompute", SliceLikeCompute) +.set_attr("TOpPattern", kInjective); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index 96fe030c2d03..bcb91e7e5737 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -344,6 +344,7 @@ Expr MultiplyForwardRewrite(const Call& ref_call, const Array& new_args, const AxesSet& expected_out_axes) { if (!expected_out_axes.defined()) return Expr(); + if (expected_out_axes.size() == 0) return Expr(); // TODO(tvm-team) allow same axes accumulation // not as important because it is less common in nn. const auto* slhs = new_args[0].as(); @@ -681,7 +682,9 @@ AxesSet AddSubBackwardPrep(const Call& call, const Array& in_axes) { // add of two elements. return in_axes[0]; } else { - return NullValue(); + auto res = NullValue(); + CHECK(!res.defined()); + return res; } } @@ -751,14 +754,14 @@ Expr MultiplyBackwardTransform(const Call& call, const auto* trhs = call->args[1]->type_as(); AxesSet lhs_axes = transformer->GetExpectedAxes(call->args[0]); AxesSet rhs_axes = transformer->GetExpectedAxes(call->args[1]); - if (lhs_axes.defined()) { + if (lhs_axes.defined() && lhs_axes.size() != 0) { // NOTE we won't recursively call mutating on scale part. // since there won't be scale chance within scale part. Expr rhs = call->args[1]; if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs)) { return transformer->Transform(call->args[0], lhs_axes, rhs); } - } else if (rhs_axes.defined()) { + } else if (rhs_axes.defined() && rhs_axes.size() != 0) { Expr lhs = call->args[0]; if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs)) { return transformer->Transform(call->args[1], rhs_axes, lhs); diff --git a/tests/python/frontend/mxnet/model_zoo/__init__.py b/tests/python/frontend/mxnet/model_zoo/__init__.py new file mode 100644 index 000000000000..eba8f8df0bba --- /dev/null +++ b/tests/python/frontend/mxnet/model_zoo/__init__.py @@ -0,0 +1,59 @@ +"""MXNet model zoo for testing purposes.""" +from __future__ import absolute_import +from . import mlp, vgg, resnet, dqn, inception_v3, squeezenet, dcgan +import tvm.relay.testing + +# mlp +def mx_mlp(): + num_class = 10 + return mlp.get_symbol(num_class) + +def relay_mlp(): + num_class = 10 + return tvm.relay.testing.mlp.get_workload(1, num_class)[0] + +# vgg +def mx_vgg(num_layers): + num_class = 1000 + return vgg.get_symbol(num_class, num_layers) + +def relay_vgg(num_layers): + num_class = 1000 + return tvm.relay.testing.vgg.get_workload( + 1, num_class, num_layers=num_layers)[0] + +# resnet +def mx_resnet(num_layers): + num_class = 1000 + return resnet.get_symbol(num_class, num_layers, '3,224,224') + +def relay_resnet(num_layers): + num_class = 1000 + return tvm.relay.testing.resnet.get_workload( + 1, num_class, num_layers=num_layers)[0] + + +# dqn +mx_dqn = dqn.get_symbol + +def relay_dqn(): + return tvm.relay.testing.dqn.get_workload(1)[0] + +# squeezenet +def mx_squeezenet(version): + return squeezenet.get_symbol(version=version) + +def relay_squeezenet(version): + return tvm.relay.testing.squeezenet.get_workload(1, version=version)[0] + +# inception +mx_inception_v3 = inception_v3.get_symbol + +def relay_inception_v3(): + return tvm.relay.testing.inception_v3.get_workload(1)[0] + +# dcgan generator +mx_dcgan = dcgan.get_symbol + +def relay_dcgan(batch_size): + return tvm.relay.testing.dcgan.get_workload(batch_size=batch_size)[0] diff --git a/tests/python/frontend/mxnet/model_zoo/dcgan.py b/tests/python/frontend/mxnet/model_zoo/dcgan.py new file mode 100644 index 000000000000..8af030b6b184 --- /dev/null +++ b/tests/python/frontend/mxnet/model_zoo/dcgan.py @@ -0,0 +1,66 @@ +# pylint: disable=unused-argument +""" +The MXNet symbol of DCGAN generator + +Adopted from: +https://github.com/tqchen/mxnet-gan/blob/master/mxgan/generator.py + +Reference: +Radford, Alec, Luke Metz, and Soumith Chintala. +"Unsupervised representation learning with deep convolutional generative adversarial networks." +arXiv preprint arXiv:1511.06434 (2015). +""" + +import mxnet as mx + +def deconv2d(data, ishape, oshape, kshape, name, stride=(2, 2)): + """a deconv layer that enlarges the feature map""" + target_shape = (oshape[-2], oshape[-1]) + pad_y = (kshape[0] - 1) // 2 + pad_x = (kshape[1] - 1) // 2 + adj_y = (target_shape[0] + 2 * pad_y - kshape[0]) % stride[0] + adj_x = (target_shape[1] + 2 * pad_x - kshape[1]) % stride[1] + + net = mx.sym.Deconvolution(data, + kernel=kshape, + stride=stride, + pad=(pad_y, pad_x), + adj=(adj_y, adj_x), + num_filter=oshape[0], + no_bias=True, + name=name) + return net + +def deconv2d_bn_relu(data, prefix, **kwargs): + """a block of deconv + batch norm + relu""" + eps = 1e-5 + 1e-12 + + net = deconv2d(data, name="%s_deconv" % prefix, **kwargs) + net = mx.sym.BatchNorm(net, eps=eps, name="%s_bn" % prefix) + net = mx.sym.Activation(net, name="%s_act" % prefix, act_type='relu') + return net + +def get_symbol(oshape=(3, 64, 64), ngf=128, code=None): + """get symbol of dcgan generator""" + assert oshape[-1] == 64, "Only support 64x64 image" + assert oshape[-2] == 64, "Only support 64x64 image" + + code = mx.sym.Variable("data") if code is None else code + net = mx.sym.FullyConnected(code, name="g1", num_hidden=ngf*8*4*4, no_bias=True, flatten=False) + net = mx.sym.Activation(net, act_type='relu') + # 4 x 4 + net = mx.sym.reshape(net, shape=(-1, ngf * 8, 4, 4)) + # 8 x 8 + net = deconv2d_bn_relu( + net, ishape=(ngf * 8, 4, 4), oshape=(ngf * 4, 8, 8), kshape=(4, 4), prefix="g2") + # 16x16 + net = deconv2d_bn_relu( + net, ishape=(ngf * 4, 8, 8), oshape=(ngf * 2, 16, 16), kshape=(4, 4), prefix="g3") + # 32x32 + net = deconv2d_bn_relu( + net, ishape=(ngf * 2, 16, 16), oshape=(ngf, 32, 32), kshape=(4, 4), prefix="g4") + # 64x64 + net = deconv2d( + net, ishape=(ngf, 32, 32), oshape=oshape[-3:], kshape=(4, 4), name="g5_deconv") + net = mx.sym.Activation(net, act_type='tanh') + return net diff --git a/tests/python/frontend/mxnet/model_zoo/dqn.py b/tests/python/frontend/mxnet/model_zoo/dqn.py new file mode 100644 index 000000000000..e037511efdf2 --- /dev/null +++ b/tests/python/frontend/mxnet/model_zoo/dqn.py @@ -0,0 +1,27 @@ +""" +The mxnet symbol of Nature DQN + +Reference: +Mnih, Volodymyr, et al. +"Human-level control through deep reinforcement learning." +Nature 518.7540 (2015): 529. +""" + +import mxnet as mx + +def get_symbol(num_action=18): + data = mx.sym.Variable(name='data') + net = mx.sym.Convolution(data, kernel=(8, 8), stride=(4, 4), + num_filter=32, name='conv1') + net = mx.sym.Activation(net, act_type='relu', name='relu1') + net = mx.sym.Convolution(net, kernel=(4, 4), stride=(2, 2), + num_filter=64, name='conv2') + net = mx.sym.Activation(net, act_type='relu', name='relu2') + net = mx.sym.Convolution(net, kernel=(3, 3), stride=(1, 1), + num_filter=64, name='conv3') + net = mx.sym.Activation(net, act_type='relu', name='relu3') + net = mx.sym.FullyConnected(net, num_hidden=512, name='fc4') + net = mx.sym.Activation(net, act_type='relu', name='relu4') + net = mx.sym.FullyConnected(net, num_hidden=num_action, name='fc5', flatten=False) + + return net diff --git a/tests/python/frontend/mxnet/model_zoo/inception_v3.py b/tests/python/frontend/mxnet/model_zoo/inception_v3.py new file mode 100644 index 000000000000..b8585bf05037 --- /dev/null +++ b/tests/python/frontend/mxnet/model_zoo/inception_v3.py @@ -0,0 +1,170 @@ +""" +Inception V3, suitable for images with around 299 x 299 + +Reference: +Szegedy, Christian, et al. "Rethinking the Inception Architecture for Computer Vision." arXiv preprint arXiv:1512.00567 (2015). + +Adopted from https://github.com/apache/incubator-mxnet/blob/ + master/example/image-classification/symbols/inception-v3.py +""" +import mxnet as mx +import numpy as np + +def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None, suffix=''): + conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, no_bias=True, name='%s%s_conv2d' %(name, suffix)) + bn = mx.sym.BatchNorm(data=conv, eps=2e-5, name='%s%s_batchnorm' % (name, suffix)) + act = mx.sym.Activation(data=bn, act_type='relu', name='%s%s_relu' %(name, suffix)) + return act + + +def Inception7A(data, + num_1x1, + num_3x3_red, num_3x3_1, num_3x3_2, + num_5x5_red, num_5x5, + pool, proj, + name): + tower_1x1 = Conv(data, num_1x1, name=('%s_conv' % name)) + tower_5x5 = Conv(data, num_5x5_red, name=('%s_tower' % name), suffix='_conv') + tower_5x5 = Conv(tower_5x5, num_5x5, kernel=(5, 5), pad=(2, 2), name=('%s_tower' % name), suffix='_conv_1') + tower_3x3 = Conv(data, num_3x3_red, name=('%s_tower_1' % name), suffix='_conv') + tower_3x3 = Conv(tower_3x3, num_3x3_1, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1') + tower_3x3 = Conv(tower_3x3, num_3x3_2, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_2') + pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + cproj = Conv(pooling, proj, name=('%s_tower_2' % name), suffix='_conv') + concat = mx.sym.Concat(*[tower_1x1, tower_5x5, tower_3x3, cproj], name='ch_concat_%s_chconcat' % name) + return concat + +# First Downsample +def Inception7B(data, + num_3x3, + num_d3x3_red, num_d3x3_1, num_d3x3_2, + pool, + name): + tower_3x3 = Conv(data, num_3x3, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=('%s_conv' % name)) + tower_d3x3 = Conv(data, num_d3x3_red, name=('%s_tower' % name), suffix='_conv') + tower_d3x3 = Conv(tower_d3x3, num_d3x3_1, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name=('%s_tower' % name), suffix='_conv_1') + tower_d3x3 = Conv(tower_d3x3, num_d3x3_2, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=('%s_tower' % name), suffix='_conv_2') + pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pad=(0,0), pool_type="max", name=('max_pool_%s_pool' % name)) + concat = mx.sym.Concat(*[tower_3x3, tower_d3x3, pooling], name='ch_concat_%s_chconcat' % name) + return concat + +def Inception7C(data, + num_1x1, + num_d7_red, num_d7_1, num_d7_2, + num_q7_red, num_q7_1, num_q7_2, num_q7_3, num_q7_4, + pool, proj, + name): + tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name)) + tower_d7 = Conv(data=data, num_filter=num_d7_red, name=('%s_tower' % name), suffix='_conv') + tower_d7 = Conv(data=tower_d7, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3), name=('%s_tower' % name), suffix='_conv_1') + tower_d7 = Conv(data=tower_d7, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0), name=('%s_tower' % name), suffix='_conv_2') + tower_q7 = Conv(data=data, num_filter=num_q7_red, name=('%s_tower_1' % name), suffix='_conv') + tower_q7 = Conv(data=tower_q7, num_filter=num_q7_1, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_1') + tower_q7 = Conv(data=tower_q7, num_filter=num_q7_2, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_2') + tower_q7 = Conv(data=tower_q7, num_filter=num_q7_3, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_3') + tower_q7 = Conv(data=tower_q7, num_filter=num_q7_4, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_4') + pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' % name), suffix='_conv') + # concat + concat = mx.sym.Concat(*[tower_1x1, tower_d7, tower_q7, cproj], name='ch_concat_%s_chconcat' % name) + return concat + +def Inception7D(data, + num_3x3_red, num_3x3, + num_d7_3x3_red, num_d7_1, num_d7_2, num_d7_3x3, + pool, + name): + tower_3x3 = Conv(data=data, num_filter=num_3x3_red, name=('%s_tower' % name), suffix='_conv') + tower_3x3 = Conv(data=tower_3x3, num_filter=num_3x3, kernel=(3, 3), pad=(0,0), stride=(2, 2), name=('%s_tower' % name), suffix='_conv_1') + tower_d7_3x3 = Conv(data=data, num_filter=num_d7_3x3_red, name=('%s_tower_1' % name), suffix='_conv') + tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_1') + tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_2') + tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_3x3, kernel=(3, 3), stride=(2, 2), name=('%s_tower_1' % name), suffix='_conv_3') + pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + # concat + concat = mx.sym.Concat(*[tower_3x3, tower_d7_3x3, pooling], name='ch_concat_%s_chconcat' % name) + return concat + +def Inception7E(data, + num_1x1, + num_d3_red, num_d3_1, num_d3_2, + num_3x3_d3_red, num_3x3, num_3x3_d3_1, num_3x3_d3_2, + pool, proj, + name): + tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name)) + tower_d3 = Conv(data=data, num_filter=num_d3_red, name=('%s_tower' % name), suffix='_conv') + tower_d3_a = Conv(data=tower_d3, num_filter=num_d3_1, kernel=(1, 3), pad=(0, 1), name=('%s_tower' % name), suffix='_mixed_conv') + tower_d3_b = Conv(data=tower_d3, num_filter=num_d3_2, kernel=(3, 1), pad=(1, 0), name=('%s_tower' % name), suffix='_mixed_conv_1') + tower_3x3_d3 = Conv(data=data, num_filter=num_3x3_d3_red, name=('%s_tower_1' % name), suffix='_conv') + tower_3x3_d3 = Conv(data=tower_3x3_d3, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1') + tower_3x3_d3_a = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_1, kernel=(1, 3), pad=(0, 1), name=('%s_tower_1' % name), suffix='_mixed_conv') + tower_3x3_d3_b = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_2, kernel=(3, 1), pad=(1, 0), name=('%s_tower_1' % name), suffix='_mixed_conv_1') + pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' % name), suffix='_conv') + # concat + concat = mx.sym.Concat(*[tower_1x1, tower_d3_a, tower_d3_b, tower_3x3_d3_a, tower_3x3_d3_b, cproj], name='ch_concat_%s_chconcat' % name) + return concat + +def get_symbol(num_classes=1000, **kwargs): + data = mx.sym.Variable(name="data") + # stage 1 + conv = Conv(data, 32, kernel=(3, 3), stride=(2, 2), name="conv") + conv_1 = Conv(conv, 32, kernel=(3, 3), name="conv_1") + conv_2 = Conv(conv_1, 64, kernel=(3, 3), pad=(1, 1), name="conv_2") + pool = mx.sym.Pooling(data=conv_2, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool") + # stage 2 + conv_3 = Conv(pool, 80, kernel=(1, 1), name="conv_3") + conv_4 = Conv(conv_3, 192, kernel=(3, 3), name="conv_4") + pool1 = mx.sym.Pooling(data=conv_4, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool1") + + # # stage 3 + in3a = Inception7A(pool1, 64, + 64, 96, 96, + 48, 64, + "avg", 32, "mixed") + in3b = Inception7A(in3a, 64, + 64, 96, 96, + 48, 64, + "avg", 64, "mixed_1") + in3c = Inception7A(in3b, 64, + 64, 96, 96, + 48, 64, + "avg", 64, "mixed_2") + in3d = Inception7B(in3c, 384, + 64, 96, 96, + "max", "mixed_3") + # stage 4 + in4a = Inception7C(in3d, 192, + 128, 128, 192, + 128, 128, 128, 128, 192, + "avg", 192, "mixed_4") + in4b = Inception7C(in4a, 192, + 160, 160, 192, + 160, 160, 160, 160, 192, + "avg", 192, "mixed_5") + in4c = Inception7C(in4b, 192, + 160, 160, 192, + 160, 160, 160, 160, 192, + "avg", 192, "mixed_6") + in4d = Inception7C(in4c, 192, + 192, 192, 192, + 192, 192, 192, 192, 192, + "avg", 192, "mixed_7") + in4e = Inception7D(in4d, 192, 320, + 192, 192, 192, 192, + "max", "mixed_8") + # stage 5 + in5a = Inception7E(in4e, 320, + 384, 384, 384, + 448, 384, 384, 384, + "avg", 192, "mixed_9") + in5b = Inception7E(in5a, 320, + 384, 384, 384, + 448, 384, 384, 384, + "max", 192, "mixed_10") + # pool + pool = mx.sym.Pooling(data=in5b, kernel=(8, 8), stride=(1, 1), pool_type="avg", name="global_pool") + flatten = mx.sym.Flatten(data=pool, name="flatten") + fc1 = mx.sym.FullyConnected(data=flatten, num_hidden=num_classes, name='fc1', flatten=False) + softmax = mx.sym.SoftmaxOutput(data=fc1, name='softmax') + return softmax diff --git a/tests/python/frontend/mxnet/model_zoo/mlp.py b/tests/python/frontend/mxnet/model_zoo/mlp.py new file mode 100644 index 000000000000..922b208749bf --- /dev/null +++ b/tests/python/frontend/mxnet/model_zoo/mlp.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +a simple multilayer perceptron +""" +import mxnet as mx + +def get_symbol(num_classes=10, **kwargs): + data = mx.symbol.Variable('data') + data = mx.sym.Flatten(data=data) + try: + fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128, flatten=False) + act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") + fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64, flatten=False) + act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") + fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes, flatten=False) + mlp = mx.symbol.softmax(data = fc3, name = 'softmax') + except: + fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128) + act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") + fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64) + act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") + fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes) + mlp = mx.symbol.softmax(data = fc3, name = 'softmax') + return mlp diff --git a/tests/python/frontend/mxnet/model_zoo/resnet.py b/tests/python/frontend/mxnet/model_zoo/resnet.py new file mode 100644 index 000000000000..3f9a870d31c0 --- /dev/null +++ b/tests/python/frontend/mxnet/model_zoo/resnet.py @@ -0,0 +1,199 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +''' +Adapted from https://github.com/tornadomeet/ResNet/blob/master/symbol_resnet.py +Original author Wei Wu + +Implemented the following paper: + +Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks" +''' +import mxnet as mx +import numpy as np + +def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True, bn_mom=0.9, workspace=256, memonger=False): + """Return ResNet Unit symbol for building ResNet + Parameters + ---------- + data : str + Input data + num_filter : int + Number of output channels + bnf : int + Bottle neck channels factor with regard to num_filter + stride : tuple + Stride used in convolution + dim_match : Boolean + True means channel number between input and output is the same, otherwise means differ + name : str + Base name of the operators + workspace : int + Workspace used in convolution operator + """ + if bottle_neck: + bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1') + act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1') + conv1 = mx.sym.Convolution(data=act1, num_filter=int(num_filter*0.25), kernel=(1,1), stride=stride, pad=(0,0), + no_bias=True, workspace=workspace, name=name + '_conv1') + bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2') + act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2') + conv2 = mx.sym.Convolution(data=act2, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(1,1), + no_bias=True, workspace=workspace, name=name + '_conv2') + bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3') + act3 = mx.sym.Activation(data=bn3, act_type='relu', name=name + '_relu3') + conv3 = mx.sym.Convolution(data=act3, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), no_bias=True, + workspace=workspace, name=name + '_conv3') + if dim_match: + shortcut = data + else: + shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True, + workspace=workspace, name=name+'_sc') + if memonger: + shortcut._set_attr(mirror_stage='True') + return conv3 + shortcut + else: + bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn1') + act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1') + conv1 = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(3,3), stride=stride, pad=(1,1), + no_bias=True, workspace=workspace, name=name + '_conv1') + bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn2') + act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2') + conv2 = mx.sym.Convolution(data=act2, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1,1), + no_bias=True, workspace=workspace, name=name + '_conv2') + if dim_match: + shortcut = data + else: + shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True, + workspace=workspace, name=name+'_sc') + if memonger: + shortcut._set_attr(mirror_stage='True') + return conv2 + shortcut + +def resnet(units, num_stages, filter_list, num_classes, image_shape, bottle_neck=True, bn_mom=0.9, workspace=256, dtype='float32', memonger=False): + """Return ResNet symbol of + Parameters + ---------- + units : list + Number of units in each stage + num_stages : int + Number of stage + filter_list : list + Channel size of each stage + num_classes : int + Ouput size of symbol + dataset : str + Dataset type, only cifar10 and imagenet supports + workspace : int + Workspace used in convolution operator + dtype : str + Precision (float32 or float16) + """ + num_unit = len(units) + assert(num_unit == num_stages) + data = mx.sym.Variable(name='data') + if dtype == 'float32': + # data = mx.sym.identity(data=data, name='id') + data = data + else: + if dtype == 'float16': + data = mx.sym.Cast(data=data, dtype=np.float16) + data = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='bn_data') + (nchannel, height, width) = image_shape + if height <= 32: # such as cifar10 + body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(3, 3), stride=(1,1), pad=(1, 1), + no_bias=True, name="conv0", workspace=workspace) + else: # often expected to be 224 such as imagenet + body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(7, 7), stride=(2,2), pad=(3, 3), + no_bias=True, name="conv0", workspace=workspace) + body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0') + body = mx.sym.Activation(data=body, act_type='relu', name='relu0') + body = mx.sym.Pooling(data=body, kernel=(3, 3), stride=(2,2), pad=(1,1), pool_type='max') + + for i in range(num_stages): + body = residual_unit(body, filter_list[i+1], (1 if i==0 else 2, 1 if i==0 else 2), False, + name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, workspace=workspace, + memonger=memonger) + for j in range(units[i]-1): + body = residual_unit(body, filter_list[i+1], (1,1), True, name='stage%d_unit%d' % (i + 1, j + 2), + bottle_neck=bottle_neck, workspace=workspace, memonger=memonger) + bn1 = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn1') + relu1 = mx.sym.Activation(data=bn1, act_type='relu', name='relu1') + # Although kernel is not used here when global_pool=True, we should put one + pool1 = mx.sym.Pooling(data=relu1, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1') + flat = mx.sym.Flatten(data=pool1) + try: + fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1', flatten=False) + except: + fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1') + if dtype == 'float16': + fc1 = mx.sym.Cast(data=fc1, dtype=np.float32) + return mx.sym.softmax(data=fc1, name='softmax') + +def get_symbol(num_classes, num_layers, image_shape, conv_workspace=256, dtype='float32', **kwargs): + """ + Adapted from https://github.com/tornadomeet/ResNet/blob/master/train_resnet.py + Original author Wei Wu + """ + image_shape = [int(l) for l in image_shape.split(',')] + (nchannel, height, width) = image_shape + if height <= 28: + num_stages = 3 + if (num_layers-2) % 9 == 0 and num_layers >= 164: + per_unit = [(num_layers-2)//9] + filter_list = [16, 64, 128, 256] + bottle_neck = True + elif (num_layers-2) % 6 == 0 and num_layers < 164: + per_unit = [(num_layers-2)//6] + filter_list = [16, 16, 32, 64] + bottle_neck = False + else: + raise ValueError("no experiments done on num_layers {}, you can do it yourself".format(num_layers)) + units = per_unit * num_stages + else: + if num_layers >= 50: + filter_list = [64, 256, 512, 1024, 2048] + bottle_neck = True + else: + filter_list = [64, 64, 128, 256, 512] + bottle_neck = False + num_stages = 4 + if num_layers == 18: + units = [2, 2, 2, 2] + elif num_layers == 34: + units = [3, 4, 6, 3] + elif num_layers == 50: + units = [3, 4, 6, 3] + elif num_layers == 101: + units = [3, 4, 23, 3] + elif num_layers == 152: + units = [3, 8, 36, 3] + elif num_layers == 200: + units = [3, 24, 36, 3] + elif num_layers == 269: + units = [3, 30, 48, 8] + else: + raise ValueError("no experiments done on num_layers {}, you can do it yourself".format(num_layers)) + + return resnet(units = units, + num_stages = num_stages, + filter_list = filter_list, + num_classes = num_classes, + image_shape = image_shape, + bottle_neck = bottle_neck, + workspace = conv_workspace, + dtype = dtype) diff --git a/tests/python/frontend/mxnet/model_zoo/squeezenet.py b/tests/python/frontend/mxnet/model_zoo/squeezenet.py new file mode 100644 index 000000000000..deb896a21385 --- /dev/null +++ b/tests/python/frontend/mxnet/model_zoo/squeezenet.py @@ -0,0 +1,76 @@ +""" +Symbol of SqueezeNet + +Reference: +Iandola, Forrest N., et al. +"Squeezenet: Alexnet-level accuracy with 50x fewer parameters and< 0.5 mb model size." (2016). +""" + +import mxnet as mx + +# Helpers +def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels): + net = _make_fire_conv(net, squeeze_channels, 1, 0) + + left = _make_fire_conv(net, expand1x1_channels, 1, 0) + right = _make_fire_conv(net, expand3x3_channels, 3, 1) + # NOTE : Assume NCHW layout here + net = mx.sym.concat(left, right, dim=1) + + return net + +def _make_fire_conv(net, channels, kernel_size, padding=0): + net = mx.sym.Convolution(net, num_filter=channels, kernel=(kernel_size, kernel_size), + pad=(padding, padding)) + net = mx.sym.Activation(net, act_type='relu') + return net + +# Net +def get_symbol(num_classes=1000, version='1.0', **kwargs): + """Get symbol of SqueezeNet + + Parameters + ---------- + num_classes: int + The number of classification results + + version : str, optional + "1.0" or "1.1" of SqueezeNet + """ + assert version in ['1.0', '1.1'], ("Unsupported SqueezeNet version {version}:" + "1.0 or 1.1 expected".format(version=version)) + net = mx.sym.Variable("data") + if version == '1.0': + net = mx.sym.Convolution(net, num_filter=96, kernel=(7, 7), stride=(2, 2), pad=(3, 3)) + net = mx.sym.Activation(net, act_type='relu') + net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2)) + net = _make_fire(net, 16, 64, 64) + net = _make_fire(net, 16, 64, 64) + net = _make_fire(net, 32, 128, 128) + net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2)) + net = _make_fire(net, 32, 128, 128) + net = _make_fire(net, 48, 192, 192) + net = _make_fire(net, 48, 192, 192) + net = _make_fire(net, 64, 256, 256) + net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2)) + net = _make_fire(net, 64, 256, 256) + else: + net = mx.sym.Convolution(net, num_filter=64, kernel=(3, 3), stride=(2, 2), pad=(1, 1)) + net = mx.sym.Activation(net, act_type='relu') + net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2)) + net = _make_fire(net, 16, 64, 64) + net = _make_fire(net, 16, 64, 64) + net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2)) + net = _make_fire(net, 32, 128, 128) + net = _make_fire(net, 32, 128, 128) + net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2)) + net = _make_fire(net, 48, 192, 192) + net = _make_fire(net, 48, 192, 192) + net = _make_fire(net, 64, 256, 256) + net = _make_fire(net, 64, 256, 256) + net = mx.sym.Dropout(net, p=0.5) + net = mx.sym.Convolution(net, num_filter=num_classes, kernel=(1, 1)) + net = mx.sym.Activation(net, act_type='relu') + net = mx.sym.Pooling(data=net, global_pool=True, kernel=(13, 13), pool_type='avg') + net = mx.sym.flatten(net) + return mx.sym.softmax(net) diff --git a/tests/python/frontend/mxnet/model_zoo/vgg.py b/tests/python/frontend/mxnet/model_zoo/vgg.py new file mode 100644 index 000000000000..68215bb80aaa --- /dev/null +++ b/tests/python/frontend/mxnet/model_zoo/vgg.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""References: + +Simonyan, Karen, and Andrew Zisserman. "Very deep convolutional networks for +large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014). +""" + +import mxnet as mx +import numpy as np + +def get_feature(internel_layer, layers, filters, batch_norm = False, **kwargs): + for i, num in enumerate(layers): + for j in range(num): + internel_layer = mx.sym.Convolution(data = internel_layer, kernel=(3, 3), pad=(1, 1), num_filter=filters[i], name="conv%s_%s" %(i + 1, j + 1)) + if batch_norm: + internel_layer = mx.symbol.BatchNorm(data=internel_layer, name="bn%s_%s" %(i + 1, j + 1)) + internel_layer = mx.sym.Activation(data=internel_layer, act_type="relu", name="relu%s_%s" %(i + 1, j + 1)) + internel_layer = mx.sym.Pooling(data=internel_layer, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool%s" %(i + 1)) + return internel_layer + +def get_classifier(input_data, num_classes, **kwargs): + flatten = mx.sym.Flatten(data=input_data, name="flatten") + try: + fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6", flatten=False) + relu6 = mx.sym.Activation(data=fc6, act_type="relu", name="relu6") + drop6 = mx.sym.Dropout(data=relu6, p=0.5, name="drop6") + fc7 = mx.sym.FullyConnected(data=drop6, num_hidden=4096, name="fc7", flatten=False) + relu7 = mx.sym.Activation(data=fc7, act_type="relu", name="relu7") + drop7 = mx.sym.Dropout(data=relu7, p=0.5, name="drop7") + fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8", flatten=False) + except: + fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6") + relu6 = mx.sym.Activation(data=fc6, act_type="relu", name="relu6") + drop6 = mx.sym.Dropout(data=relu6, p=0.5, name="drop6") + fc7 = mx.sym.FullyConnected(data=drop6, num_hidden=4096, name="fc7") + relu7 = mx.sym.Activation(data=fc7, act_type="relu", name="relu7") + drop7 = mx.sym.Dropout(data=relu7, p=0.5, name="drop7") + fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8") + return fc8 + +def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32', **kwargs): + """ + Parameters + ---------- + num_classes : int, default 1000 + Number of classification classes. + num_layers : int + Number of layers for the variant of densenet. Options are 11, 13, 16, 19. + batch_norm : bool, default False + Use batch normalization. + dtype: str, float32 or float16 + Data precision. + """ + vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]), + 13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]), + 16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]), + 19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])} + if num_layers not in vgg_spec: + raise ValueError("Invalide num_layers {}. Possible choices are 11,13,16,19.".format(num_layers)) + layers, filters = vgg_spec[num_layers] + data = mx.sym.Variable(name="data") + if dtype == 'float16': + data = mx.sym.Cast(data=data, dtype=np.float16) + feature = get_feature(data, layers, filters, batch_norm) + classifier = get_classifier(feature, num_classes) + if dtype == 'float16': + classifier = mx.sym.Cast(data=classifier, dtype=np.float32) + symbol = mx.sym.softmax(data=classifier, name='softmax') + return symbol diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py new file mode 100644 index 000000000000..81a12b041ed7 --- /dev/null +++ b/tests/python/frontend/mxnet/test_forward.py @@ -0,0 +1,214 @@ +import numpy as np + +import tvm +from tvm.contrib import graph_runtime +from tvm.relay.testing.config import ctx_list +from tvm import relay +import mxnet as mx + +from mxnet import gluon +from mxnet.gluon.model_zoo import vision +import model_zoo + + +def verify_mxnet_frontend_impl(mx_symbol, + data_shape=(1, 3, 224, 224), + out_shape=(1, 1000), + gluon_impl=False, + name=None, + dtype='float32'): + """Use name different from test to avoid let nose pick it up""" + if gluon_impl: + def get_gluon_output(name, x): + net = vision.get_model(name) + net.collect_params().initialize(mx.init.Xavier()) + net_sym = gluon.nn.SymbolBlock(outputs=net(mx.sym.var('data')), + inputs=mx.sym.var('data'), + params=net.collect_params()) + out = net_sym(mx.nd.array(x.astype(dtype))).asnumpy() + return out, net_sym + else: + def get_mxnet_output(symbol, x, dtype='float32'): + from collections import namedtuple + Batch = namedtuple('Batch', ['data']) + mod = mx.mod.Module(symbol, label_names=None) + mod.bind(data_shapes=[('data', x.shape)], for_training=False) + mod.init_params() + mod.forward(Batch([mx.nd.array(x.astype(dtype))])) + out = mod.get_outputs()[0].asnumpy() + args, auxs = mod.get_params() + return out, args, auxs + + def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'): + shape_dict = {"data": x.shape} + if gluon_impl: + new_sym, params = relay.frontend.from_mxnet(symbol, shape_dict) + else: + new_sym, params = relay.frontend.from_mxnet(symbol, + shape_dict, + arg_params=args, + aux_params=auxs) + with relay.build_config(opt_level=3): + graph, lib, params = relay.build(new_sym, target, params=params) + m = graph_runtime.create(graph, lib, ctx) + # set inputs + m.set_input("data", tvm.nd.array(x.astype(dtype))) + m.set_input(**params) + m.run() + # get outputs + out = m.get_output(0, tvm.nd.empty(out_shape, dtype)) + return out.asnumpy() + + # random input + x = np.random.uniform(size=data_shape) + if gluon_impl: + gluon_out, gluon_sym = get_gluon_output(name, x) + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(gluon_sym, x, None, None, target, ctx, dtype) + tvm.testing.assert_allclose(gluon_out, tvm_out, rtol=1e-5, atol=1e-5) + else: + mx_out, args, auxs = get_mxnet_output(mx_symbol, x, dtype) + assert "data" not in args + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype) + tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5) + +def test_forward_mlp(): + mlp = model_zoo.mx_mlp() + verify_mxnet_frontend_impl(mlp, + data_shape=(1, 1, 28, 28), + out_shape=(1, 10)) + +def test_forward_vgg(): + for n in [11]: + mx_sym = model_zoo.mx_vgg(n) + verify_mxnet_frontend_impl(mx_sym) + +def test_forward_resnet(): + for n in [18]: + mx_sym = model_zoo.mx_resnet(18) + verify_mxnet_frontend_impl(mx_sym) + +def test_forward_elu(): + data = mx.sym.var('data') + data = mx.sym.concat(data, -data, dim=1) # negative part explicitly + mx_sym = mx.sym.LeakyReLU(data, act_type='elu') + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) + +def test_forward_rrelu(): + data = mx.sym.var('data') + data = mx.sym.concat(data, -data, dim=1) # negative part explicitly + mx_sym = mx.sym.LeakyReLU(data, act_type='rrelu', lower_bound=0.3, upper_bound=0.7) + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) + +def test_forward_prelu(): + data = mx.sym.var('data') + data = mx.sym.concat(data, -data, dim=1) # negative part explicitly + mx_sym = mx.sym.LeakyReLU(data, act_type='prelu') + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) + +def test_forward_softrelu(): + data = mx.sym.var('data') + data = mx.sym.concat(data, -data, dim=1) # negative part explicitly + mx_sym = mx.sym.Activation(data, act_type='softrelu') + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) + +def test_forward_fc_flatten(): + # test flatten=True option in mxnet 0.11.1 + data = mx.sym.var('data') + try: + mx_sym = mx.sym.FullyConnected(data, num_hidden=100, flatten=True) + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100)) + mx_sym = mx.sym.FullyConnected(mx.sym.Flatten(data), num_hidden=100, flatten=False) + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100)) + except: + pass + +def test_forward_clip(): + data = mx.sym.var('data') + data = mx.sym.concat(data, -data, dim=1) # negative part explicity + mx_sym = mx.sym.clip(data, a_min=0, a_max=1) + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) + +def test_forward_split(): + data = mx.sym.var('data') + mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=False) + verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 1, 2, 1)) + +def test_forward_split_squeeze(): + data = mx.sym.var('data') + mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=True) + verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 2, 1)) + +def test_forward_expand_dims(): + data = mx.sym.var('data') + mx_sym = mx.sym.expand_dims(data, axis=1) + verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 1, 3, 4)) + +def test_forward_pooling(): + data = mx.sym.var('data') + mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='avg') + verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8)) + + mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='max') + verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8)) + +def test_forward_lrn(): + data = mx.sym.var('data') + mx_sym = mx.sym.LRN(data, alpha=2, beta=2, knorm=1, nsize=5) + verify_mxnet_frontend_impl(mx_sym, (1, 10, 24, 24), (1, 10, 24, 24)) + +def test_forward_ones(): + data = mx.sym.var('data') + ones = mx.sym.ones(shape=(2, 3, 4), dtype='float32') + mx_sym = mx.sym.elemwise_add(data, ones) + verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4)) + +def test_forward_zeros(): + data = mx.sym.var('data') + zeros = mx.sym.zeros(shape=(2, 3, 4), dtype='float32') + mx_sym = mx.sym.elemwise_add(data, zeros) + verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4)) + +def test_forward_ones_like(): + data = mx.sym.var('data') + mx_sym = mx.sym.ones_like(data, dtype='float32') + verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4)) + +def test_forward_zeros_like(): + data = mx.sym.var('data') + mx_sym = mx.sym.zeros_like(data, dtype='float32') + verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4)) + +def test_forward_argmax(): + data = mx.sym.var('data') + mx_sym = mx.sym.argmax(data, axis=1) + verify_mxnet_frontend_impl(mx_sym, (5, 3), (5,)) + +def test_forward_argmin(): + data = mx.sym.var('data') + mx_sym = mx.sym.argmin(data, axis=0) + verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,)) + + +if __name__ == '__main__': + test_forward_mlp() + test_forward_vgg() + test_forward_resnet() + test_forward_elu() + test_forward_rrelu() + test_forward_prelu() + test_forward_softrelu() + test_forward_fc_flatten() + test_forward_clip() + test_forward_split() + test_forward_split_squeeze() + test_forward_expand_dims() + test_forward_pooling() + test_forward_lrn() + test_forward_ones() + test_forward_zeros() + test_forward_ones_like() + test_forward_zeros_like() + test_forward_argmax() + test_forward_argmin() diff --git a/tests/python/frontend/mxnet/test_graph.py b/tests/python/frontend/mxnet/test_graph.py new file mode 100644 index 000000000000..c2bed8829b81 --- /dev/null +++ b/tests/python/frontend/mxnet/test_graph.py @@ -0,0 +1,101 @@ +import mxnet as mx +from tvm import relay +import model_zoo + +def compare_graph(f1, f2): + f1 = relay.ir_pass.infer_type(f1) + f2 = relay.ir_pass.infer_type(f2) + assert relay.ir_pass.alpha_equal(f1, f2) + +def test_mlp(): + shape = {"data": (1, 1, 28, 28)} + mx_fun = model_zoo.mx_mlp() + from_mx_fun, _ = relay.frontend.from_mxnet(mx_fun, shape=shape) + relay_fun = model_zoo.relay_mlp() + compare_graph(from_mx_fun, relay_fun) + + +def test_vgg(): + shape = {"data": (1, 3, 224, 224)} + for n in [11, 13, 16, 19]: + mx_sym = model_zoo.mx_vgg(n) + from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape=shape) + relay_sym = model_zoo.relay_vgg(n) + compare_graph(from_mx_sym, relay_sym) + + +def test_resnet(): + shape = {"data": (1, 3, 224, 224)} + for n in [18, 34, 50, 101]: + mx_sym = model_zoo.mx_resnet(n) + from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape=shape) + relay_sym = model_zoo.relay_resnet(n) + compare_graph(from_mx_sym, relay_sym) + + +def test_squeezenet(): + shape = {"data": (1, 3, 224, 224)} + for version in ['1.0', '1.1']: + mx_sym = model_zoo.mx_squeezenet(version) + from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape) + relay_sym = model_zoo.relay_squeezenet(version) + compare_graph(from_mx_sym, relay_sym) + + +def test_inception_v3(): + shape = {"data": (1, 3, 299, 299)} + mx_sym = model_zoo.mx_inception_v3() + from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape) + relay_sym = model_zoo.relay_inception_v3() + compare_graph(from_mx_sym, relay_sym) + + +def test_dqn(): + shape = {"data": (1, 4, 84, 84)} + mx_sym = model_zoo.mx_dqn() + from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape) + relay_sym = model_zoo.relay_dqn() + compare_graph(from_mx_sym, relay_sym) + + +def test_dcgan(): + shape = {"data": (2, 100)} + mx_sym = model_zoo.mx_dcgan() + from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape) + relay_sym = model_zoo.relay_dcgan(batch_size=2) + compare_graph(from_mx_sym, relay_sym) + + +def test_multi_outputs(): + xshape = (10, 27) + yshape = (10, 9) + + def mx_compose(F, **kwargs): + x = F.sym.Variable("x") + y = F.sym.Variable("y") + z = F.sym.split(x, **kwargs) + return F.sym.broadcast_sub(F.sym.broadcast_add(z[0], z[2]), y) + + def relay_compose(F, **kwargs): + x = F.var("x", shape=xshape) + y = F.var("y", shape=yshape) + z = F.split(x, **kwargs) + z = F.subtract(F.add(z[0], z[2]), y) + return relay.Function(relay.ir_pass.free_vars(z), z) + + mx_sym = mx_compose(mx, num_outputs=3, axis=1) + from_mx_sym, _ = relay.frontend.from_mxnet( + mx_sym, shape={"x":xshape, "y":yshape}) + relay_sym = relay_compose(relay, indices_or_sections=3, axis=1) + compare_graph(from_mx_sym, relay_sym) + + +if __name__ == "__main__": + test_mlp() + test_resnet() + test_vgg() + test_multi_outputs() + test_dqn() + test_dcgan() + test_squeezenet() + test_inception_v3() diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 43c11c4509d1..806b63b7c6f5 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -115,7 +115,7 @@ def test_squeeze_bad_axes_infer_type(): def test_reshape_infer_type(): - n, t, d1, d2 = tvm.var("n"), tvm.var("t"), 100, 20 + n, t, d1, d2 = 10, 20, 100, 20 x = relay.var("x", relay.TensorType((n, t, d1, d2), "float32")) y = relay.reshape(x, newshape=(n, t, 2000)) assert "newshape=" in y.astext() diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 9bc62b2c0249..ef7f1221a70c 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -332,7 +332,7 @@ inline Tensor concatenate(const Array& inputs, * \return A Tensor whose op member is the split operation */ inline Array split(const Tensor& x, - Array split_indices, + Array split_indices, int axis, std::string name = "tensor", std::string tag = kInjective) { @@ -342,14 +342,15 @@ inline Array split(const Tensor& x, CHECK_LT(axis, x->shape.size()) << "axis out of bounds"; auto src_axis_size = static_cast(GetConstInt(x->shape[axis])); - - auto split_indices_val = GetConstIntValues(split_indices, "split_indices"); - CHECK(std::is_sorted(split_indices_val.begin(), split_indices_val.end())) << - "split_indices must be sorted"; - std::vector begin_ids; begin_ids.push_back(0); - std::copy(split_indices_val.begin(), split_indices_val.end(), std::back_inserter(begin_ids)); + + for (Integer idx : split_indices) { + int val = static_cast(idx->value); + CHECK_GT(val, begin_ids.back()) + << "split_indices must be sorted"; + begin_ids.push_back(val); + } Array< Array > out_shapes; for (size_t i = 0; i < begin_ids.size(); ++i) { @@ -508,10 +509,10 @@ inline Tensor strided_slice(const Tensor& x, * \return A Tensor whose op member is the split operation */ inline Array split_sections(const Tensor& x, - int num_sections, - int axis, - std::string name = "tensor", - std::string tag = kInjective) { + int num_sections, + int axis, + std::string name = "tensor", + std::string tag = kInjective) { if (axis < 0) { axis += static_cast(x->shape.size()); } @@ -524,7 +525,7 @@ inline Array split_sections(const Tensor& x, << "num_sections must be an integer factor of the size of axis " << axis << " (" << src_axis_size << ")"; - Array split_indices; + Array split_indices; auto seg_size = src_axis_size / num_sections; for (int i = 0; i < num_sections; ++i) { // region at index 0 is added by split() diff --git a/topi/include/topi/vision/yolo/region.h b/topi/include/topi/vision/yolo/region.h index 88553fc29b8a..7d303f445ac4 100644 --- a/topi/include/topi/vision/yolo/region.h +++ b/topi/include/topi/vision/yolo/region.h @@ -53,7 +53,7 @@ inline Tensor region(const Tensor &data, input_shape[2], input_shape[3]}; auto data_block = reshape(data, intermediate_shape); - Array split_indices; + Array split_indices; for (int i = 1; i < split_size; ++i) { split_indices.push_back(i); } From 1845ff9c7a9f92859be65934ed8381610acafa31 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 25 Nov 2018 09:38:50 -0800 Subject: [PATCH 2/2] fix lint --- python/tvm/relay/frontend/mxnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index c5f200ccf835..9d1bd0deffa9 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -183,7 +183,7 @@ def _stable_softrelu(x): one = _expr.const(1, dtype="float32") exp_neg_abs_x = _op.exp(_op.negative(_op.abs(x))) return _op.add(_op.log(_op.add(one, exp_neg_abs_x)), - _op.nn.relu(x)) + _op.nn.relu(x)) return _stable_softrelu(inputs[0]) raise RuntimeError("Do not support act_type: {}".format(act_type))