Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 69 additions & 47 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np

import tvm
from nnvm import NNVMError
from .. import symbol as _sym
from .. import graph as _graph
from .. compiler import graph_util
Expand Down Expand Up @@ -133,7 +134,7 @@ def _impl(inputs, attr, params):
attr['strides'] = (attr['strides'][1], attr['strides'][2])

# Fix padding
input_shapes = attr['_input_shapes'][inputs[0]]
input_shape = attr['_input_shapes'][inputs[0]]
attr['padding'] = attr['padding'].decode("utf-8")

if attr['padding'] == 'VALID':
Expand All @@ -142,11 +143,11 @@ def _impl(inputs, attr, params):
stride_h, stride_w = attr['strides']
kernel_h, kernel_w = attr['kernel_shape']
if attr['data_format'] == 'NHWC':
in_h = input_shapes[0][1]
in_w = input_shapes[0][2]
in_h = input_shape[1]
in_w = input_shape[2]
else:
in_h = input_shapes[0][2]
in_w = input_shapes[0][3]
in_h = input_shape[2]
in_w = input_shape[3]

pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
Expand All @@ -171,28 +172,31 @@ def _impl(inputs, attr, params):
def _conv(opname):
def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")
input_shapes = attr['_input_shapes'][inputs[0]]
input_shape = attr['_input_shapes'][inputs[0]]

# Extract kernel shape from params
conv_param_weights = params[inputs[1].list_output_names()[0]]
if inputs[1] in attr['_input_shapes']:
weight_shape = tuple(attr['_input_shapes'][inputs[1]])
else:
weight_shape = params[inputs[1].list_output_names()[0]].shape

if attr['data_format'] == 'NHWC':
kernel_h, kernel_w, _, depth_mult = conv_param_weights.shape
attr['kernel_shape'] = (conv_param_weights.shape[0], conv_param_weights.shape[1])
kernel_h, kernel_w, _, depth_mult = weight_shape
attr['kernel_shape'] = (weight_shape[0], weight_shape[1])
if opname == 'conv':
attr['channels'] = conv_param_weights.shape[3]
attr['channels'] = weight_shape[3]
else:
attr['channels'] = input_shapes[0][3] * depth_mult
attr['channels'] = input_shape[3] * depth_mult

if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][0], attr['dilations'][1])
elif attr['data_format'] == 'NCHW':
depth_mult, _, kernel_h, kernel_w = conv_param_weights.shape
attr['kernel_shape'] = (conv_param_weights.shape[2], conv_param_weights.shape[3])
depth_mult, _, kernel_h, kernel_w = weight_shape
attr['kernel_shape'] = (weight_shape[2], weight_shape[3])
if opname == 'conv':
attr['channels'] = conv_param_weights.shape[1]
attr['channels'] = weight_shape[1]
else:
attr['channels'] = input_shapes[0][1] * depth_mult
attr['channels'] = input_shape[1] * depth_mult

if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])
Expand All @@ -215,11 +219,11 @@ def _impl(inputs, attr, params):
stride_h, stride_w = attr['strides']
kernel_h, kernel_w = attr['kernel_shape']
if attr['data_format'] == 'NHWC':
in_h = input_shapes[0][1]
in_w = input_shapes[0][2]
in_h = input_shape[1]
in_w = input_shape[2]
else:
in_h = input_shapes[0][2]
in_w = input_shapes[0][3]
in_h = input_shape[2]
in_w = input_shape[3]

pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
Expand Down Expand Up @@ -428,10 +432,20 @@ def _impl(inputs, attr, params):
ignores=['index_type', 'T'])(new_inputs, attr)
return _impl

