-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Description
Motivation
TVM currently focuses on models which are control-flow free. In order to support control-flow operations from other frameworks, we can map other framework’s control flow operators, such as tf.while_loop, to Relay expressions. We can represent these operations using a combination Relay’s branches and recursive function calls.
This RFC proposes conversion from TensowFlow (1.x) control-flow constructs to Relay control-flow. The challenge is that TensorFlow uses low-level data-flow primitives, such as Merge, Exit, Switch, NextIteration, Enter, to implement control-flow operators (i.e. cond and while_loop). It is not trivial to revert these primitives to the original control-flow operators.
Proposal
We propose a decompilation strategy which reconstructs the original high-level control flow statements via pattern matching on the TensorFlow graph. The reconstruction translates low-level dataflow into corresponding Relay, i.e. loops to recursive functions and conditions to if expressions. Nested loops, nested conditions, mixing of them, and multiple level nested cases complicate the problem, fortunately, we haven’t seen any nested cases in real applications. Furthermore, we believe it is straightforward to support nested translation in many cases.
Our proposal is based on the observation/fact that:
- A TF
condwill only be composed by on merge and switch primitives, and there is only one merge.
import tensorflow as tf
x = tf.constant(10)
y = tf.constant(15)
z = tf.constant(20)
r = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
# returned value r is tf.add(x, z) if pred "x < y" is true else tf.square(y)- A
while_loopis constructed by the 5 aforementioned primitives. There could be multiple occurrences ofEnter,Merge,Exit,SwitchandNextIterationdepending on the number of conditional variables, but there is only one occurrence ofLoopCond.
i = tf.constant(0)
c = lambda i: tf.less(i, 10)
b = lambda i: tf.add(i, 1)
r = tf.while_loop(c, b, [i])
# Repeat loop body b while the loop condition c is trueThe creation of a cond or while_loop is usually in an execution frame (i.e. similar to scoping). Therefore, by identifying a scope, we can correctly create condition and while_loop constructs.
-
We should instantiate a branch statement when we meet a
Mergeprimitive in acondexecution frame and save the inputs of theMergeto capture the branches. The input ofSwitchcontains the used variables, and the inputs ofMergeindicate the true and false bodies. Using this information, we can build a Relayifexpression. -
We can generate a loop when we traverse into a
while_loopexecution frame and see the first occurrence ofLoopCond. The input ofNextIterationindicates loop body, the inputs ofSwitchindicates loop variables,Exitindicates the completion of an execution frame where we can extract the output from a certain execution frame,LoopCondgives us the condition. Based on the collected information, we can construct a specialized while loop using recursion.

