diff --git a/CMakeLists.txt b/CMakeLists.txt index 572f4aef1432..65a7d9e36e2d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -104,6 +104,12 @@ file(GLOB COMPILER_SRCS src/schedule/*.cc ) +file(GLOB_RECURSE RELAY_SRCS + src/relay/*.cc + ) +list(APPEND COMPILER_SRCS ${RELAY_SRCS}) + + if(NOT MSVC) file(GLOB COMPILER_VERILOG_SRCS src/codegen/verilog/*.cc) list(APPEND COMPILER_SRCS ${COMPILER_VERILOG_SRCS}) diff --git a/docs/conf.py b/docs/conf.py index e3f7f6a82c24..717003824703 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -33,7 +33,7 @@ # General information about the project. project = u'tvm' author = u'%s developers' % project -copyright = u'2017, %s' % author +copyright = u'2018, %s' % author github_doc_root = 'https://github.com/tqchen/tvm/tree/master/docs/' # add markdown parser diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h new file mode 100644 index 000000000000..7c66d2c2de43 --- /dev/null +++ b/include/tvm/relay/base.h @@ -0,0 +1,203 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/base.h + * \brief Base classes for the Relay IR. + */ +#ifndef TVM_RELAY_BASE_H_ +#define TVM_RELAY_BASE_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +/*! + * \brief Relay: a high level functional IR for TVM. + * + * This namespace contains the abstract syntax tree, and other + * essential data structures for the Relay IR. + * + * You can find more about Relay by reading the language reference. + */ +namespace relay { +/*! + * \brief we always used NodeRef for referencing nodes. + * + * By default, NodeRef is a std::shared_ptr of node + */ +using NodeRef = tvm::NodeRef; + +/*! + * \brief Content data type. + */ +using DataType = ::tvm::Type; + +/*! + * \brief Symbolic expression for tensor shape. + */ +using ShapeExpr = ::tvm::Expr; + +/*! + * \brief Hash function for nodes. + * e.g. std::unordered_map + */ +using NodeHash = ::tvm::NodeHash; +/*! + * \brief Equality check function for nodes. + */ +using NodeEqual = ::tvm::NodeEqual; + +/*! + * \brief Macro to make it easy to define node ref type given node + * \param TypeName The name of the reference type. + * \param NodeName The internal container name. + * \param NodeRefBase The base type. + */ +#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \ + class TypeName : public NodeRefBase { \ + public: \ + TypeName() {} \ + explicit TypeName(std::shared_ptr<::tvm::Node> n) : NodeRefBase(n) {} \ + const NodeName* operator->() const { \ + return static_cast(node_.get()); \ + } \ + operator bool() { return this->defined(); } \ + using ContainerType = NodeName; \ + }; + +/*! + * \brief The source name in the Span + * \sa SourceNameNode, Span + */ +class SourceName; +/*! + * \brief The name of a source fragment. + */ +class SourceNameNode : public Node { + public: + /*! \brief The source name. */ + std::string name; + // override attr visitor + void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); } + + TVM_DLL static SourceName make(std::string name); + + static constexpr const char* _type_key = "relay.SourceName"; + TVM_DECLARE_NODE_TYPE_INFO(SourceNameNode, Node); +}; + +/*! + * \brief The source name of a file span. + * \sa SourceNameNode, Span + */ +class SourceName : public NodeRef { + public: + /*! \brief default constructor */ + SourceName() {} + + /*! \brief constructor from node pointer */ + explicit SourceName(std::shared_ptr n) : NodeRef(n) {} + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const SourceNameNode* operator->() const; + + /*! + * \brief Get an SourceName for a given operator name. + * Will raise an error if the source name has not been registered. + * \param name Name of the operator. + * \return Reference to a SourceName valid throughout program lifetime. + */ + TVM_DLL static const SourceName& Get(const std::string& name); + + /*! \brief specify container node */ + using ContainerType = SourceNameNode; +}; + +/*! + * \brief Span information for debugging purposes + */ +class Span; +/*! + * \brief Stores locations in frontend source that generated a node. + */ +class SpanNode : public Node { + public: + /*! \brief The source name */ + SourceName source; + /*! \brief Line number */ + int lineno; + /*! \brief column offset */ + int col_offset; + // override attr visitor + void VisitAttrs(AttrVisitor* v) final { + v->Visit("source", &source); + v->Visit("lineno", &lineno); + v->Visit("col_offset", &col_offset); + } + + TVM_DLL static Span make(SourceName source, int lineno, int col_offset); + + static constexpr const char* _type_key = "relay.Span"; + TVM_DECLARE_NODE_TYPE_INFO(SpanNode, Node); +}; + +RELAY_DEFINE_NODE_REF(Span, SpanNode, NodeRef); + +/*! + * \brief This is the base node container of all relay structures. + */ +class RelayNode : public Node { + public: + /*! \brief The location of the program in a SourceFragment can be null, + * check with span.defined() */ + mutable Span span; + + static constexpr const char* _type_key = "relay.Node"; + TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node); +}; + +/*! + * \brief Get a reference type from a Node ptr type + * + * It is always important to get a reference type + * if we want to return a value as reference or keep + * the node alive beyond the scope of the function. + * + * \param ptr The node pointer + * \tparam RefType The reference type + * \tparam NodeType The node type + * \return The corresponding RefType + */ +template +RefType GetRef(const NodeType* ptr) { + static_assert(std::is_same::value, + "Can only cast to the ref of same container type"); + return RefType(const_cast(ptr)->shared_from_this()); +} + +// TODO(@tqchen, @jroesch): can we move these semantics to HalideIR +template +inline const T* As(const NodeRef& node) { + const Node* ptr = static_cast(node.get()); + if (ptr && (ptr->is_type() || ptr->derived_from())) { + return static_cast(ptr); + } + return nullptr; +} + +template +SubRef Downcast(BaseRef ref) { + CHECK(ref->template is_type()) + << "Downcast from " << ref->type_key() << " to " + << SubRef::ContainerType::_type_key << " failed."; + return SubRef(ref.node_); +} + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BASE_H_ diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h new file mode 100644 index 000000000000..7e07dc01eab4 --- /dev/null +++ b/include/tvm/relay/environment.h @@ -0,0 +1,121 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/environment.h + * \brief The global environment: contains information needed to + * compile & optimize Relay programs. + */ +#ifndef TVM_RELAY_ENVIRONMENT_H_ +#define TVM_RELAY_ENVIRONMENT_H_ + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +struct Environment; + +/*! \brief The global environment of Relay programs. + * + * The global environment contains the global + * information needed to compile a Relay program. + * + * It contains all global functions, and configuration + * options. + * + * Many operations require access to the global + * Environment. We pass the Environment by value + * in a functional style as an explicit argument, + * but we mutate the Environment while optimizing + * Relay programs. + * + * The functional style allows users to construct custom + * environments easily, for example each thread can store + * an Environment while auto-tuning. + * */ + +class EnvironmentNode : public RelayNode { + public: + /*! \brief A map from ids to all global functions. */ + tvm::Map functions; + + EnvironmentNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("functions", &functions); + v->Visit("global_map_", &global_map_); + } + + TVM_DLL static Environment make(tvm::Map global_funcs); + + /*! \brief Add a function to the global environment. + * \param var The name of the global function. + * \param func The function. + * \param update Controls whether you can replace a definition in the + * environment. + */ + void Add(const GlobalVar& var, const Function& func, bool update = false); + + /*! \brief Update a function in the global environment. + * \param var The name of the global function to update. + * \param func The new function. + */ + void Update(const GlobalVar& var, const Function& func); + + /*! \brief Remove a function from the global environment. + * \param var The name of the global function to update. + */ + void Remove(const GlobalVar& var); + + /*! \brief Lookup a global function by its variable. + * \param str The unique string specifying the global variable. + * \returns The global variable. + */ + GlobalVar GetGlobalVar(const std::string& str); + + /*! \brief Lookup a global function by its variable. + * \param var The global var to lookup. + * \returns The function named by the variable argument. + */ + Function Lookup(const GlobalVar& var); + + /*! \brief Lookup a global function by its string name + * \param name The name of the function. + * \returns The function named by the argument. + */ + Function Lookup(const std::string& name); + + /*! \brief Combine with another Environment. + * \param other The other environment. + */ + void Merge(const Environment& other); + + static constexpr const char* _type_key = "relay.Environment"; + TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node); + + private: + /*! \brief A map from string names to global variables that + * ensures global uniqueness. + */ + tvm::Map global_map_; +}; + +struct Environment : public NodeRef { + Environment() {} + explicit Environment(std::shared_ptr p) : NodeRef(p) {} + + inline EnvironmentNode* operator->() const { + return static_cast(node_.get()); + } + + using ContainerType = EnvironmentNode; +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_ENVIRONMENT_H_ diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h new file mode 100644 index 000000000000..8ce73a027ca0 --- /dev/null +++ b/include/tvm/relay/error.h @@ -0,0 +1,36 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file error.h + * \brief The set of errors raised by Relay. + */ +#ifndef TVM_RELAY_ERROR_H_ +#define TVM_RELAY_ERROR_H_ + +#include +#include "./base.h" + +namespace tvm { +namespace relay { + +struct Error : dmlc::Error { + explicit Error(const std::string &msg) : dmlc::Error(msg) {} +}; + +struct InternalError : Error { + explicit InternalError(const std::string &msg) : Error(msg) {} +}; + +// TODO(@jroesch): we should change spanned errors to report +// errors against the Environment, inverting control to error definition. +struct FatalTypeError : dmlc::Error { + explicit FatalTypeError(const std::string &s) : dmlc::Error(s) {} +}; + +struct TypecheckerError : public dmlc::Error { + explicit TypecheckerError(const std::string &msg) : Error(msg) {} +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_ERROR_H_ diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h new file mode 100644 index 000000000000..6388e8367bf6 --- /dev/null +++ b/include/tvm/relay/expr.h @@ -0,0 +1,378 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/expr.h + * \brief Relay expression language. + */ +#ifndef TVM_RELAY_EXPR_H_ +#define TVM_RELAY_EXPR_H_ + +#include +#include +#include "./base.h" +#include "./type.h" + +namespace tvm { +namespace relay { + +/*! + * \brief A Relay expression. + */ +class Expr; +/*! + * \brief Base type of the Relay expression hiearchy. + */ +class ExprNode : public RelayNode { + public: + /*! + * \brief Stores the result of type inference(type checking). + * + * \note This can be undefined before type inference. + * This value is discarded during serialization. + */ + mutable Type checked_type_ = Type(nullptr); + /*! + * \return The checked_type + */ + const Type& checked_type() const { + CHECK(checked_type_.defined()) << "internal error: the type checker has " + "not populated the checked_type " + "field for this node"; + return this->checked_type_; + } + + static constexpr const char* _type_key = "relay.Expr"; + TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode); +}; + +RELAY_DEFINE_NODE_REF(Expr, ExprNode, NodeRef); + +/*! + * \brief Constant tensor, backed by an NDArray on the cpu(0) device. + * + * \note Scalar constants are represented by rank-0 const tensor. + * Constant folding are handled uniformly via Tensor types. + */ +class Constant; +/*! + * \brief Constant tensor type. + */ +class ConstantNode : public ExprNode { + public: + /*! \brief The data of the tensor */ + runtime::NDArray data; + + /*! \return The corresponding tensor type of the data */ + TensorType tensor_type() const; + + /*! \return Whether it is scalar(rank-0 tensor) */ + bool is_scalar() const { return data->ndim == 0; } + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("data", &data); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + TVM_DLL static Constant make(runtime::NDArray data); + + static constexpr const char* _type_key = "relay.Constant"; + TVM_DECLARE_NODE_TYPE_INFO(ConstantNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Constant, ConstantNode, Expr); + +/*! \brief Tuple of multiple Exprs */ +class Tuple; +/*! \brief Tuple container */ +class TupleNode : public ExprNode { + public: + /*! \brief the fields of the tuple */ + tvm::Array fields; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("fields", &fields); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + TVM_DLL static Tuple make(tvm::Array fields); + + static constexpr const char* _type_key = "relay.Tuple"; + TVM_DECLARE_NODE_TYPE_INFO(TupleNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr); + +/*! + * \brief Local variables used in the let expression. + * + * Its semantics are similar to tvm.Var node used in TVM's low level + * tensor expression language. + * + * \note Each Var is bind only once and is immutable/ + */ +class Var; +/*! \brief Container for Var */ +class VarNode : public ExprNode { + public: + /*! \brief The name of the variable, this only acts as a hint to the user, + * and is not used for equality. + */ + std::string name_hint; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("name_hint", &name_hint); + v->Visit("_checked_type_", &checked_type_); + } + + TVM_DLL static Var make(std::string name_hint); + + static constexpr const char* _type_key = "relay.Var"; + TVM_DECLARE_NODE_TYPE_INFO(VarNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Var, VarNode, Expr); + +/*! + * \brief Global variable that leaves in the top-level environment. + * This is used to enable recursive calls between function. + * + * \note A GlobalVar may only point to functions. + */ +class GlobalVar; +/*! \brief A GlobalId from the node's current type to target type. */ +class GlobalVarNode : public ExprNode { + public: + /*! \brief The name of the variable, this only acts as a hint. */ + std::string name_hint; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("name_hint", &name_hint); + v->Visit("_checked_type_", &checked_type_); + } + + TVM_DLL static GlobalVar make(std::string name_hint); + + static constexpr const char* _type_key = "relay.GlobalVar"; + TVM_DECLARE_NODE_TYPE_INFO(GlobalVarNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(GlobalVar, GlobalVarNode, Expr); + +/*! + * \brief Function parameter declaration. + */ +class Param; +/*! \brief A parameter. */ +class ParamNode : public ExprNode { + public: + /*! \brief The variable */ + Var var; + /*! \brief The type of the parameter */ + Type type; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("var", &var); + v->Visit("type", &type); + v->Visit("span", &span); + } + + TVM_DLL static Param make(Var var, Type type); + + static constexpr const char* _type_key = "relay.Param"; + TVM_DECLARE_NODE_TYPE_INFO(ParamNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Param, ParamNode, Expr); + +/*! + * \brief Function (subgraph in computational graph) + */ +class Function; +/*! \brief Function container */ +class FunctionNode : public ExprNode { + public: + /*! \brief Function parameters */ + tvm::Array params; + /*! \brief User annotated return type of the function. */ + Type ret_type; + /*! + * \brief + * The expression which represents the computation of the function, + * the expression may reference the parameters, and the type of it + * or sub-expressions may reference the type variables. + */ + Expr body; + /*! + * \brief Type parameters of the function. + * Enables the function to vary its type based on these. + * This corresponds to template paramaters in c++'s terminology. + * + * \note This can be usually empty for non-polymorphic functions. + */ + tvm::Array type_params; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("params", ¶ms); + v->Visit("ret_type", &ret_type); + v->Visit("body", &body); + v->Visit("type_params", &type_params); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + Type fn_type() const; + + TVM_DLL static Function make(tvm::Array params, Type ret_type, + Expr body, tvm::Array ty_params); + + static constexpr const char* _type_key = "relay.Function"; + TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr); + +/*! + * \brief Call corresponds to operator invocation. + * Corresponds to the operator in computational graph terminology. + */ +class Call; +/*! \brief Call container. */ +class CallNode : public ExprNode { + public: + /*! + * \brief The operator(function) being invoked + * + * - It can be relay::Op which corresponds to the primitive operators. + * - It can also be user defined functions (Function, GlobalVar, Var). + */ + Expr op; + + /*! \brief The arguments(inputs) of the call */ + tvm::Array args; + + /*! \brief The additional attributes */ + Attrs attrs; + + /*! + * \brief The type arguments passed to polymorphic(template) function. + * + * This is the advance feature that is only used when the function is + * polymorphic. It is safe to be ignored in most cases. For example, in the + * following code, the type_args of addone call is [int]. + * + * \code + * + * template + * T addone(T a) { return a + 1; } + * + * void main() { + * int x = addone(10); + * } + * + * \endcode + */ + tvm::Array type_args; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("op", &op); + v->Visit("args", &args); + v->Visit("attrs", &attrs); + v->Visit("type_args", &type_args); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + TVM_DLL static Call make(Expr op, Array args, Attrs attrs = Attrs(), + Array ty_args = Array()); + + static constexpr const char* _type_key = "relay.Call"; + TVM_DECLARE_NODE_TYPE_INFO(CallNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Call, CallNode, Expr); + +/*! + * \brief Let binding that binds a local var and optionally a type annotation. + * + * \note Let is useful to transform the program to be A-normal form. + * where each of the expression corresponds to a let binding. + * + * For developers who are familar with the computational graph. + * Each of the let can be viewed as a operator node in the computational graph. + * Traversing the list of let bindings is similar to running + * PostDFS-order(topo-order) traversal on the computational graph. + */ +class Let; +/*! \brief A binding of a sub-network. */ +class LetNode : public ExprNode { + public: + /*! \brief The variable we bind to */ + Var var; + /*! \brief The value we bind var to */ + Expr value; + /*! \brief The body of the let binding */ + Expr body; + /*! \brief Type annotation of value, this can be null */ + Type value_type; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("var", &var); + v->Visit("value", &value); + v->Visit("body", &body); + v->Visit("value_type", &value_type); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + TVM_DLL static Let make(Var var, Expr value, Expr body, Type value_type); + + static constexpr const char* _type_key = "relay.Let"; + TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Let, LetNode, Expr); + +/*! + * \brief Condition expression + * + * Unlike traditional statement `if`s, the if evalutes + * to the result of the branch taken. + * + * let x = if (true) { 1 } else { 0 }; // x is 1 + * let y = if (false) { 1 } else { 0 }; // y is 0 + * + * \note This is similar to C's ternary operator. + */ +class If; +/*! \brief container of If */ +class IfNode : public ExprNode { + public: + /*! \brief The condition */ + Expr cond; + /*! \brief The expression evaluated when condition is true. */ + Expr true_branch; + /*! \brief The expression evaluated when condition is false */ + Expr false_branch; + + IfNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("cond", &cond); + v->Visit("true_branch", &true_branch); + v->Visit("false_branch", &false_branch); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + TVM_DLL static If make(Expr cond, Expr true_branch, Expr false_branch); + + static constexpr const char* _type_key = "relay.If"; + TVM_DECLARE_NODE_TYPE_INFO(IfNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(If, IfNode, Expr); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_EXPR_H_ diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h new file mode 100644 index 000000000000..8ad0537ad68b --- /dev/null +++ b/include/tvm/relay/expr_functor.h @@ -0,0 +1,170 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/expr_functor.h + * \brief A more powerful visitor which enables defining arbitrary function + * signatures with type based dispatch on first argument. + */ +#ifndef TVM_RELAY_EXPR_FUNCTOR_H_ +#define TVM_RELAY_EXPR_FUNCTOR_H_ + +#include +#include +#include "./expr.h" +#include "./op.h" + +namespace tvm { +namespace relay { + +/*! + * \brief A dynamical functor that dispatches on in the first Expr argument. + * You can use this as a more powerful Visitor, since it allows you to + * define function signatures of Visit Function. + * + * \sa tvm/ir_functor.h + * + * \tparam FType function signiture + * This type is only defined for FType with function signature R(const Expr&, + * Args...) + */ +template +class ExprFunctor; + +// functions to be overriden. +#define EXPR_FUNCTOR_DEFAULT \ + { return VisitExprDefault_(op, std::forward(args)...); } + +#define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch( \ + [](const NodeRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.node_.get()), \ + std::forward(args)...); \ + }); + +template +class ExprFunctor { + private: + using TSelf = ExprFunctor; + using FType = tvm::IRFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~ExprFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const Expr& n, Args... args) { + return VisitExpr(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitExpr(const Expr& n, Args... args) { + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitExpr_(const ConstantNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const TupleNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const VarNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const GlobalVarNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ParamNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const FunctionNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const IfNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const OpNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExprDefault_(const Node* op, Args...) { + throw dmlc::Error(std::string("Do not have a default for ") + op->type_key()); + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAY_EXPR_FUNCTOR_DISPATCH(ConstantNode); + RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode); + RELAY_EXPR_FUNCTOR_DISPATCH(VarNode); + RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode); + RELAY_EXPR_FUNCTOR_DISPATCH(ParamNode); + RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode); + RELAY_EXPR_FUNCTOR_DISPATCH(CallNode); + RELAY_EXPR_FUNCTOR_DISPATCH(LetNode); + RELAY_EXPR_FUNCTOR_DISPATCH(IfNode); + RELAY_EXPR_FUNCTOR_DISPATCH(OpNode); + return vtable; + } +}; + +/*! \brief A simple visitor wrapper around ExprFunctor. + * + * Exposes two visitors with default traversal strategies, one + * which doesn't compute a result but can mutate internal state, + * and another which functionally builds a new Expr. + */ + +class ExprVisitor : public ::tvm::relay::ExprFunctor { + public: + void VisitExpr_(const VarNode* op) override; + void VisitExpr_(const GlobalVarNode* op) override; + void VisitExpr_(const ConstantNode* op) override; + void VisitExpr_(const TupleNode* op) override; + void VisitExpr_(const ParamNode* op) override; + void VisitExpr_(const FunctionNode* op) override; + void VisitExpr_(const CallNode* op) override; + void VisitExpr_(const LetNode* op) override; + void VisitExpr_(const IfNode* op) override; + void VisitExpr_(const OpNode* op) override; + virtual void VisitType(const Type& t); +}; + +/*! \brief A wrapper around ExprFunctor which functionally updates the AST. +* +* ExprMutator uses memoization and self return in order to amortize +* the cost of using functional updates. +*/ +class ExprMutator + : public ::tvm::relay::ExprFunctor { + public: + Expr Mutate(const Expr& expr); + Expr VisitExpr_(const VarNode* op, const Expr& e) override; + Expr VisitExpr_(const ConstantNode* op, const Expr& e) override; + Expr VisitExpr_(const GlobalVarNode* op, const Expr& e) override; + Expr VisitExpr_(const OpNode* op, const Expr& expr) override; + Expr VisitExpr_(const TupleNode* op, const Expr& e) override; + Expr VisitExpr_(const ParamNode* op, const Expr& e) override; + Expr VisitExpr_(const FunctionNode* op, const Expr& e) override; + Expr VisitExpr_(const CallNode* call_node, const Expr& e) override; + Expr VisitExpr_(const LetNode* op, const Expr& e) override; + Expr VisitExpr_(const IfNode* op, const Expr& e) override; + /*! \brief Used to visit the types inside of expressions. + * + * Can be overloaded to transform the types in arbitrary + * ways, one way would be to define a sub-class of type + * visitor for types which transform them appropriately. + */ + virtual Type VisitType(const Type& t); + + private: + /*! \brief Internal map used for memoization. */ + tvm::Map memo_; +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_EXPR_FUNCTOR_H_ diff --git a/include/tvm/relay/logging.h b/include/tvm/relay/logging.h new file mode 100644 index 000000000000..c53cd15ee72e --- /dev/null +++ b/include/tvm/relay/logging.h @@ -0,0 +1,33 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/logging.h + * \brief A wrapper around dmlc-core/logging.h which adds the ability + * to toggle logging via an environment variable. + */ + +#ifndef TVM_RELAY_LOGGING_H_ +#define TVM_RELAY_LOGGING_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +static bool logging_enabled() { + if (auto var = std::getenv("RELAY_LOG")) { + std::string is_on(var); + return is_on == "1"; + } else { + return false; + } +} + +#define RELAY_LOG(severity) LOG_IF(severity, logging_enabled()) + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_LOGGING_H_ diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h new file mode 100644 index 000000000000..49661fec5731 --- /dev/null +++ b/include/tvm/relay/op.h @@ -0,0 +1,469 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/op.h + * \brief Primitive operator definition. + */ +#ifndef TVM_RELAY_OP_H_ +#define TVM_RELAY_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "../attrs.h" +#include "./base.h" +#include "./expr.h" +#include "./type.h" + +namespace tvm { +namespace relay { + +// forward declare name. +template +class OpMap; +class GenericOpMap; +class OpRegistry; + +/*! + * \brief Node container of operator structure. + */ +class OpNode : public relay::ExprNode { + public: + /*! \brief name of the operator */ + std::string name; + /*! \brief the type of the operator */ + mutable FuncType op_type; + /*! + * \brief detailed description of the operator + * This can be used to generate docstring automatically for the operator. + */ + std::string description; + /* \brief Information of input arguments to the operator */ + Array arguments; + /*! + * \brief The type key of the attribute field + * This can be empty, in which case it defaults to + */ + std::string attrs_type_key; + /*! + * \brief number of input arguments to the operator, + * -1 means it is variable length + */ + int32_t num_inputs = -1; + /*! + * \brief support level of the operator, + * The lower the more priority it contains. + * This is in analogies to BLAS levels. + */ + int32_t support_level = 10; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("name", &name); + v->Visit("op_type", &op_type); + v->Visit("description", &description); + v->Visit("arguments", &arguments); + v->Visit("attrs_type_key", &attrs_type_key); + v->Visit("num_inputs", &num_inputs); + v->Visit("support_level", &support_level); + } + + static constexpr const char* _type_key = "relay.Op"; + TVM_DECLARE_NODE_TYPE_INFO(OpNode, ExprNode); + + private: + // friend class + friend class GenericOpMap; + friend class OpRegistry; + // Program internal unique index of operator. + // Used to help index the program. + uint32_t index_{0}; +}; + +/*! + * \brief Operator reference class. + */ +class Op : public relay::Expr { + public: + /*! \brief default constructor */ + Op() {} + /*! \brief constructor from node pointer */ + explicit Op(std::shared_ptr n) : Expr(n) {} + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const OpNode* operator->() const; + /*! + * \brief Get additional registered attribute about operators. + * If nothing has been registered, an empty OpMap will be returned. + * \param attr_name The name of the attribute. + * \return An OpMap of specified attr_name. + * \tparam ValueType The type of the attribute. + */ + template + inline static OpMap GetAttr(const std::string& attr_name); + /*! + * \brief Get an Op for a given operator name. + * Will raise an error if the op has not been registered. + * \param op_name Name of the operator. + * \return Pointer to a Op, valid throughout program lifetime. + */ + TVM_DLL static const Op& Get(const std::string& op_name); + + /*! \brief specify container node */ + using ContainerType = OpNode; + + private: + /*! + * \brief Get generic attrmap given attr name + * \param key The attribute key + * \return reference to GenericOpMap + */ + TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key); +}; + +/*! \brief Helper structure to register operators */ +class OpRegistry { + public: + /*! \return the operator */ + const Op& op() const { return op_; } + /*! + * \brief setter function during registration + * Set the description of operator + * \param descr the description string. + * \return reference to self. + */ + inline OpRegistry& describe(const std::string& descr); // NOLINT(*) + /*! + * \brief Add argument information to the function. + * \param name Name of the argument. + * \param type Type of the argument. + * \param description Description of the argument. + * \return reference to self. + */ + inline OpRegistry& add_argument(const std::string& name, + const std::string& type, + const std::string& description); + /*! + * \brief Attach the type function corresponding to the return type. + * \param rel_name The type relation name to register. + * \param type_rel_func The backing relation function which can solve an arbitrary + * relation on variables. + * \return reference to self. + */ + inline OpRegistry& add_type_rel( + const std::string& rel_name, + std::function(const Array&, int)> type_rel_func); + + /*! + * \brief Set the type key of attributes. + * \param type_key The type of of the attrs field.x + * \return reference to self. + */ + inline OpRegistry& set_attrs_type_key(const std::string& type_key); + /*! + * \brief Set the num_inputs + * \param n The number of inputs to be set. + * \return reference to self. + */ + inline OpRegistry& set_num_inputs(int32_t n); // NOLINT(*) + /*! + * \brief Set the support level of op. + * \param level The support level. + * \return reference to self. + */ + inline OpRegistry& set_support_level(int32_t level); // NOLINT(*) + /*! + * \brief Register additional attributes to operator. + * \param attr_name The name of the attribute. + * \param value The value to be set. + * \param plevel The priority level of this set, + * an higher priority level attribute + * will replace lower priority level attribute. + * Must be bigger than 0. + * + * Cannot set with same plevel twice in the code. + * + * \tparam ValueType The type of the value to be set. + */ + template + inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*) + const ValueType& value, int plevel = 10); + + // set the name of the op to be the same as registry + inline OpRegistry& set_name() { // NOLINT(*) + if (get()->name.length() == 0) { + get()->name = name; + } + return *this; + } + /*! \return The global single registry */ + TVM_DLL static ::dmlc::Registry* Registry(); + + private: + friend class ::dmlc::Registry; + // the name + std::string name; + /*! \brief The operator */ + Op op_; + // private constructor + OpRegistry(); + // return internal pointer to op. + inline OpNode* get(); + // update the attribute OpMap + TVM_DLL void UpdateAttr(const std::string& key, TVMRetValue value, + int plevel); +}; + +/*! + * \brief Generic map to store additional information of Op. + */ +class GenericOpMap { + public: + /*! + * \brief Check if the map has op as key. + * \param op The key to the map + * \return 1 if op is contained in map, 0 otherwise. + */ + inline int count(const Op& op) const; + /*! + * \brief get the corresponding value element at op + * \param op The key to the map + * \return the const reference to the content value. + */ + inline const TVMRetValue& operator[](const Op& op) const; + /*! + * \brief get the corresponding value element at op with default value. + * \param op The key to the map + * \param def_value The default value when the key does not exist. + * \return the const reference to the content value. + * \tparam ValueType The content value type. + */ + template + inline ValueType get(const Op& op, ValueType def_value) const; + + private: + friend class OpRegistry; + // the attribute field. + std::string attr_name_; + // internal data + std::vector > data_; + // The value + GenericOpMap() = default; +}; + +/*! + * \brief Map used to store meta-information about Op. + * \tparam ValueType The type of the value stored in map. + */ +template +class OpMap { + public: + /*! + * \brief Check if the map has op as key. + * \param op The key to the map + * \return 1 if op is contained in map, 0 otherwise. + */ + inline int count(const Op& op) const; + /*! + * \brief get the corresponding value element at op + * \param op The key to the map + * \return the const reference to the content value. + */ + inline ValueType operator[](const Op& op) const; + /*! + * \brief get the corresponding value element at op with default value. + * \param op The key to the map + * \param def_value The default value when the key does not exist. + * \return the const reference to the content value. + */ + inline ValueType get(const Op& op, ValueType def_value) const; + + private: + friend class Op; + // constructor + explicit OpMap(const GenericOpMap& map) : map_(map) {} + /*! \brief The internal map field */ + const GenericOpMap& map_; +}; + +// internal macros to make +#define RELAY_REGISTER_VAR_DEF \ + static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::OpRegistry& __make_##RelayOp + +/*! + * \def RELAY_REGISTER_OP + * \brief Register a new operator, or set attribute of the corresponding op. + * + * \param OpName The name of registry + * + * \code + * + * RELAY_REGISTER_OP("add") + * .describe("add two inputs together") + * .set_num_inputs(2) + * .set_attr("gpu_kernel", AddKernel); + * + * \endcode + */ +#define RELAY_REGISTER_OP(OpName) \ + DMLC_STR_CONCAT(RELAY_REGISTER_VAR_DEF, __COUNTER__) = \ + ::tvm::relay::OpRegistry::Registry() \ + ->__REGISTER_OR_GET__(OpName) \ + .set_name() + +// implementations +inline const OpNode* Op::operator->() const { + return static_cast(node_.get()); +} + +template +inline OpMap Op::GetAttr(const std::string& key) { + return OpMap(Op::GetGenericAttr(key)); +} + +inline OpNode* OpRegistry::get() { + return const_cast(op_.operator->()); +} + +inline OpRegistry& OpRegistry::describe( + const std::string& descr) { // NOLINT(*) + get()->description = descr; + return *this; +} + +inline OpRegistry& OpRegistry::add_argument(const std::string& name, + const std::string& type, + const std::string& description) { + std::shared_ptr n = std::make_shared(); + n->name = name; + n->type_info = type; + n->description = description; + get()->arguments.push_back(AttrFieldInfo(n)); + return *this; +} + +inline OpRegistry& OpRegistry::add_type_rel( + const std::string& rel_name, + std::function(const Array&, int)> type_rel_func) { + auto func_name = std::string("tvm.relay.type_relation.") + rel_name; + + TypedEnvFunc(const Array&, int)> env_type_rel_func; + + if (runtime::Registry::Get(func_name)) { + auto env_func = EnvFunc::Get(func_name); + env_type_rel_func = env_func; + } else { + runtime::Registry::Register(func_name) + .set_body_typed(const Array&, int)>(type_rel_func); + auto env_func = EnvFunc::Get(func_name); + env_type_rel_func = env_func; + } + + std::vector type_params; + std::vector arg_types; + + // Add inputs. + std::string input_name_prefix = "in"; + for (int i = 0; i < get()->num_inputs; i++) { + auto name = input_name_prefix + std::to_string(i); + auto param = TypeParamNode::make(name, TypeParamNode::Kind::kType); + type_params.push_back(param); + arg_types.push_back(param); + } + + auto ty_call_args = Array(arg_types); + + // Add output type. + auto out_param = TypeParamNode::make("out", TypeParamNode::Kind::kType); + type_params.push_back(out_param); + ty_call_args.push_back(out_param); + + TypeConstraint type_rel = + TypeRelationNode::make(rel_name, env_type_rel_func, ty_call_args); + + auto func_type = + FuncTypeNode::make(arg_types, out_param, type_params, {type_rel}); + + get()->op_type = func_type; + + return *this; +} + +inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) + get()->num_inputs = n; + return *this; +} + +inline OpRegistry& OpRegistry::set_attrs_type_key( // NOLINT(*) + const std::string& type_key) { + get()->attrs_type_key = type_key; + return *this; +} + +inline OpRegistry& OpRegistry::set_support_level(int32_t n) { // NOLINT(*) + get()->support_level = n; + return *this; +} + +template +inline OpRegistry& OpRegistry::set_attr( // NOLINT(*) + const std::string& attr_name, const ValueType& value, int plevel) { + CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; + TVMRetValue rv; + rv = value; + UpdateAttr(attr_name, rv, plevel); + return *this; +} + +// member functions of OpMap +inline int GenericOpMap::count(const Op& op) const { + if (op.defined()) { + const uint32_t idx = op->index_; + return idx < data_.size() ? (data_[idx].second != 0) : 0; + } else { + return 0; + } +} + +inline const TVMRetValue& GenericOpMap::operator[](const Op& op) const { + CHECK(op.defined()); + const uint32_t idx = op->index_; + CHECK(idx < data_.size() && data_[idx].second != 0) + << "Attribute " << attr_name_ << " has not been registered for Operator " + << op->name; + return data_[idx].first; +} + +template +inline ValueType GenericOpMap::get(const Op& op, ValueType value) const { + CHECK(op.defined()); + const uint32_t idx = op->index_; + if (idx < data_.size() && data_[idx].second != 0) { + return data_[idx].first; + } else { + return value; + } +} + +template +inline int OpMap::count(const Op& op) const { + return map_.count(op); +} + +template +inline ValueType OpMap::operator[](const Op& op) const { + return map_[op]; +} +template +inline ValueType OpMap::get(const Op& op, + ValueType def_value) const { + return map_.get(op, def_value); +} + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_OP_H_ diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h new file mode 100644 index 000000000000..e956097780bb --- /dev/null +++ b/include/tvm/relay/pass.h @@ -0,0 +1,85 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/pass.h + * \brief The set of Relay passes written in C++. + */ +#ifndef TVM_RELAY_PASS_H_ +#define TVM_RELAY_PASS_H_ + +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief Infer the type of an expression with the provided environment. + * + * The result of type checking is a new expression with unambigous + * type information filled in, as well as it's checked type field + * populated with the result type. + * + * \param env The environment used for global settings and referencing + * global functions. + * + * \param e The expression to type check. + * + * \return A type checked expression with its checked_type field populated. + */ +Expr InferType(const Environment& env, const Expr& e); +Expr InferType(const Environment& env, const GlobalVar& v, const Function& e); + +/*! + * \brief Check that types are well formed by applying "kinding rules". + * + * This pass ensures we do not do things that violate the design of the + * type system when writing down types. + * + * For example tensors are not allowed to contain functions in Relay. + * + * We check this by ensuring the `dtype` field of a Tensor always contains + * a data type such as `int`, `float`, `uint`. + * + * \param env The global environment. + * \param t The type to check. + * \return true if the rules are satisified otherwise false + */ +bool KindCheck(const Environment& env, const Type& t); + +/*! \brief Compare two expressions for structural equivalence. + * + * This comparison operator respects scoping and compares + * expressions without regard to variable choice. + * + * For example: `let x = 1 in x` is equal to `let y = 1 in y`. + * + * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence + * for more details. + * + * \param e1 The left hand expression. + * \param e2 The right hand expression. + * + * \return true if equal, otherwise false + */ +bool AlphaEqual(const Expr& e1, const Expr& e2); + +/*! \brief Compare two types for structural equivalence. + * + * This comparison operator respects scoping and compares + * expressions without regard to variable choice. + * + * For example: `forall s, Tensor[f32, s]` is equal to + * `forall w, Tensor[f32, w]`. + * + * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence + * for more details. + * + * \param t1 The left hand type. + * \param t2 The right hand type. + * + * \return true if equal, otherwise false + */ +bool AlphaEqual(const Type& t1, const Type& t2); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PASS_H_ diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h new file mode 100644 index 000000000000..44030ad8d97f --- /dev/null +++ b/include/tvm/relay/type.h @@ -0,0 +1,276 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/type.h + * \brief Relay typed AST nodes. + */ +#ifndef TVM_RELAY_TYPE_H_ +#define TVM_RELAY_TYPE_H_ + +#include +#include +#include +#include + +#include "./base.h" + +namespace tvm { +namespace relay { + +/*! \brief Base type of the Relay type hiearchy. */ +class TypeNode : public RelayNode { + public: + static constexpr const char* _type_key = "relay.Type"; + TVM_DECLARE_BASE_NODE_INFO(TypeNode, Node); +}; + +/*! + * \brief Type is the base type of relay type hiearchy. + * + * Relay's type system contains following two key concepts: + * + * - TensorType: type of certain Tensor values in the expression. + * - FunctionType: the type of the function. + * + * There are also advanced types to support generic(polymorphic types), + * which can be ignored when first reading the code base. + */ +class Type : public NodeRef { + public: + Type() {} + explicit Type(std::shared_ptr p) : NodeRef(p) {} + + using ContainerType = TypeNode; +}; + +/*! + * \brief Base of all Tensor types + * This container can hold TensorType or GenericTensorType. + */ +class BaseTensorTypeNode : public TypeNode { + public: + static constexpr const char* _type_key = "relay.BaseTensorType"; + TVM_DECLARE_BASE_NODE_INFO(BaseTensorTypeNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(BaseTensorType, BaseTensorTypeNode, Type); + +/*! + * \brief This is the most commonly used type in relay. + * TensorType have a fixed dimension, data type. + * + * The elements of shape can be either IntImm(constant integer), + * or any symbolic integer expression. + * The symbolic integer allows generic shape inference in certain cases. + * \sa TensorTypeNode The container class of TensorType. + */ +class TensorType; +/*! \brief TensorType container node */ +class TensorTypeNode : public BaseTensorTypeNode { + public: + /*! + * \brief The shape of the tensor, + * represented by ShapeExpr(tvm::Expr). + */ + Array shape; + /*! \brief The content data type */ + DataType dtype; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("shape", &shape); + v->Visit("dtype", &dtype); + v->Visit("span", &span); + } + + TVM_DLL static TensorType make(Array shape, DataType dtype); + + /*! \brief Construct an scalar containing elements of dtype. */ + TVM_DLL static TensorType Scalar(DataType dtype); + + static constexpr const char* _type_key = "relay.TensorType"; + TVM_DECLARE_NODE_TYPE_INFO(TensorTypeNode, BaseTensorTypeNode); +}; + +RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type); + +/*! + * \brief Type parameter in the function. + * This can be viewed as template parameter in c++ template function. + * + * For example, in the following pesudo code, + * the TypeParam of f is TypeParam(kind=kShapeVar, var=n). + * This function can take in a Tensor with shape=(3, 3) and + * returns a Tensor with shape=(9,) + * + * \code + * + * template + * f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)] + * + * \endcode + * \sa TypeParamNode The actual container class of TypeParam + */ +class TypeParam; +/*! \brief TypeParam container node */ +class TypeParamNode : public TypeNode { + public: + /*! \brief possible kinds of TypeParam */ + enum Kind : int { + /*! \brief template variable in shape expression */ + kShapeVar = 0, + kShape = 1, + kBaseType = 2, + kType = 3 + }; + /*! + * \brief The variable itself is only meaningful when + * kind is ShapeVar, otherwise, we only use the name. + */ + tvm::Var var; + /*! \brief The kind of type parameter */ + Kind kind; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("var", &var); + v->Visit("kind", &kind); + v->Visit("span", &span); + } + + TVM_DLL static TypeParam make(std::string name, Kind kind); + + static constexpr const char* _type_key = "relay.TypeParam"; + TVM_DECLARE_NODE_TYPE_INFO(TypeParamNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(TypeParam, TypeParamNode, Type); + +/*! + * \brief Potential Constraints in the type. + * \note This is reserved for future use. + */ +class TypeConstraint; +/*! \brief TypeConstraint container node. */ +class TypeConstraintNode : public TypeNode { + public: + static constexpr const char* _type_key = "relay.TypeConstraint"; + TVM_DECLARE_BASE_NODE_INFO(TypeConstraintNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(TypeConstraint, TypeConstraintNode, Type); + +class FuncType; +/*! + * \brief Function type in Relay. + * + * Relay support polymorphic function type. + * This can be roughly viewed as template function in C++. + * + * \sa TypeParam, TypeConstraint + */ +class FuncTypeNode : public TypeNode { + public: + /*! \brief type type of arguments */ + tvm::Array arg_types; + /*! \brief The type of return value. */ + Type ret_type; + // The following fields are used in polymorphic(template) functions + // For normal functions, the following two fields will be empty. + /*! \brief The type parameters of the function */ + tvm::Array type_params; + /*! + * \brief potential constraint the type need to obey + * \note this field is reserved for futher purposes. + */ + tvm::Array type_constraints; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("arg_types", &arg_types); + v->Visit("ret_type", &ret_type); + v->Visit("type_params", &type_params); + v->Visit("type_constraints", &type_constraints); + v->Visit("span", &span); + } + + TVM_DLL static FuncType make(tvm::Array arg_types, Type ret_type, + tvm::Array type_params, + tvm::Array type_constraints); + + static constexpr const char* _type_key = "relay.FuncType"; + TVM_DECLARE_NODE_TYPE_INFO(FuncTypeNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(FuncType, FuncTypeNode, Type); + +using TypeRelationFn = + TypedEnvFunc(const Array&, int)>; + +/*! + * \brief Opaque type relation, is an input-output relation on types. + */ +class TypeRelation; +/*! + * \brief TypeRelation container. + * \note This node is not directly serializable. + * The type function need to be lookedup in the environment. + */ +class TypeRelationNode : public TypeConstraintNode { + public: + /*! \brief The name of the function */ + std::string name; + + /*! + * \brief The function on input and output variables which + * this is not directly serializable, + * need to be looked-up in the environment. + */ + TypeRelationFn func_; + + /*! \brief The type arguments to the type function. */ + tvm::Array args; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("name", &name); + } + + TVM_DLL static TypeRelation make(std::string name, TypeRelationFn func_, Array args); + + static constexpr const char* _type_key = "relay.TypeRelation"; + TVM_DECLARE_NODE_TYPE_INFO(TypeRelationNode, TypeConstraintNode); +}; + +RELAY_DEFINE_NODE_REF(TypeRelation, TypeRelationNode, TypeConstraint); + +/*! + * \brief The type of tuple values. + */ +class TupleType; +/*! + * \brief TupleType container. + */ +class TupleTypeNode : public TypeNode { + public: + /*! \brief The type of each field in the tuple. */ + tvm::Array fields; + + TupleTypeNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); } + + TVM_DLL static TupleType make(tvm::Array fields); + + static constexpr const char* _type_key = "relay.TypeTuple"; + TVM_DECLARE_NODE_TYPE_INFO(TupleTypeNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type); + +// The following fields contains advanced typing +// Only keep the class name and reserved for future usage. +class GenericTensorType; +// stores a DataType. +class GenericDataType; +// stores a DataType. +class GenericShape; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_TYPE_H_ diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py new file mode 100644 index 000000000000..18a53be92815 --- /dev/null +++ b/python/tvm/relay/__init__.py @@ -0,0 +1,34 @@ +# pylint: disable=wildcard-import +"""The Relay IR namespace containing the IR definition and compiler.""" +from . import base +from . import ty +from . import expr +from . import env +from . import ir_pass +from . import ir_builder +# Operators +from .op import Op +from .op.tensor import * + +# Span +Span = base.Span + +# Type +Type = ty.Type +TensorType = ty.TensorType +Kind = ty.Kind +TypeParam = ty.TypeParam +TypeConstraint = ty.TypeConstraint +FuncType = ty.FuncType + +# Expr +Constant = expr.Constant +Tuple = expr.Tuple +Var = expr.Var +GlobalVar = expr.GlobalVar +Param = expr.Param +Function = expr.Function +Call = expr.Call +Let = expr.Let +If = expr.If +Var = Var diff --git a/python/tvm/relay/_env.py b/python/tvm/relay/_env.py new file mode 100644 index 000000000000..25b8715a7816 --- /dev/null +++ b/python/tvm/relay/_env.py @@ -0,0 +1,5 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable +"""The interface to the Environment exposed from C++.""" +from tvm._ffi.function import _init_api + +_init_api("relay._env", __name__) diff --git a/python/tvm/relay/_env.pyi b/python/tvm/relay/_env.pyi new file mode 100644 index 000000000000..c6b5d0f6c4bd --- /dev/null +++ b/python/tvm/relay/_env.pyi @@ -0,0 +1,5 @@ +from typing import Union, Tuple, Dict, List +from relay.ir import GlobalId, OperatorId, Item, NodeBase, Span, FileId +from relay.ir import ShapeExtension, Operator, Defn + +class Environment(NodeBase): ... \ No newline at end of file diff --git a/python/tvm/relay/_ir_pass.py b/python/tvm/relay/_ir_pass.py new file mode 100644 index 000000000000..61fdcfa38c2f --- /dev/null +++ b/python/tvm/relay/_ir_pass.py @@ -0,0 +1,5 @@ +"""FFI exposing the Relay type inference and checking.""" + +from tvm._ffi.function import _init_api + +_init_api("relay._ir_pass", __name__) diff --git a/python/tvm/relay/_ir_pass.pyi b/python/tvm/relay/_ir_pass.pyi new file mode 100644 index 000000000000..1bb42ab854c2 --- /dev/null +++ b/python/tvm/relay/_ir_pass.pyi @@ -0,0 +1,6 @@ +from .env import Environment +from . import ir + +def check_expr(env: Environment, expr: ir.Expr) -> ir.Type: ... +def generalize(env: Environment, expr: ir.Expr) -> ir.Expr: ... +def _get_checked_type(expr: ir.Expr) -> ir.Type: ... diff --git a/python/tvm/relay/_make.py b/python/tvm/relay/_make.py new file mode 100644 index 000000000000..20a582e76d6a --- /dev/null +++ b/python/tvm/relay/_make.py @@ -0,0 +1,9 @@ +""" +The constructors for all Relay AST nodes exposed from C++. + +This module includes MyPy type signatures for all of the +exposed modules. +""" +from .._ffi.function import _init_api + +_init_api("relay._make", __name__) diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py new file mode 100644 index 000000000000..d683c96739cd --- /dev/null +++ b/python/tvm/relay/base.py @@ -0,0 +1,26 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck +"""The base node types for the Relay language.""" +from __future__ import absolute_import as _abs +from .._ffi.node import NodeBase, register_node as _register_tvm_node +from . import _make + +NodeBase = NodeBase + +def register_relay_node(type_key=None): + """register relay node type + + Parameters + ---------- + type_key : str or cls + The type key of the node + """ + if not isinstance(type_key, str): + return _register_tvm_node( + "relay." + type_key.__name__)(type_key) + return _register_tvm_node(type_key) + + +@register_relay_node +class Span(NodeBase): + def __init__(self, source, lineno, col_offset): + self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset) diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py new file mode 100644 index 000000000000..62afef76425a --- /dev/null +++ b/python/tvm/relay/env.py @@ -0,0 +1,84 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import +"""A global environment storing everything needed to interpret or compile a Relay program.""" +from .base import register_relay_node, NodeBase +from . import _make +from . import _env + + +@register_relay_node +class Environment(NodeBase): + """The global Relay environment containing functions, + options and more. + """ + + def __init__(self, funcs): + """Construct an environment. + + Parameters + ------ + funcs: list of relay.Function + + Returns + ------ + env: A new environment containing :py:class:`~relay.env.Environment`. + """ + self.__init_handle_by_constructor__(_make.Environment, funcs) + + def add(self, var, func): + """Add a function to the environment. + + Parameters + --------- + var: GlobalVar + The global variable which names the function. + + func: Function + The function. + """ + if isinstance(var, str): + var = _env.Environment_GetGlobalVar(self, var) + + _env.Environment_Add(self, var, func) + + def merge(self, other): + """Merge two environments. + + Parameters + ---------- + other: Environment + The environment to merge into the current Environment. + """ + return _env.Environment_Merge(self, other) + + def global_var(self, name): + """Get a global variable by name. + + Parameters + ---------- + name: str + The name of the global variable. + + Returns + ------- + global_var: GlobalVar + The global variable mapped to :code:`name`. + """ + return _env.Environment_GetGlobalVar(self, name) + + def __getitem__(self, var): + """Lookup a global function by name or by variable. + + Parameters + ---------- + var: str or GlobalVar + The name or global variable. + + Returns + ------- + func: Function + The function referenced by :code:`var`. + """ + if isinstance(var, str): + return _env.Environment_Lookup_str(self, var) + else: + return _env.Environment_Lookup(self, var) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py new file mode 100644 index 000000000000..3bddbc89b56e --- /dev/null +++ b/python/tvm/relay/expr.py @@ -0,0 +1,115 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +"""The expression nodes of Relay.""" +from __future__ import absolute_import +from .base import NodeBase, register_relay_node +from ._ir_pass import _get_checked_type +from . import _make +from .. import convert + + +class Expr(NodeBase): + """The base type for all Relay expressions.""" + + def checked_type(self): + return _get_checked_type(self) + + def __call__(self, *args): + converted_args = [] + for arg in args: + if isinstance(arg, Param): + converted_args.append(arg.var) + else: + converted_args.append(arg) + + return Call(self, args, None, None) + + +@register_relay_node +class Constant(Expr): + """A constant tensor in Relay, see tvm/relay/type.h for more details. + """ + + def __init__(self, data): + self.__init_handle_by_constructor__(_make.Constant, data) + + +@register_relay_node +class Tuple(Expr): + """A hetereogenous sequence of values. + see tvm/relay/type.h for more details. + """ + + def __init__(self, fields): + self.__init_handle_by_constructor__(_make.Tuple, fields) + + +@register_relay_node +class Var(Expr): + """A local variable in Relay.""" + + def __init__(self, name_hint): + self.__init_handle_by_constructor__(_make.Var, name_hint) + + +@register_relay_node +class GlobalVar(Expr): + """A global variable in Relay.""" + + def __init__(self, name_hint): + self.__init_handle_by_constructor__(_make.GlobalVar, name_hint) + + +@register_relay_node +class Param(Expr): + """A function type in Relay, see tvm/relay/type.h for more details. + """ + + def __init__(self, var, ty): + self.__init_handle_by_constructor__(_make.Param, var, ty) + + +@register_relay_node +class Function(Expr): + """A function in Relay, see tvm/relay/expr.h for more details.""" + + def __init__(self, + params, + ret_type, + body, + type_params=None + ): + if type_params is None: + type_params = convert([]) + + self.__init_handle_by_constructor__( + _make.Function, params, ret_type, body, type_params) + + +@register_relay_node +class Call(Expr): + """A function call in Relay, see tvm/relay/expr.h for more details.""" + + def __init__(self, op, args, attrs, ty_args=None): + if not ty_args: + ty_args = [] + + self.__init_handle_by_constructor__( + _make.Call, op, args, attrs, ty_args) + + +@register_relay_node +class Let(Expr): + """A variable bindings in Relay, see tvm/relay/expr.h for more details.""" + + def __init__(self, var, value, body, value_type): + self.__init_handle_by_constructor__( + _make.Let, var, value, body, value_type) + + +@register_relay_node +class If(Expr): + """A conditional expression in Relay, see tvm/relay/expr.h for more details.""" + + def __init__(self, cond, true_value, false_value): + self.__init_handle_by_constructor__( + _make.If, cond, true_value, false_value) diff --git a/python/tvm/relay/expr.pyi b/python/tvm/relay/expr.pyi new file mode 100644 index 000000000000..fd30e3ed25cf --- /dev/null +++ b/python/tvm/relay/expr.pyi @@ -0,0 +1,114 @@ +from typing import List +import tvm +from .base import Span, NodeBase +from .ty import Type, TypeParam +from ._ir_pass import _get_checked_type + + +class Expr(NodeBase): + def checked_type(self): + ... + + def __call__(self, *args): + ... + + +class Constant(Expr): + data = ... # type: tvm.nd.NDArray + + def __init__(self, data): + # type: (tvm.nd.NDArray) -> None + ... + + +class Tuple(Expr): + fields = .. # type: List[Expr] + + def __init__(self, fields): + # type: (List[Expr]) -> None + ... + + +class Var(Expr): + """A local variable in Relay.""" + name_hint = ... # type: str + + def __init__(self, name_hint): + # type: (str) -> None + ... + + +class GlobalVar(Expr): + name_hint = ... # type: str + + def __init__(self, name_hint): + # type: (str) -> None + ... + + +class Param(Expr): + var = ... # type: Var + type = ... # type: Type + + def __init__(self, var, ty): + # type: (Var, Type) -> None + ... + + +class Function(Expr): + """A function in Relay, see tvm/relay/expr.h for more details.""" + type_params = ... # type: List[TypeParam] + params = ... # type: List[Param] + ret_type = ... # type: Type + body = ... # type: Expr + + def __init__(self, + params, # type: List[Param], + ret_type, # type: Type, + body, # type: Expr, + type_params=None, # type: List[TypeParam] + ): + # type: (...) -> None + ... + + +@register_relay_node +class Call(Expr): + """A function call in Relay, see tvm/relay/expr.h for more details.""" + op = ... # type: Expr + args = ... # type: List[Expr] + # todo(@jroesch): add attrs + + def __init__(self, op, args, attrs, ty_args=None): + # type: (Expr, List[Expr], Optional[List[Type]]) -> None + if not ty_args: + ty_args = [] + + self.__init_handle_by_constructor__( + _make.Call, op, args, attrs, ty_args) + + +@register_relay_node +class Let(Expr): + """A variable bindings in Relay, see tvm/relay/expr.h for more details.""" + var = ... # type: Var + value = ... # type: Expr + body = ... # type: Expr + value_type = ... # type: Type + + def __init__(self, var, value, body, value_type): + # type: (Var, Expr, Expr, Type) -> None + ... + + +@register_relay_node +class If(Expr): + """A conditional expression in Relay, see tvm/relay/expr.h for more details.""" + cond = ... # type: Expr + true_value = ... # type: Expr + false_value = ... # type: Expr + span = ... # type: Span + + def __init__(self, cond, true_value, false_value): + # type: (Expr, Expr, Expr) -> None + ... diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py new file mode 100644 index 000000000000..6e52f209d0c6 --- /dev/null +++ b/python/tvm/relay/ir_builder.py @@ -0,0 +1,394 @@ +# pylint: disable=no-else-return +"""IR builder for the Relay IR. + +Enables users to construct Relay programs with a Python API. +""" +from collections import OrderedDict +import numpy as np +import tvm +from .ty import Type, FuncType, TensorType +from .expr import Expr, Constant, Let, Var, Param, Function, If +from .env import Environment + + +def _convert_to_value(arg, ctxt=tvm.cpu(0)): + # type: (Any, tvm.Context) -> tvm.nd.NDArray + """Convert Python values into the appropriate types + for the Relay evaluator. + """ + if isinstance(arg, int): + return tvm.nd.array(np.array(arg, dtype='int32'), ctxt) + elif isinstance(arg, float): + return tvm.nd.array(arg, ctxt) + elif isinstance(arg, bool): + return tvm.nd.array(np.array(arg, dtype='float32'), ctxt) + elif isinstance(arg, np.ndarray): + return tvm.nd.array(arg, ctxt) + elif isinstance(arg, tvm.ndarray.NDArray): + return arg + else: + # raise Exception(f"can't convert {type(arg)} to a Relay AST") + raise Exception("unsupported argument type {0}".format(type(arg))) + + +def _convert_type(rtype): + if isinstance(rtype, str): + return scalar_type(rtype) + elif isinstance(rtype, Type): + return rtype + else: + raise Exception( + "unsupported conversion to Relay type {0}".format(type(rtype))) + + +def convert(arg): + # type: (Any) -> Expr + """Convert some Python objects into a Relay AST fragment. + + Parameters + ---------- + arg: Any + The Python object + + Returns + ------- + expr: relay.Expr + The converted expression. + """ + if isinstance(arg, Expr): + return arg + elif isinstance(arg, tuple): + return relay.Tuple([convert(el) for el in arg]) + elif isinstance(arg, PartialFunc): + return arg.to_func() + else: + value = _convert_to_value(arg) + return Constant(value) + + +class WithScope(object): + """A wrapper for builder methods which introduce scoping.""" + + def __init__(self, enter_value, exit_cb): + self._enter_value = enter_value + self._exit_cb = exit_cb + + def __enter__(self): + return self._enter_value + + def __exit__(self, ptype, value, trace): + if value: + raise value + else: + self._exit_cb() + + +class PartialFunc(object): + """A wrapper around functions while they are being built. + + Used by the builder as a user is building up a function, + allows Function nodes which contain partially initialized + state. + """ + + def __init__(self, params, ret_type, body, type_params): + self.params = params + self.ret_type = ret_type + self.body = body + self.type_params = type_params + + def param_ids(self): + return [p.var for p in self.params] + + def to_func(self): + """Converts a PartialFunc into a :py:class:`~relay.Function`.""" + return Function( + self.params, + self.ret_type, + self.body, + self.type_params) + +#pylint: disable=invalid-name + + +def _mk_let(bindings, ret_value): + let_expr = ret_value + for var, (value, ty) in reversed(list(bindings.items())): + let_expr = Let(var, value, let_expr, ty) + + return let_expr + + +class IRBuilder(object): + """The IRBuilder class. + + Enables users to build up a Relay environment and program. + + Examples + -------- + + Program: + fn (x : Tensor[f32, (10, 10)]) { + let t1 = log(x); + let t2 = add(t1, x); + return t1; + } + + ..code-block: python + b = IRBuilder() + with b.function(('x', tensor_type(10, 10))) as func: + x, = func.param_ids() + t1 = b.let('t1', log(x)) + t2 = b.let('t2', add(t1, x)) + b.ret(t2) + """ + + def __init__(self): + self.bindings = [OrderedDict({})] + self.scopes = [OrderedDict({})] + self.params = [] + self.ret_values = [None] + self.env = Environment({}) + + def enter_scope(self, params=None): + if not params: + params = [] + + self.bindings.append(OrderedDict({})) + self.scopes.append(OrderedDict({})) + self.params.append(params) + self.ret_values.append(None) + + def exit_scope(self): + bindings = self.bindings.pop() + scopes = self.scopes.pop() + params = self.params.pop() + ret_value = self.ret_values.pop() + return bindings, scopes, params, ret_value + + #pylint: disable=invalid-name + def bind(self, name, value, ty): + lv = Var(name) + self.scopes[-1][name] = lv + self.bindings[-1][lv] = (value, ty) + return lv + + def let(self, name, value, value_type=None): + if isinstance(value, Param): + value = value.var + + if not isinstance(value, Expr): + value = convert(value) + + return self.bind(name, value, value_type) + + def _convert_params(self, raw_params): + relay_params = [] + for raw_param in raw_params: + if isinstance(raw_param, Param): + var = raw_param.var + param = raw_param + elif isinstance(raw_param, tuple): + var, ty = raw_param + if isinstance(var, str): + var = Var(var) + ty = _convert_type(ty) + param = Param(var, ty) + elif isinstance(param, str): + var = Var(raw_param) + ty = None + param = Param(var, ty) + else: + raise Exception("unknown parameter type") + + self.scopes[-1][var.name_hint] = var + relay_params.append(param) + + return relay_params + + def function(self, *params): + """Construct a Relay function.""" + + relay_params = self._convert_params(params) + + self.enter_scope() + + pfunc = PartialFunc(relay_params, None, None, []) + + def _on_exit(): + bindings, _, _, ret_value = self.exit_scope() + body = _mk_let(bindings, ret_value) + pfunc.body = body + + return WithScope(pfunc, _on_exit) + + def ret(self, x): + """Set `x` to be the return value of the current function.""" + if not self.ret_values[-1]: + self.ret_values[-1] = convert(x) + else: + raise Exception( + "return value already set, a function can only have one return value") + + def if_scope(self, cond): + """Construct the if branch an if expression with scoping.""" + self.enter_scope() + + def _on_exit(): + bindings, _, _, ret_value = self.exit_scope() + assert self.ret_values[-1] is None + true_branch = _mk_let(bindings, ret_value) + self.ret_values[-1] = If(cond, true_branch, None) + + return WithScope(10, _on_exit) + + def else_scope(self): + """Construct the else branch of an if expression with scoping.""" + self.enter_scope() + + def _on_exit(): + bindings, _, _, ret_value = self.exit_scope() + partial_if = self.ret_values[-1] + assert isinstance( + partial_if, If) and partial_if.false_branch is None + false_branch = _mk_let(bindings, ret_value) + self.ret_values[-1] = If( + partial_if.cond, + partial_if.true_branch, + false_branch) + + return WithScope(10, _on_exit) + + def param(self, name, ty=None): + if not ty: + ty = scalar_type('float32') + else: + ty = _convert_type(ty) + + return Param(Var(name), ty) + + def global_var(self, name): + # type: (str) -> GlobalVar + """Construct a global var with `name` as its name hint. + + Parameters + ---------- + name: str + The name of the global variable. + + Returns + ------- + global_var: relay.GlobalVar + The global variable with `name`. + + """ + return self.env.global_var(name) + + def decl(self, name, *params, **kwargs): + """Create a global function. + + Parameters + ---------- + name: str or GlobalVar + The name of the function. + params: params + The parameters of the function. + + Returns + ------- + with_scope: Scope for the function. + """ + + ret_type = kwargs.get('ret_type', None) + + self.enter_scope() + + def _on_exit(): + bindings, _, _, ret_value = self.exit_scope() + exp = _mk_let(bindings, ret_value) + self.env.add(name, Function(params, ret_type, exp)) + + return WithScope(10, _on_exit) + + def get(self): + """Get the full program. + + Returns + ---------- + (prog, env) : (relay.Expr, relay.Environment) + A pair of the partial program, and the modified environment. + """ + bindings = self.bindings.pop() + scope = self.scopes.pop() + + if self.bindings: + raise Exception("IRBuilder: binding error") + + if self.scopes: + raise Exception("IRBuilder: scoping error") + + if bindings and scope and not self.ret_values: + raise Exception("IRBuilder: no return value set") + + return _mk_let(bindings, self.ret_values[-1]), self.env + + +def scalar_type(dtype): + """Construct a Relay scalar type. + + Parameters + ---------- + dtype: dtype + The dtype of the scalar type. + + Returns: + scalar_type: relay.Type + The scalar type. + """ + return TensorType(tvm.convert([]), dtype) + + +def tensor_type(*shape, **kwargs): + """Construct a Relay Tensor type. + + Parameters + ---------- + shape: list of tvm.Expr + The shape of the Tensor type. + dtype: dtype + The dtype of the Tensor type. + + Returns + ------- + tensor_type: relay.Type + The resulting tensor types. + """ + dtype = kwargs.get('dtype', 'float32') + + return TensorType(tvm.convert(shape), dtype) + + +def func_type(args, ret_type, type_params=None): + """Construct a Relay function type. + + Parameters + ---------- + args: list of relay.Type + The argument types. + + ret_type: relay.Type + The return type. + + type_params: list of relay.TypeParam + The type parameters. + + Returns + ------- + func_type: The function type. + """ + if not type_params: + type_params = [] + + args = [_convert_type(arg) for arg in args] + ret_type = _convert_type(ret_type) + return FuncType(args, ret_type, type_params, []) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py new file mode 100644 index 000000000000..bbc294b59f5b --- /dev/null +++ b/python/tvm/relay/ir_pass.py @@ -0,0 +1,12 @@ +# pylint: disable=no-else-return, +# pylint: disable=unidiomatic-typecheck +"""The set of passes for Relay. + +Exposes an interface for configuring the passes and scripting +them in Python. +""" +from . import _ir_pass + +# Expose checking expression, should rename to infer_type. +# pylint: disable=invalid-name +check_expr = _ir_pass.check_expr diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py new file mode 100644 index 000000000000..0646a8326db6 --- /dev/null +++ b/python/tvm/relay/op/__init__.py @@ -0,0 +1,12 @@ +#pylint: disable=wildcard-import +"""Relay core operators.""" +# operator defs +from .op import get, register, Op + +# Operators +from .tensor import * + +# operator registry +from . import _tensor +from ..expr import Expr +from ..base import register_relay_node diff --git a/python/tvm/relay/op/_make.py b/python/tvm/relay/op/_make.py new file mode 100644 index 000000000000..79c86cbb0254 --- /dev/null +++ b/python/tvm/relay/op/_make.py @@ -0,0 +1,4 @@ +"""Constructor APIs""" +from ..._ffi.function import _init_api + +_init_api("relay.op._make", __name__) diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py new file mode 100644 index 000000000000..0bc2054cebdf --- /dev/null +++ b/python/tvm/relay/op/_tensor.py @@ -0,0 +1,2 @@ +#pylint: disable=invalid-name +"""Backend compiler related feature registration""" diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py new file mode 100644 index 000000000000..f1130b52e7ce --- /dev/null +++ b/python/tvm/relay/op/op.py @@ -0,0 +1,77 @@ +"""The base node types for the Relay language.""" +from ..._ffi.function import _init_api + +from ..base import register_relay_node +from ..expr import Expr + + +@register_relay_node +class Op(Expr): + """A Relay operator definition.""" + + def __init__(self): + raise RuntimeError("Cannot create op, use get instead") + + def get_attr(self, attr_name): + """Get additional attribute about the operator. + + Parameters + ---------- + attr_name : str + The attribute name. + + Returns + ------- + value : object + The attribute value + """ + return _OpGetAttr(self, attr_name) + + +def get(op_name): + """Get the Op for a given name + + Parameters + ---------- + op_name : str + The operator name + + Returns + ------- + op : Op + The op of the corresponding name + """ + return _GetOp(op_name) + + +def register(op_name, attr_key, value=None, level=10): + """Register an operator property of an operator. + + + Parameters + ---------- + op_name : str + The name of operator + + attr_key : str + The attribute name. + + value : object, optional + The value to set + + level : int, optional + The priority level + + Returns + ------- + fregister : function + Register function if value is not specified. + """ + def _register(v): + """internal register function""" + _Register(op_name, attr_key, v, level) + return v + return _register(value) if value else _register + + +_init_api("relay.op", __name__) diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py new file mode 100644 index 000000000000..fa54d8b53dd8 --- /dev/null +++ b/python/tvm/relay/op/tensor.py @@ -0,0 +1,114 @@ +"""Basic tensor operations.""" +from __future__ import absolute_import as _abs +from . import _make +from ..expr import Tuple + +# We create a wrapper function for each operator in the +# python side to call into the positional _make.OpName function. +# +# We make this decision so that we can: +# - Have declare python docstring for each function +# - Enable keyword arguments easily +# - Not put too much burden on FFI to support complicated features +# like default value and keyword arguments + + +def log(data): + """Take log of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.log(data) + + +def exp(data): + """Take exp of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.exp(data) + + +def sqrt(data): + """Take sqrt of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.sqrt(data) + + +def add(lhs, rhs): + """Elementwise addition. + + Parameters + ---------- + lhs : relay.Expr + The left hand side input data + rhs : relay.Expr + The right hand side input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.add(lhs, rhs) + + +def subtract(lhs, rhs): + """Elementwise subtraction. + + Parameters + ---------- + lhs : relay.Expr + The left hand side input data + rhs : relay.Expr + The right hand side input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.add(lhs, rhs) + +def equal(lhs, rhs): + return _make.equal(lhs, rhs) + +def concat(*args): + """Concatenate the input tensors along the zero axis. + + Parameters + ---------- + args: list of Tensor + + Returns + ------- + tensor: The concatenated tensor. + """ + tup = Tuple(list(args)) + return _make.concat(tup) diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py new file mode 100644 index 000000000000..10e267a53977 --- /dev/null +++ b/python/tvm/relay/ty.py @@ -0,0 +1,138 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +"""The type nodes of the Relay language.""" +from enum import IntEnum +from .base import NodeBase, register_relay_node +from . import _make + + +class Type(NodeBase): + """The base type for all Relay types.""" + + def __eq__(self, other): + """Compare two Relay types for structural equivalence using + alpha equivalence. + """ + return bool(_make._type_alpha_eq(self, other)) + + def __ne__(self, other): + return not self.__eq__(other) + + def same_as(self, other): + """Compares two Relay types by referential equality.""" + return super().__eq__(other) + + +@register_relay_node +class TensorType(Type): + """A concrete TensorType in Relay, see tvm/relay/type.h for more details. + + This is the type assigned to tensor's with a known dype and shape. For + example a tensor of `float32` and `(5, 5)`. + """ + + def __init__(self, shape, dtype): + """Construct a tensor type. + + Parameters + ---------- + shape: list of tvm.Expr + dtype: str + + Returns + ------- + tensor_type: The TensorType + """ + self.__init_handle_by_constructor__(_make.TensorType, shape, dtype) + + +class Kind(IntEnum): + """The kind of a type parameter, represents a variable shape, + base type, type, or dimension. + + This controls what a type parameter is allowed to be instantiated + with. For example one's of kind BaseType can only be `float32`, `int32`, + and so on. + """ + ShapeVar = 0 + Shape = 1 + BaseType = 2 + Type = 3 + + +@register_relay_node +class TypeParam(Type): + """A type parameter used for generic types in Relay, + see tvm/relay/type.h for more details. + + A type parameter represents a type placeholder which will + be filled in later on. This allows the user to write + functions which are generic over types. + """ + + def __init__(self, var, kind): + """Construct a TypeParam. + + Parameters + ---------- + var: tvm.expr.Var + The tvm.Var which backs the type parameter. + + kind: Kind + The kind of the type parameter. + + Returns + ------- + type_param: TypeParam + The type parameter. + """ + self.__init_handle_by_constructor__(_make.TypeParam, var, kind) + + +@register_relay_node +class TypeConstraint(Type): + """Abstract class representing a type constraint.""" + pass + + +@register_relay_node +class FuncType(Type): + """A function type in Relay, see tvm/relay/type.h for more details. + + This is the type assigned to functions in Relay. They consist of + a list of type parameters which enable the definition of generic + fucntions, a set of type constraints which we omit for the time + being, a sequence of argument types, and a return type. + + We informally write them as: + `forall (type_params), (arg_types) -> ret_type where type_constraints` + """ + def __init__(self, + arg_types, + ret_type, + type_params, + type_constraints + ): + """Construct a function type. + + Parameters + ---------- + arg_types: list of Type + ret_type: Type + type_params: list of TypeParam + type_constraints: list of TypeConstraint + + Returns + ------- + func_type: FuncType + The function type. + """ + self.__init_handle_by_constructor__( + _make.FuncType, arg_types, ret_type, type_params, type_constraints) + + +@register_relay_node +class IncompleteType(Type): + """An incomplete type.""" + + def __init__(self, kind): + self.__init_handle_by_constructor__(_make.IncompleteType, kind) diff --git a/python/tvm/relay/ty.pyi b/python/tvm/relay/ty.pyi new file mode 100644 index 000000000000..0581847598d4 --- /dev/null +++ b/python/tvm/relay/ty.pyi @@ -0,0 +1,139 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +"""The type nodes of the Relay language.""" +from enum import IntEnum +from .base import NodeBase, register_relay_node +from . import _make + + +class Type(NodeBase): + """The base type for all Relay types.""" + + def __eq__(self, other): + """Compare two Relay types for structural equivalence using + alpha equivalence. + """ + return bool(_make._type_alpha_eq(self, other)) + + def __ne__(self, other): + return not self.__eq__(other) + + def same_as(self, other): + """Compares two Relay types by referential equality.""" + return super().__eq__(other) + + +@register_relay_node +class TensorType(Type): + """A concrete TensorType in Relay, see tvm/relay/type.h for more details. + + This is the type assigned to tensor's with a known dype and shape. For + example a tensor of `float32` and `(5, 5)`. + """ + + def __init__(self, shape, dtype): + """Construct a tensor type. + + Parameters + ---------- + shape: list of tvm.Expr + dtype: str + + Returns + ------- + tensor_type: The TensorType + """ + self.__init_handle_by_constructor__(_make.TensorType, shape, dtype) + + +class Kind(IntEnum): + """The kind of a type parameter, represents a variable shape, + base type, type, or dimension. + + This controls what a type parameter is allowed to be instantiated + with. For example one's of kind BaseType can only be `float32`, `int32`, + and so on. + """ + ShapeVar = 0 + Shape = 1 + BaseType = 2 + Type = 3 + + +@register_relay_node +class TypeParam(Type): + """A type parameter used for generic types in Relay, + see tvm/relay/type.h for more details. + + A type parameter represents a type placeholder which will + be filled in later on. This allows the user to write + functions which are generic over types. + """ + + def __init__(self, var, kind): + """Construct a TypeParam. + + Parameters + ---------- + var: tvm.expr.Var + The tvm.Var which backs the type parameter. + + kind: Kind + The kind of the type parameter. + + Returns + ------- + type_param: TypeParam + The type parameter. + """ + self.__init_handle_by_constructor__(_make.TypeParam, var, kind) + + +@register_relay_node +class TypeConstraint(Type): + """Abstract class representing a type constraint.""" + pass + + +@register_relay_node +class FuncType(Type): + """A function type in Relay, see tvm/relay/type.h for more details. + + This is the type assigned to functions in Relay. They consist of + a list of type parameters which enable the definition of generic + fucntions, a set of type constraints which we omit for the time + being, a sequence of argument types, and a return type. + + We informally write them as: + `forall (type_params), (arg_types) -> ret_type where type_constraints` + """ + + def __init__(self, + arg_types, + ret_type, + type_params, + type_constraints, + ): + """Construct a function type. + + Parameters + ---------- + arg_types: list of Type + ret_type: Type + type_params: list of TypeParam + type_constraints: list of TypeConstraint + + Returns + ------- + func_type: FuncType + The function type. + """ + self.__init_handle_by_constructor__( + _make.FuncType, arg_types, ret_type, type_params, type_constraints) + + +@register_relay_node +class IncompleteType(Type): + """An incomplete type.""" + + def __init__(self, kind): + self.__init_handle_by_constructor__(_make.IncompleteType, kind) diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index f169ff1b64ac..f0d60f514a37 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -6,8 +6,10 @@ from . import make as _make from . import expr as _expr + class TensorSlice(NodeGeneric, _expr.ExprOp): """Auxiliary data structure for enable slicing syntax from tensor.""" + def __init__(self, tensor, indices): if not isinstance(indices, tuple): indices = (indices,) @@ -31,9 +33,11 @@ def dtype(self): itervar_cls = None + @register_node class Tensor(NodeBase, _expr.ExprOp): """Tensor object, to construct, see function.Tensor""" + def __call__(self, *indices): ndim = self.ndim if len(indices) != ndim: @@ -104,6 +108,7 @@ def name(self): class Operation(NodeBase): """Represent an operation that generate a tensor""" + def output(self, index): """Get the index-th output of the operation diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc new file mode 100644 index 000000000000..7e7fb71f6d6c --- /dev/null +++ b/src/relay/ir/base.cc @@ -0,0 +1,77 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file base.cc + * \brief The core base types for Relay. + */ +#include +#include + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace tvm::runtime; + +SourceName SourceNameNode::make(std::string name) { + std::shared_ptr n = std::make_shared(); + n->name = std::move(name); + return SourceName(n); +} + +std::shared_ptr CreateSourceName(const std::string& name) { + SourceName sn = SourceName::Get(name); + CHECK(!sn.defined()) << "Cannot find source name \'" << name << '\''; + std::shared_ptr node = sn.node_; + return std::dynamic_pointer_cast(node); +} + +const SourceName& SourceName::Get(const std::string& name) { + static std::unordered_map source_map; + + auto sn = source_map.find(name); + if (sn == source_map.end()) { + auto source_name = SourceNameNode::make(name); + source_map.insert({name, source_name}); + return source_map.at(name); + } else { + return sn->second; + } +} + +TVM_REGISTER_API("relay._make.SourceName") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue *ret) { + *ret = SourceNameNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const SourceNameNode *node, tvm::IRPrinter *p) { + p->stream << "SourceNameNode(" << node->name << ", " << node << ")"; + }); + +TVM_REGISTER_NODE_TYPE(SourceNameNode) +.set_creator(CreateSourceName) +.set_global_key([](const Node* n) { + return static_cast(n)->name; + }); + +Span SpanNode::make(SourceName source, int lineno, int col_offset) { + std::shared_ptr n = std::make_shared(); + n->source = std::move(source); + n->lineno = lineno; + n->col_offset = col_offset; + return Span(n); +} + +TVM_REGISTER_API("relay._make.Span") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = SpanNode::make(args[0], args[1], args[2]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const SpanNode *node, tvm::IRPrinter *p) { + p->stream << "SpanNode(" << node->source << ", " << node->lineno << ", " + << node->col_offset << ")"; + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc new file mode 100644 index 000000000000..47c9789ab5ae --- /dev/null +++ b/src/relay/ir/environment.cc @@ -0,0 +1,147 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file environment.cc + * \brief The global environment in Relay. + */ +#include +#include +#include +#include "./../pass/resolve.h" + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace runtime; + +Environment EnvironmentNode::make(tvm::Map global_funcs) { + std::shared_ptr n = std::make_shared(); + n->functions = std::move(global_funcs); + return Environment(n); +} + +GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) { + auto global_id = global_map_.find(str); + if (global_id != global_map_.end()) { + return (*global_id).second; + } else { + auto id = GlobalVarNode::make(str); + this->global_map_.Set(str, id); + return id; + } +} + +/*! \brief Add a new item to the global environment + * \note if the update flag is not set adding a duplicate + * definition will trigger an exception, otherwise we will + * update the definition if and only if it is type compatible. + */ +void EnvironmentNode::Add(const GlobalVar &var, const Function &func, + bool update) { + // Type check the item before we add it to the environment. + auto env = GetRef(this); + + Expr checked_expr = InferType(env, var, func); + + if (const FunctionNode *func_node = checked_expr.as()) { + auto checked_func = GetRef(func_node); + auto type = checked_func->checked_type(); + + CHECK(IsFullyResolved(type)); + + if (functions.find(var) != functions.end()) { + if (!update) { + throw dmlc::Error("already have definition for XXXX."); + } + + auto old_type = functions[var].as()->checked_type(); + + if (!AlphaEqual(type, old_type)) { + throw dmlc::Error( + "Environment#update changes type, not possible in this mode."); + } + + this->functions.Set(var, checked_func); + } else { + this->functions.Set(var, checked_func); + } + } else { + throw Error("internal error: unknown item type, unreachable code"); + } +} + +void EnvironmentNode::Update(const GlobalVar &var, const Function &func) { + this->Add(var, func, true); +} + +void EnvironmentNode::Remove(const GlobalVar & var) { + auto functions_node = this->functions.CopyOnWrite(); + functions_node->data.erase(var.node_); +} + +Function EnvironmentNode::Lookup(const GlobalVar &var) { + auto func = functions.find(var); + if (func != functions.end()) { + return (*func).second; + } else { + throw Error(std::string("there is no definition of ") + var->name_hint); + } +} + +Function EnvironmentNode::Lookup(const std::string &str) { + GlobalVar id = this->GetGlobalVar(str); + return this->Lookup(id); +} + +void EnvironmentNode::Merge(const Environment &env) { + for (auto pair : env->functions) { + this->functions.Set(pair.first, pair.second); + } +} + +TVM_REGISTER_API("relay._make.Environment") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = EnvironmentNode::make(args[0]); + }); + +TVM_REGISTER_API("relay._env.Environment_Add") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + env->Add(args[1], args[2], false); + }); + +TVM_REGISTER_API("relay._env.Environment_GetGlobalVar") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + *ret = env->GetGlobalVar(args[1]); + }); + +TVM_REGISTER_API("relay._env.Environment_Lookup") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + GlobalVar var = args[1]; + *ret = env->Lookup(var); + }); + +TVM_REGISTER_API("relay._env.Environment_Lookup_str") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + std::string var_name = args[1]; + auto var = env->GetGlobalVar(var_name); + *ret = env->Lookup(var); + }); + +TVM_REGISTER_API("relay._env.Environment_Merge") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + env->Merge(args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const EnvironmentNode *node, + tvm::IRPrinter *p) { + p->stream << "EnvironmentNode( " << node->functions << ")"; + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc new file mode 100644 index 000000000000..f4363f5312c4 --- /dev/null +++ b/src/relay/ir/expr.cc @@ -0,0 +1,201 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/ir/expr.cc + * \brief The expression AST nodes of Relay. + */ +#include +#include + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace tvm::runtime; + +Constant ConstantNode::make(runtime::NDArray data) { + std::shared_ptr n = std::make_shared(); + n->data = std::move(data); + return Constant(n); +} + +TVM_REGISTER_API("relay._make.Constant") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ConstantNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const ConstantNode *node, + tvm::IRPrinter *p) { + p->stream << "ConstantNode(TODO)"; + }); + +TensorType ConstantNode::tensor_type() const { + auto dtype = TVMType2Type(data->dtype); + + Array shape; + for (int i = 0; i < data->ndim; i++) { + shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), data->shape[i])); + } + + return TensorTypeNode::make(shape, dtype); +} + +Tuple TupleNode::make(tvm::Array fields) { + std::shared_ptr n = std::make_shared(); + n->fields = std::move(fields); + return Tuple(n); +} + +TVM_REGISTER_API("relay._make.Tuple") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TupleNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TupleNode *node, tvm::IRPrinter *p) { + p->stream << "TupleNode(" << node->fields << ")"; + }); + +Var VarNode::make(std::string name_hint) { + std::shared_ptr n = std::make_shared(); + n->name_hint = std::move(name_hint); + return Var(n); +} + +TVM_REGISTER_API("relay._make.Var") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = VarNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const VarNode *node, + tvm::IRPrinter *p) { + p->stream << "VarNode(" << node->name_hint << ")"; + }); + +GlobalVar GlobalVarNode::make(std::string name_hint) { + std::shared_ptr n = std::make_shared(); + n->name_hint = std::move(name_hint); + return GlobalVar(n); +} + +TVM_REGISTER_API("relay._make.GlobalVar") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = GlobalVarNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const GlobalVarNode *node, + tvm::IRPrinter *p) { + p->stream << "GlobalVarNode(" << node->name_hint << ")"; + }); + +Param ParamNode::make(Var var, Type type) { + std::shared_ptr n = std::make_shared(); + n->var = std::move(var); + n->type = std::move(type); + return Param(n); +} + +TVM_REGISTER_API("relay._make.Param") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ParamNode::make(args[0], args[1]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const ParamNode *node, tvm::IRPrinter *p) { + p->stream << "ParamNode(" << node->var << ", " << node->type << ")"; +}); + +Function FunctionNode::make(tvm::Array params, Type ret_type, Expr body, + tvm::Array type_params) { + std::shared_ptr n = std::make_shared(); + n->params = std::move(params); + n->ret_type = std::move(ret_type); + n->body = std::move(body); + n->type_params = std::move(type_params); + return Function(n); +} + +Type FunctionNode::fn_type() const { + Array param_types; + for (auto param : this->params) { + param_types.push_back(param->type); + } + + return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {}); +} + +TVM_REGISTER_API("relay._make.Function") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = FunctionNode::make(args[0], args[1], args[2], args[3]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const FunctionNode *node, + tvm::IRPrinter *p) { + p->stream << "FunctionNode(" << node->params << ", " << node->ret_type + << ", " << node->body << ", " << node->type_params << ")"; +}); + +Call CallNode::make(Expr op, Array args, Attrs attrs, + Array type_args) { + std::shared_ptr n = std::make_shared(); + n->op = std::move(op); + n->args = std::move(args); + n->attrs = std::move(attrs); + n->type_args = std::move(type_args); + return Call(n); +} + +TVM_REGISTER_API("relay._make.Call") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = CallNode::make(args[0], args[1], args[2], args[3]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const CallNode *node, tvm::IRPrinter *p) { + p->stream << "CallNode(" << node->op << ", " << node->args << ", " + << node->attrs << ", " << node->type_args << ")"; +}); + +Let LetNode::make(Var var, Expr value, Expr body, Type value_type) { + std::shared_ptr n = std::make_shared(); + n->var = std::move(var); + n->value = std::move(value); + n->body = std::move(body); + n->value_type = std::move(value_type); + return Let(n); +} + +TVM_REGISTER_API("relay._make.Let") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = LetNode::make(args[0], args[1], args[2], args[3]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const LetNode *node, tvm::IRPrinter *p) { + p->stream << "LetNode(" << node->var << ", " << node->value + << ", " << node->body << ", " << node->value_type << ")"; +}); + +If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) { + std::shared_ptr n = std::make_shared(); + n->cond = std::move(cond); + n->true_branch = std::move(true_branch); + n->false_branch = std::move(false_branch); + return If(n); +} + +TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = IfNode::make(args[0], args[1], args[2]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const IfNode *node, tvm::IRPrinter *p) { + p->stream << "IfNode(" << node->cond << ", " << node->true_branch + << node->false_branch << ")"; +}); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc new file mode 100644 index 000000000000..85ae5ffa694e --- /dev/null +++ b/src/relay/ir/expr_functor.cc @@ -0,0 +1,205 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/relay/expr_mutator.cc + * \brief A wrapper around ExprFunctor which functionally updates the AST. + * + * ExprMutator uses memoization and self return in order to amortize + * the cost of using functional updates. + */ + +#include + +namespace tvm { +namespace relay { + +Expr ExprMutator::Mutate(const Expr& expr) { + auto cached_expr = this->memo_.find(expr); + if (cached_expr != this->memo_.end()) { + return (*cached_expr).second; + } else { + auto new_expr = this->ExprMutator::VisitExpr(expr, expr); + this->memo_.Set(expr, new_expr); + return new_expr; + } +} + +Expr ExprMutator::VisitExpr_(const VarNode* op, const Expr& expr) { + return expr; +} + +Expr ExprMutator::VisitExpr_(const ConstantNode* op, const Expr& expr) { + return expr; +} + +Expr ExprMutator::VisitExpr_(const GlobalVarNode* op, const Expr& expr) { + return expr; +} + +Expr ExprMutator::VisitExpr_(const OpNode* op, const Expr& expr) { + return expr; +} + +Expr ExprMutator::VisitExpr_(const TupleNode* op, const Expr& e) { + tvm::Array fields; + bool all_fields_unchanged = true; + for (auto field : op->fields) { + auto new_field = this->Mutate(field); + fields.push_back(new_field); + all_fields_unchanged &= new_field.same_as(field); + } + + if (all_fields_unchanged) { + return e; + } else { + return TupleNode::make(fields); + } +} + +Expr ExprMutator::VisitExpr_(const ParamNode* op, const Expr& e) { + Var var = Downcast(this->Mutate(op->var)); + auto type = this->VisitType(op->type); + if (var == op->var && type == op->type) { + return e; + } else { + return ParamNode::make(var, type); + } +} + +Expr ExprMutator::VisitExpr_(const FunctionNode* op, const Expr& e) { + tvm::Array ty_params; + bool all_ty_params_changed = true; + + for (auto ty_param : op->type_params) { + TypeParam new_ty_param = Downcast(VisitType(ty_param)); + ty_params.push_back(new_ty_param); + all_ty_params_changed &= new_ty_param.same_as(ty_param); + } + + tvm::Array params; + bool all_params_changed = true; + for (auto param : op->params) { + Param new_param = Downcast(this->Mutate(param)); + params.push_back(new_param); + all_params_changed &= param.same_as(new_param); + } + + auto ret_type = this->VisitType(op->ret_type); + auto body = this->Mutate(op->body); + + if (ty_params.same_as(op->type_params) && params.same_as(op->params) && + ret_type.same_as(op->ret_type) && body.same_as(op->body)) { + return e; + } else { + return FunctionNode::make(params, ret_type, body, ty_params); + } +} + +Expr ExprMutator::VisitExpr_(const CallNode* call_node, const Expr& e) { + auto op = this->Mutate(call_node->op); + + tvm::Array ty_args; + bool all_ty_args_unchanged = true; + for (auto ty_arg : call_node->type_args) { + auto new_ty_arg = this->VisitType(ty_arg); + ty_args.push_back(new_ty_arg); + all_ty_args_unchanged &= new_ty_arg.same_as(ty_arg); + } + + tvm::Array call_args; + bool all_args_unchanged = true; + for (auto arg : call_node->args) { + auto new_arg = this->Mutate(arg); + call_args.push_back(new_arg); + all_args_unchanged &= new_arg.same_as(arg); + } + + if (all_ty_args_unchanged && all_args_unchanged && + call_node->op.same_as(op)) { + return e; + } else { + return CallNode::make(op, call_args, call_node->attrs, ty_args); + } +} + +Expr ExprMutator::VisitExpr_(const LetNode* op, const Expr& e) { + Var var = Downcast(this->Mutate(op->var)); + auto type = this->VisitType(op->value_type); + auto value = this->Mutate(op->value); + auto body = this->Mutate(op->body); + + if (var.same_as(op->var) && type.same_as(op->value_type) && + value.same_as(op->value) && body.same_as(op->body)) { + return e; + } else { + return LetNode::make(var, value, body, type); + } +} + +Expr ExprMutator::VisitExpr_(const IfNode* op, const Expr& e) { + auto guard = this->Mutate(op->cond); + auto true_b = this->Mutate(op->true_branch); + auto false_b = this->Mutate(op->false_branch); + if (op->cond == guard && true_b == op->true_branch && + false_b == op->false_branch) { + return e; + } else { + return IfNode::make(guard, true_b, false_b); + } +} + +Type ExprMutator::VisitType(const Type& t) { return t; } + +void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) { return; } + +void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) { return; } + +void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) { return; } + +void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) { + for (auto field : op->fields) { + this->VisitExpr(field); + } +} + +void ExprVisitor::ExprVisitor::VisitExpr_(const ParamNode* op) { + this->VisitExpr(op->var); +} + +void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) { + for (auto param : op->params) { + this->VisitExpr(param); + } + + this->VisitExpr(op->body); +} + +void ExprVisitor::VisitExpr_(const CallNode* op) { + this->VisitExpr(op->op); + for (auto ty_arg : op->type_args) { + this->VisitType(ty_arg); + } + + for (auto arg : op->args) { + this->VisitExpr(arg); + } +} + +void ExprVisitor::VisitExpr_(const LetNode* op) { + this->VisitExpr(op->var); + this->VisitExpr(op->value); + this->VisitExpr(op->body); +} + +void ExprVisitor::VisitExpr_(const IfNode* op) { + this->VisitExpr(op->cond); + this->VisitExpr(op->true_branch); + this->VisitExpr(op->false_branch); +} + +void ExprVisitor::VisitExpr_(const OpNode* op) { return; } + +void ExprVisitor::VisitType(const Type& t) { return; } + +} // namespace relay +} // namespace tvm + diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc new file mode 100644 index 000000000000..d1a9dd072d31 --- /dev/null +++ b/src/relay/ir/op.cc @@ -0,0 +1,155 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/relay/op.cc + * \brief Resolve incomplete types to complete types. + */ +#include +#include +#include +#include + +#include +#include + +#include "./../pass/type_subst.h" + +namespace dmlc { +// enable registry +DMLC_REGISTRY_ENABLE(::tvm::relay::OpRegistry); +} // namespace dmlc + +namespace tvm { +namespace relay { + +::dmlc::Registry* OpRegistry::Registry() { + return ::dmlc::Registry::Get(); +} + +// single manager of operator information. +struct OpManager { + // mutex to avoid registration from multiple threads. + std::mutex mutex; + // global operator counter + std::atomic op_counter{0}; + // storage of additional attribute table. + std::unordered_map> attr; + // frontend functions + std::vector frontend_funcs; + // get singleton of the op manager + static OpManager* Global() { + static OpManager inst; + return &inst; + } +}; + +// find operator by name +const Op& Op::Get(const std::string& name) { + const OpRegistry* reg = dmlc::Registry::Find(name); + CHECK(reg != nullptr) << "Operator " << name << " is not registered"; + return reg->op(); +} + +OpRegistry::OpRegistry() { + OpManager* mgr = OpManager::Global(); + std::shared_ptr n = std::make_shared(); + n->index_ = mgr->op_counter++; + op_ = Op(n); +} + +// Get attribute map by key +const GenericOpMap& Op::GetGenericAttr(const std::string& key) { + OpManager* mgr = OpManager::Global(); + std::lock_guard lock(mgr->mutex); + auto it = mgr->attr.find(key); + if (it == mgr->attr.end()) { + LOG(FATAL) << "Operator attribute \'" << key << "\' is not registered"; + } + return *it->second.get(); +} + +void OpRegistry::UpdateAttr(const std::string& key, TVMRetValue value, + int plevel) { + OpManager* mgr = OpManager::Global(); + std::lock_guard lock(mgr->mutex); + std::unique_ptr& op_map = mgr->attr[key]; + if (op_map == nullptr) { + op_map.reset(new GenericOpMap()); + } + uint32_t index = op_->index_; + if (op_map->data_.size() <= index) { + op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0)); + } + std::pair& p = op_map->data_[index]; + CHECK(p.second != plevel) + << "Attribute " << key << " of operator " << this->name + << " is already registered with same plevel=" << plevel; + if (p.second < plevel) { + op_map->data_[index] = std::make_pair(value, plevel); + } +} + +// Frontend APIs +TVM_REGISTER_API("relay.op._ListOpNames") + .set_body_typed()>([]() { + Array ret; + for (const std::string& name : + dmlc::Registry::ListAllNames()) { + ret.push_back(tvm::Expr(name)); + } + return ret; + }); + +TVM_REGISTER_API("relay.op._GetOp").set_body_typed(Op::Get); + +TVM_REGISTER_API("relay.op._OpGetAttr") + .set_body([](TVMArgs args, TVMRetValue* rv) { + Op op = args[0]; + std::string attr_name = args[1]; + auto op_map = Op::GetAttr(attr_name); + if (op_map.count(op)) { + *rv = op_map[op]; + } + }); + +TVM_REGISTER_API("relay.op._Register") + .set_body([](TVMArgs args, TVMRetValue* rv) { + std::string op_name = args[0]; + std::string attr_key = args[1]; + runtime::TVMArgValue value = args[2]; + int plevel = args[3]; + auto& reg = + OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name(); + // enable resgiteration and override of certain properties + if (attr_key == "num_inputs" && plevel > 128) { + reg.set_num_inputs(value); + } else if (attr_key == "attrs_type_key" && plevel > 128) { + reg.set_attrs_type_key(value); + } else { + // normal attr table override. + if (args[2].type_code() == kFuncHandle) { + // do an eager copy of the PackedFunc + PackedFunc f = args[2]; + // If we get a function from frontend, avoid deleting it. + OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f)); + reg.set_attr(attr_key, f, plevel); + } else { + reg.set_attr(attr_key, args[2], plevel); + } + } + }); + +std::shared_ptr CreateOp(const std::string& name) { + auto op = Op::Get(name); + CHECK(!op.defined()) << "Cannot find op \'" << name << '\''; + std::shared_ptr node = op.node_; + return std::dynamic_pointer_cast(node); +} + +TVM_REGISTER_NODE_TYPE(OpNode) +.set_creator(CreateOp) +.set_global_key([](const Node* n) { + return static_cast(n)->name; + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc new file mode 100644 index 000000000000..c13fea26dacd --- /dev/null +++ b/src/relay/ir/type.cc @@ -0,0 +1,121 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/ir/type.cc + * \brief The type system AST nodes of Relay. + */ +#include +#include + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace tvm::runtime; + +TensorType TensorTypeNode::make(Array shape, DataType dtype) { + std::shared_ptr n = std::make_shared(); + n->shape = std::move(shape); + n->dtype = std::move(dtype); + return TensorType(n); +} + +TensorType TensorTypeNode::Scalar(DataType dtype) { + return TensorTypeNode::make({}, dtype); +} + +TVM_REGISTER_API("relay._make.TensorType") +.set_body([](TVMArgs args, TVMRetValue *ret) { + Array shape = args[0]; + *ret = TensorTypeNode::make(shape, args[1]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const TensorTypeNode *node, + tvm::IRPrinter *p) { + p->stream << "TensorTypeNode(" << node->dtype << ", " << node->shape << ")"; +}); + +TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) { + std::shared_ptr n = std::make_shared(); + n->var = tvm::Var(name); + n->kind = std::move(kind); + return TypeParam(n); +} + +TVM_REGISTER_API("relay._make.TypeParam") +.set_body([](TVMArgs args, TVMRetValue *ret) { + int kind = args[1]; + *ret = + TypeParamNode::make(args[0], static_cast(kind)); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const TypeParamNode *node, + tvm::IRPrinter *p) { + p->stream << "TypeParamNode(" << node->var->name_hint << ", " + << node->kind << ")"; +}); + +FuncType FuncTypeNode::make(tvm::Array arg_types, Type ret_type, + tvm::Array type_params, + tvm::Array type_constraints) { + std::shared_ptr n = std::make_shared(); + n->arg_types = std::move(arg_types); + n->ret_type = std::move(ret_type); + n->type_params = std::move(type_params); + n->type_constraints = std::move(type_constraints); + return FuncType(n); +} + +TVM_REGISTER_API("relay._make.FuncType") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const FuncTypeNode *node, + tvm::IRPrinter *p) { + p->stream << "FuncTypeNode(" << node->type_params << ", " + << node->arg_types << ", " << node->ret_type << ", " + << node->type_constraints << ")"; +}); + +TypeRelation TypeRelationNode::make(std::string name, TypeRelationFn func, Array args) { + std::shared_ptr n = std::make_shared(); + n->name = std::move(name); + n->func_ = std::move(func); + n->args = std::move(args); + return TypeRelation(n); +} + +TVM_REGISTER_API("relay._make.TypeRelation") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TypeRelationNode::make(args[0], args[1], args[2]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const TypeRelationNode *node, + tvm::IRPrinter *p) { + p->stream << "TypeRelationNode(" << node->name << ", " << node->args + << ")"; +}); + +TupleType TupleTypeNode::make(Array fields) { + std::shared_ptr n = std::make_shared(); + n->fields = std::move(fields); + return TupleType(n); +} + +TVM_REGISTER_API("relay._make.TupleType") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TupleTypeNode::make(args[0]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const TupleTypeNode *node, + tvm::IRPrinter *p) { + p->stream << "TupleTypeNode(" << node->fields << ")"; +}); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc new file mode 100644 index 000000000000..8c1823114f44 --- /dev/null +++ b/src/relay/op/tensor/elemwise.cc @@ -0,0 +1,137 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file elemwise.cc + * \brief Elementwise operators. + */ +#include +#include +#include "../type_relations.h" + +namespace tvm { +namespace relay { + +// Quick helper macro +// - Expose a positional make function to construct the node. +// - Register op to the registry. +// +// We make the decision to always only expose positional argument. +// We will do rewrapping in the frontend to support language +// sugars such as keyword arguments and default value. +// +#define RELAY_REGISTER_UNARY_OP(OpName) \ + TVM_REGISTER_API("relay.op._make." OpName) \ + .set_body_typed([](Expr data) { \ + static const Op& op = Op::Get(OpName); \ + return CallNode::make(op, {data}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP(OpName) \ + .set_num_inputs(1) \ + .add_argument("data", "Tensor", "The input tensor.") + + +RELAY_REGISTER_UNARY_OP("log") +.describe(R"code(Returns the log input array, computed element-wise. + +.. math:: + log(x) + +)code" TVM_ADD_FILELINE) +.set_support_level(1) +.add_type_rel("Identity", IdentityRel); + +// data : Tensor[shape, dtype] +// result: Tensor[shape, dtype] + + +RELAY_REGISTER_UNARY_OP("exp") +.describe(R"code(Returns the exp input array, computed element-wise. + +.. math:: + \exp(x) + +)code" TVM_ADD_FILELINE) +.set_support_level(1) +.add_type_rel("Identity", IdentityRel); + + +RELAY_REGISTER_UNARY_OP("sqrt") +.describe(R"code(Returns the sqrt input array, computed element-wise. + +.. math:: + sqrt(x) + +)code" TVM_ADD_FILELINE) +.set_support_level(1) +.add_type_rel("Identity", IdentityRel); + +// Addition +TVM_REGISTER_API("relay.op._make.add") + .set_body_typed([](Expr lhs, Expr rhs) { + static const Op& op = Op::Get("add"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); + }); + +RELAY_REGISTER_OP("add") + .set_num_inputs(2) + .add_argument("lhs", "Tensor", "The left hand side tensor.") + .add_argument("rhs", "Tensor", "The right hand side tensor.") + .set_support_level(1) + .add_type_rel("Broadcast", BroadcastRel); + + // def broadcast(s1, s2): + // ... + // + // input1: Tensor[dtype, s1] + // input2: Tensor[dtype, s2] + // output: Tensor[dtype, broadcast(s1, s2)] + +// Addition +TVM_REGISTER_API("relay.op._make.subtract") + .set_body_typed([](Expr lhs, Expr rhs) { + static const Op& op = Op::Get("subtract"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); + }); + +RELAY_REGISTER_OP("subtract") + .set_num_inputs(2) + .add_argument("lhs", "Tensor", "The left hand side tensor.") + .add_argument("rhs", "Tensor", "The right hand side tensor.") + .set_support_level(1) + .add_type_rel("Broadcast", BroadcastRel); + + // def broadcast(s1, s2): + // ... + // + // input1: Tensor[dtype, s1] + // input2: Tensor[dtype, s2] + // output: Tensor[dtype, broadcast(s1, s2)] + +// Equality Comparison +TVM_REGISTER_API("relay.op._make.equal") + .set_body_typed([](Expr lhs, Expr rhs) { + static const Op& op = Op::Get("equal"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); + }); + +RELAY_REGISTER_OP("equal") + .set_num_inputs(2) + .add_argument("lhs", "Tensor", "The left hand side tensor.") + .add_argument("rhs", "Tensor", "The right hand side tensor.") + .set_support_level(1) + .add_type_rel("BroadcastComp", BroadcastCompRel); + +// Concat +TVM_REGISTER_API("relay.op._make.concat") + .set_body_typed([](Expr tuple) { + static const Op& op = Op::Get("concat"); + return CallNode::make(op, { tuple }, Attrs(), {}); + }); + +RELAY_REGISTER_OP("concat") + .set_num_inputs(1) + .add_argument("tuple", "Tuple", "The tupled tensor arguments.") + .set_support_level(1) + .add_type_rel("Concat", ConcatRel); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc new file mode 100644 index 000000000000..94550dbd5075 --- /dev/null +++ b/src/relay/op/type_relations.cc @@ -0,0 +1,206 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file type_relations.cc + * \brief A set of utilities and common functionality + * for type relations. + */ +#include +#include +#include +#include +#include "../pass/incomplete_type.h" +#include "./type_relations.h" + +namespace tvm { +namespace relay { + +TensorType ToTensorType(const Type& t) { + if (auto tt_node = t.as()) { + return GetRef(tt_node); + } else { + return TensorType(nullptr); + } +} + +// TODO(@jroesch) what size value do we extract, 64bit or 32bit? +int ToInt(const tvm::Expr& e) { + CHECK(e.defined()); + auto imm = e.as(); + CHECK(imm) << "TYPE: " << imm << imm->type << std::endl; + return imm->value; +} + +Array IdentityRel(const Array& types, int num_args) { + CHECK_EQ(types.size(), 2); + auto t1 = ToTensorType(types[0]); + if (t1 && types[1].as()) { + return {t1, t1}; + } else { + return types; + } +} + +static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, + DataType output_dtype) { + RELAY_LOG(INFO) << "ConcreteBroadcast: t1=" << t1 << " t2=" << t2 + << std::endl; + auto sh1 = t1->shape; + auto sh2 = t2->shape; + RELAY_LOG(INFO) << "ConcreteBroadcast: sh1=" << sh1 << " sh2=" << sh2 + << std::endl; + if (sh1.size() == 0 && sh2.size() == 0) { + return TensorTypeNode::make({}, output_dtype); + // We have non-zero shapes so broadcast rules apply. + } else { + auto suffix_len = static_cast(std::min(sh1.size(), sh2.size())); + auto full_len = static_cast(std::max(sh1.size(), sh2.size())); + + auto rev_sh1 = sh1.rbegin(); + auto rev_sh2 = sh2.rbegin(); + + while (rev_sh1 != sh1.rend() && rev_sh2 != sh2.rend()) { + auto dim1 = ToInt(*rev_sh1); + auto dim2 = ToInt(*rev_sh2); + if ((dim1 != dim2) && ((dim1 != 1) && (dim2 != 1))) { + CHECK(false) << "Dimension mistmatch " + << "dim1: " << dim1 << " dim2: " << dim2 << std::endl; + } + rev_sh1++; + rev_sh2++; + } + + Array larger; + Array smaller; + + for (int i = 0; i < (full_len - suffix_len); i++) { + smaller.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), 1)); + } + + if (sh1.size() < sh2.size()) { + for (auto sh : sh1) { + smaller.push_back(sh); + } + larger = sh2; + } else if (sh1.size() > sh2.size()) { + for (auto sh : sh1) { + larger.push_back(sh); + } + smaller = sh2; + } else { + larger = sh1; + smaller = sh2; + } + + CHECK_EQ(larger.size(), smaller.size()); + + Array out_shape; + for (size_t i = 0; i < smaller.size(); i++) { + auto left = smaller[i].as(); + auto right = larger[i].as(); + CHECK(left); + CHECK(right); + int64_t dim = std::max(left->value, right->value); + out_shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), dim)); + } + + return TensorTypeNode::make(out_shape, output_dtype); + } +} + +Array BroadcastRel(const Array& types, int num_args) { + CHECK_EQ(types.size(), 3); + RELAY_LOG(INFO) << "In1: " << types[0] << "In2: " << types[1] + << "Out: " << types[2] << std::endl; + if (auto t1 = ToTensorType(types[0])) { + if (auto t2 = ToTensorType(types[1])) { + CHECK_EQ(t1->dtype, t2->dtype); + return {t1, t2, ConcreteBroadcast(t1, t2, t1->dtype)}; + } + } + + return types; +} + +/* A relation which specifies broadcasting rules for operations which + compute boolean results. +*/ +Array BroadcastCompRel(const Array& types, int num_args) { + CHECK_EQ(types.size(), 3); + if (auto t1 = ToTensorType(types[0])) { + if (auto t2 = ToTensorType(types[1])) { + return {t1, t2, ConcreteBroadcast(t1, t2, HalideIR::Bool())}; + } + } + + return types; +} + +/*! \brief Handle concrete concat case from known input to output. */ +inline Type ConcreteConcatRel(const Type& input_type) { + if (auto tuple_node = input_type.as()) { + // NB: For now the axis argument is hardwired to be 0. + std::vector dims; + DataType dtype; + + CHECK_LT(1, tuple_node->fields.size()); + bool skip_first = true; + + // Collect the suffix dimensions since axis is zero. + // TODO(@jroesch): This is a demonstration of how + // to do varargs. It requires a little more work to + // fully type the behavior of concat. + + auto first = Downcast(tuple_node->fields[0]); + dtype = first->dtype; + + for (auto dim_expr : first->shape) { + if (!skip_first) { + dims.push_back(ToInt(dim_expr)); + } else { + skip_first = false; + } + } + + std::vector axis_dims; + for (auto field_ty : tuple_node->fields) { + auto ttype = Downcast(field_ty); + for (size_t i = 0; i < ttype->shape.size(); i++) { + if (i != 0) { + CHECK_EQ(ToInt(dims[i - 1]), ToInt(ttype->shape[i])); + } else { + axis_dims.push_back(ToInt(ttype->shape[i])); + } + } + } + + auto out_axis_dim = std::accumulate(axis_dims.begin(), axis_dims.end(), 0); + + Array out_shape = { tvm::ir::IntImm::make(HalideIR::Int(64), out_axis_dim) }; + + for (auto dim : dims) { + out_shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), dim)); + } + + return TensorTypeNode::make(out_shape, dtype); + + } else { + throw TypeRelationError("concat can only be used with a tuple as its argument"); + } +} + +Array ConcatRel(const Array& types, int num_args) { + CHECK_EQ(types.size(), 2); + + if (types[0].as() && types[1].as()) { + return types; + } else if (types[1].as()) { + return { types[0], ConcreteConcatRel(types[0]) }; + } else { + throw TypeRelationError( + "can not deduce relationship between the " \ + "type of concat's input and output"); + } +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h new file mode 100644 index 000000000000..9dfc29022ee3 --- /dev/null +++ b/src/relay/op/type_relations.h @@ -0,0 +1,67 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/op/type_relations.h + * \brief A set of utilities and common functionality + * for type relations. + */ +#ifndef TVM_RELAY_OP_TYPE_RELATIONS_H_ +#define TVM_RELAY_OP_TYPE_RELATIONS_H_ + +#include +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief The error raised by a type relation. + * + * This error is how a type relation signals that it has failed. + * + */ +struct TypeRelationError : Error { + explicit TypeRelationError(const std::string& msg) + : Error(msg) {} +}; + +/*! \brief The identity type relation maps a single input variable + * to the output variable. + * + * \param types The input and output types to the relation. + * \param num_args The number of input arguments. + * \return The (potentially partial) solution to the relation. + */ +Array IdentityRel(const Array& types, int num_args); +/*! \brief The broadcast type relation, implements the broadcasting + * rule over the two input types producing the broadcasted type. + * + * \param types The input and output types to the relation. + * \param num_args The number of input arguments. + * \return The (potentially partial) solution to the relation. + */ +Array BroadcastRel(const Array& types, int num_args); +/*! \brief The broadcast type relation, implements the broadcasting + * rule over the two input types producing the broadcasted type. + * + * This differs from BroadcastRel in the return dtype, + * it instead returns bool, for use in comparsion operators + * such as equal, not_equal, lt, and so on. + * + * \param types The input and output types to the relation. + * \param num_args The number of input arguments. + * \return The (potentially partial) solution to the relation. + */ +Array BroadcastCompRel(const Array& types, int num_args); + +/*! \brief The concat relation. + * + * This relation takes a single input which must be a single tensor + * or an arbitrary sized tuple. It combines these input dimensions + * together to produce the output example. + */ +Array ConcatRel(const Array& types, int num_args); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_OP_TYPE_RELATIONS_H_ diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc new file mode 100644 index 000000000000..f76da793c503 --- /dev/null +++ b/src/relay/pass/alpha_eq.cc @@ -0,0 +1,258 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/relay/pass/alpha_eq.cc + * \brief Compute the set of variables not bound in the expression. + */ +#include +#include "./type_visitor.h" +#include "tvm/relay/pass.h" + +namespace tvm { +namespace relay { + +using namespace tvm::runtime; + +struct TypeAlphaEq : TypeVisitor { + tvm::Map eq_map; + bool equal; + + TypeAlphaEq() : eq_map(), equal(true) {} + + void DataTypeEqual(const DataType& dt1, const DataType& dt2) { + equal = equal && dt1 == dt2; + } + void ShapeEqual(Array s1, Array s2) {} + + void VisitType_(const TensorTypeNode *tt1, const Type& t2) final { + if (const TensorTypeNode *tt2 = t2.as()) { + DataTypeEqual(tt1->dtype, tt2->dtype); + ShapeEqual(tt1->shape, tt2->shape); + } else { + equal = false; + } + } + + void VisitType_(const IncompleteTypeNode *bt1, const Type& t2) final { + if (const IncompleteTypeNode *bt2 = t2.as()) { + equal = equal && bt1 == bt2; + return; + } else { + equal = false; + } + } + + void VisitType_(const TypeParamNode *ti1, const Type& t2) final { + if (const TypeParamNode *ti2 = t2.as()) { + auto tid1 = GetRef(ti1); + auto tid2 = GetRef(ti2); + + // We handle open terms with this rule assuming variables are identical. + // + // Not sure if we should do this. + if (tid1 == tid2) { + return; + } + + // Check that they are same kind + if (tid1->kind != tid2->kind) { + equal = false; + return; + } + + // Next we see if there is mapping for local1 into the rhs term. + // If there is we check to see if those are equal. + if (eq_map.find(tid1) != eq_map.end()) { + equal = equal && eq_map[tid1] == tid2; + } else { + equal = false; + } + } else { + equal = false; + } + } + + void VisitType_(const FuncTypeNode *op, const Type& t2) final { + if (const FuncTypeNode *ta2 = t2.as()) { + if (op->arg_types.size() != ta2->arg_types.size()) { + equal = false; + return; + } + + for (size_t i = 0; i < op->arg_types.size(); i++) { + this->VisitType(op->arg_types[i], ta2->arg_types[i]); + if (!equal) { + return; + } + } + + this->VisitType(op->ret_type, ta2->ret_type); + } else { + equal = false; + } + } + + void VisitType_(const TypeRelationNode *tr1, const Type& t2) final { + if (const TypeRelationNode *tr2 = t2.as()) { + equal = tr1 == tr2; + } else { + equal = false; + } + } + + void VisitType_(const TupleTypeNode *op, const Type& t2) final { + if (const TupleTypeNode *pt = t2.as()) { + if (op->fields.size() != pt->fields.size()) { + equal = false; + return; + } + + for (size_t i = 0U; i < op->fields.size(); i++) { + if (!equal) { + return; + } + this->VisitType(op->fields[i], pt->fields[i]); + } + } else { + equal = false; + } + } +}; + +bool AlphaEqual(const Type& t1, const Type& t2) { + TypeAlphaEq aeq; + aeq.VisitType(t1, t2); + return aeq.equal; +} + +struct AlphaEq : ExprFunctor { + public: + tvm::Map eq_map; + + bool equal; + AlphaEq() : eq_map(), equal(true) {} + + void VisitExpr_(const VarNode *e1, const Expr& e2) final { + if (const VarNode *id2 = e2.as()) { + auto local1 = GetRef(e1); + auto local2 = GetRef(id2); + // We handle open terms with this rule assuming variables are identical. + if (local1 == local2) { + equal = true; + return; + } + + // Next we see if there is mapping for local1 into the rhs term. + // If there is we check to see if those are equal. + if (eq_map.find(local1) != eq_map.end()) { + equal = equal && eq_map[local1] == local2; + } else { + equal = false; + } + } else { + equal = false; + } + } + + void VisitExpr_(const GlobalVarNode *g1, const Expr& e2) final { + if (const GlobalVarNode *g2 = e2.as()) { + equal = equal && g1 == g2; + } else { + equal = false; + } + } + + void VisitExpr_(const TupleNode *pl1, const Expr& e2) final { + Tuple prod1 = GetRef(pl1); + if (const TupleNode *pl2 = e2.as()) { + Tuple prod2 = GetRef(pl2); + if (prod1->fields.size() != prod2->fields.size()) { + equal = false; + return; + } + + for (size_t i = 0U; i < prod1->fields.size(); i++) { + this->VisitExpr(prod1->fields[i], prod2->fields[i]); + } + } else { + equal = false; + } + } + + void VisitExpr_(const ParamNode *p1, const Expr& e2) final { + if (const ParamNode *p2 = e2.as()) { + eq_map.Set(p1->var, p2->var); + equal = equal && AlphaEqual(p1->type, p2->type); + } else { + equal = false; + } + } + + void VisitExpr_(const FunctionNode *func1, const Expr& e2) final { + if (const FunctionNode *func2 = e2.as()) { + if (func1->params.size() != func2->params.size()) { + equal = false; + return; + } + + for (size_t i = 0U; i < func1->params.size(); i++) { + this->VisitExpr(func1->params[i], func2->params[i]); + } + + this->VisitExpr(func1->body, func2->body); + } else { + equal = false; + } + } + + void VisitExpr_(const CallNode *op, const Expr& e2) final { + if (const CallNode *call = e2.as()) { + this->VisitExpr(op->op, call->op); + + if (op->args.size() != call->args.size()) { + equal = false; + return; + } + + for (size_t i = 0U; i < op->args.size(); i++) { + this->VisitExpr(op->args[i], call->args[i]); + } + + } else { + equal = false; + } + } + + void VisitExpr_(const LetNode *op, const Expr& e2) final { + if (const LetNode *let = e2.as()) { + eq_map.Set(op->var, let->var); + this->VisitExpr(op->value, let->value); + this->VisitExpr(op->body, let->body); + } else { + equal = false; + } + } +}; + +bool AlphaEqual(const Expr& e1, const Expr& e2) { + AlphaEq eq; + eq.VisitExpr(e1, e2); + return eq.equal; +} + +// TODO(@jroesch): move to correct namespace? +TVM_REGISTER_API("relay._make._alpha_eq") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Expr e1 = args[0]; + Expr e2 = args[1]; + *ret = AlphaEqual(e1, e2); + }); + +TVM_REGISTER_API("relay._make._type_alpha_eq") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Type t1 = args[0]; + Type t2 = args[1]; + *ret = AlphaEqual(t1, t2); + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/incomplete_type.h b/src/relay/pass/incomplete_type.h new file mode 100644 index 000000000000..78771dc6e9b7 --- /dev/null +++ b/src/relay/pass/incomplete_type.h @@ -0,0 +1,38 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file incomplete_type.h + * \brief A way to defined arbitrary function signature with dispatch on types. + */ + +#ifndef TVM_RELAY_PASS_INCOMPLETE_TYPE_H_ +#define TVM_RELAY_PASS_INCOMPLETE_TYPE_H_ + +#include + +namespace tvm { +namespace relay { + +/*! + * \brief Represents a portion of an incomplete type. + */ +class IncompleteType; + +/*! \brief IncompleteType container node */ +class IncompleteTypeNode : public TypeNode { + public: + TypeParamNode::Kind kind; + + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("kind", &kind); } + + TVM_DLL static IncompleteType make(TypeParamNode::Kind kind); + + static constexpr const char* _type_key = "relay.IncompleteType"; + TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(IncompleteType, IncompleteTypeNode, Type); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_PASS_INCOMPLETE_TYPE_H_ diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc new file mode 100644 index 000000000000..522eb93483fb --- /dev/null +++ b/src/relay/pass/kind_check.cc @@ -0,0 +1,42 @@ +/*! + * Copyright (c) 2018 by Contributors + * + * \file kindchecker.cc + * + * \brief Check that types are well formed by applying "kinding rules". + * + * This pass ensures we do not do things that violate the design of the + * type system when writing down types. + * + * For example tensors are not allowed to contain functions in Relay. + * + * We check this by ensuring the `dtype` field of a Tensor always + * contains a data type such as `int`, `float`, `uint`. + */ +#include +#include +#include "./type_visitor.h" + +namespace tvm { +namespace relay { + +using namespace tvm::runtime; + +struct KindChecker : TypeVisitor<> { + bool valid; + + KindChecker() : valid(true) {} + + bool Check(const Type &t) { + this->VisitType(t); + return valid; + } +}; + +bool KindCheck(const Environment& env, const Type &t) { + KindChecker kc; + return kc.Check(t); +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/resolve.cc b/src/relay/pass/resolve.cc new file mode 100644 index 000000000000..b073613bafc2 --- /dev/null +++ b/src/relay/pass/resolve.cc @@ -0,0 +1,100 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file resolve.cc + * \brief Resolve incomplete types to complete types. + */ + +#include +#include +#include "./resolve.h" +#include "./type_visitor.h" + +namespace tvm { +namespace relay { + +struct ResolveTypeType : TypeMutator { + const TypeUnifier &unifier; + + explicit ResolveTypeType(const TypeUnifier &unifier) : unifier(unifier) {} + + Type VisitType(const Type &t) override { + if (!t.defined()) { + auto inc_ty = IncompleteTypeNode::make(TypeParamNode::Kind::kType); + unifier->Insert(inc_ty); + return inc_ty; + } else { + return TypeMutator::VisitType(t); + } + } + + Type VisitType_(const IncompleteTypeNode *op) override { + return unifier->Subst(GetRef(op)); + } +}; + +struct ResolveTypeExpr : ExprMutator { + const TypeUnifier &unifier; + + explicit ResolveTypeExpr(const TypeUnifier &unifier) : unifier(unifier) {} + + Expr Mutate(const Expr &e) { + // NB: a bit tricky here. + // + // We want to store resolved type without having + // to re-typecheck the entire term. + // + // Since we know that e : T[...] under some holes + // then it is the case that if we resolve types + // present in e, then we can type it under T + // with the wholes filled in. + // + // We will visit e like normal building a new + // term, then resolve e's old type and write + // it back into the new node. + auto new_e = ExprMutator::Mutate(e); + CHECK(e->checked_type_.defined()); + auto resolved_cty = VisitType(e->checked_type_); + new_e->checked_type_ = resolved_cty; + return new_e; + } + + Type VisitType(const Type &t) { + return ResolveTypeType(unifier).VisitType(t); + } +}; + +Type Resolve(const TypeUnifier &unifier, const Type &ty) { + CHECK(ty.defined()); + return ResolveTypeType(unifier).VisitType(ty); +} + +Expr Resolve(const TypeUnifier &unifier, const Expr &expr) { + return ResolveTypeExpr(unifier).Mutate(expr); +} + +struct FullyResolved : TypeVisitor<> { + bool incomplete; + + FullyResolved() : incomplete(true) {} + + void VisitType(const Type &t) override { + if (!t.defined()) { + incomplete = true; + } else { + return TypeVisitor<>::VisitType(t); + } + } + + void VisitType_(const IncompleteTypeNode *ty_var) override { + incomplete = false; + } +}; + +bool IsFullyResolved(const Type &t) { + auto fr = FullyResolved(); + fr.VisitType(t); + return fr.incomplete; +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/resolve.h b/src/relay/pass/resolve.h new file mode 100644 index 000000000000..0cd7dce2d88d --- /dev/null +++ b/src/relay/pass/resolve.h @@ -0,0 +1,47 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/resolve.h + * \brief Resolve incomplete types to complete types. + */ +#ifndef TVM_RELAY_PASS_RESOLVE_H_ +#define TVM_RELAY_PASS_RESOLVE_H_ + +#include +#include +#include "./unifier.h" + +namespace tvm { +namespace relay { + +/*! \brief Resolve a type containing incomplete types. + * + * This pass replaces incomplete types with their representative, and + * converts types which are not defined into fresh variables. + * + * \param unifier The unifier containing the unification data. + * \param ty The type to resolve. + * \returns The resolved type. + */ +Type Resolve(const TypeUnifier& unifier, const Type& ty); + +/*! \brief Resolve an expression containing incomplete types. + * + * This pass replaces incomplete types with their representative, and + * converts types which are not defined into fresh variables. + * + * \param unifier The unifier containing the unification data. + * \param ty The expression to resolve. + * \returns The resolved expression. + */ +Expr Resolve(const TypeUnifier& unifier, const Expr& expr); + +/*! \brief Check if all types have been filled in. + * \param t The type. + * \returns True if the type is resolved, false otherwise. + */ +bool IsFullyResolved(const Type& t); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_PASS_RESOLVE_H_ diff --git a/src/relay/pass/type_functor.h b/src/relay/pass/type_functor.h new file mode 100644 index 000000000000..339552108af4 --- /dev/null +++ b/src/relay/pass/type_functor.h @@ -0,0 +1,93 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file type_functor.h + * \brief A way to defined arbitrary function signature with dispatch on types. + */ +#ifndef TVM_RELAY_PASS_TYPE_FUNCTOR_H_ +#define TVM_RELAY_PASS_TYPE_FUNCTOR_H_ + +#include +#include +#include "./incomplete_type.h" + +namespace tvm { +namespace relay { + +template +class TypeFunctor; + +// functions to be overriden. +#define TYPE_FUNCTOR_DEFAULT \ + { return VisitTypeDefault_(op, std::forward(args)...); } + +#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch( \ + [](const NodeRef& n, TSelf* self, Args... args) { \ + return self->VisitType_(static_cast(n.node_.get()), \ + std::forward(args)...); \ + }); + +template +class TypeFunctor { + private: + using TSelf = TypeFunctor; + using FType = tvm::IRFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~TypeFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const Type& n, Args... args) { + return VisitType(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitType(const Type& n, Args... args) { + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitType_(const TensorTypeNode* op, + Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeParamNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + + virtual R VisitTypeDefault_(const Node* op, Args...) { + LOG(FATAL) << "Do not have a default for " << op->type_key(); + return R(); + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAY_TYPE_FUNCTOR_DISPATCH(TensorTypeNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeParamNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode); + RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode); + RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode); + return vtable; + } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PASS_TYPE_FUNCTOR_H_ diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc new file mode 100644 index 000000000000..f4f6d82eb5e1 --- /dev/null +++ b/src/relay/pass/type_infer.cc @@ -0,0 +1,629 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file type_infer.cc + * \brief Relay type inference and checking. + * + * This file implements one of the most important passes to the + * Relay IR. In order to do many transformations and generate the + * most efficient code we need to obtain type information for the + * IR. + * + * Like computation graphs the IR leaves most type information + * implicit and relies performing analysis of the program to + * generate this information. + * + * This pass given an expression `e` will infer a type `t` for + * the expression simultaneous checking the property `e : t` + * (i.e we can show e has type t). + * + * If we can not infer a type or there are conflicting typing + * constraints we will trigger an error. + */ + +#include +#include +#include +#include +#include "./incomplete_type.h" +#include "./resolve.h" +#include "./type_subst.h" +#include "./type_visitor.h" +#include "./unifier.h" + +namespace tvm { +namespace relay { + +using namespace tvm::runtime; + +// // We declare this for forward compatibility. +struct ConstraintData {}; + +/*! \brief A more efficient representation of the type relation + * data needed for type checking. + */ +struct TypeRelationData : ConstraintData { + std::string name; + std::vector args; + TypeRelationFn func; + Span span; + + explicit TypeRelationData(const TypeRelation& ty_rel) + : TypeRelationData(ty_rel->args, ty_rel->func_, ty_rel->span) {} + + TypeRelationData(const Array& args, const TypeRelationFn& func, const Span& sp) + : func(func), span(sp) { + for (auto arg : args) { + this->args.push_back(arg); + } + } + + TypeRelation ToTypeRel() const { + Array args = Array(this->args.begin(), this->args.end()); + return TypeRelationNode::make( + this->name, this->func, args); + } +}; + +struct TypeContext { + std::unordered_map var_map; + std::vector > constraints; + + TypeContext() { constraints.push_back({}); } + + void Insert(const Var& id, const Type& t) { var_map[id] = t; } + + void AddConstraint(const TypeConstraint& constraint) { + constraints.back().push_back(TypeRelationData(Downcast(constraint))); + } + + Type Lookup(const Var& var) { + auto type = var_map.find(var); + if (type != var_map.end()) { + return (*type).second; + } else { + throw FatalTypeError(std::string("undeclared local variable: ") + var->name_hint); + } + } + + struct Scope { + TypeContext& tc; + explicit Scope(TypeContext& tc) : tc(tc) { tc.constraints.push_back({}); } + ~Scope() { tc.constraints.pop_back(); } + }; +}; + +struct CheckedExpr { + Expr expr; + Type type; + CheckedExpr(Expr e, Type t) : expr(e), type(t) {} + CheckedExpr() {} +}; + +enum SolverResult : int; + +class TypeInferencer : private ExprFunctor { + private: + TypeContext context; + + public: + Environment env; + TypeUnifier unifier; + + template + T WithScope(const std::function& f) { + TypeContext::Scope fr(context); + return f(); + } + + TypeInferencer(); + TypeInferencer(Environment env, TypeUnifier unifier) + : env(env), unifier(unifier) {} + explicit TypeInferencer(Environment env); + + CheckedExpr Infer(const Expr &expr); + + FuncType Instantiate(FuncType fn_ty, tvm::Array &ty_args); + + Type Normalize(const Type& t); + + void ReportError(const std::string& msg, Span sp); + [[noreturn]] void FatalError(const std::string& msg, Span sp); + + Type Unify(const Type &t1, const Type& t2, Span sp); + Type Resolve(const Type &t); + Expr Resolve(const Expr &e); + + /*! \brief Attempt to solve a single relation. */ + void Solve(TypeRelationData& ty_rel); + + /*! \brief Attempt to solve all pending relations. + * + * If the solver + */ + SolverResult Solve(std::vector& rels); + + /*! \brief Check that all relations hold. */ + bool RelationsHold(bool scope_only = false); + + /*! \brief Visit a function node, extra flag controls behavior. */ + CheckedExpr VisitFunction(const Function& f, bool generalize); + + private: + CheckedExpr VisitExpr_(const VarNode* op) override; + CheckedExpr VisitExpr_(const GlobalVarNode* op) override; + CheckedExpr VisitExpr_(const ConstantNode* op) override; + CheckedExpr VisitExpr_(const TupleNode* op) override; + CheckedExpr VisitExpr_(const ParamNode* op) override; + CheckedExpr VisitExpr_(const FunctionNode* op) override; + CheckedExpr VisitExpr_(const CallNode* op) override; + CheckedExpr VisitExpr_(const LetNode* op) override; + CheckedExpr VisitExpr_(const IfNode* op) override; + CheckedExpr VisitExpr_(const OpNode* op) override; +}; + +TypeInferencer::TypeInferencer() { + this->env = EnvironmentNode::make({}); + this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); +} + +TypeInferencer::TypeInferencer(Environment env) : env(env) { + this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); +} + +CheckedExpr TypeInferencer::Infer(const Expr& expr) { + RELAY_LOG(INFO) << "TypeInferencer::Check expr=" << expr << std::endl; + CheckedExpr checked_expr = this->VisitExpr(expr); + RELAY_LOG(INFO) << "TypeInferencer::Check type=" << checked_expr.type + << std::endl; + Type final_type = checked_expr.type; + RELAY_LOG(INFO) << "TypeInferencer::Check type_after_subst=" << final_type + << std::endl; + checked_expr.expr->checked_type_ = final_type; + return checked_expr; +} + +CheckedExpr TypeInferencer::VisitExpr_(const VarNode* op) { + auto var = GetRef(op); + return {var, this->context.Lookup(var)}; +} + +CheckedExpr TypeInferencer::VisitExpr_(const GlobalVarNode* op) { + GlobalVar var = GetRef(op); + Expr e = this->env->Lookup(var); + return {var, e->checked_type()}; +} + +CheckedExpr TypeInferencer::VisitExpr_(const ConstantNode* const_node) { + return {GetRef(const_node), const_node->tensor_type()}; +} + +CheckedExpr TypeInferencer::VisitExpr_(const TupleNode* op) { + Tuple pl = GetRef(op); + + std::vector field_exprs; + std::vector field_types; + for (auto field = pl->fields.begin(); field != pl->fields.end(); field++) { + auto checked_field = Infer(*field); + field_exprs.push_back(checked_field.expr); + field_types.push_back(checked_field.type); + } + + return {TupleNode::make(field_exprs), TupleTypeNode::make(field_types)}; +} + +CheckedExpr TypeInferencer::VisitExpr_(const ParamNode* param) { + // We should trigger error here and move param code direclty into function + // checking. + auto rtype = this->Resolve(param->type); + // This is a special case ... not sure if there is a better way + // to handle this. + param->var->checked_type_ = rtype; + return {ParamNode::make(param->var, rtype), rtype}; +} + +CheckedExpr TypeInferencer::VisitFunction(const Function& f, bool generalize) { + // First we add the parameters to the context allowing us to check their + // types. + + // TODO(@jroesch): support polymorphism + + std::vector param_types; + std::vector params; + + return this->WithScope([&]() -> CheckedExpr { + for (auto param : f->params) { + CheckedExpr checked_param = this->Infer(param); + Type arg_type; + param_types.push_back(checked_param.type); + params.push_back(GetRef(checked_param.expr.as())); + this->context.Insert(param->var, checked_param.type); + } + + auto checked_body = this->Infer(f->body); + auto inferred_rtype = checked_body.type; + auto annotated_rtype = Resolve(f->ret_type); + + auto unified_rtype = this->Unify(inferred_rtype, annotated_rtype, f->span); + + CHECK(RelationsHold(true)); + + Array cs; + + for (auto cons : this->context.constraints.back()) { + cs.push_back(cons.ToTypeRel()); + } + + return {FunctionNode::make(params, unified_rtype, checked_body.expr, {}), + FuncTypeNode::make(param_types, unified_rtype, {}, cs)}; + }); +} + +CheckedExpr TypeInferencer::VisitExpr_(const FunctionNode* op) { + return this->VisitFunction(GetRef(op), false); +} + +FuncType TypeInferencer::Instantiate(FuncType fn_ty, + tvm::Array& ty_args) { + tvm::Map subst_map; + + // Build a subsitituion map up from the function type and type arguments. + // Eventually allow the type vars to be passed in. + for (auto ty_param : fn_ty->type_params) { + IncompleteType fresh = IncompleteTypeNode::make(ty_param->kind); + this->unifier->Insert(fresh); + ty_args.push_back(fresh); + subst_map.Set(ty_param, fresh); + } + + Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, + fn_ty->type_constraints); + inst_ty = TypeSubst(inst_ty, subst_map); + + CHECK(KindCheck(this->env, inst_ty)); + + return GetRef(inst_ty.as()); +} + +CheckedExpr TypeInferencer::VisitExpr_(const CallNode* op) { + Call c = GetRef(op); + + auto checked_op = this->Infer(c->op); + + RELAY_LOG(INFO) << "TypeInferencer::VisitExpr_ op=" << c << std::endl + << "fn_ty=" << checked_op.type << std::endl; + + auto fn_ty_node = checked_op.type.as(); + + if (!fn_ty_node) { + this->FatalError("only expressions with function types can be called", + c->op->span); + } + + // We now have a function type. + FuncType fn_ty = GetRef(fn_ty_node); + + tvm::Array ty_args; + if (ty_args.size() != 0) { + throw Error("found manually suplied type args, not supported"); + } + + fn_ty = Instantiate(fn_ty, ty_args); + + std::vector arg_types; + std::vector checked_args; + + for (auto arg : c->args) { + auto checked_arg = this->Infer(arg); + arg_types.push_back(checked_arg.type); + checked_args.push_back(checked_arg.expr); + } + + auto type_arity = fn_ty->arg_types.size(); + auto number_of_args = arg_types.size(); + + if (type_arity != number_of_args) { + if (type_arity < number_of_args) { + this->FatalError("the function is provided too many arguments", c->span); + } else { + this->FatalError("the function is provided too few arguments", c->span); + } + } + + for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { + this->Unify(fn_ty->arg_types[i], arg_types[i], c->args[i]->span); + } + + // After we unify the arguments we should know more about the type + // arguments, let's run a quick pass over them to find new + // representatives. + + for (size_t i = 0; i < ty_args.size(); i++) { + ty_args.Set(i, this->unifier->Subst(ty_args[i])); + } + + // Add type constraints from the function types. + for (auto cs : fn_ty->type_constraints) { + context.AddConstraint(cs); + } + + auto new_call = + CallNode::make(checked_op.expr, checked_args, c->attrs, ty_args); + + return {new_call, fn_ty->ret_type}; +} + +CheckedExpr TypeInferencer::VisitExpr_(const LetNode* op) { + Let let = GetRef(op); + + CheckedExpr checked_value; + Type annotated_ty = Resolve(let->value_type); + + // If we are let-defining a function, we want to be able to + // recursively name the function in order to support recursive + // local definitions. + if (let->value.as()) { + context.Insert(let->var, annotated_ty); + checked_value = Infer(let->value); + } else { + checked_value = Infer(let->value); + } + + Type unified_ty = this->Unify(checked_value.type, annotated_ty, let->span); + + // Update type context with unified type now that we have + // solved this equation. + context.Insert(let->var, unified_ty); + + auto checked_body = Infer(let->body); + + auto checked_let = LetNode::make(let->var, checked_value.expr, + checked_body.expr, let->value_type); + + return {checked_let, checked_body.type}; +} + +CheckedExpr TypeInferencer::VisitExpr_(const IfNode* op) { + If ifn = GetRef(op); + + // Ensure the type of the guard is of Tensor[Bool, ()], + // that is a rank-0 boolean tensor. + auto checked_cond = this->Infer(ifn->cond); + auto cond_type = checked_cond.type; + + this->Unify(cond_type, TensorTypeNode::make({}, HalideIR::Bool()), + ifn->cond->span); + auto checked_true = this->Infer(ifn->true_branch); + auto checked_false = this->Infer(ifn->false_branch); + auto unified_type = + this->Unify(checked_true.type, checked_false.type, ifn->span); + auto checked_if = + IfNode::make(checked_cond.expr, checked_true.expr, checked_false.expr); + return {checked_if, unified_type}; +} + +CheckedExpr TypeInferencer::VisitExpr_(const OpNode* op_node) { + auto op = GetRef(op_node); + return {op, op->op_type}; +} + +Type TypeInferencer::Resolve(const Type &t) { + if (t.defined()) { + return ::tvm::relay::Resolve(this->unifier, t); + } else { + return IncompleteTypeNode::make(TypeParamNode::Kind::kType); + } +} + +Expr TypeInferencer::Resolve(const Expr &e) { + CHECK(e.defined()); + return ::tvm::relay::Resolve(this->unifier, e); +} + +void TypeInferencer::Solve(TypeRelationData & ty_rel) { + Array normalized_args; + + for (auto arg : ty_rel.args) { + normalized_args.push_back(Resolve(arg)); + } + + auto new_args = ty_rel.func(normalized_args, ty_rel.args.size()); + + CHECK(new_args.size() == normalized_args.size()); + tvm::Array final_args; + + for (size_t i = 0; i < new_args.size(); i++) { + ty_rel.args[i] = Unify(normalized_args[i], new_args[i], ty_rel.span); + } +} + +int NumSolvedVars(const Array& vars) { + int num = 0; + for (auto var : vars) { + if (!var.as()) { + num += 1; + } + } + return num; +} + +enum SolverResult : int { + Failed = -1, + Progress = 0, + Done = 1, +}; + +SolverResult TypeInferencer::Solve(std::vector& rels) { + // We start in the done state with zero progress. + SolverResult status = SolverResult::Done; + int progress = 0; + + do { + // Upon rentering the loop we reset the state. + status = SolverResult::Done; + progress = 0; + + std::vector complete; + + int i = 0; + // We will now process each relation in order. + for (TypeRelationData& ty_rel : rels) { + int arity = ty_rel.args.size(); + int pre_solved = NumSolvedVars(ty_rel.args); + RELAY_LOG(INFO) << "TypeInferencer::Solve: " + << "TypeRelation= " + << ", Arity=" << arity << ", Solved=" << pre_solved + << std::endl; + // If the relation is already solved then we will make no progress but try + // to set the status to done. + if (pre_solved == arity) { + status = static_cast((status && SolverResult::Done)); + complete.push_back(i); + // If there are unsolved variables we will try to solve some. + } else if (pre_solved < arity) { + Solve(ty_rel); + int post_solved = NumSolvedVars(ty_rel.args); + + // If we solved any variables we will try to downgrade status to + // progress update the type relation, and then bump the progress counter + // by one. + if (post_solved > pre_solved) { + status = + static_cast((status && SolverResult::Progress)); + progress += 1; + } + } + i++; + } + + // If we made no progress and we aren't finished, then the state should be + // downgraded to fail, then we should exit the loop. + if (progress == 0 && status != SolverResult::Done) { + status = SolverResult::Failed; + break; + } + + // Remove the satisfied relations. + for (auto i : complete) { + if (rels.size() > 1) { + rels[i] = rels.back(); + rels.pop_back(); + } else { + rels.pop_back(); + } + } + + std::reverse(rels.begin(), rels.end()); + } while (status == SolverResult::Progress); + return status; +} + +bool TypeInferencer::RelationsHold(bool scope_only) { + // If we are only checking the top scope, + // slice out the constraints. + // + // Otherwise we use all of them. + std::vector > constraints; + + if (scope_only) { + constraints = {context.constraints[0]}; + } else { + constraints = context.constraints; + } + + RELAY_LOG(INFO) << "TypeInferencer::RelationsHold: scope_only= " << scope_only + << std::endl; + bool all_hold = true; + for (auto ty_rels : context.constraints) { + auto status = Solve(ty_rels); + RELAY_LOG(INFO) << "status= " << status << std::endl; + if (status == SolverResult::Failed || status == SolverResult::Progress) { + all_hold = false; + } else if (status == SolverResult::Done) { + continue; + } else { + throw InternalError("found invalid value for SolverResult"); + } + } + + return all_hold; +} + +Expr InferType(const Environment& env, const Expr& e) { + TypeInferencer ti(env); + auto checked_expr = ti.Infer(e); + CHECK(ti.RelationsHold()); + return ti.Resolve(checked_expr.expr); +} + +Expr InferType(const Environment& env, const GlobalVar& var, + const Function& func) { + TypeInferencer ti(env); + auto func_copy = FunctionNode::make(func->params, func->ret_type, func->body, + func->type_params); + func_copy->checked_type_ = ti.Resolve(func_copy->fn_type()); + env->functions.Set(var, func_copy); + auto checked_expr = ti.Infer(func); + CHECK(ti.RelationsHold()); + auto map_node = env->functions.CopyOnWrite(); + map_node->data.erase(var.node_); + return ti.Resolve(checked_expr.expr); +} + +void TypeInferencer::FatalError(const std::string& msg, Span sp) { + throw FatalTypeError( + "internal error: this exception should" + "be handled and errors reported with Environment::display_errors\n" + + msg); +} + +Type TypeInferencer::Unify(const Type& t1, const Type& t2, Span sp) { + try { + return this->unifier->Unify(t1, t2); + } catch (const dmlc::Error &e) { + std::stringstream ss; + ss << "Error unifying `"; + ss << t1; + ss << "` and `"; + ss << t2; + ss << "`: " << e.what(); + this->FatalError(ss.str(), sp); + } +} + +TVM_REGISTER_API("relay._ir_pass.check_expr") + .set_body([](TVMArgs args, TVMRetValue* ret) { + Environment env = args[0]; + Expr e = args[1]; + *ret = InferType(env, e); + }); + +// TODO(@jroesch): put in a better namespace. +TVM_REGISTER_API("relay._ir_pass._get_checked_type") + .set_body([](TVMArgs args, TVMRetValue* ret) { + Expr e = args[0]; + *ret = e->checked_type(); + }); + +/* Incomplete Type */ + +IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) { + std::shared_ptr n = + std::make_shared(); + n->kind = std::move(kind); + return IncompleteType(n); +} + +TVM_REGISTER_API("relay._make.IncompleteType") + .set_body([](TVMArgs args, TVMRetValue* ret) { + int kind = args[0]; + *ret = IncompleteTypeNode::make(static_cast(kind)); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const IncompleteTypeNode* node, + tvm::IRPrinter* p) { + p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/type_subst.cc b/src/relay/pass/type_subst.cc new file mode 100644 index 000000000000..0b17fa0bc4f8 --- /dev/null +++ b/src/relay/pass/type_subst.cc @@ -0,0 +1,39 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file type_subst.cc + * \brief Function for substituting a concrete type in place of a type ID + */ +#include "./type_subst.h" +#include "./type_visitor.h" + +namespace tvm { +namespace relay { + +struct TypeSubstV : TypeMutator { + tvm::Map subst_map; + + explicit TypeSubstV(tvm::Map subst_map) + : subst_map(subst_map) {} + + Type VisitType_(const TypeParamNode* op) override { + auto id = GetRef(op); + if (subst_map.find(id) != subst_map.end()) { + return this->subst_map[id]; + } else { + return id; + } + } +}; + +Type TypeSubst(const Type& type, const TypeParam& target, const Type& subst) { + TypeSubstV ty_sub({ {target, subst} }); + return ty_sub.VisitType(type); +} + +Type TypeSubst(const Type& type, tvm::Map subst_map) { + TypeSubstV ty_sub(subst_map); + return ty_sub.VisitType(type); +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/type_subst.h b/src/relay/pass/type_subst.h new file mode 100644 index 000000000000..aee3209afb7a --- /dev/null +++ b/src/relay/pass/type_subst.h @@ -0,0 +1,19 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/relay/pass/type_subst.h + * \brief Utility functions for substituting types. + */ +#ifndef TVM_RELAY_PASS_TYPE_SUBST_H_ +#define TVM_RELAY_PASS_TYPE_SUBST_H_ + +#include + +namespace tvm { +namespace relay { + +Type TypeSubst(const Type& type, const TypeParam& target, const Type& subst); +Type TypeSubst(const Type& type, tvm::Map subst_map); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PASS_TYPE_SUBST_H_ diff --git a/src/relay/pass/type_visitor.h b/src/relay/pass/type_visitor.h new file mode 100644 index 000000000000..725e3d9b3846 --- /dev/null +++ b/src/relay/pass/type_visitor.h @@ -0,0 +1,120 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file type_visitor.h + * \brief A wrapper around TypeFunctor for common use cases. + */ +#ifndef TVM_RELAY_PASS_TYPE_VISITOR_H_ +#define TVM_RELAY_PASS_TYPE_VISITOR_H_ + +#include +#include "./type_functor.h" + +namespace tvm { +namespace relay { + +/*! \brief A type visitor for vistiors which make use of internal + * mutable state. + * + * We recursively visit each type contained inside the visitor. + */ +template +struct TypeVisitor : ::tvm::relay::TypeFunctor { + void VisitType_(const TypeParamNode* op, Args... args) override {} + + void VisitType_(const FuncTypeNode* op, Args... args) override { + for (auto type_param : op->type_params) { + this->VisitType(type_param, std::forward(args)...); + } + + for (auto type_cs : op->type_constraints) { + this->VisitType(type_cs, std::forward(args)...); + } + + for (auto arg_type : op->arg_types) { + this->VisitType(arg_type, std::forward(args)...); + } + this->VisitType(op->ret_type, std::forward(args)...); + } + + void VisitType_(const TensorTypeNode* op, Args... args) override {} + + void VisitType_(const TupleTypeNode* op, Args... args) override { + for (const Type& t : op->fields) { + this->VisitType(t, std::forward(args)...); + } + } + + void VisitType_(const TypeRelationNode* op, Args... args) override { + for (const Type& t : op->args) { + this->VisitType(t, std::forward(args)...); + } + } + + void VisitType_(const IncompleteTypeNode* op, Args... args) override {} +}; + +// A functional visitor for rebuilding an AST in place. +struct TypeMutator : TypeFunctor { + Type VisitType_(const TensorTypeNode* op) override { + // TODO(@jroesch): maybe we should recursively visit + return TensorTypeNode::make(op->shape, op->dtype); + } + + Type VisitType_(const TypeParamNode* op) override { + return GetRef(op); + } + + Type VisitType_(const FuncTypeNode* op) override { + Array type_params; + for (auto type_param : op->type_params) { + auto new_type_param = VisitType(type_param); + if (const TypeParamNode* tin = new_type_param.as()) { + type_params.push_back(GetRef(tin)); + } else { + CHECK(false) << new_type_param << std::endl; + } + } + + Array type_constraints; + for (auto type_cs : op->type_constraints) { + auto new_type_cs = VisitType(type_cs); + if (const TypeConstraintNode* tin = As(new_type_cs)) { + type_constraints.push_back(GetRef(tin)); + } else { + CHECK(false) << new_type_cs << std::endl; + } + } + + std::vector args; + for (auto arg_type : op->arg_types) { + args.push_back(VisitType(arg_type)); + } + + return FuncTypeNode::make(tvm::Array(args), VisitType(op->ret_type), + type_params, type_constraints); + } + + Type VisitType_(const TupleTypeNode* op) override { + std::vector new_fields; + for (const Type& t : op->fields) { + new_fields.push_back(this->VisitType(t)); + } + return TupleTypeNode::make(new_fields); + } + + Type VisitType_(const TypeRelationNode* type_rel) override { + std::vector new_args; + for (const Type& t : type_rel->args) { + new_args.push_back(this->VisitType(t)); + } + return TypeRelationNode::make(type_rel->name, type_rel->func_, new_args); + } + + Type VisitType_(const IncompleteTypeNode* op) override { + return GetRef(op); + } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PASS_TYPE_VISITOR_H_ diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc new file mode 100644 index 000000000000..b0ed71d17911 --- /dev/null +++ b/src/relay/pass/unifier.cc @@ -0,0 +1,324 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/src/relay/pass/unifier.cc + * \brief The type unifier which solves a system of equations between + * incomplete types. + */ + +#include "./unifier.h" +#include +#include +#include +#include +#include +#include "./type_subst.h" +#include "./type_visitor.h" + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace tvm::runtime; + +UnionFind UnionFindNode::make(tvm::Map uf_map) { + std::shared_ptr n = std::make_shared(); + n->uf_map = uf_map; + return UnionFind(n); +} + +void UnionFindNode::Insert(const IncompleteType& v) { this->uf_map.Set(v, v); } + +void UnionFindNode::debug() { + for (const auto& entry : this->uf_map) { + RELAY_LOG(INFO) << entry.first << " = " << entry.second << std::endl; + } +} + +void UnionFindNode::AssertAlphaEqual(const Type& l, const Type& r) { + if (!AlphaEqual(l, r)) { + std::stringstream ss; + ss << "Incompatible parent types in UF:" << l << " and " << r; + throw UnionFindError(ss.str()); + } +} + +void UnionFindNode::Unify(const IncompleteType& v1, const Type& t) { + RELAY_LOG(INFO) << "UnionFindNode::Unify v1=" << v1 << ", t=" << t + << std::endl; + auto parent1 = this->Find(v1); + + // if t is a type var, then unify parents + const IncompleteTypeNode *tvn2 = t.as(); + if (tvn2) { + auto v2 = GetRef(tvn2); + auto parent2 = this->Find(v2); + + // if parents are exactly equal, then we're done + if (parent1 == parent2) { + return; + } + + // if first parent is a type var, then can just set its union find map to + // second parent + if (const IncompleteTypeNode *pvn1 = parent1.as()) { + auto pv1 = GetRef(pvn1); + this->uf_map.Set(pv1, parent2); + return; + } + + // if second parent is a type var but first isn't, can set second type var + if (const IncompleteTypeNode *pvn2 = parent2.as()) { + auto pv2 = GetRef(pvn2); + this->uf_map.Set(pv2, parent1); + return; + } + + // if both parents are not type vars themselves, check alpha-equality + AssertAlphaEqual(parent1, parent2); + return; + } + + // if t is not a type var, then unify with v1's parent if parent is a type + // var; else, check alpha-equality for compatibility + if (const IncompleteTypeNode *pvn1 = parent1.as()) { + auto pv1 = GetRef(pvn1); + this->uf_map.Set(pv1, t); + return; + } + + AssertAlphaEqual(parent1, t); +} + +Type UnionFindNode::Find(const IncompleteType& v) { + // The node has no mapping, so its representative is just itself. + if (this->uf_map.find(v) == this->uf_map.end()) { + return v; + } + + Type parent = this->uf_map.at(v); + + if (v == parent) { + return v; + } + + // if parent is not a type var, then it must be the representative type + const IncompleteTypeNode *rep = parent.as(); + if (!rep) { + return parent; + } + + // otherwise, recurse and perform path compression + IncompleteType pv = GetRef(rep); + Type higher_up = this->Find(pv); + this->uf_map.Set(v, higher_up); + return higher_up; +} + +TVM_REGISTER_API("relay._make.UnionFind") + .set_body([](TVMArgs args, TVMRetValue *ret) { + if (args.size() == 0) { + *ret = UnionFindNode::make({}); + } else { + *ret = UnionFindNode::make(args[0]); + } + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const UnionFindNode *node, + tvm::IRPrinter *p) { + p->stream << "UnionFindNode(" << node->uf_map << ")"; + }); + +TypeUnifier TypeUnifierNode::make(UnionFind union_find) { + std::shared_ptr n = std::make_shared(); + n->union_find = union_find; + return TypeUnifier(n); +} + +void TypeUnifierNode::Insert(const IncompleteType& v) { + this->union_find->Insert(v); +} + +Type TypeUnifierNode::Unify(const Type& t1, const Type& t2) { + RELAY_LOG(INFO) << "TypeUnifierNode::unify: t1=" << t1 << " t2=" << t2 + << std::endl; + + Type unified = this->VisitType(t1, t2); + // TODO(@jroesch): Restore this code when we finish kind checker. + // if (!check_kind(unified)) { + // throw UnificationError("Invalid kinds in unified type"); + // } + return unified; +} + +struct IncompleteTypeSubst : TypeMutator { + const TypeUnifierNode *unifier; + + IncompleteTypeSubst(const TypeUnifierNode *unifier) : unifier(unifier) {} + + // type var: look it up in the type map and recurse + Type VisitType_(const IncompleteTypeNode* op) override { + auto tv = GetRef(op); + auto parent = unifier->union_find->Find(tv); + if (parent == tv) { + return tv; + } + return this->VisitType(parent); + } +}; + +Type TypeUnifierNode::Subst(const Type& t) { + IncompleteTypeSubst tvsubst(this); + // normalize first so substitutions in quantifiers will be correct + Type ret = tvsubst.VisitType(t); + // TODO(@jroesch): Restore this code when we finish kind checker. + // if (!check_kind(ret)) { + // std::stringstream ss; + // ss << "Invalid Kinds in substituted type!"; + // ss << t << std::endl; + // ss << ret << std::endl; + // throw SubstitutionError(ss.str()); + // } + return ret; +} + +Type TypeUnifierNode::VisitType(const Type& t1, const Type t2) { + // When the right hand size is a type variable immediately unify. + if (const IncompleteTypeNode *tvn2 = t2.as()) { + return this->UnifyWithIncompleteType(t1, GetRef(tvn2)); + } else { + return TypeFunctor::VisitType(t1, t2); + } +} + +Type TypeUnifierNode::UnifyWithIncompleteType(const Type& t1, + const IncompleteType tv2) { + RELAY_LOG(INFO) << "unifyWithIncompleteType: t1=" << t1 << " t2=" << tv2 + << std::endl; + // Fix unify to return new representative + this->union_find->Unify(tv2, t1); + auto rep = this->union_find->Find(tv2); + RELAY_LOG(INFO) << "unifyWithIncompleteType: rep =" << rep << std::endl; + return rep; +} + +Type TypeUnifierNode::VisitType_(const IncompleteTypeNode* t1, const Type rt2) { + IncompleteType tv1 = GetRef(t1); + RELAY_LOG(INFO) << "VisitType_: IncompleteTypeNode t1=" << t1 << " = " << rt2 + << std::endl; + this->union_find->Unify(tv1, rt2); + auto rep = this->union_find->Find(tv1); + RELAY_LOG(INFO) << "VisitType_: IncompleteTypeNode rep=" << rep << std::endl; + return rep; +} + +Type TypeUnifierNode::VisitType_(const TypeParamNode* t1, const Type rt2) { + TypeParam ti1 = GetRef(t1); + + if (const TypeParamNode *tin2 = rt2.as()) { + TypeParam ti2 = GetRef(tin2); + + if (ti1 != ti2) { + throw UnificationError("Attempting to unify non-matching TypeParams"); + } + + return ti1; + } + + throw UnificationError("Unable to unify TypeParamNode"); +} + +Type TypeUnifierNode::VisitType_(const FuncTypeNode* t1, const Type rt2) { + FuncType ft1 = GetRef(t1); + + if (const FuncTypeNode *tan2 = rt2.as()) { + FuncType ft2 = GetRef(tan2); + + if (ft1->type_params.size() != ft2->type_params.size()) { + throw UnificationError( + "unable to unify functions with differing number of type parameters"); + } + + tvm::Map subst_map; + + for (size_t i = 0; i < ft1->arg_types.size(); i++) { + subst_map.Set(ft1->type_params[i], ft2->type_params[i]); + } + + ft1 = Downcast(TypeSubst(ft1, subst_map)); + + if (ft1->arg_types.size() != ft2->arg_types.size()) { + throw UnificationError("unable to unify functions of different arities"); + } + + tvm::Array unified_args; + for (size_t i = 0; i < ft1->arg_types.size(); i++) { + unified_args.push_back( + this->VisitType(ft1->arg_types[i], ft2->arg_types[i])); + } + + Type unified_ret_type = this->VisitType(ft1->ret_type, ft2->ret_type); + + return FuncTypeNode::make(unified_args, unified_ret_type, {}, {}); + } + + throw UnificationError("unable to unify function types"); +} + +Type TypeUnifierNode::VisitType_(const TensorTypeNode* t1, const Type rt2) { + TensorType tt1 = GetRef(t1); + + if (const TensorTypeNode *ttn2 = rt2.as()) { + TensorType tt2 = GetRef(ttn2); + + if (!AlphaEqual(tt1, tt2)) { + throw UnificationError("dtypes do not match"); + } + + RELAY_LOG(INFO) << "Unify Tensor Shape s1=" << tt1->shape + << " s2= " << tt2->shape << std::endl; + + if (tt1->shape.size() != tt2->shape.size()) { + throw UnificationError("shapes are not of the same length"); + } + + for (size_t i = 0U; i < tt1->shape.size(); i++) { + if (!tt1->shape[i].same_as(tt2->shape[i])) { + throw UnificationError("shapes do not match at index"); + } + } + + return rt2; + } + + throw UnificationError("Cannot unify TensorTypeNode"); +} + +Type TypeUnifierNode::VisitType_(const TupleTypeNode* t1, const Type rt2) { + TupleType pt1 = GetRef(t1); + + if (const TupleTypeNode *ptn2 = rt2.as()) { + TupleType pt2 = GetRef(ptn2); + + std::vector unified_fields; + if (pt1->fields.size() != pt2->fields.size()) { + throw UnificationError("Product types are of different dimensions"); + } + + for (size_t i = 0U; i < pt1->fields.size(); i++) { + Type unified = this->VisitType(pt1->fields[i], pt2->fields[i]); + unified_fields.push_back(unified); + } + + return TupleTypeNode::make(unified_fields); + } + + throw UnificationError("Cannot unify TupleTypeNode"); +} + +Type TypeUnifierNode::VisitType_(const TypeRelationNode* tr1, const Type t2) { + throw InternalError("Cannot unify different type relations"); +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/unifier.h b/src/relay/pass/unifier.h new file mode 100644 index 000000000000..4e939cc26bca --- /dev/null +++ b/src/relay/pass/unifier.h @@ -0,0 +1,141 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file include/tvm/relay/pass/unifier.h + * \brief The type unifier which solves a system of equations between + * incomplete types. + */ +#ifndef TVM_RELAY_PASS_UNIFIER_H_ +#define TVM_RELAY_PASS_UNIFIER_H_ + +#include +#include +#include "./type_functor.h" + +namespace tvm { +namespace relay { + +struct UnionFindError : dmlc::Error { + explicit UnionFindError(const std::string& msg) : Error(msg) {} +}; + +struct UnificationError : dmlc::Error { + explicit UnificationError(const std::string& msg) : Error(msg) {} +}; + +struct SubstitutionError : dmlc::Error { + explicit SubstitutionError(const std::string& msg) : Error(msg) {} +}; + +/*! \brief A union-find data structure for the type-checker */ +class UnionFind; + +class UnionFindNode : public Node { + public: + /*! \brief The inernal map from incomplete types to their representatives. */ + tvm::Map uf_map; + + UnionFindNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("uf_map", &uf_map); } + + TVM_DLL static UnionFind make(tvm::Map uf_map); + + /*! \brief Insert it into the union find. + * \param it The type to add to the union find. + */ + void Insert(const IncompleteType& it); + + /*! \brief Union operation, combine two equivalence classes. + * \param it The incomplete type to unify. + * \param ty The other type. + */ + void Unify(const IncompleteType& it, const Type& t); + + /*! \brief Find operation, returns the representative of the argument. + * \param it The element to lookup. + */ + Type Find(const IncompleteType& it); + + void debug(); + + void AssertAlphaEqual(const Type& l, const Type& r); + + static constexpr const char* _type_key = "relay.UnionFind"; + TVM_DECLARE_NODE_TYPE_INFO(UnionFindNode, Node); +}; + +class UnionFind : public NodeRef { + public: + UnionFind() {} + explicit UnionFind(std::shared_ptr p) : NodeRef(p) {} + + // The union find structure is mutable so we do not use the standard macros + // and expose the pointer via `->`. + UnionFindNode* operator->() const { + return static_cast(node_.get()); + } + + using ContainerType = UnionFindNode; +}; + +class TypeUnifier; +class TypeUnifierNode : public Node, + private TypeFunctor { + public: + UnionFind union_find; + + TypeUnifierNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("union_find", &union_find); } + + TVM_DLL static TypeUnifier make(UnionFind uf); + + /*! \brief Introduces a new type var into the unifier */ + void Insert(const IncompleteType& v); + + /*! \brief Unifies two types if possible, throws a unification error if it + * cannot */ + Type Unify(const Type& t1, const Type& t2); + + /*! \brief Attempts to substitute all type vars in t with concrete types, + * throws substitution error if it cannot concretize*/ + Type Subst(const Type& t); + + // /*! \brief Checks the kinds in the given type */ + // Type CheckKinds(const Type& t); + + static constexpr const char* _type_key = "relay.TypeUnifier"; + TVM_DECLARE_NODE_TYPE_INFO(TypeUnifierNode, Node); + + private: + /*! \brief Unify incomplete type with another type. */ + Type UnifyWithIncompleteType(const Type& t1, const IncompleteType tvn2); + /*! \brief Implements unification between two types with incomplete portions. + */ + Type VisitType(const Type& t1, const Type t2) override; + + // Visitor Cases + Type VisitType_(const IncompleteTypeNode* t1, const Type t2) override; + Type VisitType_(const TensorTypeNode* t1, const Type t2) override; + Type VisitType_(const TypeParamNode* t1, const Type t2) override; + Type VisitType_(const FuncTypeNode* t1, const Type t2) override; + Type VisitType_(const TupleTypeNode* t1, const Type t2) override; + Type VisitType_(const TypeRelationNode* s1, const Type t2) override; +}; + +class TypeUnifier : public NodeRef { + public: + TypeUnifier() {} + explicit TypeUnifier(std::shared_ptr p) : NodeRef(p) {} + + // no const so that unifier can be mutable as a member of typechecker + inline TypeUnifierNode* operator->() const { + return static_cast(node_.get()); + } + + using ContainerType = TypeUnifierNode; +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PASS_UNIFIER_H_ diff --git a/tests/python/relay/test_ir_builder.py b/tests/python/relay/test_ir_builder.py new file mode 100644 index 000000000000..c98f920ca491 --- /dev/null +++ b/tests/python/relay/test_ir_builder.py @@ -0,0 +1,20 @@ +import numpy as np +from tvm.relay.expr import Let, Constant +from tvm.relay.ir_builder import IRBuilder + +def test_let(): + b = IRBuilder() + x = b.let('x', 1) + b.ret(x) + prog, _ = b.get() + assert isinstance(prog, Let) + var = prog.var + value = prog.value + assert var.name_hint == 'x' + assert var == prog.body + assert isinstance(value, Constant) + assert value.data.asnumpy() == np.array(1) + assert prog.value_type == None + +if __name__ == "__main__": + test_let() diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py new file mode 100644 index 000000000000..803b3d0faa0c --- /dev/null +++ b/tests/python/relay/test_ir_nodes.py @@ -0,0 +1,159 @@ +""" test ir""" +import tvm +from tvm import relay +from tvm.expr import * + +# Span +def test_span(): + span = relay.Span(None, 1, 1) + assert span.source == None + assert span.lineno == 1 + assert span.col_offset == 1 + assert span.same_as(span) + assert span == span + assert isinstance(span, relay.base.Span) + str(span) + +# Types + +def test_tensor_type(): + shape = tvm.convert([1, 2, 3]) + dtype = 'float32' + tt = relay.TensorType(shape, dtype) + assert tt.dtype == dtype + assert tt.shape == shape + assert tt.span == None + str(tt) + + +def test_type_param(): + tp = relay.TypeParam('name', relay.Kind.Shape) + tp.kind == relay.Kind.Shape + tp.span # TODO allow us to set span + str(tp) + + +def test_func_type(): + type_params = tvm.convert([]) + type_constraints = tvm.convert([]) # TODO: fill me in + arg_types = tvm.convert([]) + ret_type = None + tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints) + assert tf.type_params == type_params + assert tf.type_constraints == type_constraints + assert tf.arg_types == arg_types + assert tf.ret_type == ret_type + assert tf.span == None + # TODO make sure we can set + str(tf) + + +def test_constant(): + arr = tvm.nd.array(10) + const = relay.Constant(arr) + assert const.data == arr + assert const.span == None + str(const) + + +def test_tuple(): + fields = tvm.convert([]) + tup = relay.Tuple(fields) + assert tup.fields == fields + assert tup.span == None + str(tup) + + +def test_local_var(): + name_hint = 's' + lv = relay.Var(name_hint) + lv.name_hint == name_hint + # assert lv.span == None todo(@jroesch): what do we do about spans + str(lv) + + +def test_global_var(): + name_hint = 'g' + gv = relay.GlobalVar(name_hint) + gv.name_hint == name_hint + # assert lv.span == None todo(@jroesch): what do we do about spans + str(gv) + + +def test_param(): + lv = relay.Var('x') + ty = None + param = relay.Param(lv, ty) + assert param.var == lv + assert param.type == ty + assert param.span == None + str(param) + + +def test_function(): + param_names = ['a', 'b', 'c', 'd'] + params = tvm.convert([relay.Param(relay.Var(n), None) for n in param_names]) + ret_type = None + body = None + type_params = tvm.convert([]) + fn = relay.Function(params, ret_type, body, type_params) + assert fn.params == params + assert fn.body == body + assert fn.type_params == type_params + assert fn.span == None + str(fn) + + +def test_call(): + op = relay.Var('f') + arg_names = ['a', 'b', 'c', 'd'] + args = tvm.convert([relay.Var(n) for n in arg_names]) + call = relay.Call(op, args, None, None) + assert call.op == op + assert call.args == args + assert call.span == None + str(call) + + +def test_let(): + lv = relay.Var('x') + ty = None + arr = tvm.nd.array(10) + value = relay.Constant(arr) + # I would prefer that the order of arguments + # matches syntax let x: t = v in b + let = relay.Let(lv, value, lv, ty) + assert let.var == lv + assert let.value == value + assert let.value_type == ty + assert let.body == lv + assert let.span == None + str(let) + + +def test_if(): + cond = relay.Var('cond') + left = relay.Var('left') + right = relay.Var('right') + ife = relay.If(cond, left, right) + assert ife.cond == cond + assert ife.true_branch == left + assert ife.false_branch == right + assert ife.span == None + str(ife) + + +if __name__ == "__main__": + test_span() + test_tensor_type() + test_type_param() + test_func_type() + test_constant() + test_tuple() + test_local_var() + test_global_var() + test_param() + test_function() + test_call() + test_let() + test_if() diff --git a/tests/python/relay/test_relay_op.py b/tests/python/relay/test_relay_op.py new file mode 100644 index 000000000000..1f95a3f72c15 --- /dev/null +++ b/tests/python/relay/test_relay_op.py @@ -0,0 +1,27 @@ +from tvm import relay + +def test_op_attr(): + log_op = relay.op.get("log") + + @relay.op.register("exp", "ftest") + def test(x): + return x + 1 + + assert log_op.num_inputs == 1 + assert log_op.get_attr("ftest") is None + assert relay.op.get("exp").get_attr("ftest")(1) == 2 + +def test_op_level1(): + x = relay.Var("x") + + for op_name in ["log", "exp", "sqrt"]: + y = getattr(relay, op_name)(x) + assert y.op.name == op_name + assert y.op.support_level == 1 + assert y.args[0] == x + + +if __name__ == "__main__": + test_op_attr() + test_op_level1() + diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py new file mode 100644 index 000000000000..d95cda0ba819 --- /dev/null +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -0,0 +1,162 @@ +"""Test that type checker correcly computes types + for expressions. +""" +import tvm +import numpy as np +from tvm.relay.ir_pass import check_expr +from tvm.relay.ir_builder import IRBuilder, func_type +from tvm.relay.ir_builder import scalar_type, convert, tensor_type +from tvm.relay.env import Environment +from tvm.relay.op import log, add, equal, subtract, concat +from tvm.relay.expr import Function + +def assert_has_type(expr, typ, env=Environment({})): + checked_expr = check_expr(env, expr) + assert checked_expr.checked_type() == typ + + +def assert_decl_has_type(env, name, typ): + func = env[name] + assert func.checked_type() == typ + + +def test_monomorphic_let(): + "Program: let x = 1; return x" + b = IRBuilder() + x = b.let('x', 1.0, value_type=scalar_type('float64')) + b.ret(x) + + prog, env = b.get() + assert_has_type(prog, scalar_type('float64')) + + +def test_single_op(): + "Program: fn (x : float32) { let t1 = f(x); t1 }" + b = IRBuilder() + with b.function(('x', 'float32')) as func: + x, = func.param_ids() + t1 = b.let('t1', log(x)) + b.ret(t1) + assert_has_type(func.to_func(), func_type(['float32'], 'float32')) + +def test_add_op(): + """ + Program: + fn (x, y) { + return x + y; + } + """ + b = IRBuilder() + x = b.param('x', tensor_type(5, 5, 5)) + y = b.param('y', tensor_type(5, 5, 5)) + with b.function(x, y) as func: + b.ret(add(x.var, y.var)) + b.ret(func) + prog, env = b.get() + ttype = tensor_type(5, 5, 5) + expected_ty = func_type([ttype, ttype], ttype) + assert_has_type(func.to_func(), expected_ty) + +def test_add_broadcast_op(): + """ + Program: + fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] { + return x + y; + } + """ + b = IRBuilder() + x = b.param('x', tensor_type(10, 4)) + y = b.param('y', tensor_type(5, 10, 1)) + with b.function(x, y) as func: + b.ret(add(x.var, y.var)) + b.ret(func) + prog, env = b.get() + ttype = tensor_type(5, 5, 5) + expected_ty = func_type([ttype, ttype], ttype) + assert_has_type(func.to_func(), expected_ty) + +def test_dual_op(): + """Program: + fn (x : Tensor[f32, (10, 10)]) { + let t1 = log(x); + let t2 = add(t1, x); + return t1; + } + """ + b = IRBuilder() + with b.function(('x', tensor_type(10, 10))) as func: + x, = func.param_ids() + t1 = b.let('t1', log(x)) + t2 = b.let('t2', add(t1, x)) + b.ret(t2) + assert_has_type(func.to_func(), func_type(['float32'], 'float32')) + + +def test_decl(): + """Program: + def f(x : Tensor[f32, (10, 10)]) { + let lx = log(x); + return lx; + } + """ + b = IRBuilder() + x = b.param('x') + with b.decl('f', x): + lx = b.let('lx', log(x)) + b.ret(lx) + _, env = b.get() + assert_decl_has_type(env, 'f', func_type(['float32'], 'float32')) + + +def test_recursion(): + """ + Program: + def f(n: i32, data: f32) -> f32 { + if (n == 0) { + return f(n - 1, log(data)); + } else { + return data; + } + } + f(2, 10000); + """ + b = IRBuilder() + f = b.global_var('f') + n = b.param('n', ty='int32') + data = b.param('data', ty='float32') + with b.decl(f, n, data): + with b.if_scope(equal(n, convert(0.0))): + b.ret(f(subtract(n, convert(1)), log(data))) + with b.else_scope(): + b.ret(data) + b.ret(f(convert(2.0), convert(10000.0))) + assert_decl_has_type(b.env, 'f', func_type( + ['int32', 'float32'], 'float32')) + # TODO(@jroesch): need evaluator or new runtime + # to execute this. + +def test_concat(): + """ + Program: + def try_concat2(x: Float(3, 2), y: Float(2, 2)) -> Float(5, 2) { + return concat(x, y); + } + """ + ib = IRBuilder() + try_concat2 = ib.global_var('try_concat2') + x = ib.param('x', ty=tensor_type(3, 2)) + y = ib.param('y', ty=tensor_type(2, 2)) + with ib.decl(try_concat2, x, y): + ib.ret(concat(x, y)) + fn_ty = func_type([tensor_type(3, 2), tensor_type(2, 2)], tensor_type(5, 2)) + assert_decl_has_type(ib.env, try_concat2, fn_ty) + +if __name__ == "__main__": + # test_monomorphic_let() + # test_single_op() + # test_add_op() + # test_add_broadcast_op() + # test_dual_op() + # test_decl() + # test_recursion() + test_concat() diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index 8104bf079502..7dcd5c921905 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -18,6 +18,8 @@ TVM_FFI=cython python -m nose -v tests/python/integration || exit -1 TVM_FFI=ctypes python3 -m nose -v tests/python/integration || exit -1 TVM_FFI=cython python -m nose -v tests/python/contrib || exit -1 TVM_FFI=ctypes python3 -m nose -v tests/python/contrib || exit -1 +TVM_FFI=cython python -m nose -v tests/python/relay || exit -1 +TVM_FFI=ctypes python3 -m nose -v tests/python/relay || exit -1 # Do not enabke OpenGL # TVM_FFI=cython python -m nose -v tests/webgl || exit -1