def _split():
def _impl(inputs, attr, params):
pop_node = inputs.pop(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ref. Split
tensorflow has two variations for split as num_split and size_split.
Please handle if possible or raise exception accordingly.

If handling size_split please handle num attribute if possible or raise exception accordingly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Split had num_split attribute and SplitV had size_splits attribute. As of this PR supported Split operator

axis = params[pop_node.list_output_names()[0]].asnumpy()[0]
return AttrCvt(op_name="split",
ignores=['num_split'],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to use transforms instead of ignores & extras.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you reached consensus the changes here is good?

extras={'indices_or_sections':attr['num_split'],
'axis': axis})(inputs, attr)
return _impl

def _lrn():
def _impl(inputs, attr, params):
attr_new = {}
depth_radius = attr.get('depth_radius', 5)
depth_radius = attr.get('depth_radius', 2)
size = (depth_radius * 2) + 1
attr_new['axis'] = 3 # Fix axis, NHWC format
attr_new['size'] = size
Expand Down Expand Up @@ -493,7 +507,7 @@ def _impl(inputs, attr, params):
new_axis_mask = int(attr.get('new_axis_mask', 0))
shrink_axis_mask = int(attr.get('shrink_axis_mask', 0))
data_shape = attr['_input_shapes'][inputs[0]]
data_dim = len(data_shape[0])
data_dim = len(data_shape)
stride_dim = len(stride)

def _transform_mask(stride_dim, ellipsis_mask):
Expand Down Expand Up @@ -523,26 +537,26 @@ def _transform_mask(stride_dim, ellipsis_mask):
+ new_axes_after_ellipsis), data_dim)
for i in range(final_index, to_index):
m_begin[final_index] = 0
m_end[final_index] = data_shape[0][final_index]
m_end[final_index] = data_shape[final_index]
m_stride[final_index] = 1
final_index += 1
elif not mask & new_axis_mask:
if final_index == len(m_begin):
break
if mask & begin_mask:
m_begin[final_index] = data_shape[0][final_index] \
m_begin[final_index] = data_shape[final_index] \
if stride[index] < 0 else 0
elif begin[index]:
m_begin[final_index] = begin[index]
if mask & end_mask:
m_end[final_index] = 0 if stride[index] < 0 \
else data_shape[0][final_index]
else data_shape[final_index]
elif end[index]:
m_end[final_index] = end[index]
m_stride[final_index] = stride[index]
if mask & shrink_axis_mask:
#Tensorflow make axis with shrink_axis_mask as dimension 1
m_begin[final_index] = data_shape[0][final_index] + begin[index] \
m_begin[final_index] = data_shape[final_index] + begin[index] \
if begin[index] < 0 else begin[index]
m_end[final_index] = begin[index] + 1
m_stride[final_index] = 1
Expand Down Expand Up @@ -603,8 +617,8 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
forget_bias = attr.pop('forget_bias')
input_shape = attr['_input_shapes'][inputs[0]]
weight_shape = attr['_input_shapes'][inputs[3]]
batch_size, input_size = input_shape[0][0], input_shape[0][1]
num_hidden_layers = weight_shape[0][1]
batch_size, input_size = input_shape[0], input_shape[1]
num_hidden_layers = weight_shape[1]
num_hidden = num_hidden_layers // 4

in_data = _sym.reshape(in_data,
Expand Down Expand Up @@ -695,6 +709,7 @@ def _impl(inputs, attr, params):
'Fill' : _fill(),
'GatherV2' : _gather_v2(),
'StridedSlice' : _stridedSlice(),
'Split' : _split(),
'LRN' : _lrn(),
'Pad' : _pad('Pad'),
'PadV2' : _pad('PadV2'),
Expand Down Expand Up @@ -798,8 +813,8 @@ def _LSTMBlockCellWrapper(inputs, attr, params,
"""LSTM cell warapper to prepare the inputs"""
input_shape = attr['_input_shapes'][inputs[0]]
weight_shape = attr['_input_shapes'][inputs[3]]
batch_size = input_shape[0][0]
num_hidden = weight_shape[0][1] // 4
batch_size = input_shape[0]
num_hidden = weight_shape[1] // 4

if layer == 0:
#Create initial states placeholder in case of first layer
Expand Down Expand Up @@ -982,24 +997,31 @@ def from_tensorflow(self, graph):
# Pass the node name too in attr
attr["_node_name"] = node.name

#ToDo: Some of the tensorflow operators internaly maintain
#execution layers and its output name will the layer number along with
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
#output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,
#the digit has to be ignored.
if ":" in node.input[0]:
in_name, _ = node.input[0].split(':')
node.input[0] = in_name

# Fill shapes for all inputs in a list
try:
inputs = [self._nodes[i] for i in node.input]
for i in node.input:
input_shapes[self._nodes[i]] = self._output_shapes[i]
attr['_input_shapes'] = input_shapes
except KeyError:
# TODO: Need to find clean way to handle '^CheckNumerics'
pass
inputs = []
for node_input_name in node.input:
node_input_key = node_input_name.split(':')
slot_num = 0
if len(node_input_key) > 1:
slot_num = int(node_input_key[1])
node_input_key = node_input_key[0]

try:
try:
input_sym = self._nodes[node_input_key].__getitem__(slot_num)
except NNVMError:
# TODO: Fancy node name with invalid slot should discard and
# retrieve node input with zero'th(default) index.
# eg: Node name:- 'Model/RNN/cell_0/RnnCell:6', in this case
# name had fancy name convention and discard slot-id.
input_sym = self._nodes[node_input_key].__getitem__(0)

inputs.append(input_sym)
input_shapes[input_sym] = self._output_shapes[
node_input_key].__getitem__(slot_num)
attr['_input_shapes'] = input_shapes
except KeyError:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This except is for the same statement below where the self._nodes doesn't have key node_input_key
input_sym = self._nodes[node_input_key].__getitem__(slot_num)

Suggest to make one try block with multiple except instead of nested.

Copy link
Contributor Author

@Dayananda-V Dayananda-V Aug 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as address to review comment, we need nested try block to handle fancy node name convention logic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the same statement which cause 2 exceptions in two different cases, one try on input_sym = self._nodes[node_input_key].__getitem__(slot_num) with two except should work. Pls check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First try block we are trying to access the slot base input and after caught NNVMError we are trying to access default(0) location input, in this case if we caught error should be thrown to user.

# TODO: Need to find clean way to handle '^CheckNumerics'
pass

inputs = self._fix_extranodes(node.op, attr, inputs)

Expand Down
88 changes: 67 additions & 21 deletions nnvm/tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,19 @@ def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype)
def run_tf_graph(sess, input_data, input_node, output_node):
""" Generic function to execute tensorflow """

tensor = sess.graph.get_tensor_by_name(output_node)

if isinstance(input_data, list):
input_dict = {}
for i, e in enumerate(input_node):
input_dict[e] = input_data[i]
else:
input_dict = {input_node: input_data}

output_data = sess.run(tensor, input_dict)
return output_data
if isinstance(output_node, list):
tensor = [sess.graph.get_tensor_by_name(node_name) for node_name in output_node]
return sess.run(tensor, input_dict)

tensor = sess.graph.get_tensor_by_name(output_node)
return sess.run(tensor, input_dict)


def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False):
Expand Down Expand Up @@ -787,29 +789,72 @@ def _get_sample(data, state):
np.testing.assert_allclose(tf_samples, tvm_samples, rtol=1e-5, atol=1e-5)
assert(tvm_sample_str == tf_sample_str)

