From 819f73cada6e0ec1701083d6aa82fd548331cc89 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Sat, 16 Mar 2019 04:38:00 +0000 Subject: [PATCH 1/5] decompile tf control flow --- python/tvm/relay/frontend/tensorflow.py | 157 ++++++++++- tests/python/relay/test_tf_loop_to_relay.py | 298 ++++++++++++++++++++ 2 files changed, 452 insertions(+), 3 deletions(-) create mode 100644 tests/python/relay/test_tf_loop_to_relay.py diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 0efebe3cfec9..4c30f0777ddd 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -5,6 +5,7 @@ import logging import warnings +from collections import defaultdict # Numpy support import numpy as np @@ -1270,6 +1271,100 @@ def _get_abs_layer_name(node): params, num_layers) return sym +_control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond'] + +class Branch: + """A class contains the components that are used to build up a Relay if + node. + """ + def __init__(self): + self._if = None + self.cond_vars = set() + self.cond = None + self.true_branch = None + self.false_branch = None + + def _if_node(self): + from tvm import relay + + cond_vars = [] + bind_map = {} + for i, var in enumerate(list(self.cond_vars)): + if not isinstance(var, _expr.Var): + raise TypeError("var is expected to be _expr.Var type, but " + "received {}".format(repr(var))) + v = relay.var("cond_var" + str(i), + type_annotation=var.type_annotation) + cond_vars.append(v) + bind_map[var] = v + + self.cond = relay.bind(self.cond, bind_map) + cond = relay.op.min(self.cond) + self.true_branch = relay.bind(self.true_branch, bind_map) + self.false_branch = relay.bind(self.false_branch, bind_map) + + return relay.If(cond, self.true_branch, self.false_branch) + + def if_node(self): + """Create a if node if it hasn't been created yet.""" + if self._if is None: + self._if = self._if_node() + return self._if + return self._if + + +class Loop: + """A class contains the components that are used to build up a Relay + recursive call. + """ + def __init__(self): + self.loop_vars = [] + self.cond = None + self.body = [] + self._loop = None + + def _while_loop(self): + from tvm import relay + wl = relay.var('while_loop') + sb = relay.scope_builder.ScopeBuilder() + + loop_vars = [] + bind_map = {} + for i, var in enumerate(self.loop_vars): + assert isinstance(var, _expr.Var), repr(var) + v = relay.var("loop_var" + str(i), + type_annotation=var.type_annotation) + loop_vars.append(v) + bind_map[var] = v + + self.cond = relay.bind(self.cond, bind_map) + self.body = [relay.bind(b, bind_map) for b in self.body] + + cond = relay.op.min(self.cond) + + with sb.if_scope(cond): + sb.ret(wl(*self.body)) + with sb.else_scope(): + sb.ret(relay.Tuple(loop_vars)) + + loop_fn = relay.Function(loop_vars, sb.get()) + sb = relay.scope_builder.ScopeBuilder() + sb.let(wl, loop_fn) + sb.ret(wl(*self.loop_vars)) + return sb.get() + + def while_loop(self): + if self._loop is None: + self._loop = self._while_loop() + return self._loop + return self._loop + + +def _in_while_loop(control_flow_node_map, op_name): + return op_name in control_flow_node_map and \ + "LoopCond" in control_flow_node_map[op_name] + + class GraphProto(object): """ A helper class for handling relay graph copying from Tensorflow GraphDef. Definition: @@ -1284,6 +1379,9 @@ def __init__(self): self._num_rnn_layer = False self._outputs_are_0d = {} self._input_shapes = {} + self._loops = {} + self._branches = {} + # self.module = relay.Module({}) def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): """Construct relay nodes from tensorflow graph definition - GraphDef. @@ -1332,7 +1430,10 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): raise NotImplementedError( \ "The following operators are not implemented: {}".format(missing_operators)) + control_flow_node_map = defaultdict(set) for node in graph.node: + node_name_prefix = node.name.rsplit('/', 1)[0] + control_flow_node_map[node_name_prefix].add(node.op) if node.op == 'Placeholder': if shape and node.name in shape: self._input_shapes[node.name] = list(shape[node.name]) @@ -1451,8 +1552,53 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): attr['_input_shapes'] = input_shapes attr['_input_0d_mismatch'] = input_0d_mismatch + node_name_prefix = node.name.rsplit('/', 1)[0] - op = self._convert_operator(node.op, inputs, attr, graph) + if node.op == "Merge": + if _in_while_loop(control_flow_node_map, node_name_prefix): + op = self._nodes[node.input[0]] + self._loops[node_name_prefix] = Loop() + else: + if len(self._branches) == 0: + raise RuntimeError("Cannot find a created " + "conditional for merge node") + branch = self._branches[node_name_prefix] + false_br = self._nodes[node.input[0]] + true_br = self._nodes[node.input[1]] + assert len(true_br) == 1 + assert len(false_br) == 1 + branch.true_branch = true_br[0] + branch.false_branch = false_br[0] + op = [branch.if_node()] + # del self._branches[node_name_prefix] + elif node.op == "Exit": + loop = self._loops[node_name_prefix] + exit_name = node.name.split('/')[-1] + assert str.startswith(exit_name, 'Exit') + exit_number = int("0" + exit_name[4:]) + expr = loop.while_loop() + op = _expr.TupleGetItem(expr, exit_number) + elif node.op == "Enter": + op = self._nodes[node.input[0]] + elif node.op == "LoopCond": + op = self._nodes[node.input[0]] + assert len(op) == 1 + self._loops[node_name_prefix].cond = op[0] + elif node.op == "Switch": + op = self._nodes[node.input[0]] + assert len(op) == 1 + if _in_while_loop(control_flow_node_map, node_name_prefix): + self._loops[node_name_prefix].loop_vars.append(op[0]) + else: + if node_name_prefix not in self._branches: + self._branches[node_name_prefix] = Branch() + self._branches[node_name_prefix].cond = ir_pass.infer_type(op[0]) + elif node.op == "NextIteration": + op = self._nodes[node.input[0]] + assert len(op) == 1 + self._loops[node_name_prefix].body.append(op[0]) + else: + op = self._convert_operator(node.op, inputs, attr, graph) # Check if op is converted to param if isinstance(op, np.ndarray): @@ -1493,7 +1639,10 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): out = [] if outputs is None: - out = op + if node.op == "Exit": + out = [op[0].tuple_value] + else: + out = op else: for out_name in outputs: if ":" in out_name: @@ -1529,7 +1678,9 @@ def _parse_import_prerequisites(self, graph): elif node.op == "Const": pass else: - if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]): + if any([node.op in t for t in [_identity_list, _convert_map, + _convert_map_rnn, + _control_flow_nodes]]): pass else: missing_operators.add(node.op) diff --git a/tests/python/relay/test_tf_loop_to_relay.py b/tests/python/relay/test_tf_loop_to_relay.py new file mode 100644 index 000000000000..49196123274c --- /dev/null +++ b/tests/python/relay/test_tf_loop_to_relay.py @@ -0,0 +1,298 @@ +"""Unit tests for converting TensorFlow control flow op to Relay.""" +import tensorflow as tf +import numpy as np +from tvm import relay +from tvm.relay.frontend.tensorflow import from_tensorflow + + +def check_equal(graph, tf_out): + expr, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) + ex = relay.create_executor('debug') + relay_out = ex.evaluate(expr)(**params) + if isinstance(relay_out, relay.backend.interpreter.TensorValue): + np.testing.assert_allclose(tf_out, relay_out.asnumpy()) + else: + if not isinstance(tf_out, list): + tf_out = [tf_out] + for x, y in zip(tf_out, [r.asnumpy() for r in relay_out]): + np.testing.assert_allclose(x, y) + + +def vanilla_loop(): + graph = tf.Graph() + with graph.as_default(): + i = tf.constant(0) + + def c(i): return tf.less(i, 10) + + def b(i): return tf.add(i, 1) + r = tf.while_loop(c, b, [i]) + + with tf.Session() as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +def loop_2_vars(): + graph = tf.Graph() + with graph.as_default(): + i0 = tf.constant(0) + j0 = tf.ones([2, 2]) + + def c(i, j): return i < 10 + + def b(i, j): return [tf.add(i, 1), j] + i1, i2 = tf.while_loop(c, b, loop_vars=[i0, j0]) + i1 += tf.constant(1337) + + with tf.Session() as sess: + tf_out = sess.run(i1) + + check_equal(graph, tf_out) + + +def loop_3_vars(): + graph = tf.Graph() + with graph.as_default(): + i0 = tf.constant(1) + j0 = tf.constant(2) + k0 = tf.constant(4) + + def c(i, j, k): return i < 10 + + def b(i, j, k): return [i+1, j * k, k + i] + r = tf.while_loop(c, b, loop_vars=[i0, j0, k0]) + + with tf.Session() as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +def loop_conditions(): + graph = tf.Graph() + with graph.as_default(): + i = tf.constant(1) + j = tf.constant(1) + k = tf.constant(5) + + def c(i, j, k): return \ + tf.equal(tf.not_equal(tf.less(i + j, 10), + tf.less(j * k, 100)), + tf.greater_equal(k, i + j)) + + def b(i, j, k): return [i+j, j+k, k+1] + r = tf.while_loop(c, b, loop_vars=[i, j, k]) + with tf.Session() as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +def loop_bodies(): + graph = tf.Graph() + with graph.as_default(): + def body(x): + a = tf.constant(np.array([[5, 6], [7, 8]]), dtype=tf.int32) + b = tf.constant(np.array([[1, 2], [3, 4]]), dtype=tf.int32) + c = a + b + return tf.nn.relu(x + c) + + def condition(x): + return tf.reduce_sum(x) < 100 + x = tf.constant(0, shape=[2, 2]) + r = tf.while_loop(condition, body, [x]) + with tf.Session() as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +def nested_loop(): + graph = tf.Graph() + with graph.as_default(): + + def body(x): + def nest_body(c): + return tf.multiply(c, 2) + def cd(c): return tf.less(c, 10) + c = tf.constant(2) + res = tf.while_loop(cd, nest_body, loop_vars=[c]) + return tf.nn.relu(x + res) + + def condition(x): + return tf.greater(x, 100) + x = tf.constant(3) + r = tf.while_loop(condition, body, loop_vars=[x]) + with tf.Session() as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +def vanilla_cond(): + graph = tf.Graph() + with graph.as_default(): + i = tf.constant(1) + j = tf.constant(4) + + def f1(): + return tf.multiply(1, 17) + + def f2(): + return tf.add(4, 23) + r = tf.cond(tf.less(i, j), f1, f2) + + with tf.Session(graph=graph) as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +def multiple_cond_vars(): + graph = tf.Graph() + with graph.as_default(): + x1 = tf.constant(7) + x2 = tf.constant(12) + z = tf.constant(20) + r = tf.cond(tf.less(tf.add(x1, x2), 10), + lambda: tf.add(10, 2), lambda: tf.square(5)) + with tf.Session() as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +def cond_fn_parameters(): + graph = tf.Graph() + with graph.as_default(): + def fn1(x, y): + return tf.multiply(5, 6) + + def fn2(x, y): + return tf.add(3, 4) + + i = tf.constant(1) + j = tf.constant(2) + k = tf.constant(3) + r = tf.cond(tf.less(i, j), lambda: fn1(i, k), lambda: fn2(j, k)) + with tf.Session() as sess: + tf_out = sess.run(r, feed_dict={i: 1, j: 2, k: 3}) + + check_equal(graph, tf_out) + + +def nested_cond(): + graph = tf.Graph() + with graph.as_default(): + def fn1(a, b): + def nest_fn1(): + return tf.add(1, 2) + + def nest_fn2(): + return tf.subtract(10, 5) + + res = tf.cond(tf.less(1, 2), nest_fn1, nest_fn2) + return tf.multiply(tf.add(87, res), 10) + + def fn2(a, b): + return tf.add(10, 10) + + x = tf.constant(5) + y = tf.constant(6) + z = tf.constant(7) + pred = tf.less(x, y) + r = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z)) + with tf.Session() as sess: + tf_out = sess.run(r, feed_dict={x: 1, y: 2, z: 3, pred: True}) + + check_equal(graph, tf_out) + + +def loop_in_cond(): + graph = tf.Graph() + with graph.as_default(): + def fn1(a, b): + i = tf.constant(0) + + def cd(i): return tf.less(i, 10) + + def bd(i): return tf.add(i, 1) + res = tf.while_loop(cd, bd, [i]) + return tf.multiply(tf.add(20, res), 10) + + def fn2(a, b): + return tf.add(10, 20) + + x = tf.constant(7) + y = tf.constant(20) + z = tf.constant(10) + pred = tf.less(x, y) + r = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z)) + + with tf.Session() as sess: + tf_out = sess.run(r, feed_dict={x: 1, y: 2, z: 3, pred: True}) + + check_equal(graph, tf_out) + + +def cond_in_loop(): + graph = tf.Graph() + with graph.as_default(): + def body(x): + x = tf.constant(7) + z = tf.constant(20) + res = tf.cond(tf.less(x, 10), lambda: tf.add( + 10, 20), lambda: tf.square(10)) + return tf.multiply(res, x) + + x = tf.constant(21) + def condition(x): + return tf.less(x, 100) + + r = tf.while_loop(condition, body, loop_vars=[x]) + with tf.Session() as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +def loop_lambda_placeholder(): + graph = tf.Graph() + with graph.as_default(): + c = lambda i, j: tf.equal(tf.less(i, 17), tf.greater(j, 7)) + b = lambda i, j: [i + 3, j - 13] + + i = tf.placeholder(tf.float32) + j = tf.placeholder(tf.float32) + r = tf.while_loop(c, b, loop_vars=[i, j]) + + with tf.Session() as sess: + tf_out = sess.run(r, feed_dict={i: -203, j: 107}) + + check_equal(graph, tf_out) + + +if __name__ == "__main__": + + # tf.while_loop + vanilla_loop() + loop_2_vars() + loop_3_vars() + loop_conditions() + loop_bodies() + + # tf.cond + vanilla_cond() + multiple_cond_vars() + cond_fn_parameters() + + # nested cases + nested_loop() + nested_cond() + loop_in_cond() + cond_in_loop() + + # w/ placeholder and lambda + loop_lambda_placeholder() From ca3d6dc5e726a1bfa785910aca6db752a537043a Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Sat, 16 Mar 2019 23:27:03 +0000 Subject: [PATCH 2/5] Add docs --- python/tvm/relay/frontend/tensorflow.py | 305 +++++++++++++++----- src/relay/backend/interpreter.cc | 36 ++- tests/python/relay/test_tf_loop_to_relay.py | 73 ++--- 3 files changed, 291 insertions(+), 123 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 4c30f0777ddd..f1272f9b7dd3 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -10,6 +10,7 @@ import numpy as np import tvm +from tvm import relay from topi.util import get_const_tuple from .. import ir_pass from .. import expr as _expr @@ -1271,42 +1272,109 @@ def _get_abs_layer_name(node): params, num_layers) return sym +# An internal list to contain all the control flow primitives used in Tensorflow +# 1.x. _control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond'] +def _in_while_loop(control_flow_node_map, op_name): + """ + Check if a given control flow operator is part of a while loop execution + frame. This is based on the fact that there is only one occurrence of + `LoopCond` for a loop execution frame and it is only presented in the loop + construct. + + Parameters + ---------- + control_flow_node_map : Dict[str, Set[str]] + A dictionay contains the unqiue control flow execution frame name to + a set of primitive operators mapping. + + op_name : str + The name of a control flow primitive. + + Returns + ------- + ret : bool + Return true if the operator is in a while loop execution frame, + otherwise, return false. + """ + return op_name in control_flow_node_map and \ + "LoopCond" in control_flow_node_map[op_name] + + class Branch: """A class contains the components that are used to build up a Relay if node. + + Parameters + ---------- + cond : tvm.relay.Expr + The condition of a if node. + + true_branch : tvm.relay.Expr + The body of the true branch of a if expression. + + false_branch: tvm.relay.Expr + The body of the false branch of a if expression. + + _if : tvm.relay.Expr + An internal variable indicates where an if expression is already created + for a matched TF condition construct. + + Examples + -------- + The following is a cond statement written in TensorFlow: + + .. code-block:: python + + def vanilla_cond(): + i = tf.constant(1) + j = tf.constant(4) + + def f1(): + return tf.multiply(1, 17) + + def f2(): + return tf.add(4, 23) + r = tf.cond(tf.less(i, j), f1, f2) + + This condition statement should be coverted into Relay in the following + form: + + .. code-block:: python + + fn (%Const: Tensor[(1,), int32], + %Const_1: Tensor[(1,), int32], + %cond/Mul/x: Tensor[(1,), int32], + %cond/Mul/y: Tensor[(1,), int32], + %cond/Add/x: Tensor[(1,), int32], + %cond/Add/y: Tensor[(1,), int32]) { + %0 = less(%Const, %Const_1) # ty=Tensor[(1,), bool] + %1 = min(%0) + if (%1) { + %2 = multiply(%cond/Mul/x, %cond/Mul/y) + %2 + } else { + %3 = add(%cond/Add/x, %cond/Add/y) + %3 + } + } """ def __init__(self): self._if = None - self.cond_vars = set() self.cond = None self.true_branch = None self.false_branch = None def _if_node(self): - from tvm import relay - - cond_vars = [] - bind_map = {} - for i, var in enumerate(list(self.cond_vars)): - if not isinstance(var, _expr.Var): - raise TypeError("var is expected to be _expr.Var type, but " - "received {}".format(repr(var))) - v = relay.var("cond_var" + str(i), - type_annotation=var.type_annotation) - cond_vars.append(v) - bind_map[var] = v - - self.cond = relay.bind(self.cond, bind_map) + """An internal API to create a relay if node from the matched TF + condition construct. + """ cond = relay.op.min(self.cond) - self.true_branch = relay.bind(self.true_branch, bind_map) - self.false_branch = relay.bind(self.false_branch, bind_map) - return relay.If(cond, self.true_branch, self.false_branch) def if_node(self): - """Create a if node if it hasn't been created yet.""" + """Create an tvm.relay.If node if it hasn't been created yet.""" if self._if is None: self._if = self._if_node() return self._if @@ -1314,8 +1382,60 @@ def if_node(self): class Loop: - """A class contains the components that are used to build up a Relay + """ + A class contains the components that are used to build up a Relay recursive call. + + Parameters + ---------- + loop_vars : List[tvm.relay.Expr] + The loop variables that used in a while loop. + + cond : tvm.relay.Expr + The condition of a while loop. + + body : tvm.relay.Expr + The body of a matched while loop. + + _loop : tvm.relay.Expr + An internal variable indicates where a recursive call is already created + for a matched TF while loop construct. + + Examples + -------- + The following is a vanilla loop from TensorFlow: + + .. code-block:: python + + i = tf.constant(0) + c = lambda i: tf.less(i, 10) + b = lambda i: tf.add(i, 1) + r = tf.while_loop(c, b, [i]) + + It will be converted to the following recursive call in Relay: + + .. code-block:: python + + fn (%while/Less/y: Tensor[(1,), int32], + %while/Add/y: Tensor[(1,), int32], + %Const: Tensor[(1,), int32]) { + %0 = fn(%loop_var0: Tensor[(1,), int32]) { + %1 = less(%loop_var0, %while/Less/y) + %2 = min(%1) + if (%2) { + %3 = add(%loop_var0, %while/Add/y) + free_var %while_loop + %4 = %while_loop(%3) + %4 + } else { + %5 = (%loop_var0,) + %5 + } + } + let %while_loop1 = %0 + %6 = %while_loop1(%Const) + %6 + } """ def __init__(self): self.loop_vars = [] @@ -1324,8 +1444,11 @@ def __init__(self): self._loop = None def _while_loop(self): - from tvm import relay + """An internal API to create a Relay recurisve call for a matched TF + `while_loop` construct. + """ wl = relay.var('while_loop') + sb = relay.scope_builder.ScopeBuilder() loop_vars = [] @@ -1354,17 +1477,13 @@ def _while_loop(self): return sb.get() def while_loop(self): + """Instantiate a while loop if it has not been created yet.""" if self._loop is None: self._loop = self._while_loop() return self._loop return self._loop -def _in_while_loop(control_flow_node_map, op_name): - return op_name in control_flow_node_map and \ - "LoopCond" in control_flow_node_map[op_name] - - class GraphProto(object): """ A helper class for handling relay graph copying from Tensorflow GraphDef. Definition: @@ -1381,7 +1500,6 @@ def __init__(self): self._input_shapes = {} self._loops = {} self._branches = {} - # self.module = relay.Module({}) def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): """Construct relay nodes from tensorflow graph definition - GraphDef. @@ -1548,55 +1666,15 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): # This means the node is 1d in Relay and 0d in TF. # See `_expand_dims_0d_aware`. if self._outputs_are_0d[node_name][tensor_slot] and input_shape: - input_0d_mismatch.add(in_sym) + input_0d_mismatch.add(in_sym[0]) attr['_input_shapes'] = input_shapes attr['_input_0d_mismatch'] = input_0d_mismatch - node_name_prefix = node.name.rsplit('/', 1)[0] - if node.op == "Merge": - if _in_while_loop(control_flow_node_map, node_name_prefix): - op = self._nodes[node.input[0]] - self._loops[node_name_prefix] = Loop() - else: - if len(self._branches) == 0: - raise RuntimeError("Cannot find a created " - "conditional for merge node") - branch = self._branches[node_name_prefix] - false_br = self._nodes[node.input[0]] - true_br = self._nodes[node.input[1]] - assert len(true_br) == 1 - assert len(false_br) == 1 - branch.true_branch = true_br[0] - branch.false_branch = false_br[0] - op = [branch.if_node()] - # del self._branches[node_name_prefix] - elif node.op == "Exit": - loop = self._loops[node_name_prefix] - exit_name = node.name.split('/')[-1] - assert str.startswith(exit_name, 'Exit') - exit_number = int("0" + exit_name[4:]) - expr = loop.while_loop() - op = _expr.TupleGetItem(expr, exit_number) - elif node.op == "Enter": - op = self._nodes[node.input[0]] - elif node.op == "LoopCond": - op = self._nodes[node.input[0]] - assert len(op) == 1 - self._loops[node_name_prefix].cond = op[0] - elif node.op == "Switch": - op = self._nodes[node.input[0]] - assert len(op) == 1 - if _in_while_loop(control_flow_node_map, node_name_prefix): - self._loops[node_name_prefix].loop_vars.append(op[0]) - else: - if node_name_prefix not in self._branches: - self._branches[node_name_prefix] = Branch() - self._branches[node_name_prefix].cond = ir_pass.infer_type(op[0]) - elif node.op == "NextIteration": - op = self._nodes[node.input[0]] - assert len(op) == 1 - self._loops[node_name_prefix].body.append(op[0]) + if node.op in _control_flow_nodes: + op = self._convert_control_flow_operator(node, inputs, + attr, + control_flow_node_map) else: op = self._convert_operator(node.op, inputs, attr, graph) @@ -1807,6 +1885,89 @@ def _convert_rnn_operator(self, op_name, inputs, sym = self.rnn.process_op(op_name, inputs, attrs, params) return sym + def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_map): + """ + Convert the Relay control flow primitive into corresponding component + of a Relay control flow construct, i.e. `tf.cond` and `tf.while_loop` + are converted in Relay `If` and recusrive call, respectively. + + Parameters + ---------- + node: TensorFlow graph node object. + A TensorFlow graph node object. + + inputs : List[tvm.relay.Expr] + List of input symbols. + + attrs : Dict[tvm.Attrs] + Dict of operator attributes. + + control_flow_node_map : Dict[str, Set[str]] + A dictionary contains the execution frame name to primitives + mapping. + + Returns + ------- + op : tvm.relay.Expr + Converted relay expression. + """ + node_name_prefix = node.name.rsplit('/', 1)[0] + if node.op == "Merge": + if _in_while_loop(control_flow_node_map, node_name_prefix): + op = self._nodes[node.input[0]] + self._loops[node_name_prefix] = Loop() + else: + if len(self._branches) == 0: + raise RuntimeError("Cannot find a created " + "conditional for merge node") + branch = self._branches[node_name_prefix] + false_br = self._nodes[node.input[0]] + true_br = self._nodes[node.input[1]] + assert len(true_br) == 1 + assert len(false_br) == 1 + branch.true_branch = true_br[0] + branch.false_branch = false_br[0] + op = [branch.if_node()] + elif node.op == "Exit": + loop = self._loops[node_name_prefix] + exit_name = node.name.split('/')[-1] + assert str.startswith(exit_name, 'Exit') + + # TensorFlow has differen naming convention on different + # versions. + if '_' in exit_name: + exit_number = int("0" + exit_name[5:]) + else: + exit_number = int("0" + exit_name[4:]) + + expr = loop.while_loop() + op = _expr.TupleGetItem(expr, exit_number) + elif node.op == "Enter": + op = self._nodes[node.input[0]] + elif node.op == "LoopCond": + op = self._nodes[node.input[0]] + assert len(op) == 1 + self._loops[node_name_prefix].cond = op[0] + elif node.op == "Switch": + op = self._nodes[node.input[0]] + assert len(op) == 1 + if _in_while_loop(control_flow_node_map, node_name_prefix): + self._loops[node_name_prefix].loop_vars.append(op[0]) + else: + if node_name_prefix not in self._branches: + self._branches[node_name_prefix] = Branch() + self._branches[node_name_prefix].cond = ir_pass.infer_type(op[0]) + elif node.op == "NextIteration": + op = self._nodes[node.input[0]] + assert len(op) == 1 + self._loops[node_name_prefix].body.append(op[0]) + else: + raise Exception("Cannot identify control flow operator: " + + "{}".format(node.op)) + + return op + + def _convert_operator(self, op_name, inputs, attrs, graph, identity_list=None, convert_map=None): """Convert from Tensorflow operator to relay operator. diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 3128d2a71159..cf0c6fa91a8b 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -270,16 +270,30 @@ class Interpreter : return TupleValueNode::make(values); } - Value VisitExpr_(const FunctionNode* func_node) final { - auto func = GetRef(func_node); + // TODO(@jroesch): this doesn't support mutual letrec. + Value MakeClosure(const Function& func, const Var& letrec_name = Var()) { tvm::Map captured_mod; Array free_vars = FreeVars(func); for (const auto& var : free_vars) { - captured_mod.Set(var, Eval(var)); + // Evaluate the free var (which could be a function call) if it hasn't + // shown up in a letting binding that has invoked the function. + if (!letrec_name.defined() || letrec_name != var) { + captured_mod.Set(var, Eval(var)); + } } - return ClosureNode::make(captured_mod, func); + // We must use mutation here to build a self referential closure. + auto closure = ClosureNode::make(captured_mod, func); + auto mut_closure = + static_cast(const_cast(closure.get())); + mut_closure->env.Set(letrec_name, closure); + return closure; + } + + Value VisitExpr_(const FunctionNode* func_node) final { + auto func = GetRef(func_node); + return MakeClosure(func); } Value InvokePrimitiveOp(Function func, @@ -438,10 +452,16 @@ class Interpreter : } } - Value VisitExpr_(const LetNode* op) final { - auto value = Eval(op->value); - this->extend(op->var, value); - return Eval(op->body); + Value VisitExpr_(const LetNode* let) final { + if (auto func = let->value.as()) { + auto clo = MakeClosure(GetRef(func), let->var); + this->extend(let->var, clo); + } else { + auto value = Eval(let->value); + this->extend(let->var, value); + } + + return Eval(let->body); } Value VisitExpr_(const TupleGetItemNode* op) final { diff --git a/tests/python/relay/test_tf_loop_to_relay.py b/tests/python/relay/test_tf_loop_to_relay.py index 49196123274c..c5b38c319467 100644 --- a/tests/python/relay/test_tf_loop_to_relay.py +++ b/tests/python/relay/test_tf_loop_to_relay.py @@ -18,7 +18,7 @@ def check_equal(graph, tf_out): np.testing.assert_allclose(x, y) -def vanilla_loop(): +def test_vanilla_loop(): graph = tf.Graph() with graph.as_default(): i = tf.constant(0) @@ -26,6 +26,7 @@ def vanilla_loop(): def c(i): return tf.less(i, 10) def b(i): return tf.add(i, 1) + r = tf.while_loop(c, b, [i]) with tf.Session() as sess: @@ -34,7 +35,7 @@ def b(i): return tf.add(i, 1) check_equal(graph, tf_out) -def loop_2_vars(): +def test_loop_2_vars(): graph = tf.Graph() with graph.as_default(): i0 = tf.constant(0) @@ -43,6 +44,7 @@ def loop_2_vars(): def c(i, j): return i < 10 def b(i, j): return [tf.add(i, 1), j] + i1, i2 = tf.while_loop(c, b, loop_vars=[i0, j0]) i1 += tf.constant(1337) @@ -52,7 +54,7 @@ def b(i, j): return [tf.add(i, 1), j] check_equal(graph, tf_out) -def loop_3_vars(): +def test_loop_3_vars(): graph = tf.Graph() with graph.as_default(): i0 = tf.constant(1) @@ -70,7 +72,7 @@ def b(i, j, k): return [i+1, j * k, k + i] check_equal(graph, tf_out) -def loop_conditions(): +def test_loop_conditions(): graph = tf.Graph() with graph.as_default(): i = tf.constant(1) @@ -90,7 +92,7 @@ def b(i, j, k): return [i+j, j+k, k+1] check_equal(graph, tf_out) -def loop_bodies(): +def test_loop_bodies(): graph = tf.Graph() with graph.as_default(): def body(x): @@ -109,7 +111,7 @@ def condition(x): check_equal(graph, tf_out) -def nested_loop(): +def test_nested_loop(): graph = tf.Graph() with graph.as_default(): @@ -125,13 +127,14 @@ def condition(x): return tf.greater(x, 100) x = tf.constant(3) r = tf.while_loop(condition, body, loop_vars=[x]) + with tf.Session() as sess: tf_out = sess.run(r) check_equal(graph, tf_out) -def vanilla_cond(): +def test_vanilla_cond(): graph = tf.Graph() with graph.as_default(): i = tf.constant(1) @@ -150,7 +153,7 @@ def f2(): check_equal(graph, tf_out) -def multiple_cond_vars(): +def test_multiple_cond_vars(): graph = tf.Graph() with graph.as_default(): x1 = tf.constant(7) @@ -158,13 +161,14 @@ def multiple_cond_vars(): z = tf.constant(20) r = tf.cond(tf.less(tf.add(x1, x2), 10), lambda: tf.add(10, 2), lambda: tf.square(5)) + with tf.Session() as sess: tf_out = sess.run(r) check_equal(graph, tf_out) -def cond_fn_parameters(): +def test_cond_fn_parameters(): graph = tf.Graph() with graph.as_default(): def fn1(x, y): @@ -177,13 +181,14 @@ def fn2(x, y): j = tf.constant(2) k = tf.constant(3) r = tf.cond(tf.less(i, j), lambda: fn1(i, k), lambda: fn2(j, k)) + with tf.Session() as sess: tf_out = sess.run(r, feed_dict={i: 1, j: 2, k: 3}) check_equal(graph, tf_out) -def nested_cond(): +def test_nested_cond(): graph = tf.Graph() with graph.as_default(): def fn1(a, b): @@ -204,13 +209,14 @@ def fn2(a, b): z = tf.constant(7) pred = tf.less(x, y) r = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z)) + with tf.Session() as sess: tf_out = sess.run(r, feed_dict={x: 1, y: 2, z: 3, pred: True}) check_equal(graph, tf_out) -def loop_in_cond(): +def test_loop_in_cond(): graph = tf.Graph() with graph.as_default(): def fn1(a, b): @@ -237,7 +243,7 @@ def fn2(a, b): check_equal(graph, tf_out) -def cond_in_loop(): +def test_cond_in_loop(): graph = tf.Graph() with graph.as_default(): def body(x): @@ -258,41 +264,22 @@ def condition(x): check_equal(graph, tf_out) -def loop_lambda_placeholder(): - graph = tf.Graph() - with graph.as_default(): - c = lambda i, j: tf.equal(tf.less(i, 17), tf.greater(j, 7)) - b = lambda i, j: [i + 3, j - 13] - - i = tf.placeholder(tf.float32) - j = tf.placeholder(tf.float32) - r = tf.while_loop(c, b, loop_vars=[i, j]) - - with tf.Session() as sess: - tf_out = sess.run(r, feed_dict={i: -203, j: 107}) - - check_equal(graph, tf_out) - - if __name__ == "__main__": # tf.while_loop - vanilla_loop() - loop_2_vars() - loop_3_vars() - loop_conditions() - loop_bodies() + test_vanilla_loop() + test_loop_2_vars() + test_loop_3_vars() + test_loop_conditions() + test_loop_bodies() # tf.cond - vanilla_cond() - multiple_cond_vars() - cond_fn_parameters() + test_vanilla_cond() + test_multiple_cond_vars() + test_cond_fn_parameters() # nested cases - nested_loop() - nested_cond() - loop_in_cond() - cond_in_loop() - - # w/ placeholder and lambda - loop_lambda_placeholder() + test_nested_loop() + test_nested_cond() + test_loop_in_cond() + test_cond_in_loop() From 6d46bff4495963e11d8bec4768416c4e3a9e963d Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Thu, 21 Mar 2019 19:47:54 +0000 Subject: [PATCH 3/5] remove import relay --- python/tvm/relay/frontend/tensorflow.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index f1272f9b7dd3..255546735d9b 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -10,7 +10,6 @@ import numpy as np import tvm -from tvm import relay from topi.util import get_const_tuple from .. import ir_pass from .. import expr as _expr @@ -1370,8 +1369,8 @@ def _if_node(self): """An internal API to create a relay if node from the matched TF condition construct. """ - cond = relay.op.min(self.cond) - return relay.If(cond, self.true_branch, self.false_branch) + cond = tvm.relay.op.min(self.cond) + return tvm.relay.If(cond, self.true_branch, self.false_branch) def if_node(self): """Create an tvm.relay.If node if it hasn't been created yet.""" @@ -1447,31 +1446,31 @@ def _while_loop(self): """An internal API to create a Relay recurisve call for a matched TF `while_loop` construct. """ - wl = relay.var('while_loop') + wl = tvm.relay.var('while_loop') - sb = relay.scope_builder.ScopeBuilder() + sb = tvm.relay.scope_builder.ScopeBuilder() loop_vars = [] bind_map = {} for i, var in enumerate(self.loop_vars): assert isinstance(var, _expr.Var), repr(var) - v = relay.var("loop_var" + str(i), - type_annotation=var.type_annotation) + v = tvm.relay.var("loop_var" + str(i), + type_annotation=var.type_annotation) loop_vars.append(v) bind_map[var] = v - self.cond = relay.bind(self.cond, bind_map) - self.body = [relay.bind(b, bind_map) for b in self.body] + self.cond = tvm.relay.bind(self.cond, bind_map) + self.body = [tvm.relay.bind(b, bind_map) for b in self.body] - cond = relay.op.min(self.cond) + cond = tvm.relay.op.min(self.cond) with sb.if_scope(cond): sb.ret(wl(*self.body)) with sb.else_scope(): - sb.ret(relay.Tuple(loop_vars)) + sb.ret(tvm.relay.Tuple(loop_vars)) - loop_fn = relay.Function(loop_vars, sb.get()) - sb = relay.scope_builder.ScopeBuilder() + loop_fn = tvm.relay.Function(loop_vars, sb.get()) + sb = tvm.relay.scope_builder.ScopeBuilder() sb.let(wl, loop_fn) sb.ret(wl(*self.loop_vars)) return sb.get() From cbb408ad7ebf9ea1224216cbd2df48a3714faaff Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Thu, 21 Mar 2019 20:06:03 +0000 Subject: [PATCH 4/5] move tests under tensorflow frontend --- .../tensorflow/test_control_flow.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/python/{relay/test_tf_loop_to_relay.py => frontend/tensorflow/test_control_flow.py} (100%) diff --git a/tests/python/relay/test_tf_loop_to_relay.py b/tests/python/frontend/tensorflow/test_control_flow.py similarity index 100% rename from tests/python/relay/test_tf_loop_to_relay.py rename to tests/python/frontend/tensorflow/test_control_flow.py From c8ab6937061311e44e8d4addacd695375f38d331 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 22 Mar 2019 17:23:48 +0000 Subject: [PATCH 5/5] minor fix --- python/tvm/relay/frontend/tensorflow.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 255546735d9b..304c5e11f1a5 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1369,6 +1369,9 @@ def _if_node(self): """An internal API to create a relay if node from the matched TF condition construct. """ + # `cond` returns a tensor that contains boolean values. We add a `min` + # operator to checks if there is any false value. If so, this condition + # doesn't not hold. cond = tvm.relay.op.min(self.cond) return tvm.relay.If(cond, self.true_branch, self.false_branch) @@ -1376,7 +1379,6 @@ def if_node(self): """Create an tvm.relay.If node if it hasn't been created yet.""" if self._if is None: self._if = self._if_node() - return self._if return self._if