diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 65dd3619b5b2..1efc46aba85f 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -7,6 +7,7 @@ import numpy as np import tvm +from nnvm import NNVMError from .. import symbol as _sym from .. import graph as _graph from .. compiler import graph_util @@ -133,7 +134,7 @@ def _impl(inputs, attr, params): attr['strides'] = (attr['strides'][1], attr['strides'][2]) # Fix padding - input_shapes = attr['_input_shapes'][inputs[0]] + input_shape = attr['_input_shapes'][inputs[0]] attr['padding'] = attr['padding'].decode("utf-8") if attr['padding'] == 'VALID': @@ -142,11 +143,11 @@ def _impl(inputs, attr, params): stride_h, stride_w = attr['strides'] kernel_h, kernel_w = attr['kernel_shape'] if attr['data_format'] == 'NHWC': - in_h = input_shapes[0][1] - in_w = input_shapes[0][2] + in_h = input_shape[1] + in_w = input_shape[2] else: - in_h = input_shapes[0][2] - in_w = input_shapes[0][3] + in_h = input_shape[2] + in_w = input_shape[3] pad_v = _get_pad_pair(in_h, kernel_h, stride_h) pad_h = _get_pad_pair(in_w, kernel_w, stride_w) @@ -171,28 +172,31 @@ def _impl(inputs, attr, params): def _conv(opname): def _impl(inputs, attr, params): attr['data_format'] = attr['data_format'].decode("utf-8") - input_shapes = attr['_input_shapes'][inputs[0]] + input_shape = attr['_input_shapes'][inputs[0]] # Extract kernel shape from params - conv_param_weights = params[inputs[1].list_output_names()[0]] + if inputs[1] in attr['_input_shapes']: + weight_shape = tuple(attr['_input_shapes'][inputs[1]]) + else: + weight_shape = params[inputs[1].list_output_names()[0]].shape if attr['data_format'] == 'NHWC': - kernel_h, kernel_w, _, depth_mult = conv_param_weights.shape - attr['kernel_shape'] = (conv_param_weights.shape[0], conv_param_weights.shape[1]) + kernel_h, kernel_w, _, depth_mult = weight_shape + attr['kernel_shape'] = (weight_shape[0], weight_shape[1]) if opname == 'conv': - attr['channels'] = conv_param_weights.shape[3] + attr['channels'] = weight_shape[3] else: - attr['channels'] = input_shapes[0][3] * depth_mult + attr['channels'] = input_shape[3] * depth_mult if 'dilations' in attr: attr['dilations'] = (attr['dilations'][0], attr['dilations'][1]) elif attr['data_format'] == 'NCHW': - depth_mult, _, kernel_h, kernel_w = conv_param_weights.shape - attr['kernel_shape'] = (conv_param_weights.shape[2], conv_param_weights.shape[3]) + depth_mult, _, kernel_h, kernel_w = weight_shape + attr['kernel_shape'] = (weight_shape[2], weight_shape[3]) if opname == 'conv': - attr['channels'] = conv_param_weights.shape[1] + attr['channels'] = weight_shape[1] else: - attr['channels'] = input_shapes[0][1] * depth_mult + attr['channels'] = input_shape[1] * depth_mult if 'dilations' in attr: attr['dilations'] = (attr['dilations'][2], attr['dilations'][3]) @@ -215,11 +219,11 @@ def _impl(inputs, attr, params): stride_h, stride_w = attr['strides'] kernel_h, kernel_w = attr['kernel_shape'] if attr['data_format'] == 'NHWC': - in_h = input_shapes[0][1] - in_w = input_shapes[0][2] + in_h = input_shape[1] + in_w = input_shape[2] else: - in_h = input_shapes[0][2] - in_w = input_shapes[0][3] + in_h = input_shape[2] + in_w = input_shape[3] pad_v = _get_pad_pair(in_h, kernel_h, stride_h) pad_h = _get_pad_pair(in_w, kernel_w, stride_w) @@ -428,10 +432,20 @@ def _impl(inputs, attr, params): ignores=['index_type', 'T'])(new_inputs, attr) return _impl +def _split(): + def _impl(inputs, attr, params): + pop_node = inputs.pop(0) + axis = params[pop_node.list_output_names()[0]].asnumpy()[0] + return AttrCvt(op_name="split", + ignores=['num_split'], + extras={'indices_or_sections':attr['num_split'], + 'axis': axis})(inputs, attr) + return _impl + def _lrn(): def _impl(inputs, attr, params): attr_new = {} - depth_radius = attr.get('depth_radius', 5) + depth_radius = attr.get('depth_radius', 2) size = (depth_radius * 2) + 1 attr_new['axis'] = 3 # Fix axis, NHWC format attr_new['size'] = size @@ -493,7 +507,7 @@ def _impl(inputs, attr, params): new_axis_mask = int(attr.get('new_axis_mask', 0)) shrink_axis_mask = int(attr.get('shrink_axis_mask', 0)) data_shape = attr['_input_shapes'][inputs[0]] - data_dim = len(data_shape[0]) + data_dim = len(data_shape) stride_dim = len(stride) def _transform_mask(stride_dim, ellipsis_mask): @@ -523,26 +537,26 @@ def _transform_mask(stride_dim, ellipsis_mask): + new_axes_after_ellipsis), data_dim) for i in range(final_index, to_index): m_begin[final_index] = 0 - m_end[final_index] = data_shape[0][final_index] + m_end[final_index] = data_shape[final_index] m_stride[final_index] = 1 final_index += 1 elif not mask & new_axis_mask: if final_index == len(m_begin): break if mask & begin_mask: - m_begin[final_index] = data_shape[0][final_index] \ + m_begin[final_index] = data_shape[final_index] \ if stride[index] < 0 else 0 elif begin[index]: m_begin[final_index] = begin[index] if mask & end_mask: m_end[final_index] = 0 if stride[index] < 0 \ - else data_shape[0][final_index] + else data_shape[final_index] elif end[index]: m_end[final_index] = end[index] m_stride[final_index] = stride[index] if mask & shrink_axis_mask: #Tensorflow make axis with shrink_axis_mask as dimension 1 - m_begin[final_index] = data_shape[0][final_index] + begin[index] \ + m_begin[final_index] = data_shape[final_index] + begin[index] \ if begin[index] < 0 else begin[index] m_end[final_index] = begin[index] + 1 m_stride[final_index] = 1 @@ -603,8 +617,8 @@ def _impl(inputs, in_state_c, in_state_h, attr, params): forget_bias = attr.pop('forget_bias') input_shape = attr['_input_shapes'][inputs[0]] weight_shape = attr['_input_shapes'][inputs[3]] - batch_size, input_size = input_shape[0][0], input_shape[0][1] - num_hidden_layers = weight_shape[0][1] + batch_size, input_size = input_shape[0], input_shape[1] + num_hidden_layers = weight_shape[1] num_hidden = num_hidden_layers // 4 in_data = _sym.reshape(in_data, @@ -695,6 +709,7 @@ def _impl(inputs, attr, params): 'Fill' : _fill(), 'GatherV2' : _gather_v2(), 'StridedSlice' : _stridedSlice(), + 'Split' : _split(), 'LRN' : _lrn(), 'Pad' : _pad('Pad'), 'PadV2' : _pad('PadV2'), @@ -798,8 +813,8 @@ def _LSTMBlockCellWrapper(inputs, attr, params, """LSTM cell warapper to prepare the inputs""" input_shape = attr['_input_shapes'][inputs[0]] weight_shape = attr['_input_shapes'][inputs[3]] - batch_size = input_shape[0][0] - num_hidden = weight_shape[0][1] // 4 + batch_size = input_shape[0] + num_hidden = weight_shape[1] // 4 if layer == 0: #Create initial states placeholder in case of first layer @@ -982,24 +997,31 @@ def from_tensorflow(self, graph): # Pass the node name too in attr attr["_node_name"] = node.name - #ToDo: Some of the tensorflow operators internaly maintain - #execution layers and its output name will the layer number along with - #graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the - #output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case, - #the digit has to be ignored. - if ":" in node.input[0]: - in_name, _ = node.input[0].split(':') - node.input[0] = in_name - - # Fill shapes for all inputs in a list - try: - inputs = [self._nodes[i] for i in node.input] - for i in node.input: - input_shapes[self._nodes[i]] = self._output_shapes[i] - attr['_input_shapes'] = input_shapes - except KeyError: - # TODO: Need to find clean way to handle '^CheckNumerics' - pass + inputs = [] + for node_input_name in node.input: + node_input_key = node_input_name.split(':') + slot_num = 0 + if len(node_input_key) > 1: + slot_num = int(node_input_key[1]) + node_input_key = node_input_key[0] + + try: + try: + input_sym = self._nodes[node_input_key].__getitem__(slot_num) + except NNVMError: + # TODO: Fancy node name with invalid slot should discard and + # retrieve node input with zero'th(default) index. + # eg: Node name:- 'Model/RNN/cell_0/RnnCell:6', in this case + # name had fancy name convention and discard slot-id. + input_sym = self._nodes[node_input_key].__getitem__(0) + + inputs.append(input_sym) + input_shapes[input_sym] = self._output_shapes[ + node_input_key].__getitem__(slot_num) + attr['_input_shapes'] = input_shapes + except KeyError: + # TODO: Need to find clean way to handle '^CheckNumerics' + pass inputs = self._fix_extranodes(node.op, attr, inputs) diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index 6fa020a03444..0c1027d4b602 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -71,8 +71,6 @@ def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype) def run_tf_graph(sess, input_data, input_node, output_node): """ Generic function to execute tensorflow """ - tensor = sess.graph.get_tensor_by_name(output_node) - if isinstance(input_data, list): input_dict = {} for i, e in enumerate(input_node): @@ -80,8 +78,12 @@ def run_tf_graph(sess, input_data, input_node, output_node): else: input_dict = {input_node: input_data} - output_data = sess.run(tensor, input_dict) - return output_data + if isinstance(output_node, list): + tensor = [sess.graph.get_tensor_by_name(node_name) for node_name in output_node] + return sess.run(tensor, input_dict) + + tensor = sess.graph.get_tensor_by_name(output_node) + return sess.run(tensor, input_dict) def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False): @@ -787,29 +789,72 @@ def _get_sample(data, state): np.testing.assert_allclose(tf_samples, tvm_samples, rtol=1e-5, atol=1e-5) assert(tvm_sample_str == tf_sample_str) -####################################################################### -# LRN (Local Response Normalization) -# ---------------------------------- -def _test_lrn(ishape, size, axis, bias, alpha, beta): - """ testing local response normalization """ - lrn_depth_radius = size / 2 +####################################################################### +# Split +# ------ +def _test_split(ip_shape, num_split, axis): + tf.reset_default_graph() + dtype = 'float32' + in_data = tf.placeholder(dtype, ip_shape, name="in_data") + tf.split(in_data, num_split, axis=axis, name="split") + np_data = np.random.uniform(size=ip_shape).astype(dtype) + with tf.Session() as sess: + final_graph_def = tf.graph_util.convert_variables_to_constants( + sess, + sess.graph.as_graph_def(add_shapes=True), + ['split']) + + tf_out_node_names = ['split:%g' % i for i in range(num_split)] + tf_output = run_tf_graph(sess, [np_data], ['in_data:0'], tf_out_node_names) + tvm_out_shapes = [tf_output[i].shape for i in range(num_split)] + tvm_out_dtypes = [tf_output[i].dtype for i in range(num_split)] + tvm_output = run_tvm_graph(final_graph_def, [np_data], ['in_data'], + tvm_out_shapes, tvm_out_dtypes) + np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5) + sess.close() - inp_array = np.random.uniform(size=ishape).astype(np.float32) +def test_forward_split(): + '''test split operator''' + _test_split((2, 3), 2, axis=0) + _test_split((6, 3), 3, axis=0) + _test_split((5, 9, 3), 3, axis=1) + _test_split((2,5,3,9), 3, axis=3) - with tf.Graph().as_default(): - in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype, name="lrn0_data") - nn_ops.local_response_normalization(in1, - name="lrn", - depth_radius=lrn_depth_radius, - bias=bias, - alpha=alpha, - beta=beta) - compare_tf_with_tvm(inp_array, 'lrn0_data:0', 'lrn:0') +####################################################################### +# LRN (Local Response Normalization) +# ---------------------------------- +def _test_lrn(ip_shape, depth_radius=2, alpha=1e-05, beta=0.75, bias=1.0): + tf.reset_default_graph() + dtype = 'float32' + in_data = tf.placeholder(dtype, ip_shape, name="in_data") + tf.nn.local_response_normalization(in_data, + depth_radius=depth_radius, + alpha=alpha, + beta=beta, + bias=bias, + name="local_response_normalization") + np_data = np.random.uniform(size=ip_shape).astype(dtype) + with tf.Session() as sess: + final_graph_def = tf.graph_util.convert_variables_to_constants( + sess, + sess.graph.as_graph_def(add_shapes=True), + ['local_response_normalization']) + tf_output = run_tf_graph(sess, [np_data], ['in_data:0'], + 'local_response_normalization:0') + tvm_output = run_tvm_graph(final_graph_def, [np_data], ['in_data'], + tf_output.shape, dtype) + np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5) + sess.close() def test_forward_lrn(): - _test_lrn((1, 3, 20, 20), 3, 1, 1.0, 1.0, 0.5) + '''test local_response_normalization operator''' + _test_lrn((4, 4, 4, 4)) + _test_lrn((4,3,5,6), 1, 1e-05, 0.5) + _test_lrn((1, 3, 20, 20), 2, 1e-05, 0.5, 1.0) + _test_lrn((1, 3, 20, 20), 2, 1e-05, 0.75, 2.0) + ####################################################################### # l2_normalize @@ -855,6 +900,7 @@ def test_forward_l2_normalize(): test_forward_resize_bilinear() test_forward_pad() test_forward_lstm() + test_forward_split() test_forward_stridedslice() test_forward_gather() test_forward_ptb()