#######################################################################
# LRN (Local Response Normalization)
# ----------------------------------

def _test_lrn(ishape, size, axis, bias, alpha, beta):
""" testing local response normalization """
lrn_depth_radius = size / 2
#######################################################################
# Split
# ------
def _test_split(ip_shape, num_split, axis):
tf.reset_default_graph()
dtype = 'float32'
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
tf.split(in_data, num_split, axis=axis, name="split")
np_data = np.random.uniform(size=ip_shape).astype(dtype)
with tf.Session() as sess:
final_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
['split'])

tf_out_node_names = ['split:%g' % i for i in range(num_split)]
tf_output = run_tf_graph(sess, [np_data], ['in_data:0'], tf_out_node_names)
tvm_out_shapes = [tf_output[i].shape for i in range(num_split)]
tvm_out_dtypes = [tf_output[i].dtype for i in range(num_split)]
tvm_output = run_tvm_graph(final_graph_def, [np_data], ['in_data'],
tvm_out_shapes, tvm_out_dtypes)
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
sess.close()

inp_array = np.random.uniform(size=ishape).astype(np.float32)
def test_forward_split():
'''test split operator'''
_test_split((2, 3), 2, axis=0)
_test_split((6, 3), 3, axis=0)
_test_split((5, 9, 3), 3, axis=1)
_test_split((2,5,3,9), 3, axis=3)

with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype, name="lrn0_data")
nn_ops.local_response_normalization(in1,
name="lrn",
depth_radius=lrn_depth_radius,
bias=bias,
alpha=alpha,
beta=beta)

compare_tf_with_tvm(inp_array, 'lrn0_data:0', 'lrn:0')
#######################################################################
# LRN (Local Response Normalization)
# ----------------------------------
def _test_lrn(ip_shape, depth_radius=2, alpha=1e-05, beta=0.75, bias=1.0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ref. we don't need to explicitly set the default values. TF can handle them if not set.

Suggest to ignore to allow TF default values play role here.

tf.reset_default_graph()
dtype = 'float32'
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
tf.nn.local_response_normalization(in_data,
depth_radius=depth_radius,
alpha=alpha,
beta=beta,
bias=bias,
name="local_response_normalization")
np_data = np.random.uniform(size=ip_shape).astype(dtype)
with tf.Session() as sess:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#1546 adds compare_tf_with_tvm which does comparing tf and tvm.

Suggest to use it.

final_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
['local_response_normalization'])
tf_output = run_tf_graph(sess, [np_data], ['in_data:0'],
'local_response_normalization:0')
tvm_output = run_tvm_graph(final_graph_def, [np_data], ['in_data'],
tf_output.shape, dtype)
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
sess.close()

def test_forward_lrn():
_test_lrn((1, 3, 20, 20), 3, 1, 1.0, 1.0, 0.5)
'''test local_response_normalization operator'''
_test_lrn((4, 4, 4, 4))
_test_lrn((4,3,5,6), 1, 1e-05, 0.5)
_test_lrn((1, 3, 20, 20), 2, 1e-05, 0.5, 1.0)
_test_lrn((1, 3, 20, 20), 2, 1e-05, 0.75, 2.0)


#######################################################################
# l2_normalize
Expand Down Expand Up @@ -855,6 +900,7 @@ def test_forward_l2_normalize():
test_forward_resize_bilinear()
test_forward_pad()
test_forward_lstm()
test_forward_split()
test_forward_stridedslice()
test_forward_gather()
test_forward_ptb()
Expand Down