Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 29 additions & 40 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,27 @@ class Var;
/*! \brief Container for Var */
class VarNode : public ExprNode {
public:
/*! \brief The name of the variable, this only acts as a hint to the user,
* and is not used for equality.
/*!
* \brief The name of the variable,
* this only acts as a hint to the user,
* and is not used for equality.
*/
std::string name_hint;
/*!
* \brief type annotaion of the variable.
* This field records user provided type annotation of the Var.
* This field is optional and can be None.
*/
Type type_annotation;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name_hint", &name_hint);
v->Visit("type_annotation", &type_annotation);
v->Visit("_checked_type_", &checked_type_);
}

TVM_DLL static Var make(std::string name_hint);
TVM_DLL static Var make(std::string name_hint,
Type type_annotation);

static constexpr const char* _type_key = "relay.Var";
TVM_DECLARE_NODE_TYPE_INFO(VarNode, ExprNode);
Expand Down Expand Up @@ -162,32 +172,6 @@ class GlobalVarNode : public ExprNode {

RELAY_DEFINE_NODE_REF(GlobalVar, GlobalVarNode, Expr);

/*!
* \brief Function parameter declaration.
*/
class Param;
/*! \brief A parameter. */
class ParamNode : public ExprNode {
public:
/*! \brief The variable */
Var var;
/*! \brief The type of the parameter */
Type type;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("type", &type);
v->Visit("span", &span);
}

TVM_DLL static Param make(Var var, Type type);

static constexpr const char* _type_key = "relay.Param";
TVM_DECLARE_NODE_TYPE_INFO(ParamNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Param, ParamNode, Expr);

/*!
* \brief Function (subgraph in computational graph)
*/
Expand All @@ -196,7 +180,7 @@ class Function;
class FunctionNode : public ExprNode {
public:
/*! \brief Function parameters */
tvm::Array<Param> params;
tvm::Array<Var> params;
/*! \brief User annotated return type of the function. */
Type ret_type;
/*!
Expand Down Expand Up @@ -224,10 +208,18 @@ class FunctionNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}

Type fn_type() const;
/*!
* \brief Return the derived function annotation of this expression.
*
* \return The function type annotation.
* \note The function type annotation can contain IncompleteType.
*/
TVM_DLL FuncType func_type_annotation() const;

TVM_DLL static Function make(tvm::Array<Param> params, Type ret_type,
Expr body, tvm::Array<TypeParam> ty_params);
TVM_DLL static Function make(tvm::Array<Var> params,
Type ret_type,
Expr body,
tvm::Array<TypeParam> ty_params);

static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
Expand Down Expand Up @@ -289,7 +281,7 @@ class CallNode : public ExprNode {
TVM_DLL static Call make(Expr op,
Array<Expr> args,
Attrs attrs = Attrs(),
Array<Type> ty_args = Array<Type>());
Array<Type> type_args = Array<Type>());

static constexpr const char* _type_key = "relay.Call";
TVM_DECLARE_NODE_TYPE_INFO(CallNode, ExprNode);
Expand Down Expand Up @@ -318,19 +310,16 @@ class LetNode : public ExprNode {
Expr value;
/*! \brief The body of the let binding */
Expr body;
/*! \brief Type annotation of value, this can be null */
Type value_type;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("value", &value);
v->Visit("body", &body);
v->Visit("value_type", &value_type);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}

TVM_DLL static Let make(Var var, Expr value, Expr body, Type value_type);
TVM_DLL static Let make(Var var, Expr value, Expr body);

static constexpr const char* _type_key = "relay.Let";
TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode);
Expand Down Expand Up @@ -376,11 +365,11 @@ class IfNode : public ExprNode {

RELAY_DEFINE_NODE_REF(If, IfNode, Expr);

/*! \brief Get a field out of a tuple. */
/*! \brief Get index-th field out of a tuple. */
class TupleGetItem;
class TupleGetItemNode : public ExprNode {
public:
/*! \brief The tuple */
/*! \brief The tuple Expression */
Expr tuple;
/*! \brief which value to get */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

capitalize

int index;
Expand Down
4 changes: 0 additions & 4 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ class ExprFunctor<R(const Expr& n, Args...)> {
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const GlobalVarNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ParamNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FunctionNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
Expand All @@ -103,7 +102,6 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode);
RELAY_EXPR_FUNCTOR_DISPATCH(VarNode);
RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode);
RELAY_EXPR_FUNCTOR_DISPATCH(ParamNode);
RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode);
RELAY_EXPR_FUNCTOR_DISPATCH(CallNode);
RELAY_EXPR_FUNCTOR_DISPATCH(LetNode);
Expand All @@ -127,7 +125,6 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
void VisitExpr_(const GlobalVarNode* op) override;
void VisitExpr_(const ConstantNode* op) override;
void VisitExpr_(const TupleNode* op) override;
void VisitExpr_(const ParamNode* op) override;
void VisitExpr_(const FunctionNode* op) override;
void VisitExpr_(const CallNode* op) override;
void VisitExpr_(const LetNode* op) override;
Expand All @@ -151,7 +148,6 @@ class ExprMutator
Expr VisitExpr_(const GlobalVarNode* op) override;
Expr VisitExpr_(const OpNode* op) override;
Expr VisitExpr_(const TupleNode* op) override;
Expr VisitExpr_(const ParamNode* op) override;
Expr VisitExpr_(const FunctionNode* op) override;
Expr VisitExpr_(const CallNode* call_node) override;
Expr VisitExpr_(const LetNode* op) override;
Expand Down
1 change: 0 additions & 1 deletion python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
Tuple = expr.Tuple
Var = expr.Var
GlobalVar = expr.GlobalVar
Param = expr.Param
Function = expr.Function
Call = expr.Call
Let = expr.Let
Expand Down
Loading