From 908ff9a4cf297ecd9c1d40f962196cf7ddbd782a Mon Sep 17 00:00:00 2001 From: Mark Rogers Date: Tue, 12 Mar 2019 21:53:49 +0000 Subject: [PATCH 1/2] unified error handling in nnvm and relay --- nnvm/python/nnvm/frontend/__init__.py | 7 + nnvm/python/nnvm/frontend/caffe2.py | 9 +- nnvm/python/nnvm/frontend/common.py | 16 +- nnvm/python/nnvm/frontend/coreml.py | 23 +-- nnvm/python/nnvm/frontend/darknet.py | 156 ++++++++----------- nnvm/python/nnvm/frontend/keras.py | 44 +++--- nnvm/python/nnvm/frontend/mxnet.py | 190 ++++++++++-------------- nnvm/python/nnvm/frontend/onnx.py | 5 +- nnvm/python/nnvm/frontend/tensorflow.py | 21 ++- python/tvm/error_handling/__init__.py | 44 ++++++ python/tvm/relay/frontend/__init__.py | 5 + python/tvm/relay/frontend/caffe2.py | 13 +- python/tvm/relay/frontend/coreml.py | 22 +-- python/tvm/relay/frontend/keras.py | 40 +++-- python/tvm/relay/frontend/mxnet.py | 50 ++++--- python/tvm/relay/frontend/onnx.py | 10 +- python/tvm/relay/frontend/tensorflow.py | 18 +-- python/tvm/relay/frontend/tflite.py | 26 ++-- 18 files changed, 344 insertions(+), 355 deletions(-) create mode 100644 python/tvm/error_handling/__init__.py diff --git a/nnvm/python/nnvm/frontend/__init__.py b/nnvm/python/nnvm/frontend/__init__.py index 49f53df1174f..f95e134cf0dd 100644 --- a/nnvm/python/nnvm/frontend/__init__.py +++ b/nnvm/python/nnvm/frontend/__init__.py @@ -7,3 +7,10 @@ from .darknet import from_darknet from .tensorflow import from_tensorflow from .caffe2 import from_caffe2 +from .common import raise_not_supported, get_nnvm_op, required_attr, \ + warn_not_used, parse_tshape, parse_bool_str +from tvm.error_handling import raise_attribute_required, \ + raise_attribute_invalid, \ + raise_operator_unimplemented, \ + raise_attribute_unimplemented, \ + warn_not_used diff --git a/nnvm/python/nnvm/frontend/caffe2.py b/nnvm/python/nnvm/frontend/caffe2.py index 8211971a8c3c..32d08678a0f8 100755 --- a/nnvm/python/nnvm/frontend/caffe2.py +++ b/nnvm/python/nnvm/frontend/caffe2.py @@ -73,8 +73,7 @@ def get_converter(cls): if hasattr(cls, '_impl'): return getattr(cls, '_impl') - raise NotImplementedError('{} not implemented'.format( - cls.__name__)) + raise_operator_unimplemented(cls.__name__) _caffe2_internal_args = { @@ -176,8 +175,7 @@ def _get_axis_from_order_str(order): return 1 if order == 'NHWC': return 3 - raise RuntimeError( - "Unsupported storage order: {} in caffe2".format(order)) + raise_attribute_invalid(order, 'storage order', 'concat') return AttrCvt( op_name='concatenate', @@ -427,8 +425,7 @@ def _convert_operator(self, # Add a sanitizing step to convert all byte strings in args to strings sym = convert_map[op_type](inputs, args, self._params) else: - raise NotImplementedError( - "Operator {} not implemented.".format(op_type)) + raise_operator_unimplemented(op_type) return sym diff --git a/nnvm/python/nnvm/frontend/common.py b/nnvm/python/nnvm/frontend/common.py index 7b8c4621029d..58ce6703b28d 100644 --- a/nnvm/python/nnvm/frontend/common.py +++ b/nnvm/python/nnvm/frontend/common.py @@ -7,9 +7,23 @@ def get_nnvm_op(op_name): op = getattr(_sym, op_name) if not op: - raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name)) + raise_operator_unimplemented(op_name) return op +def required_attr(attr, key, op_name): + assert isinstance(attr, dict) + if key not in attr: + raise_attribute_required(key, op_name) + return attr[key] + +def parse_tshape(tshape): + """Parse tshape in string.""" + return [int(x.strip()) for x in tshape.strip('()').split(',')] + +def parse_bool_str(attr, key, default='False'): + """Parse bool string to boolean.""" + return attr.get(key, default).strip().lower() in ['true', '1', 't', 'y', 'yes'] + class Renamer(object): """A simply renamer for operators. diff --git a/nnvm/python/nnvm/frontend/coreml.py b/nnvm/python/nnvm/frontend/coreml.py index 77285efe7a76..e7c5a0d7eda8 100644 --- a/nnvm/python/nnvm/frontend/coreml.py +++ b/nnvm/python/nnvm/frontend/coreml.py @@ -83,7 +83,7 @@ def BatchnormLayerParams(op, insym, symtab): """Get layer of batchnorm parameter""" # this changes the symbol if op.instanceNormalization: - raise NotImplementedError("instance normalization not implemented") + raise_operator_unimplemented('instance normalization') else: params = {'gamma':symtab.new_const(list(op.gamma.floatValue)), 'beta':symtab.new_const(list(op.beta.floatValue)), @@ -136,7 +136,7 @@ def ActivationParams(op, insym, symtab): betasym = symtab.new_const(beta) return _sym.broadcast_mul(_sym.log(_sym.broadcast_add( _sym.exp(insym), betasym)), alphasym) - raise NotImplementedError('%s not implemented' % whichActivation) + raise_operator_unimplemented(whichActivation) def ScaleLayerParams(op, insym, symtab): """Scale layer params.""" @@ -158,7 +158,7 @@ def PoolingLayerParams(op, insym, symtab): return _sym.global_max_pool2d(insym) if op.type == 1: return _sym.global_avg_pool2d(insym) - raise NotImplementedError("Only max and average pooling implemented") + raise_operator_unimplemented('pooling (not max or average)') else: params = {'pool_size':list(op.kernelSize), @@ -178,7 +178,8 @@ def PoolingLayerParams(op, insym, symtab): params['padding'] = padding params['ceil_mode'] = True else: - raise NotImplementedError("Other convolution padding not implemented") + raise_attribute_invalid(op.WhichOneof('PoolingPaddingType'), + 'PoolingPaddingType', 'pooling') # consume padding layer if symtab.in_padding: @@ -190,7 +191,7 @@ def PoolingLayerParams(op, insym, symtab): return _sym.max_pool2d(insym, **params) if op.type == 1: return _sym.avg_pool2d(insym, **params) - raise NotImplementedError("Only max and average pooling implemented") + raise_operator_unimplemented('pooling (not max or average)') def SoftmaxLayerParams(op, insym, symtab): return _sym.softmax(_sym.flatten(insym)) @@ -229,7 +230,7 @@ def ConcatLayerParams(op, insyms, symtab): if not isinstance(insyms, list): insyms = [insyms] if op.sequenceConcat: - raise NotImplementedError("Sequence Concat not supported") + raise_operator_unimplemented('sequence concat') ret = _sym.concatenate(*insyms, axis=1) return ret @@ -243,14 +244,14 @@ def PaddingLayerParams(op, insym, symtab): if op.WhichOneof('PaddingType') == 'constant': constant = op.constant if constant.value != 0: - raise NotImplementedError("Padding value {} not supported.".format(constant.value)) + raise_attribute_invalid(constant.value, 'padding value', 'padding') padding = [b.startEdgeSize for b in op.paddingAmounts.borderAmounts] padding2 = [b.endEdgeSize for b in op.paddingAmounts.borderAmounts] for i, j in zip(padding, padding2): assert i == j symtab.set_padding(padding) else: - raise NotImplementedError("Only constant padding is supported now.") + raise_operator_unimplemented('non-constant padding') return insym def PermuteLayerParams(op, insym, symtab): @@ -259,8 +260,8 @@ def PermuteLayerParams(op, insym, symtab): def UpsampleLayerParams(op, insym, symtab): if op.scalingFactor[0] != op.scalingFactor[1]: - raise NotImplementedError("Upsampling only supported with same \ - height and width scaling factor.") + raise_attribute_invalid(op.scalingFactor, 'scaling factors', + 'upsample') interpolationMode = 'NEAREST_NEIGHBOR' if op.mode == 0 else 'BILINEAR' return _sym.upsampling(insym, scale=op.scalingFactor[0], method=interpolationMode) @@ -341,7 +342,7 @@ def coreml_op_to_nnvm(op, inname, outname, symtab): """ classname = type(op).__name__ if classname not in _convert_map: - raise NotImplementedError("%s is not supported" % (classname)) + raise_operator_unimplemented(classname) if isinstance(inname, string_types): insym = symtab.get_var(inname) else: diff --git a/nnvm/python/nnvm/frontend/darknet.py b/nnvm/python/nnvm/frontend/darknet.py index 154c83c90ec6..bbb0926f29c8 100644 --- a/nnvm/python/nnvm/frontend/darknet.py +++ b/nnvm/python/nnvm/frontend/darknet.py @@ -57,45 +57,11 @@ class ACTIVATION(object): __all__ = ['from_darknet'] -def _darknet_get_nnvm_op(op_name): - """Get the nnvm operation from opname, raise error if not supported.""" - op = getattr(_sym, op_name) - if not op: - raise RuntimeError("Not to map op_name {} to nnvm.sym".format(op_name)) - return op - -def _darknet_required_attr(attr, key): - """Check the attribute exists and return if exists, if not return error.""" - assert isinstance(attr, dict) - if key not in attr: - raise AttributeError("Required attribute {} not found.".format(key)) - return attr[key] - -def _darknet_raise_not_supported(attr, op='nnvm'): - """Raise error if any operation is not supported.""" - err = "{} is not supported in {}.".format(attr, op) - raise NotImplementedError(err) - -def _darknet_warn_not_used(attr, op='nnvm'): - """Raise warning if any operation not supported.""" - import warnings - err = "{} is ignored in {}.".format(attr, op) - warnings.warn(err) - -def _darknet_parse_tshape(tshape): - """Parse tshape in string.""" - return [int(x.strip()) for x in tshape.strip('()').split(',')] - -def _darknet_parse_bool_str(attr, key, default='False'): - """Parse bool string to boolean.""" - return attr.get(key, default).strip().lower() in \ - ['true', '1', 't', 'y', 'yes'] - def _darknet_maxpooling(inputs, attrs): """Process the max pool 2d operation.""" - kernel = _darknet_parse_tshape(_darknet_required_attr(attrs, 'kernel')) + kernel = parse_tshape(required_attr(attrs, 'kernel', 'maxpool')) if len(kernel) != 1: - _darknet_raise_not_supported('non-2d kernel', 'pool_2d') + raise_attribute_unimplemented('non-2d kernel', 'pool_2d') op_name, new_attrs = 'max_pool2d', {} strides = int(attrs.get('stride', (1, 1))) @@ -107,13 +73,13 @@ def _darknet_maxpooling(inputs, attrs): if extra_pad_size: pad_width = ((0, 0), (0, 0), (0, extra_pad_size), (0, extra_pad_size)) inputs = _sym.pad(*inputs, pad_width=pad_width, pad_value=np.finfo(np.float32).min) - return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None + return get_nnvm_op(op_name)(*inputs, **new_attrs), None def _darknet_avgpooling(inputs, attrs): """Process the average pool 2d operation.""" - kernel = _darknet_parse_tshape(_darknet_required_attr(attrs, 'kernel')) + kernel = parse_tshape(required_attr(attrs, 'kernel', 'avgpool')) if len(kernel) != 1: - _darknet_raise_not_supported('non-2d kernel', 'pool_2d') + raise_attribute_unimplemented('non-2d kernel', 'pool_2d') op_name, new_attrs = 'avg_pool2d', {} strides = int(attrs.get('stride', (1, 1))) @@ -122,7 +88,7 @@ def _darknet_avgpooling(inputs, attrs): new_attrs['strides'] = str((strides, strides)) new_attrs['padding'] = str((pads, pads)) - return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None + return get_nnvm_op(op_name)(*inputs, **new_attrs), None def _darknet_batch_norm(inputs, attrs): """Process the batchnormalization operation.""" @@ -131,21 +97,21 @@ def _darknet_batch_norm(inputs, attrs): new_attrs['epsilon'] = attrs.get('eps', 0.000001) new_attrs['center'] = True new_attrs['scale'] = True - return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None + return get_nnvm_op(op_name)(*inputs, **new_attrs), None def _darknet_conv2d(inputs, attrs): """Process the convolution 2d operation.""" - kernel = _darknet_parse_tshape(_darknet_required_attr(attrs, 'kernel')) + kernel = parse_tshape(required_attr(attrs, 'kernel', 'conv2d')) if len(kernel) != 1: - _darknet_raise_not_supported('non 2d kernel', 'conv2d') + raise_attribute_unimplemented('non 2d kernel', 'conv2d') layout = attrs.get('layout', 'NCHW') if layout not in ['NCHW', 'NHWC']: - _darknet_raise_not_supported('layout: ' + layout, 'conv2d') + raise_attribute_invalid(layout, 'layout', 'conv2d') strides = int(attrs.get('stride', (1, 1))) pads = int(attrs.get('pad', (0, 0))) op_name, new_attrs = 'conv2d', {} - new_attrs['channels'] = _darknet_required_attr(attrs, 'num_filter') + new_attrs['channels'] = required_attr(attrs, 'num_filter', 'conv2d') new_attrs['kernel_size'] = [kernel[0], kernel[0]] new_attrs['strides'] = (strides, strides) new_attrs['padding'] = (pads, pads) @@ -157,13 +123,13 @@ def _darknet_conv2d(inputs, attrs): else: new_attrs['use_bias'] = True out_name = {} - sym = _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs) + sym = get_nnvm_op(op_name)(*inputs, **new_attrs) out_name[0] = sym.list_output_names()[0].replace('_output', '') if attrs.get('use_batchNorm', False) is True: op_name, new_attrs = 'batch_norm', {} new_attrs['epsilon'] = 0.000001 - sym = _darknet_get_nnvm_op(op_name)(*sym, **new_attrs) + sym = get_nnvm_op(op_name)(*sym, **new_attrs) out_name[1] = sym.list_output_names()[0].replace('_output', '') if 'activation' in attrs: new_attrs = {} @@ -176,15 +142,15 @@ def _darknet_conv2d(inputs, attrs): def _darknet_conv2d_transpose(inputs, attrs): """Process the convolution 2d transpose operation.""" if 'target_shape' in attrs: - _darknet_raise_not_supported('target_shape', 'conv2d_transpose') - kernel = _darknet_parse_tshape(_darknet_required_attr(attrs, 'kernel')) + raise_attribute_unimplemented('target_shape', 'conv2d_transpose') + kernel = parse_tshape(required_attr(attrs, 'kernel', 'conv2d_transpose')) if len(kernel) != 2: - _darknet_raise_not_supported('non-2d kernel', 'conv2d_transpose') + raise_attribute_unimplemented('non-2d kernel', 'conv2d_transpose') layout = attrs.get('layout', 'NCHW') if layout not in ['NCHW', 'NHWC']: - _darknet_raise_not_supported('layout: ' + layout, 'conv2d_transpose') + raise_attribute_invalid(layout, 'layout', 'conv2d_transpose') op_name, new_attrs = 'conv2d_transpose', {} - new_attrs['channels'] = _darknet_required_attr(attrs, 'num_filter') + new_attrs['channels'] = required_attr(attrs, 'num_filter', 'conv2d_transpose') new_attrs['kernel_size'] = kernel new_attrs['strides'] = attrs.get('stride', (1, 1)) new_attrs['output_padding'] = attrs.get('adj', (0, 0)) @@ -192,8 +158,8 @@ def _darknet_conv2d_transpose(inputs, attrs): new_attrs['dilation'] = attrs.get('dilate', (1, 1)) new_attrs['groups'] = attrs.get('num_group', 1) new_attrs['layout'] = layout - new_attrs['use_bias'] = not _darknet_parse_bool_str(attrs, 'no_bias') - return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None + new_attrs['use_bias'] = not parse_bool_str(attrs, 'no_bias') + return get_nnvm_op(op_name)(*inputs, **new_attrs), None def _darknet_shortcut(inputs, attrs): """Process the shortcut operation.""" @@ -219,7 +185,7 @@ def _darknet_shortcut(inputs, attrs): pad_value=0.) new_inputs = _as_list([input_0, input_1]) - sym = _darknet_get_nnvm_op(op_name)(*new_inputs, **new_attrs) + sym = get_nnvm_op(op_name)(*new_inputs, **new_attrs) out_name = sym.list_output_names()[0].replace('_output', '') if 'activation' in attrs: new_attrs['activation'] = attrs['activation'] @@ -229,17 +195,17 @@ def _darknet_shortcut(inputs, attrs): def _darknet_dense(inputs, attrs): """Process the dense operation.""" op_name, new_attrs = 'dense', {} - new_attrs['units'] = _darknet_required_attr(attrs, 'num_hidden') + new_attrs['units'] = required_attr(attrs, 'num_hidden', 'dense') out_name = {} new_attrs['use_bias'] = attrs.get('use_bias', False) if attrs.get('use_flatten', False) is True: inputs[0] = _sym.flatten(inputs[0]) - sym = _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs) + sym = get_nnvm_op(op_name)(*inputs, **new_attrs) out_name[0] = sym.list_output_names()[0].replace('_output', '') if 'use_batchNorm' in attrs: op_name, new_attrs = 'batch_norm', {} new_attrs['epsilon'] = 0.000001 - sym = _darknet_get_nnvm_op(op_name)(*sym, **new_attrs) + sym = get_nnvm_op(op_name)(*sym, **new_attrs) out_name[1] = sym.list_output_names()[0].replace('_output', '') if 'activation' in attrs: new_attrs = {} @@ -251,28 +217,28 @@ def _darknet_dropout(inputs, attrs): """Process the dropout operation, its a blank operation.""" op_name, new_attrs = 'dropout', {} new_attrs['rate'] = attrs.get('p', 0.5) - return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None + return get_nnvm_op(op_name)(*inputs, **new_attrs), None def _darknet_reshape(inputs, attrs): """Process the reshape operation.""" - if _darknet_parse_bool_str(attrs, 'reverse'): - _darknet_raise_not_supported('reverse', 'reshape') + if parse_bool_str(attrs, 'reverse'): + raise_attribute_unimplemented('reverse', 'reshape') op_name, new_attrs = 'reshape', {} - new_attrs['shape'] = _darknet_required_attr(attrs, 'shape') - return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None + new_attrs['shape'] = required_attr(attrs, 'shape', 'reshape') + return get_nnvm_op(op_name)(*inputs, **new_attrs), None def _darknet_upsampling(inputs, attrs): """Process the upsampling operation.""" op_name, new_attrs = 'upsampling', {} new_attrs['scale'] = attrs.get('scale', 1) - return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None + return get_nnvm_op(op_name)(*inputs, **new_attrs), None def _darknet_l2normalize(inputs, attrs): """Process the l2 normalization operation.""" op_name, new_attrs = 'l2_normalize', {} new_attrs['eps'] = attrs.get('eps', 0) new_attrs['axis'] = attrs.get('axis', 1) - return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None + return get_nnvm_op(op_name)(*inputs, **new_attrs), None def _darknet_softmax_output(inputs, attrs): """Process the softmax operation.""" @@ -280,25 +246,25 @@ def _darknet_softmax_output(inputs, attrs): if temperature != 1: inputs[0] = inputs[0] / float(temperature) op_name, new_attrs = 'softmax', {} - if _darknet_parse_bool_str(attrs, 'multi_output'): + if parse_bool_str(attrs, 'multi_output'): new_attrs['axis'] = 1 if attrs.get('use_flatten', False) is True: inputs[0] = _sym.flatten(inputs[0]) - return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None + return get_nnvm_op(op_name)(*inputs, **new_attrs), None def _darknet_route(inputs, attrs): """Process the route operation, which is equivalent to concat.""" op_name = 'concatenate' new_attrs = {'axis': attrs.get('dim', 1)} - return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None + return get_nnvm_op(op_name)(*inputs, **new_attrs), None def _darknet_reorg(inputs, attrs): """Process the reorg operation.""" op_name, new_attrs = 'yolo_reorg', {} if 'stride' in attrs: new_attrs = {'stride': attrs.get('stride', 1)} - return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None + return get_nnvm_op(op_name)(*inputs, **new_attrs), None def _darknet_region(inputs, attrs): """Process the region operation.""" @@ -344,7 +310,7 @@ def _darknet_yolo(inputs, attrs): def _darknet_activations(inputs, attrs): """Process the activation function.""" - act = _darknet_required_attr(attrs, 'activation') + act = required_attr(attrs, 'activation', 'activations') if ACTIVATION.LOGISTIC == act: act_type = 'sigmoid' elif ACTIVATION.RELU == act: @@ -358,22 +324,22 @@ def _darknet_activations(inputs, attrs): elif ACTIVATION.ELU == act: act_type = 'elu' else: - _darknet_raise_not_supported('act: ' + act) + raise_operator_unimplemented('act: ' + act) if act_type in ['relu', 'tanh']: op_name, new_attrs = act_type, {} - sym = _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs) + sym = get_nnvm_op(op_name)(*inputs, **new_attrs) elif act_type in ['leaky_relu']: op_name, new_attrs = act_type, {} new_attrs['alpha'] = attrs.get('slope', 0.1) - sym = _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs) + sym = get_nnvm_op(op_name)(*inputs, **new_attrs) elif act_type in ['elu']: sym = -1 * _sym.relu(1 - _sym.exp(*inputs)) + _sym.relu(*inputs) elif act_type in ['sigmoid']: op_name, new_attrs = act_type, {} - sym = _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs) + sym = get_nnvm_op(op_name)(*inputs, **new_attrs) else: - _darknet_raise_not_supported('act_type: ' + act_type) + raise_operator_unimplemented('act_type: ' + act_type) return sym, None def _darknet_op_not_support(inputs, attrs): @@ -436,7 +402,7 @@ def _darknet_convert_symbol(op_name, inputs, attrs): if op_name in _DARKNET_CONVERT_MAP: sym, out_name = _DARKNET_CONVERT_MAP[op_name](inputs, attrs) else: - _darknet_raise_not_supported('Operator type ' + str(op_name)) + raise_operator_unimplemented(op_name) if out_name is None: out_name = sym.list_output_names()[0].replace('_output', '') return out_name, sym @@ -483,7 +449,8 @@ def _get_convolution_weights(self, layer, opname): return if (layer.n * layer.c * layer.size * layer.size) != layer.nweights: - raise RuntimeError("layer weights size not matching with n c h w") + raise_attribute_invalid(layer.n * layer.c * layer.size * layer.size, + 'layer weights size', 'conv2d') shape = (layer.n, layer.c, layer.size, layer.size) weights = self._read_memory_buffer(shape, layer.weights) @@ -663,8 +630,7 @@ def _get_darknet_attrs(self, layer, layer_num): pass else: - err = "Darknet layer type {} is not supported in nnvm.".format(layer.type) - raise NotImplementedError(err) + raise_operator_unimplemented(layer.type) return attr @@ -761,7 +727,7 @@ def _handle_darknet_rnn_layers(self, layer_num, sym): op_name, new_attrs = 'elemwise_add', {} new_inputs = _as_list([sym, state]) - state = _darknet_get_nnvm_op(op_name)(*new_inputs, **new_attrs) + state = get_nnvm_op(op_name)(*new_inputs, **new_attrs) self._outs.append(state) output_layer = layer.output_layer @@ -786,7 +752,7 @@ def _handle_darknet_rnn_layers(self, layer_num, sym): op_name, new_attrs = 'elemwise_add', {} new_inputs = _as_list([sym, state]) - state = _darknet_get_nnvm_op(op_name)(*new_inputs, **new_attrs) + state = get_nnvm_op(op_name)(*new_inputs, **new_attrs) self._outs.append(state) output_layer = layer.output_layer @@ -797,7 +763,7 @@ def _handle_darknet_rnn_layers(self, layer_num, sym): elif LAYERTYPE.LSTM == layer.type: if layer.steps > 1: - raise NotImplementedError("Currently support only single step GRU") + raise_attribute_invalid(layer.steps, 'number of steps', 'RNN') op_name_add = 'elemwise_add' op_name_mul = 'elemwise_mul' @@ -819,16 +785,16 @@ def _handle_darknet_rnn_layers(self, layer_num, sym): sym_uo = self._get_darknet_rnn_attrs(layer.uo, input_sym) new_inputs = _as_list([sym_wf, sym_uf]) - add_f = _darknet_get_nnvm_op(op_name_add)(*new_inputs, **attrs) + add_f = get_nnvm_op(op_name_add)(*new_inputs, **attrs) new_inputs = _as_list([sym_wi, sym_ui]) - add_i = _darknet_get_nnvm_op(op_name_add)(*new_inputs, **attrs) + add_i = get_nnvm_op(op_name_add)(*new_inputs, **attrs) new_inputs = _as_list([sym_wg, sym_ug]) - add_g = _darknet_get_nnvm_op(op_name_add)(*new_inputs, **attrs) + add_g = get_nnvm_op(op_name_add)(*new_inputs, **attrs) new_inputs = _as_list([sym_wo, sym_uo]) - add_o = _darknet_get_nnvm_op(op_name_add)(*new_inputs, **attrs) + add_o = get_nnvm_op(op_name_add)(*new_inputs, **attrs) act_attr['activation'] = ACTIVATION.LOGISTIC act_f, _ = _darknet_activations(_as_list(add_f), act_attr) @@ -843,19 +809,19 @@ def _handle_darknet_rnn_layers(self, layer_num, sym): act_o, _ = _darknet_activations(_as_list(add_o), act_attr) new_inputs = _as_list([act_i, act_g]) - mul_t = _darknet_get_nnvm_op(op_name_mul)(*new_inputs, **attrs) + mul_t = get_nnvm_op(op_name_mul)(*new_inputs, **attrs) new_inputs = _as_list([act_f, c_state]) - c_state = _darknet_get_nnvm_op(op_name_mul)(*new_inputs, **attrs) + c_state = get_nnvm_op(op_name_mul)(*new_inputs, **attrs) new_inputs = _as_list([mul_t, c_state]) - c_state = _darknet_get_nnvm_op(op_name_add)(*new_inputs, **attrs) + c_state = get_nnvm_op(op_name_add)(*new_inputs, **attrs) act_attr['activation'] = ACTIVATION.TANH h_state, _ = _darknet_activations(_as_list(c_state), act_attr) new_inputs = _as_list([act_o, h_state]) - h_state = _darknet_get_nnvm_op(op_name_mul)(*new_inputs, **attrs) + h_state = get_nnvm_op(op_name_mul)(*new_inputs, **attrs) self._outs = self._outs + [c_state, h_state] sym = h_state self._sym_array[layer_num] = sym @@ -863,7 +829,7 @@ def _handle_darknet_rnn_layers(self, layer_num, sym): elif LAYERTYPE.GRU == layer.type: if layer.steps > 1: - raise NotImplementedError("Currently support only single step GRU") + raise_attribute_invalid(layer.steps, 'number of steps', 'RNN') op_name_add = 'elemwise_add' op_name_mul = 'elemwise_mul' @@ -881,10 +847,10 @@ def _handle_darknet_rnn_layers(self, layer_num, sym): sym_uh = self._get_darknet_rnn_attrs(layer.uh, input_sym) new_inputs = _as_list([sym_uz, sym_wz]) - add_z = _darknet_get_nnvm_op(op_name_add)(*new_inputs, **attrs) + add_z = get_nnvm_op(op_name_add)(*new_inputs, **attrs) new_inputs = _as_list([sym_ur, sym_wr]) - add_r = _darknet_get_nnvm_op(op_name_add)(*new_inputs, **attrs) + add_r = get_nnvm_op(op_name_add)(*new_inputs, **attrs) act_attr['activation'] = ACTIVATION.LOGISTIC act_z, _ = _darknet_activations(_as_list(add_z), act_attr) @@ -893,12 +859,12 @@ def _handle_darknet_rnn_layers(self, layer_num, sym): act_r, _ = _darknet_activations(_as_list(add_r), act_attr) new_inputs = _as_list([act_r, state]) - forgot = _darknet_get_nnvm_op(op_name_mul)(*new_inputs, **attrs) + forgot = get_nnvm_op(op_name_mul)(*new_inputs, **attrs) sym_wh = self._get_darknet_rnn_attrs(layer.wh, forgot) new_inputs = _as_list([sym_uh, sym_wh]) - h_state = _darknet_get_nnvm_op(op_name_add)(*new_inputs, **attrs) + h_state = get_nnvm_op(op_name_add)(*new_inputs, **attrs) if layer.tanh == 1: act_attr['activation'] = ACTIVATION.TANH diff --git a/nnvm/python/nnvm/frontend/keras.py b/nnvm/python/nnvm/frontend/keras.py index 56758ada5f46..d15d2b3f01ab 100644 --- a/nnvm/python/nnvm/frontend/keras.py +++ b/nnvm/python/nnvm/frontend/keras.py @@ -74,7 +74,7 @@ def _convert_activation(insym, keras_layer, _): if act_type == 'hard_sigmoid': transformX = (0.2 * insym) + 0.5 return _sym.clip(transformX, a_min=0, a_max=1) - raise TypeError("Unsupported activation type : {}".format(act_type)) + raise_operator_unimplemented(act_type) def _convert_advanced_activation(insym, keras_layer, symtab): @@ -100,7 +100,7 @@ def _convert_advanced_activation(insym, keras_layer, symtab): theta = keras_layer.theta if hasattr(keras_layer, "theta") else 1.0 theta_tensor = _sym.full_like(insym[0], fill_value=float(theta)) return _sym.elemwise_mul(insym[0], _sym.greater(insym[0], theta_tensor, out_type="float32")) - raise TypeError("Unsupported advanced activation type : {}".format(act_type)) + raise_operator_unimplemented(act_type) def _convert_merge(insym, keras_layer, _): @@ -114,11 +114,11 @@ def _convert_merge(insym, keras_layer, _): elif merge_type == 'Multiply': ret = _sym.elemwise_mul(ret, insym[i]) elif merge_type == 'Average': - raise NotImplementedError('Average merge not implemented') + raise_operator_unimplemented('average merge') elif merge_type == 'Maximum': - raise NotImplementedError('Maximum merge not implemented') + raise_operator_unimplemented('maximum merge') else: - raise TypeError("Unsupported merge type : {}".format(merge_type)) + raise_operator_unimplemented(merge_type) return ret @@ -135,7 +135,7 @@ def _convert_dense(insym, keras_layer, symtab): if input_dim > 2: input_shape = tuple(dim if dim else 1 for dim in _as_list(input_shape)[0]) if input_dim != 3 or input_shape[0] != 1 or input_shape[1] != 1: - raise ValueError("Cannot flatten the inputs with shape.", input_shape, " for dense.") + raise_attribute_invalid(input_shape, 'input shape', 'dense') insym = _sym.squeeze(insym, axis=0) out = _sym.dense(data=insym, **params) # defuse activation @@ -199,7 +199,7 @@ def _convert_convolution(insym, keras_layer, symtab): else: insym = _sym.pad(data=insym, pad_width=((0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r))) else: - raise TypeError("Unsupported padding type : {}".format(keras_layer.padding)) + raise_operator_unimplemented('padding with {}'.format(keras_layer.padding)) if is_deconv: out = _sym.conv2d_transpose(data=insym, **params) else: @@ -240,7 +240,7 @@ def _convert_separable_convolution(insym, keras_layer, symtab): insym = _sym.pad(data=insym, pad_width=( (0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r))) else: - raise TypeError("Unsupported padding type : {}".format(keras_layer.padding)) + raise_operator_unimplemented('padding with {}'.format(keras_layer.padding)) depthconv = _sym.conv2d(data=insym, **params0) # pointwise conv weight1 = weightList[1].transpose([3, 2, 0, 1]) @@ -294,13 +294,13 @@ def _convert_pooling(insym, keras_layer, symtab): pad_l, pad_r = _get_pad_pair(in_w, pool_w, stride_w) params['padding'] = [pad_t, pad_l, pad_b, pad_r] else: - raise TypeError("Unsupported padding type : {}".format(keras_layer.padding)) + raise_operator_unimplemented('padding with {}'.format(keras_layer.padding)) if pool_type == 'MaxPooling2D': return _sym.max_pool2d(insym, **params) if pool_type == 'AveragePooling2D': # TODO: in keras, padded zeros are not calculated return _sym.avg_pool2d(insym, **params) - raise TypeError("Unsupported pooling type : {}".format(keras_layer)) + raise_operator_unimplemented('pooling with {}'.format(keras_layer)) def _convert_upsample(insym, keras_layer, _): @@ -312,17 +312,15 @@ def _convert_upsample(insym, keras_layer, _): elif upsample_type == "UpSampling2D": h, w = keras_layer.size if h != w: - raise TypeError("Unsupported upsampling type with different axes size : {}" - .format(keras_layer.size)) + raise_attribute_invalid(keras_layer.size, 'size', 'upsample') params = {'scale': h} elif upsample_type == "UpSampling3D": h, w, d = keras_layer.size if h != w or w != d: - raise TypeError("Unsupported upsampling type with different axes size : {}" - .format(keras_layer.size)) + raise_attribute_invalid(keras_layer.size, 'size', 'upsample') params = {'scale': h} else: - raise TypeError("Unsupported upsampling type : {}".format(upsample_type)) + raise_operator_unimplemented(upsample_type) return _sym.upsampling(insym, **params) @@ -330,12 +328,12 @@ def _convert_cropping(insym, keras_layer, _): _check_data_format(keras_layer) crop_type = type(keras_layer).__name__ if crop_type == "Cropping1D": - raise NotImplementedError("Cropping1D not implemented") + raise_operator_unimplemented(crop_type) elif crop_type == "Cropping2D": (_, in_h, in_w, _) = keras_layer.input_shape ((crop_t, crop_b), (crop_l, crop_r)) = keras_layer.cropping else: - raise TypeError("Unrecognized cropping type : {}".format(crop_type)) + raise_operator_unimplemented(crop_type) int32_max = np.iinfo(np.int32).max return _sym.strided_slice(insym, begin=[0, 0, crop_t, crop_l], end=[int32_max, int32_max, in_h-crop_b, in_w-crop_r]) @@ -379,13 +377,11 @@ def _convert_padding(insym, keras_layer, _): top, bottom = padding[0] left, right = padding[1] else: - raise ValueError("Unrecognized padding option: {}".format(str(padding))) + raise_attribute_invalid(str(padding), 'padding', padding_type) else: - raise ValueError("Unrecognized padding option: {}".format(str(padding))) - elif padding_type == 'ZeroPadding1D': - raise NotImplementedError("ZeroPadding1D not implemented") + raise_attribute_invalid(str(padding), 'padding', padding_type) else: - raise ValueError("Unrecognized padding type: {}".format(padding_type)) + raise_operator_unimplemented(padding_type) return _sym.pad(data=insym, pad_width=((0, 0), (0, 0), (top, bottom), (left, right))) @@ -593,7 +589,7 @@ def _default_skip(insym, keras_layer, _): # pylint: disable=unused-argument def _check_unsupported_layers(model): for layer in model.layers: if type(layer).__name__ not in _convert_map: - raise ValueError("Keras layer {} not supported.".format(type(layer).__name__)) + raise_operator_unimplemented(type(layer).__name__) def _as_list(arr): """Force being a list, ignore if already is.""" @@ -619,7 +615,7 @@ def keras_op_to_nnvm(insym, keras_layer, outname, symtab): The global symbol table to be updated """ if type(keras_layer).__name__ not in _convert_map: - raise NotImplementedError("{} is not supported".format((type(keras_layer).__name__))) + raise_operator_unimplemented(type(keras_layer).__name__) outs = _convert_map[type(keras_layer).__name__](insym, keras_layer, symtab) outs = _as_list(outs) diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index 47d7ede96e5f..372f10bd98b9 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -7,48 +7,19 @@ __all__ = ['from_mxnet'] -def _get_nnvm_op(op_name): - op = getattr(_sym, op_name) - if not op: - raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name)) - return op - -def _required_attr(attr, key): - assert isinstance(attr, dict) - if key not in attr: - raise AttributeError("Required attribute {} not found.".format(key)) - return attr[key] - -def _raise_not_supported(attr, op='nnvm'): - err = "{} is not supported in {}.".format(attr, op) - raise NotImplementedError(err) - -def _warn_not_used(attr, op='nnvm'): - import warnings - err = "{} is ignored in {}.".format(attr, op) - warnings.warn(err) - -def _parse_tshape(tshape): - """Parse tshape in string.""" - return [int(x.strip()) for x in tshape.strip('()').split(',')] - -def _parse_bool_str(attr, key, default='False'): - """Parse bool string to boolean.""" - return attr.get(key, default).strip().lower() in ['true', '1', 't', 'y', 'yes'] - def _rename(new_name): def impl(inputs, attrs): - return _get_nnvm_op(new_name)(*inputs, **attrs) + return get_nnvm_op(new_name)(*inputs, **attrs) return impl def _pooling(inputs, attrs): - kernel = _parse_tshape(_required_attr(attrs, 'kernel')) + kernel = parse_tshape(required_attr(attrs, 'kernel', 'pooling')) if len(kernel) != 2: - _raise_not_supported('non-2d kernel', 'pool_2d') - global_pool = 'global' if _parse_bool_str(attrs, 'global_pool') else '' - pool_type = _required_attr(attrs, 'pool_type') + raise_attribute_unimplemented('non-2d kernel', 'pool_2d') + global_pool = 'global' if parse_bool_str(attrs, 'global_pool') else '' + pool_type = required_attr(attrs, 'pool_type', 'pooling') if pool_type not in ['avg', 'max']: - _raise_not_supported('non-avg/max', 'pool2d') + raise_attribute_unimplemented('non-avg/max', 'pool2d') op_name, new_attrs = '_'.join([global_pool, pool_type, 'pool2d']).strip('_'), {} # new_attrs['layout'] = 'NCHW' if not global_pool: @@ -58,42 +29,41 @@ def _pooling(inputs, attrs): new_attrs['ceil_mode'] = (attrs.get('pooling_convention', 'valid') == 'full') if pool_type == 'avg': new_attrs['count_include_pad'] = attrs.get('count_include_pad', True) - return _get_nnvm_op(op_name)(*inputs, **new_attrs) + return get_nnvm_op(op_name)(*inputs, **new_attrs) def _batch_norm(inputs, attrs): - if _parse_bool_str(attrs, 'output_mean_var'): - _raise_not_supported('output_mean_var', 'batch_norm') - # if _parse_bool_str(attrs, 'fix_gamma'): + raise_attribute_unimplemented('output_mean_var', 'batch_norm') + # if parse_bool_str(attrs, 'fix_gamma'): # _warn_not_used('fix_gamma', 'batch_norm') - if _parse_bool_str(attrs, 'use_global_stats'): - _warn_not_used('use_global_stats', 'batch_norm') - # if _parse_bool_str(attrs, 'momentum'): + if parse_bool_str(attrs, 'use_global_stats'): + warn_not_used('use_global_stats', 'batch_norm') + # if parse_bool_str(attrs, 'momentum'): # _warn_not_used('momentum', 'batch_norm') op_name, new_attrs = 'batch_norm', {} new_attrs['axis'] = attrs.get('axis', 1) new_attrs['epsilon'] = attrs.get('eps', 0.001) new_attrs['center'] = True - new_attrs['scale'] = not _parse_bool_str(attrs, 'fix_gamma', default="False") - return _get_nnvm_op(op_name)(*inputs, **new_attrs) + new_attrs['scale'] = not parse_bool_str(attrs, 'fix_gamma', default="False") + return get_nnvm_op(op_name)(*inputs, **new_attrs) def _concat(inputs, attrs): op_name = 'concatenate' new_attrs = {'axis': attrs.get('dim', 1)} - return _get_nnvm_op(op_name)(*inputs, **new_attrs) + return get_nnvm_op(op_name)(*inputs, **new_attrs) def _conv2d(inputs, attrs): - kernel = _parse_tshape(_required_attr(attrs, 'kernel')) + kernel = parse_tshape(required_attr(attrs, 'kernel', 'conv2d')) if len(kernel) != 2: - _raise_not_supported('non 2d kernel', 'conv2d') + raise_attribute_unimplemented('non 2d kernel', 'conv2d') layout = attrs.get('layout', 'NCHW') if layout not in ['NCHW', 'NHWC']: - _raise_not_supported('layout: ' + layout, 'conv2d') + raise_attribute_unimplemented('layout: ' + layout, 'conv2d') if 'kernel_layout' in attrs: kernel_layout = attrs['kernel_layout'] else: kernel_layout = 'HWIO' if layout == 'NHWC' else 'OIHW' op_name, new_attrs = 'conv2d', {} - new_attrs['channels'] = _required_attr(attrs, 'num_filter') + new_attrs['channels'] = required_attr(attrs, 'num_filter', 'conv2d') new_attrs['kernel_size'] = kernel new_attrs['strides'] = attrs.get('stride', (1, 1)) new_attrs['padding'] = attrs.get('pad', (0, 0)) @@ -102,23 +72,23 @@ def _conv2d(inputs, attrs): new_attrs['layout'] = layout new_attrs['kernel_layout'] = kernel_layout new_attrs['use_bias'] = attrs.get('no_bias', 'False').strip() == 'False' - return _get_nnvm_op(op_name)(*inputs, **new_attrs) + return get_nnvm_op(op_name)(*inputs, **new_attrs) def _conv2d_transpose(inputs, attrs): if 'target_shape' in attrs: - _raise_not_supported('target_shape', 'conv2d_transpose') - kernel = _parse_tshape(_required_attr(attrs, 'kernel')) + raise_attribute_unimplemented('target_shape', 'conv2d_transpose') + kernel = parse_tshape(required_attr(attrs, 'kernel', 'conv2d_transpose')) if len(kernel) != 2: - _raise_not_supported('non-2d kernel', 'conv2d_transpose') + raise_attribute_invalid(len(kernel), 'kernel dim', 'conv2d_transpose') layout = attrs.get('layout', 'NCHW') if layout not in ['NCHW', 'NHWC']: - _raise_not_supported('layout: ' + layout, 'conv2d_transpose') + raise_attribute_unimplemented('layout: ' + layout, 'conv2d_transpose') if 'kernel_layout' in attrs: kernel_layout = attrs['kernel_layout'] else: kernel_layout = 'HWIO' if layout == 'NHWC' else 'OIHW' op_name, new_attrs = 'conv2d_transpose', {} - new_attrs['channels'] = _required_attr(attrs, 'num_filter') + new_attrs['channels'] = required_attr(attrs, 'num_filter', 'conv2d_transpose') new_attrs['kernel_size'] = kernel new_attrs['strides'] = attrs.get('stride', (1, 1)) new_attrs['output_padding'] = attrs.get('adj', (0, 0)) @@ -127,67 +97,67 @@ def _conv2d_transpose(inputs, attrs): new_attrs['groups'] = attrs.get('num_group', 1) new_attrs['layout'] = layout new_attrs['kernel_layout'] = kernel_layout - new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias') - return _get_nnvm_op(op_name)(*inputs, **new_attrs) + new_attrs['use_bias'] = not parse_bool_str(attrs, 'no_bias') + return get_nnvm_op(op_name)(*inputs, **new_attrs) def _dense(inputs, attrs): import mxnet as mx op_name, new_attrs = 'dense', {} - new_attrs['units'] = _required_attr(attrs, 'num_hidden') - new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias') + new_attrs['units'] = required_attr(attrs, 'num_hidden', 'dense') + new_attrs['use_bias'] = not parse_bool_str(attrs, 'no_bias') try: _ = mx.sym.FullyConnected(mx.sym.var('x'), num_hidden=1, flatten=True) has_flatten = True except mx.base.MXNetError: # no flatten attribute in old mxnet has_flatten = False - use_flatten = _parse_bool_str(attrs, 'flatten', 'True') + use_flatten = parse_bool_str(attrs, 'flatten', 'True') if has_flatten and use_flatten: inputs[0] = _sym.flatten(inputs[0]) - return _get_nnvm_op(op_name)(*inputs, **new_attrs) + return get_nnvm_op(op_name)(*inputs, **new_attrs) def _dropout(inputs, attrs): op_name, new_attrs = 'dropout', {} new_attrs['rate'] = attrs.get('p', 0.5) - return _get_nnvm_op(op_name)(*inputs, **new_attrs) + return get_nnvm_op(op_name)(*inputs, **new_attrs) def _leaky_relu(inputs, attrs): - act_type = _required_attr(attrs, 'act_type') + act_type = required_attr(attrs, 'act_type', 'leaky_relu') if act_type in ['leaky', 'prelu']: op_name, new_attrs = act_type, {} if act_type == 'leaky': new_attrs['alpha'] = attrs.get('slope', 0.25) - sym = _get_nnvm_op(op_name)(*inputs, **new_attrs) + sym = get_nnvm_op(op_name)(*inputs, **new_attrs) elif act_type == 'elu': slope = attrs.get('slope', 0.25) sym = -slope * _sym.relu(1 - _sym.exp(*inputs)) + _sym.relu(*inputs) elif act_type == 'rrelu': - lower_bound = float(_required_attr(attrs, 'lower_bound')) - upper_bound = float(_required_attr(attrs, 'upper_bound')) + lower_bound = float(required_attr(attrs, 'lower_bound', 'leaky_relu')) + upper_bound = float(required_attr(attrs, 'upper_bound', 'leaky_relu')) slope = (lower_bound + upper_bound) / 2.0 op_name, new_attrs = 'leaky_relu', {'alpha': str(slope)} - sym = _get_nnvm_op(op_name)(*inputs, **new_attrs) + sym = get_nnvm_op(op_name)(*inputs, **new_attrs) else: - _raise_not_supported('act_type: ' + act_type) + raise_attribute_unimplemented([act_type]) return sym def _activations(inputs, attrs): - act_type = _required_attr(attrs, 'act_type') + act_type = required_attr(attrs, 'act_type', 'activations') if act_type in ['relu', 'sigmoid', 'tanh']: op_name, new_attrs = act_type, {} - sym = _get_nnvm_op(op_name)(*inputs, **new_attrs) + sym = get_nnvm_op(op_name)(*inputs, **new_attrs) elif act_type == 'softrelu': sym = _sym.log((1 + _sym.exp(*inputs))) else: - _raise_not_supported('act_type: ' + act_type) + raise_operator_unimplemented(act_type) return sym def _reshape(inputs, attrs): - if _parse_bool_str(attrs, 'reverse'): - _raise_not_supported('reverse', 'reshape') + if parse_bool_str(attrs, 'reverse'): + raise_attribute_unimplemented('reverse', 'reshape') op_name, new_attrs = 'reshape', {} - new_attrs['shape'] = _required_attr(attrs, 'shape') - return _get_nnvm_op(op_name)(*inputs, **new_attrs) + new_attrs['shape'] = required_attr(attrs, 'shape', 'reshape') + return get_nnvm_op(op_name)(*inputs, **new_attrs) def _slice(inputs, attrs): begin = attrs.get('begin', None) @@ -200,60 +170,60 @@ def _slice(inputs, attrs): new_attrs = {'begin': begin, 'end': end} if stride is not None: new_attrs['stride'] = stride - return _get_nnvm_op('strided_slice')(inputs[0], **new_attrs) + return get_nnvm_op('strided_slice')(inputs[0], **new_attrs) def _split(inputs, attrs): op_name, new_attrs = 'split', {} axis = attrs.get('axis', 1) - new_attrs['indices_or_sections'] = _required_attr(attrs, 'num_outputs') + new_attrs['indices_or_sections'] = required_attr(attrs, 'num_outputs', 'split') new_attrs['axis'] = axis - outputs = _get_nnvm_op(op_name)(*inputs, **new_attrs) - if _parse_bool_str(attrs, 'squeeze_axis'): + outputs = get_nnvm_op(op_name)(*inputs, **new_attrs) + if parse_bool_str(attrs, 'squeeze_axis'): squeeze_attrs = {'axis': axis} - outputs = _sym.Group([_get_nnvm_op('squeeze')(o, **squeeze_attrs) for o in outputs]) + outputs = _sym.Group([get_nnvm_op('squeeze')(o, **squeeze_attrs) for o in outputs]) return outputs def _softmax_activation(inputs, attrs): op_name, new_attrs = 'softmax', {} mode = attrs.get('mode', 'instance') new_attrs['axis'] = 0 if mode == 'instance' else 1 - return _get_nnvm_op(op_name)(inputs[0], **new_attrs) + return get_nnvm_op(op_name)(inputs[0], **new_attrs) def _softmax_output(inputs, attrs): op_name, new_attrs = 'softmax', {} - if _parse_bool_str(attrs, 'multi_output'): + if parse_bool_str(attrs, 'multi_output'): new_attrs['axis'] = 1 - return _get_nnvm_op(op_name)(inputs[0], **new_attrs) + return get_nnvm_op(op_name)(inputs[0], **new_attrs) def _upsampling(inputs, attrs): scale = attrs.get('scale') new_attrs = {'scale':int(scale)} - return _get_nnvm_op('upsampling')(inputs[0], **new_attrs) + return get_nnvm_op('upsampling')(inputs[0], **new_attrs) def _clip(inputs, attrs): op_name, new_attrs = "clip", {} - new_attrs['a_min'] = _required_attr(attrs, 'a_min') - new_attrs['a_max'] = _required_attr(attrs, 'a_max') - return _get_nnvm_op(op_name)(*inputs, **new_attrs) + new_attrs['a_min'] = required_attr(attrs, 'a_min', 'clip') + new_attrs['a_max'] = required_attr(attrs, 'a_max', 'clip') + return get_nnvm_op(op_name)(*inputs, **new_attrs) def _contrib_multibox_detection(inputs, attrs): - clip = _parse_bool_str(attrs, 'clip', default='True') + clip = parse_bool_str(attrs, 'clip', default='True') threshold = attrs.get('threshold') or 0.01 nms_threshold = attrs.get('nms_threshold') or 0.5 - force_suppress = _parse_bool_str(attrs, 'force_suppress', default='False') + force_suppress = parse_bool_str(attrs, 'force_suppress', default='False') variances = tuple([float(x.strip()) for x in attrs.get('variances').strip('()').split(',')]) \ if attrs.get('variances') is not None else (0.1, 0.1, 0.2, 0.2) nms_topk = attrs.get('nms_topk') or -1 new_attrs0 = {'clip': clip, 'threshold': float(threshold), 'variances': variances} new_attrs1 = {'return_indices': False, 'iou_threshold': float(nms_threshold), 'force_suppress': force_suppress, 'top_k': int(nms_topk)} - data, valid_count = _get_nnvm_op('multibox_transform_loc')(inputs[0], inputs[1], + data, valid_count = get_nnvm_op('multibox_transform_loc')(inputs[0], inputs[1], inputs[2], **new_attrs0) - return _get_nnvm_op('non_max_suppression')(data, valid_count, **new_attrs1) + return get_nnvm_op('non_max_suppression')(data, valid_count, **new_attrs1) def _elemwise_sum(inputs, _): new_attrs = {'num_args':len(inputs)} - return _get_nnvm_op('elemwise_sum')(*inputs, **new_attrs) + return get_nnvm_op('elemwise_sum')(*inputs, **new_attrs) def _crop_like(inputs, attrs): new_attrs = {} @@ -261,20 +231,20 @@ def _crop_like(inputs, attrs): tuple([float(x.strip()) for x in attrs.get('offsets').strip('()').split(',')]) \ if attrs.get('offsets') is not None else (0, 0) if offsets != (0, 0): - raise RuntimeError("Currently only supports offsets to be zero.") - center_crop = _parse_bool_str(attrs, 'center_crop', default="False") + raise_attribute_invalid(offsets, 'offsets', 'crop_like') + center_crop = parse_bool_str(attrs, 'center_crop', default="False") if center_crop: - raise RuntimeError("center crop is not supported.") + raise_attribute_unimplemented('center crop', 'crop_like') if len(inputs) < 2: raise RuntimeError("Only support crop_like pattern.") new_attrs["axis"] = [2, 3] - return _get_nnvm_op('slice_like')(inputs[0], inputs[1], **new_attrs) + return get_nnvm_op('slice_like')(inputs[0], inputs[1], **new_attrs) def _expand_dims(inputs, attrs): op_name, new_attrs = 'expand_dims', {} - new_attrs['axis'] = _required_attr(attrs, 'axis') - return _get_nnvm_op(op_name)(*inputs, **new_attrs) + new_attrs['axis'] = required_attr(attrs, 'axis', 'expand_dims') + return get_nnvm_op(op_name)(*inputs, **new_attrs) def _lrn(inputs, attrs): op_name, new_attrs = 'lrn', {} @@ -283,36 +253,36 @@ def _lrn(inputs, attrs): new_attrs['bias'] = attrs.get('knorm', 2) # NCHW format and normalization along channel axis new_attrs['axis'] = 1 - new_attrs['size'] = _required_attr(attrs, 'nsize') - return _get_nnvm_op(op_name)(*inputs, **new_attrs) + new_attrs['size'] = required_attr(attrs, 'nsize', 'lrn') + return get_nnvm_op(op_name)(*inputs, **new_attrs) def _minimum(inputs, attrs): - return _get_nnvm_op('broadcast_min')(*inputs, **attrs) + return get_nnvm_op('broadcast_min')(*inputs, **attrs) def _maximum(inputs, attrs): - return _get_nnvm_op('broadcast_max')(*inputs, **attrs) + return get_nnvm_op('broadcast_max')(*inputs, **attrs) def _ones(_, attrs): op_name = 'ones' - return _get_nnvm_op(op_name)(**attrs) + return get_nnvm_op(op_name)(**attrs) def _zeros(_, attrs): op_name = 'zeros' - return _get_nnvm_op(op_name)(**attrs) + return get_nnvm_op(op_name)(**attrs) def _argmax(inputs, attrs): op_name, new_attrs = 'argmax', {} new_attrs['dtype'] = 'float32' new_attrs['axis'] = attrs.get('axis', 0) - new_attrs['keepdims'] = _parse_bool_str(attrs, 'keepdims', default="False") - return _get_nnvm_op(op_name)(*inputs, **new_attrs) + new_attrs['keepdims'] = parse_bool_str(attrs, 'keepdims', default="False") + return get_nnvm_op(op_name)(*inputs, **new_attrs) def _argmin(inputs, attrs): op_name, new_attrs = 'argmin', {} new_attrs['dtype'] = 'float32' new_attrs['axis'] = attrs.get('axis', 0) - new_attrs['keepdims'] = _parse_bool_str(attrs, 'keepdims', default="False") - return _get_nnvm_op(op_name)(*inputs, **new_attrs) + new_attrs['keepdims'] = parse_bool_str(attrs, 'keepdims', default="False") + return get_nnvm_op(op_name)(*inputs, **new_attrs) _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__', '__div_symbol__', '__mul_scalar__', '__mul_symbol__', @@ -406,12 +376,12 @@ def _convert_symbol(op_name, inputs, attrs, identity_list = identity_list if identity_list else _identity_list convert_map = convert_map if convert_map else _convert_map if op_name in identity_list: - op = _get_nnvm_op(op_name) + op = get_nnvm_op(op_name) sym = op(*inputs, **attrs) elif op_name in convert_map: sym = convert_map[op_name](inputs, attrs) else: - _raise_not_supported('Operator: ' + op_name) + raise_operator_unimplemented(op_name) return sym def _as_list(arr): diff --git a/nnvm/python/nnvm/frontend/onnx.py b/nnvm/python/nnvm/frontend/onnx.py index ad0acc31a521..1262bebbb85f 100644 --- a/nnvm/python/nnvm/frontend/onnx.py +++ b/nnvm/python/nnvm/frontend/onnx.py @@ -397,7 +397,7 @@ def _impl_v7(cls, inputs, attr, params): elif mode == b'linear': method = "BILINEAR" else: - raise ValueError("Invalid ONNX upsample mode: {}".format(mode)) + raise_attribute_invalid(mode, 'mode', 'upsample') return _sym.upsampling(inputs[0], scale=int(scales[-1]), method=method, layout='NCHW') @@ -922,8 +922,7 @@ def _convert_operator(self, elif op_name in convert_map: sym = convert_map[op_name](inputs, attrs, self._params) else: - raise NotImplementedError( - "Operator {} not implemented.".format(op_name)) + raise_operator_unimplemented(op_name) return sym def _fix_outputs(self, op_name, outputs): diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index f4065cc544e1..140fa900eefa 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -11,7 +11,7 @@ from .. import symbol as _sym from .. import graph as _graph from .. compiler import graph_util, build_module -from .common import get_nnvm_op, AttrConverter as AttrConvert +from .common import AttrConverter as AttrConvert __all__ = ['from_tensorflow'] @@ -68,7 +68,7 @@ def _impl(attr): kernel = attr['kernel_shape'] if len(kernel) == 2: return prefix + '2d' + surfix - raise NotImplementedError("Only 2d kernel supported.") + raise_attribute_unimplemented('non-2d kernel', prefix) return _impl def _dimension_constraint(): @@ -129,7 +129,7 @@ def _impl(inputs, attr, params): attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3]) attr['strides'] = (attr['strides'][2], attr['strides'][3]) else: - raise TypeError("Unsupported data_format type : {}".format(attr['data_format'])) + raise_attribute_invalid(attr['data_format'], 'data_format', 'pooling') if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": tmp_shape = attr['_input_shapes'][inputs[0]] @@ -158,7 +158,7 @@ def _impl(inputs, attr, params): attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]] else: - raise TypeError("Unsupported padding type : {}".format(attr['padding'])) + raise_attribute_unimplemented(attr['padding'], 'padding', 'pooling') if name == "avg_pool": attr['count_include_pad'] = False @@ -232,7 +232,7 @@ def _impl(inputs, attr, params): attr['dilations'] = (attr['dilations'][2], attr['dilations'][3]) attr['strides'] = (attr['strides'][2], attr['strides'][3]) else: - raise TypeError("Unsupported data format type : {}".format(attr['data_format'])) + raise_attribute_invalid(attr['data_format'], 'data_format', 'conv') if opname == 'depthwise': @@ -276,7 +276,7 @@ def _impl(inputs, attr, params): attr['padding'] = [0, 0] else: - raise TypeError("Unsupported padding type : {}".format(attr['padding'])) + raise_attribute_invalid(attr['padding'], 'padding', 'conv') if 'kernel_layout' not in attr: if opname == 'conv': @@ -432,7 +432,7 @@ def _impl(inputs, attr, params): op_name="reshape", extras={'shape':tuple(params_new[0].asnumpy().flatten())}, ignores=['Tshape'])(inputs, attr) - raise RuntimeError("Reshape with dynamic shape input not supported yet.") + raise_attribute_unimplemented('dynamic shape', 'reshape') return _impl def _bias_add(): @@ -736,7 +736,7 @@ def _impl(inputs, attr, params): if padlist_key in params: padlist = params.pop(padlist_key).asnumpy() else: - raise RuntimeError("Required parameter {} not fount.".format(padlist_key)) + raise_attribute_required(padlist_key, 'pad') paddings = tuple([tuple(l) for l in padlist]) attr['pad_width'] = paddings attr['pad_value'] = 0 @@ -1188,8 +1188,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): missing_operators = self._parse_import_prerequisites(graph) if missing_operators: - raise NotImplementedError( \ - "The following operators are not implemented: {}".format(missing_operators)) + raise_operator_unimplemented(*missing_operators) for node in graph.node: if node.op == 'Placeholder': @@ -1529,7 +1528,7 @@ def _convert_operator(self, op_name, inputs, attrs, self._params, graph, convert_map_rnn) else: - raise NotImplementedError("Operator {} not implemented.".format(op_name)) + raise_operator_unimplemented(op_name) return sym def _fix_extranodes(self, op_name, attr, inputs): diff --git a/python/tvm/error_handling/__init__.py b/python/tvm/error_handling/__init__.py new file mode 100644 index 000000000000..8616d1ba973a --- /dev/null +++ b/python/tvm/error_handling/__init__.py @@ -0,0 +1,44 @@ +import warnings +import traceback +import sys + +def _excepthook(type, value, tb): + print(''.join(traceback.format_exception(type, value, tb))) + +sys.excepthook = _excepthook + +class OperatorError(Exception): + pass + +def _raise_error_helper(exception, msg, *args): + raise exception(msg.format(*args)) + +def raise_attribute_required(key, op_name): + class OperatorAttributeRequired(OperatorError): + pass + msg = 'Required attribute {} not found in operator {}.' + _raise_error_helper(OperatorAttributeRequired, msg, key, op_name) + +def raise_attribute_invalid(val, attr, op_name): + class OperatorAttributeValueNotValid(OperatorError): + pass + msg = 'Value {} in attr {} is not valid in operator {}.' + _raise_error_helper(OperatorAttributeValueNotValid, msg, val, attr, + op_name) + +def raise_operator_unimplemented(*missing_ops): + class OperatorNotImplemented(OperatorError): + pass + missing_ops = str(missing_ops).strip('(,)') + msg = 'The following operators are not supported: {}.' + _raise_error_helper(OperatorNotImplemented, msg, missing_ops) + +def raise_attribute_unimplemented(key, op_name): + class OperatorAttributeNotImplemented(OperatorError): + pass + msg = 'Attribute {} is not supported in operator {}.' + _raise_error_helper(OperatorAttributeNotImplemented, msg, key, op_name) + +def warn_not_used(attr, op_name): + msg = '{} is ignored in {}.'.format(attr, op_name) + warnings.warn(msg) diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index dee3999ad3f1..6ba2f0bde12d 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -14,3 +14,8 @@ from .coreml import from_coreml from .caffe2 import from_caffe2 from .tensorflow import from_tensorflow +from tvm.error_handling import raise_attribute_required, \ + raise_attribute_invalid, \ + raise_operator_unimplemented, \ + raise_attribute_unimplemented, \ + warn_not_used diff --git a/python/tvm/relay/frontend/caffe2.py b/python/tvm/relay/frontend/caffe2.py index 519dfc185add..5ae7a294d306 100755 --- a/python/tvm/relay/frontend/caffe2.py +++ b/python/tvm/relay/frontend/caffe2.py @@ -15,7 +15,7 @@ def _impl(attr): kernel = attr['kernel_shape'] if len(kernel) == 2: return prefix + '2d' + surfix - raise NotImplementedError("Only 2d kernel supported.") + raise_operator_unimplemented('non 2d kernel') return _impl @@ -27,7 +27,7 @@ def revert_caffe2_pad(pads): elif len(pads) == 2: pass else: - raise ValueError("Invalid caffe2 type padding: {}".format(pads)) + raise_attribute_invalid(str(len(pads)), 'len(pads)', 'padding') return pads @@ -103,8 +103,7 @@ def get_converter(cls): if hasattr(cls, '_impl'): return getattr(cls, '_impl') - raise NotImplementedError('{} not implemented'.format( - cls.__name__)) + raise_operator_unimplemented(cls.__name__) _caffe2_internal_args = [ @@ -224,8 +223,7 @@ def _get_axis_from_order_str(order): return 1 if order == 'NHWC': return 3 - raise RuntimeError( - "Unsupported storage order: {} in caffe2".format(order)) + raise_attribute_unimplemented(order, 'Concat') return AttrCvt( op_name='concatenate', @@ -517,8 +515,7 @@ def _convert_operator(self, # Add a sanitizing step to convert all byte strings in args to strings func = convert_map[op_type](inputs, args, self._params) else: - raise NotImplementedError( - "Operator {} not implemented.".format(op_type)) + raise_operator_unimplemented(op_type) return func diff --git a/python/tvm/relay/frontend/coreml.py b/python/tvm/relay/frontend/coreml.py index a4f9b39b70e2..369a5d4bb3a4 100644 --- a/python/tvm/relay/frontend/coreml.py +++ b/python/tvm/relay/frontend/coreml.py @@ -81,7 +81,7 @@ def _BatchnormLayerParams(op, inexpr, etab): """Get layer of batchnorm parameter""" # this changes the symbol if op.instanceNormalization: - raise NotImplementedError("instance normalization not implemented") + raise_operator_unimplemented('instance normalization') else: params = {'gamma':etab.new_const(list(op.gamma.floatValue)), 'beta':etab.new_const(list(op.beta.floatValue)), @@ -142,7 +142,7 @@ def _ActivationParams(op, inexpr, etab): alpha_expr = etab.new_const(alpha) beta_expr = etab.new_const(beta) return _op.multiply(_op.log(_op.add(_op.exp(inexpr), beta_expr)), alpha_expr) - raise NotImplementedError('%s not implemented' % whichActivation) + raise_operator_unimplemented(whichActivation) def _ScaleLayerParams(op, inexpr, etab): @@ -164,7 +164,7 @@ def _PoolingLayerParams(op, inexpr, etab): return _op.nn.global_max_pool2d(inexpr) if op.type == 1: return _op.nn.global_avg_pool2d(inexpr) - raise NotImplementedError("Only max and average pooling implemented") + raise_operator_unimplemented('pooling (not max or average)') else: params = {'pool_size':list(op.kernelSize), @@ -184,7 +184,8 @@ def _PoolingLayerParams(op, inexpr, etab): params['padding'] = padding params['ceil_mode'] = True else: - raise NotImplementedError("Other convolution padding not implemented") + raise_attribute_unimplemented(op.WhichOneof('PoolingPaddingType'), + 'PoolingPaddingType', 'pooling') # consume padding layer if etab.in_padding: @@ -196,7 +197,7 @@ def _PoolingLayerParams(op, inexpr, etab): return _op.nn.max_pool2d(inexpr, **params) if op.type == 1: return _op.nn.avg_pool2d(inexpr, **params) - raise NotImplementedError("Only max and average pooling implemented") + raise_operator_unimplemented('pooling (not max or average)') def _SoftmaxLayerParams(op, inexpr, etab): @@ -239,7 +240,7 @@ def _ConcatLayerParams(op, inexpr, etab): if not isinstance(inexpr, list): inexpr = [inexpr] if op.sequenceConcat: - raise NotImplementedError("Sequence Concat not supported") + raise_operator_unimplemented('Sequence Concat') ret = _op.concatenate(inexpr, axis=1) return ret @@ -255,14 +256,14 @@ def _PaddingLayerParams(op, inexpr, etab): if op.WhichOneof('PaddingType') == 'constant': constant = op.constant if constant.value != 0: - raise NotImplementedError("Padding value {} not supported.".format(constant.value)) + raise_attribute_unimplemented(constant.value, 'padding value', 'padding') padding = [b.startEdgeSize for b in op.paddingAmounts.borderAmounts] padding2 = [b.endEdgeSize for b in op.paddingAmounts.borderAmounts] for i, j in zip(padding, padding2): assert i == j etab.set_padding(padding) else: - raise NotImplementedError("Only constant padding is supported now.") + raise_operator_unimplemented('non-constant padding') return inexpr @@ -273,8 +274,7 @@ def _PermuteLayerParams(op, inexpr, etab): def _UpsampleLayerParams(op, inexpr, etab): if op.scalingFactor[0] != op.scalingFactor[1]: - raise NotImplementedError("Upsampling only supported with same \ - height and width scaling factor.") + raise_attribute_unimplemented('unequal height/width scaling factors', 'upsample') interpolationMode = 'NEAREST_NEIGHBOR' if op.mode == 0 else 'BILINEAR' return _op.nn.upsampling(inexpr, scale=op.scalingFactor[0], method=interpolationMode) @@ -364,7 +364,7 @@ def coreml_op_to_relay(op, inname, outname, etab): """ classname = type(op).__name__ if classname not in _convert_map: - raise NotImplementedError("%s is not supported" % (classname)) + raise_operator_unimplemented(classname) if isinstance(inname, _base.string_types): insym = etab.get_expr(inname) else: diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index a865f08243eb..2e266852f9dc 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -91,7 +91,7 @@ def _convert_activation(inexpr, keras_layer, _): x = (_expr.const(0.2, dtype='float32') * inexpr) + _expr.const(0.5, dtype='float32') return _op.clip(x, a_min=0., a_max=1.) - raise TypeError("Unsupported activation type : {}".format(act_type)) + raise_operator_unimplemented(act_type) def _convert_advanced_activation(inexpr, keras_layer, etab): @@ -118,7 +118,7 @@ def _convert_advanced_activation(inexpr, keras_layer, etab): return _op.multiply(inexpr, _op.greater(inexpr, \ _expr.const(theta, dtype='float32')).astype('float32')) - raise TypeError("Unsupported advanced activation type : {}".format(act_type)) + raise_operator_unimplemented(act_type) def _convert_merge(inexpr, keras_layer, _): @@ -136,7 +136,7 @@ def _convert_merge(inexpr, keras_layer, _): ret = _op.add(ret, inexpr[i]) ret = ret / _expr.const(len(inexpr), dtype='float32') else: - raise TypeError("Unsupported merge type : {}".format(merge_type)) + raise_operator_unimplemented(merge_type) return ret @@ -150,7 +150,7 @@ def _convert_dense(inexpr, keras_layer, etab): if input_dim > 2: input_shape = tuple(dim if dim else 1 for dim in _as_list(input_shape)[0]) if input_dim != 3 or input_shape[0] != 1 or input_shape[1] != 1: - raise ValueError("Cannot flatten the inputs with shape.", input_shape, " for dense.") + raise_attribute_invaid(input_shape, 'input shape', 'dense') inexpr = _op.squeeze(inexpr, axis=0) out = _op.nn.dense(data=inexpr, **params) if keras_layer.use_bias: @@ -214,7 +214,7 @@ def _convert_convolution(inexpr, keras_layer, etab): inexpr = _op.nn.pad(data=inexpr, pad_width=( (0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r))) else: - raise TypeError("Unsupported padding type : {}".format(keras_layer.padding)) + raise_operator_unimplemented('padding with {}'.format(keras_layer.padding)) if is_deconv: out = _op.nn.conv2d_transpose(data=inexpr, **params) else: @@ -260,7 +260,7 @@ def _convert_separable_convolution(inexpr, keras_layer, etab): inexpr = _op.nn.pad(data=inexpr, pad_width=( (0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r))) else: - raise TypeError("Unsupported padding type : {}".format(keras_layer.padding)) + raise_operator_unimplemented('padding with {}'.format(keras_layer.padding)) depthconv = _op.nn.conv2d(data=inexpr, **params0) # pointwise conv weight1 = weightList[1].transpose([3, 2, 0, 1]) @@ -313,13 +313,13 @@ def _convert_pooling(inexpr, keras_layer, etab): pad_l, pad_r = _get_pad_pair(in_w, pool_w, stride_w) params['padding'] = [pad_t, pad_l, pad_b, pad_r] else: - raise TypeError("Unsupported padding type : {}".format(keras_layer.padding)) + raise_operator_unimplemented('padding with {}'.format(keras_layer.padding)) if pool_type == 'MaxPooling2D': return _op.nn.max_pool2d(inexpr, **params) if pool_type == 'AveragePooling2D': params['count_include_pad'] = False return _op.nn.avg_pool2d(inexpr, **params) - raise TypeError("Unsupported pooling type : {}".format(keras_layer)) + raise_operator_unimplemented('pooling type {}'.format(keras_layer)) def _convert_upsample(inexpr, keras_layer, _): @@ -331,8 +331,7 @@ def _convert_upsample(inexpr, keras_layer, _): elif upsample_type == 'UpSampling2D': h, w = keras_layer.size if h != w: - raise TypeError("Unsupported upsampling type with different axes size : {}" - .format(keras_layer.size)) + raise_attribute_invalid(keras_layer.size, 'size', 'upsample') params = {'scale': h} if hasattr(keras_layer, 'interpolation'): @@ -345,11 +344,10 @@ def _convert_upsample(inexpr, keras_layer, _): elif upsample_type == 'UpSampling3D': h, w, d = keras_layer.size if h != w or w != d: - raise TypeError("Unsupported upsampling type with different axes size : {}" - .format(keras_layer.size)) + raise_attribute_invalid(keras_layer.size, 'size', 'upsample') params = {'scale': h} else: - raise TypeError("Unsupported upsampling type : {}".format(upsample_type)) + raise_operator_unimplemented(upsample_type) return _op.nn.upsampling(inexpr, **params) @@ -357,12 +355,12 @@ def _convert_cropping(inexpr, keras_layer, _): _check_data_format(keras_layer) crop_type = type(keras_layer).__name__ if crop_type == 'Cropping1D': - raise NotImplementedError("Cropping1D not implemented") + raise_operator_unimplemented(crop_type) elif crop_type == 'Cropping2D': (_, in_h, in_w, _) = keras_layer.input_shape ((crop_t, crop_b), (crop_l, crop_r)) = keras_layer.cropping else: - raise TypeError("Unrecognized cropping type : {}".format(crop_type)) + raise_operator_unimplemented(crop_type) int32_max = np.iinfo(np.int32).max return _op.strided_slice(inexpr, begin=[0, 0, crop_t, crop_l], \ end=[int32_max, int32_max, in_h-crop_b, in_w-crop_r]) @@ -407,13 +405,13 @@ def _convert_padding(inexpr, keras_layer, _): top, bottom = padding[0] left, right = padding[1] else: - raise ValueError("Unrecognized padding option: {}".format(str(padding))) + raise_attribute_invalid(str(padding), 'padding', 'padding') else: - raise ValueError("Unrecognized padding option: {}".format(str(padding))) + raise_attribute_invalid(str(padding), 'padding', 'padding') elif padding_type == 'ZeroPadding1D': - raise NotImplementedError("ZeroPadding1D not implemented") + raise_operator_unimplemented(padding_type) else: - raise ValueError("Unrecognized padding type: {}".format(padding_type)) + raise_operator_unimplemented(padding_type) return _op.nn.pad(data=inexpr, pad_width=((0, 0), (0, 0), (top, bottom), (left, right))) @@ -602,7 +600,7 @@ def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument def _check_unsupported_layers(model): for layer in model.layers: if type(layer).__name__ not in _convert_map: - raise ValueError("Keras layer {} not supported.".format(type(layer).__name__)) + raise_operator_unimplemented(type(layer).__name__) def keras_op_to_relay(inexpr, keras_layer, outname, etab): @@ -623,7 +621,7 @@ def keras_op_to_relay(inexpr, keras_layer, outname, etab): The global expression table to be updated. """ if type(keras_layer).__name__ not in _convert_map: - raise NotImplementedError("{} is not supported".format((type(keras_layer).__name__))) + raise_operator_unimplemented(type(keras_layer).__name__) outs = _convert_map[type(keras_layer).__name__](inexpr, keras_layer, etab) outs = _as_list(outs) for t_idx, out in enumerate(outs): diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 758793c980d6..b28558bb25f9 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -41,7 +41,7 @@ def _get_channel_axis(layout, op_name): return 1 if layout == "NHWC": return 3 - raise RuntimeError("layout: {} is not supported in {}".format(layout, op_name)) + raise_attribute_invalid(layout, 'layout', op_name) def _mx_activations(inputs, attrs): @@ -61,7 +61,7 @@ def _stable_softrelu(x): return _op.add(_op.log(_op.add(one, exp_neg_abs_x)), _op.nn.relu(x)) return _stable_softrelu(inputs[0]) - raise RuntimeError("Do not support act_type: {}".format(act_type)) + raise_operator_unimplemented(act_type) def _mx_compare(new_op, wrapper): @@ -74,7 +74,7 @@ def impl(inputs, attrs): def _mx_conv2d(inputs, attrs): kernel_size = attrs.get_int_tuple("kernel") if len(kernel_size) != 2: - raise RuntimeError("non-2d kernel is not supported in conv2d") + raise_attribute_invalid(kernel_size, 'kernel size', 'conv2d') data_layout = attrs.get_str("layout", "NCHW") channel_axis = _get_channel_axis(data_layout, "conv2d") @@ -102,10 +102,10 @@ def _mx_conv2d(inputs, attrs): def _mx_conv2d_transpose(inputs, attrs): if "target_shape" in attrs.attrs: - raise RuntimeError("target_shape is not supported in conv2d_transpose") + raise_attribute_unimplemented('target_shape', 'conv2d_transpose') kernel_size = attrs.get_int_tuple("kernel") if len(kernel_size) != 2: - raise RuntimeError("non-2d kernel is not supported in conv2d") + raise_attribute_invalid(len(kernel_size), 'kernel dimensionality', 'conv2d') data_layout = attrs.get_str("layout", "NCHW") channel_axis = _get_channel_axis(data_layout, "conv2d_transpose") @@ -140,7 +140,7 @@ def _mx_pooling(inputs, attrs): def _pool2d(new_op, is_avg): kernel_size = attrs.get_int_tuple("kernel") if len(kernel_size) != 2: - raise RuntimeError("non-2d kernel is not supported in pool2d") + raise_attribute_invalid(len(kernel_size), 'kernel dimensionality', 'pool2d') new_attrs = {} new_attrs["pool_size"] = kernel_size new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1)) @@ -158,7 +158,7 @@ def _pool2d(new_op, is_avg): if global_pool: return _op.nn.global_avg_pool2d(inputs[0]) return _pool2d(_op.nn.avg_pool2d, True) - raise RuntimeError("Do not support pool_type:{}".format(pool_type)) + raise_operator_unimplemented(pool_type) def _mx_dropout(inputs, attrs): @@ -172,7 +172,7 @@ def _mx_BlockGrad(inputs, attrs): #pylint: disable=unused-argument def _mx_batch_norm(inputs, attrs): if attrs.get_bool("output_mean_var", False): - raise RuntimeError("batch_norm do not support output_mean_var") + raise_attribute_unimplemented('output_mean_var', 'batch_norm') if attrs.get_bool("use_global_stats", False): _warn_not_used("use_global_stats", "batch_norm") new_attrs = {} @@ -188,10 +188,14 @@ def _mx_slice(inputs, attrs): begin = attrs.get_int_tuple('begin', None) end = attrs.get_int_tuple('end', None) stride = attrs.get_int_tuple('step', None) - if begin is None or end is None: - raise RuntimeError("begin and end are required parameters.") - if None in begin or None in end: - raise RuntimeError("None in begin or end is not supported yet.") + if begin is None: + raise_attribute_required('begin', 'slice') + if end is None: + raise_attribute_required('end', 'slice') + if None in begin: + raise_attribute_unimplemented('None in begin', 'slice') + if None in end: + raise_attribute_unimplemented('None in end', 'slice') new_attrs = {'begin': begin, 'end': end} if stride is not None: new_attrs['strides'] = stride @@ -295,7 +299,7 @@ def _mx_leaky_relu(inputs, attrs): upper_bound = attrs.get_float("upper_bound") alpha = (lower_bound + upper_bound) / 2.0 return _op.nn.leaky_relu(inputs[0], alpha=alpha) - raise RuntimeError("act_type: {} is not supported".format(act_type)) + raise_operator_unimplemented(act_type) def _mx_make_power(power): @@ -389,7 +393,7 @@ def _mx_batch_dot(inputs, attrs): transpose_a = attrs.get_bool("transpose_a", False) transpose_b = attrs.get_bool("transpose_b", False) if transpose_a is True: - raise RuntimeError("batch_dot: only support transpose_a=False") + raise_attribute_invalid(transpose_a, 'transpose_a', 'batch_dot') if transpose_b is False: b = _op.transpose(b, axes=[0, 2, 1]) return _op.batch_matmul(a, b) @@ -398,7 +402,7 @@ def _mx_batch_dot(inputs, attrs): def _mx_arange(inputs, attrs): assert len(inputs) == 0 if attrs.get_int("repeat", 1) != 1: - raise RuntimeError("arange doesn't support repeat") + raise_attribute_unimplemented('repeat', 'arange') new_attrs = {} new_attrs["start"] = attrs.get_float("start", 0) new_attrs["stop"] = attrs.get_float("stop") @@ -482,15 +486,15 @@ def _mx_box_nms(inputs, attrs): in_format = attrs.get_str('in_format', 'corner') out_format = attrs.get_str('out_format', 'corner') if coord_start != 2: - raise RuntimeError('coord_start %s is not supported.' % coord_start) + raise_attribute_invalid(coord_start, 'coord_start', 'box_nms') if score_index != 1: - raise RuntimeError('score_index %s is not supported.' % score_index) + raise_attribute_invalid(score_index, 'score_index', 'box_nms') if id_index != -1 and int(id_index) != 0: - raise RuntimeError('id_index %s is not supported.' % id_index) + raise_attribute_invalid(id_index, 'id_index', 'box_nms') if in_format != 'corner': - raise RuntimeError('in_format %s is not supported.' % in_format) + raise_attribute_invalid(in_format, 'in_format', 'box_nms') if out_format != 'corner': - raise RuntimeError('out_format %s is not supported.' % out_format) + raise_attribute_invalid(out_format, 'out_format', 'box_nms') ret = _op.vision.get_valid_counts(inputs[0], score_threshold=valid_thresh) nms_out = _op.vision.non_max_suppression(ret[1], @@ -508,7 +512,7 @@ def _mx_l2_normalize(inputs, attrs): new_attrs = {} mode = attrs.get_str('mode', 'instance') if mode != 'channel': - raise RuntimeError('mode %s is not supported.' % mode) + raise_attribute_invalid(mode, 'mode', 'l2_normalize') new_attrs['eps'] = attrs.get_float('eps', 1e-10) new_attrs['axis'] = [1] return _op.nn.l2_normalize(inputs[0], **new_attrs) @@ -768,10 +772,10 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info): elif isinstance(res, _expr.Expr): res = [res] else: - raise RuntimeError("unexpected type %s" % type(res)) + raise_attribute_invalid(type(res), 'type(res)', op_name) node_map[nid] = res else: - raise RuntimeError("{} is not supported in relay frontend".format(op_name)) + raise_operator_unimplemented(op_name) outputs = [node_map[e[0]][e[1]] for e in jgraph["heads"]] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index e92aa203b401..1bffdfd4bcd9 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -18,7 +18,7 @@ def _impl(attr): kernel = attr['kernel_shape'] if len(kernel) == 2: return prefix + '2d' + surfix - raise NotImplementedError("Only 2d kernel supported.") + raise_attribute_invalid(len(kernel), 'kernel dimensionality', prefix) return _impl @@ -29,7 +29,7 @@ def revert_caffe2_pad(pads): elif len(pads) == 2: pass else: - raise ValueError("Invalid caffe2 type padding: {}".format(pads)) + raise_attribute_invalid(len(pads), 'len(pads)', 'padding') return pads def dimension_constraint(): @@ -461,7 +461,7 @@ def _impl_v9(cls, inputs, attr, params): elif mode == b'linear': method = "BILINEAR" else: - raise ValueError("Invalid ONNX upsample mode: {}".format(mode)) + raise_attribute_invalid(mode, 'mode', 'upsample') attr = {'scale':int(scales[-1]), 'method':method, 'layout':'NCHW'} return AttrCvt('upsampling')(inputs, attr) @@ -717,9 +717,7 @@ def _impl_v1(cls, inputs, attr, params): if 'input_as_shape' in attr and attr['input_as_shape']: shape = params[get_name(inputs[0])].asnumpy() else: - if 'extra_shape' in attr: - raise ImportError( - "Extra Shape not supported with fill_like") + raise_attribute_required('extra_shape', 'ConstantFill') return _op.full_like(inputs[0], inputs[1]) if 'extra_shape' in attr: diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 304c5e11f1a5..f795aa70a596 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -27,7 +27,7 @@ def _get_relay_op(op_name): op = getattr(_op.image, op_name) if not op: - raise RuntimeError("Unable to map op_name {} to relay".format(op_name)) + raise_operator_unimplemented(op_name) return op class AttrCvt(object): @@ -99,7 +99,7 @@ def __call__(self, inputs, attrs, *args): new_attrs = {} for k in attrs.keys(): if k in self._excludes: - raise NotImplementedError("Attribute {} not supported yet.".format(k)) + raise_operator_unimplemented(k, op_name) elif k in self._disables: logging.warning("Attribute %s is disabled in relay.%s", k, op_name) elif k in self._ignores: @@ -148,7 +148,7 @@ def _required_attr(self, attr, key): """Wrapper for getting required attributes.""" assert isinstance(attr, dict) if key not in attr: - raise AttributeError("Required attribute {} not found.".format(key)) + raise_attribute_required(key, self._op_name) return attr[key] def _get_pad_pair(input1d, kernel1d, stride1d): @@ -178,7 +178,7 @@ def _impl(attr): kernel = attr['kernel_shape'] if len(kernel) == 2: return prefix + '2d' + surfix - raise NotImplementedError("Only 2d kernel supported.") + raise_attribute_invalid(len(kernel), 'kernel dimensionality', prefix) return _impl def _dimension_constraint(): @@ -238,7 +238,7 @@ def _impl(inputs, attr, params): attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3]) attr['strides'] = (attr['strides'][2], attr['strides'][3]) else: - raise TypeError("Unsupported data_format type : {}".format(attr['data_format'])) + raise_attribute_invalid(attr['data_format'], 'data_format', 'pooling') if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": tmp_shape = attr['_input_shapes'][inputs[0]] @@ -267,7 +267,7 @@ def _impl(inputs, attr, params): attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]] else: - raise TypeError("Unsupported padding type : {}".format(attr['padding'])) + raise_attribute_invalid(attr['padding'], 'padding', 'padding') if name == "avg_pool": attr['count_include_pad'] = False @@ -341,7 +341,7 @@ def _impl(inputs, attr, params): attr['dilations'] = (attr['dilations'][2], attr['dilations'][3]) attr['strides'] = (attr['strides'][2], attr['strides'][3]) else: - raise TypeError("Unsupported data format type : {}".format(attr['data_format'])) + raise_attribute_unimplemented(attr['data_format'], 'data_format', 'conv') if opname == 'depthwise': @@ -386,7 +386,7 @@ def _impl(inputs, attr, params): attr['padding'] = [0, 0] else: - raise TypeError("Unsupported padding type : {}".format(attr['padding'])) + raise_attribute_invalid(attr['padding'], 'padding', 'conv') if 'kernel_layout' not in attr: if opname == 'conv': @@ -791,7 +791,7 @@ def _impl(inputs, attr, params): if padlist_key in params: padlist = params.pop(padlist_key).asnumpy() else: - raise RuntimeError("Required parameter {} not fount.".format(padlist_key)) + raise_attribute_required(padlist_key, 'pad') paddings = tuple([tuple(l) for l in padlist]) attr['pad_width'] = paddings attr['pad_value'] = 0 diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index d45bb33859b2..37f4e1367e53 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -59,8 +59,7 @@ def check_unsupported_ops(self): unsupported_ops_set.add(op_code_str) if unsupported_ops_set: - raise NotImplementedError("Unsupported Ops: %s" % ( - ','.join(unsupported_ops_set))) + raise_operator_unimplemented(*upsupported_ops_set) def convert_op_to_relay(self): """Convert TFLite ops to relay ops""" @@ -205,8 +204,7 @@ def convert_reshape(self, op): # finally convert back if necessary in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1)) else: - raise NotImplementedError("Not support input shape length {} of reshape : " - .format(str(input_shape_length))) + raise_attribute_invalid(input_shape_length, 'input shape length', 'reshape') out = _op.reshape(in_expr, newshape=tuple(target_shape)) @@ -223,8 +221,7 @@ def convert_reshape(self, op): elif len(target_shape) == 4: out = _op.transpose(out, axes=(0, 3, 1, 2)) else: - raise NotImplementedError("Not support to reshape to shape length {}: " - .format(str(len(target_shape)))) + raise_attribute_invalid(len(target_shape), 'shape length', 'reshape') return out @@ -330,8 +327,7 @@ def convert_squeeze(self, op): # finally convert back if necessary in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1)) else: - raise NotImplementedError("Not support input shape length {} of squeeze : " - .format(str(input_shape_length))) + raise_attribute_invalid(input_shape_length, 'input shape length', 'squeeze') out = _op.squeeze(in_expr, axis=tuple(squeeze_axis)) @@ -348,8 +344,7 @@ def convert_squeeze(self, op): elif output_shape_length == 4: out = _op.transpose(out, axes=(0, 3, 1, 2)) else: - raise NotImplementedError("Not support to squeeze to length {} : " - .format(str(output_shape_length))) + raise_attribute_invalid(output_shape_length, 'output_shape_length', 'squeeze') return out @@ -369,8 +364,7 @@ def convert_fused_activation_function(self, in_expr, fused_activation_fn): if fused_activation_fn == ActivationFunctionType.TANH: return _op.tanh(in_expr) fused_activation_fn_str = self.activation_fn_type[fused_activation_fn] - raise NotImplementedError("Unsupported fused activation fn {}" - .format(fused_activation_fn_str)) + raise_operator_unimplemented(fused_activation_fn_str) def convert_conv(self, op, conv_type): """convolution implementation.""" @@ -409,7 +403,7 @@ def convert_conv(self, op, conv_type): assert depth_multiplier == 1, "TF frontend have transformed it be 1 " \ "no matter original value be set by 0.25, 0.5 or any else" else: - raise ValueError("Not support conv type: {}".format(conv_type)) + raise_operator_unimplemented(conv_type) stride_h = conv_options.StrideH() stride_w = conv_options.StrideW() @@ -466,7 +460,7 @@ def convert_conv(self, op, conv_type): (pad_top, pad_bottom), (pad_left, pad_right))) else: - raise NotImplementedError("Not support padding format: {}".format(padding)) + raise_attribute_invalid(padding, 'padding format', 'conv') out = _op.nn.conv2d(data=in_expr, weight=weight_expr, **params) @@ -529,14 +523,14 @@ def convert_pool2d(self, op, pool_type): pad_left, pad_right = get_pad_value(input_w, filter_w, stride_w) params['padding'] = [pad_top, pad_left, pad_bottom, pad_right] else: - raise NotImplementedError("Not support padding format: {}".format(padding)) + raise_attribute_invalid(padding, 'padding', 'pool2d') if pool_type == "average": out = _op.nn.avg_pool2d(in_expr, **params) elif pool_type == "max": out = _op.nn.max_pool2d(in_expr, **params) else: - raise ValueError("Not support pool type: {}".format(pool_type)) + raise_operator_unimplemented(pool_type + ' pool') # If we have fused activations if fused_activation_fn != ActivationFunctionType.NONE: From 4273ce94755beecfeb933569c211e35b8f7d941e Mon Sep 17 00:00:00 2001 From: Mark Rogers Date: Wed, 20 Mar 2019 00:29:35 +0000 Subject: [PATCH 2/2] use latest error handling conventions --- nnvm/python/nnvm/frontend/__init__.py | 7 --- nnvm/python/nnvm/frontend/caffe2.py | 10 ++-- nnvm/python/nnvm/frontend/common.py | 6 +- nnvm/python/nnvm/frontend/coreml.py | 35 +++++++----- nnvm/python/nnvm/frontend/darknet.py | 50 +++++++++++------ nnvm/python/nnvm/frontend/keras.py | 66 +++++++++++++--------- nnvm/python/nnvm/frontend/mxnet.py | 50 +++++++++++------ nnvm/python/nnvm/frontend/onnx.py | 6 +- nnvm/python/nnvm/frontend/tensorflow.py | 30 ++++++---- python/tvm/error_handling/__init__.py | 44 --------------- python/tvm/relay/frontend/__init__.py | 5 -- python/tvm/relay/frontend/caffe2.py | 16 ++++-- python/tvm/relay/frontend/coreml.py | 33 +++++++---- python/tvm/relay/frontend/keras.py | 74 ++++++++++++++++--------- python/tvm/relay/frontend/mxnet.py | 71 ++++++++++++++++-------- python/tvm/relay/frontend/onnx.py | 16 ++++-- python/tvm/relay/frontend/tensorflow.py | 31 ++++++++--- python/tvm/relay/frontend/tflite.py | 33 +++++++---- 18 files changed, 347 insertions(+), 236 deletions(-) delete mode 100644 python/tvm/error_handling/__init__.py diff --git a/nnvm/python/nnvm/frontend/__init__.py b/nnvm/python/nnvm/frontend/__init__.py index f95e134cf0dd..49f53df1174f 100644 --- a/nnvm/python/nnvm/frontend/__init__.py +++ b/nnvm/python/nnvm/frontend/__init__.py @@ -7,10 +7,3 @@ from .darknet import from_darknet from .tensorflow import from_tensorflow from .caffe2 import from_caffe2 -from .common import raise_not_supported, get_nnvm_op, required_attr, \ - warn_not_used, parse_tshape, parse_bool_str -from tvm.error_handling import raise_attribute_required, \ - raise_attribute_invalid, \ - raise_operator_unimplemented, \ - raise_attribute_unimplemented, \ - warn_not_used diff --git a/nnvm/python/nnvm/frontend/caffe2.py b/nnvm/python/nnvm/frontend/caffe2.py index 32d08678a0f8..63b7913dd755 100755 --- a/nnvm/python/nnvm/frontend/caffe2.py +++ b/nnvm/python/nnvm/frontend/caffe2.py @@ -3,7 +3,7 @@ from __future__ import absolute_import as _abs import tvm from nnvm import symbol as _sym -from nnvm.frontend.common import get_nnvm_op, Renamer, AttrConverter as AttrCvt +from .common import get_nnvm_op from .onnx_caffe2_utils import dimension_picker, dimension_constraint, infer_channels, revert_caffe2_pad from . import onnx @@ -73,7 +73,8 @@ def get_converter(cls): if hasattr(cls, '_impl'): return getattr(cls, '_impl') - raise_operator_unimplemented(cls.__name__) + raise tvm.error.OpNotImplemented( + 'Operator {} is not implemented in frontend Caffe2.'.format(cls.__name__)) _caffe2_internal_args = { @@ -175,7 +176,7 @@ def _get_axis_from_order_str(order): return 1 if order == 'NHWC': return 3 - raise_attribute_invalid(order, 'storage order', 'concat') + raise tvm.error.OpAttributeInvalid('Value {} in attribute {} of operator {} is not valid.'.format(order, 'order', 'Concat')) return AttrCvt( op_name='concatenate', @@ -425,7 +426,8 @@ def _convert_operator(self, # Add a sanitizing step to convert all byte strings in args to strings sym = convert_map[op_type](inputs, args, self._params) else: - raise_operator_unimplemented(op_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Caffe2.'.format(op_type)) return sym diff --git a/nnvm/python/nnvm/frontend/common.py b/nnvm/python/nnvm/frontend/common.py index 58ce6703b28d..5a8defdb3d6e 100644 --- a/nnvm/python/nnvm/frontend/common.py +++ b/nnvm/python/nnvm/frontend/common.py @@ -7,13 +7,15 @@ def get_nnvm_op(op_name): op = getattr(_sym, op_name) if not op: - raise_operator_unimplemented(op_name) + raise OpNotImplemented( + 'Operator {} is not supported.'.format(op)) return op def required_attr(attr, key, op_name): assert isinstance(attr, dict) if key not in attr: - raise_attribute_required(key, op_name) + raise OpAttributeRequired( + 'Required attribute {} not found in operator {}'.format(key, op_name)) return attr[key] def parse_tshape(tshape): diff --git a/nnvm/python/nnvm/frontend/coreml.py b/nnvm/python/nnvm/frontend/coreml.py index e7c5a0d7eda8..1483e95cf6f0 100644 --- a/nnvm/python/nnvm/frontend/coreml.py +++ b/nnvm/python/nnvm/frontend/coreml.py @@ -2,11 +2,10 @@ """CoreML frontend.""" from __future__ import absolute_import as _abs import numpy as np - import tvm +from .common import SymbolTable from .. import symbol as _sym from .._base import string_types -from .common import SymbolTable __all__ = ['from_coreml'] @@ -83,7 +82,8 @@ def BatchnormLayerParams(op, insym, symtab): """Get layer of batchnorm parameter""" # this changes the symbol if op.instanceNormalization: - raise_operator_unimplemented('instance normalization') + msg = 'Operator "instance normalization" is not supported in frontend CoreML.' + raise tvm.error.OpNotImplemented(msg) else: params = {'gamma':symtab.new_const(list(op.gamma.floatValue)), 'beta':symtab.new_const(list(op.beta.floatValue)), @@ -136,7 +136,8 @@ def ActivationParams(op, insym, symtab): betasym = symtab.new_const(beta) return _sym.broadcast_mul(_sym.log(_sym.broadcast_add( _sym.exp(insym), betasym)), alphasym) - raise_operator_unimplemented(whichActivation) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend CoreML.'.format(whichActivation)) def ScaleLayerParams(op, insym, symtab): """Scale layer params.""" @@ -158,7 +159,8 @@ def PoolingLayerParams(op, insym, symtab): return _sym.global_max_pool2d(insym) if op.type == 1: return _sym.global_avg_pool2d(insym) - raise_operator_unimplemented('pooling (not max or average)') + raise tvm.error.OpNotImplemented( + 'Operator pooling (not max or average) is not supported in frontend CoreML.') else: params = {'pool_size':list(op.kernelSize), @@ -178,8 +180,8 @@ def PoolingLayerParams(op, insym, symtab): params['padding'] = padding params['ceil_mode'] = True else: - raise_attribute_invalid(op.WhichOneof('PoolingPaddingType'), - 'PoolingPaddingType', 'pooling') + msg = 'Value {} in attribute PoolingPaddingType of operator Pooling is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(op.WhichOneof('PoolingPaddingType'))) # consume padding layer if symtab.in_padding: @@ -191,7 +193,8 @@ def PoolingLayerParams(op, insym, symtab): return _sym.max_pool2d(insym, **params) if op.type == 1: return _sym.avg_pool2d(insym, **params) - raise_operator_unimplemented('pooling (not max or average)') + msg = 'Operator pooling (not max or average) is not supported in frontend CoreML.' + raise tvm.error.OpNotImplemented(msg) def SoftmaxLayerParams(op, insym, symtab): return _sym.softmax(_sym.flatten(insym)) @@ -230,7 +233,8 @@ def ConcatLayerParams(op, insyms, symtab): if not isinstance(insyms, list): insyms = [insyms] if op.sequenceConcat: - raise_operator_unimplemented('sequence concat') + raise tvm.error.OpNotImplemented( + 'Operator Sequence Concat is not supported in frontend CoreML.') ret = _sym.concatenate(*insyms, axis=1) return ret @@ -244,14 +248,16 @@ def PaddingLayerParams(op, insym, symtab): if op.WhichOneof('PaddingType') == 'constant': constant = op.constant if constant.value != 0: - raise_attribute_invalid(constant.value, 'padding value', 'padding') + msg = 'Value {} in attribute "padding value" of operator Padding is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(constant.value)) padding = [b.startEdgeSize for b in op.paddingAmounts.borderAmounts] padding2 = [b.endEdgeSize for b in op.paddingAmounts.borderAmounts] for i, j in zip(padding, padding2): assert i == j symtab.set_padding(padding) else: - raise_operator_unimplemented('non-constant padding') + raise tvm.error.OpNotImplemented( + 'Operator "non-constant padding" is not supported in frontend CoreML.') return insym def PermuteLayerParams(op, insym, symtab): @@ -260,8 +266,8 @@ def PermuteLayerParams(op, insym, symtab): def UpsampleLayerParams(op, insym, symtab): if op.scalingFactor[0] != op.scalingFactor[1]: - raise_attribute_invalid(op.scalingFactor, 'scaling factors', - 'upsample') + raise tvm.error.OpAttributeInvalid( + 'Height and width scaling factors of Upsample operator must be equal.') interpolationMode = 'NEAREST_NEIGHBOR' if op.mode == 0 else 'BILINEAR' return _sym.upsampling(insym, scale=op.scalingFactor[0], method=interpolationMode) @@ -342,7 +348,8 @@ def coreml_op_to_nnvm(op, inname, outname, symtab): """ classname = type(op).__name__ if classname not in _convert_map: - raise_operator_unimplemented(classname) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend CoreML.'.format(classname)) if isinstance(inname, string_types): insym = symtab.get_var(inname) else: diff --git a/nnvm/python/nnvm/frontend/darknet.py b/nnvm/python/nnvm/frontend/darknet.py index bbb0926f29c8..bf5a832258fa 100644 --- a/nnvm/python/nnvm/frontend/darknet.py +++ b/nnvm/python/nnvm/frontend/darknet.py @@ -6,6 +6,7 @@ import numpy as np import tvm from .. import symbol as _sym +from .common import get_nnvm_op, required_attr, parse_tshape, parse_bool_str class LAYERTYPE(object): """Darknet LAYERTYPE Class constant.""" @@ -61,7 +62,8 @@ def _darknet_maxpooling(inputs, attrs): """Process the max pool 2d operation.""" kernel = parse_tshape(required_attr(attrs, 'kernel', 'maxpool')) if len(kernel) != 1: - raise_attribute_unimplemented('non-2d kernel', 'pool_2d') + raise tvm.error.OpAttributeUnimplemented( + 'Non-2D kernels for Max Pooling are not supported in frontend Darknet.') op_name, new_attrs = 'max_pool2d', {} strides = int(attrs.get('stride', (1, 1))) @@ -79,7 +81,8 @@ def _darknet_avgpooling(inputs, attrs): """Process the average pool 2d operation.""" kernel = parse_tshape(required_attr(attrs, 'kernel', 'avgpool')) if len(kernel) != 1: - raise_attribute_unimplemented('non-2d kernel', 'pool_2d') + raise tvm.error.OpAttributeUnimplemented( + 'Non-2D kernels for Average Pooling are not supported in frontend Darknet.') op_name, new_attrs = 'avg_pool2d', {} strides = int(attrs.get('stride', (1, 1))) @@ -103,10 +106,12 @@ def _darknet_conv2d(inputs, attrs): """Process the convolution 2d operation.""" kernel = parse_tshape(required_attr(attrs, 'kernel', 'conv2d')) if len(kernel) != 1: - raise_attribute_unimplemented('non 2d kernel', 'conv2d') + raise tvm.error.OpAttributeUnimplemented('Non-2D kernels for Conv2D are unsupported ' + 'in frontend Darknet.') layout = attrs.get('layout', 'NCHW') if layout not in ['NCHW', 'NHWC']: - raise_attribute_invalid(layout, 'layout', 'conv2d') + raise tvm.error.OpAttributeInvalid( + 'Value {} in attribute "layout" of operator Conv2D is not valid.'.format(layout)) strides = int(attrs.get('stride', (1, 1))) pads = int(attrs.get('pad', (0, 0))) @@ -142,13 +147,16 @@ def _darknet_conv2d(inputs, attrs): def _darknet_conv2d_transpose(inputs, attrs): """Process the convolution 2d transpose operation.""" if 'target_shape' in attrs: - raise_attribute_unimplemented('target_shape', 'conv2d_transpose') + raise tvm.error.OpAttributeUnimplemented( + 'Attribute "target_shape" is not supported in operator Conv2D-transpose.') kernel = parse_tshape(required_attr(attrs, 'kernel', 'conv2d_transpose')) if len(kernel) != 2: - raise_attribute_unimplemented('non-2d kernel', 'conv2d_transpose') + raise tvm.error.OpAttributeUnimplemented( + 'Non-2D kernels are not supported in operator Conv2D-transpose.') layout = attrs.get('layout', 'NCHW') if layout not in ['NCHW', 'NHWC']: - raise_attribute_invalid(layout, 'layout', 'conv2d_transpose') + msg = 'Value {} in attribute "layout" of operator Conv2D-transpose is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(layout)) op_name, new_attrs = 'conv2d_transpose', {} new_attrs['channels'] = required_attr(attrs, 'num_filter', 'conv2d_transpose') new_attrs['kernel_size'] = kernel @@ -222,7 +230,8 @@ def _darknet_dropout(inputs, attrs): def _darknet_reshape(inputs, attrs): """Process the reshape operation.""" if parse_bool_str(attrs, 'reverse'): - raise_attribute_unimplemented('reverse', 'reshape') + raise tvm.error.OpAttributeUnimplemented( + 'Attribute "reverse" is not supported in operator Reshape.') op_name, new_attrs = 'reshape', {} new_attrs['shape'] = required_attr(attrs, 'shape', 'reshape') return get_nnvm_op(op_name)(*inputs, **new_attrs), None @@ -324,7 +333,8 @@ def _darknet_activations(inputs, attrs): elif ACTIVATION.ELU == act: act_type = 'elu' else: - raise_operator_unimplemented('act: ' + act) + raise tvm.error.OpNotImplemented( + 'Operator act: {} is not supported in framework Darknet.'.format(act)) if act_type in ['relu', 'tanh']: op_name, new_attrs = act_type, {} @@ -339,7 +349,8 @@ def _darknet_activations(inputs, attrs): op_name, new_attrs = act_type, {} sym = get_nnvm_op(op_name)(*inputs, **new_attrs) else: - raise_operator_unimplemented('act_type: ' + act_type) + raise tvm.error.OpNotImplemented( + 'Operator act: {} is not supported in framework Darknet.'.format(act)) return sym, None def _darknet_op_not_support(inputs, attrs): @@ -402,7 +413,8 @@ def _darknet_convert_symbol(op_name, inputs, attrs): if op_name in _DARKNET_CONVERT_MAP: sym, out_name = _DARKNET_CONVERT_MAP[op_name](inputs, attrs) else: - raise_operator_unimplemented(op_name) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Darknet.'.format(op_name)) if out_name is None: out_name = sym.list_output_names()[0].replace('_output', '') return out_name, sym @@ -448,9 +460,10 @@ def _get_convolution_weights(self, layer, opname): if layer.nweights == 0: return - if (layer.n * layer.c * layer.size * layer.size) != layer.nweights: - raise_attribute_invalid(layer.n * layer.c * layer.size * layer.size, - 'layer weights size', 'conv2d') + if layer.n * layer.c * layer.size * layer.size != layer.nweights: + msg = 'nweights ({}) != n * c * h * w ({}) in operator {}' + msg = msg.format(layer.nweights, layer.n * layer.c * layer.size ** 2, opname) + raise tvm.error.OpAttributeInvalid(msg) shape = (layer.n, layer.c, layer.size, layer.size) weights = self._read_memory_buffer(shape, layer.weights) @@ -630,7 +643,8 @@ def _get_darknet_attrs(self, layer, layer_num): pass else: - raise_operator_unimplemented(layer.type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Darknet.'.format(layer.type)) return attr @@ -763,7 +777,8 @@ def _handle_darknet_rnn_layers(self, layer_num, sym): elif LAYERTYPE.LSTM == layer.type: if layer.steps > 1: - raise_attribute_invalid(layer.steps, 'number of steps', 'RNN') + raise tvm.error.OpAttributeInvalid( + 'Number of steps {} of RNN is not valid.'.format(layer.steps)) op_name_add = 'elemwise_add' op_name_mul = 'elemwise_mul' @@ -829,7 +844,8 @@ def _handle_darknet_rnn_layers(self, layer_num, sym): elif LAYERTYPE.GRU == layer.type: if layer.steps > 1: - raise_attribute_invalid(layer.steps, 'number of steps', 'RNN') + raise tvm.error.OpAttributeInvalid( + 'Number of steps {} is not valid in RNN.'.format(layer.steps)) op_name_add = 'elemwise_add' op_name_mul = 'elemwise_mul' diff --git a/nnvm/python/nnvm/frontend/keras.py b/nnvm/python/nnvm/frontend/keras.py index d15d2b3f01ab..63b4122a4060 100644 --- a/nnvm/python/nnvm/frontend/keras.py +++ b/nnvm/python/nnvm/frontend/keras.py @@ -74,7 +74,8 @@ def _convert_activation(insym, keras_layer, _): if act_type == 'hard_sigmoid': transformX = (0.2 * insym) + 0.5 return _sym.clip(transformX, a_min=0, a_max=1) - raise_operator_unimplemented(act_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Keras.'.format(act_type)) def _convert_advanced_activation(insym, keras_layer, symtab): @@ -100,7 +101,8 @@ def _convert_advanced_activation(insym, keras_layer, symtab): theta = keras_layer.theta if hasattr(keras_layer, "theta") else 1.0 theta_tensor = _sym.full_like(insym[0], fill_value=float(theta)) return _sym.elemwise_mul(insym[0], _sym.greater(insym[0], theta_tensor, out_type="float32")) - raise_operator_unimplemented(act_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Keras.'.format(act_type)) def _convert_merge(insym, keras_layer, _): @@ -113,12 +115,9 @@ def _convert_merge(insym, keras_layer, _): ret = _sym.elemwise_sub(ret, insym[i]) elif merge_type == 'Multiply': ret = _sym.elemwise_mul(ret, insym[i]) - elif merge_type == 'Average': - raise_operator_unimplemented('average merge') - elif merge_type == 'Maximum': - raise_operator_unimplemented('maximum merge') else: - raise_operator_unimplemented(merge_type) + raise tvm.error.OpNotImplemented( + 'Operator {} Merge is not supported in frontend Keras.'.format(merge_type)) return ret @@ -135,7 +134,8 @@ def _convert_dense(insym, keras_layer, symtab): if input_dim > 2: input_shape = tuple(dim if dim else 1 for dim in _as_list(input_shape)[0]) if input_dim != 3 or input_shape[0] != 1 or input_shape[1] != 1: - raise_attribute_invalid(input_shape, 'input shape', 'dense') + msg = 'Value {} in attribute "input_shape" of operator Dense is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(input_shape)) insym = _sym.squeeze(insym, axis=0) out = _sym.dense(data=insym, **params) # defuse activation @@ -199,7 +199,8 @@ def _convert_convolution(insym, keras_layer, symtab): else: insym = _sym.pad(data=insym, pad_width=((0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r))) else: - raise_operator_unimplemented('padding with {}'.format(keras_layer.padding)) + msg = 'Value {} in attribute "padding" of operator Convolution is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(keras_layer.padding)) if is_deconv: out = _sym.conv2d_transpose(data=insym, **params) else: @@ -240,7 +241,8 @@ def _convert_separable_convolution(insym, keras_layer, symtab): insym = _sym.pad(data=insym, pad_width=( (0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r))) else: - raise_operator_unimplemented('padding with {}'.format(keras_layer.padding)) + msg = 'Value {} in attribute "padding" of operator Separable Convolution is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(keras_layer.padding)) depthconv = _sym.conv2d(data=insym, **params0) # pointwise conv weight1 = weightList[1].transpose([3, 2, 0, 1]) @@ -294,13 +296,15 @@ def _convert_pooling(insym, keras_layer, symtab): pad_l, pad_r = _get_pad_pair(in_w, pool_w, stride_w) params['padding'] = [pad_t, pad_l, pad_b, pad_r] else: - raise_operator_unimplemented('padding with {}'.format(keras_layer.padding)) + msg = 'Value {} in attribute "padding" of operator Pooling is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(keras_layer.padding)) if pool_type == 'MaxPooling2D': return _sym.max_pool2d(insym, **params) if pool_type == 'AveragePooling2D': # TODO: in keras, padded zeros are not calculated return _sym.avg_pool2d(insym, **params) - raise_operator_unimplemented('pooling with {}'.format(keras_layer)) + msg = 'Value {} in attribute "padding" of operator Pooling is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(keras_layer.padding)) def _convert_upsample(insym, keras_layer, _): @@ -312,28 +316,30 @@ def _convert_upsample(insym, keras_layer, _): elif upsample_type == "UpSampling2D": h, w = keras_layer.size if h != w: - raise_attribute_invalid(keras_layer.size, 'size', 'upsample') + raise tvm.error.OpAttributeInvalid( + 'Upsample height ({}) must equal width ({})'.format(h, w)) params = {'scale': h} elif upsample_type == "UpSampling3D": h, w, d = keras_layer.size if h != w or w != d: - raise_attribute_invalid(keras_layer.size, 'size', 'upsample') + raise tvm.error.OpAttributeInvalid( + 'Upsample height ({}), width ({}), and depth ({}) must be equal.'.format(h, w, d)) params = {'scale': h} else: - raise_operator_unimplemented(upsample_type) + msg = 'Operator {} is not supported in frontend Keras.' + raise tvm.error.OpNotImplemented(msg.format(upsample_type)) return _sym.upsampling(insym, **params) def _convert_cropping(insym, keras_layer, _): _check_data_format(keras_layer) crop_type = type(keras_layer).__name__ - if crop_type == "Cropping1D": - raise_operator_unimplemented(crop_type) - elif crop_type == "Cropping2D": + if crop_type == "Cropping2D": (_, in_h, in_w, _) = keras_layer.input_shape ((crop_t, crop_b), (crop_l, crop_r)) = keras_layer.cropping else: - raise_operator_unimplemented(crop_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Keras.'.format(crop_type)) int32_max = np.iinfo(np.int32).max return _sym.strided_slice(insym, begin=[0, 0, crop_t, crop_l], end=[int32_max, int32_max, in_h-crop_b, in_w-crop_r]) @@ -377,11 +383,13 @@ def _convert_padding(insym, keras_layer, _): top, bottom = padding[0] left, right = padding[1] else: - raise_attribute_invalid(str(padding), 'padding', padding_type) + msg = 'Value {} in attribute "padding" of operator {} is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(str(padding), padding_type)) else: - raise_attribute_invalid(str(padding), 'padding', padding_type) + msg = 'Value {} in attribute "padding" of operator {} is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(str(padding), padding_type)) else: - raise_operator_unimplemented(padding_type) + raise tvm.error.OpNotImplemented('Operator {} is not supported in frontend Keras.') return _sym.pad(data=insym, pad_width=((0, 0), (0, 0), (top, bottom), (left, right))) @@ -588,8 +596,10 @@ def _default_skip(insym, keras_layer, _): # pylint: disable=unused-argument def _check_unsupported_layers(model): for layer in model.layers: - if type(layer).__name__ not in _convert_map: - raise_operator_unimplemented(type(layer).__name__) + op_name = type(layer).__name__ + if op_name not in _convert_map: + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Keras.'.format(op_name)) def _as_list(arr): """Force being a list, ignore if already is.""" @@ -614,9 +624,11 @@ def keras_op_to_nnvm(insym, keras_layer, outname, symtab): symtab : nnvm.frontend.common.SymbolTable The global symbol table to be updated """ - if type(keras_layer).__name__ not in _convert_map: - raise_operator_unimplemented(type(keras_layer).__name__) - outs = _convert_map[type(keras_layer).__name__](insym, keras_layer, symtab) + op_name = type(keras_layer).__name__ + if op_name not in _convert_map: + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Keras.'.format(op_name)) + outs = _convert_map[op_name](insym, keras_layer, symtab) outs = _as_list(outs) for t_idx, out in enumerate(outs): diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index 372f10bd98b9..da5e154bce12 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -4,6 +4,7 @@ import json import tvm from .. import symbol as _sym +from .common import get_nnvm_op, required_attr, parse_tshape, parse_bool_str __all__ = ['from_mxnet'] @@ -15,11 +16,13 @@ def impl(inputs, attrs): def _pooling(inputs, attrs): kernel = parse_tshape(required_attr(attrs, 'kernel', 'pooling')) if len(kernel) != 2: - raise_attribute_unimplemented('non-2d kernel', 'pool_2d') + raise tvm.error.OpAttributeUnimplemented( + 'Non-2D kernels are not supported for Pool2D.') global_pool = 'global' if parse_bool_str(attrs, 'global_pool') else '' pool_type = required_attr(attrs, 'pool_type', 'pooling') if pool_type not in ['avg', 'max']: - raise_attribute_unimplemented('non-avg/max', 'pool2d') + raise tvm.error.OpNotImplemented( + 'Only max and average pooling are supported in frontend MXNet.') op_name, new_attrs = '_'.join([global_pool, pool_type, 'pool2d']).strip('_'), {} # new_attrs['layout'] = 'NCHW' if not global_pool: @@ -32,11 +35,15 @@ def _pooling(inputs, attrs): return get_nnvm_op(op_name)(*inputs, **new_attrs) def _batch_norm(inputs, attrs): - raise_attribute_unimplemented('output_mean_var', 'batch_norm') + if parse_bool_str(attrs, 'output_mean_var'): + raise tvm.error.OpAttributeUnimplemented( + 'Attribute "output_mean_var" is not supported in operator batch_norm.') # if parse_bool_str(attrs, 'fix_gamma'): # _warn_not_used('fix_gamma', 'batch_norm') if parse_bool_str(attrs, 'use_global_stats'): - warn_not_used('use_global_stats', 'batch_norm') + from warnings import warn + warn( + 'Attribute "use_global_stats" is ignored in operator batch_norm.') # if parse_bool_str(attrs, 'momentum'): # _warn_not_used('momentum', 'batch_norm') op_name, new_attrs = 'batch_norm', {} @@ -54,10 +61,12 @@ def _concat(inputs, attrs): def _conv2d(inputs, attrs): kernel = parse_tshape(required_attr(attrs, 'kernel', 'conv2d')) if len(kernel) != 2: - raise_attribute_unimplemented('non 2d kernel', 'conv2d') + raise tvm.error.OpAttributeUnimplemented( + 'Non-2D kernels are not supported for operator Conv2D.') layout = attrs.get('layout', 'NCHW') if layout not in ['NCHW', 'NHWC']: - raise_attribute_unimplemented('layout: ' + layout, 'conv2d') + raise tvm.error.OpAttributeUnimplemented( + 'Layout {} is not supported in operator Conv2D.'.format(layout)) if 'kernel_layout' in attrs: kernel_layout = attrs['kernel_layout'] else: @@ -76,13 +85,16 @@ def _conv2d(inputs, attrs): def _conv2d_transpose(inputs, attrs): if 'target_shape' in attrs: - raise_attribute_unimplemented('target_shape', 'conv2d_transpose') + raise tvm.error.OpAttributeUnimplemented( + 'Attribute "target_shape" is not supported in operator Conv2D-transpose.') kernel = parse_tshape(required_attr(attrs, 'kernel', 'conv2d_transpose')) if len(kernel) != 2: - raise_attribute_invalid(len(kernel), 'kernel dim', 'conv2d_transpose') + raise tvm.error.OpAttributeInvalid( + 'Non-2D kernels are not supported in Conv2D-transpose.') layout = attrs.get('layout', 'NCHW') if layout not in ['NCHW', 'NHWC']: - raise_attribute_unimplemented('layout: ' + layout, 'conv2d_transpose') + raise tvm.error.OpAttributeUnimplemented( + 'Layout {} is not supported in operator Conv2D-transpose.') if 'kernel_layout' in attrs: kernel_layout = attrs['kernel_layout'] else: @@ -138,7 +150,8 @@ def _leaky_relu(inputs, attrs): op_name, new_attrs = 'leaky_relu', {'alpha': str(slope)} sym = get_nnvm_op(op_name)(*inputs, **new_attrs) else: - raise_attribute_unimplemented([act_type]) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend MXNet.'.format(act_type)) return sym def _activations(inputs, attrs): @@ -149,12 +162,14 @@ def _activations(inputs, attrs): elif act_type == 'softrelu': sym = _sym.log((1 + _sym.exp(*inputs))) else: - raise_operator_unimplemented(act_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend MXNet.'.format(act_type)) return sym def _reshape(inputs, attrs): if parse_bool_str(attrs, 'reverse'): - raise_attribute_unimplemented('reverse', 'reshape') + raise tvm.error.OpAttributeUnimplemented( + 'Attribute "reverse" is not supported in operator Reshape.') op_name, new_attrs = 'reshape', {} new_attrs['shape'] = required_attr(attrs, 'shape', 'reshape') return get_nnvm_op(op_name)(*inputs, **new_attrs) @@ -218,7 +233,7 @@ def _contrib_multibox_detection(inputs, attrs): new_attrs1 = {'return_indices': False, 'iou_threshold': float(nms_threshold), 'force_suppress': force_suppress, 'top_k': int(nms_topk)} data, valid_count = get_nnvm_op('multibox_transform_loc')(inputs[0], inputs[1], - inputs[2], **new_attrs0) + inputs[2], **new_attrs0) return get_nnvm_op('non_max_suppression')(data, valid_count, **new_attrs1) def _elemwise_sum(inputs, _): @@ -231,10 +246,12 @@ def _crop_like(inputs, attrs): tuple([float(x.strip()) for x in attrs.get('offsets').strip('()').split(',')]) \ if attrs.get('offsets') is not None else (0, 0) if offsets != (0, 0): - raise_attribute_invalid(offsets, 'offsets', 'crop_like') + raise tvm.error.OpAttributeInvalid( + 'crop_like offsets must equal (0,0).') center_crop = parse_bool_str(attrs, 'center_crop', default="False") if center_crop: - raise_attribute_unimplemented('center crop', 'crop_like') + raise tvm.error.OpAttributeUnimplemented( + 'Center crop is not supported in operator crop_like.') if len(inputs) < 2: raise RuntimeError("Only support crop_like pattern.") new_attrs["axis"] = [2, 3] @@ -381,7 +398,8 @@ def _convert_symbol(op_name, inputs, attrs, elif op_name in convert_map: sym = convert_map[op_name](inputs, attrs) else: - raise_operator_unimplemented(op_name) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend MXNet.'.format(op_name)) return sym def _as_list(arr): diff --git a/nnvm/python/nnvm/frontend/onnx.py b/nnvm/python/nnvm/frontend/onnx.py index 1262bebbb85f..18eb213bab7b 100644 --- a/nnvm/python/nnvm/frontend/onnx.py +++ b/nnvm/python/nnvm/frontend/onnx.py @@ -397,7 +397,8 @@ def _impl_v7(cls, inputs, attr, params): elif mode == b'linear': method = "BILINEAR" else: - raise_attribute_invalid(mode, 'mode', 'upsample') + raise tvm.error.OpAttributeInvalid( + 'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode)) return _sym.upsampling(inputs[0], scale=int(scales[-1]), method=method, layout='NCHW') @@ -922,7 +923,8 @@ def _convert_operator(self, elif op_name in convert_map: sym = convert_map[op_name](inputs, attrs, self._params) else: - raise_operator_unimplemented(op_name) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend ONNX.') return sym def _fix_outputs(self, op_name, outputs): diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 140fa900eefa..f2ff60294489 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -11,7 +11,7 @@ from .. import symbol as _sym from .. import graph as _graph from .. compiler import graph_util, build_module -from .common import AttrConverter as AttrConvert +from .common import get_nnvm_op, AttrConverter as AttrConvert __all__ = ['from_tensorflow'] @@ -68,7 +68,8 @@ def _impl(attr): kernel = attr['kernel_shape'] if len(kernel) == 2: return prefix + '2d' + surfix - raise_attribute_unimplemented('non-2d kernel', prefix) + raise tvm.error.OpAttributeUnimplemented( + 'Non-2D kernels are not supported for operator {}.'.format(prefix)) return _impl def _dimension_constraint(): @@ -129,7 +130,8 @@ def _impl(inputs, attr, params): attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3]) attr['strides'] = (attr['strides'][2], attr['strides'][3]) else: - raise_attribute_invalid(attr['data_format'], 'data_format', 'pooling') + msg = 'Value {} in attribute "data_format" of operator Pooling is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format'])) if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": tmp_shape = attr['_input_shapes'][inputs[0]] @@ -158,7 +160,8 @@ def _impl(inputs, attr, params): attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]] else: - raise_attribute_unimplemented(attr['padding'], 'padding', 'pooling') + msg = 'Value {} in attribute "padding" of operator Pooling is not valid.' + raise tvm.error.OpAttributeUnimplemented(msg.format(attr['padding'])) if name == "avg_pool": attr['count_include_pad'] = False @@ -232,7 +235,8 @@ def _impl(inputs, attr, params): attr['dilations'] = (attr['dilations'][2], attr['dilations'][3]) attr['strides'] = (attr['strides'][2], attr['strides'][3]) else: - raise_attribute_invalid(attr['data_format'], 'data_format', 'conv') + msg = 'Value {} in attribute "data_format" of operator Conv is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format'])) if opname == 'depthwise': @@ -276,7 +280,8 @@ def _impl(inputs, attr, params): attr['padding'] = [0, 0] else: - raise_attribute_invalid(attr['padding'], 'padding', 'conv') + msg = 'Value {} in attribute "padding" of operator Conv is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr['padding'])) if 'kernel_layout' not in attr: if opname == 'conv': @@ -432,7 +437,8 @@ def _impl(inputs, attr, params): op_name="reshape", extras={'shape':tuple(params_new[0].asnumpy().flatten())}, ignores=['Tshape'])(inputs, attr) - raise_attribute_unimplemented('dynamic shape', 'reshape') + raise tvm.error.OpAttributeUnimplemented( + 'Attribute "dynamic shape" of operator Reshape is not supported.') return _impl def _bias_add(): @@ -736,7 +742,8 @@ def _impl(inputs, attr, params): if padlist_key in params: padlist = params.pop(padlist_key).asnumpy() else: - raise_attribute_required(padlist_key, 'pad') + raise tvm.error.OpAttributeRequired( + 'Required attribute "{}" not found in operator Pad.'.format(padlist_key)) paddings = tuple([tuple(l) for l in padlist]) attr['pad_width'] = paddings attr['pad_value'] = 0 @@ -1188,7 +1195,9 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): missing_operators = self._parse_import_prerequisites(graph) if missing_operators: - raise_operator_unimplemented(*missing_operators) + msg = 'The following operators are not supported in frontend TensorFlow: {}' + ops = str(list(missing_operators)).strip('[,]') + raise tvm.error.OpNotImplemented(msg.format(ops)) for node in graph.node: if node.op == 'Placeholder': @@ -1528,7 +1537,8 @@ def _convert_operator(self, op_name, inputs, attrs, self._params, graph, convert_map_rnn) else: - raise_operator_unimplemented(op_name) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend TensorFlow.'.format(op_name)) return sym def _fix_extranodes(self, op_name, attr, inputs): diff --git a/python/tvm/error_handling/__init__.py b/python/tvm/error_handling/__init__.py deleted file mode 100644 index 8616d1ba973a..000000000000 --- a/python/tvm/error_handling/__init__.py +++ /dev/null @@ -1,44 +0,0 @@ -import warnings -import traceback -import sys - -def _excepthook(type, value, tb): - print(''.join(traceback.format_exception(type, value, tb))) - -sys.excepthook = _excepthook - -class OperatorError(Exception): - pass - -def _raise_error_helper(exception, msg, *args): - raise exception(msg.format(*args)) - -def raise_attribute_required(key, op_name): - class OperatorAttributeRequired(OperatorError): - pass - msg = 'Required attribute {} not found in operator {}.' - _raise_error_helper(OperatorAttributeRequired, msg, key, op_name) - -def raise_attribute_invalid(val, attr, op_name): - class OperatorAttributeValueNotValid(OperatorError): - pass - msg = 'Value {} in attr {} is not valid in operator {}.' - _raise_error_helper(OperatorAttributeValueNotValid, msg, val, attr, - op_name) - -def raise_operator_unimplemented(*missing_ops): - class OperatorNotImplemented(OperatorError): - pass - missing_ops = str(missing_ops).strip('(,)') - msg = 'The following operators are not supported: {}.' - _raise_error_helper(OperatorNotImplemented, msg, missing_ops) - -def raise_attribute_unimplemented(key, op_name): - class OperatorAttributeNotImplemented(OperatorError): - pass - msg = 'Attribute {} is not supported in operator {}.' - _raise_error_helper(OperatorAttributeNotImplemented, msg, key, op_name) - -def warn_not_used(attr, op_name): - msg = '{} is ignored in {}.'.format(attr, op_name) - warnings.warn(msg) diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index 6ba2f0bde12d..dee3999ad3f1 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -14,8 +14,3 @@ from .coreml import from_coreml from .caffe2 import from_caffe2 from .tensorflow import from_tensorflow -from tvm.error_handling import raise_attribute_required, \ - raise_attribute_invalid, \ - raise_operator_unimplemented, \ - raise_attribute_unimplemented, \ - warn_not_used diff --git a/python/tvm/relay/frontend/caffe2.py b/python/tvm/relay/frontend/caffe2.py index 5ae7a294d306..769740df0be3 100755 --- a/python/tvm/relay/frontend/caffe2.py +++ b/python/tvm/relay/frontend/caffe2.py @@ -1,6 +1,7 @@ # pylint: disable=import-self, invalid-name, line-too-long, unused-argument """Caffe2 frontend""" from __future__ import absolute_import as _abs +import tvm from .. import ir_pass from .. import expr as _expr from .. import op as _op @@ -15,7 +16,8 @@ def _impl(attr): kernel = attr['kernel_shape'] if len(kernel) == 2: return prefix + '2d' + surfix - raise_operator_unimplemented('non 2d kernel') + raise tvm.error.OpAttributeUnimplemented( + 'Non-2D kernels are not supported for operator {}2d'.format(prefix)) return _impl @@ -27,7 +29,8 @@ def revert_caffe2_pad(pads): elif len(pads) == 2: pass else: - raise_attribute_invalid(str(len(pads)), 'len(pads)', 'padding') + raise tvm.error.OpAttributeInvalid( + 'Number of pads must equal 2 or 4.') return pads @@ -103,7 +106,8 @@ def get_converter(cls): if hasattr(cls, '_impl'): return getattr(cls, '_impl') - raise_operator_unimplemented(cls.__name__) + raise tvm.error.OpNotInplemented( + 'Operator {} is not supported in frontend Caffe2.'.format(cls.__name__)) _caffe2_internal_args = [ @@ -223,7 +227,8 @@ def _get_axis_from_order_str(order): return 1 if order == 'NHWC': return 3 - raise_attribute_unimplemented(order, 'Concat') + raise tvm.error.OpAttributeUnimplemented( + 'Order {} is not supported in operator Concat.'.format(order)) return AttrCvt( op_name='concatenate', @@ -515,7 +520,8 @@ def _convert_operator(self, # Add a sanitizing step to convert all byte strings in args to strings func = convert_map[op_type](inputs, args, self._params) else: - raise_operator_unimplemented(op_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Caffe2.'.format(op_type)) return func diff --git a/python/tvm/relay/frontend/coreml.py b/python/tvm/relay/frontend/coreml.py index 369a5d4bb3a4..963b21f38297 100644 --- a/python/tvm/relay/frontend/coreml.py +++ b/python/tvm/relay/frontend/coreml.py @@ -1,6 +1,7 @@ # pylint: disable=invalid-name, import-self, unused-argument, unused-variable, inconsistent-return-statements """CoreML frontend.""" from __future__ import absolute_import as _abs +import tvm import numpy as np from .. import ir_pass from .. import expr as _expr @@ -81,7 +82,8 @@ def _BatchnormLayerParams(op, inexpr, etab): """Get layer of batchnorm parameter""" # this changes the symbol if op.instanceNormalization: - raise_operator_unimplemented('instance normalization') + raise tvm.error.OpNotImplemented( + 'Operator "instance normalization" is not supported in frontend CoreML.') else: params = {'gamma':etab.new_const(list(op.gamma.floatValue)), 'beta':etab.new_const(list(op.beta.floatValue)), @@ -142,7 +144,8 @@ def _ActivationParams(op, inexpr, etab): alpha_expr = etab.new_const(alpha) beta_expr = etab.new_const(beta) return _op.multiply(_op.log(_op.add(_op.exp(inexpr), beta_expr)), alpha_expr) - raise_operator_unimplemented(whichActivation) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend CoreML.'.format(whichActivation)) def _ScaleLayerParams(op, inexpr, etab): @@ -164,7 +167,8 @@ def _PoolingLayerParams(op, inexpr, etab): return _op.nn.global_max_pool2d(inexpr) if op.type == 1: return _op.nn.global_avg_pool2d(inexpr) - raise_operator_unimplemented('pooling (not max or average)') + raise tvm.error.OpNotImplemented( + 'Only Max and Average Pooling are supported in frontend CoreML.') else: params = {'pool_size':list(op.kernelSize), @@ -184,8 +188,9 @@ def _PoolingLayerParams(op, inexpr, etab): params['padding'] = padding params['ceil_mode'] = True else: - raise_attribute_unimplemented(op.WhichOneof('PoolingPaddingType'), - 'PoolingPaddingType', 'pooling') + msg = 'PoolingPaddingType {} is not supported in operator Pooling.' + op_name = op.WhichOneof('PoolingPaddingType') + raise tvm.error.OpAttributeUnimplemented(msg.format(op_name)) # consume padding layer if etab.in_padding: @@ -197,7 +202,8 @@ def _PoolingLayerParams(op, inexpr, etab): return _op.nn.max_pool2d(inexpr, **params) if op.type == 1: return _op.nn.avg_pool2d(inexpr, **params) - raise_operator_unimplemented('pooling (not max or average)') + raise tvm.error.OpNotImplemented( + 'Only Max and Average Pooling are supported in CoreML.') def _SoftmaxLayerParams(op, inexpr, etab): @@ -240,7 +246,8 @@ def _ConcatLayerParams(op, inexpr, etab): if not isinstance(inexpr, list): inexpr = [inexpr] if op.sequenceConcat: - raise_operator_unimplemented('Sequence Concat') + raise tvm.error.OpNotImplemented( + 'Operator Sequence Concat is not supported in frontend CoreML.') ret = _op.concatenate(inexpr, axis=1) return ret @@ -256,14 +263,16 @@ def _PaddingLayerParams(op, inexpr, etab): if op.WhichOneof('PaddingType') == 'constant': constant = op.constant if constant.value != 0: - raise_attribute_unimplemented(constant.value, 'padding value', 'padding') + raise tvm.error.OpAttributeUnimplemented( + '{} is not supported in operator Padding.'.format(constant.value)) padding = [b.startEdgeSize for b in op.paddingAmounts.borderAmounts] padding2 = [b.endEdgeSize for b in op.paddingAmounts.borderAmounts] for i, j in zip(padding, padding2): assert i == j etab.set_padding(padding) else: - raise_operator_unimplemented('non-constant padding') + raise tvm.error.OpNotImplemented( + 'Non-constant padding is not supported in frontend CoreML.') return inexpr @@ -274,7 +283,8 @@ def _PermuteLayerParams(op, inexpr, etab): def _UpsampleLayerParams(op, inexpr, etab): if op.scalingFactor[0] != op.scalingFactor[1]: - raise_attribute_unimplemented('unequal height/width scaling factors', 'upsample') + raise tvm.error.OpAttributeUnimplemented( + 'Upsample height and width must be equal.') interpolationMode = 'NEAREST_NEIGHBOR' if op.mode == 0 else 'BILINEAR' return _op.nn.upsampling(inexpr, scale=op.scalingFactor[0], method=interpolationMode) @@ -364,7 +374,8 @@ def coreml_op_to_relay(op, inname, outname, etab): """ classname = type(op).__name__ if classname not in _convert_map: - raise_operator_unimplemented(classname) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend CoreML.'.format(classname)) if isinstance(inname, _base.string_types): insym = etab.get_expr(inname) else: diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 2e266852f9dc..bd7cb4f3b110 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -2,6 +2,7 @@ """Keras frontend.""" from __future__ import absolute_import as _abs import sys +import tvm import numpy as np from .. import ir_pass from .. import expr as _expr @@ -91,7 +92,8 @@ def _convert_activation(inexpr, keras_layer, _): x = (_expr.const(0.2, dtype='float32') * inexpr) + _expr.const(0.5, dtype='float32') return _op.clip(x, a_min=0., a_max=1.) - raise_operator_unimplemented(act_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Keras.'.format(act_type)) def _convert_advanced_activation(inexpr, keras_layer, etab): @@ -118,7 +120,8 @@ def _convert_advanced_activation(inexpr, keras_layer, etab): return _op.multiply(inexpr, _op.greater(inexpr, \ _expr.const(theta, dtype='float32')).astype('float32')) - raise_operator_unimplemented(act_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Keras.'.format(act_type)) def _convert_merge(inexpr, keras_layer, _): @@ -136,7 +139,8 @@ def _convert_merge(inexpr, keras_layer, _): ret = _op.add(ret, inexpr[i]) ret = ret / _expr.const(len(inexpr), dtype='float32') else: - raise_operator_unimplemented(merge_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Keras.'.format(merge_type)) return ret @@ -150,7 +154,8 @@ def _convert_dense(inexpr, keras_layer, etab): if input_dim > 2: input_shape = tuple(dim if dim else 1 for dim in _as_list(input_shape)[0]) if input_dim != 3 or input_shape[0] != 1 or input_shape[1] != 1: - raise_attribute_invaid(input_shape, 'input shape', 'dense') + raise tvm.error.OpAttributeInvalid( + 'Input shape {} is not valid for operator Dense.'.format(input_shape)) inexpr = _op.squeeze(inexpr, axis=0) out = _op.nn.dense(data=inexpr, **params) if keras_layer.use_bias: @@ -214,7 +219,9 @@ def _convert_convolution(inexpr, keras_layer, etab): inexpr = _op.nn.pad(data=inexpr, pad_width=( (0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r))) else: - raise_operator_unimplemented('padding with {}'.format(keras_layer.padding)) + msg = 'Padding with {} is not supported for operator Convolution ' \ + 'in frontend Keras.' + raise tvm.error.OpAttributeUnimplemented(msg.format(keras_layer.padding)) if is_deconv: out = _op.nn.conv2d_transpose(data=inexpr, **params) else: @@ -260,7 +267,10 @@ def _convert_separable_convolution(inexpr, keras_layer, etab): inexpr = _op.nn.pad(data=inexpr, pad_width=( (0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r))) else: - raise_operator_unimplemented('padding with {}'.format(keras_layer.padding)) + msg = 'Padding with {} is not supported for operator Separable ' \ + 'Convolution in frontend Keras.' + raise tvm.error.OpAttributeUnimplemented(msg.format(keras_layer.padding)) + depthconv = _op.nn.conv2d(data=inexpr, **params0) # pointwise conv weight1 = weightList[1].transpose([3, 2, 0, 1]) @@ -313,13 +323,15 @@ def _convert_pooling(inexpr, keras_layer, etab): pad_l, pad_r = _get_pad_pair(in_w, pool_w, stride_w) params['padding'] = [pad_t, pad_l, pad_b, pad_r] else: - raise_operator_unimplemented('padding with {}'.format(keras_layer.padding)) + raise tvm.error.OpAttributeUnimplemented( + 'Padding with {} is not supported in operator Pooling.'.format(keras_layer.padding)) if pool_type == 'MaxPooling2D': return _op.nn.max_pool2d(inexpr, **params) if pool_type == 'AveragePooling2D': params['count_include_pad'] = False return _op.nn.avg_pool2d(inexpr, **params) - raise_operator_unimplemented('pooling type {}'.format(keras_layer)) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported for frontend Keras.'.format(keras_layer)) def _convert_upsample(inexpr, keras_layer, _): @@ -331,7 +343,8 @@ def _convert_upsample(inexpr, keras_layer, _): elif upsample_type == 'UpSampling2D': h, w = keras_layer.size if h != w: - raise_attribute_invalid(keras_layer.size, 'size', 'upsample') + raise tvm.error.OpAttributeInvalid( + 'Height must equal width for operator Upsample.') params = {'scale': h} if hasattr(keras_layer, 'interpolation'): @@ -344,23 +357,24 @@ def _convert_upsample(inexpr, keras_layer, _): elif upsample_type == 'UpSampling3D': h, w, d = keras_layer.size if h != w or w != d: - raise_attribute_invalid(keras_layer.size, 'size', 'upsample') + raise tvm.error.OpAttributeInvalid( + 'Height, width, and depth must all be equal for operator Upsample.') params = {'scale': h} else: - raise_operator_unimplemented(upsample_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported for frontend Keras.'.format(upsample_type)) return _op.nn.upsampling(inexpr, **params) def _convert_cropping(inexpr, keras_layer, _): _check_data_format(keras_layer) crop_type = type(keras_layer).__name__ - if crop_type == 'Cropping1D': - raise_operator_unimplemented(crop_type) - elif crop_type == 'Cropping2D': + if crop_type == 'Cropping2D': (_, in_h, in_w, _) = keras_layer.input_shape ((crop_t, crop_b), (crop_l, crop_r)) = keras_layer.cropping else: - raise_operator_unimplemented(crop_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported for frontend Keras.'.format(crop_type)) int32_max = np.iinfo(np.int32).max return _op.strided_slice(inexpr, begin=[0, 0, crop_t, crop_l], \ end=[int32_max, int32_max, in_h-crop_b, in_w-crop_r]) @@ -405,14 +419,18 @@ def _convert_padding(inexpr, keras_layer, _): top, bottom = padding[0] left, right = padding[1] else: - raise_attribute_invalid(str(padding), 'padding', 'padding') + msg = 'Value {} in attribute "padding" of operator Padding ' \ + 'is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(str(padding))) else: - raise_attribute_invalid(str(padding), 'padding', 'padding') - elif padding_type == 'ZeroPadding1D': - raise_operator_unimplemented(padding_type) + msg = 'Value {} in attribute "padding" of operator Padding is ' \ + 'not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(str(padding))) else: - raise_operator_unimplemented(padding_type) - return _op.nn.pad(data=inexpr, pad_width=((0, 0), (0, 0), (top, bottom), (left, right))) + msg = 'Operator {} is not supported in frontend Keras.' + raise tvm.error.OpNotImplemented(msg.format(padding_type)) + return _op.nn.pad(data=inexpr, + pad_width=((0, 0), (0, 0), (top, bottom), (left, right))) def _convert_concat(inexpr, keras_layer, _): @@ -599,8 +617,10 @@ def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument def _check_unsupported_layers(model): for layer in model.layers: - if type(layer).__name__ not in _convert_map: - raise_operator_unimplemented(type(layer).__name__) + op_name = type(layer).__name__ + if op_name not in _convert_map: + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend Keras.'.format(op_name)) def keras_op_to_relay(inexpr, keras_layer, outname, etab): @@ -620,9 +640,11 @@ def keras_op_to_relay(inexpr, keras_layer, outname, etab): etab : relay.frontend.common.ExprTable The global expression table to be updated. """ - if type(keras_layer).__name__ not in _convert_map: - raise_operator_unimplemented(type(keras_layer).__name__) - outs = _convert_map[type(keras_layer).__name__](inexpr, keras_layer, etab) + op_name = type(keras_layer).__name__ + if op_name not in _convert_map: + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported for frontend Keras.'.format(op_name)) + outs = _convert_map[op_name](inexpr, keras_layer, etab) outs = _as_list(outs) for t_idx, out in enumerate(outs): name = outname + ":" + str(t_idx) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index b28558bb25f9..39daaf91063a 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -3,10 +3,12 @@ from __future__ import absolute_import as _abs import json +import tvm from .. import ir_pass from .. import expr as _expr from .. import op as _op from ... import nd as _nd + from .common import StrAttrsDict from .nnvm_common import _rename, _binop_scalar, _rbinop_scalar, _reduce from .nnvm_common import _arg_reduce, _init_op, _softmax_op, _cast @@ -41,7 +43,8 @@ def _get_channel_axis(layout, op_name): return 1 if layout == "NHWC": return 3 - raise_attribute_invalid(layout, 'layout', op_name) + raise tvm.error.OpAttributeInvalid( + 'Value {} in attribute "layout" of operator {} is not valid.'.format(layout, op_name)) def _mx_activations(inputs, attrs): @@ -61,7 +64,8 @@ def _stable_softrelu(x): return _op.add(_op.log(_op.add(one, exp_neg_abs_x)), _op.nn.relu(x)) return _stable_softrelu(inputs[0]) - raise_operator_unimplemented(act_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported for frontend MXNet.'.format(act_type)) def _mx_compare(new_op, wrapper): @@ -74,7 +78,8 @@ def impl(inputs, attrs): def _mx_conv2d(inputs, attrs): kernel_size = attrs.get_int_tuple("kernel") if len(kernel_size) != 2: - raise_attribute_invalid(kernel_size, 'kernel size', 'conv2d') + raise tvm.error.OpAttributeInvalid( + 'Non-2D kernels are not supported for operator Conv2D.') data_layout = attrs.get_str("layout", "NCHW") channel_axis = _get_channel_axis(data_layout, "conv2d") @@ -102,10 +107,12 @@ def _mx_conv2d(inputs, attrs): def _mx_conv2d_transpose(inputs, attrs): if "target_shape" in attrs.attrs: - raise_attribute_unimplemented('target_shape', 'conv2d_transpose') + raise tvm.error.OpAttributeUnimplemented( + 'Attribute "target_shape" is not supported for operator Conv2D-transpose.') kernel_size = attrs.get_int_tuple("kernel") if len(kernel_size) != 2: - raise_attribute_invalid(len(kernel_size), 'kernel dimensionality', 'conv2d') + raise tvm.error.OpAttributeInvalid( + 'Non-2D kernels are not supported for operator Conv2D-transpose.') data_layout = attrs.get_str("layout", "NCHW") channel_axis = _get_channel_axis(data_layout, "conv2d_transpose") @@ -140,7 +147,8 @@ def _mx_pooling(inputs, attrs): def _pool2d(new_op, is_avg): kernel_size = attrs.get_int_tuple("kernel") if len(kernel_size) != 2: - raise_attribute_invalid(len(kernel_size), 'kernel dimensionality', 'pool2d') + raise tvm.error.OpAttributeInvalid( + 'Only 2D kernels are supported for operator Pool2D.') new_attrs = {} new_attrs["pool_size"] = kernel_size new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1)) @@ -158,7 +166,8 @@ def _pool2d(new_op, is_avg): if global_pool: return _op.nn.global_avg_pool2d(inputs[0]) return _pool2d(_op.nn.avg_pool2d, True) - raise_operator_unimplemented(pool_type) + raise tvm.error.OpNotImplemented( + 'Operator {} Pooling is not supported for frontend MXNet.'.format(pool_type.capitalize())) def _mx_dropout(inputs, attrs): @@ -172,7 +181,8 @@ def _mx_BlockGrad(inputs, attrs): #pylint: disable=unused-argument def _mx_batch_norm(inputs, attrs): if attrs.get_bool("output_mean_var", False): - raise_attribute_unimplemented('output_mean_var', 'batch_norm') + raise tvm.error.OpAttributeUnimplemented( + 'Attribute "output_mean_var" is not supported for operator Batch Norm.') if attrs.get_bool("use_global_stats", False): _warn_not_used("use_global_stats", "batch_norm") new_attrs = {} @@ -189,13 +199,17 @@ def _mx_slice(inputs, attrs): end = attrs.get_int_tuple('end', None) stride = attrs.get_int_tuple('step', None) if begin is None: - raise_attribute_required('begin', 'slice') + raise tvm.error.OpAttributeRequired( + 'Attribute "begin" not found in operator Slice.') if end is None: - raise_attribute_required('end', 'slice') + raise tvm.error.OpAttributeRequired( + 'Attribute "end" not found in operator Slice.') if None in begin: - raise_attribute_unimplemented('None in begin', 'slice') + raise tvm.error.OpAttributeInvalid( + 'Value None in attribute "begin" of operator Slice is not valid.') if None in end: - raise_attribute_unimplemented('None in end', 'slice') + raise tvm.error.OpAttributeInvalid( + 'Value None in attribute "end" of operator Slice is not valid.') new_attrs = {'begin': begin, 'end': end} if stride is not None: new_attrs['strides'] = stride @@ -299,7 +313,8 @@ def _mx_leaky_relu(inputs, attrs): upper_bound = attrs.get_float("upper_bound") alpha = (lower_bound + upper_bound) / 2.0 return _op.nn.leaky_relu(inputs[0], alpha=alpha) - raise_operator_unimplemented(act_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported for frontend MXNet.'.format(act_type)) def _mx_make_power(power): @@ -393,7 +408,9 @@ def _mx_batch_dot(inputs, attrs): transpose_a = attrs.get_bool("transpose_a", False) transpose_b = attrs.get_bool("transpose_b", False) if transpose_a is True: - raise_attribute_invalid(transpose_a, 'transpose_a', 'batch_dot') + msg = 'Value {} in attribute "transpose_a" of operator batch_dot ' \ + 'is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(transpose_a)) if transpose_b is False: b = _op.transpose(b, axes=[0, 2, 1]) return _op.batch_matmul(a, b) @@ -402,7 +419,8 @@ def _mx_batch_dot(inputs, attrs): def _mx_arange(inputs, attrs): assert len(inputs) == 0 if attrs.get_int("repeat", 1) != 1: - raise_attribute_unimplemented('repeat', 'arange') + raise tvm.error.OpAttributeUnimplemented( + 'Attribute "repeat" is not supported in operator arange.') new_attrs = {} new_attrs["start"] = attrs.get_float("start", 0) new_attrs["stop"] = attrs.get_float("stop") @@ -486,15 +504,20 @@ def _mx_box_nms(inputs, attrs): in_format = attrs.get_str('in_format', 'corner') out_format = attrs.get_str('out_format', 'corner') if coord_start != 2: - raise_attribute_invalid(coord_start, 'coord_start', 'box_nms') + raise tvm.error.OpAttributeInvalid( + 'Value of attribute "coord_start" must equal 2 for operator box_nms.') if score_index != 1: - raise_attribute_invalid(score_index, 'score_index', 'box_nms') + raise tvm.error.OpAttributeInvalid( + 'Value of attribute "score_index" must equal 1 for operator box_nms.') if id_index != -1 and int(id_index) != 0: - raise_attribute_invalid(id_index, 'id_index', 'box_nms') + raise tvm.error.OpAttributeInvalid( + 'Value of attribute "id_index" must equal either -1 or 0 for operator box_nms.') if in_format != 'corner': - raise_attribute_invalid(in_format, 'in_format', 'box_nms') + raise tvm.error.OpAttributeInvalid( + 'Value of attribute "in_format" must equal "corner" for operator box_nms.') if out_format != 'corner': - raise_attribute_invalid(out_format, 'out_format', 'box_nms') + raise tvm.error.OpAttributeInvalid( + 'Value of attribute "out_format" must equal "corner" for operator box_nms.') ret = _op.vision.get_valid_counts(inputs[0], score_threshold=valid_thresh) nms_out = _op.vision.non_max_suppression(ret[1], @@ -512,7 +535,8 @@ def _mx_l2_normalize(inputs, attrs): new_attrs = {} mode = attrs.get_str('mode', 'instance') if mode != 'channel': - raise_attribute_invalid(mode, 'mode', 'l2_normalize') + raise tvm.error.OpAttributeInvalid( + 'Value of attribute "mode" must equal "channel" for operator l2_normalize.') new_attrs['eps'] = attrs.get_float('eps', 1e-10) new_attrs['axis'] = [1] return _op.nn.l2_normalize(inputs[0], **new_attrs) @@ -772,10 +796,11 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info): elif isinstance(res, _expr.Expr): res = [res] else: - raise_attribute_invalid(type(res), 'type(res)', op_name) + raise RuntimeError("unexpected type %s" % type(res)) node_map[nid] = res else: - raise_operator_unimplemented(op_name) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported in frontend MXNet.'.format(op_name)) outputs = [node_map[e[0]][e[1]] for e in jgraph["heads"]] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 1bffdfd4bcd9..a6851b833931 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3,6 +3,7 @@ from __future__ import absolute_import as _abs import logging +import tvm import numpy as np from ... import nd as _nd from .. import ir_pass @@ -18,7 +19,9 @@ def _impl(attr): kernel = attr['kernel_shape'] if len(kernel) == 2: return prefix + '2d' + surfix - raise_attribute_invalid(len(kernel), 'kernel dimensionality', prefix) + msg = 'Only 2D kernels are supported for operator {}.' + op_name = prefix + '2d' + raise tvm.error.OpAttributeInvalid(msg.format(op_name)) return _impl @@ -29,7 +32,8 @@ def revert_caffe2_pad(pads): elif len(pads) == 2: pass else: - raise_attribute_invalid(len(pads), 'len(pads)', 'padding') + raise tvm.error.OpAttributeInvalid( + 'Number of pads must be either 2 or 4.') return pads def dimension_constraint(): @@ -461,7 +465,8 @@ def _impl_v9(cls, inputs, attr, params): elif mode == b'linear': method = "BILINEAR" else: - raise_attribute_invalid(mode, 'mode', 'upsample') + raise tvm.error.OpAttributeInvalid( + 'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode)) attr = {'scale':int(scales[-1]), 'method':method, 'layout':'NCHW'} return AttrCvt('upsampling')(inputs, attr) @@ -717,7 +722,10 @@ def _impl_v1(cls, inputs, attr, params): if 'input_as_shape' in attr and attr['input_as_shape']: shape = params[get_name(inputs[0])].asnumpy() else: - raise_attribute_required('extra_shape', 'ConstantFill') + if 'extra_shape' in attr: + raise tvm.error.OpAttributeInvalid('Attribute "extra_shape" not ' + 'supported with "fill_like" for ' + 'operator ConstantFill.') return _op.full_like(inputs[0], inputs[1]) if 'extra_shape' in attr: diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index f795aa70a596..afeaee7e8f95 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -27,7 +27,8 @@ def _get_relay_op(op_name): op = getattr(_op.image, op_name) if not op: - raise_operator_unimplemented(op_name) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported for frontend TensorFlow.'.format(op_name)) return op class AttrCvt(object): @@ -99,7 +100,8 @@ def __call__(self, inputs, attrs, *args): new_attrs = {} for k in attrs.keys(): if k in self._excludes: - raise_operator_unimplemented(k, op_name) + raise tvm.error.OpAttributeUnimplemented( + 'Attribute {} in operator {} is not supported.'.format(k, op_name)) elif k in self._disables: logging.warning("Attribute %s is disabled in relay.%s", k, op_name) elif k in self._ignores: @@ -148,7 +150,8 @@ def _required_attr(self, attr, key): """Wrapper for getting required attributes.""" assert isinstance(attr, dict) if key not in attr: - raise_attribute_required(key, self._op_name) + raise tvm.error.OpAttributeRequired( + 'Attribute {} not found in operator {}'.format(key, self._op_name)) return attr[key] def _get_pad_pair(input1d, kernel1d, stride1d): @@ -178,7 +181,8 @@ def _impl(attr): kernel = attr['kernel_shape'] if len(kernel) == 2: return prefix + '2d' + surfix - raise_attribute_invalid(len(kernel), 'kernel dimensionality', prefix) + raise tvm.error.OpAttributeInvalid( + 'Only 2D kernels are supported for operator {}'.format(prefix + '2d')) return _impl def _dimension_constraint(): @@ -238,7 +242,9 @@ def _impl(inputs, attr, params): attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3]) attr['strides'] = (attr['strides'][2], attr['strides'][3]) else: - raise_attribute_invalid(attr['data_format'], 'data_format', 'pooling') + msg = 'Value {} of attribute "data_format" of operator Pooling ' \ + 'is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(attrs['data_format'])) if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": tmp_shape = attr['_input_shapes'][inputs[0]] @@ -267,7 +273,9 @@ def _impl(inputs, attr, params): attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]] else: - raise_attribute_invalid(attr['padding'], 'padding', 'padding') + msg = 'Value {} in attribute "padding" of operator Pooling is ' \ + 'not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr['padding'])) if name == "avg_pool": attr['count_include_pad'] = False @@ -341,7 +349,9 @@ def _impl(inputs, attr, params): attr['dilations'] = (attr['dilations'][2], attr['dilations'][3]) attr['strides'] = (attr['strides'][2], attr['strides'][3]) else: - raise_attribute_unimplemented(attr['data_format'], 'data_format', 'conv') + msg = 'Value {} in attribute "data_format" of operator Conv is ' \ + 'not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format'])) if opname == 'depthwise': @@ -386,7 +396,9 @@ def _impl(inputs, attr, params): attr['padding'] = [0, 0] else: - raise_attribute_invalid(attr['padding'], 'padding', 'conv') + msg = 'Value {} in attribute "padding" of operator Conv is not ' \ + 'valid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr['padding'])) if 'kernel_layout' not in attr: if opname == 'conv': @@ -791,7 +803,8 @@ def _impl(inputs, attr, params): if padlist_key in params: padlist = params.pop(padlist_key).asnumpy() else: - raise_attribute_required(padlist_key, 'pad') + raise tvm.error.OpAttributeRequired( + 'Attribute {} not found in operator Pad.'.format(padlist_key)) paddings = tuple([tuple(l) for l in padlist]) attr['pad_width'] = paddings attr['pad_value'] = 0 diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 37f4e1367e53..0e31500fe67d 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3,6 +3,7 @@ from __future__ import absolute_import as _abs import math import numpy as np +import tvm from .. import ir_pass from .. import expr as _expr from .. import op as _op @@ -59,7 +60,10 @@ def check_unsupported_ops(self): unsupported_ops_set.add(op_code_str) if unsupported_ops_set: - raise_operator_unimplemented(*upsupported_ops_set) + msg = 'The following operators are not supported in frontend ' \ + 'TFLite: {}' + ops = str(list(unsupported_ops_set)).strip('[,]') + raise tvm.error.OpNotImplemented(msg.format(ops)) def convert_op_to_relay(self): """Convert TFLite ops to relay ops""" @@ -204,7 +208,8 @@ def convert_reshape(self, op): # finally convert back if necessary in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1)) else: - raise_attribute_invalid(input_shape_length, 'input shape length', 'reshape') + msg = 'Input shape length {} for operator Reshape is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length)) out = _op.reshape(in_expr, newshape=tuple(target_shape)) @@ -221,7 +226,8 @@ def convert_reshape(self, op): elif len(target_shape) == 4: out = _op.transpose(out, axes=(0, 3, 1, 2)) else: - raise_attribute_invalid(len(target_shape), 'shape length', 'reshape') + raise tvm.error.OpAttributeInvalid( + 'Length of target shape must be between 1 and 5 for operator Reshape.') return out @@ -327,7 +333,8 @@ def convert_squeeze(self, op): # finally convert back if necessary in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1)) else: - raise_attribute_invalid(input_shape_length, 'input shape length', 'squeeze') + msg = 'Input shape length {} for operator Squeeze is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length)) out = _op.squeeze(in_expr, axis=tuple(squeeze_axis)) @@ -344,7 +351,8 @@ def convert_squeeze(self, op): elif output_shape_length == 4: out = _op.transpose(out, axes=(0, 3, 1, 2)) else: - raise_attribute_invalid(output_shape_length, 'output_shape_length', 'squeeze') + msg = 'Output shape length {} for operator Squeeze is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(output_shape_length)) return out @@ -364,7 +372,8 @@ def convert_fused_activation_function(self, in_expr, fused_activation_fn): if fused_activation_fn == ActivationFunctionType.TANH: return _op.tanh(in_expr) fused_activation_fn_str = self.activation_fn_type[fused_activation_fn] - raise_operator_unimplemented(fused_activation_fn_str) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported for frontend TFLite.'.format(fused_activation_fn_str)) def convert_conv(self, op, conv_type): """convolution implementation.""" @@ -403,7 +412,8 @@ def convert_conv(self, op, conv_type): assert depth_multiplier == 1, "TF frontend have transformed it be 1 " \ "no matter original value be set by 0.25, 0.5 or any else" else: - raise_operator_unimplemented(conv_type) + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported for frontend TFLite.'.format(conv_type)) stride_h = conv_options.StrideH() stride_w = conv_options.StrideW() @@ -460,7 +470,8 @@ def convert_conv(self, op, conv_type): (pad_top, pad_bottom), (pad_left, pad_right))) else: - raise_attribute_invalid(padding, 'padding format', 'conv') + raise tvm.error.OpAttributeUnimplemented( + 'Padding format {} is not supported for operator Conv.'.format(padding)) out = _op.nn.conv2d(data=in_expr, weight=weight_expr, **params) @@ -523,14 +534,16 @@ def convert_pool2d(self, op, pool_type): pad_left, pad_right = get_pad_value(input_w, filter_w, stride_w) params['padding'] = [pad_top, pad_left, pad_bottom, pad_right] else: - raise_attribute_invalid(padding, 'padding', 'pool2d') + raise tvm.error.OpAttributeUnimplemented( + 'Padding format {} for operator Pool2D is not supported.'.format(padding)) if pool_type == "average": out = _op.nn.avg_pool2d(in_expr, **params) elif pool_type == "max": out = _op.nn.max_pool2d(in_expr, **params) else: - raise_operator_unimplemented(pool_type + ' pool') + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported for frontend TFLite.'.format(pool_type + ' pool')) # If we have fused activations if fused_activation_fn != ActivationFunctionType.NONE: