diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h
index c6e5573d9413..5e50cfc05e67 100644
--- a/include/tvm/relay/expr.h
+++ b/include/tvm/relay/expr.h
@@ -118,17 +118,27 @@ class Var;
/*! \brief Container for Var */
class VarNode : public ExprNode {
public:
- /*! \brief The name of the variable, this only acts as a hint to the user,
- * and is not used for equality.
+ /*!
+ * \brief The name of the variable,
+ * this only acts as a hint to the user,
+ * and is not used for equality.
*/
std::string name_hint;
+ /*!
+ * \brief type annotaion of the variable.
+ * This field records user provided type annotation of the Var.
+ * This field is optional and can be None.
+ */
+ Type type_annotation;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name_hint", &name_hint);
+ v->Visit("type_annotation", &type_annotation);
v->Visit("_checked_type_", &checked_type_);
}
- TVM_DLL static Var make(std::string name_hint);
+ TVM_DLL static Var make(std::string name_hint,
+ Type type_annotation);
static constexpr const char* _type_key = "relay.Var";
TVM_DECLARE_NODE_TYPE_INFO(VarNode, ExprNode);
@@ -162,32 +172,6 @@ class GlobalVarNode : public ExprNode {
RELAY_DEFINE_NODE_REF(GlobalVar, GlobalVarNode, Expr);
-/*!
- * \brief Function parameter declaration.
- */
-class Param;
-/*! \brief A parameter. */
-class ParamNode : public ExprNode {
- public:
- /*! \brief The variable */
- Var var;
- /*! \brief The type of the parameter */
- Type type;
-
- void VisitAttrs(tvm::AttrVisitor* v) final {
- v->Visit("var", &var);
- v->Visit("type", &type);
- v->Visit("span", &span);
- }
-
- TVM_DLL static Param make(Var var, Type type);
-
- static constexpr const char* _type_key = "relay.Param";
- TVM_DECLARE_NODE_TYPE_INFO(ParamNode, ExprNode);
-};
-
-RELAY_DEFINE_NODE_REF(Param, ParamNode, Expr);
-
/*!
* \brief Function (subgraph in computational graph)
*/
@@ -196,7 +180,7 @@ class Function;
class FunctionNode : public ExprNode {
public:
/*! \brief Function parameters */
- tvm::Array params;
+ tvm::Array params;
/*! \brief User annotated return type of the function. */
Type ret_type;
/*!
@@ -224,10 +208,18 @@ class FunctionNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
- Type fn_type() const;
+ /*!
+ * \brief Return the derived function annotation of this expression.
+ *
+ * \return The function type annotation.
+ * \note The function type annotation can contain IncompleteType.
+ */
+ TVM_DLL FuncType func_type_annotation() const;
- TVM_DLL static Function make(tvm::Array params, Type ret_type,
- Expr body, tvm::Array ty_params);
+ TVM_DLL static Function make(tvm::Array params,
+ Type ret_type,
+ Expr body,
+ tvm::Array ty_params);
static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
@@ -289,7 +281,7 @@ class CallNode : public ExprNode {
TVM_DLL static Call make(Expr op,
Array args,
Attrs attrs = Attrs(),
- Array ty_args = Array());
+ Array type_args = Array());
static constexpr const char* _type_key = "relay.Call";
TVM_DECLARE_NODE_TYPE_INFO(CallNode, ExprNode);
@@ -318,19 +310,16 @@ class LetNode : public ExprNode {
Expr value;
/*! \brief The body of the let binding */
Expr body;
- /*! \brief Type annotation of value, this can be null */
- Type value_type;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("value", &value);
v->Visit("body", &body);
- v->Visit("value_type", &value_type);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
- TVM_DLL static Let make(Var var, Expr value, Expr body, Type value_type);
+ TVM_DLL static Let make(Var var, Expr value, Expr body);
static constexpr const char* _type_key = "relay.Let";
TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode);
@@ -376,11 +365,11 @@ class IfNode : public ExprNode {
RELAY_DEFINE_NODE_REF(If, IfNode, Expr);
-/*! \brief Get a field out of a tuple. */
+/*! \brief Get index-th field out of a tuple. */
class TupleGetItem;
class TupleGetItemNode : public ExprNode {
public:
- /*! \brief The tuple */
+ /*! \brief The tuple Expression */
Expr tuple;
/*! \brief which value to get */
int index;
diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h
index be174d33b4c8..c10933590f99 100644
--- a/include/tvm/relay/expr_functor.h
+++ b/include/tvm/relay/expr_functor.h
@@ -80,7 +80,6 @@ class ExprFunctor {
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const GlobalVarNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const ParamNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FunctionNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
@@ -103,7 +102,6 @@ class ExprFunctor {
RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode);
RELAY_EXPR_FUNCTOR_DISPATCH(VarNode);
RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode);
- RELAY_EXPR_FUNCTOR_DISPATCH(ParamNode);
RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode);
RELAY_EXPR_FUNCTOR_DISPATCH(CallNode);
RELAY_EXPR_FUNCTOR_DISPATCH(LetNode);
@@ -127,7 +125,6 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor {
void VisitExpr_(const GlobalVarNode* op) override;
void VisitExpr_(const ConstantNode* op) override;
void VisitExpr_(const TupleNode* op) override;
- void VisitExpr_(const ParamNode* op) override;
void VisitExpr_(const FunctionNode* op) override;
void VisitExpr_(const CallNode* op) override;
void VisitExpr_(const LetNode* op) override;
@@ -151,7 +148,6 @@ class ExprMutator
Expr VisitExpr_(const GlobalVarNode* op) override;
Expr VisitExpr_(const OpNode* op) override;
Expr VisitExpr_(const TupleNode* op) override;
- Expr VisitExpr_(const ParamNode* op) override;
Expr VisitExpr_(const FunctionNode* op) override;
Expr VisitExpr_(const CallNode* call_node) override;
Expr VisitExpr_(const LetNode* op) override;
diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py
index 18c02a416d6b..b1085be2e1e2 100644
--- a/python/tvm/relay/__init__.py
+++ b/python/tvm/relay/__init__.py
@@ -34,7 +34,6 @@
Tuple = expr.Tuple
Var = expr.Var
GlobalVar = expr.GlobalVar
-Param = expr.Param
Function = expr.Function
Call = expr.Call
Let = expr.Let
diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py
index 6ed8df0d736b..a71fd329ed5b 100644
--- a/python/tvm/relay/expr.py
+++ b/python/tvm/relay/expr.py
@@ -11,11 +11,11 @@ class Expr(NodeBase):
"""The base type for all Relay expressions."""
@property
def checked_type(self):
- """Get the checked type of relay.
+ """Get the checked type of tvm.relay.Expr.
Returns
-------
- checked_type : relay.Type
+ checked_type : tvm.relay.Type
The checked type.
"""
ret = self._checked_type_
@@ -25,70 +25,97 @@ def checked_type(self):
return ret
def __call__(self, *args):
- converted_args = []
- for arg in args:
- if isinstance(arg, Param):
- converted_args.append(arg.var)
- else:
- converted_args.append(arg)
-
return Call(self, args, None, None)
@register_relay_node
class Constant(Expr):
- """A constant tensor in Relay, see tvm/relay/type.h for more details.
- """
+ """A constant expression in Relay.
+ Parameters
+ ----------
+ data : tvm.nd.NDArray
+ The data content of the constant expression.
+ """
def __init__(self, data):
self.__init_handle_by_constructor__(_make.Constant, data)
@register_relay_node
class Tuple(Expr):
- """A hetereogenous sequence of values.
- see tvm/relay/type.h for more details.
- """
+ """Tuple expression that groups several fields together.
+ Parameters
+ ----------
+ fields : List[tvm.relay.Expr]
+ The fields in the tuple.
+ """
def __init__(self, fields):
self.__init_handle_by_constructor__(_make.Tuple, fields)
@register_relay_node
class Var(Expr):
- """A local variable in Relay."""
+ """A local variable in Tvm.Relay.
- def __init__(self, name_hint):
- self.__init_handle_by_constructor__(_make.Var, name_hint)
+ Local variable can be used to declare input
+ arguments to a function, or intermediate variables.
+
+ Parameters
+ ----------
+ name_hint: str
+ The name of the variable.
+ This name only acts as a hint, and is not used
+ for equality.
+
+ type_annotation: tvm.relay.Type, optional
+ The type annotation on the variable.
+ """
+ def __init__(self, name_hint, type_annotation=None):
+ self.__init_handle_by_constructor__(
+ _make.Var, name_hint, type_annotation)
@register_relay_node
class GlobalVar(Expr):
- """A global variable in Relay."""
+ """A global variable in Tvm.Relay.
+ GlobalVar is used to refer to the global functions
+ stored in the environment.
+
+ Parameters
+ ----------
+ name_hint: str
+ The name of the variable.
+ """
def __init__(self, name_hint):
self.__init_handle_by_constructor__(_make.GlobalVar, name_hint)
@register_relay_node
-class Param(Expr):
- """A function type in Relay, see tvm/relay/type.h for more details.
- """
+class Function(Expr):
+ """A function declaration expression.
- def __init__(self, var, ty):
- self.__init_handle_by_constructor__(_make.Param, var, ty)
+ Parameters
+ ----------
+ params: List[tvm.relay.Var]
+ List of input parameters to the function.
+ ret_type: tvm.relay.Type
+ The return type annotation of the function.
-@register_relay_node
-class Function(Expr):
- """A function in Relay, see tvm/relay/expr.h for more details."""
+ body: tvm.relay.Expr
+ The body of the function.
+ type_params: Optional[List[tvm.relay.TypeParam]]
+ The additional type parameters, this is only
+ used in advanced usecase of template functions.
+ """
def __init__(self,
params,
ret_type,
body,
- type_params=None
- ):
+ type_params=None):
if type_params is None:
type_params = convert([])
@@ -98,39 +125,87 @@ def __init__(self,
@register_relay_node
class Call(Expr):
- """A function call in Relay, see tvm/relay/expr.h for more details."""
+ """Function call node in Relay.
+
+ Call node corresponds the operator application node
+ in computational graph terminology.
+
+ Parameters
+ ----------
+ op: tvm.relay.Op or any tvm.relay.Expr with function type.
+ The operation to be called.
- def __init__(self, op, args, attrs, ty_args=None):
- if not ty_args:
- ty_args = []
+ args: List[tvm.relay.Expr]
+ The arguments to the call.
+ attrs: Optional[tvm.Attrs]
+ Attributes to the call, can be None
+
+ type_args: Optional[List[tvm.relay.Type]]
+ The additional type arguments, this is only
+ used in advanced usecase of template functions.
+ """
+ def __init__(self, op, args, attrs=None, type_args=None):
+ if not type_args:
+ type_args = []
self.__init_handle_by_constructor__(
- _make.Call, op, args, attrs, ty_args)
+ _make.Call, op, args, attrs, type_args)
@register_relay_node
class Let(Expr):
- """A variable bindings in Relay, see tvm/relay/expr.h for more details."""
+ """Let variable binding expression.
+
+ Parameters
+ ----------
+ var: tvm.relay.Var
+ The local variable to be bound.
+
+ value: tvm.relay.Expr
+ The value to be bound.
- def __init__(self, var, value, body, value_type=None):
+ body: tvm.relay.Expr
+ The body of the let binding.
+ """
+ def __init__(self, var, value, body):
self.__init_handle_by_constructor__(
- _make.Let, var, value, body, value_type)
+ _make.Let, var, value, body)
@register_relay_node
class If(Expr):
- """A conditional expression in Relay, see tvm/relay/expr.h for more details."""
+ """A conditional expression in Relay.
+
+ Parameters
+ ----------
+ cond: tvm.relay.Expr
+ The condition.
- def __init__(self, cond, true_value, false_value):
+ true_branch: tvm.relay.Expr
+ The expression evaluated when condition is true.
+
+ false_branch: tvm.relay.Expr
+ The expression evaluated when condition is false.
+ """
+ def __init__(self, cond, true_branch, false_branch):
self.__init_handle_by_constructor__(
- _make.If, cond, true_value, false_value)
+ _make.If, cond, true_branch, false_branch)
+
@register_relay_node
class TupleGetItem(Expr):
- """An expression that get field from tuple in Relay, see tvm/relay/expr.h for more details."""
+ """Get index-th item from a tuple.
+
+ Parameters
+ ----------
+ tuple_value: tvm.relay.Expr
+ The input tuple expression.
- def __init__(self, tuple_, index):
+ index: int
+ The index.
+ """
+ def __init__(self, tuple_value, index):
self.__init_handle_by_constructor__(
- _make.TupleGetItem, tuple_, index)
+ _make.TupleGetItem, tuple_value, index)
debug_print = _expr._debug_print
diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py
index accb782659df..a429aea7d5ea 100644
--- a/python/tvm/relay/ir_builder.py
+++ b/python/tvm/relay/ir_builder.py
@@ -7,7 +7,7 @@
import numpy as np
import tvm
from .ty import Type, FuncType, TensorType
-from .expr import Expr, Constant, Let, Var, Param, Function, If
+from .expr import Expr, Constant, Let, Var, Function, If
from .env import Environment
@@ -98,7 +98,7 @@ def __init__(self, params, ret_type, body, type_params):
self.type_params = type_params
def param_ids(self):
- return [p.var for p in self.params]
+ return [p for p in self.params]
def to_func(self):
"""Converts a PartialFunc into a :py:class:`~relay.Function`."""
@@ -113,9 +113,8 @@ def to_func(self):
def _mk_let(bindings, ret_value):
let_expr = ret_value
- for var, (value, ty) in reversed(list(bindings.items())):
- let_expr = Let(var, value, let_expr, ty)
-
+ for var, value in reversed(list(bindings.items())):
+ let_expr = Let(var, value, let_expr)
return let_expr
@@ -168,15 +167,12 @@ def exit_scope(self):
#pylint: disable=invalid-name
def bind(self, name, value, ty):
- lv = Var(name)
+ lv = Var(name, ty)
self.scopes[-1][name] = lv
- self.bindings[-1][lv] = (value, ty)
+ self.bindings[-1][lv] = value
return lv
def let(self, name, value, value_type=None):
- if isinstance(value, Param):
- value = value.var
-
if not isinstance(value, Expr):
value = convert(value)
@@ -185,23 +181,18 @@ def let(self, name, value, value_type=None):
def _convert_params(self, raw_params):
relay_params = []
for raw_param in raw_params:
- if isinstance(raw_param, Param):
- var = raw_param.var
+ if isinstance(raw_param, Var):
param = raw_param
elif isinstance(raw_param, tuple):
var, ty = raw_param
- if isinstance(var, str):
- var = Var(var)
ty = _convert_type(ty)
- param = Param(var, ty)
- elif isinstance(param, str):
- var = Var(raw_param)
- ty = None
- param = Param(var, ty)
+ param = Var(var, ty)
+ elif isinstance(raw_param, str):
+ param = Var(raw_param, None)
else:
raise Exception("unknown parameter type")
- self.scopes[-1][var.name_hint] = var
+ self.scopes[-1][param.name_hint] = param
relay_params.append(param)
return relay_params
@@ -265,7 +256,7 @@ def param(self, name, ty=None):
else:
ty = _convert_type(ty)
- return Param(Var(name), ty)
+ return Var(name, ty)
def global_var(self, name):
# type: (str) -> GlobalVar
diff --git a/src/relay/ir/debug_printer.cc b/src/relay/ir/debug_printer.cc
index 90e82d3b2dd7..cb463ef6975a 100644
--- a/src/relay/ir/debug_printer.cc
+++ b/src/relay/ir/debug_printer.cc
@@ -96,7 +96,9 @@ class TypeDocifier : private TypeFunctor {
}
std::vector DocifyTypeParam(const tvm::Array& arr) {
- return MapDocify(arr, [=](const TypeParam& tp) { return Docify(tp); });
+ return MapDocify(arr, [=](const TypeParam& tp) {
+ return Docify(tp);
+ });
}
std::vector DocifyTypeConstraint(const tvm::Array& arr) {
@@ -188,10 +190,11 @@ class ExprDocifier : private ExprFunctor {
return vec;
}
- std::vector DocifyParamArray(const tvm::Array& arr) {
+ std::vector DocifyParamArray(const tvm::Array& arr) {
std::vector vec;
- for (size_t i = 0; i < arr.size(); ++i) {
- vec.push_back(Docify(arr[i]));
+ for (Var param : arr) {
+ vec.emplace_back(TypeAnnotation(DocOfStr(VarName(param)),
+ param->type_annotation));
}
return vec;
}
@@ -212,10 +215,6 @@ class ExprDocifier : private ExprFunctor {
return DocOfStr(g->name_hint);
}
- Doc VisitExpr_(const ParamNode* p) final {
- return TypeAnnotation(Docify(p->var), p->type);
- }
-
Doc VisitExpr_(const FunctionNode* f) final {
return Group(TypeAnnotation(Seq("(", DocifyParamArray(f->params), ")"), f->ret_type) + Sep() +
DocOfStr("=>") + Sep() +
@@ -227,7 +226,8 @@ class ExprDocifier : private ExprFunctor {
}
Doc VisitExpr_(const LetNode* l) final {
- return Group(DocOfStr("let") + Sep() + TypeAnnotation(Docify(l->var), l->value_type) + Sep() +
+ return Group(DocOfStr("let") + Sep() +
+ TypeAnnotation(Docify(l->var), l->var->type_annotation) + Sep() +
DocOfStr("=") + Sep() + Docify(l->value) + DocOfStr(";") + Endl() +
Docify(l->body));
}
diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc
index 6b56cb4e844f..c248ad0de6f7 100644
--- a/src/relay/ir/expr.cc
+++ b/src/relay/ir/expr.cc
@@ -54,20 +54,26 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "Tuple(" << node->fields << ")";
});
-Var VarNode::make(std::string name_hint) {
+Var VarNode::make(std::string name_hint, Type type_annotation) {
NodePtr n = make_node();
n->name_hint = std::move(name_hint);
+ n->type_annotation = std::move(type_annotation);
return Var(n);
}
TVM_REGISTER_API("relay._make.Var")
.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = VarNode::make(args[0]);
+ *ret = VarNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch([](const VarNode *node, tvm::IRPrinter *p) {
- p->stream << "Var(" << node->name_hint << ")";
+ p->stream << "Var(" << node->name_hint;
+ if (node->type_annotation.defined()) {
+ p->stream << ", ty=";
+ p->print(node->type_annotation);
+ }
+ p->stream << ")";
});
GlobalVar GlobalVarNode::make(std::string name_hint) {
@@ -86,24 +92,10 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "GlobalVar(" << node->name_hint << ")";
});
-Param ParamNode::make(Var var, Type type) {
- NodePtr n = make_node();
- n->var = std::move(var);
- n->type = std::move(type);
- return Param(n);
-}
-
-TVM_REGISTER_API("relay._make.Param")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = ParamNode::make(args[0], args[1]);
-});
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch([](const ParamNode *node, tvm::IRPrinter *p) {
- p->stream << "Param(" << node->var << ", " << node->type << ")";
-});
-
-Function FunctionNode::make(tvm::Array params, Type ret_type, Expr body,
+Function FunctionNode::make(tvm::Array params,
+ Type ret_type,
+ Expr body,
tvm::Array type_params) {
NodePtr n = make_node();
n->params = std::move(params);
@@ -113,12 +105,11 @@ Function FunctionNode::make(tvm::Array params, Type ret_type, Expr body,
return Function(n);
}
-Type FunctionNode::fn_type() const {
+FuncType FunctionNode::func_type_annotation() const {
Array param_types;
for (auto param : this->params) {
- param_types.push_back(param->type);
+ param_types.push_back(param->type_annotation);
}
-
return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {});
}
@@ -155,24 +146,23 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< node->attrs << ", " << node->type_args << ")";
});
-Let LetNode::make(Var var, Expr value, Expr body, Type value_type) {
+Let LetNode::make(Var var, Expr value, Expr body) {
NodePtr n = make_node();
n->var = std::move(var);
n->value = std::move(value);
n->body = std::move(body);
- n->value_type = std::move(value_type);
return Let(n);
}
TVM_REGISTER_API("relay._make.Let")
.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = LetNode::make(args[0], args[1], args[2], args[3]);
-});
+ *ret = LetNode::make(args[0], args[1], args[2]);
+ });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch([](const LetNode *node, tvm::IRPrinter *p) {
p->stream << "LetNode(" << node->var << ", " << node->value
- << ", " << node->body << ", " << node->value_type << ")";
+ << ", " << node->body << ")";
});
If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc
index 792f99d699dd..c55e4d672b6c 100644
--- a/src/relay/ir/expr_functor.cc
+++ b/src/relay/ir/expr_functor.cc
@@ -24,6 +24,16 @@ Expr ExprMutator::Mutate(const Expr& expr) {
}
Expr ExprMutator::VisitExpr_(const VarNode* op) {
+ // NOTE: var will only be mutated once
+ // Thanks to the memo and reused during rewriting if necessary.
+ // It is safe to assume that the
+ if (op->type_annotation.defined()) {
+ auto type = this->VisitType(op->type_annotation);
+ if (!op->type_annotation.same_as(type)) {
+ return VarNode::make(op->name_hint, type);
+ }
+ }
+ // default case return self.
return GetRef(op);
}
@@ -55,16 +65,6 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) {
}
}
-Expr ExprMutator::VisitExpr_(const ParamNode* op) {
- Var var = Downcast(this->Mutate(op->var));
- auto type = this->VisitType(op->type);
- if (op->var.same_as(var) && op->type.same_as(type)) {
- return GetRef(op);
- } else {
- return ParamNode::make(var, type);
- }
-}
-
Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
tvm::Array ty_params;
bool all_ty_params_changed = true;
@@ -75,10 +75,10 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
all_ty_params_changed &= new_ty_param.same_as(ty_param);
}
- tvm::Array params;
+ tvm::Array params;
bool all_params_changed = true;
for (auto param : op->params) {
- Param new_param = Downcast(this->Mutate(param));
+ Var new_param = Downcast(this->Mutate(param));
params.push_back(new_param);
all_params_changed &= param.same_as(new_param);
}
@@ -123,17 +123,15 @@ Expr ExprMutator::VisitExpr_(const CallNode* call_node) {
Expr ExprMutator::VisitExpr_(const LetNode* op) {
Var var = Downcast(this->Mutate(op->var));
- auto type = this->VisitType(op->value_type);
auto value = this->Mutate(op->value);
auto body = this->Mutate(op->body);
if (var.same_as(op->var) &&
- type.same_as(op->value_type) &&
value.same_as(op->value) &&
body.same_as(op->body)) {
return GetRef(op);
} else {
- return LetNode::make(var, value, body, type);
+ return LetNode::make(var, value, body);
}
}
@@ -162,6 +160,9 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
Type ExprMutator::VisitType(const Type& t) { return t; }
void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) {
+ if (op->type_annotation.defined()) {
+ this->VisitType(op->type_annotation);
+ }
}
void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) {
@@ -176,10 +177,6 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) {
}
}
-void ExprVisitor::ExprVisitor::VisitExpr_(const ParamNode* op) {
- this->VisitExpr(op->var);
-}
-
void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) {
for (auto param : op->params) {
this->VisitExpr(param);
diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc
index 0ed0e3df3056..29d2f87cf04a 100644
--- a/src/relay/pass/alpha_eq.cc
+++ b/src/relay/pass/alpha_eq.cc
@@ -252,15 +252,6 @@ struct AlphaEq : ExprFunctor {
}
}
- void VisitExpr_(const ParamNode* p1, const Expr& e2) final {
- if (const ParamNode* p2 = e2.as()) {
- eq_map.Set(p1->var, p2->var);
- equal = equal && AlphaEqual(p1->type, p2->type);
- } else {
- equal = false;
- }
- }
-
void VisitExpr_(const FunctionNode* func1, const Expr& e2) final {
if (const FunctionNode* func2 = e2.as()) {
if (func1->params.size() != func2->params.size()) {
@@ -273,9 +264,10 @@ struct AlphaEq : ExprFunctor {
return;
}
- for (size_t i = 0U; i < func1->params.size(); i++) {
- this->VisitExpr(func1->params[i], func2->params[i]);
+ for (size_t i = 0; i < func1->params.size(); ++i) {
+ MergeVarDecl(func1->params[i], func2->params[i]);
}
+ if (!equal) return;
for (size_t i = 0U; i < func1->type_params.size(); i++) {
equal = equal && AlphaEqual(func1->type_params[i], func2->type_params[i]);
@@ -332,19 +324,9 @@ struct AlphaEq : ExprFunctor {
void VisitExpr_(const LetNode* op, const Expr& e2) final {
if (const LetNode* let = e2.as()) {
- eq_map.Set(op->var, let->var);
+ MergeVarDecl(op->var, let->var);
this->VisitExpr(op->value, let->value);
this->VisitExpr(op->body, let->body);
-
- // value_type should match as well (including nulls)
- if (op->value_type.defined() != let->value_type.defined()) {
- equal = false;
- return;
- }
-
- if (op->value_type.defined()) {
- equal = equal && AlphaEqual(op->value_type, let->value_type);
- }
} else {
equal = false;
}
@@ -388,6 +370,20 @@ struct AlphaEq : ExprFunctor {
equal = false;
}
}
+
+ private:
+ void MergeVarDecl(const Var& var1, const Var& var2) {
+ if (var1->type_annotation.defined() != var2->type_annotation.defined()) {
+ equal = false;
+ return;
+ }
+ if (var1->type_annotation.defined() &&
+ !AlphaEqual(var1->type_annotation, var2->type_annotation)) {
+ equal = false;
+ return;
+ }
+ eq_map.Set(var1, var2);
+ }
};
bool AlphaEqual(const Expr& e1, const Expr& e2) {
diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc
index 05036042a635..2e2eca1f2739 100644
--- a/src/relay/pass/dead_code.cc
+++ b/src/relay/pass/dead_code.cc
@@ -54,12 +54,7 @@ class CalcDep : private ExprMutator {
}
private:
- struct Binder {
- Type t;
- Expr e;
- Binder(const Type& t, const Expr& e) : t(t), e(e) { }
- };
- using VarMap = std::unordered_map;
+ using VarMap = std::unordered_map;
VarMap var_map_;
Expr VisitExpr_(const IfNode* i) final {
@@ -74,9 +69,7 @@ class CalcDep : private ExprMutator {
}
Expr VisitExpr_(const LetNode* l) final {
- var_map_.insert(std::pair(l->var,
- Binder(l->value_type,
- Eliminate(l->value))));
+ var_map_[l->var] = Eliminate(l->value);
return VisitExpr(l->body);
}
@@ -92,15 +85,16 @@ class CalcDep : private ExprMutator {
explicit GenLet(const VarMap& var_map) : var_map_(var_map) { }
friend CalcDep;
- void VisitExpr_(const VarNode* vn) final {
- Var v = GetRef(vn);
- if (var_map_.count(v) != 0) {
- auto val = var_map_.at(v);
- var_map_.erase(v);
+ void VisitExpr_(const VarNode* vnode) final {
+ Var v = GetRef(vnode);
+ auto it = var_map_.find(v);
+ if (it != var_map_.end()) {
+ Expr expr = it->second;
+ var_map_.erase(it);
// erase before visit to handle letrec
- VisitExpr(val.e);
+ VisitExpr(expr);
// visit before push back so the dependency of dependency is before the dependency
- lets_.Push(v, val.t, val.e);
+ lets_.Push(v, expr);
}
}
};
diff --git a/src/relay/pass/let_list.h b/src/relay/pass/let_list.h
index d13358fe0e30..43b8bb8bba1d 100644
--- a/src/relay/pass/let_list.h
+++ b/src/relay/pass/let_list.h
@@ -26,57 +26,46 @@ namespace relay {
*/
class LetList {
public:
- /*! \brief insert a binding.
+ /*!
+ * \brief insert a binding.
*
- * \param pv the var of the binding.
+ * \param pv the var of the binding.
*
- * \param ty the type of the binding.
+ * \param expr the value of the binding.
*
- * \param expr the value of the binding.
- *
- * \return a Var that hold the inserted expr.
+ * \return a Var that hold the inserted expr.
*/
- Var Push(const Var& pv, const Type& ty, const Expr& expr) {
- std::tuple tuple(pv, ty, expr);
- lets_.push_back(tuple);
+ Var Push(Var pv, Expr expr) {
+ lets_.emplace_back(std::make_pair(pv, expr));
return pv;
}
- /*! \brief insert a binding.
+ /*!
+ * \brief insert a binding.
*
- * \param ty the type of the binding.
+ * \param ty the type of the binding.
*
- * \param expr the value of the binding.
+ * \param expr the value of the binding.
*
- * \return a Var that hold the inserted expr.
- */
- Var Push(const Type& ty, const Expr& expr) {
- return Push(VarNode::make("x"), ty, expr);
- }
-
- /*! \brief insert a binding.
- *
- * \param pv the var of the binding.
- *
- * \param expr the value of the binding.
- *
- * \return a Var that hold the inserted expr.
+ * \return a Var that hold the inserted expr.
*/
- Var Push(const Var& pv, const Expr& expr) {
- return Push(pv, IncompleteTypeNode::make(TypeParamNode::kType), expr);
+ Var Push(Type ty, Expr expr) {
+ return Push(VarNode::make("x", ty), expr);
}
- /*! \brief insert a binding.
+ /*!
+ * \brief insert a binding.
*
* \param expr the value of the binding.
*
* \return a Var that hold the inserted expr.
*/
- Var Push(const Expr& expr) {
+ Var Push(Expr expr) {
return Push(IncompleteTypeNode::make(TypeParamNode::kType), expr);
}
- /*! \brief wrap an expr around the LetList.
+ /*!
+ * \brief wrap an expr around the LetList.
*
* \param body the Expression to be wrapped around.
*
@@ -85,7 +74,7 @@ class LetList {
Expr Get(const Expr& body) const {
Expr ret = body;
for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) {
- ret = LetNode::make(std::get<0>(*rit), std::get<2>(*rit), ret, std::get<1>(*rit));
+ ret = LetNode::make(std::get<0>(*rit), std::get<1>(*rit), ret);
}
return ret;
}
@@ -118,7 +107,7 @@ class LetList {
}
private:
- std::vector > lets_;
+ std::vector > lets_;
};
} // namespace relay
diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc
index 72bdaf69f061..1b30865eacb1 100644
--- a/src/relay/pass/type_infer.cc
+++ b/src/relay/pass/type_infer.cc
@@ -87,15 +87,11 @@ class TypeInferencer : private ExprFunctor {
// Visitor logics
Type VisitExpr_(const VarNode* op) final {
- // The type of Var can already been lookedup in type_map_;
- LOG(FATAL) << "Cannot find binding for var " << GetRef(op);
- return Type();
- }
-
- Type VisitExpr_(const ParamNode* op) final {
- // directly handled by Funtion
- LOG(FATAL) << "not reached";
- return Type();
+ if (op->type_annotation.defined()) {
+ return op->type_annotation;
+ } else {
+ return IncompleteTypeNode::make(TypeParamNode::kType);
+ }
}
Type VisitExpr_(const GlobalVarNode* op) final {
@@ -139,11 +135,11 @@ class TypeInferencer : private ExprFunctor {
Type VisitExpr_(const LetNode* op) final {
Type vtype = GetType(op->value);
- if (op->value_type.defined()) {
- vtype = Unify(vtype, op->value_type, op->span);
+ if (op->var->type_annotation.defined()) {
+ vtype = Unify(vtype, op->var->type_annotation, op->span);
}
CHECK(!type_map_.count(op->var));
- // NOTE: no scoping is necessary becase var are unique in program
+ // NOTE: no scoping is necessary because var are unique in program
type_map_[op->var] = vtype;
return GetType(op->body);
}
@@ -256,8 +252,7 @@ class TypeInferencer : private ExprFunctor {
Type VisitExpr_(const FunctionNode* f) final {
for (auto param : f->params) {
- type_map_[param->var] = param->type;
- type_map_[param] = param->type;
+ GetType(param);
}
Type rtype = GetType(f->body);
// Run solver using the currently known information
@@ -265,8 +260,7 @@ class TypeInferencer : private ExprFunctor {
// Trying to resolve
Array arg_types;
for (size_t i = 0; i < f->params.size(); ++i) {
- Param param = f->params[i];
- Type atype = solver_.Resolve(param->type);
+ Type atype = solver_.Resolve(GetType(f->params[i]));
CHECK(atype.as() == nullptr)
<< "Cannot resolve type of " << i
<< "-th parameter of function at" << f->span;
@@ -311,9 +305,6 @@ class TypeInferencer::Resolver : public ExprMutator {
return AttachCheckedType(op);
}
- Expr VisitExpr_(const ParamNode* op) final {
- return ExprMutator::VisitExpr_(op);
- }
Expr VisitExpr_(const FunctionNode* op) final {
return AttachCheckedType(op);
@@ -380,7 +371,7 @@ Expr InferType(const Environment& env,
const GlobalVar& var,
const Function& func) {
Function func_copy = Function(make_node(*func.operator->()));
- func_copy->checked_type_ = func_copy->fn_type();
+ func_copy->checked_type_ = func_copy->func_type_annotation();
env->functions.Set(var, func_copy);
Expr func_ret = TypeInferencer(env).Infer(func_copy);
auto map_node = env->functions.CopyOnWrite();
diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc
index 5f87c3d4cb89..c845995b2003 100644
--- a/src/relay/pass/util.cc
+++ b/src/relay/pass/util.cc
@@ -50,14 +50,17 @@ class FreeVar : public ExprVisitor {
if (bound_vars.count(var) == 0) {
free_vars.insert(var);
}
+ if (v->type_annotation.defined()) {
+ VisitType(v->type_annotation);
+ }
}
void VisitExpr_(const FunctionNode *f) final {
for (const auto& tp : f->type_params) {
bound_types.insert(tp);
}
- for (const auto& p : f->params) {
- bound_vars.insert(p->var);
+ for (const auto& param : f->params) {
+ bound_vars.insert(param);
}
VisitExpr(f->body);
VisitType(f->ret_type);
@@ -67,7 +70,6 @@ class FreeVar : public ExprVisitor {
bound_vars.insert(l->var);
VisitExpr(l->value);
VisitExpr(l->body);
- VisitType(l->value_type);
}
public:
diff --git a/src/relay/pass/well_formed.cc b/src/relay/pass/well_formed.cc
index a9bce74926bf..e008a72e5d90 100644
--- a/src/relay/pass/well_formed.cc
+++ b/src/relay/pass/well_formed.cc
@@ -34,8 +34,8 @@ class WellFormedChecker : private ExprVisitor {
}
void VisitExpr_(const FunctionNode * f) final {
- for (const Param & p : f->params) {
- Check(p->var);
+ for (const Var & param : f->params) {
+ Check(param);
}
CheckWellFormed(f->body);
}
diff --git a/tests/python/relay/test_ir_builder.py b/tests/python/relay/test_ir_builder.py
index c98f920ca491..165c66f17ac3 100644
--- a/tests/python/relay/test_ir_builder.py
+++ b/tests/python/relay/test_ir_builder.py
@@ -14,7 +14,6 @@ def test_let():
assert var == prog.body
assert isinstance(value, Constant)
assert value.data.asnumpy() == np.array(1)
- assert prog.value_type == None
if __name__ == "__main__":
test_let()
diff --git a/tests/python/relay/test_ir_debug_printer.py b/tests/python/relay/test_ir_debug_printer.py
index e5f9ad2e69cd..b8aa86a87638 100644
--- a/tests/python/relay/test_ir_debug_printer.py
+++ b/tests/python/relay/test_ir_debug_printer.py
@@ -49,18 +49,11 @@ def test_global_var():
show(gv)
-def test_param():
- lv = relay.Var('x')
- ty = None
- param = relay.Param(lv, ty)
- show(lv)
-
-
def test_function():
param_names = ['a', 'b', 'c', 'd']
- params = tvm.convert([relay.Param(relay.Var(n), None) for n in param_names])
+ params = tvm.convert([relay.Var(n) for n in param_names])
ret_type = None
- body = params[0].var
+ body = params[0]
type_params = tvm.convert([])
fn = relay.Function(params, ret_type, body, type_params)
show(fn)
@@ -76,11 +69,11 @@ def test_call():
def test_let():
- lv = relay.Var('x')
ty = relay.ty.TensorType((10, 20), 'float32')
+ lv = relay.Var('x', ty)
arr = tvm.nd.array(10)
value = relay.Constant(arr)
- let = relay.Let(lv, value, lv, ty)
+ let = relay.Let(lv, value, lv)
show(let)
diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py
index 79883ed225e0..e571f2a9c99a 100644
--- a/tests/python/relay/test_ir_nodes.py
+++ b/tests/python/relay/test_ir_nodes.py
@@ -99,10 +99,16 @@ def test_tuple():
def test_local_var():
name_hint = 's'
lv = relay.Var(name_hint)
- lv.name_hint == name_hint
+ assert lv.name_hint == name_hint
+ assert lv.type_annotation is None
# assert lv.span == None todo(@jroesch): what do we do about spans
str(lv)
+ t1 = relay.ty.TensorType((), "float")
+ lv = relay.Var(name_hint, t1)
+ assert lv.name_hint == name_hint
+ assert lv.type_annotation == t1
+
def test_global_var():
name_hint = 'g'
@@ -112,19 +118,9 @@ def test_global_var():
str(gv)
-def test_param():
- lv = relay.Var('x')
- ty = None
- param = relay.Param(lv, ty)
- assert param.var == lv
- assert param.type == ty
- assert param.span == None
- str(param)
-
-
def test_function():
param_names = ['a', 'b', 'c', 'd']
- params = tvm.convert([relay.Param(relay.Var(n), None) for n in param_names])
+ params = tvm.convert([relay.Var(n) for n in param_names])
ret_type = None
body = None
type_params = tvm.convert([])
@@ -154,10 +150,9 @@ def test_let():
value = relay.Constant(arr)
# I would prefer that the order of arguments
# matches syntax let x: t = v in b
- let = relay.Let(lv, value, lv, ty)
+ let = relay.Let(lv, value, lv)
assert let.var == lv
assert let.value == value
- assert let.value_type == ty
assert let.body == lv
assert let.span == None
str(let)
@@ -194,7 +189,6 @@ def test_tuple_get_item():
test_tuple()
test_local_var()
test_global_var()
- test_param()
test_function()
test_call()
test_let()
diff --git a/tests/python/relay/test_ir_well_formed.py b/tests/python/relay/test_ir_well_formed.py
index c6cb99662bb5..d555c2beb627 100644
--- a/tests/python/relay/test_ir_well_formed.py
+++ b/tests/python/relay/test_ir_well_formed.py
@@ -7,23 +7,22 @@ def test_well_formed():
assert well_formed(x)
v = relay.Constant(tvm.nd.array(10))
ty = None
- let = relay.Let(x, v, x, ty)
+ let = relay.Let(x, v, x)
assert well_formed(let)
- assert not well_formed(relay.Let(x, v, let, ty))
- f = relay.Function([relay.Param(x, ty)], ty, x)
+ assert not well_formed(relay.Let(x, v, let))
+ f = relay.Function([x], ty, x)
assert well_formed(f)
# this test should pass in case of weak uniqueness (only test for shadowing)
# but we want all binder to be distinct from each other.
assert not well_formed(relay.Let(relay.Var("y"), f,
- relay.Let(relay.Var("z"), f, v, ty), ty))
+ relay.Let(relay.Var("z"), f, v)))
def test_tuple():
x = relay.Var('x')
assert well_formed(x)
v = relay.Constant(tvm.nd.array(10))
- ty = None
- let = relay.Let(x, v, x, ty)
+ let = relay.Let(x, v, x)
assert well_formed(let)
assert well_formed(relay.Tuple([v, v]))
assert not well_formed(relay.Tuple([let, let]))
diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py
index a90f6eb55ae1..05c02ab5d197 100644
--- a/tests/python/relay/test_op_level1.py
+++ b/tests/python/relay/test_op_level1.py
@@ -27,6 +27,8 @@ def check_single_op(opfunc):
tvm.relay.sigmoid, tvm.relay.tanh]:
check_single_op(opfunc)
+
+
def test_expand_dims_infer_type():
ib = relay.ir_builder.IRBuilder()
n, t, d = tvm.var("n"), tvm.var("t"), 100
@@ -75,12 +77,13 @@ def test_unary_op():
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((10, 4), "int32"))
with ib.function(x) as func:
- ib.ret(op(x.var))
+ ib.ret(op(x))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((10, 4), "int32")
+
def test_binary_op():
def check_binary_op(opfunc):
"""
@@ -94,7 +97,7 @@ def check_binary_op(opfunc):
x = b.param('x', tensor_type(5, 5, 5))
y = b.param('y', tensor_type(5, 5, 5))
with b.function(x, y) as func:
- b.ret(opfunc(x.var, y.var))
+ b.ret(opfunc(x, y))
b.ret(func)
prog, env = b.get()
ttype = tensor_type(5, 5, 5)
@@ -118,7 +121,7 @@ def check_binary_broadcast_op(opfunc):
x = b.param('x', tensor_type(10, 4))
y = b.param('y', tensor_type(5, 10, 1))
with b.function(x, y) as func:
- b.ret(opfunc(x.var, y.var))
+ b.ret(opfunc(x, y))
b.ret(func)
prog, env = b.get()
diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py
index f67faea19be1..d0d02aece06d 100644
--- a/tests/python/relay/test_op_level2.py
+++ b/tests/python/relay/test_op_level2.py
@@ -11,7 +11,7 @@ def test_conv2d_infer_type():
w = ib.param("w", relay.ty.IncompleteType())
with ib.function(x, w) as func:
- ib.ret(relay.nn.conv2d(x.var, w.var,
+ ib.ret(relay.nn.conv2d(x, w,
kernel_size=(3, 3),
padding=(1, 1),
channels=2))
@@ -29,7 +29,7 @@ def test_conv2d_infer_type():
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8"))
w = ib.param("w", relay.ty.TensorType((2, 10, 3, 3), "int8"))
with ib.function(x, w) as func:
- ib.ret(relay.nn.conv2d(x.var, w.var, out_dtype="int32"))
+ ib.ret(relay.nn.conv2d(x, w, out_dtype="int32"))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
@@ -42,7 +42,7 @@ def test_conv2d_infer_type():
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8"))
w = ib.param("w", relay.ty.IncompleteType())
with ib.function(x, w) as func:
- ib.ret(relay.nn.conv2d(x.var, w.var,
+ ib.ret(relay.nn.conv2d(x, w,
kernel_size=(3, 3),
padding=(1, 1),
channels=16,
@@ -65,7 +65,7 @@ def test_conv2d_transpose_infer_type():
w = ib.param("w", relay.ty.IncompleteType())
with ib.function(x, w) as func:
- ib.ret(relay.nn.conv2d_transpose(x.var, w.var,
+ ib.ret(relay.nn.conv2d_transpose(x, w,
kernel_size=(3, 3),
padding=(1, 1),
channels=15))
@@ -83,7 +83,7 @@ def test_conv2d_transpose_infer_type():
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
w = ib.param("w", relay.ty.TensorType((12, 11, 5, 5), "float32"))
with ib.function(x, w) as func:
- ib.ret(relay.nn.conv2d_transpose(x.var, w.var,
+ ib.ret(relay.nn.conv2d_transpose(x, w,
output_padding=(1, 1),
channels=11,
data_layout="NHWC"))
@@ -98,7 +98,7 @@ def test_upsampling_infer_type():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func:
- ib.ret(relay.nn.upsampling(x.var, scale=2, layout="NCHW", method="BILINEAR"))
+ ib.ret(relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR"))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
@@ -108,7 +108,7 @@ def test_upsampling_infer_type():
n, c = tvm.var("n"), tvm.var("c")
x = ib.param("x", relay.ty.TensorType((n, c, 100, 200), "float32"))
with ib.function(x) as func:
- ib.ret(relay.nn.upsampling(x.var, scale=2, layout="NCHW", method="BILINEAR"))
+ ib.ret(relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR"))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
@@ -119,7 +119,7 @@ def _test_pool2d_infer_type(opfunc):
n, c, h, w = tvm.var("n"), 10, 224, 224
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func:
- ib.ret(opfunc(x.var, pool_size=(1, 1)))
+ ib.ret(opfunc(x, pool_size=(1, 1)))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
@@ -132,7 +132,7 @@ def _test_pool2d_infer_type(opfunc):
n, c, h, w = tvm.var("n"), 10, 224, 224
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func:
- ib.ret(opfunc(x.var, pool_size=(ph, pw), strides=(sh, sw)))
+ ib.ret(opfunc(x, pool_size=(ph, pw), strides=(sh, sw)))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
@@ -144,7 +144,7 @@ def _test_global_pool2d_infer_type(opfunc):
n, c, h, w = tvm.var("n"), tvm.var("c"), 224, 224
x = ib.param("x", relay.ty.TensorType((n, h, w, c), "float32"))
with ib.function(x) as func:
- ib.ret(opfunc(x.var, layout="NHWC"))
+ ib.ret(opfunc(x, layout="NHWC"))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
@@ -154,7 +154,7 @@ def _test_global_pool2d_infer_type(opfunc):
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
with ib.function(x) as func:
- ib.ret(opfunc(x.var))
+ ib.ret(opfunc(x))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
@@ -172,7 +172,7 @@ def test_flatten_infer_type():
x = ib.param("x", relay.ty.TensorType((d1, d2, d3, d4), "float32"))
with ib.function(x) as func:
- ib.ret(relay.nn.batch_flatten(x.var))
+ ib.ret(relay.nn.batch_flatten(x))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
@@ -181,7 +181,7 @@ def test_flatten_infer_type():
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.ty.TensorType((3, 2, 4, 3), "float32"))
with ib.function(x) as func:
- ib.ret(relay.nn.batch_flatten(x.var))
+ ib.ret(relay.nn.batch_flatten(x))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
@@ -190,7 +190,7 @@ def test_flatten_infer_type():
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.ty.TensorType((d1, 2, d3, 3), "float32"))
with ib.function(x) as func:
- ib.ret(relay.nn.batch_flatten(x.var))
+ ib.ret(relay.nn.batch_flatten(x))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
@@ -202,7 +202,7 @@ def test_pad_infer_type():
n, c, h, w = 1, 2, 3, 4
t = ib.param("t", relay.TensorType((n, c, h, w), "float32"))
with ib.function(t) as func:
- ib.ret(relay.nn.pad(t.var, ((1, 1), (2, 2), (3, 3), (4, 4))))
+ ib.ret(relay.nn.pad(t, ((1, 1), (2, 2), (3, 3), (4, 4))))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
@@ -213,7 +213,7 @@ def test_pad_infer_type():
n, c, h, w = tvm.var("n"), 2, 3, tvm.var("w")
t = ib.param("t", relay.TensorType((n, c, h, w), "float32"))
with ib.function(t) as func:
- ib.ret(relay.nn.pad(t.var, ((1, 1), (2, 2), (3, 3), (4, 4))))
+ ib.ret(relay.nn.pad(t, ((1, 1), (2, 2), (3, 3), (4, 4))))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
@@ -227,4 +227,3 @@ def test_pad_infer_type():
test_flatten_infer_type()
test_pad_infer_type()
test_conv2d_transpose_infer_type()
-
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index 9515db87e64a..7d949b21026b 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -17,12 +17,13 @@ def test_zeros_ones():
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((124, 50), "float64")
+
def test_unary_identity():
for op in [relay.zeros_like, relay.ones_like]:
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((8, 9, 4), "int32"))
with ib.function(x) as func:
- ib.ret(op(x.var))
+ ib.ret(op(x))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
@@ -33,7 +34,7 @@ def test_clip_type():
ib = relay.ir_builder.IRBuilder()
a = ib.param("a", relay.TensorType((10, 4), "float32"))
with ib.function(a) as func:
- ib.ret(relay.clip(a.var, 1., 4.))
+ ib.ret(relay.clip(a, 1., 4.))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
@@ -106,7 +107,7 @@ def verify_take(dshape, indices_shape, oshape, axis=None):
x = ib.param("x", relay.ty.TensorType(dshape, "float32"))
indices = ib.param("indices", relay.ty.TensorType(indices_shape, "int32"))
with ib.function(x, indices) as func:
- ib.ret(relay.take(x.var, indices.var, axis=axis))
+ ib.ret(relay.take(x, indices, axis=axis))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
@@ -127,7 +128,7 @@ def test_full():
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((), "int8"))
with ib.function(x) as func:
- ib.ret(relay.full(x.var, ()))
+ ib.ret(relay.full(x, ()))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
@@ -137,7 +138,7 @@ def test_full():
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.TensorType((), "float32"))
with ib.function(x) as func:
- ib.ret(relay.full(x.var, (1, 2), "int8"))
+ ib.ret(relay.full(x, (1, 2), "int8"))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
@@ -150,7 +151,7 @@ def test_full_like():
base = ib.param("base", relay.TensorType((1, 2, 3), "float32"))
fill = ib.param("fill", relay.TensorType((), "float32"))
with ib.function(base, fill) as func:
- ib.ret(relay.full_like(base.var, fill.var))
+ ib.ret(relay.full_like(base, fill))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
@@ -162,7 +163,7 @@ def test_full_like():
base = ib.param("base", relay.TensorType((n, c, h, w), "float32"))
fill = ib.param("fill", relay.TensorType((), "float32"))
with ib.function(base, fill) as func:
- ib.ret(relay.full_like(base.var, fill.var))
+ ib.ret(relay.full_like(base, fill))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py
index 807d3a3a964e..995e15fb9760 100644
--- a/tests/python/relay/test_op_level4.py
+++ b/tests/python/relay/test_op_level4.py
@@ -24,7 +24,7 @@ def test_cmp_type():
x = ib.param("x", relay.TensorType((10, 4), "float32"))
y = ib.param("y", relay.TensorType((5, 10, 1), "float32"))
with ib.function(x, y) as func:
- ib.ret(op(x.var, y.var))
+ ib.ret(op(x, y))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
@@ -39,7 +39,7 @@ def test_binary_broadcast():
x = ib.param("x", relay.TensorType((10, 4), "int32"))
y = ib.param("y", relay.TensorType((5, 10, 1), "int32"))
with ib.function(x, y) as func:
- ib.ret(op(x.var, y.var))
+ ib.ret(op(x, y))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
@@ -58,7 +58,7 @@ def check_binary_op(opfunc):
x = b.param('x', tensor_type(5, 5, 5))
y = b.param('y', tensor_type(5, 5, 5))
with b.function(x, y) as func:
- b.ret(opfunc(x.var, y.var))
+ b.ret(opfunc(x, y))
b.ret(func)
prog, env = b.get()
ttype = tensor_type(5, 5, 5)
@@ -81,7 +81,7 @@ def check_binary_broadcast_op(opfunc):
x = b.param('x', tensor_type(10, 4))
y = b.param('y', tensor_type(5, 10, 1))
with b.function(x, y) as func:
- b.ret(opfunc(x.var, y.var))
+ b.ret(opfunc(x, y))
b.ret(func)
prog, env = b.get()
@@ -103,7 +103,7 @@ def test_cmp_type():
x = ib.param("x", relay.TensorType((10, 4), "float32"))
y = ib.param("y", relay.TensorType((5, 10, 1), "float32"))
with ib.function(x, y) as func:
- ib.ret(op(x.var, y.var))
+ ib.ret(op(x, y))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
@@ -118,7 +118,7 @@ def test_binary_broadcast():
x = ib.param("x", relay.TensorType((10, 4), "int32"))
y = ib.param("y", relay.TensorType((5, 10, 1), "int32"))
with ib.function(x, y) as func:
- ib.ret(op(x.var, y.var))
+ ib.ret(op(x, y))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
@@ -131,7 +131,7 @@ def test_where():
x = ib.param("x", relay.TensorType((3, 4), "float32"))
y = ib.param("y", relay.TensorType((3, 4), "float32"))
with ib.function(cond, x, y) as func:
- ib.ret(relay.where(cond.var, x.var, y.var))
+ ib.ret(relay.where(cond, x, y))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py
index 62da592e8249..8d871e9ef4f5 100644
--- a/tests/python/relay/test_op_level5.py
+++ b/tests/python/relay/test_op_level5.py
@@ -10,7 +10,7 @@ def test_resize_infer_type():
th, tw = tvm.var("th"), tvm.var("tw")
with ib.function(x) as func:
- ib.ret(relay.image.resize(x.var, (th, tw)))
+ ib.ret(relay.image.resize(x, (th, tw)))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
@@ -19,7 +19,7 @@ def test_resize_infer_type():
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8"))
with ib.function(x) as func:
- ib.ret(relay.image.resize(x.var, (100, 200), "NCHW", "BILINEAR", False))
+ ib.ret(relay.image.resize(x, (100, 200), "NCHW", "BILINEAR", False))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py
index dd722399dac4..04ef3cf3da8f 100644
--- a/tests/python/relay/test_pass_alpha_equal.py
+++ b/tests/python/relay/test_pass_alpha_equal.py
@@ -1,4 +1,5 @@
import tvm
+import numpy as np
from tvm import relay
from tvm.relay.ir_pass import alpha_equal
from tvm.relay.ir_builder import convert
@@ -179,9 +180,9 @@ def test_var_alpha_equal():
assert not alpha_equal(v1, v2)
# let node allows for setting the eq_map
- l1 = relay.Let(v1, convert(1), v1, None)
- l2 = relay.Let(v2, convert(1), v2, None)
- l3 = relay.Let(v1, convert(1), v2, None)
+ l1 = relay.Let(v1, convert(1), v1)
+ l2 = relay.Let(v2, convert(1), v2)
+ l3 = relay.Let(v1, convert(1), v2)
assert alpha_equal(l1, l2)
assert not alpha_equal(l1, l3)
@@ -209,10 +210,10 @@ def test_tuple_alpha_equal():
assert alpha_equal(tup, same)
# use the eq_map
- let_tup = relay.Let(v1, tup, v1, None)
+ let_tup = relay.Let(v1, tup, v1)
let_mapped = relay.Let(v2, relay.Tuple([v2, convert(2), convert(3),
relay.Tuple([convert(4)])]),
- v2, None)
+ v2)
assert alpha_equal(let_tup, let_mapped)
more_fields = relay.Tuple([v1, convert(2), convert(3), relay.Tuple([convert(4)]), v2])
@@ -242,61 +243,44 @@ def test_tuple_get_item_alpha_equal():
assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))
-def test_param_alpha_equal():
- # only checks equality of the types
- v1 = relay.Var("v1")
- v2 = relay.Var("v2")
-
- p1 = relay.Param(v1, relay.TensorType((1, 2, 3), "float32"))
- p2 = relay.Param(v2, relay.TensorType((1, 2, 3), "float32"))
- assert alpha_equal(p1, p2)
-
- p3 = relay.Param(v1, relay.TensorType((4, 5, 6), "int8"))
- assert not alpha_equal(p1, p3)
-
- p4 = relay.Param(v1, relay.TupleType([relay.TensorType((1, 2, 3),
- "float32")]))
- assert not alpha_equal(p1, p4)
-
-
def test_function_alpha_equal():
- v1 = relay.Var("v1")
- v2 = relay.Var("v2")
- v3 = relay.Var("v3")
- v4 = relay.Var("v4")
-
tt1 = relay.TensorType((1, 2, 3), "float32")
tt2 = relay.TensorType((4, 5, 6), "int8")
tt3 = relay.TupleType([tt1, tt2])
+ v1 = relay.Var("v1", tt1)
+ v2 = relay.Var("v2", tt2)
+ v3 = relay.Var("v3", tt3)
+ v4 = relay.Var("v4", tt2)
+ vret = relay.Constant(tvm.nd.array(np.ones(1)))
+
tp1 = relay.TypeParam("tp1", relay.Kind.Type)
tp2 = relay.TypeParam("tp2", relay.Kind.Type)
tp3 = relay.TypeParam("tp3", relay.Kind.Shape)
tp4 = relay.TypeParam("tp4", relay.Kind.Shape)
- basic_args = [relay.Param(v3, tt1), relay.Param(v4, tt2)]
+ basic_args = [relay.Var("v3", tt1), relay.Var("v4", tt2)]
basic_tps = [tp1, tp2]
- func = relay.Function([relay.Param(v1, tt1), relay.Param(v2, tt2)],
- tt2, v2, basic_tps)
- mapped = relay.Function(basic_args, tt2, v4, basic_tps)
+ func = relay.Function([v1, v2],
+ tt2, v1, basic_tps)
+ mapped = relay.Function(basic_args, tt2, basic_args[0], basic_tps)
assert alpha_equal(func, mapped)
- fewer_params = relay.Function([relay.Param(v4, tt2)], tt2, v4, basic_tps)
+ fewer_params = relay.Function([relay.Var("v4", tt2)], tt2, v4, basic_tps)
assert not alpha_equal(func, fewer_params)
- more_params = relay.Function([relay.Param(v3, tt1), relay.Param(v4, tt2),
- relay.Param(v2, tt2)], tt2, v4, basic_tps)
+ more_params = relay.Function([relay.Var("v3", tt1),
+ relay.Var("v4", tt2),
+ relay.Var("v2", tt2)], tt2, v4, basic_tps)
assert not alpha_equal(func, more_params)
- params_unordered = relay.Function([relay.Param(v3, tt2),
- relay.Param(v4, tt1)],
- tt1, v3, basic_tps)
+ params_unordered = relay.Function([v2, v1],
+ tt2, v1, basic_tps)
assert not alpha_equal(func, params_unordered)
- params_mismatch = relay.Function([relay.Param(v3, tt3),
- relay.Param(v4, tt2)],
- tt2, v4, basic_tps)
+ params_mismatch = relay.Function([v1, v3],
+ tt2, v1, basic_tps)
assert not alpha_equal(func, params_mismatch)
# also would not typecheck
@@ -376,7 +360,10 @@ def test_call_alpha_equal():
def test_let_alpha_equal():
+ tt1 = relay.TensorType((), "float32")
+ tt2 = relay.TensorType((), "int8")
v1 = relay.Var("v1")
+ v1_wtype = relay.Var("v1", tt1)
v2 = relay.Var("v2")
v3 = relay.Var("v3")
@@ -394,14 +381,13 @@ def test_let_alpha_equal():
assert not alpha_equal(let, different_body)
# specified types must match
- tt1 = relay.TensorType((), "float32")
- tt2 = relay.TensorType((), "int8")
- let_with_type = relay.Let(v1, convert(2), v1, tt1)
- same_type = relay.Let(v1, convert(2), v1, tt1)
+
+ let_with_type = relay.Let(v1_wtype, convert(2), v1_wtype)
+ same_type = relay.Let(v1_wtype, convert(2), v1_wtype)
assert alpha_equal(let_with_type, same_type)
assert not alpha_equal(let, let_with_type)
-
- different_type = relay.Let(v1, convert(2), v1, tt2)
+ v2 = relay.Var("v1", tt2)
+ different_type = relay.Let(v2, convert(2), v2)
assert not alpha_equal(let_with_type, different_type)
@@ -437,16 +423,13 @@ def test_op_alpha_equal():
test_tensor_type_alpha_equal()
test_incomplete_type_alpha_equal()
test_constant_alpha_equal()
- test_type_param_alpha_equal()
test_func_type_alpha_equal()
test_tuple_type_alpha_equal()
test_type_relation_alpha_equal()
test_constant_alpha_equal()
- test_var_alpha_equal()
test_global_var_alpha_equal()
test_tuple_alpha_equal()
test_tuple_get_item_alpha_equal()
- test_param_alpha_equal()
test_function_alpha_equal()
test_call_alpha_equal()
test_let_alpha_equal()
diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py
index ce9bda3d254f..121cea0081bd 100644
--- a/tests/python/relay/test_pass_dead_code_elimination.py
+++ b/tests/python/relay/test_pass_dead_code_elimination.py
@@ -28,17 +28,17 @@ def __init__(self):
def test_let():
- orig = relay.Let(e.x, e.y, e.z, e.tt)
+ orig = relay.Let(e.x, e.y, e.z)
assert alpha_equal(dead_code_elimination(orig), e.z)
def test_used_let():
- orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c, e.tt), e.tt)
- assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c, e.tt))
+ orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
+ assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c))
def test_chain_unused_let():
- orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e, e.tt), e.tt)
+ orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e))
assert alpha_equal(dead_code_elimination(orig), e.e)
@@ -56,19 +56,17 @@ def test_recursion():
f(2, 10000);
"""
f = relay.Var("f")
- n = relay.Var("n")
- np = relay.Param(n, e.int32)
- data = relay.Var("data")
- datap = relay.Param(data, e.float32)
+ n = relay.Var("n", e.int32)
+ data = relay.Var("data", e.float32)
funcbody = relay.If(equal(n, convert(0)), data, f(subtract(n, convert(1.0)), log(data)))
- value = relay.Function([np, datap], e.float32, funcbody, [])
- orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0)), e.float32)
+ value = relay.Function([n, data], e.float32, funcbody, [])
+ orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0)))
assert alpha_equal(dead_code_elimination(orig), orig)
- assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three, e.float32)), e.three)
+ assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three)), e.three)
def test_op_let():
- assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three, e.float32), e.two)), add(e.three, e.two))
+ assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three), e.two)), add(e.three, e.two))
def test_if():
@@ -80,7 +78,7 @@ def test_tuple_get_item():
t = relay.Var('t')
g = relay.TupleGetItem(t, 0)
assert alpha_equal(dead_code_elimination(g), g)
- assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t, e.float32), 0)), g)
+ assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t), 0)), g)
if __name__ == "__main__":
diff --git a/tests/python/relay/test_pass_free_vars.py b/tests/python/relay/test_pass_free_vars.py
index 989c9f8d25db..a4c745de10e0 100644
--- a/tests/python/relay/test_pass_free_vars.py
+++ b/tests/python/relay/test_pass_free_vars.py
@@ -3,16 +3,17 @@
from tvm.relay.ir_pass import free_vars, free_type_vars
def test_free_vars():
- x = relay.Var("x")
+ ty = relay.TensorType([], "int32")
+ x = relay.Var("x", ty)
fvx = free_vars(x)
assert len(fvx) == 1
assert fvx[0] == x
v = relay.Constant(tvm.nd.array(10))
- ty = relay.TensorType([], "int32")
- let = relay.Let(x, v, x, ty)
+
+ let = relay.Let(x, v, x)
fvx = free_vars(let)
assert len(free_vars(let)) == 0
- f = relay.Function([relay.Param(x, ty)], ty, x)
+ f = relay.Function([x], ty, x)
assert len(free_vars(f)) == 0
@@ -29,9 +30,9 @@ def test_tuple():
def test_free_type_vars():
tp = relay.TypeParam("")
ty = relay.TupleType([tp, relay.TensorType([], "int32")])
- x = relay.Var("x")
+ x = relay.Var("x", ty)
y = relay.Var("y")
- let = relay.Let(x, y, x, ty)
+ let = relay.Let(x, y, x)
fvl = free_vars(let)
assert len(fvl) == 1
assert fvl[0] == y