-
Notifications
You must be signed in to change notification settings - Fork 3.8k
TensorFlow Split and LRN operator support for Alexnet Model #1572
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cef9484
313000d
d4cb1a9
f1a5376
9f6f77d
da194ce
937d040
04f51bf
d06ca78
78e4363
7de2d28
64f7635
68d9111
c1e374a
fbed5b1
47eca2c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,7 @@ | |
| import numpy as np | ||
|
|
||
| import tvm | ||
| from nnvm import NNVMError | ||
| from .. import symbol as _sym | ||
| from .. import graph as _graph | ||
| from .. compiler import graph_util | ||
|
|
@@ -133,7 +134,7 @@ def _impl(inputs, attr, params): | |
| attr['strides'] = (attr['strides'][1], attr['strides'][2]) | ||
|
|
||
| # Fix padding | ||
| input_shapes = attr['_input_shapes'][inputs[0]] | ||
| input_shape = attr['_input_shapes'][inputs[0]] | ||
| attr['padding'] = attr['padding'].decode("utf-8") | ||
|
|
||
| if attr['padding'] == 'VALID': | ||
|
|
@@ -142,11 +143,11 @@ def _impl(inputs, attr, params): | |
| stride_h, stride_w = attr['strides'] | ||
| kernel_h, kernel_w = attr['kernel_shape'] | ||
| if attr['data_format'] == 'NHWC': | ||
| in_h = input_shapes[0][1] | ||
| in_w = input_shapes[0][2] | ||
| in_h = input_shape[1] | ||
| in_w = input_shape[2] | ||
| else: | ||
| in_h = input_shapes[0][2] | ||
| in_w = input_shapes[0][3] | ||
| in_h = input_shape[2] | ||
| in_w = input_shape[3] | ||
|
|
||
| pad_v = _get_pad_pair(in_h, kernel_h, stride_h) | ||
| pad_h = _get_pad_pair(in_w, kernel_w, stride_w) | ||
|
|
@@ -171,28 +172,31 @@ def _impl(inputs, attr, params): | |
| def _conv(opname): | ||
| def _impl(inputs, attr, params): | ||
| attr['data_format'] = attr['data_format'].decode("utf-8") | ||
| input_shapes = attr['_input_shapes'][inputs[0]] | ||
| input_shape = attr['_input_shapes'][inputs[0]] | ||
|
|
||
| # Extract kernel shape from params | ||
| conv_param_weights = params[inputs[1].list_output_names()[0]] | ||
| if inputs[1] in attr['_input_shapes']: | ||
| weight_shape = tuple(attr['_input_shapes'][inputs[1]]) | ||
| else: | ||
| weight_shape = params[inputs[1].list_output_names()[0]].shape | ||
|
|
||
| if attr['data_format'] == 'NHWC': | ||
| kernel_h, kernel_w, _, depth_mult = conv_param_weights.shape | ||
| attr['kernel_shape'] = (conv_param_weights.shape[0], conv_param_weights.shape[1]) | ||
| kernel_h, kernel_w, _, depth_mult = weight_shape | ||
| attr['kernel_shape'] = (weight_shape[0], weight_shape[1]) | ||
| if opname == 'conv': | ||
| attr['channels'] = conv_param_weights.shape[3] | ||
| attr['channels'] = weight_shape[3] | ||
| else: | ||
| attr['channels'] = input_shapes[0][3] * depth_mult | ||
| attr['channels'] = input_shape[3] * depth_mult | ||
|
|
||
| if 'dilations' in attr: | ||
| attr['dilations'] = (attr['dilations'][0], attr['dilations'][1]) | ||
| elif attr['data_format'] == 'NCHW': | ||
| depth_mult, _, kernel_h, kernel_w = conv_param_weights.shape | ||
| attr['kernel_shape'] = (conv_param_weights.shape[2], conv_param_weights.shape[3]) | ||
| depth_mult, _, kernel_h, kernel_w = weight_shape | ||
| attr['kernel_shape'] = (weight_shape[2], weight_shape[3]) | ||
| if opname == 'conv': | ||
| attr['channels'] = conv_param_weights.shape[1] | ||
| attr['channels'] = weight_shape[1] | ||
| else: | ||
| attr['channels'] = input_shapes[0][1] * depth_mult | ||
| attr['channels'] = input_shape[1] * depth_mult | ||
|
|
||
| if 'dilations' in attr: | ||
| attr['dilations'] = (attr['dilations'][2], attr['dilations'][3]) | ||
|
|
@@ -215,11 +219,11 @@ def _impl(inputs, attr, params): | |
| stride_h, stride_w = attr['strides'] | ||
| kernel_h, kernel_w = attr['kernel_shape'] | ||
| if attr['data_format'] == 'NHWC': | ||
| in_h = input_shapes[0][1] | ||
| in_w = input_shapes[0][2] | ||
| in_h = input_shape[1] | ||
| in_w = input_shape[2] | ||
| else: | ||
| in_h = input_shapes[0][2] | ||
| in_w = input_shapes[0][3] | ||
| in_h = input_shape[2] | ||
| in_w = input_shape[3] | ||
|
|
||
| pad_v = _get_pad_pair(in_h, kernel_h, stride_h) | ||
| pad_h = _get_pad_pair(in_w, kernel_w, stride_w) | ||
|
|
@@ -428,10 +432,20 @@ def _impl(inputs, attr, params): | |
| ignores=['index_type', 'T'])(new_inputs, attr) | ||
| return _impl | ||
|
|
||
| def _split(): | ||
| def _impl(inputs, attr, params): | ||
| pop_node = inputs.pop(0) | ||
| axis = params[pop_node.list_output_names()[0]].asnumpy()[0] | ||
| return AttrCvt(op_name="split", | ||
| ignores=['num_split'], | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggest to use transforms instead of ignores & extras.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have you reached consensus the changes here is good? |
||
| extras={'indices_or_sections':attr['num_split'], | ||
| 'axis': axis})(inputs, attr) | ||
| return _impl | ||
|
|
||
| def _lrn(): | ||
| def _impl(inputs, attr, params): | ||
| attr_new = {} | ||
| depth_radius = attr.get('depth_radius', 5) | ||
| depth_radius = attr.get('depth_radius', 2) | ||
| size = (depth_radius * 2) + 1 | ||
| attr_new['axis'] = 3 # Fix axis, NHWC format | ||
| attr_new['size'] = size | ||
|
|
@@ -493,7 +507,7 @@ def _impl(inputs, attr, params): | |
| new_axis_mask = int(attr.get('new_axis_mask', 0)) | ||
| shrink_axis_mask = int(attr.get('shrink_axis_mask', 0)) | ||
| data_shape = attr['_input_shapes'][inputs[0]] | ||
| data_dim = len(data_shape[0]) | ||
| data_dim = len(data_shape) | ||
| stride_dim = len(stride) | ||
|
|
||
| def _transform_mask(stride_dim, ellipsis_mask): | ||
|
|
@@ -523,26 +537,26 @@ def _transform_mask(stride_dim, ellipsis_mask): | |
| + new_axes_after_ellipsis), data_dim) | ||
| for i in range(final_index, to_index): | ||
| m_begin[final_index] = 0 | ||
| m_end[final_index] = data_shape[0][final_index] | ||
| m_end[final_index] = data_shape[final_index] | ||
| m_stride[final_index] = 1 | ||
| final_index += 1 | ||
| elif not mask & new_axis_mask: | ||
| if final_index == len(m_begin): | ||
| break | ||
| if mask & begin_mask: | ||
| m_begin[final_index] = data_shape[0][final_index] \ | ||
| m_begin[final_index] = data_shape[final_index] \ | ||
| if stride[index] < 0 else 0 | ||
| elif begin[index]: | ||
| m_begin[final_index] = begin[index] | ||
| if mask & end_mask: | ||
| m_end[final_index] = 0 if stride[index] < 0 \ | ||
| else data_shape[0][final_index] | ||
| else data_shape[final_index] | ||
| elif end[index]: | ||
| m_end[final_index] = end[index] | ||
| m_stride[final_index] = stride[index] | ||
| if mask & shrink_axis_mask: | ||
| #Tensorflow make axis with shrink_axis_mask as dimension 1 | ||
| m_begin[final_index] = data_shape[0][final_index] + begin[index] \ | ||
| m_begin[final_index] = data_shape[final_index] + begin[index] \ | ||
| if begin[index] < 0 else begin[index] | ||
| m_end[final_index] = begin[index] + 1 | ||
| m_stride[final_index] = 1 | ||
|
|
@@ -603,8 +617,8 @@ def _impl(inputs, in_state_c, in_state_h, attr, params): | |
| forget_bias = attr.pop('forget_bias') | ||
| input_shape = attr['_input_shapes'][inputs[0]] | ||
| weight_shape = attr['_input_shapes'][inputs[3]] | ||
| batch_size, input_size = input_shape[0][0], input_shape[0][1] | ||
| num_hidden_layers = weight_shape[0][1] | ||
| batch_size, input_size = input_shape[0], input_shape[1] | ||
| num_hidden_layers = weight_shape[1] | ||
| num_hidden = num_hidden_layers // 4 | ||
|
|
||
| in_data = _sym.reshape(in_data, | ||
|
|
@@ -695,6 +709,7 @@ def _impl(inputs, attr, params): | |
| 'Fill' : _fill(), | ||
| 'GatherV2' : _gather_v2(), | ||
| 'StridedSlice' : _stridedSlice(), | ||
| 'Split' : _split(), | ||
| 'LRN' : _lrn(), | ||
| 'Pad' : _pad('Pad'), | ||
| 'PadV2' : _pad('PadV2'), | ||
|
|
@@ -798,8 +813,8 @@ def _LSTMBlockCellWrapper(inputs, attr, params, | |
| """LSTM cell warapper to prepare the inputs""" | ||
| input_shape = attr['_input_shapes'][inputs[0]] | ||
| weight_shape = attr['_input_shapes'][inputs[3]] | ||
| batch_size = input_shape[0][0] | ||
| num_hidden = weight_shape[0][1] // 4 | ||
| batch_size = input_shape[0] | ||
| num_hidden = weight_shape[1] // 4 | ||
|
|
||
| if layer == 0: | ||
| #Create initial states placeholder in case of first layer | ||
|
|
@@ -982,24 +997,31 @@ def from_tensorflow(self, graph): | |
| # Pass the node name too in attr | ||
| attr["_node_name"] = node.name | ||
|
|
||
| #ToDo: Some of the tensorflow operators internaly maintain | ||
| #execution layers and its output name will the layer number along with | ||
| #graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the | ||
| #output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case, | ||
| #the digit has to be ignored. | ||
| if ":" in node.input[0]: | ||
| in_name, _ = node.input[0].split(':') | ||
| node.input[0] = in_name | ||
|
|
||
| # Fill shapes for all inputs in a list | ||
| try: | ||
| inputs = [self._nodes[i] for i in node.input] | ||
| for i in node.input: | ||
| input_shapes[self._nodes[i]] = self._output_shapes[i] | ||
| attr['_input_shapes'] = input_shapes | ||
| except KeyError: | ||
| # TODO: Need to find clean way to handle '^CheckNumerics' | ||
| pass | ||
| inputs = [] | ||
| for node_input_name in node.input: | ||
| node_input_key = node_input_name.split(':') | ||
| slot_num = 0 | ||
| if len(node_input_key) > 1: | ||
| slot_num = int(node_input_key[1]) | ||
| node_input_key = node_input_key[0] | ||
|
|
||
| try: | ||
| try: | ||
| input_sym = self._nodes[node_input_key].__getitem__(slot_num) | ||
| except NNVMError: | ||
| # TODO: Fancy node name with invalid slot should discard and | ||
| # retrieve node input with zero'th(default) index. | ||
| # eg: Node name:- 'Model/RNN/cell_0/RnnCell:6', in this case | ||
| # name had fancy name convention and discard slot-id. | ||
| input_sym = self._nodes[node_input_key].__getitem__(0) | ||
|
|
||
| inputs.append(input_sym) | ||
| input_shapes[input_sym] = self._output_shapes[ | ||
| node_input_key].__getitem__(slot_num) | ||
| attr['_input_shapes'] = input_shapes | ||
| except KeyError: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This except is for the same statement below where the Suggest to make one try block with multiple except instead of nested.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as address to review comment, we need nested try block to handle fancy node name convention logic.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's the same statement which cause 2 exceptions in two different cases, one try on
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. First try block we are trying to access the slot base input and after caught NNVMError we are trying to access default(0) location input, in this case if we caught error should be thrown to user. |
||
| # TODO: Need to find clean way to handle '^CheckNumerics' | ||
| pass | ||
|
|
||
| inputs = self._fix_extranodes(node.op, attr, inputs) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -71,17 +71,19 @@ def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype) | |
| def run_tf_graph(sess, input_data, input_node, output_node): | ||
| """ Generic function to execute tensorflow """ | ||
|
|
||
| tensor = sess.graph.get_tensor_by_name(output_node) | ||
|
|
||
| if isinstance(input_data, list): | ||
| input_dict = {} | ||
| for i, e in enumerate(input_node): | ||
| input_dict[e] = input_data[i] | ||
| else: | ||
| input_dict = {input_node: input_data} | ||
|
|
||
| output_data = sess.run(tensor, input_dict) | ||
| return output_data | ||
| if isinstance(output_node, list): | ||
| tensor = [sess.graph.get_tensor_by_name(node_name) for node_name in output_node] | ||
| return sess.run(tensor, input_dict) | ||
|
|
||
| tensor = sess.graph.get_tensor_by_name(output_node) | ||
| return sess.run(tensor, input_dict) | ||
|
|
||
|
|
||
| def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False): | ||
|
|
@@ -787,29 +789,72 @@ def _get_sample(data, state): | |
| np.testing.assert_allclose(tf_samples, tvm_samples, rtol=1e-5, atol=1e-5) | ||
| assert(tvm_sample_str == tf_sample_str) | ||
|
|
||
| ####################################################################### | ||
| # LRN (Local Response Normalization) | ||
| # ---------------------------------- | ||
|
|
||
| def _test_lrn(ishape, size, axis, bias, alpha, beta): | ||
| """ testing local response normalization """ | ||
| lrn_depth_radius = size / 2 | ||
| ####################################################################### | ||
| # Split | ||
| # ------ | ||
| def _test_split(ip_shape, num_split, axis): | ||
| tf.reset_default_graph() | ||
| dtype = 'float32' | ||
| in_data = tf.placeholder(dtype, ip_shape, name="in_data") | ||
| tf.split(in_data, num_split, axis=axis, name="split") | ||
| np_data = np.random.uniform(size=ip_shape).astype(dtype) | ||
| with tf.Session() as sess: | ||
| final_graph_def = tf.graph_util.convert_variables_to_constants( | ||
| sess, | ||
| sess.graph.as_graph_def(add_shapes=True), | ||
| ['split']) | ||
|
|
||
| tf_out_node_names = ['split:%g' % i for i in range(num_split)] | ||
| tf_output = run_tf_graph(sess, [np_data], ['in_data:0'], tf_out_node_names) | ||
| tvm_out_shapes = [tf_output[i].shape for i in range(num_split)] | ||
| tvm_out_dtypes = [tf_output[i].dtype for i in range(num_split)] | ||
| tvm_output = run_tvm_graph(final_graph_def, [np_data], ['in_data'], | ||
| tvm_out_shapes, tvm_out_dtypes) | ||
| np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5) | ||
| sess.close() | ||
|
|
||
| inp_array = np.random.uniform(size=ishape).astype(np.float32) | ||
| def test_forward_split(): | ||
| '''test split operator''' | ||
| _test_split((2, 3), 2, axis=0) | ||
| _test_split((6, 3), 3, axis=0) | ||
| _test_split((5, 9, 3), 3, axis=1) | ||
| _test_split((2,5,3,9), 3, axis=3) | ||
|
|
||
| with tf.Graph().as_default(): | ||
| in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype, name="lrn0_data") | ||
| nn_ops.local_response_normalization(in1, | ||
| name="lrn", | ||
| depth_radius=lrn_depth_radius, | ||
| bias=bias, | ||
| alpha=alpha, | ||
| beta=beta) | ||
|
|
||
| compare_tf_with_tvm(inp_array, 'lrn0_data:0', 'lrn:0') | ||
| ####################################################################### | ||
| # LRN (Local Response Normalization) | ||
| # ---------------------------------- | ||
| def _test_lrn(ip_shape, depth_radius=2, alpha=1e-05, beta=0.75, bias=1.0): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ref. we don't need to explicitly set the default values. TF can handle them if not set. Suggest to ignore to allow TF default values play role here. |
||
| tf.reset_default_graph() | ||
| dtype = 'float32' | ||
| in_data = tf.placeholder(dtype, ip_shape, name="in_data") | ||
| tf.nn.local_response_normalization(in_data, | ||
| depth_radius=depth_radius, | ||
| alpha=alpha, | ||
| beta=beta, | ||
| bias=bias, | ||
| name="local_response_normalization") | ||
| np_data = np.random.uniform(size=ip_shape).astype(dtype) | ||
| with tf.Session() as sess: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #1546 adds Suggest to use it. |
||
| final_graph_def = tf.graph_util.convert_variables_to_constants( | ||
| sess, | ||
| sess.graph.as_graph_def(add_shapes=True), | ||
| ['local_response_normalization']) | ||
| tf_output = run_tf_graph(sess, [np_data], ['in_data:0'], | ||
| 'local_response_normalization:0') | ||
| tvm_output = run_tvm_graph(final_graph_def, [np_data], ['in_data'], | ||
| tf_output.shape, dtype) | ||
| np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5) | ||
| sess.close() | ||
|
|
||
| def test_forward_lrn(): | ||
| _test_lrn((1, 3, 20, 20), 3, 1, 1.0, 1.0, 0.5) | ||
| '''test local_response_normalization operator''' | ||
| _test_lrn((4, 4, 4, 4)) | ||
| _test_lrn((4,3,5,6), 1, 1e-05, 0.5) | ||
| _test_lrn((1, 3, 20, 20), 2, 1e-05, 0.5, 1.0) | ||
| _test_lrn((1, 3, 20, 20), 2, 1e-05, 0.75, 2.0) | ||
|
|
||
|
|
||
| ####################################################################### | ||
| # l2_normalize | ||
|
|
@@ -855,6 +900,7 @@ def test_forward_l2_normalize(): | |
| test_forward_resize_bilinear() | ||
| test_forward_pad() | ||
| test_forward_lstm() | ||
| test_forward_split() | ||
| test_forward_stridedslice() | ||
| test_forward_gather() | ||
| test_forward_ptb() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ref. Split
tensorflow has two variations for split as
num_splitandsize_split.Please handle if possible or raise exception accordingly.
If handling size_split please handle
numattribute if possible or raise exception accordingly.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Split had num_split attribute and SplitV had size_splits attribute. As of this PR supported Split operator