diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index ad7c4fc6796f..5feb446c8254 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -694,7 +694,7 @@ def _impl(inputs, attr, params): if padlist_key in params: padlist = params.pop(padlist_key).asnumpy() else: - raise RuntimeError("Required parameter {} not fount.".format(padlist_key)) + raise RuntimeError("Required parameter {} not found.".format(padlist_key)) paddings = tuple([tuple(l) for l in padlist]) attr['pad_width'] = paddings attr['pad_value'] = 0 @@ -768,6 +768,23 @@ def _impl(inputs, attr, params): )(inputs, attr) return _impl +def _split(name): + def _impl(inputs, attr, params): + if name == 'Split': + axis = params.pop(inputs[0].list_output_names()[0]).asnumpy()[0] + return AttrCvt(op_name="split", ignores=['Tdim', 'Tidx'], + transforms={'num_split': 'indices_or_sections'}, + extras={'axis': axis})(inputs[1], attr) + elif name == 'SplitV': + indices = params.pop(inputs[1].list_output_names()[0]).asnumpy() + axis = params.pop(inputs[2].list_output_names()[0]).asnumpy()[0] + return AttrCvt(op_name="split", ignores=['Tdim', 'Tidx', 'Tlen', 'num_split'], + extras={'indices_or_sections': tuple(sorted(indices)), + 'axis': axis})(inputs[0], attr) + else: + raise NotImplementedError("Unexpected split type: {}".format(name)) + return _impl + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -834,6 +851,8 @@ def _impl(inputs, attr, params): 'GreaterEqual' : _broadcast('greater_equal'), 'Equal' : _broadcast('equal'), 'NotEqual' : _broadcast('not_equal'), + 'Split' : _split('Split'), + 'SplitV' : _split('SplitV') } # _convert_map_rnn defines maps of rnn operator name to @@ -1131,14 +1150,16 @@ def from_tensorflow(self, graph, layout="NHWC"): node.input[0] = in_name # Fill shapes for all inputs in a list - try: - inputs = [self._nodes[i] for i in node.input] - for i in node.input: - input_shapes[self._nodes[i]] = self._output_shapes[i] - attr['_input_shapes'] = input_shapes - except KeyError: - # TODO: Need to find clean way to handle '^CheckNumerics' - pass + inputs = [] + for i in node.input: + try: + symbol = self._nodes[i] + inputs.append(symbol) + input_shapes[symbol] = self._output_shapes[i] + except KeyError: + # TODO: Need to find clean way to handle '^CheckNumerics' + pass + attr['_input_shapes'] = input_shapes inputs = self._fix_extranodes(node.op, attr, inputs) diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index 2ebc7b671ba5..db4525cbe7e5 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -638,6 +638,30 @@ def test_forward_pad(): _test_pad((2, 3), [[1,1], [2,2]], mode="CONSTANT", constant_values=1.0) +####################################################################### +# Split +# ----- +def test_forward_split(): + def check_split(ishape, **kwargs): + inp_array = np.random.uniform(size=ishape).astype(np.float32) + with tf.Graph().as_default(): + in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype) + tf.split(in1, **kwargs) + compare_tf_with_tvm(inp_array, 'Placeholder:0', 'split:0') + + def check_split_concat(ishape, **kwargs): + inp_array = np.random.uniform(size=ishape).astype(np.float32) + with tf.Graph().as_default(): + in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype) + splited = tf.split(in1, **kwargs) + tf.concat(splited, axis=1) + compare_tf_with_tvm(inp_array, 'Placeholder:0', 'concat:0') + + check_split((5, 30), num_or_size_splits=3, axis=1) + check_split((5, 30), num_or_size_splits=[4, 15, 11], axis=1) + check_split_concat((5, 30), num_or_size_splits=[15, 15], axis=1) + + ####################################################################### # Inception V3 # ------------ @@ -1013,6 +1037,7 @@ def test_forward_rel_ops(): test_forward_pad() test_forward_gather() #test_forward_stridedslice() + test_forward_split() # Activations test_forward_sigmoid()