Skip to content

[RFC] Decompile TensorFlow Control Flow Primitives to Relay #2812

@zhiics

Description

@zhiics

Motivation

TVM currently focuses on models which are control-flow free. In order to support control-flow operations from other frameworks, we can map other framework’s control flow operators, such as tf.while_loop, to Relay expressions. We can represent these operations using a combination Relay’s branches and recursive function calls.

This RFC proposes conversion from TensowFlow (1.x) control-flow constructs to Relay control-flow. The challenge is that TensorFlow uses low-level data-flow primitives, such as Merge, Exit, Switch, NextIteration, Enter, to implement control-flow operators (i.e. cond and while_loop). It is not trivial to revert these primitives to the original control-flow operators.

Proposal

We propose a decompilation strategy which reconstructs the original high-level control flow statements via pattern matching on the TensorFlow graph. The reconstruction translates low-level dataflow into corresponding Relay, i.e. loops to recursive functions and conditions to if expressions. Nested loops, nested conditions, mixing of them, and multiple level nested cases complicate the problem, fortunately, we haven’t seen any nested cases in real applications. Furthermore, we believe it is straightforward to support nested translation in many cases.

Our proposal is based on the observation/fact that:

  1. A TF cond will only be composed by on merge and switch primitives, and there is only one merge.
import tensorflow as tf
x = tf.constant(10)
y = tf.constant(15)
z = tf.constant(20)
r = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
# returned value r is tf.add(x, z) if pred "x < y" is true else tf.square(y)

Screen Shot 2019-02-25 at 8 31 55 PM

  1. A while_loop is constructed by the 5 aforementioned primitives. There could be multiple occurrences of Enter, Merge, Exit, Switch and NextIteration depending on the number of conditional variables, but there is only one occurrence of LoopCond.
i = tf.constant(0)
c = lambda i: tf.less(i, 10)
b = lambda i: tf.add(i, 1)
r = tf.while_loop(c, b, [i])
# Repeat loop body b while the loop condition c is true

Screen Shot 2019-03-13 at 11 13 54 PM

The creation of a cond or while_loop is usually in an execution frame (i.e. similar to scoping). Therefore, by identifying a scope, we can correctly create condition and while_loop constructs.

  1. We should instantiate a branch statement when we meet a Merge primitive in a cond execution frame and save the inputs of the Merge to capture the branches. The input of Switch contains the used variables, and the inputs of Merge indicate the true and false bodies. Using this information, we can build a Relay if expression.

  2. We can generate a loop when we traverse into a while_loop execution frame and see the first occurrence of LoopCond. The input of NextIteration indicates loop body, the inputs of Switch indicates loop variables, Exit indicates the completion of an execution frame where we can extract the output from a certain execution frame, LoopCond gives us the condition. Based on the collected information, we can construct a specialized while loop using recursion.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions