diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index d35e0e1c203d..1e2a2d4f826f 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -37,6 +37,7 @@ from .common import infer_type as _infer_type from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value +from .common import set_span from .tensorflow_ops import _convert_map from .tensorflow_ops import _need_prelude_for_shape_inference @@ -328,7 +329,7 @@ def _while_loop(self): `while_loop` construct. """ bind_map = {} - wl = tvm.relay.var("while_loop") + wl = set_span(tvm.relay.var("while_loop"), self._loop_name) sb = tvm.relay.scope_builder.ScopeBuilder() lv_list = [] @@ -345,7 +346,7 @@ def _while_loop(self): if lv not in self._lvar2expr[self._loop_name]: var_name = "{}_loop_var_{}".format(self._loop_name, i) var_type = _infer_type(lv, self._mod).checked_type - loop_var = tvm.relay.var(var_name, type_annotation=var_type) + loop_var = set_span(tvm.relay.var(var_name, type_annotation=var_type), var_name) self._lvar2expr[self._loop_name][loop_var] = lv bind_map[lv] = loop_var self.loop_vars[i] = loop_var @@ -358,7 +359,7 @@ def _while_loop(self): self.cond = rewrite_subgraph(self.cond, bind_map) self.body = [rewrite_subgraph(b, bind_map) for b in self.body] - cond = tvm.relay.op.min(self.cond) + cond = set_span(tvm.relay.op.min(self.cond), self.cond.span) for lv, exp in self._lvar2expr[self._loop_name].items(): if lv not in self.loop_vars: @@ -517,8 +518,11 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None): self._output_shapes[node.name] = [self._input_shapes[node.name]] attr = self._parse_attr(node.attr) self._nodes[node.name] = [ - _expr.var( - node.name, shape=self._input_shapes[node.name], dtype=attr["dtype"].name + set_span( + _expr.var( + node.name, shape=self._input_shapes[node.name], dtype=attr["dtype"].name + ), + node.name, ) ] @@ -708,16 +712,23 @@ def _parse_param(self, key, value, name, shape): var_shape = shape[name] else: var_shape = tensor_util.TensorShapeProtoToList(value.tensor.tensor_shape) - self._nodes[name] = [_expr.var(name, shape=var_shape, dtype="uint8")] + self._nodes[name] = [ + set_span(_expr.var(name, shape=var_shape, dtype="uint8"), span=name) + ] return array_ndim = len(np_array.shape) if array_ndim == 0: - self._nodes[name] = [tvm.relay.const(np_array, np_array.dtype)] + self._nodes[name] = [set_span(tvm.relay.const(np_array, np_array.dtype), name)] else: self._params[name] = tvm.nd.array(np_array) self._nodes[name] = [ - _expr.var(name, shape=self._params[name].shape, dtype=self._params[name].dtype) + set_span( + _expr.var( + name, shape=self._params[name].shape, dtype=self._params[name].dtype + ), + name, + ) ] else: if key not in ("dtype", "_output_shapes", "_class"): @@ -998,6 +1009,8 @@ def _convert_operator( ---------- op_name : str Operator name, such as Conv2D, AvgPool + node_name : str + Node name, predefined by user or default setting of TF inputs : list of relay.op List of input symbols. attrs : dict @@ -1028,22 +1041,8 @@ def _convert_operator( else: raise NotImplementedError("Operator {} not implemented.".format(op_name)) - sym = self._set_span(sym, node_name) - - return sym + sym = set_span(sym, node_name) - @staticmethod - def _set_span(sym, node_name): - span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0) - if isinstance(sym, _expr.Call) and sym.span is None: - sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span) - elif isinstance(sym, _expr.TupleWrapper): - tuple_value = sym.tuple_value - if isinstance(tuple_value, _expr.Call) and tuple_value.span is None: - tuple_value = _expr.Call( - tuple_value.op, tuple_value.args, tuple_value.attrs, tuple_value.type_args, span - ) - sym = _expr.TupleWrapper(tuple_value, sym.size) return sym def _licm_construct(self, loop_name, node_name): @@ -1079,7 +1078,7 @@ def _licm_construct(self, loop_name, node_name): if node_name not in self._lname_map[loop_name]: var_name = "{}_loop_var".format(node_name) var_type = _infer_type(actual_expr, self._mod).checked_type - loop_var = tvm.relay.var(var_name, type_annotation=var_type) + loop_var = set_span(tvm.relay.var(var_name, type_annotation=var_type), var_name) try: extra_param = _infer_value(actual_expr, self._params, self._mod) self._params[var_name] = extra_param @@ -1183,10 +1182,13 @@ def _backtrack_construct(self, node_name): if isinstance(op, np.ndarray): self._params[node.name] = tvm.nd.array(op) op = [ - _expr.var( + set_span( + _expr.var( + node.name, + shape=self._params[node.name].shape, + dtype=self._params[node.name].dtype, + ), node.name, - shape=self._params[node.name].shape, - dtype=self._params[node.name].dtype, ) ] diff --git a/tests/python/frontend/tensorflow/test_bn_dynamic.py b/tests/python/frontend/tensorflow/test_bn_dynamic.py index 55555e885a60..df7052008821 100644 --- a/tests/python/frontend/tensorflow/test_bn_dynamic.py +++ b/tests/python/frontend/tensorflow/test_bn_dynamic.py @@ -65,7 +65,11 @@ def verify_fused_batch_norm(shape): if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) continue - mod, params = relay.frontend.from_tensorflow(constant_graph, outputs=["output"]) + with tvm.testing.disable_span_filling(): + mod, params = relay.frontend.from_tensorflow(constant_graph, outputs=["output"]) + with tvm.testing.enable_span_filling(): + mod_with_span, _ = relay.frontend.from_tensorflow(constant_graph, outputs=["output"]) + assert tvm.ir.structural_equal(mod["main"], mod_with_span["main"]) with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(mod, target=device, params=params) from tvm.contrib import graph_executor diff --git a/tests/python/frontend/tensorflow/test_control_flow.py b/tests/python/frontend/tensorflow/test_control_flow.py index 49dc5170c52f..494deb46835f 100644 --- a/tests/python/frontend/tensorflow/test_control_flow.py +++ b/tests/python/frontend/tensorflow/test_control_flow.py @@ -25,13 +25,17 @@ import tensorflow as tf from tensorflow.python.ops import control_flow_ops import numpy as np -from tvm import nd -from tvm import relay +from tvm import nd, relay, ir, testing from tvm.relay.frontend.tensorflow import from_tensorflow def check_equal(graph, tf_out, input_map=None): - mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) + with testing.disable_span_filling(): + mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) + with testing.enable_span_filling(): + mod_with_span, _ = from_tensorflow(graph.as_graph_def(add_shapes=True)) + assert ir.structural_equal(mod["main"], mod_with_span["main"]) + if input_map is not None: params.update(input_map) relay_out = relay.create_executor("vm", mod=mod).evaluate()(**params) diff --git a/tests/python/frontend/tensorflow/test_debugging.py b/tests/python/frontend/tensorflow/test_debugging.py index 0e08840e56ee..0f7c4dd7d65a 100644 --- a/tests/python/frontend/tensorflow/test_debugging.py +++ b/tests/python/frontend/tensorflow/test_debugging.py @@ -22,12 +22,19 @@ except ImportError: import tensorflow as tf import numpy as np -from tvm import relay +from tvm import relay, ir, testing from tvm.relay.frontend.tensorflow import from_tensorflow def run_relay(graph, shape_dict=None, *vars): - mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True), shape=shape_dict) + with testing.disable_span_filling(): + mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True), shape=shape_dict) + with testing.enable_span_filling(): + mod_with_span, _ = relay.frontend.from_tensorflow( + graph.as_graph_def(add_shapes=True), shape=shape_dict + ) + assert ir.structural_equal(mod["main"], mod_with_span["main"]) + return relay.create_executor("debug", mod=mod).evaluate()(*vars) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index dce18ee231d3..2fb7c74f60a1 100755 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -31,11 +31,12 @@ import pytest from PIL import Image -from tvm import relay +from tvm import relay, ir from tvm.runtime.vm import VirtualMachine from tvm.relay.frontend.tensorflow import from_tensorflow from tvm.contrib import graph_executor from tvm.contrib import utils +from relay.utils.tag_span import _set_span, _create_span, _verify_structural_equal_with_span import tvm import tvm.relay.testing.tf as tf_testing @@ -149,13 +150,23 @@ def run_tvm_graph( shape_dict = { e: i.shape if hasattr(i, "shape") else () for e, i in zip(input_node, input_data) } - mod, params = relay.frontend.from_tensorflow( - graph_def, - layout=layout, - shape=shape_dict, - outputs=out_names, - convert_config=convert_config, - ) + with tvm.testing.disable_span_filling(): + mod, params = relay.frontend.from_tensorflow( + graph_def, + layout=layout, + shape=shape_dict, + outputs=out_names, + convert_config=convert_config, + ) + with tvm.testing.enable_span_filling(): + mod_with_span, _ = relay.frontend.from_tensorflow( + graph_def, + layout=layout, + shape=shape_dict, + outputs=out_names, + convert_config=convert_config, + ) + assert tvm.ir.structural_equal(mod["main"], mod_with_span["main"], map_free_vars=True) dev = tvm.device(target, 0) if mode == "debug": @@ -1804,9 +1815,15 @@ def test_read_variable_op(target, dev): shape_dict = {e: i.shape for e, i in zip(in_name, in_data)} with pytest.raises(Exception) as execinfo: - _, _ = relay.frontend.from_tensorflow( - final_graph_def, layout=None, shape=shape_dict, outputs=None - ) + with tvm.testing.disable_span_filling(): + mod, _ = relay.frontend.from_tensorflow( + final_graph_def, layout=None, shape=shape_dict, outputs=None + ) + with tvm.testing.enable_span_filling(): + mod_with_span, _ = relay.frontend.from_tensorflow( + final_graph_def, layout=None, shape=shape_dict, outputs=None + ) + assert tvm.ir.structural_equal(mod["main"], mod_with_span["main"]) assert execinfo.value.args[0].startswith("Graph is not frozen. Provide a frozen graph") @@ -4072,17 +4089,31 @@ def _get_tvm_graph_module(graph_def): # Cell inputs 'c and 'h' consist of all layers values shape_dict = {"Model/Placeholder": (batch_size, num_steps)} - mod, params = relay.frontend.from_tensorflow( - graph_def, - shape=shape_dict, - outputs=[ - "Model/Softmax:0", - "Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:1", - "Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:6", - "Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:1", - "Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:6", - ], - ) + with tvm.testing.disable_span_filling(): + mod, params = relay.frontend.from_tensorflow( + graph_def, + shape=shape_dict, + outputs=[ + "Model/Softmax:0", + "Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:1", + "Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:6", + "Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:1", + "Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:6", + ], + ) + with tvm.testing.enable_span_filling(): + mod_with_span, _ = relay.frontend.from_tensorflow( + graph_def, + shape=shape_dict, + outputs=[ + "Model/Softmax:0", + "Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:1", + "Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:6", + "Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:1", + "Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:6", + ], + ) + assert tvm.ir.structural_equal(mod["main"], mod_with_span["main"]) target = "llvm" with tvm.transform.PassContext(opt_level=0): @@ -5723,7 +5754,12 @@ def test_moments(): mean, variance = tf.nn.moments(A, [1], keep_dims=True) _ = (A - mean) / tf.sqrt(variance + 0.0005) - mod, _ = from_tensorflow(g.as_graph_def(add_shapes=True)) + with tvm.testing.disable_span_filling(): + mod, _ = from_tensorflow(g.as_graph_def(add_shapes=True)) + with tvm.testing.enable_span_filling(): + mod_with_span, _ = from_tensorflow(g.as_graph_def(add_shapes=True)) + assert tvm.ir.structural_equal(mod["main"], mod_with_span["main"], map_free_vars=True) + program = """ def @main(%A: Tensor[(4, 176, 8, 8), float32]) { %527 = mean(%A, axis=[1], keepdims=True) /* moments/mean */; @@ -5834,5 +5870,181 @@ def test_forward_dense_bincount(): _test_dense_bincount((10,), 20, None, binary_output) +####################################################################### +# Test structural_equal and span of a model +# -------------------------------------- +class TestSetSpan: + """Test Structure and span of frequently-used models""" + + def _verify(self, res_fptr, golden_fptr): + with tvm.testing.enable_span_filling(): + with_span = res_fptr() + with tvm.testing.disable_span_filling(): + without_span = res_fptr() + assert tvm.ir.structural_equal(with_span, without_span) + _verify_structural_equal_with_span(with_span, golden_fptr()) + + def test_conv2d_bias_add_span(self): + """Test Structure and span of conv2d and bias add model match to the expected result""" + + def _res(): + in_shape = (1, 5, 5, 1) + kernel_shpae = (2, 2, 1, 2) + kernel_in = np.ones(kernel_shpae) + bias_val_shape = tuple([2]) + bias_val_in = np.ones(bias_val_shape) + + with tf.Graph().as_default() as g: + x = array_ops.placeholder(shape=in_shape, dtype="float32", name="input") + kernel = tf.constant(kernel_in, dtype=tf.float32, name="filter_weight") + bias_val_tensor = tf.constant(bias_val_in, dtype=tf.float32, name="conv2d_bias") + conv2d = tf.nn.conv2d( + x, kernel, strides=[1, 1, 1, 1], padding="VALID", name="conv2d" + ) + _ = tf.nn.bias_add(conv2d, bias_val_tensor, name="bias_add") + + mod, _ = relay.frontend.from_tensorflow( + g.as_graph_def(), shape={"input": in_shape}, outputs=["bias_add"] + ) + return mod["main"] + + def _golden(): + model_in = relay.var( + "input", relay.TensorType([1, 5, 5, 1]), span=_create_span("input") + ) + weight = relay.var( + "filter_weight", relay.TensorType([2, 2, 1, 2]), span=_create_span("filter_weight") + ) + bias = relay.var("conv2d_bias", relay.TensorType([2]), span=_create_span("conv2d_bias")) + conv2d = _set_span( + relay.nn.conv2d( + model_in, + weight, + channels=2, + kernel_size=[2, 2], + data_layout="NHWC", + kernel_layout="HWIO", + ), + "conv2d", + ) + add = _set_span(relay.op.add(conv2d, bias), "bias_add") + mod = ir.IRModule.from_expr(add) + return mod["main"] + + self._verify(_res, _golden) + + def test_fully_connected_bias_add_span(self): + """Test Structure and span of fully connected model match to the expected result""" + + def _res(): + in_shape = (1, 10) + kernel_shpae = (10, 10) + kernel_in = np.ones(kernel_shpae) + bias_val_shape = tuple([10]) + bias_val_in = np.ones(bias_val_shape) + + with tf.Graph().as_default() as g: + x = array_ops.placeholder(shape=in_shape, dtype="float32", name="input") + in_filter = tf.constant(kernel_in, dtype=tf.float32, name="filter_weight") + bias_val_tensor = tf.constant(bias_val_in, dtype=tf.float32, name="dense_bias") + mat_mul = math_ops.mat_mul(x, in_filter, name="dense") + _ = tf.nn.bias_add(mat_mul, bias_val_tensor, name="bias_add") + + mod, _ = relay.frontend.from_tensorflow( + g.as_graph_def(), + shape={"input": in_shape}, + outputs=["bias_add"], + convert_config={"use_dense": True}, + ) + return mod["main"] + + def _golden(): + model_in = relay.var("input", relay.TensorType([1, 10]), span=_create_span("input")) + weight = relay.var( + "filter_weight", relay.TensorType([10, 10]), span=_create_span("filter_weight") + ) + bias = relay.var("dense_bias", relay.TensorType([10]), span=_create_span("dense_bias")) + transpose = _set_span(relay.transpose(weight, [1, 0]), "dense") + dense = _set_span(relay.nn.dense(model_in, transpose, units=10), "dense") + add = _set_span(relay.op.add(dense, bias), "bias_add") + mod = ir.IRModule.from_expr(add) + return mod["main"] + + self._verify(_res, _golden) + + def test_reshape_span(self): + """Test Structure and span of reshape model match to the expected result""" + + def _res(): + in_shape = (1, 10) + output_shape = (2, 5) + + with tf.Graph().as_default() as g: + x = array_ops.placeholder(shape=in_shape, dtype="float32", name="input") + _ = array_ops.reshape(x, output_shape, "reshape") + + mod, _ = relay.frontend.from_tensorflow( + g.as_graph_def(), shape={"input": in_shape}, outputs=["reshape"] + ) + return mod["main"] + + def _golden(): + model_in = relay.var("input", relay.TensorType([1, 10]), span=_create_span("input")) + reshape = _set_span(relay.reshape(model_in, [2, 5]), "reshape") + mod = ir.IRModule.from_expr(reshape) + return mod["main"] + + self._verify(_res, _golden) + + def test_batch_norm_span(self): + """Test Structure and span of batchnorm model match to the expected result""" + + def _res(): + in_shape = (1, 12, 12, 32) + with tf.Graph().as_default() as g: + input_tensor = tf.placeholder(tf.float32, shape=in_shape, name="input") + alpha = tf.constant( + np.ones( + in_shape[-1], + ), + dtype=tf.float32, + name="alpha", + ) + beta = tf.constant( + np.ones( + in_shape[-1], + ), + dtype=tf.float32, + name="beta", + ) + _ = tf.nn.fused_batch_norm(x=input_tensor, offset=beta, scale=alpha, name="bn") + mod, _ = relay.frontend.from_tensorflow( + g.as_graph_def(), shape={"input": in_shape}, outputs=["bn"] + ) + return mod["main"] + + def _golden(): + model_in = relay.var( + "input", relay.TensorType([1, 12, 12, 32]), span=_create_span("input") + ) + alpha = relay.var("alpha", relay.TensorType([32]), span=_create_span("alpha")) + beta = relay.var("beta", relay.TensorType([32]), span=_create_span("beta")) + mean = _set_span(relay.op.mean(model_in, axis=[3], exclude=True), "bn") + variance_mean = _set_span( + relay.op.mean(model_in, axis=[3], keepdims=True, exclude=True), "bn" + ) + variance = _set_span( + relay.op._make._variance(model_in, variance_mean, [3], False, True, False), "bn" + ) + bn = _set_span( + relay.nn.batch_norm(model_in, alpha, beta, mean, variance, axis=3, epsilon=0.001), + "bn", + ) + mod = ir.IRModule.from_expr(bn[0]) + return mod["main"] + + self._verify(_res, _golden) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/frontend/tensorflow/test_no_op.py b/tests/python/frontend/tensorflow/test_no_op.py index 4f8583f71cff..bc6be5c3059c 100644 --- a/tests/python/frontend/tensorflow/test_no_op.py +++ b/tests/python/frontend/tensorflow/test_no_op.py @@ -22,12 +22,17 @@ except ImportError: import tensorflow as tf import numpy as np -from tvm import relay +from tvm import relay, ir, testing from tvm.relay.frontend.tensorflow import from_tensorflow def run_relay(graph): - mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) + with testing.disable_span_filling(): + mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) + with testing.enable_span_filling(): + mod_with_span, _ = relay.frontend.from_tensorflow(graph.as_graph_def(add_shapes=True)) + assert ir.structural_equal(mod["main"], mod_with_span["main"]) + return relay.create_executor("debug", mod=mod).evaluate()(**params)