Skip to content
83 changes: 65 additions & 18 deletions python/tvm/relay/testing/py_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
# import tvm
# from tvm import relay
# from tvm import nd
# from tvm.runtime import import container as _container
# from tvm.runtime import container as _container
# from tvm.relay.backend.interpreter import RefValue, ConstructorValue
PROLOGUE = [
ast.Import([alias("numpy", None)]),
Expand All @@ -60,7 +60,7 @@ class PythonConverter(ExprFunctor):
def __init__(self, mod, target) -> None:
super().__init__()
self.mod = mod
self.tgt = target
self.tgt = target if isinstance(target, tvm.target.Target) else tvm.target.Target(target)
self.tec = te_compiler.get()
self.fun_no = 0
self.var_no = 0
Expand Down Expand Up @@ -98,15 +98,31 @@ def optimize(self, prog: Expr):
# unwrap tuple wrappers (some op calls produce them)
unwrapped = prog.astuple() if isinstance(prog, relay.TupleWrapper) else prog
assert relay.analysis.well_formed(unwrapped)
mod = self.mod.from_expr(unwrapped, self.mod.functions, self.mod.type_definitions)
# For a lone global var, there is nothing we need to do
if isinstance(unwrapped, relay.GlobalVar):
return unwrapped

# main might be in the mod already and from_expr will not override it if it's there,
# so we need a new name
target_name = self.generate_function_name("target")

wrapped = unwrapped
if not isinstance(unwrapped, relay.Function):
wrapped = relay.Function(relay.analysis.free_vars(unwrapped), unwrapped)

# easiest way to make a deep copy -- note that main will not be overridden if it's present
copy_mod = tvm.IRModule.from_expr(
relay.Tuple([]), self.mod.functions, self.mod.type_definitions
)
copy_mod[target_name] = wrapped

# necessary pass: SimplifyInference (otherwise we can't generate code for some operators)
# and fusion (to get primitive functions)
opts = tvm.transform.Sequential(
[relay.transform.SimplifyInference(), relay.transform.FuseOps(fuse_opt_level=0)]
)
mod = opts(mod)
optimized = mod["main"]
copy_mod = opts(copy_mod)
optimized = copy_mod[target_name]
return optimized if isinstance(unwrapped, Function) else optimized.body

def sanitize(self, name: str) -> str:
Expand Down Expand Up @@ -197,7 +213,7 @@ def convert_func_node(self, func: Function, name_var=None):

var_names = [self.get_var_name(var) for var in func.params]
body, defs = self.visit(func.body)
ret = self.create_def(func_name, var_names, defs + [Return(body)])
ret = self.create_def(func_name, var_names, defs + [Return(body)], register_packed=True)
return (ret, func_name)

def convert_module(self):
Expand All @@ -219,10 +235,25 @@ def create_call(self, func_name: str, arguments):
"""Creates a simple function call."""
return ast.Call(self.parse_name(func_name), arguments, [])

def create_def(self, func_name: str, arguments: [str], body):
"""Wrapper over function definition AST node, whose constructor is inconvenient."""
def create_def(self, func_name: str, arguments: [str], body, register_packed: bool = False):
"""
Wrapper over function definition AST node, whose constructor is inconvenient.

register_packed includes a tvm.register_func decorator on the generated function if true.
This option should be used for Relay functions (warning: clobbers registry!)
"""
inner_args = [ast.arg(argument, None) for argument in arguments]

# add a decorator to register as a PackedFunc so the function will be an ObjectRef
# and will allow for putting functions into tuples or refs
decorator_list = [
ast.Call(
self.parse_name("tvm.register_func"),
[ast.Constant(value=func_name)],
[ast.keyword(arg="override", value=ast.Constant(value=True))],
)
]

global __MAJOR__, __MINOR__
if __MAJOR__ == 3 and __MINOR__ >= 8:
arguments = ast.arguments([], inner_args, None, [], [], None, [])
Expand All @@ -233,10 +264,19 @@ def create_def(self, func_name: str, arguments: [str], body):
func_name,
arguments,
body,
[],
decorator_list if register_packed else [],
None,
)

def create_tuple(self, fields):
"""
Given the ASTs for tuple fields, produce an AST that creates a
tuple value with those fields
"""
# Use the FFI API directly so that PackedFuncs will be correctly converted to ObjectRef.
# Using tvm.runtime.container.tuple_object fails to convert PackedFuncs in Python
return self.create_call("_container._ffi_api.Tuple", fields)

