From d6e27b9bf5958dd66930867147a58305f296c740 Mon Sep 17 00:00:00 2001 From: Rohan Date: Wed, 26 May 2021 18:59:29 +0000 Subject: [PATCH 01/13] adding tf control flow ops with a different frontend code Co-authored-by: David Huang Co-authored-by: Rohan Mukherjee Co-authored-by: Srinidhi Goud Co-authored-by: Xingyu Zhou Co-authored-by: Xiao --- python/tvm/relay/frontend/tensorflow2.py | 644 ++++++++++++++++++ tests/python/frontend/tensorflow2/common.py | 4 +- .../tensorflow2/test_functional_models.py | 71 ++ 3 files changed, 716 insertions(+), 3 deletions(-) create mode 100644 python/tvm/relay/frontend/tensorflow2.py diff --git a/python/tvm/relay/frontend/tensorflow2.py b/python/tvm/relay/frontend/tensorflow2.py new file mode 100644 index 000000000000..17a32e34583f --- /dev/null +++ b/python/tvm/relay/frontend/tensorflow2.py @@ -0,0 +1,644 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except +"""Tensorflow2.x graph to relay converter. + +If model is constructed using tf2.x API, then use this converter: + from tvm.relay.frontend.tensorflow2 import from_tensorflow +Otherwise use the tf1.x converter: + from tvm.relay.frontend.tensorflow import from_tensorflow + +""" + +import numpy as np + +import tvm +from tvm import relay +from tvm.relay.transform import InferType +from tvm.relay.prelude import Prelude +from tvm.ir import IRModule +from .. import expr as _expr +from .. import analysis +from .. import function as _function +from ..loops import while_loop as _while_loop +from .common import infer_shape as _infer_shape +from .common import infer_type as _infer_type + +from tensorflow.python.framework import function_def_to_graph +from tensorflow.python.framework import tensor_util +from tensorflow.python.framework import dtypes + +from .tensorflow import _convert_map as _convert_map_tf1 +from .tensorflow import _need_prelude_for_shape_inference + +from ..ty import Any, TensorType + +__all__ = ["from_tensorflow"] + +def _infer_type_with_prelude(val, prelude): + body = _infer_type(val, prelude.mod) + return body.checked_type + +def set_span(sym, node_name): + span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0) + if isinstance(sym, _expr.Call): + sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span) + elif isinstance(sym, _expr.TupleWrapper): + tuple_value = sym.tuple_value + if isinstance(tuple_value, _expr.Call): + tuple_value = _expr.Call( + tuple_value.op, tuple_value.args, tuple_value.attrs, tuple_value.type_args, span + ) + sym = _expr.TupleWrapper(tuple_value, sym.size) + return sym + + +def convert_const_node(node, shape): + """convert tf const node into relay const or var + """ + + # get the value of the constant + tensor_value = node.attr["value"].tensor + np_array = tensor_util.MakeNdarray(tensor_value) + + if np_array.dtype == np.dtype(object): + # assert False # not tested, maybe tf string type? + if shape and node.name in shape: + var_shape = shape[node.name] + else: + var_shape = tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape) + param = None + sym = [_expr.var(node.name, shape=var_shape, dtype="uint8")] + return sym, param + + if len(np_array.shape) == 0: + param = None + sym = [tvm.relay.const(np_array, np_array.dtype)] + else: + param = tvm.nd.array(np_array) + sym = [ + _expr.var(node.name, shape=param.shape, dtype=param.dtype) + ] + + return sym, param + + +def get_attr(buf): + """convert value of a node attribute. node attribute is part of a node in a graph. + // tensorflow/core/framework/attr_value.proto + message AttrValue { + oneof value { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + DataType type = 6; // "type" + TensorShapeProto shape = 7; // "shape" + TensorProto tensor = 8; // "tensor" + ListValue list = 1; // any "list(...)" } + } + Parameters + ---------- + buf: attrvalue protobuf. + Returns + ------- + The value of the attr, as a Python object. + """ + fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"] + + x = buf + + ret = [] + + if not x.WhichOneof("value"): + assert False # not yet tested; why would there be empty attribute value in a node def? + + if x.HasField("list"): + for f in fields: + if getattr(x.list, f): + if f == "type": + ret += [dtypes.as_dtype(x) for x in list(getattr(x.list, f))] + else: + ret += list(getattr(x.list, f)) + else: + for f in fields: + if x.HasField(f): + if f == "type": + ret = dtypes.as_dtype(getattr(x, f)) + else: + ret = getattr(x, f) + return ret + +def parse_attr(attr_proto): + """Convert node attributes (a serialized map of key-value pairs) in a node to a dict + Parameters + ---------- + attr_proto: + attributes of a tf node + protobuf message format: + // tensorflow/core/framework/node_def.proto + message NodeDef { + map attr = 5; + } + Returns + ------- + Dict {string: python object} + Examples + -------- + attributes in following node converted to {'_user_specified_name': b'x', 'dtype': tf.float32 } + node { + name: "x" + op: "Placeholder" + attr { + key: "_user_specified_name" + value { + s: "x" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + """ + attrs = {} + for key, value in attr_proto.items(): + attrs[key] = get_attr(value) + + return attrs + +def convert_place_holder(shape, node, in_type=None): + """ convert tf place holder into relay var. + + Examples + -------- + a tf place holder with name "x" is converted to [Var(x, ty=TensorType([], float32))] + """ + + if shape and node.name in shape: + input_shape = list(shape[node.name]) + # assert False # not yet tested + else: + input_shape = tensor_util.TensorShapeProtoToList( + node.attr["shape"].shape + ) + for idx, dim in enumerate(input_shape): + if dim < 0: + input_shape[idx] = Any() + attr = parse_attr(node.attr) + if in_type is not None: + sym = [ + _expr.var( + node.name, type_annotation=in_type + ) + ] + else: + sym = [ + _expr.var( + node.name, shape=input_shape, dtype=attr["dtype"].name + ) + ] + return input_shape, sym + + +class RelayModule: + """ states related to the entire relay module (multiple functions) after converted from tf graphdef + """ + def __init__(self): + self.mod = IRModule({}) # relay function and type definitions. defined in tvm/ir/module.py + self.params = {} # for constants (weights) in the entire relay module + self.prelude = Prelude(self.mod) # relay.prelude needed for tensorlist ops + +class GraphProto: + """Capturing states when converting a tf graph to a single relay function. + """ + def __init__(self, module): + self._module: RelayModule = module + self._prelude = self._module.prelude + self._params = {} + self._nodes = {} + self._input_shapes = {} + self._output_shapes = {} + self._tf_node_map = {} + self._gdef_lib = {} + + def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None, input_types={}, gdef_lib={}): + self._gdef_lib = gdef_lib + func = self._get_relay_func(graph, layout=layout, shape=shape, outputs=outputs, input_types=input_types) + return func, self._params + + def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_types={}): + self._layout = layout + for node in graph.node: + name = node.name + self._tf_node_map[name] = node + if node.op == "Placeholder": + in_type = None + if node.name in input_types: + in_type = input_types[node.name] + self._input_shapes[name], self._nodes[name] = convert_place_holder(shape, node, in_type) + elif node.op == "Const": + sym, param = convert_const_node(node, shape) + self._nodes[node.name] = sym + if param: + self._params[node.name] = param + for node in graph.node: + self._backtrack_construct(graph, node.name) + return self._func(graph, outputs) + + def _func(self, graph, outputs): + out = [] + if outputs is None: + last_node = graph.node[-1] + op = self._nodes[last_node.name.split(":")[0]] + if last_node.op == "Exit": + assert False # not yet tested + else: + out = op + else: + for out_name in outputs: + if ":" in out_name: + out_name = out_name.split(":") + out_name, out_num = out_name[0], out_name[-1] + out_num = int(out_num) + out.append(self._nodes[out_name][out_num]) + else: + out.append(self._nodes[out_name][0]) + + if isinstance(out, _expr.TupleWrapper): + out = out.astuple() + else: + out = out[0] if len(out) == 1 else _expr.Tuple(out) + fvars = analysis.free_vars(out) + func = _function.Function(fvars, out) + final_params = {} + for fv in fvars: + if fv.name_hint in self._params: + final_params[fv.name_hint] = self._params[fv.name_hint] + self._params = final_params + return func + + def _convert_operator(self, graph, op_name, node_name, inputs, attrs): + """Convert from Tensorflow operator to relay operator. + The converter must specify conversions explicitly for incompatible name, and + apply handlers to operator attributes. + + Parameters + ---------- + op_name : str + Operator name, such as Conv2D, AvgPool + inputs : list of relay.op + List of input symbols. + attrs : dict + Dict of operator attributes + + Returns + ------- + sym : relay.op + Converted relay operator + """ + + if op_name in ["PartitionedCall", "StatefulPartitionedCall"]: + sym = _partition_call_operator(self._module, graph, inputs, attrs, self._prelude, gdef_lib=self._gdef_lib) + elif op_name in ["StatelessIf", "If"]: + sym = _convert_if(self._module, graph, inputs, attrs, self._prelude, gdef_lib=self._gdef_lib) + elif op_name in ["StatelessWhile", "While"]: + sym = _convert_loop(self._module, graph, inputs, attrs, node_name, self._tf_node_map, self._prelude, gdef_lib=self._gdef_lib) + elif op_name in _convert_map_tf1: + if _need_prelude_for_shape_inference(op_name): + sym = _convert_map_tf1[op_name](inputs, attrs, self._params, self._prelude) + else: + sym = _convert_map_tf1[op_name](inputs, attrs, self._params, self._module.mod) + else: + raise NotImplementedError("Operator {} not implemented.".format(op_name)) + + sym = set_span(sym, node_name) + return sym + + def _backtrack_construct(self, graph, node_name): + """Convert a specific tensorflow node to relay expression. + + If any of its ancestor node is not converted yet, backtrack as + far as input node and covert all nodes on the path. resurion is used here. + + This is required when parsing control flow nodes, since the parsing + order may not follow the original graph def. + + to discover input node, current tf node's input is iterated: + + tensorflow/core/framework/node_def.proto + message NodeDef { + repeated string input = 3; + } + + a node has many inputs (other nodes). each input has the following format: + data input is "node:src_output". node is the string name. + control input is "^node". + + Parameters + ---------- + node_name : str + node name + + Returns + ------- + op : relay.Expr + Converted relay expression. + + Examples + -------- + tf expression "x+1" is converted to relay expression: + CallNode(Op(add), [Var(x, ty=TensorType([], float32)), Constant(1.0)], (nullptr), []) + + """ + try: + from tensorflow.python.framework import tensor_util + except ImportError as e: + raise ImportError("Unable to import tensorflow which is required {}".format(e)) + + input_op_name = node_name.split(":")[0].split("^")[-1] + if input_op_name not in self._nodes: + node = self._tf_node_map[input_op_name] + attr = parse_attr(node.attr) + if "_output_shapes" in attr: + self._output_shapes[node.name] = [ + tensor_util.TensorShapeProtoToList(tshape) + for tshape in attr["_output_shapes"] + ] + else: + self._output_shapes[node.name] = [None] + + attr["_output_shapes"] = self._output_shapes[input_op_name] + attr["_node_name"] = node.name + attr["_target_layout"] = self._layout + inputs = [self._backtrack_construct(graph, iname) for iname in node.input] + op = self._convert_operator(graph, node.op, node.name, inputs, attr) + + if isinstance(op, np.ndarray): + self._params[node.name] = tvm.nd.array(op) + op = [ + _expr.var( + node.name, + shape=self._params[node.name].shape, + dtype=self._params[node.name].dtype, + ) + ] + elif isinstance(op, (_expr.Expr, _expr.TupleGetItem)): + op = [op] + self._nodes[input_op_name] = op + + out = self._nodes[input_op_name] + if isinstance(out, _expr.TupleWrapper): + tn = node_name.split(":") + tensor_slot = int(tn[1]) if len(tn) > 1 else 0 + return out[tensor_slot] + + return out[0] + +def _partition_call_operator(module, graph, inputs, attr, prelude, gdef_lib): + """ convert tf PartitionedCall node to a relay function call """ + node_func_name = attr.get("f").name + return _convert_function(module, graph, inputs, attr, node_func_name, prelude, gdef_lib=gdef_lib) + +def _convert_if(module, graph, inputs, attr, prelude, gdef_lib): + """ Convert tf If/StatelessIf to Relay If """ + cond_expr = inputs[0] + branch_names = [attr.get(x).name for x in ["then_branch", "else_branch"]] + then_fn, else_fn = [ + _convert_function(module, graph, inputs[1:], attr, name, prelude, gdef_lib=gdef_lib) for name in branch_names + ] + out = _expr.If(cond_expr, then_fn, else_fn) + return out + +def _convert_loop(module, graph, inputs, attr, node_name, nodes, prelude, gdef_lib): + """ convert tf while_loop to Relay loop """ + input_size = len(inputs) + cond_fn_name, body_fn_name = [attr.get(x).name for x in ["cond", "body"]] + + def convert_vars(loop_inputs, input_signature): + """ convert inputs to relay vars to be used as loop variables + Loop inputs are packed as: + [iteration_number, max_iterations, loop_variables...] + """ + new_vars = [] + for i, v in enumerate(loop_inputs): + if isinstance(v, _expr.Constant): + vtype = _infer_type(v).checked_type.dtype + new_vars.append(_expr.var(input_signature[i].name, shape=(), dtype=vtype)) + else: + vtype = _infer_type_with_prelude(v, prelude) + new_vars.append(_expr.var(input_signature[i].name, type_annotation=vtype)) + return new_vars + + while_func = next( + (f for f in graph.library.function if f.signature.name == attr["body"].name), + None, + ) + loop_inputs = convert_vars(inputs, while_func.signature.input_arg) + # in_shapes = nodes[node_name].attr["output_shapes"].list.shape + + def cond_fn(*loop_inputs): + return _convert_function(module, graph, loop_inputs, attr, cond_fn_name, prelude, gdef_lib=gdef_lib) + + # Define the loop body, in this function we need to unpack loop inputs, + # convert the loop subgraph, and pack outputs for the next iteration. + def body_fn(*loop_inputs): + # Increment loop iteration counter + loop_count = loop_inputs[0] + _expr.const(1, dtype='int32') + max_count = loop_inputs[1] + fn = _convert_function(module, graph, loop_inputs, attr, body_fn_name, prelude, gdef_lib=gdef_lib) + + # Repack loop variables + out = [loop_count, max_count] + [_expr.TupleGetItem(fn, i) for i in range(2, input_size)] + return out + + loop = _while_loop(cond_fn, loop_inputs, body_fn) + outputs = loop(*inputs) + outputs = _expr.TupleWrapper( + _expr.Tuple( + [ + _expr.TupleGetItem(outputs, i) + for i in range(input_size) + ] + ), + input_size + ) + return outputs + +def _convert_function(module, graph, inputs, attr, node_func_name, prelude, gdef_lib, in_shapes=None): + """ Convert given tf node to a relay function call + + Parameters + ---------- + module : IRModule + where converted function is stored + + graph: + top level tf graphdef + + inputs : List[tvm.relay.Expr] + List of input symbols. Parameters for the function. + + attrs : Dict[tvm.Attrs] + Dict of operator attributes. + + node_func_name : str + Name of tf2 node to be converted + + Returns + ------- + op : tvm.relay.Expr + + + Examples + -------- + a tf function "x+1", is implemented as a subgraph in the libary section of the graph. this subgraph is converted + to a relay function such as + fn (%x: float32) { + add(%x, 1f) /* Identity */ + } + + the subgraph has a function name such as __inference_add_95 + the tf function call operator is returned as relay expression, such as: + free_var %x: float32; + @func___inference_add_95(%x) + + """ + func = next( + (f for f in graph.library.function if f.signature.name == node_func_name), + None, + ) + if func is None: + raise Exception("Function not found - {}".format(node_func_name)) + devices = set(node.device for node in func.node_def) + if len(devices) > 1: + raise Exception("node_def in function {} contains > 1 types of devices {}".format(node_func_name, devices)) + + subgraph = gdef_lib[node_func_name] + # preserve library functions in subgraphs to make them available to nested functions + for fn in graph.library.function: + subgraph.library.function.add().CopyFrom(fn) + + # Computing subgraph's input shape and type dictionaries + input_expr_dict = {} + input_types = {} + for f_arg, input in zip(func.signature.input_arg, inputs): + input_expr_dict[f_arg.name] = input + input_types[f_arg.name] = _infer_type_with_prelude(input, prelude) + + func_name = "func_{}".format(func.signature.name) + try: + global_func = module.mod[func_name] + sub_func = global_func + sub_params = module.params + except ValueError: + # Construct relay nodes from the subgraph + g1 = GraphProto(module) + output_sig = [func.ret[f.name] for f in func.signature.output_arg] + + # TODO: unify prelude and main IRModules + sub_func, sub_params = g1.from_tensorflow(subgraph, outputs=output_sig, input_types=input_types, gdef_lib=gdef_lib) + module.params.update(sub_params) + func_expr = _function.Function(sub_func.params, sub_func.body) + global_func = tvm.relay.GlobalVar(func_name) + module.mod[global_func] = func_expr + module.mod = InferType()(module.mod) + prelude.mod = module.mod + + param_exprs = [] + for param_expr in sub_func.params: + # sub_params is subset of sub_func.params + param_name = param_expr.vid.name_hint + if param_name in input_expr_dict.keys(): + param_exprs.append(input_expr_dict[param_name]) + elif param_name in sub_params.keys(): + param_exprs.append(param_expr) + else: + raise Exception("Input parameter {} not found".format(param_name)) + + sb = tvm.relay.scope_builder.ScopeBuilder() + loop_ret = global_func(*param_exprs) + sb.ret(loop_ret) + ret = sb.get() + return ret + +def from_tensorflow(graph_def, layout="NHWC", shape=None, outputs=None): + """convert tensorflow2.x graph into relay function. + + Parameters + ---------- + graph_def : must be frozen graph (no variables allowed). + Placeholders are assumed to be inputs to the graph. + + tensorflow/core/framework/graph.proto + message GraphDef { + repeated NodeDef node = 1; + FunctionDefLibrary library = 2; + } + tensorflow/core/framework/function.proto + message FunctionDef { + repeated NodeDef node_def = 3; + } + + layout : str + The layout for the model. + + shape : List[str, List[int]] + Input to the model. It is a key and shape vector mapping. Applies to placeholders. + + outputs : List[str] + The list of output nodes. The last node is treated as the output if not + specified. + + Returns + ------- + mod : tvm.IRModule + The module that optimizations will be performed on. + + params : dict of str to tvm.nd.NDArray + Dict of converted parameters stored in tvm.nd.NDArray format. + + Examples + -------- + "x+1" tf module where x has a shape of (2,2) is converted as follows: + + mod : tvm.IRModule + def @func___inference_add_95(%x: Tensor[(2, 2), float32], %add/y: Tensor[(2, 2), float32]) -> Tensor[(2, 2), float32] { + add(%x, %add/y) /* Identity */ /* ty=Tensor[(2, 2), float32] */ + } + + def @main(%x1: Tensor[(2, 2), float32], %add/y1: Tensor[(2, 2), float32]) { + @func___inference_add_95(%x1, %add/y1) /* Identity */ + } + + params : dict of str to tvm.nd.NDArray + {'add/y': + + """ + + # Subgraph graph_defs are cached here to avoid a TF error when parsing after prelude init + graph_def_library = {} + for func in graph_def.library.function: + inshape = func.attr["_input_shapes"].list.shape + graph_def_library[func.signature.name], _ = function_def_to_graph.function_def_to_graph_def(func, inshape) + module = RelayModule() + g = GraphProto(module) + func, params = g.from_tensorflow(graph_def, layout, shape, outputs, gdef_lib=graph_def_library) + module.mod["main"] = func + module.params.update(params) + return module.mod, module.params \ No newline at end of file diff --git a/tests/python/frontend/tensorflow2/common.py b/tests/python/frontend/tensorflow2/common.py index e30ee7b0c993..df701675bb45 100644 --- a/tests/python/frontend/tensorflow2/common.py +++ b/tests/python/frontend/tensorflow2/common.py @@ -23,8 +23,7 @@ from tvm.runtime.vm import VirtualMachine import tvm.contrib.graph_executor as runtime -from tvm.relay.frontend.tensorflow import from_tensorflow - +from tvm.relay.frontend.tensorflow2 import from_tensorflow import tvm.testing from tvm.relay.testing.tf import vmobj_to_list as vmobj_to_list @@ -101,5 +100,4 @@ def compare_tf_tvm(gdef, input_, output_, runtime="vm", output_tensors=None): tvm_out = run_graph_executor(lib, input_) else: raise RuntimeError("Runtime input not supported: %s" % runtime) - tvm.testing.assert_allclose(output_, tvm_out, atol=1e-5) diff --git a/tests/python/frontend/tensorflow2/test_functional_models.py b/tests/python/frontend/tensorflow2/test_functional_models.py index 40d42a28025a..1c2af20d9ccc 100644 --- a/tests/python/frontend/tensorflow2/test_functional_models.py +++ b/tests/python/frontend/tensorflow2/test_functional_models.py @@ -357,5 +357,76 @@ def func(self, x): run_all(ConcatV2) +def test_if() + class If(tf.Module): + def get_input(self): + return np.ones((2,2), dtype='float32') + + def expected_ops(self): + return ['Placeholder', 'If', 'Identity', 'Const'] + + def expected_lib_ops(self): + return ['If', 'Identity', 'Const', 'Mul'] + + @tf.function(input_signature=[tf.TensorSpec(shape=(2,2), dtype=tf.float32)]) + def func(self, x): + @tf.function(input_signature=[tf.TensorSpec(shape=(2,2), dtype=tf.float32)]) + def double(x): + return 2*x + + @tf.function(input_signature=[tf.TensorSpec(shape=(2,2), dtype=tf.float32)]) + def triple(x): + return 3*x + + cond = True + output = tf.raw_ops.If(cond=cond, input=[x], Tout=[tf.float32], output_shapes=[(2,2)], + then_branch=double.get_concrete_function(), else_branch=triple.get_concrete_function()) + return output[0] + + run_model_graph(If) + run_func_graph(If, use_vm=True) + + +def test_stateless_while(): + class StatelessWhile(tf.Module): + def get_input(self): + return np.array([6], dtype='float32') + + def expected_ops(self): + return ['Identity', 'StatelessWhile', 'Const', 'Placeholder'] + + def expected_lib_ops(self): + return ['StatelessWhile', 'Squeeze', 'Const', 'Less', 'Add', 'AddV2', 'Identity'] + + @tf.function(input_signature=[tf.TensorSpec(shape=(1,), dtype=tf.float32)]) + def func(self, x): + i = tf.constant(3.) + cond = lambda i: tf.less(i, x) + body = lambda i: (tf.add(i, 2),) + r = tf.while_loop(cond, body, [i]) + return r[0] + + run_model_graph(StatelessWhile) + + + +def test_stateless_while_2var(): + class StatelessWhile2Var(StatelessWhile): + def get_input(self): + return np.array([20], dtype='float32') + + @tf.function(input_signature=[tf.TensorSpec(shape=(1,), dtype=tf.float32)]) + def func(self, x): + i = tf.constant(3.) + j = tf.constant(5.) + cond = lambda i,j: tf.less(i+j, x) + body = lambda i,j: (tf.add(i, 2), tf.add(j, 3)) + r = tf.while_loop(cond, body, [i, j]) + return r + + run_model_graph(StatelessWhile2Var) + + + if __name__ == "__main__": pytest.main([__file__]) From d2a7fd9fb03af965002dee9bf8b2786549f9ee84 Mon Sep 17 00:00:00 2001 From: Rohan Date: Wed, 26 May 2021 19:59:02 +0000 Subject: [PATCH 02/13] Some minor fixes --- .../frontend/tensorflow2/test_functional_models.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/python/frontend/tensorflow2/test_functional_models.py b/tests/python/frontend/tensorflow2/test_functional_models.py index 1c2af20d9ccc..ab4f06f907cc 100644 --- a/tests/python/frontend/tensorflow2/test_functional_models.py +++ b/tests/python/frontend/tensorflow2/test_functional_models.py @@ -49,13 +49,13 @@ def _model_graph(TestClass): return gdef, input_, output -def run_all(TestClass): - def run_func_graph(TestClass, runtime="vm"): - compare_tf_tvm(*_function_graph(TestClass), runtime=runtime) +def run_func_graph(TestClass, runtime="vm"): + compare_tf_tvm(*_function_graph(TestClass), runtime=runtime) - def run_model_graph(TestClass): - compare_tf_tvm(*_model_graph(TestClass), runtime="vm") +def run_model_graph(TestClass): + compare_tf_tvm(*_model_graph(TestClass), runtime="vm") +def run_all(TestClass): run_model_graph(TestClass) for runtime_ in ["vm", "graph"]: run_func_graph(TestClass, runtime=runtime_) @@ -357,7 +357,7 @@ def func(self, x): run_all(ConcatV2) -def test_if() +def test_if(): class If(tf.Module): def get_input(self): return np.ones((2,2), dtype='float32') From 4a75a5e4be08c61448cbe93e20e9182a6b57304b Mon Sep 17 00:00:00 2001 From: Rohan Mukherjee Date: Tue, 1 Jun 2021 21:09:16 +0000 Subject: [PATCH 03/13] Fixing output order in TF2 outputs Co-authored-by: David Huang Co-authored-by: Rohan Mukherjee Co-authored-by: Srinidhi Goud Co-authored-by: Xingyu Zhou Co-authored-by: Xiao --- python/tvm/relay/frontend/tensorflow2.py | 25 +++++++------- tests/python/frontend/tensorflow2/common.py | 25 +++++++------- .../tensorflow2/test_functional_models.py | 33 ++++++++++--------- 3 files changed, 44 insertions(+), 39 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow2.py b/python/tvm/relay/frontend/tensorflow2.py index 17a32e34583f..1266db6f83e7 100644 --- a/python/tvm/relay/frontend/tensorflow2.py +++ b/python/tvm/relay/frontend/tensorflow2.py @@ -192,7 +192,6 @@ def convert_place_holder(shape, node, in_type=None): if shape and node.name in shape: input_shape = list(shape[node.name]) - # assert False # not yet tested else: input_shape = tensor_util.TensorShapeProtoToList( node.attr["shape"].shape @@ -258,7 +257,7 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_ if param: self._params[node.name] = param for node in graph.node: - self._backtrack_construct(graph, node.name) + self._backtrack_construct(graph, node.name, outputs=outputs) return self._func(graph, outputs) def _func(self, graph, outputs): @@ -284,6 +283,7 @@ def _func(self, graph, outputs): out = out.astuple() else: out = out[0] if len(out) == 1 else _expr.Tuple(out) + fvars = analysis.free_vars(out) func = _function.Function(fvars, out) final_params = {} @@ -293,7 +293,7 @@ def _func(self, graph, outputs): self._params = final_params return func - def _convert_operator(self, graph, op_name, node_name, inputs, attrs): + def _convert_operator(self, graph, op_name, node_name, inputs, attrs, outputs=None): """Convert from Tensorflow operator to relay operator. The converter must specify conversions explicitly for incompatible name, and apply handlers to operator attributes. @@ -314,11 +314,12 @@ def _convert_operator(self, graph, op_name, node_name, inputs, attrs): """ if op_name in ["PartitionedCall", "StatefulPartitionedCall"]: - sym = _partition_call_operator(self._module, graph, inputs, attrs, self._prelude, gdef_lib=self._gdef_lib) + sym = _partition_call_operator(self._module, graph, inputs, attrs, self._prelude, gdef_lib=self._gdef_lib, outputs=outputs) elif op_name in ["StatelessIf", "If"]: sym = _convert_if(self._module, graph, inputs, attrs, self._prelude, gdef_lib=self._gdef_lib) elif op_name in ["StatelessWhile", "While"]: - sym = _convert_loop(self._module, graph, inputs, attrs, node_name, self._tf_node_map, self._prelude, gdef_lib=self._gdef_lib) + sym = _convert_loop(self._module, graph, inputs, attrs, node_name, self._tf_node_map, + self._prelude, gdef_lib=self._gdef_lib) elif op_name in _convert_map_tf1: if _need_prelude_for_shape_inference(op_name): sym = _convert_map_tf1[op_name](inputs, attrs, self._params, self._prelude) @@ -330,7 +331,7 @@ def _convert_operator(self, graph, op_name, node_name, inputs, attrs): sym = set_span(sym, node_name) return sym - def _backtrack_construct(self, graph, node_name): + def _backtrack_construct(self, graph, node_name, outputs=None): """Convert a specific tensorflow node to relay expression. If any of its ancestor node is not converted yet, backtrack as @@ -386,8 +387,8 @@ def _backtrack_construct(self, graph, node_name): attr["_output_shapes"] = self._output_shapes[input_op_name] attr["_node_name"] = node.name attr["_target_layout"] = self._layout - inputs = [self._backtrack_construct(graph, iname) for iname in node.input] - op = self._convert_operator(graph, node.op, node.name, inputs, attr) + inputs = [self._backtrack_construct(graph, iname, outputs=outputs) for iname in node.input] + op = self._convert_operator(graph, node.op, node.name, inputs, attr, outputs=outputs) if isinstance(op, np.ndarray): self._params[node.name] = tvm.nd.array(op) @@ -410,10 +411,10 @@ def _backtrack_construct(self, graph, node_name): return out[0] -def _partition_call_operator(module, graph, inputs, attr, prelude, gdef_lib): +def _partition_call_operator(module, graph, inputs, attr, prelude, gdef_lib, outputs=None): """ convert tf PartitionedCall node to a relay function call """ node_func_name = attr.get("f").name - return _convert_function(module, graph, inputs, attr, node_func_name, prelude, gdef_lib=gdef_lib) + return _convert_function(module, graph, inputs, attr, node_func_name, prelude, gdef_lib=gdef_lib, outputs=outputs) def _convert_if(module, graph, inputs, attr, prelude, gdef_lib): """ Convert tf If/StatelessIf to Relay If """ @@ -480,7 +481,7 @@ def body_fn(*loop_inputs): ) return outputs -def _convert_function(module, graph, inputs, attr, node_func_name, prelude, gdef_lib, in_shapes=None): +def _convert_function(module, graph, inputs, attr, node_func_name, prelude, gdef_lib, in_shapes=None, outputs=None): """ Convert given tf node to a relay function call Parameters @@ -549,7 +550,7 @@ def _convert_function(module, graph, inputs, attr, node_func_name, prelude, gdef except ValueError: # Construct relay nodes from the subgraph g1 = GraphProto(module) - output_sig = [func.ret[f.name] for f in func.signature.output_arg] + output_sig = [func.ret[f.name] for f in func.signature.output_arg] if outputs is None else outputs # TODO: unify prelude and main IRModules sub_func, sub_params = g1.from_tensorflow(subgraph, outputs=output_sig, input_types=input_types, gdef_lib=gdef_lib) diff --git a/tests/python/frontend/tensorflow2/common.py b/tests/python/frontend/tensorflow2/common.py index df701675bb45..76f4038e6e4d 100644 --- a/tests/python/frontend/tensorflow2/common.py +++ b/tests/python/frontend/tensorflow2/common.py @@ -33,20 +33,21 @@ def run_tf_code(func, input_): if type(func) is Function: - out = func(input_) - if isinstance(out, list): - a = [x.numpy() for x in out] + f_out = func(input_) + if isinstance(f_out, (list, tuple)): + np_out = [x.numpy() for x in f_out] else: - a = [out.numpy()] + np_out = [f_out.numpy()] else: - a = func(tf.constant(input_)) - if type(a) is dict: - a = [x.numpy() for x in a.values()] - elif type(a) is list: - a = [x.numpy() for x in a] + f_out = func(tf.constant(input_)) + if type(f_out) is dict: + np_out = [f_out[k].numpy() for k in sorted(f_out.keys())] + + elif type(f_out) is list: + np_out = [x.numpy() for x in f_out] else: - a = a.numpy() - return a + np_out = f_out.numpy() + return np_out def compile_graph_executor(mod, params, target="llvm", target_host="llvm", opt_level=3): @@ -91,6 +92,7 @@ def compare_tf_tvm(gdef, input_, output_, runtime="vm", output_tensors=None): output_tensors : List of output tensor names (Optional) if not specified then the last node is assumed as graph output. """ + mod, params = from_tensorflow(gdef, outputs=output_tensors) if runtime == "vm": exec_ = compile_vm(mod, params) @@ -100,4 +102,5 @@ def compare_tf_tvm(gdef, input_, output_, runtime="vm", output_tensors=None): tvm_out = run_graph_executor(lib, input_) else: raise RuntimeError("Runtime input not supported: %s" % runtime) + tvm.testing.assert_allclose(output_, tvm_out, atol=1e-5) diff --git a/tests/python/frontend/tensorflow2/test_functional_models.py b/tests/python/frontend/tensorflow2/test_functional_models.py index ab4f06f907cc..547a64d5e706 100644 --- a/tests/python/frontend/tensorflow2/test_functional_models.py +++ b/tests/python/frontend/tensorflow2/test_functional_models.py @@ -34,7 +34,6 @@ def _function_graph(TestClass): output = run_tf_code(f, input_) return gdef, input_, output - def _model_graph(TestClass): model = TestClass() with tempfile.TemporaryDirectory() as model_path: @@ -48,7 +47,6 @@ def _model_graph(TestClass): output = run_tf_code(f, input_) return gdef, input_, output - def run_func_graph(TestClass, runtime="vm"): compare_tf_tvm(*_function_graph(TestClass), runtime=runtime) @@ -357,16 +355,23 @@ def func(self, x): run_all(ConcatV2) -def test_if(): - class If(tf.Module): +def test_multi_output(): + + class MultiOutput(tf.Module): def get_input(self): return np.ones((2,2), dtype='float32') - def expected_ops(self): - return ['Placeholder', 'If', 'Identity', 'Const'] + @tf.function(input_signature=[tf.TensorSpec(shape=(2,2), dtype=tf.float32)]) + def func(self, x): + y = 2*x + return x, y - def expected_lib_ops(self): - return ['If', 'Identity', 'Const', 'Mul'] + run_model_graph(MultiOutput) + +def test_if(): + class If(tf.Module): + def get_input(self): + return np.ones((2,2), dtype='float32') @tf.function(input_signature=[tf.TensorSpec(shape=(2,2), dtype=tf.float32)]) def func(self, x): @@ -383,8 +388,9 @@ def triple(x): then_branch=double.get_concrete_function(), else_branch=triple.get_concrete_function()) return output[0] + run_func_graph(If, runtime="vm") run_model_graph(If) - run_func_graph(If, use_vm=True) + def test_stateless_while(): @@ -392,12 +398,6 @@ class StatelessWhile(tf.Module): def get_input(self): return np.array([6], dtype='float32') - def expected_ops(self): - return ['Identity', 'StatelessWhile', 'Const', 'Placeholder'] - - def expected_lib_ops(self): - return ['StatelessWhile', 'Squeeze', 'Const', 'Less', 'Add', 'AddV2', 'Identity'] - @tf.function(input_signature=[tf.TensorSpec(shape=(1,), dtype=tf.float32)]) def func(self, x): i = tf.constant(3.) @@ -406,12 +406,13 @@ def func(self, x): r = tf.while_loop(cond, body, [i]) return r[0] + run_func_graph(StatelessWhile, runtime="vm") run_model_graph(StatelessWhile) def test_stateless_while_2var(): - class StatelessWhile2Var(StatelessWhile): + class StatelessWhile2Var(tf.Module): def get_input(self): return np.array([20], dtype='float32') From ba8517df31dcc203a5a38e6d1882ed02ceaa67e6 Mon Sep 17 00:00:00 2001 From: Rohan Mukherjee Date: Tue, 1 Jun 2021 21:19:33 +0000 Subject: [PATCH 04/13] Using black --- .../tensorflow2/test_functional_models.py | 53 +++++++++++-------- 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/tests/python/frontend/tensorflow2/test_functional_models.py b/tests/python/frontend/tensorflow2/test_functional_models.py index 547a64d5e706..d3273f6a4455 100644 --- a/tests/python/frontend/tensorflow2/test_functional_models.py +++ b/tests/python/frontend/tensorflow2/test_functional_models.py @@ -34,6 +34,7 @@ def _function_graph(TestClass): output = run_tf_code(f, input_) return gdef, input_, output + def _model_graph(TestClass): model = TestClass() with tempfile.TemporaryDirectory() as model_path: @@ -47,12 +48,15 @@ def _model_graph(TestClass): output = run_tf_code(f, input_) return gdef, input_, output + def run_func_graph(TestClass, runtime="vm"): compare_tf_tvm(*_function_graph(TestClass), runtime=runtime) + def run_model_graph(TestClass): compare_tf_tvm(*_model_graph(TestClass), runtime="vm") + def run_all(TestClass): run_model_graph(TestClass) for runtime_ in ["vm", "graph"]: @@ -61,7 +65,7 @@ def run_all(TestClass): def test_add_one(): class AddOne(tf.Module): - """ simple function to test x=x+1; scalar as input""" + """simple function to test x=x+1; scalar as input""" def get_input(self): return np.array(1.0, dtype="float32") @@ -356,51 +360,56 @@ def func(self, x): def test_multi_output(): - class MultiOutput(tf.Module): def get_input(self): - return np.ones((2,2), dtype='float32') + return np.ones((2, 2), dtype="float32") - @tf.function(input_signature=[tf.TensorSpec(shape=(2,2), dtype=tf.float32)]) + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2), dtype=tf.float32)]) def func(self, x): - y = 2*x + y = 2 * x return x, y run_model_graph(MultiOutput) + def test_if(): class If(tf.Module): def get_input(self): - return np.ones((2,2), dtype='float32') + return np.ones((2, 2), dtype="float32") - @tf.function(input_signature=[tf.TensorSpec(shape=(2,2), dtype=tf.float32)]) + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2), dtype=tf.float32)]) def func(self, x): - @tf.function(input_signature=[tf.TensorSpec(shape=(2,2), dtype=tf.float32)]) + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2), dtype=tf.float32)]) def double(x): - return 2*x + return 2 * x - @tf.function(input_signature=[tf.TensorSpec(shape=(2,2), dtype=tf.float32)]) + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2), dtype=tf.float32)]) def triple(x): - return 3*x + return 3 * x cond = True - output = tf.raw_ops.If(cond=cond, input=[x], Tout=[tf.float32], output_shapes=[(2,2)], - then_branch=double.get_concrete_function(), else_branch=triple.get_concrete_function()) + output = tf.raw_ops.If( + cond=cond, + input=[x], + Tout=[tf.float32], + output_shapes=[(2, 2)], + then_branch=double.get_concrete_function(), + else_branch=triple.get_concrete_function(), + ) return output[0] run_func_graph(If, runtime="vm") run_model_graph(If) - def test_stateless_while(): class StatelessWhile(tf.Module): def get_input(self): - return np.array([6], dtype='float32') + return np.array([6], dtype="float32") @tf.function(input_signature=[tf.TensorSpec(shape=(1,), dtype=tf.float32)]) def func(self, x): - i = tf.constant(3.) + i = tf.constant(3.0) cond = lambda i: tf.less(i, x) body = lambda i: (tf.add(i, 2),) r = tf.while_loop(cond, body, [i]) @@ -410,24 +419,22 @@ def func(self, x): run_model_graph(StatelessWhile) - def test_stateless_while_2var(): class StatelessWhile2Var(tf.Module): def get_input(self): - return np.array([20], dtype='float32') + return np.array([20], dtype="float32") @tf.function(input_signature=[tf.TensorSpec(shape=(1,), dtype=tf.float32)]) def func(self, x): - i = tf.constant(3.) - j = tf.constant(5.) - cond = lambda i,j: tf.less(i+j, x) - body = lambda i,j: (tf.add(i, 2), tf.add(j, 3)) + i = tf.constant(3.0) + j = tf.constant(5.0) + cond = lambda i, j: tf.less(i + j, x) + body = lambda i, j: (tf.add(i, 2), tf.add(j, 3)) r = tf.while_loop(cond, body, [i, j]) return r run_model_graph(StatelessWhile2Var) - if __name__ == "__main__": pytest.main([__file__]) From d720fe3442ab329d08f1851a824a5dcd2c05846b Mon Sep 17 00:00:00 2001 From: Rohan Mukherjee Date: Tue, 1 Jun 2021 21:33:24 +0000 Subject: [PATCH 05/13] Refactoring Co-authored-by: David Huang Co-authored-by: Rohan Mukherjee Co-authored-by: Srinidhi Goud Co-authored-by: Xingyu Zhou Co-authored-by: Xiao --- python/tvm/relay/frontend/tensorflow2.py | 222 +++++++++++++---------- 1 file changed, 128 insertions(+), 94 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow2.py b/python/tvm/relay/frontend/tensorflow2.py index 1266db6f83e7..3e11327660d3 100644 --- a/python/tvm/relay/frontend/tensorflow2.py +++ b/python/tvm/relay/frontend/tensorflow2.py @@ -43,16 +43,18 @@ from tensorflow.python.framework import dtypes from .tensorflow import _convert_map as _convert_map_tf1 -from .tensorflow import _need_prelude_for_shape_inference +from .tensorflow import _need_prelude_for_shape_inference from ..ty import Any, TensorType __all__ = ["from_tensorflow"] + def _infer_type_with_prelude(val, prelude): body = _infer_type(val, prelude.mod) return body.checked_type + def set_span(sym, node_name): span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0) if isinstance(sym, _expr.Call): @@ -68,15 +70,13 @@ def set_span(sym, node_name): def convert_const_node(node, shape): - """convert tf const node into relay const or var - """ + """convert tf const node into relay const or var""" # get the value of the constant tensor_value = node.attr["value"].tensor np_array = tensor_util.MakeNdarray(tensor_value) if np_array.dtype == np.dtype(object): - # assert False # not tested, maybe tf string type? if shape and node.name in shape: var_shape = shape[node.name] else: @@ -90,9 +90,7 @@ def convert_const_node(node, shape): sym = [tvm.relay.const(np_array, np_array.dtype)] else: param = tvm.nd.array(np_array) - sym = [ - _expr.var(node.name, shape=param.shape, dtype=param.dtype) - ] + sym = [_expr.var(node.name, shape=param.shape, dtype=param.dtype)] return sym, param @@ -110,7 +108,7 @@ def get_attr(buf): TensorShapeProto shape = 7; // "shape" TensorProto tensor = 8; // "tensor" ListValue list = 1; // any "list(...)" } - } + } Parameters ---------- buf: attrvalue protobuf. @@ -125,7 +123,7 @@ def get_attr(buf): ret = [] if not x.WhichOneof("value"): - assert False # not yet tested; why would there be empty attribute value in a node def? + return ret if x.HasField("list"): for f in fields: @@ -142,7 +140,8 @@ def get_attr(buf): else: ret = getattr(x, f) return ret - + + def parse_attr(attr_proto): """Convert node attributes (a serialized map of key-value pairs) in a node to a dict Parameters @@ -182,50 +181,42 @@ def parse_attr(attr_proto): return attrs + def convert_place_holder(shape, node, in_type=None): - """ convert tf place holder into relay var. - + """convert tf place holder into relay var. + Examples - -------- + -------- a tf place holder with name "x" is converted to [Var(x, ty=TensorType([], float32))] """ if shape and node.name in shape: input_shape = list(shape[node.name]) else: - input_shape = tensor_util.TensorShapeProtoToList( - node.attr["shape"].shape - ) + input_shape = tensor_util.TensorShapeProtoToList(node.attr["shape"].shape) for idx, dim in enumerate(input_shape): if dim < 0: input_shape[idx] = Any() attr = parse_attr(node.attr) if in_type is not None: - sym = [ - _expr.var( - node.name, type_annotation=in_type - ) - ] + sym = [_expr.var(node.name, type_annotation=in_type)] else: - sym = [ - _expr.var( - node.name, shape=input_shape, dtype=attr["dtype"].name - ) - ] + sym = [_expr.var(node.name, shape=input_shape, dtype=attr["dtype"].name)] return input_shape, sym class RelayModule: - """ states related to the entire relay module (multiple functions) after converted from tf graphdef - """ + """states related to the entire relay module (multiple functions) after converted from tf graphdef""" + def __init__(self): - self.mod = IRModule({}) # relay function and type definitions. defined in tvm/ir/module.py - self.params = {} # for constants (weights) in the entire relay module + self.mod = IRModule({}) # relay function and type definitions. defined in tvm/ir/module.py + self.params = {} # for constants (weights) in the entire relay module self.prelude = Prelude(self.mod) # relay.prelude needed for tensorlist ops + class GraphProto: - """Capturing states when converting a tf graph to a single relay function. - """ + """Capturing states when converting a tf graph to a single relay function.""" + def __init__(self, module): self._module: RelayModule = module self._prelude = self._module.prelude @@ -236,11 +227,15 @@ def __init__(self, module): self._tf_node_map = {} self._gdef_lib = {} - def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None, input_types={}, gdef_lib={}): + def from_tensorflow( + self, graph, layout="NHWC", shape=None, outputs=None, input_types={}, gdef_lib={} + ): self._gdef_lib = gdef_lib - func = self._get_relay_func(graph, layout=layout, shape=shape, outputs=outputs, input_types=input_types) + func = self._get_relay_func( + graph, layout=layout, shape=shape, outputs=outputs, input_types=input_types + ) return func, self._params - + def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_types={}): self._layout = layout for node in graph.node: @@ -250,7 +245,9 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_ in_type = None if node.name in input_types: in_type = input_types[node.name] - self._input_shapes[name], self._nodes[name] = convert_place_holder(shape, node, in_type) + self._input_shapes[name], self._nodes[name] = convert_place_holder( + shape, node, in_type + ) elif node.op == "Const": sym, param = convert_const_node(node, shape) self._nodes[node.name] = sym @@ -259,14 +256,14 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_ for node in graph.node: self._backtrack_construct(graph, node.name, outputs=outputs) return self._func(graph, outputs) - + def _func(self, graph, outputs): out = [] if outputs is None: last_node = graph.node[-1] op = self._nodes[last_node.name.split(":")[0]] if last_node.op == "Exit": - assert False # not yet tested + out = [op[0].tuple_value] else: out = op else: @@ -283,7 +280,7 @@ def _func(self, graph, outputs): out = out.astuple() else: out = out[0] if len(out) == 1 else _expr.Tuple(out) - + fvars = analysis.free_vars(out) func = _function.Function(fvars, out) final_params = {} @@ -314,12 +311,30 @@ def _convert_operator(self, graph, op_name, node_name, inputs, attrs, outputs=No """ if op_name in ["PartitionedCall", "StatefulPartitionedCall"]: - sym = _partition_call_operator(self._module, graph, inputs, attrs, self._prelude, gdef_lib=self._gdef_lib, outputs=outputs) + sym = _partition_call_operator( + self._module, + graph, + inputs, + attrs, + self._prelude, + gdef_lib=self._gdef_lib, + outputs=outputs, + ) elif op_name in ["StatelessIf", "If"]: - sym = _convert_if(self._module, graph, inputs, attrs, self._prelude, gdef_lib=self._gdef_lib) + sym = _convert_if( + self._module, graph, inputs, attrs, self._prelude, gdef_lib=self._gdef_lib + ) elif op_name in ["StatelessWhile", "While"]: - sym = _convert_loop(self._module, graph, inputs, attrs, node_name, self._tf_node_map, - self._prelude, gdef_lib=self._gdef_lib) + sym = _convert_loop( + self._module, + graph, + inputs, + attrs, + node_name, + self._tf_node_map, + self._prelude, + gdef_lib=self._gdef_lib, + ) elif op_name in _convert_map_tf1: if _need_prelude_for_shape_inference(op_name): sym = _convert_map_tf1[op_name](inputs, attrs, self._params, self._prelude) @@ -339,27 +354,27 @@ def _backtrack_construct(self, graph, node_name, outputs=None): This is required when parsing control flow nodes, since the parsing order may not follow the original graph def. - - to discover input node, current tf node's input is iterated: + + to discover input node, current tf node's input is iterated: tensorflow/core/framework/node_def.proto message NodeDef { repeated string input = 3; } - a node has many inputs (other nodes). each input has the following format: - data input is "node:src_output". node is the string name. - control input is "^node". + a node has many inputs (other nodes). each input has the following format: + data input is "node:src_output". node is the string name. + control input is "^node". Parameters ---------- node_name : str - node name + node name Returns ------- op : relay.Expr - Converted relay expression. + Converted relay expression. Examples -------- @@ -378,8 +393,7 @@ def _backtrack_construct(self, graph, node_name, outputs=None): attr = parse_attr(node.attr) if "_output_shapes" in attr: self._output_shapes[node.name] = [ - tensor_util.TensorShapeProtoToList(tshape) - for tshape in attr["_output_shapes"] + tensor_util.TensorShapeProtoToList(tshape) for tshape in attr["_output_shapes"] ] else: self._output_shapes[node.name] = [None] @@ -387,7 +401,9 @@ def _backtrack_construct(self, graph, node_name, outputs=None): attr["_output_shapes"] = self._output_shapes[input_op_name] attr["_node_name"] = node.name attr["_target_layout"] = self._layout - inputs = [self._backtrack_construct(graph, iname, outputs=outputs) for iname in node.input] + inputs = [ + self._backtrack_construct(graph, iname, outputs=outputs) for iname in node.input + ] op = self._convert_operator(graph, node.op, node.name, inputs, attr, outputs=outputs) if isinstance(op, np.ndarray): @@ -399,7 +415,7 @@ def _backtrack_construct(self, graph, node_name, outputs=None): dtype=self._params[node.name].dtype, ) ] - elif isinstance(op, (_expr.Expr, _expr.TupleGetItem)): + elif isinstance(op, (_expr.Expr, _expr.TupleGetItem)): op = [op] self._nodes[input_op_name] = op @@ -411,30 +427,36 @@ def _backtrack_construct(self, graph, node_name, outputs=None): return out[0] + def _partition_call_operator(module, graph, inputs, attr, prelude, gdef_lib, outputs=None): - """ convert tf PartitionedCall node to a relay function call """ + """convert tf PartitionedCall node to a relay function call""" node_func_name = attr.get("f").name - return _convert_function(module, graph, inputs, attr, node_func_name, prelude, gdef_lib=gdef_lib, outputs=outputs) + return _convert_function( + module, graph, inputs, attr, node_func_name, prelude, gdef_lib=gdef_lib, outputs=outputs + ) + def _convert_if(module, graph, inputs, attr, prelude, gdef_lib): - """ Convert tf If/StatelessIf to Relay If """ + """Convert tf If/StatelessIf to Relay If""" cond_expr = inputs[0] branch_names = [attr.get(x).name for x in ["then_branch", "else_branch"]] then_fn, else_fn = [ - _convert_function(module, graph, inputs[1:], attr, name, prelude, gdef_lib=gdef_lib) for name in branch_names + _convert_function(module, graph, inputs[1:], attr, name, prelude, gdef_lib=gdef_lib) + for name in branch_names ] out = _expr.If(cond_expr, then_fn, else_fn) return out + def _convert_loop(module, graph, inputs, attr, node_name, nodes, prelude, gdef_lib): - """ convert tf while_loop to Relay loop """ + """convert tf while_loop to Relay loop""" input_size = len(inputs) cond_fn_name, body_fn_name = [attr.get(x).name for x in ["cond", "body"]] def convert_vars(loop_inputs, input_signature): - """ convert inputs to relay vars to be used as loop variables - Loop inputs are packed as: - [iteration_number, max_iterations, loop_variables...] + """convert inputs to relay vars to be used as loop variables + Loop inputs are packed as: + [iteration_number, max_iterations, loop_variables...] """ new_vars = [] for i, v in enumerate(loop_inputs): @@ -454,15 +476,19 @@ def convert_vars(loop_inputs, input_signature): # in_shapes = nodes[node_name].attr["output_shapes"].list.shape def cond_fn(*loop_inputs): - return _convert_function(module, graph, loop_inputs, attr, cond_fn_name, prelude, gdef_lib=gdef_lib) + return _convert_function( + module, graph, loop_inputs, attr, cond_fn_name, prelude, gdef_lib=gdef_lib + ) # Define the loop body, in this function we need to unpack loop inputs, # convert the loop subgraph, and pack outputs for the next iteration. def body_fn(*loop_inputs): # Increment loop iteration counter - loop_count = loop_inputs[0] + _expr.const(1, dtype='int32') + loop_count = loop_inputs[0] + _expr.const(1, dtype="int32") max_count = loop_inputs[1] - fn = _convert_function(module, graph, loop_inputs, attr, body_fn_name, prelude, gdef_lib=gdef_lib) + fn = _convert_function( + module, graph, loop_inputs, attr, body_fn_name, prelude, gdef_lib=gdef_lib + ) # Repack loop variables out = [loop_count, max_count] + [_expr.TupleGetItem(fn, i) for i in range(2, input_size)] @@ -471,26 +497,23 @@ def body_fn(*loop_inputs): loop = _while_loop(cond_fn, loop_inputs, body_fn) outputs = loop(*inputs) outputs = _expr.TupleWrapper( - _expr.Tuple( - [ - _expr.TupleGetItem(outputs, i) - for i in range(input_size) - ] - ), - input_size + _expr.Tuple([_expr.TupleGetItem(outputs, i) for i in range(input_size)]), input_size ) return outputs -def _convert_function(module, graph, inputs, attr, node_func_name, prelude, gdef_lib, in_shapes=None, outputs=None): - """ Convert given tf node to a relay function call + +def _convert_function( + module, graph, inputs, attr, node_func_name, prelude, gdef_lib, in_shapes=None, outputs=None +): + """Convert given tf node to a relay function call Parameters ---------- - module : IRModule + module : IRModule where converted function is stored - graph: - top level tf graphdef + graph: + top level tf graphdef inputs : List[tvm.relay.Expr] List of input symbols. Parameters for the function. @@ -503,33 +526,37 @@ def _convert_function(module, graph, inputs, attr, node_func_name, prelude, gdef Returns ------- - op : tvm.relay.Expr + op : tvm.relay.Expr Examples -------- - a tf function "x+1", is implemented as a subgraph in the libary section of the graph. this subgraph is converted + a tf function "x+1", is implemented as a subgraph in the libary section of the graph. this subgraph is converted to a relay function such as - fn (%x: float32) { - add(%x, 1f) /* Identity */ - } + fn (%x: float32) { + add(%x, 1f) /* Identity */ + } - the subgraph has a function name such as __inference_add_95 + the subgraph has a function name such as __inference_add_95 the tf function call operator is returned as relay expression, such as: free_var %x: float32; @func___inference_add_95(%x) - + """ func = next( (f for f in graph.library.function if f.signature.name == node_func_name), None, ) if func is None: - raise Exception("Function not found - {}".format(node_func_name)) + raise Exception("Function not found - {}".format(node_func_name)) devices = set(node.device for node in func.node_def) if len(devices) > 1: - raise Exception("node_def in function {} contains > 1 types of devices {}".format(node_func_name, devices)) - + raise Exception( + "node_def in function {} contains > 1 types of devices {}".format( + node_func_name, devices + ) + ) + subgraph = gdef_lib[node_func_name] # preserve library functions in subgraphs to make them available to nested functions for fn in graph.library.function: @@ -550,10 +577,14 @@ def _convert_function(module, graph, inputs, attr, node_func_name, prelude, gdef except ValueError: # Construct relay nodes from the subgraph g1 = GraphProto(module) - output_sig = [func.ret[f.name] for f in func.signature.output_arg] if outputs is None else outputs + output_sig = ( + [func.ret[f.name] for f in func.signature.output_arg] if outputs is None else outputs + ) # TODO: unify prelude and main IRModules - sub_func, sub_params = g1.from_tensorflow(subgraph, outputs=output_sig, input_types=input_types, gdef_lib=gdef_lib) + sub_func, sub_params = g1.from_tensorflow( + subgraph, outputs=output_sig, input_types=input_types, gdef_lib=gdef_lib + ) module.params.update(sub_params) func_expr = _function.Function(sub_func.params, sub_func.body) global_func = tvm.relay.GlobalVar(func_name) @@ -578,12 +609,13 @@ def _convert_function(module, graph, inputs, attr, node_func_name, prelude, gdef ret = sb.get() return ret + def from_tensorflow(graph_def, layout="NHWC", shape=None, outputs=None): """convert tensorflow2.x graph into relay function. Parameters ---------- - graph_def : must be frozen graph (no variables allowed). + graph_def : must be frozen graph (no variables allowed). Placeholders are assumed to be inputs to the graph. tensorflow/core/framework/graph.proto @@ -600,7 +632,7 @@ def from_tensorflow(graph_def, layout="NHWC", shape=None, outputs=None): The layout for the model. shape : List[str, List[int]] - Input to the model. It is a key and shape vector mapping. Applies to placeholders. + Input to the model. It is a key and shape vector mapping. Applies to placeholders. outputs : List[str] The list of output nodes. The last node is treated as the output if not @@ -612,7 +644,7 @@ def from_tensorflow(graph_def, layout="NHWC", shape=None, outputs=None): The module that optimizations will be performed on. params : dict of str to tvm.nd.NDArray - Dict of converted parameters stored in tvm.nd.NDArray format. + Dict of converted parameters stored in tvm.nd.NDArray format. Examples -------- @@ -636,10 +668,12 @@ def @main(%x1: Tensor[(2, 2), float32], %add/y1: Tensor[(2, 2), float32]) { graph_def_library = {} for func in graph_def.library.function: inshape = func.attr["_input_shapes"].list.shape - graph_def_library[func.signature.name], _ = function_def_to_graph.function_def_to_graph_def(func, inshape) + graph_def_library[func.signature.name], _ = function_def_to_graph.function_def_to_graph_def( + func, inshape + ) module = RelayModule() g = GraphProto(module) func, params = g.from_tensorflow(graph_def, layout, shape, outputs, gdef_lib=graph_def_library) module.mod["main"] = func module.params.update(params) - return module.mod, module.params \ No newline at end of file + return module.mod, module.params From 95ce44397cf56e92ecf517d1eb82736c20657a6c Mon Sep 17 00:00:00 2001 From: Rohan Mukherjee Date: Tue, 1 Jun 2021 22:27:34 +0000 Subject: [PATCH 06/13] resolving a bug with passing output tensors for Functional Graphs --- python/tvm/relay/frontend/tensorflow2.py | 7 ++++--- tests/python/frontend/tensorflow2/common.py | 1 - .../frontend/tensorflow2/test_functional_models.py | 12 ++++++++---- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow2.py b/python/tvm/relay/frontend/tensorflow2.py index 3e11327660d3..550d927e1465 100644 --- a/python/tvm/relay/frontend/tensorflow2.py +++ b/python/tvm/relay/frontend/tensorflow2.py @@ -218,7 +218,7 @@ class GraphProto: """Capturing states when converting a tf graph to a single relay function.""" def __init__(self, module): - self._module: RelayModule = module + self._module = module self._prelude = self._module.prelude self._params = {} self._nodes = {} @@ -255,6 +255,7 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_ self._params[node.name] = param for node in graph.node: self._backtrack_construct(graph, node.name, outputs=outputs) + return self._func(graph, outputs) def _func(self, graph, outputs): @@ -288,6 +289,7 @@ def _func(self, graph, outputs): if fv.name_hint in self._params: final_params[fv.name_hint] = self._params[fv.name_hint] self._params = final_params + return func def _convert_operator(self, graph, op_name, node_name, inputs, attrs, outputs=None): @@ -309,7 +311,6 @@ def _convert_operator(self, graph, op_name, node_name, inputs, attrs, outputs=No sym : relay.op Converted relay operator """ - if op_name in ["PartitionedCall", "StatefulPartitionedCall"]: sym = _partition_call_operator( self._module, @@ -578,7 +579,7 @@ def _convert_function( # Construct relay nodes from the subgraph g1 = GraphProto(module) output_sig = ( - [func.ret[f.name] for f in func.signature.output_arg] if outputs is None else outputs + ([func.ret[f.name] for f in func.signature.output_arg]) if outputs is None else outputs ) # TODO: unify prelude and main IRModules diff --git a/tests/python/frontend/tensorflow2/common.py b/tests/python/frontend/tensorflow2/common.py index 76f4038e6e4d..4ab37d7018a2 100644 --- a/tests/python/frontend/tensorflow2/common.py +++ b/tests/python/frontend/tensorflow2/common.py @@ -92,7 +92,6 @@ def compare_tf_tvm(gdef, input_, output_, runtime="vm", output_tensors=None): output_tensors : List of output tensor names (Optional) if not specified then the last node is assumed as graph output. """ - mod, params = from_tensorflow(gdef, outputs=output_tensors) if runtime == "vm": exec_ = compile_vm(mod, params) diff --git a/tests/python/frontend/tensorflow2/test_functional_models.py b/tests/python/frontend/tensorflow2/test_functional_models.py index d3273f6a4455..9d16d81ff319 100644 --- a/tests/python/frontend/tensorflow2/test_functional_models.py +++ b/tests/python/frontend/tensorflow2/test_functional_models.py @@ -49,12 +49,12 @@ def _model_graph(TestClass): return gdef, input_, output -def run_func_graph(TestClass, runtime="vm"): - compare_tf_tvm(*_function_graph(TestClass), runtime=runtime) +def run_func_graph(TestClass, runtime="vm", outputs=None): + compare_tf_tvm(*_function_graph(TestClass), runtime=runtime, output_tensors=outputs) -def run_model_graph(TestClass): - compare_tf_tvm(*_model_graph(TestClass), runtime="vm") +def run_model_graph(TestClass, outputs=None): + compare_tf_tvm(*_model_graph(TestClass), runtime="vm", output_tensors=outputs) def run_all(TestClass): @@ -369,6 +369,7 @@ def func(self, x): y = 2 * x return x, y + run_func_graph(MultiOutput, runtime="vm", outputs=["Identity:output:0", "Identity_1:output:0"]) run_model_graph(MultiOutput) @@ -433,6 +434,9 @@ def func(self, x): r = tf.while_loop(cond, body, [i, j]) return r + run_func_graph( + StatelessWhile2Var, runtime="vm", outputs=["Identity:output:0", "Identity_1:output:0"] + ) run_model_graph(StatelessWhile2Var) From b07a58fae53de28734123b28f640ad5619856141 Mon Sep 17 00:00:00 2001 From: Rohan Mukherjee Date: Wed, 2 Jun 2021 00:43:32 +0000 Subject: [PATCH 07/13] fixing multi output for graph runtime --- python/tvm/relay/frontend/tensorflow2.py | 3 +-- tests/python/frontend/tensorflow2/common.py | 3 +-- tests/python/frontend/tensorflow2/test_functional_models.py | 3 +++ 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow2.py b/python/tvm/relay/frontend/tensorflow2.py index 550d927e1465..86ee7ec7ab74 100644 --- a/python/tvm/relay/frontend/tensorflow2.py +++ b/python/tvm/relay/frontend/tensorflow2.py @@ -261,6 +261,7 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_ def _func(self, graph, outputs): out = [] if outputs is None: + last_node = graph.node[-1] op = self._nodes[last_node.name.split(":")[0]] if last_node.op == "Exit": @@ -289,7 +290,6 @@ def _func(self, graph, outputs): if fv.name_hint in self._params: final_params[fv.name_hint] = self._params[fv.name_hint] self._params = final_params - return func def _convert_operator(self, graph, op_name, node_name, inputs, attrs, outputs=None): @@ -581,7 +581,6 @@ def _convert_function( output_sig = ( ([func.ret[f.name] for f in func.signature.output_arg]) if outputs is None else outputs ) - # TODO: unify prelude and main IRModules sub_func, sub_params = g1.from_tensorflow( subgraph, outputs=output_sig, input_types=input_types, gdef_lib=gdef_lib diff --git a/tests/python/frontend/tensorflow2/common.py b/tests/python/frontend/tensorflow2/common.py index 4ab37d7018a2..9686909ff31f 100644 --- a/tests/python/frontend/tensorflow2/common.py +++ b/tests/python/frontend/tensorflow2/common.py @@ -42,7 +42,6 @@ def run_tf_code(func, input_): f_out = func(tf.constant(input_)) if type(f_out) is dict: np_out = [f_out[k].numpy() for k in sorted(f_out.keys())] - elif type(f_out) is list: np_out = [x.numpy() for x in f_out] else: @@ -72,7 +71,7 @@ def run_graph_executor(lib, input_, ctx=tvm.cpu(0)): mod = runtime.GraphModule(lib["default"](ctx)) mod.set_input(0, input_) mod.run() - return [mod.get_output(0).asnumpy()] + return [mod.get_output(i).asnumpy() for i in range(mod.get_num_outputs())] def compare_tf_tvm(gdef, input_, output_, runtime="vm", output_tensors=None): diff --git a/tests/python/frontend/tensorflow2/test_functional_models.py b/tests/python/frontend/tensorflow2/test_functional_models.py index 9d16d81ff319..af384b7fb1bc 100644 --- a/tests/python/frontend/tensorflow2/test_functional_models.py +++ b/tests/python/frontend/tensorflow2/test_functional_models.py @@ -370,6 +370,9 @@ def func(self, x): return x, y run_func_graph(MultiOutput, runtime="vm", outputs=["Identity:output:0", "Identity_1:output:0"]) + run_func_graph( + MultiOutput, runtime="graph", outputs=["Identity:output:0", "Identity_1:output:0"] + ) run_model_graph(MultiOutput) From 897af9f57d5702db9288a3169c46e8a5f9bdd22d Mon Sep 17 00:00:00 2001 From: Rohan Mukherjee Date: Wed, 2 Jun 2021 00:54:08 +0000 Subject: [PATCH 08/13] adding docstring edits --- python/tvm/relay/frontend/tensorflow2.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/tvm/relay/frontend/tensorflow2.py b/python/tvm/relay/frontend/tensorflow2.py index 86ee7ec7ab74..97171aa3ad1b 100644 --- a/python/tvm/relay/frontend/tensorflow2.py +++ b/python/tvm/relay/frontend/tensorflow2.py @@ -369,9 +369,15 @@ def _backtrack_construct(self, graph, node_name, outputs=None): Parameters ---------- + graph : + TF2 frozen graph def + node_name : str node name + outputs : List[str] + List of output nodes + Returns ------- op : relay.Expr @@ -525,6 +531,9 @@ def _convert_function( node_func_name : str Name of tf2 node to be converted + outputs : List[str] + The list of output nodes. + Returns ------- op : tvm.relay.Expr From 54fcdbdfca06b3e1d706189285e5b0459be2e66a Mon Sep 17 00:00:00 2001 From: Rohan Mukherjee Date: Wed, 2 Jun 2021 01:41:06 +0000 Subject: [PATCH 09/13] linting + black --- python/tvm/relay/frontend/tensorflow2.py | 26 ++++++++++-------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow2.py b/python/tvm/relay/frontend/tensorflow2.py index 97171aa3ad1b..ccb4f257e56a 100644 --- a/python/tvm/relay/frontend/tensorflow2.py +++ b/python/tvm/relay/frontend/tensorflow2.py @@ -25,9 +25,12 @@ """ import numpy as np +from tensorflow.python.framework import function_def_to_graph +from tensorflow.python.framework import tensor_util +from tensorflow.python.framework import dtypes + import tvm -from tvm import relay from tvm.relay.transform import InferType from tvm.relay.prelude import Prelude from tvm.ir import IRModule @@ -35,17 +38,12 @@ from .. import analysis from .. import function as _function from ..loops import while_loop as _while_loop -from .common import infer_shape as _infer_shape from .common import infer_type as _infer_type -from tensorflow.python.framework import function_def_to_graph -from tensorflow.python.framework import tensor_util -from tensorflow.python.framework import dtypes - from .tensorflow import _convert_map as _convert_map_tf1 from .tensorflow import _need_prelude_for_shape_inference -from ..ty import Any, TensorType +from ..ty import Any __all__ = ["from_tensorflow"] @@ -206,7 +204,8 @@ def convert_place_holder(shape, node, in_type=None): class RelayModule: - """states related to the entire relay module (multiple functions) after converted from tf graphdef""" + """states related to the entire relay module (multiple functions) + after converted from tf graphdef""" def __init__(self): self.mod = IRModule({}) # relay function and type definitions. defined in tvm/ir/module.py @@ -389,10 +388,6 @@ def _backtrack_construct(self, graph, node_name, outputs=None): CallNode(Op(add), [Var(x, ty=TensorType([], float32)), Constant(1.0)], (nullptr), []) """ - try: - from tensorflow.python.framework import tensor_util - except ImportError as e: - raise ImportError("Unable to import tensorflow which is required {}".format(e)) input_op_name = node_name.split(":")[0].split("^")[-1] if input_op_name not in self._nodes: @@ -541,8 +536,8 @@ def _convert_function( Examples -------- - a tf function "x+1", is implemented as a subgraph in the libary section of the graph. this subgraph is converted - to a relay function such as + a tf function "x+1", is implemented as a subgraph in the libary section of the graph. + this subgraph is converted to a relay function such as fn (%x: float32) { add(%x, 1f) /* Identity */ } @@ -660,7 +655,8 @@ def from_tensorflow(graph_def, layout="NHWC", shape=None, outputs=None): "x+1" tf module where x has a shape of (2,2) is converted as follows: mod : tvm.IRModule - def @func___inference_add_95(%x: Tensor[(2, 2), float32], %add/y: Tensor[(2, 2), float32]) -> Tensor[(2, 2), float32] { + def @func___inference_add_95(%x: Tensor[(2, 2), float32], %add/y: Tensor[(2, 2), float32]) + -> Tensor[(2, 2), float32] { add(%x, %add/y) /* Identity */ /* ty=Tensor[(2, 2), float32] */ } From d051b46381126c3f8ce8172ca24e07bdc3c0d4b0 Mon Sep 17 00:00:00 2001 From: Rohan Mukherjee Date: Wed, 2 Jun 2021 01:56:37 +0000 Subject: [PATCH 10/13] linting + black --- python/tvm/relay/frontend/tensorflow2.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow2.py b/python/tvm/relay/frontend/tensorflow2.py index ccb4f257e56a..de61dd418494 100644 --- a/python/tvm/relay/frontend/tensorflow2.py +++ b/python/tvm/relay/frontend/tensorflow2.py @@ -54,6 +54,8 @@ def _infer_type_with_prelude(val, prelude): def set_span(sym, node_name): + """set span of symbol""" + span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0) if isinstance(sym, _expr.Call): sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span) @@ -227,15 +229,24 @@ def __init__(self, module): self._gdef_lib = {} def from_tensorflow( - self, graph, layout="NHWC", shape=None, outputs=None, input_types={}, gdef_lib={} + self, graph, layout="NHWC", shape=None, outputs=None, input_types=None, gdef_lib=None ): + if input_types is None: + input_types = {} + + if gdef_lib is None: + gdef_lib = {} + self._gdef_lib = gdef_lib func = self._get_relay_func( graph, layout=layout, shape=shape, outputs=outputs, input_types=input_types ) return func, self._params - def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_types={}): + def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_types=None): + if input_types is None: + input_types = {} + self._layout = layout for node in graph.node: name = node.name @@ -570,9 +581,9 @@ def _convert_function( # Computing subgraph's input shape and type dictionaries input_expr_dict = {} input_types = {} - for f_arg, input in zip(func.signature.input_arg, inputs): - input_expr_dict[f_arg.name] = input - input_types[f_arg.name] = _infer_type_with_prelude(input, prelude) + for f_arg, input_ in zip(func.signature.input_arg, inputs): + input_expr_dict[f_arg.name] = input_ + input_types[f_arg.name] = _infer_type_with_prelude(input_, prelude) func_name = "func_{}".format(func.signature.name) try: From 9b26ba2862bbc578fed208c7513210ab00b02963 Mon Sep 17 00:00:00 2001 From: Rohan Mukherjee Date: Wed, 2 Jun 2021 02:03:02 +0000 Subject: [PATCH 11/13] linting + black --- python/tvm/relay/frontend/tensorflow2.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/tvm/relay/frontend/tensorflow2.py b/python/tvm/relay/frontend/tensorflow2.py index de61dd418494..6774b267d36b 100644 --- a/python/tvm/relay/frontend/tensorflow2.py +++ b/python/tvm/relay/frontend/tensorflow2.py @@ -231,6 +231,9 @@ def __init__(self, module): def from_tensorflow( self, graph, layout="NHWC", shape=None, outputs=None, input_types=None, gdef_lib=None ): + """Wrapper to _get_relay_func which converts Tensorflow graph to Relay function + which is used as main function for the Relay module + """ if input_types is None: input_types = {} From fdb72c50423aee137e30a118804ed57c717b43f3 Mon Sep 17 00:00:00 2001 From: Rohan Mukherjee Date: Wed, 2 Jun 2021 03:39:23 +0000 Subject: [PATCH 12/13] removing unnecessary output propagation across function --- python/tvm/relay/frontend/tensorflow2.py | 30 ++++++------------- .../tensorflow2/test_functional_models.py | 4 +-- 2 files changed, 11 insertions(+), 23 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow2.py b/python/tvm/relay/frontend/tensorflow2.py index 6774b267d36b..90651556dc8b 100644 --- a/python/tvm/relay/frontend/tensorflow2.py +++ b/python/tvm/relay/frontend/tensorflow2.py @@ -267,14 +267,13 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_ if param: self._params[node.name] = param for node in graph.node: - self._backtrack_construct(graph, node.name, outputs=outputs) + self._backtrack_construct(graph, node.name) return self._func(graph, outputs) def _func(self, graph, outputs): out = [] if outputs is None: - last_node = graph.node[-1] op = self._nodes[last_node.name.split(":")[0]] if last_node.op == "Exit": @@ -305,7 +304,7 @@ def _func(self, graph, outputs): self._params = final_params return func - def _convert_operator(self, graph, op_name, node_name, inputs, attrs, outputs=None): + def _convert_operator(self, graph, op_name, node_name, inputs, attrs): """Convert from Tensorflow operator to relay operator. The converter must specify conversions explicitly for incompatible name, and apply handlers to operator attributes. @@ -332,7 +331,6 @@ def _convert_operator(self, graph, op_name, node_name, inputs, attrs, outputs=No attrs, self._prelude, gdef_lib=self._gdef_lib, - outputs=outputs, ) elif op_name in ["StatelessIf", "If"]: sym = _convert_if( @@ -360,7 +358,7 @@ def _convert_operator(self, graph, op_name, node_name, inputs, attrs, outputs=No sym = set_span(sym, node_name) return sym - def _backtrack_construct(self, graph, node_name, outputs=None): + def _backtrack_construct(self, graph, node_name): """Convert a specific tensorflow node to relay expression. If any of its ancestor node is not converted yet, backtrack as @@ -388,9 +386,6 @@ def _backtrack_construct(self, graph, node_name, outputs=None): node_name : str node name - outputs : List[str] - List of output nodes - Returns ------- op : relay.Expr @@ -417,10 +412,8 @@ def _backtrack_construct(self, graph, node_name, outputs=None): attr["_output_shapes"] = self._output_shapes[input_op_name] attr["_node_name"] = node.name attr["_target_layout"] = self._layout - inputs = [ - self._backtrack_construct(graph, iname, outputs=outputs) for iname in node.input - ] - op = self._convert_operator(graph, node.op, node.name, inputs, attr, outputs=outputs) + inputs = [self._backtrack_construct(graph, iname) for iname in node.input] + op = self._convert_operator(graph, node.op, node.name, inputs, attr) if isinstance(op, np.ndarray): self._params[node.name] = tvm.nd.array(op) @@ -444,11 +437,11 @@ def _backtrack_construct(self, graph, node_name, outputs=None): return out[0] -def _partition_call_operator(module, graph, inputs, attr, prelude, gdef_lib, outputs=None): +def _partition_call_operator(module, graph, inputs, attr, prelude, gdef_lib): """convert tf PartitionedCall node to a relay function call""" node_func_name = attr.get("f").name return _convert_function( - module, graph, inputs, attr, node_func_name, prelude, gdef_lib=gdef_lib, outputs=outputs + module, graph, inputs, attr, node_func_name, prelude, gdef_lib=gdef_lib ) @@ -519,7 +512,7 @@ def body_fn(*loop_inputs): def _convert_function( - module, graph, inputs, attr, node_func_name, prelude, gdef_lib, in_shapes=None, outputs=None + module, graph, inputs, attr, node_func_name, prelude, gdef_lib, in_shapes=None ): """Convert given tf node to a relay function call @@ -540,9 +533,6 @@ def _convert_function( node_func_name : str Name of tf2 node to be converted - outputs : List[str] - The list of output nodes. - Returns ------- op : tvm.relay.Expr @@ -596,9 +586,7 @@ def _convert_function( except ValueError: # Construct relay nodes from the subgraph g1 = GraphProto(module) - output_sig = ( - ([func.ret[f.name] for f in func.signature.output_arg]) if outputs is None else outputs - ) + output_sig = [func.ret[f.name] for f in func.signature.output_arg] # TODO: unify prelude and main IRModules sub_func, sub_params = g1.from_tensorflow( subgraph, outputs=output_sig, input_types=input_types, gdef_lib=gdef_lib diff --git a/tests/python/frontend/tensorflow2/test_functional_models.py b/tests/python/frontend/tensorflow2/test_functional_models.py index af384b7fb1bc..8b6af4611f0b 100644 --- a/tests/python/frontend/tensorflow2/test_functional_models.py +++ b/tests/python/frontend/tensorflow2/test_functional_models.py @@ -373,7 +373,7 @@ def func(self, x): run_func_graph( MultiOutput, runtime="graph", outputs=["Identity:output:0", "Identity_1:output:0"] ) - run_model_graph(MultiOutput) + run_model_graph(MultiOutput, outputs=["Identity:output:0"]) def test_if(): @@ -440,7 +440,7 @@ def func(self, x): run_func_graph( StatelessWhile2Var, runtime="vm", outputs=["Identity:output:0", "Identity_1:output:0"] ) - run_model_graph(StatelessWhile2Var) + run_model_graph(StatelessWhile2Var, outputs=["Identity:output:0"]) if __name__ == "__main__": From c683c960ae7b17cef4265d795f6ef5979333dc63 Mon Sep 17 00:00:00 2001 From: Rohan Mukherjee Date: Thu, 3 Jun 2021 17:32:37 +0000 Subject: [PATCH 13/13] addressed comments in PR --- python/tvm/relay/frontend/tensorflow2.py | 98 +++++++------------ .../tensorflow2/test_functional_models.py | 54 +++++----- 2 files changed, 64 insertions(+), 88 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow2.py b/python/tvm/relay/frontend/tensorflow2.py index 90651556dc8b..121cb12d4715 100644 --- a/python/tvm/relay/frontend/tensorflow2.py +++ b/python/tvm/relay/frontend/tensorflow2.py @@ -25,10 +25,7 @@ """ import numpy as np -from tensorflow.python.framework import function_def_to_graph -from tensorflow.python.framework import tensor_util -from tensorflow.python.framework import dtypes - +from tensorflow.python.framework import function_def_to_graph, tensor_util, dtypes import tvm from tvm.relay.transform import InferType @@ -40,7 +37,7 @@ from ..loops import while_loop as _while_loop from .common import infer_type as _infer_type -from .tensorflow import _convert_map as _convert_map_tf1 +from .tensorflow import _convert_map as _convert_map_common from .tensorflow import _need_prelude_for_shape_inference from ..ty import Any @@ -97,83 +94,55 @@ def convert_const_node(node, shape): def get_attr(buf): """convert value of a node attribute. node attribute is part of a node in a graph. - // tensorflow/core/framework/attr_value.proto - message AttrValue { - oneof value { - bytes s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - DataType type = 6; // "type" - TensorShapeProto shape = 7; // "shape" - TensorProto tensor = 8; // "tensor" - ListValue list = 1; // any "list(...)" } - } + Parameters ---------- buf: attrvalue protobuf. + Returns ------- The value of the attr, as a Python object. + + Raises: + ------- + ValueError: If this op does not have an attr with the given `name`. """ - fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"] - x = buf + fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"] ret = [] - if not x.WhichOneof("value"): + if not buf.WhichOneof("value"): return ret - if x.HasField("list"): + if buf.HasField("list"): for f in fields: - if getattr(x.list, f): + if getattr(buf.list, f): if f == "type": - ret += [dtypes.as_dtype(x) for x in list(getattr(x.list, f))] + ret += [dtypes.as_dtype(x) for x in list(getattr(buf.list, f))] else: - ret += list(getattr(x.list, f)) + ret += list(getattr(buf.list, f)) else: for f in fields: - if x.HasField(f): + if buf.HasField(f): if f == "type": - ret = dtypes.as_dtype(getattr(x, f)) + ret = dtypes.as_dtype(getattr(buf, f)) else: - ret = getattr(x, f) + ret = getattr(buf, f) return ret def parse_attr(attr_proto): """Convert node attributes (a serialized map of key-value pairs) in a node to a dict + Parameters ---------- attr_proto: - attributes of a tf node - protobuf message format: - // tensorflow/core/framework/node_def.proto - message NodeDef { - map attr = 5; - } + Returns ------- Dict {string: python object} - Examples - -------- - attributes in following node converted to {'_user_specified_name': b'x', 'dtype': tf.float32 } - node { - name: "x" - op: "Placeholder" - attr { - key: "_user_specified_name" - value { - s: "x" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } + """ attrs = {} for key, value in attr_proto.items(): @@ -182,12 +151,12 @@ def parse_attr(attr_proto): return attrs -def convert_place_holder(shape, node, in_type=None): - """convert tf place holder into relay var. +def convert_placeholder(shape, node, in_type=None): + """convert tf placeholder into relay var. - Examples + Example -------- - a tf place holder with name "x" is converted to [Var(x, ty=TensorType([], float32))] + a tf placeholder with name "x" is converted to [Var(x, ty=TensorType([], float32))] """ if shape and node.name in shape: @@ -210,9 +179,9 @@ class RelayModule: after converted from tf graphdef""" def __init__(self): - self.mod = IRModule({}) # relay function and type definitions. defined in tvm/ir/module.py - self.params = {} # for constants (weights) in the entire relay module - self.prelude = Prelude(self.mod) # relay.prelude needed for tensorlist ops + self.mod = IRModule({}) + self.params = {} + self.prelude = Prelude(self.mod) class GraphProto: @@ -258,7 +227,7 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_ in_type = None if node.name in input_types: in_type = input_types[node.name] - self._input_shapes[name], self._nodes[name] = convert_place_holder( + self._input_shapes[name], self._nodes[name] = convert_placeholder( shape, node, in_type ) elif node.op == "Const": @@ -311,8 +280,12 @@ def _convert_operator(self, graph, op_name, node_name, inputs, attrs): Parameters ---------- + graph: + TF2 frozen graph def op_name : str Operator name, such as Conv2D, AvgPool + node_name: str + Name of the node in TF2 graph, such as Identity:0 inputs : list of relay.op List of input symbols. attrs : dict @@ -347,11 +320,11 @@ def _convert_operator(self, graph, op_name, node_name, inputs, attrs): self._prelude, gdef_lib=self._gdef_lib, ) - elif op_name in _convert_map_tf1: + elif op_name in _convert_map_common: if _need_prelude_for_shape_inference(op_name): - sym = _convert_map_tf1[op_name](inputs, attrs, self._params, self._prelude) + sym = _convert_map_common[op_name](inputs, attrs, self._params, self._prelude) else: - sym = _convert_map_tf1[op_name](inputs, attrs, self._params, self._module.mod) + sym = _convert_map_common[op_name](inputs, attrs, self._params, self._module.mod) else: raise NotImplementedError("Operator {} not implemented.".format(op_name)) @@ -482,7 +455,6 @@ def convert_vars(loop_inputs, input_signature): None, ) loop_inputs = convert_vars(inputs, while_func.signature.input_arg) - # in_shapes = nodes[node_name].attr["output_shapes"].list.shape def cond_fn(*loop_inputs): return _convert_function( diff --git a/tests/python/frontend/tensorflow2/test_functional_models.py b/tests/python/frontend/tensorflow2/test_functional_models.py index 8b6af4611f0b..b3504ff38328 100644 --- a/tests/python/frontend/tensorflow2/test_functional_models.py +++ b/tests/python/frontend/tensorflow2/test_functional_models.py @@ -377,33 +377,37 @@ def func(self, x): def test_if(): - class If(tf.Module): - def get_input(self): - return np.ones((2, 2), dtype="float32") - - @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2), dtype=tf.float32)]) - def func(self, x): - @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2), dtype=tf.float32)]) - def double(x): - return 2 * x + def create_if_class(_condition=True): + class If(tf.Module): + def get_input(self): + return np.ones((2, 2), dtype="float32") @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2), dtype=tf.float32)]) - def triple(x): - return 3 * x - - cond = True - output = tf.raw_ops.If( - cond=cond, - input=[x], - Tout=[tf.float32], - output_shapes=[(2, 2)], - then_branch=double.get_concrete_function(), - else_branch=triple.get_concrete_function(), - ) - return output[0] - - run_func_graph(If, runtime="vm") - run_model_graph(If) + def func(self, x): + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2), dtype=tf.float32)]) + def double(x): + return 2 * x + + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2), dtype=tf.float32)]) + def triple(x): + return 3 * x + + output = tf.raw_ops.If( + cond=_condition, + input=[x], + Tout=[tf.float32], + output_shapes=[(2, 2)], + then_branch=double.get_concrete_function(), + else_branch=triple.get_concrete_function(), + ) + return output[0] + + return If + + for cond in [True, False]: + if_class = create_if_class(_condition=cond) + run_func_graph(if_class, runtime="vm") + run_model_graph(if_class) def test_stateless_while():