From f94a5293567f2d104dc14048ed00b1d62ec18a32 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 20 Nov 2018 15:12:02 -0800 Subject: [PATCH] Port from_nnvm to NNVM as to_relay Repair imports First test passes Fix bug in axes passing Add implementations for upsampling and pad Add support for bias with conv ops Rebase fixup Borrow the reshape changes while waiting on #2159 Debugging tests Repair after rebase Refactor to use common functionality Fix helper func Remove final eval Remove some debuggin Remove unecessary ws change Fix multiple outputs Tests pass Enable all tests Rebase repair Fix last test case Fix last test case All tests but RNNs pass Fix tests Remove final eval Reformat Fix bugs in MxNet Fix 3.5 support in Relay, and rename Factor common functionality into one file Apply most of code review Fix linting Fix NNVM linting Fix test error Fix MLP test Fix linting error One more linting issue Clean up diff and fix bugs Retrigger disable MLP test Retrigger Retrigger Roll back NNVM change Retrigger Retrigger Retrigger --- nnvm/python/nnvm/to_relay.py | 506 ++++++++++++++++++++ nnvm/tests/python/compiler/test_to_relay.py | 41 ++ python/tvm/relay/frontend/common.py | 55 ++- python/tvm/relay/frontend/mxnet.py | 137 +----- python/tvm/relay/frontend/nnvm_common.py | 132 +++++ python/tvm/relay/op/_transform.py | 1 + python/tvm/relay/op/nn/_nn.py | 7 +- src/relay/backend/graph_plan_memory.cc | 3 + src/relay/ir/alpha_equal.cc | 10 +- src/relay/op/nn/upsampling.cc | 48 +- tests/python/relay/frontend/test_keras.py | 332 +++++++++++++ topi/include/topi/image/resize.h | 3 +- 12 files changed, 1116 insertions(+), 159 deletions(-) create mode 100644 nnvm/python/nnvm/to_relay.py create mode 100644 nnvm/tests/python/compiler/test_to_relay.py create mode 100644 python/tvm/relay/frontend/nnvm_common.py create mode 100644 tests/python/relay/frontend/test_keras.py diff --git a/nnvm/python/nnvm/to_relay.py b/nnvm/python/nnvm/to_relay.py new file mode 100644 index 000000000000..318ff1ee92dd --- /dev/null +++ b/nnvm/python/nnvm/to_relay.py @@ -0,0 +1,506 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-argument +"""Convert an NNVM graph to Relay.""" +import json +from tvm import relay, nd +from tvm.relay import op, expr, var +from tvm.relay.frontend.common import StrAttrsDict +from tvm.relay.frontend.nnvm_common import _rename +import numpy +from .symbol import Symbol +from .compiler import graph_attr +from .graph import create as graph_create + +def _nn_batch_flatten(children, attrs, odtype='float32'): + assert len(children) == 1 + return op.nn.batch_flatten(children[0]) + + +def _dense(children, attrs, odtype='float32'): + use_bias = attrs.get_bool('use_bias', True) + units = attrs.get_int('units') + dense = op.nn.dense(children[0], children[1], units=units) + if use_bias: + return op.nn.bias_add(dense, children[2]) + else: + return dense + +def _nn_softmax(children, attrs, odtype='float32'): + assert len(children) == 1 + axis = attrs.get_int('axis', 1) + return op.nn.softmax(children[0], axis) + +def _conv2d(children, attrs, odtype='float32'): + use_bias = attrs.get_bool('use_bias', False) + + if use_bias: + data, weight, bias = children + else: + data, weight = children + + strides = attrs.get_int_tuple('strides', (1, 1)) + padding = attrs.get_int_tuple('padding', (0, 0)) + dilation = attrs.get_int_tuple('dilation', (1, 1)) + groups = attrs.get_int('groups', 1) + data_layout = attrs.get_str('layout', 'NCHW') + weight_layout = attrs.get_str('kernel_layout', 'OIHW') + out_layout = '' + out_dtype = attrs.get_str('out_dtype', '') + + conv_out = op.nn.conv2d( + data, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout=data_layout, + weight_layout=weight_layout, + out_layout=out_layout, + out_dtype=out_dtype) + + if use_bias: + return op.nn.bias_add(conv_out, bias) + else: + return conv_out + + +def _conv2d_transpose(children, attrs, odtype='float32'): + use_bias = attrs.get_bool('use_bias', False) + + if use_bias: + data, weight, bias = children + else: + data, weight = children + + strides = attrs.get_int_tuple('strides', (1, 1)) + padding = attrs.get_int_tuple('padding', (0, 0)) + dilation = attrs.get_int_tuple('dilation', (1, 1)) + groups = attrs.get_int('groups', 1) + data_layout = attrs.get_str('layout', 'NCHW') + weight_layout = attrs.get_str('kernel_layout', 'OIHW') + out_dtype = attrs.get_str('out_dtype', '') + + out_conv2d = op.nn.conv2d_transpose( + data, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout=data_layout, + weight_layout=weight_layout, + out_dtype=out_dtype) + + if use_bias: + return op.nn.bias_add(out_conv2d, bias) + else: + return out_conv2d + + +def _batch_norm(children, attrs, odtype='float32'): + data, gamma, beta, moving_mean, moving_view = children + axis = attrs.get_int('axis', 1) + epsilon = attrs.get_float('epsilon', 1e-05) + center = attrs.get_bool('center', True) + scale = attrs.get_bool('scale', True) + + return op.nn.batch_norm( + data, + gamma, + beta, + moving_mean, + moving_view, + axis=axis, + epsilon=epsilon, + center=center, + scale=scale)[0] + + +def _max_pool2d(children, attrs, odtype='float32'): + assert len(children) == 1 + data = children[0] + pool_size = attrs.get_int_tuple('pool_size', (1, 1)) + strides = attrs.get_int_tuple('strides', (1, 1)) + padding = attrs.get_int_tuple('padding', (0, 0)) + layout = attrs.get_int_tuple('layout', 'NCHW') + ceil_mode = attrs.get_bool('ceil_mode', False) + + return op.nn.max_pool2d( + data, + pool_size=pool_size, + strides=strides, + padding=padding, + layout=layout, + ceil_mode=ceil_mode) + + +def _reshape(children, attrs, odtype='float32'): + data = children[0] + shape = attrs.get_int_list('shape') + return op.reshape(data, shape) + + +def _transpose(children, attrs, odtype='float32'): + axes = attrs.get_int_list('axes', None) + return op.transpose(children[0], axes=axes) + + +def _add(children, attrs, odtype='float32'): + if len(children) == 1: + left = children[0] + scalar = attrs.get_float('scalar') + right = relay.const(scalar, dtype=odtype) + else: + assert len(children) == 2 + left = children[0] + right = children[1] + + return op.add(left, right) + + +def _subtract(children, attrs, odtype='float32'): + if len(children) == 1: + left = children[0] + scalar = attrs.get_float('scalar') + right = relay.const(scalar, dtype=odtype) + else: + assert len(children) == 2 + left = children[0] + right = children[1] + + return op.subtract(left, right) + + +def _rsubtract(children, attrs, odtype='float32'): + if len(children) == 1: + left = children[0] + scalar = attrs.get_float('scalar') + right = relay.const(scalar, dtype=odtype) + else: + assert len(children) == 2 + left = children[0] + right = children[1] + + return op.subtract(right, left) + + +def _multiply(children, attrs, odtype='float32'): + if len(children) == 1: + left = children[0] + scalar = attrs.get_float('scalar') + right = relay.const(scalar, dtype=odtype) + else: + assert len(children) == 2 + left = children[0] + right = children[1] + + return op.multiply(left, right) + + +def _divide(children, attrs, odtype='float32'): + if len(children) == 1: + left = children[0] + scalar = attrs.get_float('scalar') + right = relay.const(scalar, dtype=odtype) + else: + assert len(children) == 2 + left = children[0] + right = children[1] + + return op.divide(left, right) + + +def _rshift(children, attrs, odtype='float32'): + if len(children) == 1: + left = children[0] + scalar = attrs.get_float('scalar') + right = relay.const(scalar, dtype='int32') + else: + assert len(children) == 2 + left = children[0] + right = children[1] + + return op.right_shift(left, right) + + +def _clip(children, attrs, odtype='float32'): + a_min = attrs.get_float('a_min') + a_max = attrs.get_float('a_max') + return op.clip(children[0], a_min, a_max) + + +def _cast(children, attrs, odtype='float32'): + data = children[0] + dtype = attrs.get_str('dtype') + return data.astype(dtype) + + +def _expand_dims(children, attrs, odtype='float32'): + data = children[0] + axis = attrs.get_int('axis') + num_newaxis = attrs.get_int('num_newaxis', 1) + return op.transform.expand_dims(data, axis, num_newaxis=num_newaxis) + + +def broadcast_to(children, attrs, odtype='float32'): + # TODO(@jroesch) export broadcast to? + data = children[0] + shape = attrs.get_int_tuple('shape') + array = numpy.zeros(shape).astype(odtype) + rconst = relay.Constant(nd.array(array)) + return op.broadcast_to_like(data, rconst) + +def _copy(children, attrs, odtype='float32'): + return op.copy(children[0]) + + +def _global_avg_pool2d(children, attrs, odtype='float32'): + data = children[0] + layout = attrs.get_str('layout', "NCHW") + return op.nn.global_avg_pool2d(data, layout) + + +def _avg_pool2d(children, attrs, odtype='float32'): + data = children[0] + pool_size = attrs.get_int_tuple('pool_size', (1, 1)) + strides = attrs.get_int_tuple('strides', (1, 1)) + padding = attrs.get_int_tuple('padding', (0, 0)) + layout = attrs.get_str('layout', "NCHW") + ceil_mode = attrs.get_bool('ceil_mode', False) + count_include_pad = attrs.get_bool('layout', False) + return op.nn.avg_pool2d( + data, + pool_size=pool_size, + strides=strides, + padding=padding, + layout=layout, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad) + + +def _upsampling(children, attrs, odtype='float32'): + scale = attrs.get_int('scale') + layout = attrs.get_str('layout', 'NCHW') + method = attrs.get_str('method', 'NEAREST_NEIGHBOR') + return op.nn.upsampling( + children[0], + scale=scale, + layout=layout, + method=method) + + +def _pad(children, attrs, odtype='float32'): + pad_value = attrs.get_float('pad_value', 0.0) + pad_width = attrs.get_tuple_tuple_int('pad_width') + return op.nn.pad(children[0], pad_width, pad_value=pad_value) + +def _leaky_relu(children, attrs, odtype='float32'): + alpha = attrs.get_float('alpha') + return op.nn.leaky_relu(children[0], alpha) + + +def _full_like(children, attrs, odtype='float32'): + fill_value = relay.const(attrs.get_float('fill_value'), dtype='float32') + return op.full_like(children[0], fill_value) + + +def _greater(children, attrs, odtype='float32'): + out_type = attrs.get_str('out_type') + if out_type: + return op.greater(children[0], children[1]).astype(out_type) + else: + return op.greater(children[0], children[1]) + + +def _greater_equal(children, attrs, odtype='float32'): + out_type = attrs.get_str('out_type', None) + if out_type: + return op.greater_equal(children[0], children[1]).astype(out_type) + else: + return op.greater_equal(children[0], children[1]) + + +def _less(children, attrs, odtype='float32'): + out_type = attrs.get_str('out_type', None) + if out_type: + return op.less(children[0], children[1]).astype(out_type) + else: + return op.less(children[0], children[1]) + + +def _less_equal(children, attrs, odtype='float32'): + out_type = attrs.get_str('out_type', None) + if out_type: + return op.less_equal(children[0], children[1]).astype(out_type) + else: + return op.less_equal(children[0], children[1]) + + +def _strided_slice(children, attrs, odtype='float32'): + begin = attrs.get_int_list('begin') + end = attrs.get_int_list('end') + strides = attrs.get_int_list('strides', None) + return op.strided_slice(children[0], begin, end, strides=strides) + + +def _split(children, attrs, odtype='float32'): + indices_or_sections = None + try: + indices_or_sections = attrs.get_int('indices_or_sections', None) + except ValueError: + indices_or_sections = indices_or_sections or attrs.get_int_tuple( + 'indices_or_sections') + + axis = attrs.get_int('axis', 0) + + return op.split(children[0], indices_or_sections, axis) + +def _squeeze(children, attrs, odtype='float32'): + axis = None + try: + axis = [attrs.get_int('axis', None)] + except ValueError: + axis = axis or attrs.get_int_tuple('axis', None) + + return op.squeeze(children[0], axis) + +NNVM_OP_2_RELAY_OP = { + 'flatten': _nn_batch_flatten, + 'dense': _dense, + 'softmax': _nn_softmax, + 'conv2d': _conv2d, + 'batch_norm': _batch_norm, + 'max_pool2d': _max_pool2d, + 'reshape': _reshape, + 'transpose': _transpose, + # Addition + '__add_scalar__': _add, + 'broadcast_add': _add, + 'elemwise_add': _add, + # Subtraction + '__sub_scalar__': _subtract, + '__rsub_scalar__': _rsubtract, + 'broadcast_sub': _subtract, + 'elemwise_sub': _subtract, + # Multiply + '__mul_scalar__': _multiply, + 'broadcast_mul': _multiply, + 'elemwise_mul': _multiply, + # Division + '__div_scalar__': _divide, + 'broadcast_div': _divide, + 'elemwise_div': _divide, + # Negative + 'negative': _rename("negative"), + + # Comparsion + 'greater': _greater, + 'greater_equal': _greater_equal, + 'less': _less, + 'less_equal': _less_equal, + + # Activations + 'sigmoid': _rename('sigmoid'), + 'relu': _rename('nn.relu'), + 'exp': _rename('exp'), + 'log': _rename('log'), + 'tanh': _rename('tanh'), + 'leaky_relu': _leaky_relu, + 'clip': _clip, + 'round': _rename('round'), + 'cast': _cast, + 'expand_dims': _expand_dims, + 'broadcast_to': broadcast_to, + '__rshift_scalar__': _rshift, + 'copy': _copy, + 'global_avg_pool2d': _global_avg_pool2d, + 'avg_pool2d': _avg_pool2d, + 'conv2d_transpose': _conv2d_transpose, + 'upsampling': _upsampling, + 'pad': _pad, + 'full_like': _full_like, + 'strided_slice': _strided_slice, + 'split': _split, + 'squeeze': _squeeze, +} + + +def to_relay(graph, shape_dict, dtype_dict, params): + """Convert an NNVM graph into the corresponding Relay expression. + + Parameters + ---------- + graph : Graph + The input graph. + + shape_dict : dict of str to shape + The input shape. + + dtype_dict : dict of str to shape + The input shape. + + params : dict of str to array + The parameters. + + Returns + ------- + (expr, params) : Tuple[relay.Expr, dict of str to array] + The corresponding Relay expression and parameters. + """ + if isinstance(graph, Symbol): + graph = graph_create(graph) + + param_shapes = dict((k, params[k].shape) for k in params) + shape_dict = shape_dict.copy() + shape_dict.update(param_shapes) + graph = graph_attr.set_shape_inputs(graph, shape_dict) + graph = graph_attr.set_dtype_inputs(graph, dtype_dict) + graph = graph.apply(["InferShape", "InferType"]) + shape = graph.json_attr("shape") + dtype = [graph_attr.TCODE_TO_DTYPE[di] for di in graph.json_attr("dtype")] + heads = [x[0] for x in json.loads(graph.json())['heads']] + + gidx = graph.index + relay_map = {} + fn_params = [] + output_ids = [] + + for nid, node in enumerate(gidx.nodes): + children = [] + for i in node['inputs']: + child = relay_map[i[0]] + if isinstance(child, expr.TupleWrapper): + children.append(child[i[1]]) + else: + children.append(child) + + oshape = shape[gidx.entry_id(nid, 0)] + odtype = dtype[gidx.entry_id(nid, 0)] + attrs = node.get("attrs", {}) + node_name = node["name"] + op_name = node["op"] + + if op_name == "null": + v = var(node_name, shape=oshape, dtype=odtype) + fn_params.append(v) + relay_map[nid] = v + else: + if nid in heads: + output_ids.append(nid) + + if op_name in NNVM_OP_2_RELAY_OP: + str_attrs = StrAttrsDict(attrs) + call = NNVM_OP_2_RELAY_OP[op_name](children, str_attrs, odtype) + relay_map[nid] = call + else: + raise Exception( + "nnvm.to_relay: unsupported operator: {0}".format(op_name)) + + outputs = [relay_map[nid] for nid in output_ids] + if len(outputs) == 1: + body = outputs[0] + else: + body = expr.Tuple(outputs) + + func = relay.Function(fn_params, body) + return func, params diff --git a/nnvm/tests/python/compiler/test_to_relay.py b/nnvm/tests/python/compiler/test_to_relay.py new file mode 100644 index 000000000000..25037cfd3587 --- /dev/null +++ b/nnvm/tests/python/compiler/test_to_relay.py @@ -0,0 +1,41 @@ +import nnvm +from nnvm import testing +from nnvm import to_relay +import tvm +from tvm.relay import ir_pass +from tvm.relay import create_executor +from tvm.contrib import graph_runtime +import numpy as np + +def check_model(sym, shapes, dtypes, params): + net = nnvm.graph.create(sym) + graph_json, mod, params = nnvm.compiler.build( + net, + 'llvm', + shape=shapes, + dtype=dtypes, + params=params) + nnvm_rts = graph_runtime.create(graph_json, mod, tvm.cpu(0)) + inputs = {} + for name in shapes: + np_array = np.random.rand(*shapes[name]).astype('float32') + inputs[name] = tvm.nd.array(np_array) + + nnvm_rts.set_input(**params) + nnvm_rts.run(**inputs) + nnvm_out = nnvm_rts.get_output(0) + relay_model, params = to_relay.to_relay(net, shapes, dtypes, params) + relay_model = ir_pass.infer_type(relay_model) + relay_rts = create_executor(kind='graph', ctx=tvm.cpu(0), target='llvm') + inputs.update(params) + relay_out = relay_rts.evaluate(relay_model)(*list(inputs.values())) + np.testing.assert_allclose(nnvm_out.asnumpy(), relay_out.asnumpy()) + +# def test_mlp(): +# mlp, params = testing.mlp.get_workload(1) +# shapes = { "data": (10, 3, 224, 224) } +# dtypes = { "data": 'float32' } +# check_model(mlp, shapes, dtypes, params) + +if __name__ == "__main__": + test_mlp() diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 8e037d4bc554..95633a4d4586 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -101,11 +101,64 @@ def get_int_tuple(self, key, default=RequiredAttr()): """ if key in self.attrs: tshape = self.attrs[key] - return tuple(int(x.strip()) for x in tshape.strip('()').split(',')) + return tuple(int(x.strip()) for x in tshape.strip('()[]').split(',')) if isinstance(default, RequiredAttr): raise AttributeError("Required attribute {} not found.".format(key)) return default + def get_tuple_tuple_int(self, key, default=RequiredAttr()): + """Get int list attribute + + Parameters + ---------- + key : str + The attribute key + + default : float + The default value. + + Returns + ------- + value : The result + """ + if key in self.attrs: + value = self.attrs[key] + seq = [] + for tup in value.strip('()').split('),'): + tup = tup.strip('[]()') + els = [int(x.strip('( ')) for x in tup.split(',')] + seq.append(tuple(els)) + + return tuple(seq) + + if isinstance(default, RequiredAttr): + raise AttributeError("Required attribute {} not found.".format(key)) + return default + + def get_int_list(self, key, default=RequiredAttr()): + """Get int list attribute + + Parameters + ---------- + key : str + The attribute key + + default : float + The default value. + + Returns + ------- + value : The result + """ + if key in self.attrs: + tshape = self.attrs[key] + return tuple(int(x.strip()) for x in tshape.strip('[]()').split(',')) + if isinstance(default, RequiredAttr): + raise AttributeError("Required attribute {} not found.".format(key)) + return default + + + def get_bool(self, key, default=RequiredAttr()): """Get bool tuple attribute diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index b0b1e700987c..77e97d26efe0 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -8,138 +8,14 @@ from .. import op as _op from ... import nd as _nd from .common import StrAttrsDict +from .nnvm_common import _rename, _binop_scalar, _rbinop_scalar, _reduce +from .nnvm_common import _arg_reduce, _init_op, _softmax_op, _cast +from .nnvm_common import _clip, _transpose, _upsampling +from .nnvm_common import _elemwise_sum, _reshape +from .nnvm_common import _warn_not_used __all__ = ['from_mxnet'] - -def _get_relay_op(op_name): - op = getattr(_op, op_name) - if not op: - raise RuntimeError("Unable to map op_name {} to relay".format(op_name)) - return op - - -def _warn_not_used(attr, op='nnvm'): - import warnings - err = "{} is ignored in {}.".format(attr, op) - warnings.warn(err) - - -def _rename(new_op): - if isinstance(new_op, str): - new_op = _get_relay_op(new_op) - # attrs are ignored. - def impl(inputs, _): - return new_op(*inputs) - return impl - - -def _reshape(inputs, attrs): - if attrs.get_bool("reverse", False): - raise RuntimeError("reshape do not support option reverse") - shape = attrs.get_int_tuple("shape") - return _op.reshape(inputs[0], newshape=shape) - - -def _init_op(new_op): - """Init ops like zeros/ones""" - def _impl(inputs, attrs): - assert len(inputs) == 0 - shape = attrs.get_int_tuple("shape") - dtype = attrs.get_str("dtype", "float32") - return new_op(shape=shape, dtype=dtype) - return _impl - - -def _softmax_op(new_op): - """softmax/log_softmax""" - def _impl(inputs, attrs): - assert len(inputs) == 1 - axis = attrs.get_int("axis", -1) - return new_op(inputs[0], axis=axis) - return _impl - - -def _reduce(new_op): - """Reduction ops like sum/min/max""" - def _impl(inputs, attrs): - assert len(inputs) == 1 - axis = attrs.get_int_tuple("axis", []) - keepdims = attrs.get_bool("keepdims", False) - # use None for reduce over all axis. - axis = None if len(axis) == 0 else axis - return new_op(inputs[0], axis=axis, keepdims=keepdims) - return _impl - - -def _arg_reduce(new_op): - """Arg Reduction ops like argmin/argmax""" - def _impl(inputs, attrs): - assert len(inputs) == 1 - axis = attrs.get_int("axis", None) - keepdims = attrs.get_bool("keepdims", False) - res = new_op(inputs[0], axis=[axis], keepdims=keepdims) - # cast to dtype. - res = res.astype("float32") - return res - return _impl - - -def _cast(inputs, attrs): - """Type cast""" - dtype = attrs.get_str("dtype") - return _op.cast(inputs[0], dtype=dtype) - - -def _clip(inputs, attrs): - a_min = attrs.get_float("a_min") - a_max = attrs.get_float("a_max") - return _op.clip(inputs[0], a_min=a_min, a_max=a_max) - - -def _transpose(inputs, attrs): - axes = attrs.get_int_tuple("axes", None) - # translate default case - axes = None if len(axes) == 0 else axes - return _op.transpose(inputs[0], axes=axes) - - -def _upsampling(inputs, attrs): - scale = attrs.get_int("scale") - return _op.nn.upsampling(inputs[0], scale=scale) - - -def _elemwise_sum(inputs, _): - assert len(inputs) > 0 - res = inputs[0] - for x in inputs[1:]: - res = _op.add(res, x) - return res - - -def _binop_scalar(new_op): - def _impl(inputs, attrs): - assert len(inputs) == 1 - scalar = attrs.get_float("scalar") - # Note: binary scalar only works for float op for now - scalar = _expr.const(scalar, dtype="float32") - return new_op(inputs[0], scalar) - return _impl - - -def _rbinop_scalar(new_op): - def _impl(inputs, attrs): - assert len(inputs) == 1 - scalar = attrs.get_float("scalar") - # Note: binary scalar only works for float op for now - scalar = _expr.const(scalar, dtype="float32") - return new_op(scalar, inputs[0]) - return _impl - -# All the functions with _mx prefix specific to MXNet. -# The functions without _mx prefix can be reused for -# NNVMv1 conversion to _op. - def _mx_fully_connected(inputs, attrs): import mxnet as mx units = attrs.get_int("num_hidden") @@ -493,6 +369,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info): jnodes = jgraph["nodes"] node_map = {} + for nid, node in enumerate(jnodes): children = [node_map[e[0]][e[1]] for e in node["inputs"]] attrs = StrAttrsDict(node.get("attrs", {})) @@ -501,7 +378,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info): if op_name == "null": shape = shape_dict[node_name] if node_name in shape_dict else None if isinstance(dtype_info, dict): - dtype = dtype_info[node_name] if node_name in dtype_dict else "float32" + dtype = dtype_info[node_name] if node_name in dtype_info else "float32" else: dtype = dtype_info node_map[nid] = [_expr.var(node_name, shape=shape, dtype=dtype)] diff --git a/python/tvm/relay/frontend/nnvm_common.py b/python/tvm/relay/frontend/nnvm_common.py new file mode 100644 index 000000000000..17502dbaa090 --- /dev/null +++ b/python/tvm/relay/frontend/nnvm_common.py @@ -0,0 +1,132 @@ +# pylint: disable=invalid-name, import-self, len-as-condition +"""Utility functions common to NNVM and MxNet conversion.""" +from __future__ import absolute_import as _abs + +from .. import expr as _expr +from .. import op as _op + +def _get_relay_op(op_name): + op = _op + for path in op_name.split("."): + op = getattr(op, path) + if not op: + raise RuntimeError("Unable to map op_name {} to relay".format(op_name)) + return op + + +def _warn_not_used(attr, op='nnvm'): + import warnings + err = "{} is ignored in {}.".format(attr, op) + warnings.warn(err) + + +def _rename(new_op): + if isinstance(new_op, str): + new_op = _get_relay_op(new_op) + # attrs are ignored. + def impl(inputs, _, _dtype='float32'): + return new_op(*inputs) + return impl + + +def _reshape(inputs, attrs): + if attrs.get_bool("reverse", False): + raise RuntimeError("reshape do not support option reverse") + shape = attrs.get_int_tuple("shape") + return _op.reshape(inputs[0], newshape=shape) + + +def _init_op(new_op): + """Init ops like zeros/ones""" + def _impl(inputs, attrs): + assert len(inputs) == 0 + shape = attrs.get_int_tuple("shape") + dtype = attrs.get_str("dtype", "float32") + return new_op(shape=shape, dtype=dtype) + return _impl + + +def _softmax_op(new_op): + """softmax/log_softmax""" + def _impl(inputs, attrs): + assert len(inputs) == 1 + axis = attrs.get_int("axis", -1) + return new_op(inputs[0], axis=axis) + return _impl + + +def _reduce(new_op): + """Reduction ops like sum/min/max""" + def _impl(inputs, attrs): + assert len(inputs) == 1 + axis = attrs.get_int_tuple("axis", []) + keepdims = attrs.get_bool("keepdims", False) + # use None for reduce over all axis. + axis = None if len(axis) == 0 else axis + return new_op(inputs[0], axis=axis, keepdims=keepdims) + return _impl + + +def _arg_reduce(new_op): + """Arg Reduction ops like argmin/argmax""" + def _impl(inputs, attrs): + assert len(inputs) == 1 + axis = attrs.get_int("axis", None) + keepdims = attrs.get_bool("keepdims", False) + res = new_op(inputs[0], axis=[axis], keepdims=keepdims) + # cast to dtype. + res = res.astype("float32") + return res + return _impl + + +def _cast(inputs, attrs): + """Type cast""" + dtype = attrs.get_str("dtype") + return inputs[0].astype(dtype=dtype) + + +def _clip(inputs, attrs): + a_min = attrs.get_float("a_min") + a_max = attrs.get_float("a_max") + return _op.clip(inputs[0], a_min=a_min, a_max=a_max) + + +def _transpose(inputs, attrs): + axes = attrs.get_int_tuple("axes", None) + # translate default case + axes = None if len(axes) == 0 else axes + return _op.transpose(inputs[0], axes=axes) + + +def _upsampling(inputs, attrs): + scale = attrs.get_int("scale") + return _op.nn.upsampling(inputs[0], scale=scale) + + +def _elemwise_sum(inputs, _): + assert len(inputs) > 0 + res = inputs[0] + for x in inputs[1:]: + res = _op.add(res, x) + return res + + +def _binop_scalar(new_op): + def _impl(inputs, attrs): + assert len(inputs) == 1 + scalar = attrs.get_float("scalar") + # Note: binary scalar only works for float op for now + scalar = _expr.const(scalar, dtype="float32") + return new_op(inputs[0], scalar) + return _impl + + +def _rbinop_scalar(new_op): + def _impl(inputs, attrs): + assert len(inputs) == 1 + scalar = attrs.get_float("scalar") + # Note: binary scalar only works for float op for now + scalar = _expr.const(scalar, dtype="float32") + return new_op(scalar, inputs[0]) + return _impl diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 1aaf376a7dc8..c1e71e9133ea 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -9,6 +9,7 @@ schedule_injective = _reg.schedule_injective schedule_broadcast = _reg.schedule_injective + _reg.register_schedule("collapse_sum_like", _schedule_reduce) _reg.register_schedule("broadcast_to_like", schedule_broadcast) _reg.register_schedule("expand_dims", schedule_broadcast) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 007888996ed5..f5f76e6af38a 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -243,14 +243,11 @@ def schedule_l2_normalize(attrs, outs, target): reg.register_pattern("nn.l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE) - -@reg.register_schedule("nn.upsampling") +# Upsampling +reg.register_schedule("nn.upsampling", reg.schedule_injective) def schedule_upsampling(_, outs, target): """Schedule definition of upsampling""" with target: return topi.generic.schedule_injective(outs) - -reg.register_pattern("nn.upsampling", OpPattern.INJECTIVE) - # pad reg.register_schedule("nn.pad", schedule_broadcast) diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 5001e2cd4fea..4a5aa4ea0a33 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -253,6 +253,9 @@ class StorageAllocator : public StorageAllocaBaseVisitor { size_t size = 1; for (IndexExpr dim : ttype->shape) { const int64_t* pval = as_const_int(dim); + CHECK_GE(*pval, 0) << + "can not allocate memory for tensor with negative shape" << + *pval; CHECK(pval != nullptr) << "Cannot allocate memory symbolic tensor shape " << ttype->shape; diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 16af572a9d6f..064343c834ea 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -13,7 +13,7 @@ namespace tvm { namespace relay { -// Alpha equal handler for relay. +// Alpha Equal handler for Relay. class AlphaEqualHandler: public AttrsEqualHandler, public TypeFunctor, @@ -26,7 +26,7 @@ class AlphaEqualHandler: * Check equality of two nodes. * \param lhs The left hand operand. * \param rhs The right hand operand. - * \return The compare result. + * \return The comparison result. */ bool Equal(const NodeRef& lhs, const NodeRef& rhs) { if (lhs.same_as(rhs)) return true; @@ -46,7 +46,7 @@ class AlphaEqualHandler: * Check equality of two attributes. * \param lhs The left hand operand. * \param rhs The right hand operand. - * \return The compare result. + * \return The comparison result. */ bool AttrEqual(const NodeRef& lhs, const NodeRef& rhs) { return AttrsEqualHandler::Equal(lhs, rhs); @@ -55,7 +55,7 @@ class AlphaEqualHandler: * Check equality of two types. * \param lhs The left hand operand. * \param rhs The right hand operand. - * \return The compare result. + * \return the comparison result. */ bool TypeEqual(const Type& lhs, const Type& rhs) { if (lhs.same_as(rhs)) return true; @@ -72,7 +72,7 @@ class AlphaEqualHandler: * * \param lhs The left hand operand. * \param rhs The right hand operand. - * \return The compare result. + * \return The comparison result. */ bool ExprEqual(const Expr& lhs, const Expr& rhs) { if (lhs.same_as(rhs)) return true; diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index 6a98d2884621..d386437ae15b 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -6,8 +6,11 @@ #include #include #include +#include #include #include +#include +#include "../op_common.h" #include "../layout.h" namespace tvm { @@ -86,26 +89,37 @@ RELAY_REGISTER_OP("nn.upsampling") .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) .add_type_rel("UpSampling", UpSamplingRel) +.set_attr("TOpPattern", kInjective) .set_attr( "FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type, - const Target& target) { - const auto* param = attrs.as(); - const auto* out_ttype = out_type.as(); - CHECK(param != nullptr); - CHECK(param->layout == "NCHW" || param->layout == "NHWC"); - CHECK(out_ttype != nullptr); - Array oshape; - if (param->layout == "NCHW") { - oshape.push_back(out_ttype->shape[2]); - oshape.push_back(out_ttype->shape[3]); - } else if (param->layout == "NHWC") { - oshape.push_back(out_ttype->shape[1]); - oshape.push_back(out_ttype->shape[2]); - } - return Array{ topi::nn::upsampling(inputs[0], oshape, param->layout, param->method)}; + const Array& inputs, + const Type& out_type, + const Target& target) { + const auto* uattrs = attrs.as(); + CHECK(uattrs != nullptr); + auto out_tt = out_type.as(); + CHECK(out_tt) << "expected a tensor type: " << out_type; + CHECK(uattrs->layout == "NCHW" || uattrs->layout == "NHWC") + << "unknown layout: " << uattrs->layout; + + Array oshape; + if (uattrs->layout == "NCHW") { + oshape.push_back(out_tt->shape[2]); + oshape.push_back(out_tt->shape[3]); + } else if (uattrs->layout == "NHWC") { + oshape.push_back(out_tt->shape[1]); + oshape.push_back(out_tt->shape[2]); + } + + return Array{ + topi::nn::upsampling( + inputs[0], + oshape, + uattrs->layout, + uattrs->method) + }; }); + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/frontend/test_keras.py b/tests/python/relay/frontend/test_keras.py new file mode 100644 index 000000000000..f508c5b44310 --- /dev/null +++ b/tests/python/relay/frontend/test_keras.py @@ -0,0 +1,332 @@ +import numpy as np +import nnvm +from nnvm import to_relay +import tvm +from tvm import relay +from tvm.contrib import graph_runtime +from nnvm.testing.config import ctx_list +import keras + +# prevent keras from using up all gpu memory +import tensorflow as tf +from keras.backend.tensorflow_backend import set_session +config = tf.ConfigProto() +config.gpu_options.per_process_gpu_memory_fraction = 0.5 +set_session(tf.Session(config=config)) + + +def verify_keras_frontend(keras_model, need_transpose=True): + # Keras frontend currently supports tensorflow backend only. + assert(keras.backend.backend() == 'tensorflow') + + in_shapes = [] + for layer in keras_model._input_layers: + in_shapes.append(tuple(dim.value if dim.value is not None else 1 for dim in layer.input.shape)) + + def get_keras_output(xs, dtype='float32'): + return keras_model.predict(xs) + + def get_tvm_output(xs, target, ctx, dtype='float32'): + sym, params = nnvm.frontend.from_keras(keras_model) + shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, xs)} + with relay.build_module.build_config(opt_level=2): + func, params = to_relay.to_relay(sym, shape_dict, dtype, params) + graph, lib, params = relay.build(func, target='llvm', params=params) + m = graph_runtime.create(graph, lib, ctx) + for name, x in zip(keras_model.input_names, xs): + m.set_input(name, tvm.nd.array(x.astype(dtype))) + m.set_input(**params) + m.run() + + return [m.get_output(i).asnumpy() for i in range(m.get_num_outputs())] + + def to_channels_first(arr): + return arr.transpose([0, -1] + list(range(1, arr.ndim - 1))) + + def to_channels_last(arr): + return arr.transpose([0] + list(range(2, arr.ndim)) + [1]) + + xs = [np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in in_shapes] + keras_out = get_keras_output(xs) + + keras_out = keras_out if isinstance(keras_out, list) else [keras_out] + for target, ctx in ctx_list(): + inputs = [to_channels_first(x) for x in xs] if need_transpose else xs + tvm_out = get_tvm_output(inputs, target, ctx) + for kout, tout in zip(keras_out, tvm_out): + if need_transpose: + tout = to_channels_last(tout) + tvm.testing.assert_allclose(kout, tout, rtol=1e-5, atol=1e-5) + +def test_forward_elemwise_add(): + r = [] + data = keras.layers.Input(shape=(32,32,3)) + x = keras.layers.Conv2D(8, (3, 3), padding="same")(data) + r.append(x) + x = keras.layers.Conv2D(8, (3, 3), padding="same")(x) + r.append(x) + x = keras.layers.Conv2D(8, (3, 3), padding="same")(x) + # add two symbols + y = keras.layers.add([keras.layers.add([x, r[0]]), r[1]]) + y = keras.layers.GlobalAveragePooling2D()(y) + keras_model = keras.models.Model(data, y) + verify_keras_frontend(keras_model) + # add three symbols + y = keras.layers.add([x, r[0], r[1]]) + y = keras.layers.GlobalAveragePooling2D()(y) + keras_model = keras.models.Model(data, y) + verify_keras_frontend(keras_model) + + +def test_forward_dense(): + data = keras.layers.Input(shape=(32,32,1)) + x = keras.layers.Flatten()(data) + x = keras.layers.Dropout(0.5)(x) + x = keras.layers.Dense(10, activation='relu', kernel_initializer='uniform')(x) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model) + + +def test_forward_pool(): + data = keras.layers.Input(shape=(32,32,1)) + # maxpool + x = keras.layers.MaxPooling2D((3, 3), strides=(1, 1), padding='same')(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model) + # avgpool + y = keras.layers.AveragePooling2D((3, 3), strides=(1, 1), padding='same')(data) + keras_model = keras.models.Model(data, y) + verify_keras_frontend(keras_model) + + +def test_forward_conv(): + data = keras.layers.Input(shape=(32,32,3)) + conv_funcs = [keras.layers.Conv2D(filters=10, kernel_size=(3,3), + strides=(2,2), padding='same'), + keras.layers.Conv2D(filters=10, kernel_size=(3,3), + dilation_rate=(2,2), padding='same'), + keras.layers.DepthwiseConv2D(kernel_size=(3,3), padding='same'), + keras.layers.Conv2DTranspose(filters=10, kernel_size=(3,3), padding='valid'), + keras.layers.SeparableConv2D(filters=10, kernel_size=(3,3), padding='same')] + for conv_func in conv_funcs: + x = conv_func(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model) + + +def test_forward_upsample(): + data = keras.layers.Input(shape=(32,32,3)) + x = keras.layers.UpSampling2D(size=(3,3))(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model) + + +def test_forward_reshape(): + data = keras.layers.Input(shape=(32,32,3)) + x = keras.layers.Reshape(target_shape=(32,32,3))(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model) + + +def test_forward_crop(): + data = keras.layers.Input(shape=(32,32,3)) + x = keras.layers.Cropping2D(cropping=((1, 1), (1, 1)))(data) + x = keras.layers.Cropping2D(cropping=(1, 1))(x) + x = keras.layers.Cropping2D(cropping=1)(x) + x = keras.layers.Cropping2D(cropping=((0, 1), (1, 0)))(x) + x = keras.layers.Cropping2D(cropping=(1, 0))(x) + x = keras.layers.Cropping2D(cropping=0)(x) + x = keras.layers.Add()([x, x]) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model) + + +def test_forward_vgg16(): + keras_model = keras.applications.vgg16.VGG16(include_top=True, weights='imagenet', + input_shape=(224,224,3), classes=1000) + verify_keras_frontend(keras_model) + + +def test_forward_xception(): + keras_model = keras.applications.xception.Xception(include_top=True, weights='imagenet', + input_shape=(299,299,3), classes=1000) + verify_keras_frontend(keras_model) + + +def test_forward_resnet50(): + keras_model = keras.applications.resnet50.ResNet50(include_top=True, weights='imagenet', + input_shape=(224,224,3), classes=1000) + verify_keras_frontend(keras_model) + + +def test_forward_mobilenet(): + keras_model = keras.applications.mobilenet.MobileNet(include_top=True, weights='imagenet', + input_shape=(224,224,3), classes=1000) + verify_keras_frontend(keras_model) + + +def test_forward_activations(): + data = keras.layers.Input(shape=(32,32,3)) + weights = np.random.rand(1, 32, 32, 3) + act_funcs = [keras.layers.Activation('softmax'), + keras.layers.Activation('softplus'), + keras.layers.ReLU(), + keras.layers.ReLU(max_value=6.), + keras.layers.LeakyReLU(alpha=0.3), + keras.layers.PReLU(weights=weights, alpha_initializer="zero"), + keras.layers.ELU(alpha=0.5), + keras.layers.Activation('selu'), + keras.layers.ThresholdedReLU(theta=0.5), + keras.layers.Activation('softsign'), + keras.layers.Activation('hard_sigmoid'), + keras.layers.Activation('sigmoid'), + keras.layers.Activation('tanh'), + keras.layers.Activation('linear')] + for act_func in act_funcs: + x = act_func(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model) + + +def test_forward_multi_inputs(): + data1 = keras.layers.Input(shape=(32,32,3)) + data2 = keras.layers.Input(shape=(32,32,3)) + x = keras.layers.Conv2D(8, (3, 3), padding="same")(data1) + y = keras.layers.Conv2D(8, (3, 3), padding="same")(data2) + z = keras.layers.add([x, y]) + z = keras.layers.GlobalAveragePooling2D()(z) + keras_model = keras.models.Model([data1, data2], z) + verify_keras_frontend(keras_model) + + +def test_forward_multi_outputs(): + data = keras.layers.Input(shape=(32,32,3)) + x = keras.layers.Conv2D(8, (3, 3), padding="same")(data) + x = keras.layers.GlobalAveragePooling2D()(x) + y = keras.layers.Conv2D(8, (3, 3), padding="same")(data) + y = keras.layers.GlobalAveragePooling2D()(y) + keras_model = keras.models.Model(data, [x, y]) + verify_keras_frontend(keras_model) + + +def test_forward_reuse_layers(): + # reuse conv2d + data = keras.layers.Input(shape=(32,32,3)) + conv2d = keras.layers.Conv2D(8, (3, 3), padding="same") + x = conv2d(data) + y = conv2d(data) + z = keras.layers.add([x, y]) + z = keras.layers.GlobalAveragePooling2D()(z) + keras_model = keras.models.Model(data, z) + verify_keras_frontend(keras_model) + + # reuse add + data = keras.layers.Input(shape=(32,32,3)) + x = keras.layers.Conv2D(8, (3, 3), padding="same")(data) + add = keras.layers.Add() + x = add([x, x]) + x = add([x, x]) + z = keras.layers.GlobalAveragePooling2D()(x) + keras_model = keras.models.Model(data, z) + verify_keras_frontend(keras_model) + +def _test_LSTM(inputs, hidden, return_state=True): + data = keras.layers.Input(shape=(1, inputs)) + lstm_out = keras.layers.LSTM(hidden, + return_state=return_state, + recurrent_activation='sigmoid', + activation='tanh') + x = lstm_out(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model, need_transpose=False) + +def _test_LSTM_MultiLayer(inputs, hidden): + inputs = keras.layers.Input(shape=(1, inputs)) + layer = keras.layers.LSTM(hidden, return_state=True, return_sequences=True, + recurrent_activation='sigmoid', + activation='tanh') + outputs = layer(inputs) + output, state = outputs[0], outputs[1:] + output = keras.layers.LSTM(hidden, recurrent_activation='sigmoid', + activation='tanh')(output, initial_state=state) + keras_model = keras.models.Model(inputs, output) + verify_keras_frontend(keras_model, need_transpose=False) + + +def test_forward_LSTM(): + # TODO(@jroesch): need to modify compile engine to fix return_state=True + _test_LSTM(8, 8, return_state=False) + _test_LSTM(4, 4, return_state=False) + _test_LSTM_MultiLayer(4, 4) + +def _test_RNN(inputs, units): + data = keras.layers.Input(shape=(1, inputs)) + rnn_out = keras.layers.SimpleRNN(units, return_state=True, + activation='tanh') + x = rnn_out(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model, need_transpose=False) + +def _test_RNN_MultiLayer(inputs, units): + inputs = keras.layers.Input(shape=(1, inputs)) + layer = keras.layers.SimpleRNN(units, return_state=True, return_sequences=True, + activation='tanh') + outputs = layer(inputs) + output, state = outputs[0], outputs[1:] + output = keras.layers.SimpleRNN(units, activation='tanh')(output, initial_state=state) + keras_model = keras.models.Model(inputs, output) + verify_keras_frontend(keras_model, need_transpose=False) + +def test_forward_RNN(): + _test_RNN(2, 4) + _test_RNN(4, 3) + _test_RNN_MultiLayer(4, 12) + +def _test_GRU(inputs, units): + data = keras.layers.Input(shape=(1, inputs)) + gru_out = keras.layers.GRU(units, + return_state=True, + recurrent_activation='sigmoid', + activation='tanh') + x = gru_out(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model, need_transpose=False) + +def _test_GRU_MultiLayer(inputs, units): + inputs = keras.layers.Input(shape=(1, inputs)) + layer = keras.layers.GRU(units, + return_state=True, + return_sequences=True, + recurrent_activation='sigmoid', + activation='tanh') + outputs = layer(inputs) + output, state = outputs[0], outputs[1:] + output = keras.layers.GRU(units, recurrent_activation='sigmoid', + activation='tanh')(output, initial_state=state) + keras_model = keras.models.Model(inputs, output) + verify_keras_frontend(keras_model, need_transpose=False) + +def test_forward_GRU(): + _test_GRU(2, 4) + _test_GRU(4, 3) + _test_GRU_MultiLayer(4, 4) + +if __name__ == '__main__': + test_forward_elemwise_add() + test_forward_activations() + test_forward_dense() + test_forward_pool() + test_forward_conv() + test_forward_upsample() + test_forward_reshape() + test_forward_crop() + test_forward_vgg16() + test_forward_xception() + test_forward_resnet50() + test_forward_mobilenet() + test_forward_multi_inputs() + test_forward_multi_outputs() + test_forward_reuse_layers() + test_forward_LSTM() + test_forward_RNN() + test_forward_GRU() diff --git a/topi/include/topi/image/resize.h b/topi/include/topi/image/resize.h index b6bd51ef0fd2..2ffe4f453ba2 100644 --- a/topi/include/topi/image/resize.h +++ b/topi/include/topi/image/resize.h @@ -12,6 +12,7 @@ #include #include "topi/tags.h" +#include "topi/elemwise.h" #include "topi/detail/ravel_unravel.h" #include "topi/detail/constant_utils.h" #include "tvm/tvm.h" @@ -288,7 +289,7 @@ inline Tensor resize_bilinear_nchw(const Tensor& input, * \return A Tensor resized to given shape */ inline Tensor resize_bilinear(const Tensor& input, - const Array& shape, + const Array& shape, std::string layout = "NCHW", bool align_corners = false, std::string name = "tensor",