def create_op_call(self, op: Function, relay_args, py_args):
"""Lowers the passed primitive function, registers it in TVM's
global compiler, and produces a call to the lowered function in
Expand Down Expand Up @@ -290,8 +330,7 @@ def convert_output(ret_type):
assignments += inner_assignments
extra_args += inner_args
fields.append(inner_output)
fields = [ast.List(fields, Load())]
return (assignments, extra_args, self.create_call("_container.tuple_object", fields))
return (assignments, extra_args, self.create_tuple(fields))

# create a function to wrap the call of the lowered op and return
# a call to that function
Expand Down Expand Up @@ -418,7 +457,9 @@ def visit_var(self, var: Expr):
def visit_global_var(self, gvar: Expr):
# we don't need to add numbers to global var names because
# the *names* are checked for uniqueness in the mod
return (Name(str(gvar.name_hint), Load()), [])
func_name = str(gvar.name_hint)
# load in the packed func
return (self.create_call("tvm.get_global_func", [ast.Constant(value=func_name)]), [])

def visit_let(self, letexp: Expr):
# To properly account for scoping and ensure that the entire node produces an expression,
Expand Down Expand Up @@ -456,8 +497,7 @@ def let_thunk(var):

def visit_tuple(self, tup: Expr):
fields, ret_defs = self.convert_fields(tup.fields)
fields = [ast.List(fields, Load())]
return (self.create_call("_container.tuple_object", fields), ret_defs)
return (self.create_tuple(fields), ret_defs)

def visit_tuple_getitem(self, tgi: Expr):
tup, tup_defs = self.visit(tgi.tuple_value)
Expand All @@ -471,7 +511,7 @@ def visit_if(self, if_block: Expr):

# need to get the value out of a NDArray to check the condition
# equvialent to: val.numpy()
cond_check = ast.Call(ast.Attribute(cond_body, "asnumpy", Load()), [], [])
cond_check = ast.Call(ast.Attribute(cond_body, "numpy", Load()), [], [])
ret = ast.IfExp(cond_check, true_body, false_body)
return (ret, cond_defs + true_defs + false_defs)

Expand All @@ -490,7 +530,11 @@ def visit_constant(self, constant: Expr):
def visit_function(self, func: Expr):
# Python's lambdas are very restrictive, so we do "name" inline functions
converted_func, func_name = self.convert_func_node(func)
return (Name(func_name, Load()), [converted_func])
# load in the PackedFunc
return (
self.create_call("tvm.get_global_func", [ast.Constant(value=func_name)]),
[converted_func],
)

def visit_call(self, call: Expr):
"""For calls, we must distinguish between ordinary functions,
Expand Down Expand Up @@ -546,7 +590,7 @@ def visit_ref_write(self, write: Expr):
+ val_defs
+ [
Assign([ast.Attribute(ref, "value", Store())], val),
Return(self.create_call("_container.tuple_object", [])),
Return(self.create_tuple([])),
],
)
return (self.create_call(thunk_name, []), [thunk])
Expand Down Expand Up @@ -602,7 +646,10 @@ def to_python(expr: Expr, mod=None, target=tvm.target.Target("llvm")):

def run_as_python(expr: Expr, mod=None, target=tvm.target.Target("llvm")):
"""Converts the given Relay expression into a Python script and
executes it."""
executes it.

Note that closures will be returned as PackedFuncs
"""
mod = mod if mod is not None else tvm.IRModule()
py_ast = to_python(expr, mod, target)
code = compile(py_ast, "<string>", "exec")
Expand Down
1 change: 1 addition & 0 deletions src/runtime/container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,5 +202,6 @@ TVM_REGISTER_GLOBAL("runtime.GetShapeTupleElem").set_body_typed([](ShapeTuple sh
ICHECK_LT(idx, shape.size());
return shape[idx];
});

} // namespace runtime
} // namespace tvm
65 changes: 63 additions & 2 deletions tests/python/relay/test_py_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import tvm
from tvm import te
from tvm import relay
from tvm.relay.testing import to_python, run_as_python
from tvm.relay.testing import run_as_python
from tvm.relay.prelude import Prelude
from tvm.runtime.container import ADT
from tvm.relay.backend.interpreter import RefValue, ConstructorValue
Expand Down Expand Up @@ -70,7 +70,6 @@ def test_create_empty_tuple():
def test_create_scalar():
scalar = relay.const(1)
tensor_val = run_as_python(scalar)
print(type(tensor_val))
assert_tensor_value(tensor_val, 1)


Expand Down Expand Up @@ -611,3 +610,65 @@ def reference(x, gamma, beta, moving_mean, moving_var):
verify_batch_norm([(20, 10), (10,), (10,), (10,), (10,)])
verify_batch_norm([(10, 50), (50,), (50,), (50,), (50,)])
verify_batch_norm([(30, 40), (40,), (40,), (40,), (40,)])


def test_return_global_var():
tt = relay.TensorType([1], "float32")
x = relay.Var("x", type_annotation=tt)
identity = relay.Function([x], x, ret_type=tt)
mod = tvm.IRModule()
mod["main"] = identity
main_var = mod.get_global_var("main")
main_func = run_as_python(main_var, mod=mod)

arg = tvm.nd.array(np.array([0.0], dtype="float32"))
res = main_func(arg)
assert arg.numpy() == res.numpy()


def test_closure_in_tuple():
tt = relay.TensorType([1], "float32")
x = relay.Var("x", type_annotation=tt)
identity = relay.Function([x], x, ret_type=tt)
tup = relay.Tuple([identity, identity])
index = relay.TupleGetItem(tup, 0)

func = run_as_python(index)
arg = tvm.nd.array(np.array([0.0], dtype="float32"))
res = func(arg)
assert arg.numpy() == res.numpy()


def test_closure_in_ref():
tt = relay.TensorType([1], "float32")
x = relay.Var("x", type_annotation=tt)
identity = relay.Function([x], x, ret_type=tt)
gv = relay.GlobalVar("id")

r = relay.Var("r")
seq = relay.Let(
r,
relay.RefCreate(gv),
relay.Call(relay.RefRead(r), [relay.const(np.array([0.0], dtype="float32"))]),
)

mod = tvm.IRModule()
mod[gv] = identity
res = run_as_python(seq, mod=mod)
assert res.numpy() == np.array([0.0], dtype="float32")


def test_compiling_with_main():
unit_type = relay.TupleType([])
unit = relay.Function([], relay.Tuple([]), ret_type=unit_type)

x = relay.Var("x", type_annotation=unit_type)
identity = relay.Function([x], x, ret_type=unit_type)

mod = tvm.IRModule()
mod["unit"] = unit
mod["main"] = identity

res = run_as_python(mod.get_global_var("main")(mod.get_global_var("unit")()), mod=mod)
assert isinstance(res, ADT)
assert len(res) == 0