Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,13 @@ class FunctionNode : public ExprNode {
*/
TVM_DLL FuncType func_type_annotation() const;

/*!
* \brief Check whether the function is a primitive function.
*
* \return Whether the function is primitive or not.
*/
bool IsPrimitive() const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was cherry-picking changes from an old branch 😄 got distracted at TVM conference while working on this PR.


TVM_DLL static Function make(tvm::Array<Var> params,
Expr body,
Type ret_type,
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from . import base
from . import ty
from . import expr
from . import expr_functor
from . import module
from . import ir_pass
from .build_module import build, build_config, create_executor
Expand Down Expand Up @@ -53,6 +54,10 @@
If = expr.If
TupleGetItem = expr.TupleGetItem

# ExprFunctor
ExprFunctor = expr_functor.ExprFunctor
ExprMutator = expr_functor.ExprMutator

# helper functions
var = expr.var
const = expr.const
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/relay/backend/graph_runtime_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from . import _backend
from . import compile_engine
from ..op import Op
from ..expr import Function, GlobalVar, ExprFunctor
from ..expr import Function, GlobalVar
from ..expr_functor import ExprFunctor
from ..ty import TupleType, TensorType


Expand Down Expand Up @@ -251,6 +252,9 @@ def visit_call(self, call):
op_name, inputs, {})
return self.add_node(op_node, call)

def visit_op(self, _):
raise Exception("can not compile op in non-eta expanded form")

def _get_json(self):
"""
Convert the sequence of nodes stored by the compiler into the
Expand Down
130 changes: 3 additions & 127 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,13 @@ def __init__(self,
params,
body,
ret_type=None,
type_params=None):
type_params=None,
attrs=None):
if type_params is None:
type_params = convert([])

self.__init_handle_by_constructor__(
_make.Function, params, body, ret_type, type_params)
_make.Function, params, body, ret_type, type_params, attrs)

def __call__(self, *args):
"""Invoke the gobal function.
Expand Down Expand Up @@ -343,131 +344,6 @@ def realize(self):
return _expr.TempExprRealize(self)


class ExprFunctor(object):
"""
An abstract visitor defined over Expr.

Defines the default dispatch over expressions, and
implements memoization.
"""
def __init__(self):
self.memo_map = {}

# pylint: disable=no-else-return
def visit(self, expr):
"""Apply the visitor to an expression."""
found = self.memo_map.get(expr)
if found:
return found

if isinstance(expr, Function):
res = self.visit_function(expr)
elif isinstance(expr, Call):
res = self.visit_call(expr)
elif isinstance(expr, Let):
res = self.visit_let(expr)
elif isinstance(expr, Var):
res = self.visit_var(expr)
elif isinstance(expr, GlobalVar):
res = self.visit_global_var(expr)
elif isinstance(expr, If):
res = self.visit_if(expr)
elif isinstance(expr, Tuple):
res = self.visit_tuple(expr)
elif isinstance(expr, TupleGetItem):
res = self.visit_tuple_getitem(expr)
elif isinstance(expr, Constant):
res = self.visit_constant(expr)
else:
raise Exception("warning unhandled case: {0}".format(type(expr)))

self.memo_map[expr] = res
return res

def visit_function(self, _):
raise NotImplementedError()

def visit_let(self, _):
raise NotImplementedError()

def visit_call(self, _):
raise NotImplementedError()

def visit_var(self, _):
raise NotImplementedError()

def visit_type(self, typ):
return typ

def visit_if(self, _):
raise NotImplementedError()

def visit_tuple(self, _):
raise NotImplementedError()

def visit_tuple_getitem(self, _):
raise NotImplementedError()

def visit_constant(self, _):
raise NotImplementedError()

def visit_global_var(self, _):
raise NotImplementedError()


class ExprMutator(ExprFunctor):
"""
A functional visitor over Expr.

The default behavior recursively traverses the AST
and reconstructs the AST.
"""
def visit_function(self, fn):
new_body = self.visit(fn.body)
return Function(
list(fn.params),
fn.ret_type, new_body,
fn.type_params)

def visit_let(self, let):
new_var = self.visit(let.var)
new_val = self.visit(let.value)
new_body = self.visit(let.body)
return Let(new_var, new_val, new_body)

def visit_call(self, call):
new_fn = self.visit(call.op)
new_args = [self.visit(arg) for arg in call.args]
return Call(new_fn, new_args, call.attrs)

def visit_var(self, rvar):
return rvar

def visit_global_id(self, global_var):
return global_var

def visit_if(self, ite):
return If(
self.visit(ite.guard),
self.visit(ite.true_b),
self.visit(ite.false_b))

def visit_tuple(self, tup):
return Tuple([self.visit(field) for field in tup.fields])

def visit_tuple_getitem(self, op):
tuple_value = self.visit(op.tuple_value)
if not tuple_value.same_as(op.tuple_value):
return TupleGetItem(tuple_value, op.index)
return op

def visit_global_var(self, gvar):
return gvar

def visit_constant(self, rconst):
return rconst


class TupleWrapper(object):
"""TupleWrapper.

Expand Down
155 changes: 155 additions & 0 deletions python/tvm/relay/expr_functor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The expression functor of Relay."""

from .expr import Function, Call, Let, Var, GlobalVar, If, Tuple, TupleGetItem, Constant
from .op import Op

class ExprFunctor:
"""
An abstract visitor defined over Expr.

Defines the default dispatch over expressions, and
implements memoization.
"""
def __init__(self):
self.memo_map = {}

# pylint: disable=no-else-return
def visit(self, expr):
"""Apply the visitor to an expression."""
found = self.memo_map.get(expr)
if found:
return found

if isinstance(expr, Function):
res = self.visit_function(expr)
elif isinstance(expr, Call):
res = self.visit_call(expr)
elif isinstance(expr, Let):
res = self.visit_let(expr)
elif isinstance(expr, Var):
res = self.visit_var(expr)
elif isinstance(expr, GlobalVar):
res = self.visit_global_var(expr)
elif isinstance(expr, If):
res = self.visit_if(expr)
elif isinstance(expr, Tuple):
res = self.visit_tuple(expr)
elif isinstance(expr, TupleGetItem):
res = self.visit_tuple_getitem(expr)
elif isinstance(expr, Constant):
res = self.visit_constant(expr)
elif isinstance(expr, Op):
res = self.visit_op(expr)
else:
raise Exception("warning unhandled case: {0}".format(type(expr)))

self.memo_map[expr] = res

return res

def visit_function(self, _):
raise NotImplementedError()

def visit_let(self, _):
raise NotImplementedError()

def visit_call(self, _):
raise NotImplementedError()

def visit_var(self, _):
raise NotImplementedError()

def visit_type(self, typ):
return typ

def visit_if(self, _):
raise NotImplementedError()

def visit_tuple(self, _):
raise NotImplementedError()

def visit_tuple_getitem(self, _):
raise NotImplementedError()

def visit_global_var(self, _):
raise NotImplementedError()

def visit_op(self, _):
raise NotImplementedError()

def visit_constant(self, _):
raise NotImplementedError()


class ExprMutator(ExprFunctor):
"""
A functional visitor over Expr.

The default behavior recursively traverses the AST
and reconstructs the AST.
"""
def visit_function(self, fn):
new_body = self.visit(fn.body)
return Function(
list(fn.params),
new_body,
fn.ret_type,
fn.type_params,
fn.attrs)

def visit_let(self, let):
new_var = self.visit(let.var)
new_val = self.visit(let.value)
new_body = self.visit(let.body)
return Let(new_var, new_val, new_body)

def visit_call(self, call):
new_fn = self.visit(call.op)
new_args = [self.visit(arg) for arg in call.args]
return Call(new_fn, new_args, call.attrs)

def visit_var(self, rvar):
return rvar

def visit_global_id(self, global_var):
return global_var

def visit_if(self, ite):
return If(
self.visit(ite.guard),
self.visit(ite.true_b),
self.visit(ite.false_b))

def visit_tuple(self, tup):
return Tuple([self.visit(field) for field in tup.fields])

def visit_tuple_getitem(self, op):
tuple_value = self.visit(op.tuple_value)
if not tuple_value.same_as(op.tuple_value):
return TupleGetItem(tuple_value, op.index)
return op

def visit_global_var(self, gvar):
return gvar

def visit_op(self, op):
return op

def visit_constant(self, const):
return const

def visit_constructor(self, con):
return con

def visit_match(self, m):
return Match(self.visit(m.data), [Clause(c.lhs, self.visit(c.rhs)) for c in m.pattern])

def visit_ref_new(self, r):
return RefNew(self.visit(r.value))

def visit_ref_write(self, r):
return RefWrite(self.visit(r.ref), self.visit(r.value))

def visit_ref_read(self, r):
return RefRead(self.visit(r.ref))
8 changes: 4 additions & 4 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,14 @@ class ScheduleGetter :

int op_pattern = fpattern[op];
if (op_pattern >= kCommReduce) {
CHECK(!master_op_.defined() || master_op_patetrn_ < kCommReduce)
CHECK(!master_op_.defined() || master_op_pattern_ < kCommReduce)
<< "Two complicated op in a primitive function "
<< " master=" << master_op_ << " current=" << op;
}
if (op_pattern >= master_op_patetrn_) {
if (op_pattern >= master_op_pattern_) {
master_op_ = op;
master_attrs_ = call_node->attrs;
master_op_patetrn_ = op_pattern;
master_op_pattern_ = op_pattern;
}
if (outputs.size() != 1) {
const auto* tuple_type =
Expand Down Expand Up @@ -213,7 +213,7 @@ class ScheduleGetter :
tvm::Target target_;
Op master_op_;
Attrs master_attrs_;
int master_op_patetrn_{0};
int master_op_pattern_{0};
std::ostringstream readable_name_stream_;
std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> memo_;
};
Expand Down
Loading