Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,17 @@ class ExprMutator
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;
};

/*
* \brief Bind function parameters or free variables.
*
* Parameter binding can only happen if expr is a Function.
* binds cannot change internal arguments of internal functions.
*
* \param expr The function to be binded.
* \param binds The map of arguments to
*/
Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_FUNCTOR_H_
10 changes: 10 additions & 0 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ enum OpPatternKind {
/*! \brief the operator pattern */
using TOpPattern = int;

/*!
* \brief Whether operator is stateful or contain internal state.
*
* All the primitive ops we registered so far are pure.
* This attribute is left for potential future compatible reasons.
* We can always work around the stateful ops by adding an additional
* handle argument and return it.
*/
using TOpIsStateful = bool;

/*!
* \brief Computation description interface.
*
Expand Down
16 changes: 16 additions & 0 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,22 @@ tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);
*/
Expr DeadCodeElimination(const Expr& e);

/*!
* \brief Fold constant expressions.
* \param expr the expression to be optimized.
* \return The optimized expression.
*/
Expr FoldConstant(const Expr& expr);

/*!
* \brief Fuse operations into expr into seperate functions.
* \param expr The expression.
* \param fuse_opt_level Optimization level.
* \return The optimized expression.
*/
Expr FuseOps(const Expr& expr, int fuse_opt_level);


/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
# helper functions
var = expr.var
const = expr.const

bind = expr.bind

# pylint: disable=unused-argument
@register_func("relay.debug")
Expand Down
14 changes: 11 additions & 3 deletions python/tvm/relay/backend/graph_runtime_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(self, mod, target):
self.target = target
self.nodes = []
self.var_map = {}
self.params = {}
self.compile_engine = compile_engine.get()
self.lowered_funcs = set()
self._name_map = {}
Expand Down Expand Up @@ -162,8 +163,12 @@ def visit_tuple_getitem(self, op):
assert isinstance(vtuple, tuple)
return vtuple[op.index]

def visit_constant(self, _):
raise RuntimeError("constant not supported")
def visit_constant(self, op):
index = len(self.params)
name = "p%d" % index
self.params[name] = op.data
node = InputNode(name, {})
return self.add_node(node, op.checked_type)

def visit_function(self, _):
raise RuntimeError("function not supported")
Expand Down Expand Up @@ -312,6 +317,9 @@ def codegen(self, func):

lowered_funcs : List[tvm.LoweredFunc]
The lowered functions.

params : Dict[str, tvm.nd.NDArray]
Additional constant parameters.
"""
# First we convert all the parameters into input nodes.
for param in func.params:
Expand All @@ -324,7 +332,7 @@ def codegen(self, func):
self.heads = self.visit(func.body)
graph_json = self._get_json()
lowered_funcs = list(self.lowered_funcs)
return graph_json, lowered_funcs
return graph_json, lowered_funcs, self.params
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add returns in pydoc above


def _get_unique_name(self, name):
if name not in self._name_map:
Expand Down
56 changes: 45 additions & 11 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt
from . import ir_pass
from . import expr
from .backend import interpreter as _interpreter
from .backend import graph_runtime_codegen as _graph_gen

# List of optimization pass and level when switch on
OPT_PASS_LEVEL = {
"SimplifyInference": 0,
"OpFusion": 1,
"FoldConstant": 2,
"FoldScaleAxis": 3,
}

Expand Down Expand Up @@ -95,22 +97,50 @@ def build_config(**kwargs):
return BuildConfig(**kwargs)


def optimize(func):
def _bind_params_by_name(func, params):
"""Bind parameters of function by its name."""
name_dict = {}
for arg in func.params:
name = arg.name_hint
if name in name_dict:
name_dict[name] = None
else:
name_dict[name] = arg
bind_dict = {}
for k, v in params.items():
if k not in name_dict:
continue
arg = name_dict[k]
if arg is None:
raise ValueError("Multiple args in the function have name %s" % k)
bind_dict[arg] = expr.const(v)
return expr.bind(func, bind_dict)


def optimize(func, params=None):
"""Perform target invariant optimizations.

Parameters
----------
func : tvm.relay.Function
The input to optimization.

params : Optional[Dict[str, tvm.nd.NDArray]]
Input parameters to the graph that do not change
during inference time. used for constant folding.

Returns
-------
opt_func : tvm.relay.Function
The optimized version of the function.
"""
cfg = BuildConfig.current

if cfg.pass_enabled("FoldScaleAxis"):
# bind expressions
if params:
func = _bind_params_by_name(func, params)

if cfg.pass_enabled("SimplifyInference"):
func = ir_pass.infer_type(func)
func = ir_pass.simplify_inference(func)

Expand All @@ -119,6 +149,10 @@ def optimize(func):
func = ir_pass.backward_fold_scale_axis(func)
func = ir_pass.infer_type(func)
func = ir_pass.forward_fold_scale_axis(func)

if cfg.pass_enabled("FoldConstant"):
func = ir_pass.fold_constant(func)

return func


Expand Down Expand Up @@ -147,8 +181,7 @@ def build(func,

params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for pre-compute
folding optimization.
during inference time. Used for constant folding.

Returns
-------
Expand Down Expand Up @@ -176,14 +209,14 @@ def build(func,
cfg = BuildConfig.current

with tophub_context:
func = optimize(func)
func = optimize(func, params)
# Fuse ops before running code gen
func = ir_pass.infer_type(func)
func = ir_pass.fuse_ops(func, cfg.opt_level)
# Graph code generation
func = ir_pass.infer_type(func)
graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target)
graph_json, lowered_funcs = graph_gen.codegen(func)
graph_json, lowered_funcs, params = graph_gen.codegen(func)
mod = _tvm_build_module(lowered_funcs, target=target, target_host=target_host)
return graph_json, mod, params

Expand All @@ -210,21 +243,22 @@ def __init__(self, mod, ctx, target):
self.target = target

def _make_executor(self, func):
graph_json, mod, params = build(func, target=self.target)
gmodule = _graph_rt.create(graph_json, mod, self.ctx)
if params:
gmodule.set_input(*params)
def _graph_wrapper(*args):
graph_json, mod, params = build(func, target=self.target)
assert params is None
gmodule = _graph_rt.create(graph_json, mod, self.ctx)
# Create map of inputs.
for i, arg in enumerate(args):
gmodule.set_input(i, arg)
# Run the module, and fetch the output.
gmodule.run()
return gmodule.get_output(0)
# make a copy so multiple invocation won't hurt perf.
return gmodule.get_output(0).copyto(_nd.cpu(0))

return _graph_wrapper



def create_executor(kind="debug",
mod=None,
ctx=None,
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as _np
from .base import RelayNode, register_relay_node
from . import _make
from . import _expr
from . import ty as _ty
from .._ffi import base as _base
from .. import nd as _nd
Expand Down Expand Up @@ -577,3 +578,24 @@ def const(value, dtype=None):
if not isinstance(value, _nd.NDArray):
raise ValueError("value has to be scalar or NDArray")
return Constant(value)


def bind(expr, binds):
"""Bind an free variables in expr or function arguments.

We can bind parameters expr if it is a function.

Parameters
----------
expr : tvm.relay.Expr
The input expression.

binds : Union[Map[tvm.relay.Var, tvm.relay.Expr], Map[str, tvm.relay.Expr]]
The specific bindings.

Returns
-------
result : tvm.relay.Expr
The expression or function after binding.
"""
return _expr.Bind(expr, binds)
16 changes: 16 additions & 0 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,22 @@ def structural_hash(value):
raise TypeError(msg)


def fold_constant(expr):
"""Fold the constant expression in expr.

Parameters
----------
expr : tvm.relay.Expr
The input expression.

Returns
-------
transformed_expr : tvm.relay.Expr
The transformed expression.
"""
return _ir_pass.FoldConstant(expr)


def fuse_ops(expr, opt_level=1):
"""Fuse operators in expr together.

Expand Down
71 changes: 70 additions & 1 deletion src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
* ExprMutator uses memoization and self return in order to amortize
* the cost of using functional updates.
*/

#include <tvm/relay/expr_functor.h>
#include "type_functor.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -228,5 +228,74 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {

void ExprVisitor::VisitType(const Type& t) { return; }

// Implement bind.
class ExprBinder : public ExprMutator {
public:
explicit ExprBinder(const tvm::Map<Var, Expr>& args_map)
: args_map_(args_map) {
}

Expr VisitExpr_(const LetNode* op) final {
CHECK(!args_map_.count(op->var))
<< "Cannot bind an internel variable in let";
return ExprMutator::VisitExpr_(op);
}

Expr VisitExpr_(const FunctionNode* op) final {
for (Var param : op->params) {
CHECK(!args_map_.count(param))
<< "Cannnot bind an internal function parameter";
}
return ExprMutator::VisitExpr_(op);
}

Expr VisitExpr_(const VarNode* op) final {
auto id = GetRef<Var>(op);
auto it = args_map_.find(id);
if (it != args_map_.end()) {
return (*it).second;
} else {
return id;
}
}

private:
const tvm::Map<Var, Expr>& args_map_;
};

Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
if (const FunctionNode* func = expr.as<FunctionNode>()) {
Expr new_body = ExprBinder(args_map).Mutate(func->body);
Array<Var> new_params;
for (Var param : func->params) {
if (!args_map.count(param)) {
new_params.push_back(param);
}
}
if (new_body.same_as(func->body) &&
new_params.size() == func->params.size()) {
return expr;
}
return FunctionNode::make(new_params,
new_body,
func->ret_type,
func->type_params,
func->attrs);
} else {
return ExprBinder(args_map).Mutate(expr);
}
}


TVM_REGISTER_API("relay._expr.Bind")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef input = args[0];
if (input->derived_from<ExprNode>()) {
*ret = Bind(Downcast<Expr>(input), args[1]);
} else {
CHECK(input->derived_from<TypeNode>());
*ret = Bind(Downcast<Type>(input), args[1]);
}
});
} // namespace relay
} // namespace tvm
2 changes: 0 additions & 2 deletions src/relay/ir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
#include <memory>
#include <mutex>

#include "./../pass/type_subst.h"

namespace dmlc {
// enable registry
DMLC_REGISTRY_ENABLE(::tvm::relay::OpRegistry);
Expand Down
Loading