From cc7706e3ed57c33b5d7565d52b3139040d465916 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 19 Aug 2018 16:27:52 -0700 Subject: [PATCH 001/136] Add stripped down version of expr.h and type.h This commit adds a simplified version of type.h and expr.h from the previous Relay version. We implement the basic data types and the associated machinery for exporting these to Python, as well as tests that they can be constructed, all fields are live, and can be printed using `str`. --- CMakeLists.txt | 6 + include/tvm/relay/base.h | 154 ++++++++++++ include/tvm/relay/expr.h | 361 ++++++++++++++++++++++++++++ include/tvm/relay/type.h | 243 +++++++++++++++++++ python/tvm/relay/__init__.py | 12 + python/tvm/relay/_make.py | 9 + python/tvm/relay/_make.pyi | 91 +++++++ python/tvm/relay/base.py | 27 +++ python/tvm/relay/expr.py | 69 ++++++ python/tvm/relay/make.py | 20 ++ python/tvm/relay/type.py | 51 ++++ src/relay/base.cc | 40 +++ src/relay/expr.cc | 181 ++++++++++++++ src/relay/type.cc | 100 ++++++++ tests/python/relay/test_ir_nodes.py | 154 ++++++++++++ 15 files changed, 1518 insertions(+) create mode 100644 include/tvm/relay/base.h create mode 100644 include/tvm/relay/expr.h create mode 100644 include/tvm/relay/type.h create mode 100644 python/tvm/relay/__init__.py create mode 100644 python/tvm/relay/_make.py create mode 100644 python/tvm/relay/_make.pyi create mode 100644 python/tvm/relay/base.py create mode 100644 python/tvm/relay/expr.py create mode 100644 python/tvm/relay/make.py create mode 100644 python/tvm/relay/type.py create mode 100644 src/relay/base.cc create mode 100644 src/relay/expr.cc create mode 100644 src/relay/type.cc create mode 100644 tests/python/relay/test_ir_nodes.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 572f4aef1432..65a7d9e36e2d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -104,6 +104,12 @@ file(GLOB COMPILER_SRCS src/schedule/*.cc ) +file(GLOB_RECURSE RELAY_SRCS + src/relay/*.cc + ) +list(APPEND COMPILER_SRCS ${RELAY_SRCS}) + + if(NOT MSVC) file(GLOB COMPILER_VERILOG_SRCS src/codegen/verilog/*.cc) list(APPEND COMPILER_SRCS ${COMPILER_VERILOG_SRCS}) diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h new file mode 100644 index 000000000000..3b31aae52617 --- /dev/null +++ b/include/tvm/relay/base.h @@ -0,0 +1,154 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/base.h + * \brief Base data structure for relay. + */ +#ifndef TVM_RELAY_BASE_H_ +#define TVM_RELAY_BASE_H_ + +#include +#include +#include +#include + +namespace tvm { +/*! + * \brief Relay: high level functional IR + */ +namespace relay { +/*! + * \brief we always used NodeRef for referencing nodes. + * + * By default, NodePtr is a std::shared_ptr of node + */ +using NodeRef = tvm::NodeRef; + +/*! + * \brief Content data type. + */ +using DataType = ::tvm::Type; + +/*! + * \brief Symbolic expression for tensor shape. + */ +using ShapeExpr = ::tvm::Expr; + +/*! + * \brief Hash function for nodes. + * e.g. std::unordered_map + */ +using NodeHash = ::tvm::NodeHash; +/*! + * \brief Equality check function for nodes. + */ +using NodeEqual = ::tvm::NodeEqual; + +/*! + * \brief Macro to make it easy to define node ref type given node + * \param TypeName The name of the reference type. + * \param NodeName The internal contrainer name. + * \param NodeRefBase The base type. + */ +#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \ + class TypeName : public NodeRefBase { \ + public: \ + TypeName() {} \ + explicit TypeName(std::shared_ptr<::tvm::Node> n) : NodeRefBase(n) {} \ + const NodeName* operator->() const { \ + return static_cast(node_.get()); \ + } \ + using ContainerType = NodeName; \ + }; + + +/*! + * \brief The source name in the Span + * \sa SourceNameNode, Span + */ +class SourceName; +/*! + * \brief The source name in the Span + */ +class SourceNameNode : public Node { + public: + /*! \brief The source name */ + std::string name; + // override attr visitor + void VisitAttrs(AttrVisitor* v) final { + v->Visit("name", &name); + } + + TVM_DLL static SourceName make(std::string name); + + static constexpr const char* _type_key = "relay.SourceName"; + TVM_DECLARE_NODE_TYPE_INFO(SourceNameNode, Node); +}; + +RELAY_DEFINE_NODE_REF(SourceName, SourceNameNode, NodeRef); + +/*! + * \brief Span information for debugging purposes + */ +class Span; +/*! + * \brief Stores locations in frontend source that generated a node. + * + */ +class SpanNode : public Node { + public: + /*! \brief The source name */ + SourceName source; + /*! \brief Line number */ + int lineno; + /*! \brief column offset */ + int col_offset; + // override attr visitor + void VisitAttrs(AttrVisitor* v) final { + v->Visit("source", &source); + v->Visit("lineno", &lineno); + v->Visit("col_offset", &col_offset); + } + + TVM_DLL static Span make(SourceName source, int lineno, int col_offset); + + static constexpr const char* _type_key = "relay.Span"; + TVM_DECLARE_NODE_TYPE_INFO(SpanNode, Node); +}; + +RELAY_DEFINE_NODE_REF(Span, SpanNode, NodeRef); + +/*! + * \brief This is the base node container of all relay structures. + */ +class RelayNode : public Node { + public: + /*! \brief The debug information, can be null, check with span.defined() */ + mutable Span span; + + static constexpr const char* _type_key = "relay.Node"; + TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node); +}; + +/*! + * \brief Get a reference type from a Node ptr type + * + * It is always important to get a reference type + * if we want to return a value as reference or keep + * the node alive beyond the scope of the function. + * + * \param ptr The node pointer + * \tparam RefType The reference type + * \tparam NodeType The node type + * \return The corresponding RefType + */ +template +RefType GetRef(const NodeType* ptr) { + static_assert(std::is_same::value, + "Can only cast to the ref of same container type"); + return RefType(const_cast(ptr)->shared_from_this()); +} + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BASE_H_ diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h new file mode 100644 index 000000000000..b830c7ce04ef --- /dev/null +++ b/include/tvm/relay/expr.h @@ -0,0 +1,361 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/expr.h + * \brief Relay expression IR Node. + */ +#ifndef TVM_RELAY_EXPR_H_ +#define TVM_RELAY_EXPR_H_ + +#include +#include +#include +#include +#include "./base.h" +#include "./type.h" + +namespace tvm { +namespace relay { +/*! + * \brief Relay expression. + */ +class Expr; +/*! + * \brief Base type of the Relay type hiearchy. + */ +class ExprNode : public RelayNode { + public: + /*! + * \brief Stores the result of type inference(type checking). + * + * \note This can be undefined before type inference. + * this value is discarded during serialization. + */ + Type checked_type_ = Type(nullptr); + /*! + * \return The checked_type + */ + const Type& checked_type() const { + CHECK(checked_type_.defined()) << "internal error: the type checker has " + "not populated the checked_type " + << "field for this node"; + return this->checked_type_; + } + + static constexpr const char* _type_key = "relay.Expr"; + TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode); +}; + +RELAY_DEFINE_NODE_REF(Expr, ExprNode, NodeRef); + +/*! + * \brief Constant tensor, backed by an NDArray on cpu(0). + * + * \note scalar constants are represented by rank-0 const tensor. + * Constant folding are handled uniformly via Tensor types. + */ +class Constant; +/*! + * \brief Constant tensor type. + */ +class ConstantNode : public ExprNode { + public: + /*! \brief The data of the tensor */ + runtime::NDArray data; + + // TODO(tqchen) add the function after we get TensorType constructor + // TODO(tqchen) create simple TensorType constructor for concrete types. + /*! \return The corresponding tensor type of the data */ + TensorType tensor_type() const; + + /*! \return whether it is scalar(rank-0 tensor) */ + bool is_scalar() const { return data->ndim == 0; } + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("data", &data); + v->Visit("span", &span); + } + + TVM_DLL static Constant make(runtime::NDArray data); + + static constexpr const char* _type_key = "relay.Constant"; + TVM_DECLARE_NODE_TYPE_INFO(ConstantNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Constant, ConstantNode, Expr); + +/*! \brief Tuple of multiple Exprs */ +class Tuple; +/*! \brief Tuple container */ +class TupleNode : public ExprNode { + public: + /*! \brief the fields of the tuple */ + tvm::Array fields; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("fields", &fields); + v->Visit("span", &span); + } + + TVM_DLL static Tuple make(tvm::Array fields); + + static constexpr const char* _type_key = "relay.Tuple"; + TVM_DECLARE_NODE_TYPE_INFO(TupleNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr); + +/*! + * \brief Local variables used in the let expression. + * This is similar to Var that is being used in the low level tensor expression. + * + * \note Each LocalVar is bind only once and is immutable/ + */ +class LocalVar; +/*! \brief Container for LocalVar */ +class LocalVarNode : public ExprNode { + public: + /*! \brief The name of the variable, this only acts as a hint. */ + std::string name_hint; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("name_hint", &name_hint); + } + + TVM_DLL static LocalVar make(std::string name_hint); + + static constexpr const char* _type_key = "relay.LocalVar"; + TVM_DECLARE_NODE_TYPE_INFO(LocalVarNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(LocalVar, LocalVarNode, Expr); + +/*! + * \brief Global variable that leaves in the top-level environment. + * This is used to enable recursive calls between function. + * + * \note GlobalVar can only corresponds to functions. + */ +class GlobalVar; +/*! \brief A GlobalId from the node's current type to target type. */ +class GlobalVarNode : public ExprNode { + public: + /*! \brief The name of the variable, this only acts as a hint. */ + std::string name_hint; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("name_hint", &name_hint); + } + + TVM_DLL static GlobalVar make(std::string name_hint); + + static constexpr const char* _type_key = "relay.GlobalVar"; + TVM_DECLARE_NODE_TYPE_INFO(GlobalVarNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(GlobalVar, GlobalVarNode, Expr); + +/*! + * \brief Function parameter declaration. + */ +class Param; +/*! \brief A parameter. */ +class ParamNode : public ExprNode { + public: + /*! \brief The variable */ + LocalVar var; + /*! \brief The type of the parameter */ + Type type; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("var", &var); + v->Visit("type", &type); + v->Visit("span", &span); + } + + TVM_DLL static Param make(LocalVar var, Type type); + + static constexpr const char* _type_key = "relay.Param"; + TVM_DECLARE_NODE_TYPE_INFO(ParamNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Param, ParamNode, Expr); + +/*! + * \brief Function (subgraph in computational graph) + */ +class Function; +/*! \brief Function container */ +class FunctionNode : public ExprNode { + public: + /*! \brief Function parameters */ + tvm::Array params; + /*! \brief User annotated return type of the function. */ + Type ret_type; + /*! + * \brief + * The expression which represents the computation of the function, + * the expression may reference the parameters, and the type of it + * or sub-expressions may reference the type variables. + */ + Expr body; + /*! + * \brief Type parameters of the function. + * Enables the function to vary its type based on these. + * This corresponds to template paramaters in c++'s terminology. + * + * \note This can be usually empty for non-polymorphic functions. + */ + tvm::Array type_params; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("params", ¶ms); + v->Visit("ret_type", &ret_type); + v->Visit("body", &body); + v->Visit("type_params", &type_params); + v->Visit("span", &span); + } + + TVM_DLL static Function make(tvm::Array params, Type ret_type, + Expr body, tvm::Array ty_params); + + static constexpr const char* _type_key = "relay.Function"; + TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr); + +// TODO(tqchen) change Expr to Attr after we introduce Attr system. +using Attrs = tvm::Map; + +/*! + * \brief Call corresponds to operator invocation. + * Corresponds to the operator in computational graph terminology. + */ +class Call; +/*! \brief Call container. */ +class CallNode : public ExprNode { + public: + /*! + * \brief The operator(function) being invoked + * + * - It can be relay::Op which corresponds to the primitive operators. + * - It can also be user defined functions (Function, GlobalVar, LocalVar). + */ + Expr op; + + /*! \brief The arguments(inputs) of the call */ + tvm::Array args; + + /*! \brief The additional attributes */ + Attrs attrs; + + /*! + * \brief The type arguments passed to polymorphic(template) function. + * + * This is the advance feature that is only used when the function is + * polymorphic. It is safe to be ignored in most cases. For example, in the + * following code, the type_args of addone call is [int]. + * + * \code + * + * template + * T addone(T a) { return a + 1; } + * + * void main() { + * int x = addone(10); + * } + * + * \endcode + */ + tvm::Array type_args; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("op", &op); + v->Visit("args", &args); + v->Visit("type_args", &type_args); + v->Visit("span", &span); + } + + TVM_DLL static Call make(Expr op, Array args, Attrs attrs, + Array ty_args); + + static constexpr const char* _type_key = "relay.Call"; + TVM_DECLARE_NODE_TYPE_INFO(CallNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Call, CallNode, Expr); + +/*! + * \brief Let binding that binds a local var and optionally a type annotation. + * + * \note Let is useful to transform the program to be A-normal form. + * where each of the expression corresponds to a let binding. + * + * For developers who are familar with the computational graph. + * Each of the let can be viewed as a operator node in the computational graph. + * Traversing the list of let bindings is similar to running + * PostDFS-order(topo-order) traversal on the computational graph. + */ +class Let; +/*! \brief A binding of a sub-network. */ +class LetNode : public ExprNode { + public: + /*! \brief The variable we bind to */ + LocalVar var; + /*! \brief The value we bind var to */ + Expr value; + /*! \brief The body of the let binding */ + Expr body; + /*! \brief type annotation of value, this can be null */ + Type value_type; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("var", &var); + v->Visit("value", &value); + v->Visit("body", &body); + v->Visit("value_type", &value_type); + v->Visit("span", &span); + } + + TVM_DLL static Let make(LocalVar var, Expr value, Expr body, Type value_type); + + static constexpr const char* _type_key = "relay.Let"; + TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Let, LetNode, Expr); + +/*! + * \brief Condition expression + */ +class If; +/*! \brief container of If */ +class IfNode : public ExprNode { + public: + /*! \brief The condition */ + Expr cond; + /*! \brief The value to take when condition is true */ + Expr true_value; + /*! \brief The value to take when condition is false */ + Expr false_value; + + IfNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("cond", &cond); + v->Visit("true_value", &true_value); + v->Visit("false_value", &false_value); + v->Visit("span", &span); + } + + TVM_DLL static If make(Expr cond, Expr true_value, Expr false_value); + + static constexpr const char* _type_key = "relay.If"; + TVM_DECLARE_NODE_TYPE_INFO(IfNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(If, IfNode, Expr); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_EXPR_H_ diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h new file mode 100644 index 000000000000..4c6995646114 --- /dev/null +++ b/include/tvm/relay/type.h @@ -0,0 +1,243 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/type.h + * \brief Relay typed AST nodes. + */ +#ifndef TVM_RELAY_TYPE_H_ +#define TVM_RELAY_TYPE_H_ + +#include +#include +#include +#include + +#include "./base.h" + +namespace tvm { +namespace relay { + +/*! \brief Base type of the Relay type hiearchy. */ +class TypeNode : public RelayNode { + public: + static constexpr const char* _type_key = "relay.Type"; + TVM_DECLARE_BASE_NODE_INFO(TypeNode, Node); +}; + +/*! + * \brief Type is the base type of relay type hiearchy. + * + * Relay's type system contains following two key concepts: + * + * - TensorType: type of certain Tensor values in the expression. + * - FunctionType: the type of the function. + * + * There are also advanced types to support generic(polymorphic types), + * which can be ignored when first reading the code base. + */ +class Type : public NodeRef { + public: + Type() {} + explicit Type(std::shared_ptr p) : NodeRef(p) {} + + using ContainerType = TypeNode; +}; + +/*! + * \brief Base of all Tensor types + * This container can hold TensorType or GenericTensorType. + */ +class BaseTensorTypeNode : public TypeNode { + public: + static constexpr const char* _type_key = "relay.BaseTensorType"; + TVM_DECLARE_BASE_NODE_INFO(BaseTensorTypeNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(BaseTensorType, BaseTensorTypeNode, Type); + +/*! + * \brief This is the most commonly used type in relay. + * TensorType have a fixed dimension, data type. + * + * The elements of shape can be either IntImm(constant integer), + * or any symbolic integer expression. + * The symbolic integer allows generic shape inference in certain cases. + * \sa TensorTypeNode The container class of TensorType. + */ +class TensorType; +/*! \brief TensorType container node */ +class TensorTypeNode : public BaseTensorTypeNode { + public: + /*! + * \brief The shape of the tensor, + * represented by ShapeExpr(tvm::Expr). + */ + Array shape; + /*! \brief The content data type */ + DataType dtype; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("shape", &shape); + v->Visit("dtype", &dtype); + v->Visit("span", &span); + } + + TVM_DLL static TensorType make(Array shape, DataType dtype); + + static constexpr const char* _type_key = "relay.TensorType"; + TVM_DECLARE_NODE_TYPE_INFO(TensorTypeNode, BaseTensorTypeNode); +}; + +RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type); + +/*! + * \brief Type parameter in the function. + * This can be viewed as template parameter in c++ template function. + * + * For example, in the following pesudo code, + * the TypeParam of f is TypeParam(kind=kShapeVar, var=n). + * This function can take in a Tensor with shape=(3, 3) and + * returns a Tensor with shape=(9,) + * + * \code + * + * template + * f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)] + * + * \endcode + * \sa TypeParamNode The actual container class of TypeParam + */ +class TypeParam; +/*! \brief TypeParam container node */ +class TypeParamNode : public TypeNode { + public: + /*! \brief possible kinds of TypeParam */ + enum Kind : int { + /*! \brief template variable in shape expression */ + kShapeVar = 0 + }; + /*! + * \brief The variable + * The variable itself is only meaningful when + * kind is ShapeVar, otherwise, we can only use the name. + */ + tvm::Var var; + /*! \brief The kind of type parameter */ + Kind kind; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("var", &var); + v->Visit("kind", &kind); + v->Visit("span", &span); + } + + TVM_DLL static TypeParam make(std::string name, Kind kind); + + static constexpr const char* _type_key = "relay.TypeParam"; + TVM_DECLARE_NODE_TYPE_INFO(TypeParamNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(TypeParam, TypeParamNode, Type); + +/*! + * \brief Potential Constraints in the type. + * \note This is reserved for future use. + */ +class TypeConstraint; +/*! \brief TypeConstraint container node. */ +class TypeConstraintNode : public Node { + public: + static constexpr const char* _type_key = "relay.TypeConstraint"; + TVM_DECLARE_BASE_NODE_INFO(TypeConstraintNode, Node); +}; + +RELAY_DEFINE_NODE_REF(TypeConstraint, TypeConstraintNode, NodeRef); + +class FuncType; +/*! + * \brief Function type in Relay. + * + * Relay support polymorphic function type. + * This can be roughly viewed as template function in C++. + * + * \sa TypeParam, TypeConstraint + */ +class FuncTypeNode : public TypeNode { + public: + /*! \brief type type of arguments */ + tvm::Array arg_types; + /*! \brief The type of return value. */ + Type ret_type; + // The following fields are used in polymorphic(template) functions + // For normal functions, the following two fields will be empty. + /*! \brief The type parameters of the function */ + tvm::Array type_params; + /*! + * \brief potential constraint the type need to obey + * \note this field is reserved for futher purposes. + */ + tvm::Array type_constraints; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("arg_types", &arg_types); + v->Visit("ret_type", &ret_type); + v->Visit("type_params", &type_params); + v->Visit("type_constraints", &type_constraints); + v->Visit("span", &span); + } + + TVM_DLL static FuncType make(tvm::Array arg_types, Type ret_type, + tvm::Array type_params, + tvm::Array type_constraints); + + static constexpr const char* _type_key = "relay.FuncType"; + TVM_DECLARE_NODE_TYPE_INFO(FuncTypeNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(FuncType, FuncTypeNode, Type); + +/*! + * \brief Opaque type inference function. + */ +class TypeFunction; +/*! + * \brief TypeFunction container. + * \note This node is not directly serializable. + * The type function need to be lookedup in the environment. + */ +class TypeFunctionNode : public RelayNode { + public: + /*! \brief The name of the function */ + std::string name; + /*! \brief Number of input type arguments, can be -1, which means VarArgs */ + int num_args; + /*! + * \brief The type function, + * this is not directly serializable, + * need to be looked-up in the environment. + */ + mutable std::function& arg_types)> func_; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("name", &name); + v->Visit("num_args", &num_args); + } + + TVM_DLL static TypeFunction make(std::string name, int num_args); + + static constexpr const char* _type_key = "relay.TypeFunction"; + TVM_DECLARE_NODE_TYPE_INFO(TypeFunctionNode, RelayNode); +}; + +RELAY_DEFINE_NODE_REF(TypeFunction, TypeFunctionNode, NodeRef); + +// The following fields contains advanced typing +// Only keep the class name and reserved for future usage. +class GenericTensorType; +// stores a DataType. +class GenericDataType; +// stores a DataType. +class GenericShape; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_TYPE_H_ diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py new file mode 100644 index 000000000000..c90875db4178 --- /dev/null +++ b/python/tvm/relay/__init__.py @@ -0,0 +1,12 @@ +"""Relay namespace.""" +from . import base +from . import type as tpe +from . import make + +# Type +Type = tpe.Type +TensorType = tpe.TensorType +Kind = tpe.Kind +TypeParam = tpe.TypeParam +TypeConstraint = tpe.TypeConstraint +FuncType = tpe.FuncType diff --git a/python/tvm/relay/_make.py b/python/tvm/relay/_make.py new file mode 100644 index 000000000000..20a582e76d6a --- /dev/null +++ b/python/tvm/relay/_make.py @@ -0,0 +1,9 @@ +""" +The constructors for all Relay AST nodes exposed from C++. + +This module includes MyPy type signatures for all of the +exposed modules. +""" +from .._ffi.function import _init_api + +_init_api("relay._make", __name__) diff --git a/python/tvm/relay/_make.pyi b/python/tvm/relay/_make.pyi new file mode 100644 index 000000000000..d94857916319 --- /dev/null +++ b/python/tvm/relay/_make.pyi @@ -0,0 +1,91 @@ +# from typing import Dict, List, Any, Callable, TypeVar as PyTypeVar +# import nnvm.relay.ir as ir +# import nnvm.relay.env as env +# import ctypes + +# # Environment +# def Environment(items: Dict[ir.GlobalId, ir.Item]) -> env.Environment: ... + +# # Items TODO(@jroesch) Correct Anys to the right type. +# def Operator(id: ir.OperatorId, tvm_name: str, ty: ir.Type, compiler: Any, fwd_mode: Any, rev_mode: Any) -> ir.Operator: ... +# def Defn(id: ir.GlobalId, ty: ir.Type, body: ir.Function) -> ir.Defn: ... + +# # Types +# def IntType(bits: int, lanes: int) -> ir.Type: ... +# def UIntType(bits: int, lanes: int) -> ir.Type: ... +# def FloatType(bits: int, lanes: int) -> ir.Type: ... +# def BoolType(lanes: int) -> ir.Type: ... +# def TupleType(fields: List[ir.Type]) -> ir.Type: ... +# def TensorType(dtype: ir.Type, shape: ir.Type) -> ir.Type: ... +# def TypeParam(name: str, kind: ir.Kind) -> ir.Type: ... +# def TypeQuantifier(id: ir.TypeId, body: ir.Type) -> ir.Type: ... +# def TypeArrow(left: ir.Type, right: ir.Type) -> ir.Type: ... +# def TypeVar(kind: ir.Kind) -> ir.Type: ... +# def PlaceholderType() -> ir.Type: ... +# def ShapeSeq(shapes: List[ir.Type]) -> ir.ShapeSeq: ... +# def ShapeSingleton(value: int) -> ir.ShapeSingleton: ... +# def ShapeAttr(id: ir.StringLit) -> ir.ShapeAttr: ... +# def ShapeProjection(shape: ir.Type, value: int) -> ir.ShapeProjection: ... +# def ShapeBinaryOp(op: ir.ShapeOp, left: ir.Type, right: ir.Type) -> ir.ShapeBinaryOp: ... +# def ShapeBroadcast(left: ir.Type, right: ir.Type) -> ir.ShapeBroadcast: ... +# def ShapeExtension(name: str, eval: Any) -> ir.ShapeExtension: ... +# def TypeCall(func: ir.Type, args: List[ir.Type]) -> ir.TypeCall: ... +# def RefType(data_type: ir.Type) -> ir.RefType: ... + +# # Expressions +# def Param(id: ir.LocalId, type: ir.Type) -> ir.Param: ... +# def Function(ty_params: List[ir.TypeId], params: List[ir.Param], ret_type: ir.Type, body: ir.Expr) -> ir.Function: ... +# def LocalId(name: str) -> ir.Expr: ... +# def GlobalId(name: str) -> ir.Expr: ... +# def OperatorId(name: str) -> ir.Expr: ... +# def Let(id: ir.LocalId, ty: ir.Type, value: ir.Expr, body: ir.Expr) -> ir.Expr: ... +# def IntLit(value: int) -> ir.IntLit: ... +# def FloatLit(value: float) -> ir.FloatLit: ... +# def TensorLit(value: List[ir.Expr]) -> ir.TensorLit: ... +# def Tuple(fields: List[ir.Expr]) -> ir.Expr: ... +# def BoolLit(value: bool) -> ir.BoolLit: ... +# def StringLit(value: str) -> ir.StringLit: ... +# def Attributes(attrs: Dict[str, ir.Expr]) -> ir.Attributes: ... +# def Call(func: ir.Expr, args: List[ir.Expr], attrs: ir.Attributes) -> ir.Call: ... +# def UnaryOp(op: ir.UOp, arg: ir.Expr) -> ir.Expr: ... +# def BinaryOp(op: ir.BOp, left: ir.Expr, right: ir.Expr) -> ir.Expr: ... +# def Projection(tuple: ir.Expr, field : int) -> ir.Expr: ... +# def Gradient(node: ir.Expr) -> ir.Expr: ... +# def Cast(target: ir.Type, node: ir.Expr) -> ir.Expr: ... +# def Debug(node: ir.Expr) -> ir.Expr: ... +# def Zero(type: ir.Type) -> ir.Expr: ... +# def If(guard: ir.Expr, true_branch: ir.Expr, false_branch: ir.Expr) -> ir.Expr: ... +# def Ref(value: ir.Expr) -> ir.Expr: ... +# def ReadRef(ref: ir.Expr) -> ir.Expr: ... +# def WriteRef(ref: ir.Expr, value: ir.Expr) -> ir.Expr: ... + +# # Values +# def IntValue(value: int) -> ir.TensorValue: ... +# def FloatValue(value: float) -> ir.TensorValue: ... +# def BoolValue(value: bool) -> ir.TensorValue: ... +# def TensorValue(handle: ctypes.c_void_p) -> ir.TensorValue: ... +# def Closure(env: Dict[ir.LocalId, ir.Value], fn: ir.Function) -> ir.Closure: ... + +# # Error Reporting +# def Span(file_id: ir.FileId, lineno: int, col_offset: int) -> ir.NodeBase: ... +# def FileId(file_id: int) -> ir.FileId: ... + +# # Utils +# def _alpha_eq(e1: ir.Expr, e2: ir.Expr) -> bool: ... +# def _type_alpha_eq(e1: ir.Type, e2: ir.Type) -> bool: ... +# def _expr_set_span(e: ir.Expr, sp: ir.Span) -> None: ... +# def _type_set_span(t: ir.Type, sp: ir.Span) -> None: ... +# def _item_set_span(t: ir.Item, sp: ir.Span) -> None: ... +# def Node_hash(n: ir.Node) -> int: ... +# def Operator_is_generic(op: ir.Operator) -> bool: ... + +# # FIXME +# def UnionFind() -> Any: ... +# def TypeUnifier() -> Any: ... + +# T = PyTypeVar('T') +# U = PyTypeVar('U') +# PassFunc = Callable[[env.Environment], Callable[[T], U]] + +# # Passes +# def ItemPass(name: str, pass_func: PassFunc[ir.Item, ir.Item]) -> ir.ItemPass: ... diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py new file mode 100644 index 000000000000..687ba53ac005 --- /dev/null +++ b/python/tvm/relay/base.py @@ -0,0 +1,27 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck +"""The base node types for the Relay language.""" +from __future__ import absolute_import as _abs +from typing import Union +from .._ffi.node import NodeBase, register_node as _register_tvm_node + +NodeBase = NodeBase + +def register_relay_node(type_key=None): + """register relay node type + + Parameters + ---------- + type_key : str or cls + The type key of the node + """ + if not isinstance(type_key, str): + return _register_tvm_node( + "relay." + type_key.__name__)(type_key) + return _register_tvm_node(type_key) + + +@register_relay_node +class Span(NodeBase): + source: "FileSource" + lineno: int + col_offset: int diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py new file mode 100644 index 000000000000..dea3a99f5f09 --- /dev/null +++ b/python/tvm/relay/expr.py @@ -0,0 +1,69 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +"""The expression nodes of Relay.""" +import tvm +from typing import Tuple as PyTuple, List +from enum import IntEnum +from .base import Span, NodeBase, register_relay_node +from .type import Type, TypeParam +from tvm import expr + +class Expr(NodeBase): + """The base type for all Relay exprressions.""" + pass + +@register_relay_node +class Constant(Expr): + """A constant tensor in Relay, see tvm/relay/type.h for more details. + """ + data: tvm.nd.NDArray + +@register_relay_node +class Tuple(Expr): + """A hetereogenous sequence of values. + see tvm/relay/type.h for more details. + """ + fields: List[Expr] + +@register_relay_node +class LocalVar(Expr): + """A local variable in Relay.""" + name_hint: str + +@register_relay_node +class GlobalVar(Expr): + """A global variable in Relay.""" + name_hint: str + +@register_relay_node +class Param(Expr): + """A function type in Relay, see tvm/relay/type.h for more details. + """ + var: LocalVar + type: Type + +@register_relay_node +class Function(Expr): + type_params: List[TypeParam] + params: List[Param] + ret_type: Type + body: Expr + +class Call(Expr): + op: Expr + args: List[Expr] + # todo(@jroesch): add attrs + +@register_relay_node +class Let(Expr): + var: LocalVar + value: Expr + body: Expr + value_type: Type # should be type nanotation + +@register_relay_node +class If(Expr): + cond: Expr + true_value: Expr + false_value: Expr + span: Span + diff --git a/python/tvm/relay/make.py b/python/tvm/relay/make.py new file mode 100644 index 000000000000..14d9ac040dc9 --- /dev/null +++ b/python/tvm/relay/make.py @@ -0,0 +1,20 @@ +from . import _make + +# Base Constructors +Span = _make.Span + +# Type Constructors +TensorType = _make.TensorType +TypeParam = _make.TypeParam +FuncType = _make.FuncType + +# Expr Constructors +Constant = _make.Constant +Tuple = _make.Tuple +LocalVar = _make.LocalVar +GlobalVar = _make.GlobalVar +Param = _make.Param +Function = _make.Function +Call = _make.Call +Let = _make.Let +If = _make.If diff --git a/python/tvm/relay/type.py b/python/tvm/relay/type.py new file mode 100644 index 000000000000..c92f0d756587 --- /dev/null +++ b/python/tvm/relay/type.py @@ -0,0 +1,51 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +"""The type nodes of the Relay language.""" +from typing import Tuple, List +from enum import IntEnum +from .base import Span, NodeBase, register_relay_node +from tvm import expr + +class Type(NodeBase): + """The base type for all Relay types.""" + pass + +@register_relay_node +class TensorType(Type): + """A concrete TensorType in Relay, see tvm/relay/type.h for more details. + """ + dtype: str + shape: List[expr.Expr] + span: Span + +class Kind(IntEnum): + """The kind of a type parameter, represents a variable shape, + base type, type, or dimension. + """ + Shape = 0 + BaseType = 1 + Type = 2 + Elem = 3 + +@register_relay_node +class TypeParam(Type): + """A type parameter used for generic types in Relay, + see tvm/relay/type.h for more details. + """ + var: expr.Var + kind: Kind + span: Span + +@register_relay_node +class TypeConstraint(Type): + """Abstract class representing a type constraint.""" + pass + +@register_relay_node +class FuncType(Type): + """A function type in Relay, see tvm/relay/type.h for more details. + """ + type_params: List[TypeParam] + type_constraints: List[TypeConstraint] + arg_types: List[Type] + ret_type: Type + span: Span diff --git a/src/relay/base.cc b/src/relay/base.cc new file mode 100644 index 000000000000..5fdf96ded224 --- /dev/null +++ b/src/relay/base.cc @@ -0,0 +1,40 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file base.cc + * \brief The core base types for Relay. + */ +#include +#include + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace tvm::runtime; + +Span SpanNode::make(SourceName source, int lineno, int col_offset) { + std::shared_ptr n = std::make_shared(); + n->source = std::move(source); + n->lineno = lineno; + n->col_offset = col_offset; + return Span(n); +} + +TVM_REGISTER_API("relay._make.Span") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = SpanNode::make(args[0], args[1], args[2]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const SourceNameNode *node, tvm::IRPrinter *p) { + p->stream << node->name; + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const SpanNode *node, tvm::IRPrinter *p) { + p->stream << "SpanNode(" << node->source << ", " << node->lineno << ", " + << node->col_offset << ")"; + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/expr.cc b/src/relay/expr.cc new file mode 100644 index 000000000000..38df81940e48 --- /dev/null +++ b/src/relay/expr.cc @@ -0,0 +1,181 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file expr.cc + * \brief The expression AST nodes of Relay. + */ +#include "tvm/relay/expr.h" +#include "tvm/ir_functor.h" + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace tvm::runtime; + +Constant ConstantNode::make(runtime::NDArray data) { + std::shared_ptr n = std::make_shared(); + n->data = std::move(data); + return Constant(n); +} + +TVM_REGISTER_API("relay._make.Constant") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ConstantNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const ConstantNode *node, + tvm::IRPrinter *p) { + p->stream << "ConstantNode(TODO)"; + }); + +Tuple TupleNode::make(tvm::Array fields) { + std::shared_ptr n = std::make_shared(); + n->fields = std::move(fields); + return Tuple(n); +} + +TVM_REGISTER_API("relay._make.Tuple") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TupleNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TupleNode *node, tvm::IRPrinter *p) { + p->stream << "TupleNode(" << node->fields << ")"; + }); + +LocalVar LocalVarNode::make(std::string name_hint) { + std::shared_ptr n = std::make_shared(); + n->name_hint = std::move(name_hint); + return LocalVar(n); +} + +TVM_REGISTER_API("relay._make.LocalVar") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = LocalVarNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const LocalVarNode *node, + tvm::IRPrinter *p) { + p->stream << "LocalVarNode(" << node->name_hint << ")"; + }); + +GlobalVar GlobalVarNode::make(std::string name_hint) { + std::shared_ptr n = std::make_shared(); + n->name_hint = std::move(name_hint); + return GlobalVar(n); +} + +TVM_REGISTER_API("relay._make.GlobalVar") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = GlobalVarNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const GlobalVarNode *node, + tvm::IRPrinter *p) { + p->stream << "GlobalVarNode(" << node->name_hint << ")"; + }); + +Param ParamNode::make(LocalVar var, Type type) { + std::shared_ptr n = std::make_shared(); + n->var = std::move(var); + n->type = std::move(type); + return Param(n); +} + +TVM_REGISTER_API("relay._make.Param") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ParamNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const ParamNode *node, tvm::IRPrinter *p) { + p->stream << "ParamNode(" << node->var << ", " << node->type << ")"; + }); + +Function FunctionNode::make(tvm::Array params, Type ret_type, Expr body, + tvm::Array type_params) { + std::shared_ptr n = std::make_shared(); + n->params = std::move(params); + n->ret_type = std::move(ret_type); + n->body = std::move(body); + n->type_params = std::move(type_params); + return Function(n); +} + +TVM_REGISTER_API("relay._make.Function") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = FunctionNode::make(args[0], args[1], args[2], args[3]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const FunctionNode *node, + tvm::IRPrinter *p) { + p->stream << "FunctionNode(TODO)"; + }); + +Call CallNode::make(Expr op, Array args, Attrs attrs, + Array type_args) { + std::shared_ptr n = std::make_shared(); + n->op = std::move(op); + n->args = std::move(args); + n->attrs = std::move(attrs); + n->type_args = std::move(type_args); + return Call(n); +} + +TVM_REGISTER_API("relay._make.Call") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = CallNode::make(args[0], args[1], args[2], args[3]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const CallNode *node, tvm::IRPrinter *p) { + p->stream << "CallNode(" << node->op << ", " << node->args << ", " + << node->attrs << ", " << node->type_args << ")"; + }); + +Let LetNode::make(LocalVar var, Expr value, Expr body, Type value_type) { + std::shared_ptr n = std::make_shared(); + n->var = std::move(var); + n->value = std::move(value); + n->body = std::move(body); + n->value_type = std::move(value_type); + return Let(n); +} + +TVM_REGISTER_API("relay._make.Let") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = LetNode::make(args[0], args[1], args[2], args[3]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const LetNode *node, tvm::IRPrinter *p) { + p->stream << "LetNode(" << node->var << node->value << node->body << node->value_type << ")"; + }); + +If IfNode::make(Expr cond, Expr true_value, Expr false_value) { + std::shared_ptr n = std::make_shared(); + n->cond = std::move(cond); + n->true_value = std::move(true_value); + n->false_value = std::move(false_value); + return If(n); +} + +TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = IfNode::make(args[0], args[1], args[2]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const IfNode *node, tvm::IRPrinter *p) { + p->stream << "IfNode(" << + node->cond << ", " << + node->true_value << + node->false_value << ")"; + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/type.cc b/src/relay/type.cc new file mode 100644 index 000000000000..156207e1b73a --- /dev/null +++ b/src/relay/type.cc @@ -0,0 +1,100 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file type.cc + * \brief The type system AST nodes of Relay. + */ +#include "tvm/relay/type.h" +#include "tvm/ir_functor.h" + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace tvm::runtime; + +TensorType TensorTypeNode::make(Array shape, DataType dtype) { + std::shared_ptr n = std::make_shared(); + n->shape = std::move(shape); + n->dtype = std::move(dtype); + return TensorType(n); +} + +TVM_REGISTER_API("relay._make.TensorType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Array shape = args[0]; + *ret = TensorTypeNode::make(shape, args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TensorTypeNode *node, + tvm::IRPrinter *p) { + p->stream << "TensorTypeNode(" << node->dtype << ", " << node->shape + << ")"; + }); + +TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) { + std::shared_ptr n = std::make_shared(); + n->var = tvm::Var(name); + n->kind = std::move(kind); + return TypeParam(n); +} + +TVM_REGISTER_API("relay._make.TypeParam") + .set_body([](TVMArgs args, TVMRetValue *ret) { + int kind = args[1]; + *ret = + TypeParamNode::make(args[0], static_cast(kind)); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TypeParamNode *node, + tvm::IRPrinter *p) { + p->stream << "TypeParamNode(" << node->var->name_hint << ", " + << node->kind << ")"; + }); + + +FuncType FuncTypeNode::make(tvm::Array arg_types, Type ret_type, + tvm::Array type_params, + tvm::Array type_constraints) { + std::shared_ptr n = std::make_shared(); + n->arg_types = std::move(arg_types); + n->ret_type = std::move(ret_type); + n->type_params = std::move(type_params); + n->type_constraints = std::move(type_constraints); + return FuncType(n); +} + +TVM_REGISTER_API("relay._make.FuncType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const FuncTypeNode *node, + tvm::IRPrinter *p) { + p->stream << "FuncTypeNode(" << node->type_params << ", " + << node->arg_types << ", " << node->ret_type << ", " + << node->type_constraints << ")"; + }); + +TypeFunction TypeFunctionNode::make(std::string name, int num_args) { + std::shared_ptr n = std::make_shared(); + n->name = std::move(name); + n->num_args = std::move(num_args); + return TypeFunction(n); +} + +TVM_REGISTER_API("relay._make.TypeFunction") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TypeFunctionNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TypeFunctionNode *node, + tvm::IRPrinter *p) { + p->stream << "TypeFunctionNode(" << node->name << ", " << node->num_args << ")"; + }); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py new file mode 100644 index 000000000000..26fe06109513 --- /dev/null +++ b/tests/python/relay/test_ir_nodes.py @@ -0,0 +1,154 @@ +""" test ir""" +import tvm +from tvm import relay +import tvm.relay.make as mk +from tvm import expr + +# Span + + +def test_span() -> None: + span = mk.Span(None, 1, 1) + assert span.source == None + assert span.lineno == 1 + assert span.col_offset == 1 + assert span.same_as(span) + assert span == span + assert isinstance(span, relay.base.Span) + str(span) + +# Types + + +def test_tensor_type() -> None: + shape = tvm.convert([1, 2, 3]) + dtype = 'float32' + tt = mk.TensorType(shape, dtype) + assert tt.dtype == dtype + assert tt.shape == shape + assert tt.span == None + str(tt) + + +def test_type_param() -> None: + tp = mk.TypeParam('name', relay.Kind.Shape) + tp.kind == relay.Kind.Shape + tp.span # TODO allow us to set span + str(tp) + + +def test_func_type() -> None: + type_params = tvm.convert([]) + type_constraints = tvm.convert([]) # TODO: fill me in + arg_types = tvm.convert([]) + ret_type = None + tf = mk.FuncType(arg_types, ret_type, type_params, type_constraints) + assert tf.type_params == type_params + assert tf.type_constraints == type_constraints + assert tf.arg_types == arg_types + assert tf.ret_type == ret_type + assert tf.span == None + # TODO make sure we can set + str(tf) + + +def test_constant() -> None: + arr = tvm.nd.array(10) + const = mk.Constant(arr) + assert const.data == arr + assert const.span == None + str(const) + + +def test_tuple() -> None: + fields = tvm.convert([]) + tup = mk.Tuple(fields) + assert tup.fields == fields + assert tup.span == None + str(tup) + + +def test_local_var() -> None: + name_hint = 's' + lv = mk.LocalVar(name_hint) + lv.name_hint == name_hint + # assert lv.span == None todo(@jroesch): what do we do about spans + str(lv) + + +def test_global_var() -> None: + name_hint = 'g' + gv = mk.GlobalVar(name_hint) + gv.name_hint == name_hint + # assert lv.span == None todo(@jroesch): what do we do about spans + str(gv) + + +def test_param() -> None: + lv = mk.LocalVar('x') + ty = None + param = mk.Param(lv, ty) + assert param.var == lv + assert param.type == ty + assert param.span == None + str(param) + + +def test_function() -> None: + param_names = ['a', 'b', 'c', 'd'] + params = tvm.convert([mk.Param(mk.LocalVar(n), None) for n in param_names]) + ret_type = None + body = None + type_params = tvm.convert([]) + fn = mk.Function(params, ret_type, body, type_params) + assert fn.params == params + assert fn.body == body + assert fn.type_params == type_params + assert fn.span == None + str(fn) + + +def test_call() -> None: + op = mk.LocalVar('f') + arg_names = ['a', 'b', 'c', 'd'] + args = tvm.convert([mk.LocalVar(n) for n in arg_names]) + call = mk.Call(op, args, None, None) + assert call.op == op + assert call.args == args + assert call.span == None + str(call) + + +def test_let() -> None: + lv = mk.LocalVar('x') + ty = None + arr = tvm.nd.array(10) + value = mk.Constant(arr) + # I would prefer that the order of arguments + # matches syntax let x : t = v in b + let = mk.Let(lv, value, lv, ty) + assert let.var == lv + assert let.value == value + assert let.value_type == ty + assert let.body == lv + assert let.span == None + str(let) + + +def test_if() -> None: + cond = mk.LocalVar('cond') + left = mk.LocalVar('left') + right = mk.LocalVar('right') + ife = mk.If(cond, left, right) + assert ife.cond == cond + assert ife.true_value == left + assert ife.false_value == right + assert ife.span == None + str(ife) + + +if __name__ == "__main__": + test_span() + test_tensor_type() + test_type_param() + test_func_type() From 8b79ae0cf355fb3472c504577f1fbdc46f5a92f0 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 19 Aug 2018 16:31:02 -0700 Subject: [PATCH 002/136] Add InternTable data structure --- include/tvm/relay/compiler/intern_table.h | 55 +++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 include/tvm/relay/compiler/intern_table.h diff --git a/include/tvm/relay/compiler/intern_table.h b/include/tvm/relay/compiler/intern_table.h new file mode 100644 index 000000000000..1850e513e5e5 --- /dev/null +++ b/include/tvm/relay/compiler/intern_table.h @@ -0,0 +1,55 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/compiler/intern_table.h + * \brief A table which maps string keys to data. + * + * These are useful for mapping user-readable names + * to globally unique allocations which use pointer + * equality for comparsion. + */ +#ifndef TVM_RELAY_COMPILER_INTERN_TABLE_H_ +#define TVM_RELAY_COMPILER_INTERN_TABLE_H_ + +#include +#include +#include "dmlc/logging.h" + +namespace tvm { +namespace relay { + +struct KeyNotFound : dmlc::Error { + explicit KeyNotFound(std::string msg) : dmlc::Error(msg) {} +}; + +template +class InternTable { +private: + /*! \brief The internal table mapping from strings to T. */ + std::unordered_map table_; + + public: + /*! \brief Insert a new key into the table. + * \note Attempting to reinsert a key triggers an error. + */ + void Insert(const std::string& key, const T& value) { + if (table_.find(key) == table_.end()) { + table_.insert({key, value}); + } else { + throw dmlc::Error( + std::string("you have previously interred a value for: ") + key); + } + } + + /*! \brief Lookup the data in the table. */ + const T& Lookup(std::string key) { + if (table_.find(key) != table_.end()) { + return table_.at(key); + } else { + throw KeyNotFound(std::string("could not find match") + key); + } + } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_COMPILER_INTERN_TABLE_H_ From e2b6a2f66dd5ff95bde14e263f0fed30ee90e5ea Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 19 Aug 2018 17:56:59 -0700 Subject: [PATCH 003/136] Add placeholder defn of Operator --- include/tvm/relay/expr.h | 2 +- include/tvm/relay/op.h | 47 ++++++++++++++++++++++++++++++++++++++++ src/relay/op.cc | 31 ++++++++++++++++++++++++++ 3 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 include/tvm/relay/op.h create mode 100644 src/relay/op.cc diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index b830c7ce04ef..c1dd557717af 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -1,7 +1,7 @@ /*! * Copyright (c) 2018 by Contributors * \file tvm/relay/expr.h - * \brief Relay expression IR Node. + * \brief The Relay IR expression nodes. */ #ifndef TVM_RELAY_EXPR_H_ #define TVM_RELAY_EXPR_H_ diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h new file mode 100644 index 000000000000..fa152945d38c --- /dev/null +++ b/include/tvm/relay/op.h @@ -0,0 +1,47 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/op.h + * \brief Relay's representation of operators. + */ +#ifndef TVM_RELAY_OP_H_ +#define TVM_RELAY_OP_H_ + +#include "./expr.h" + +namespace tvm { +namespace relay { + + +/*! + * \brief A primitive Relay operator defined externally to Relay. + * + * \note Currently these are expected to be backed by a TVM's operator, + * such as the ones defined in TOPI. + * + * For developers who are familar with the computational graph this + * directly maps to the concept of operators in NNVM. + */ +class Operator; +/*! \brief Container for Operator */ +class OperatorNode : public ExprNode { + public: + /*! \brief A type which specifies the relationship between the inputs and outputs + * of the operator. + */ + Type op_type; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("op_type", &op_type); + } + + TVM_DLL static Operator make(Type op_type); + + static constexpr const char* _type_key = "relay.Operator"; + TVM_DECLARE_NODE_TYPE_INFO(OperatorNode, OperatorNode); +}; + +RELAY_DEFINE_NODE_REF(Operator, OperatorNode, Expr); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_EXPR_H_ diff --git a/src/relay/op.cc b/src/relay/op.cc new file mode 100644 index 000000000000..07ad5f0ae4ed --- /dev/null +++ b/src/relay/op.cc @@ -0,0 +1,31 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file op.cc + * \brief Relay's representation of operators. + */ +#include "tvm/relay/op.h" +#include "tvm/ir_functor.h" + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace runtime; + +Operator OperatorNode::make(Type op_type) { + std::shared_ptr n = std::make_shared(); + n->op_type = std::move(op_type); + return Operator(n); +} + +TVM_REGISTER_API("relay._make.Operator").set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = OperatorNode::make(args[0]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const OperatorNode *node, tvm::IRPrinter *p) { + p->stream << "OperatorNode(" << node->op_type << ")"; + }); + +} // namespace relay +} // namespace tvm From 4d0a60de0d89e31bdc228755e59b2c136a515f62 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 19 Aug 2018 22:37:16 -0700 Subject: [PATCH 004/136] Add initial port of environment.h --- include/tvm/relay/compiler/environment.h | 110 +++++++++ include/tvm/relay/error.h | 28 +++ src/relay/compiler/environment.cc | 292 +++++++++++++++++++++++ 3 files changed, 430 insertions(+) create mode 100644 include/tvm/relay/compiler/environment.h create mode 100644 include/tvm/relay/error.h create mode 100644 src/relay/compiler/environment.cc diff --git a/include/tvm/relay/compiler/environment.h b/include/tvm/relay/compiler/environment.h new file mode 100644 index 000000000000..ddb7f0dca192 --- /dev/null +++ b/include/tvm/relay/compiler/environment.h @@ -0,0 +1,110 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file environment.h + * \brief The global environment containing + */ +#ifndef TVM_RELAY_ENVIRONMENT_H_ +#define TVM_RELAY_ENVIRONMENT_H_ + +#include +#include +#include "tvm/relay/compiler/intern_table.h" +#include "../expr.h" +#include "../type.h" +#include "../op.h" +#include "../error.h" +// #include "tvm/relay/options.h" +// #include "tvm/relay/source_map.h" + +namespace tvm { +namespace relay { + +struct Environment; + +/*! \brief The global environment of Relay programs. + * + * The global environment contains all the global + * information needed to compile a Relay program, + * including the set of operators, the set of + * global functions, and configuration options. + * + * Many operations require acess to the global + * Environment. We mostly pass the argument by value + * in a functional style as an explicit argument. + * + * This means users can construct custom environments + * easily, for example a fresh environment for each + * thread while auto-tuning. + * */ + +class EnvironmentNode : public RelayNode { + private: + /*! A map from string names to GlobalIds, ensures global uniqueness. */ + InternTable global_map_; + /*! A map from string names to Operators, ensures global uniqueness. */ + InternTable operator_map_; + // /*! \brief A map from file names to source fragments. */ + // SourceMap source_map_; + // /*! \brief A list of the errors reported during the current run. */ + // std::vector errors_; + + public: + // This map contains all items *except* operators. + std::unordered_map items; + + // Options options; + + tvm::PackedFunc jit_for(Operator op); + tvm::PackedFunc reverse(Operator op); + + EnvironmentNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final {} + + TVM_DLL static Environment make( + std::unordered_map global_funcs); + + // Add an item to the Enviroment. + // void add(const Operator& op, bool update = false); + // void add(const Operator& op, bool update = false); + + // void try_add(const Item& item, bool update=false); + // void update(const Item& item); + // void remove(const GlobalId& id); + + // GlobalId global_id(const std::string& str); + // OperatorId operator_id(const std::string& str); + + // We can lookup a GlobalId, OperatorId. + // Defn lookup(const GlobalId& id); + // Operator lookup(const OperatorId& id); + // Defn lookup_global(const std::string& str); + // Item lookup_operator(const std::string& str); + // FileId add_source(std::string file_name, std::string source); + + // tvm::Array get_operators(); + // tvm::Array get_defns(); + + // void report_error(std::string msg, Span sp); + // void display_errors(); + // void register_shape_ext(ShapeExtension ext); + + static constexpr const char* _type_key = "relay.Environment"; + TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node); +}; + +struct Environment : public NodeRef { + Environment() {} + explicit Environment(std::shared_ptr p) : NodeRef(p) {} + + inline EnvironmentNode* operator->() const { + return static_cast(node_.get()); + } + + using ContainerType = EnvironmentNode; +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_ENVIRONMENT_H_ diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h new file mode 100644 index 000000000000..d2698f8e380b --- /dev/null +++ b/include/tvm/relay/error.h @@ -0,0 +1,28 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file error.h + * \brief The set of errors raised by Relay. + */ +#ifndef TVM_RELAY_ERROR_H_ +#define TVM_RELAY_ERROR_H_ + +#include +#include "./base.h" + +namespace tvm { +namespace relay { + +struct Error : dmlc::Error { + Error(std::string msg) : dmlc::Error(msg) {} +}; + +struct SpannedError { + std::string msg; + Span sp; + SpannedError(std::string msg, Span sp) : msg(msg), sp(sp) {} +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_ERROR_H_ diff --git a/src/relay/compiler/environment.cc b/src/relay/compiler/environment.cc new file mode 100644 index 000000000000..125ceae834b3 --- /dev/null +++ b/src/relay/compiler/environment.cc @@ -0,0 +1,292 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file environment.cc + * \brief Relay global environment. + */ +#include +#include "tvm/relay/compiler/environment.h" +// #include "tvm/relay/alpha_eq.h" +// #include "tvm/relay/debug.h" +// #include "tvm/relay/typeck/typechecker.h" +// #include "tvm/relay/util/rang.h" +// #include "tvm/runtime/packed_func_ext.h" + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace tvm::runtime; + +Environment EnvironmentNode::make( + std::unordered_map global_funcs) { + std::shared_ptr n = std::make_shared(); + n->items = std::move(global_funcs); + return Environment(n); +} + +// tvm::PackedFunc EnvironmentNode::jit_for(OperatorId id) { +// return this->lookup(id)->compiler; +// } + +// GlobalId EnvironmentNode::global_id(const std::string &str) { +// try { +// return global_map_.Lookup(str); +// } catch (const KeyNotFound &err) { +// GlobalId id = GlobalIdNode::make(str); +// global_map_.Insert(str, id); +// return id; +// } +// } + +// OperatorId EnvironmentNode::operator_id(const std::string &str) { +// try { +// return operator_map_.Lookup(str); +// } catch (const KeyNotFound &err) { +// OperatorId id = OperatorIdNode::make(str); +// operator_map_.Insert(str, id); +// return id; +// } +// } + +// // Add a new item to the global environment +// // throws an exception if the item already +// // exists. +// void EnvironmentNode::add(const Item &unchecked_item, bool update) { +// // Type check the item before we add it to the environment. +// auto env = GetRef(this); +// Item item = check(env, unchecked_item); + +// if (const OperatorNode *op_node = item.as()) { +// Operator op = GetRef(op_node); +// auto type = op->type; +// if (operators.find(op->id) != operators.end()) { +// if (!update) { +// throw dmlc::Error("already have definition for XXXX."); +// } + +// auto old_type = operators[op->id]->type; + +// if (!alpha_eq(type, old_type)) { +// throw dmlc::Error( +// "Environment#update changes type, not possible in this mode."); +// } + +// operators.insert({op->id, op}); +// } else { +// operators.insert({op->id, op}); +// } +// } else if (const DefnNode *d = item.as()) { +// auto def = GetRef(d); +// auto type = def->type; +// if (items.find(def->id) != items.end()) { +// if (!update) { +// throw dmlc::Error("already have definition for XXXX."); +// } + +// auto old_type = items[def->id].as()->type; + +// if (!alpha_eq(type, old_type)) { +// throw dmlc::Error( +// "Environment#update changes type, not possible in this mode."); +// } + +// this->items.insert({def->id, def}); +// } else { +// this->items.insert({def->id, def}); +// } +// } else { +// throw EnvError("internal error: unknown item type, unreachable code"); +// } +// } + +// void EnvironmentNode::update(const Item &item) { return this->add(item, true); } + +// void EnvironmentNode::remove(const GlobalId &id) { this->items.erase(id); } + +// Defn EnvironmentNode::lookup(const GlobalId &id) { +// if (items.find(id) != items.end()) { +// return items.at(id); +// } else { +// throw EnvError(std::string("there is no definition of ") + id->name); +// } +// } + +// Operator EnvironmentNode::lookup(const OperatorId &id) { +// if (operators.find(id) != operators.end()) { +// return operators.at(id); +// } else { +// throw EnvError(std::string("there is no definition of ") + id->name); +// } +// } + +// Item EnvironmentNode::lookup_operator(const std::string &str) { +// OperatorId id = this->operator_id(str); +// return lookup(id); +// } + +// Defn EnvironmentNode::lookup_global(const std::string &str) { +// GlobalId id = this->global_id(str); +// return this->lookup(id); +// } + +// inline FileId EnvironmentNode::add_source(std::string file_name, +// std::string source) { +// return this->source_map_.add_source(file_name, source); +// } + +// void EnvironmentNode::report_error(std::string msg, Span sp) { +// this->errors_.push_back(Error(msg, sp)); +// } + +// void EnvironmentNode::display_errors() { +// for (auto err : this->errors_) { +// auto sp = err.sp; +// auto source_file = this->source_map_.GetSource(err.sp->file_id); +// auto file_name = source_file.file_name; +// auto source_at_span = source_file.SourceAt(err.sp, 1); +// std::string error_marker = "error:"; +// auto line_info = +// std::to_string(sp->lineno) + ":" + std::to_string(sp->col_offset); + +// std::cout << rang::style::bold << rang::fg::red << error_marker +// << rang::fg::reset << file_name << ":" << line_info +// << rang::style::reset << " " << source_at_span << std::endl; + +// // Build the cursor. + +// // Fix this code, hardwired to compute alignment of pointer. +// size_t spaces = error_marker.size() + line_info.size() + file_name.size() + +// sp->col_offset - 3; + +// std::string cursor = "~~~~^~~~~"; +// for (size_t i = 0; i < spaces; i++) { +// std::cout << " "; +// } +// std::cout << rang::fg::red << cursor << " " << err.msg << rang::style::reset +// << std::endl; +// } +// } + +// Array EnvironmentNode::get_operators() { +// std::vector ops; +// for (auto pair : this->operators) { +// ops.push_back(pair.second); +// } +// return Array(ops); +// } + +// Array EnvironmentNode::get_defns() { +// std::vector defns; +// for (auto pair : this->items) { +// defns.push_back(pair.second); +// } +// return Array(defns); +// } + +// void EnvironmentNode::register_shape_ext(ShapeExtension ext) { +// this->shape_exts_.Insert(ext->name, ext); +// } + +// TVM_REGISTER_API("relay._make.Environment") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// *ret = EnvironmentNode::make({}); +// }); + +// TVM_REGISTER_API("relay._env.Environment_add") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// Item item = args[1]; +// env->add(item, true); // REMOVE ME +// }); + +// TVM_REGISTER_API("relay._env.Environment_lookup_global") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// GlobalId id = args[1]; +// *ret = env->lookup(id); +// }); + +// TVM_REGISTER_API("relay._env.Environment_lookup_operator") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// OperatorId id = args[1]; +// *ret = env->lookup(id); +// }); + +// // TVM_REGISTER_API("relay._env.Environment_remove_global") +// // .set_body([](TVMArgs args, TVMRetValue *ret) { +// // Environment env = args[0]; +// // GlobalId id = args[1]; +// // env->remove(id); +// // }); + +// TVM_REGISTER_API("relay._env.Environment_global_id") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// std::string str = args[1]; +// *ret = env->global_id(str); +// }); + +// TVM_REGISTER_API("relay._env.Environment_operator_id") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// std::string str = args[1]; +// *ret = env->operator_id(str); +// }); + +// TVM_REGISTER_API("relay._env.Environment_register_shape_ext") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// ShapeExtension ext = args[1]; +// env->register_shape_ext(ext); +// }); + +// TVM_REGISTER_API("relay._env.Environment_register_primitive") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// std::string str = args[1]; +// *ret = env->global_id(str); +// }); + +// TVM_REGISTER_API("relay._env.Environment_add_source") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// std::string file_name = args[1]; +// std::string source_name = args[2]; +// *ret = env->add_source(file_name, source_name); +// }); + +// TVM_REGISTER_API("relay._env.Environment_report_error") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// std::string msg = args[1]; +// Span sp = args[2]; +// env->report_error(msg, sp); +// }); + +// TVM_REGISTER_API("relay._env.Environment_display_errors") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// return env->display_errors(); +// }); + +// TVM_REGISTER_API("relay._env.Environment_get_operators") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// *ret = env->get_operators(); +// }); + +// TVM_REGISTER_API("relay._env.Environment_get_defns") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// *ret = env->get_defns(); +// }); + +// TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +// .set_dispatch([](const EnvironmentNode *node, +// tvm::IRPrinter *p) { +// p->stream << "EnvironmentNode(todo)"; // << node->items << ")"; +// }); + +} // namespace relay +} // namespace tvm From dd392be2cf18d2e0ee72a7690c1116704de16134 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 19 Aug 2018 22:44:16 -0700 Subject: [PATCH 005/136] Add expr_functor.h --- include/tvm/relay/expr_functor.h | 143 +++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 include/tvm/relay/expr_functor.h diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h new file mode 100644 index 000000000000..922892e8a7a5 --- /dev/null +++ b/include/tvm/relay/expr_functor.h @@ -0,0 +1,143 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file expr_functor.h + * \brief A more powerful Visitor that enables defining arbitrary function + * signatures with dispatch on first argument. + */ +#ifndef TVM_RELAY_EXPR_FUNCTOR_H_ +#define TVM_RELAY_EXPR_FUNCTOR_H_ + +#include +#include +#include "ir.h" + +namespace tvm { +namespace relay { + +/*! + * \brief A dynamical functor that dispatches on in the first Expr argument. + * You can use this as a more powerful Visitor, since it allows you to + * define function signatures of Visit Function. + * + * This helps you to avoid to book-keep return value of Visitor via state, + * which can cause bugs easily when state is incorrectly maintained. + * + * \code + * // A functor that set variable to b. and calculate results. + * class MyExprFunctor + * : public ir::ExprFunctor { + * public: + * int VisitExpr_(const Variable* op, int b) final { + * return b; + * } + * int VisitExpr_(const IntImm* op, int b) final { + * return op->value; + * } + * int VisitExpr_(const Add* op, int b) final { + * return Visit(op->a, b) + Visit(op->b, b); + * } + * }; + * MyExprFunctor f; + * Var x("x"); + * CHECK_EQ(f(x + 1, 2), 3); + * \endcode + * + * \note Why do we need this more powerful Functor: + * + * We often need to implement a transformer tasks. + * Say we want to take Expr and transform it to some analysis result, + * This easily be done incorrectly using plain Visitor. See IRVisitor's + * document for possible error cases. + * + * \tparam FType function signiture + * This type if only defined for FType with function signiture R(const Expr&, + * Args...) + */ +template +class ExprFunctor; + +// functions to be overriden. +#define EXPR_FUNCTOR_DEFAULT \ + { return VisitExprDefault_(op, std::forward(args)...); } + +#define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch( \ + [](const NodeRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.node_.get()), \ + std::forward(args)...); \ + }); + +template +class ExprFunctor { + private: + using TSelf = ExprFunctor; + using FType = tvm::IRFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~ExprFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const Expr& n, Args... args) { + return VisitExpr(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitExpr(const Expr& n, Args... args) { + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitExpr_(const ConstantNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const TupleNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const LocalVarNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const GlobalVarNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ParamNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const FunctionNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const IfNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const OperatorNode* op, + Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExprDefault_(const Node* op, Args...) { + throw dmlc::Error(std::string("Do not have a default for ") + op->type_key()); + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAY_EXPR_FUNCTOR_DISPATCH(ConstantNode); + RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode); + RELAY_EXPR_FUNCTOR_DISPATCH(LocalVarNode); + RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode); + RELAY_EXPR_FUNCTOR_DISPATCH(ParamNode); + RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode); + RELAY_EXPR_FUNCTOR_DISPATCH(CallNode); + RELAY_EXPR_FUNCTOR_DISPATCH(LetNode); + RELAY_EXPR_FUNCTOR_DISPATCH(IfNode); + RELAY_EXPR_FUNCTOR_DISPATCH(OperatorNode); + return vtable; + } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_EXPR_FUNCTOR_H_ From 336fe2b57211401c563ada4763185be2d95793f6 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 19 Aug 2018 22:47:14 -0700 Subject: [PATCH 006/136] Add initial version of type_functor.h --- include/tvm/relay/compiler/type_functor.h | 93 +++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 include/tvm/relay/compiler/type_functor.h diff --git a/include/tvm/relay/compiler/type_functor.h b/include/tvm/relay/compiler/type_functor.h new file mode 100644 index 000000000000..66454725db48 --- /dev/null +++ b/include/tvm/relay/compiler/type_functor.h @@ -0,0 +1,93 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file type_functor.h + * \brief A way to defined arbitrary function signature with dispatch on types. + */ +#ifndef TVM_RELAY_COMPILER_TYPE_FUNCTOR_H_ +#define TVM_RELAY_COMPILER_TYPE_FUNCTOR_H_ + +#include +#include "ir.h" + +namespace tvm { +namespace relay { + +template +class TypeFunctor; + +// functions to be overriden. +#define TYPE_FUNCTOR_DEFAULT \ + { return VisitTypeDefault_(op, std::forward(args)...); } + +#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch( \ + [](const NodeRef& n, TSelf* self, Args... args) { \ + return self->VisitType_(static_cast(n.node_.get()), \ + std::forward(args)...); \ + }); + +template +class TypeFunctor { + private: + using TSelf = TypeFunctor; + using FType = tvm::IRFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~TypeFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const Type& n, Args... args) { + return VisitType(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitType(const Type& n, Args... args) { + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitType_(const TensorTypeNode* op, + Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeParamNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeFunction* op, Args... args) TYPE_FUNCTOR_DEFAULT; + Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + + virtual R VisitTypeDefault_(const Node* op, Args...) { + LOG(FATAL) << "Do not have a default for " << op->type_key(); + return R(); + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAY_TYPE_FUNCTOR_DISPATCH(TensorTypeNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeParamNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode); + RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeFunctionNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode); + return vtable; + } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_COMPILER_TYPE_FUNCTOR_H_ From e31e945941c0f24cb03204acf693759a1365f44f Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 19 Aug 2018 22:53:00 -0700 Subject: [PATCH 007/136] Make type_functor.h a private header --- {include/tvm => src}/relay/compiler/type_functor.h | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {include/tvm => src}/relay/compiler/type_functor.h (100%) diff --git a/include/tvm/relay/compiler/type_functor.h b/src/relay/compiler/type_functor.h similarity index 100% rename from include/tvm/relay/compiler/type_functor.h rename to src/relay/compiler/type_functor.h From 425d5f24d285163ee855b18db4be7f0b3113995d Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 19 Aug 2018 23:05:16 -0700 Subject: [PATCH 008/136] Add ir.h --- include/tvm/relay/ir.h | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 include/tvm/relay/ir.h diff --git a/include/tvm/relay/ir.h b/include/tvm/relay/ir.h new file mode 100644 index 000000000000..73c275cf1c98 --- /dev/null +++ b/include/tvm/relay/ir.h @@ -0,0 +1,20 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/ir.h + * \brief The Relay intermediate representation's core data structures. + */ +#ifndef TVM_RELAY_IR_H_ +#define TVM_RELAY_IR_H_ + +#include "./base.h" +#include "./type.h" +#include "./expr.h" +#include "./op.h" + +// namespace tvm { +// namespace relay { + +// } // namespace relay +// } // namespace tvm + +#endif // TVM_RELAY_IR_H_ From a429689fc0ce978ad02b62ed5cd3dcbcf965bb18 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 19 Aug 2018 23:07:02 -0700 Subject: [PATCH 009/136] Add back Relay's logging.h --- include/tvm/relay/logging.h | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 include/tvm/relay/logging.h diff --git a/include/tvm/relay/logging.h b/include/tvm/relay/logging.h new file mode 100644 index 000000000000..99cfc44de6cb --- /dev/null +++ b/include/tvm/relay/logging.h @@ -0,0 +1,33 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/logging.h + * \brief A wrapper around dmlc-core/logging.h which adds the ability + * to toggle logging via an environment variable. + */ + +#ifndef TVM_RELAY_LOGGING_H_ +#define TVM_RELAY_LOGGING_H_ + +#include +#include +#include +#include "dmlc/logging.h" + +namespace tvm { +namespace relay { + +static bool logging_enabled() { + if (auto var = std::getenv("RELAY_LOG")) { + std::string is_on(var); + return is_on == "1"; + } else { + return false; + } +} + +#define RELAY_LOG(severity) LOG_IF(severity, logging_enabled()) + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_LOGGING_H_ From 90337475382a11351e48db335ed0d5d6a9a3dabf Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 19 Aug 2018 23:14:49 -0700 Subject: [PATCH 010/136] Add type checker header --- include/tvm/relay/compiler/typechecker.h | 25 ++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 include/tvm/relay/compiler/typechecker.h diff --git a/include/tvm/relay/compiler/typechecker.h b/include/tvm/relay/compiler/typechecker.h new file mode 100644 index 000000000000..c71f78c1a5b0 --- /dev/null +++ b/include/tvm/relay/compiler/typechecker.h @@ -0,0 +1,25 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file tvm/relay/typechecker.h + * \brief Type check a Relay program producing a type checked program + * with its checked_type field populated and incomplete types resolved. + */ +#ifndef TVM_RELAY_COMPILER_TYPECHECKER_H_ +#define TVM_RELAY_COMPILER_TYPECHECKER_H_ + +#include "tvm/relay/ir.h" +#include "tvm/relay/environment.h" + +namespace tvm { +namespace relay { + +/*! The result of type checking an expression is a new expression + * with unambigous type information filled in, as well as it's + * checked type field populated with the result type. + */ +Expr check(const Environment & env, const Expr & e); +Operator check(const Environment & env, const Operator & op); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_COMPILER_TYPECHECKER_H_ From bb1e501d51f7aa16f92970a01e6532a0ad17d0ae Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 00:04:16 -0700 Subject: [PATCH 011/136] Add alpha_eq --- include/tvm/relay/compiler/alpha_eq.h | 19 + src/relay/compiler/alpha_eq.cc | 284 +++++++++++++ tests/python/relay/test_alpha_eq.py | 576 ++++++++++++++++++++++++++ 3 files changed, 879 insertions(+) create mode 100644 include/tvm/relay/compiler/alpha_eq.h create mode 100644 src/relay/compiler/alpha_eq.cc create mode 100644 tests/python/relay/test_alpha_eq.py diff --git a/include/tvm/relay/compiler/alpha_eq.h b/include/tvm/relay/compiler/alpha_eq.h new file mode 100644 index 000000000000..ba91afc21015 --- /dev/null +++ b/include/tvm/relay/compiler/alpha_eq.h @@ -0,0 +1,19 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/alpha_eq.h + * \brief Check expressions & types for structural equivalence. + */ +#ifndef TVM_RELAY_ALPHA_EQ_H_ +#define TVM_RELAY_ALPHA_EQ_H_ + +#include "tvm/relay/ir.h" + +namespace tvm { +namespace relay { + +bool alpha_eq(const Expr & e1, const Expr & e2); +bool alpha_eq(const Type & t1, const Type & t2); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_ALPHA_EQ_H_ diff --git a/src/relay/compiler/alpha_eq.cc b/src/relay/compiler/alpha_eq.cc new file mode 100644 index 000000000000..4b8e904bf29e --- /dev/null +++ b/src/relay/compiler/alpha_eq.cc @@ -0,0 +1,284 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file alpha_eq.cc + * \brief Compute the set of variables not bound in the expression. + */ +#include "tvm/relay/compiler/alpha_eq.h" +#include "tvm/relay/expr_visitor.h" +#include "./type_visitor.h" + +namespace tvm { +namespace relay { + +using namespace tvm::runtime; + +struct TypeAlphaEq : TypeVisitor { + tvm::Map eq_map; + bool equal; + + TypeAlphaEq() : eq_map(), equal(true) {} + + void DataTypeEqual(const DataType & dt1, const DataType & dt2) { + equal = equal && dt1 == dt2; + } + void ShapeEqual(Array s1, Array s2) { + } + + void VisitType_(const TensorTypeNode *tt1, const Type &t2) override { + if (const TensorTypeNode *tt2 = t2.as()) { + DataTypeEqual(tt1->dtype, tt2->dtype); + ShapeEqual(tt1->shape, tt2->shape); + } else { + equal = false; + } + } + +// void VisitType_(const TypeVarNode *bt1, const Type &t2) override { +// if (const TypeVarNode *bt2 = t2.as()) { +// equal = equal && bt1 == bt2; +// return; +// } else { +// equal = false; +// } +// } + + void VisitType_(const TypeParamNode *ti1, const Type &t2) override { + if (const TypeParamNode *ti2 = t2.as()) { + auto tid1 = GetRef(ti1); + auto tid2 = GetRef(ti2); + + // We handle open terms with this rule assuming variables are identical. + // + // Not sure if we should do this. + if (tid1 == tid2) { + return; + } + + // Check that they are same kind + if (tid1->kind != tid2->kind) { + equal = false; + return; + } + + // Next we see if there is mapping for local1 into the rhs term. + // If there is we check to see if those are equal. + if (eq_map.find(tid1) != eq_map.end()) { + equal = equal && eq_map[tid1] == tid2; + } else { + equal = false; + } + } else { + equal = false; + } + } + + void VisitType_(const FuncTypeNode *op, const Type &t2) override { + if (const FuncTypeNode *ta2 = t2.as()) { + if (op->arg_types.size() != ta2->arg_types.size()) { + equal = false; + return; + } + + for (size_t i = 0; i < op->arg_types.size(); i++) { + this->VisitType(op->arg_types[i], ta2->arg_types[i]); + if (!equal) { + return; + } + } + + this->VisitType(op->ret_type, ta2->ret_type); + } else { + equal = false; + } + } + + void VisitType_(const TypeFunctionNode *op, const Type &t2) override { + } +// void VisitType_(const TupleTypeNode *op, const Type &t2) override { +// if (const TupleTypeNode *pt = t2.as()) { +// if (op->fields.size() != pt->fields.size()) { +// equal = false; +// return; +// } + +// for (size_t i = 0U; i < op->fields.size(); i++) { +// if (!equal) { +// return; +// } +// this->VisitType(op->fields[i], pt->fields[i]); +// } +// } else { +// equal = false; +// } +// } + +// void VisitType_(const TypeCallNode *tyn1, const Type &t2) override { +// TypeCall tycall = GetRef(tyn1); +// if (const TypeCallNode *tyn2 = t2.as()) { +// if (tycall->func != tyn2->func) { +// equal = false; +// return; +// } + +// if (tycall->args.size() != tyn2->args.size()) { +// equal = false; +// return; +// } + +// for (size_t i = 0U; i < tycall->args.size(); i++) { +// this->VisitType(tycall->args[i], tyn2->args[i]); +// } +// } else { +// equal = false; +// } +// } +}; + +bool alpha_eq(const Type &t1, const Type &t2) { + TypeAlphaEq aeq; + aeq.VisitType(t1, t2); + return aeq.equal; +} + +// struct AlphaEq : ExprVisitor { +// public: +// tvm::Map eq_map; +// bool equal; +// AlphaEq() : eq_map(), equal(true) {} + +// void VisitExpr_(const LocalIdNode *e1, const Expr &e2) override { +// if (const LocalIdNode *id2 = e2.as()) { +// auto local1 = GetRef(e1); +// auto local2 = GetRef(id2); +// // +// // We handle open terms with this rule assuming variables are identical. +// // +// // Not sure if we should do this. +// if (local1 == local2) { +// equal = true; +// return; +// } + +// // Next we see if there is mapping for local1 into the rhs term. +// // If there is we check to see if those are equal. +// if (eq_map.find(local1) != eq_map.end()) { +// equal = equal && eq_map[local1] == local2; +// } else { +// equal = false; +// } +// } else { +// equal = false; +// } +// } + +// void VisitExpr_(const GlobalIdNode *g1, const Expr &e2) override { +// if (const GlobalIdNode *g2 = e2.as()) { +// equal = equal && g1 == g2; +// } else { +// equal = false; +// } +// } + +// void VisitExpr_(const OperatorIdNode *i1, const Expr &e2) override { +// if (const OperatorIdNode *i2 = e2.as()) { +// equal = equal && i1 == i2; +// } else { +// equal = false; +// } +// } + +// void VisitExpr_(const TupleNode *pl1, const Expr &e2) override { +// Tuple prod1 = GetRef(pl1); +// if (const TupleNode *pl2 = e2.as()) { +// Tuple prod2 = GetRef(pl2); +// if (prod1->fields.size() != prod2->fields.size()) { +// equal = false; +// return; +// } + +// for (size_t i = 0U; i < prod1->fields.size(); i++) { +// this->VisitExpr(prod1->fields[i], prod2->fields[i]); +// } +// } else { +// equal = false; +// } +// } + +// void VisitExpr_(const ParamNode *p1, const Expr &e2) override { +// if (const ParamNode *p2 = e2.as()) { +// eq_map.Set(p1->id, p2->id); +// equal = equal && alpha_eq(p1->type, p2->type); +// } else { +// equal = false; +// } +// } + +// void VisitExpr_(const FunctionNode *func1, const Expr &e2) override { +// if (const FunctionNode *func2 = e2.as()) { +// if (func1->params.size() != func2->params.size()) { +// equal = false; +// return; +// } + +// for (size_t i = 0U; i < func1->params.size(); i++) { +// this->VisitExpr(func1->params[i], func2->params[i]); +// } + +// this->VisitExpr(func1->body, func2->body); +// } else { +// equal = false; +// } +// } + +// void VisitExpr_(const CallNode *op, const Expr &e2) override { +// if (const CallNode *call = e2.as()) { +// this->VisitExpr(op->fn, call->fn); + +// if (op->args.size() != call->args.size()) { +// equal = false; +// return; +// } + +// for (size_t i = 0U; i < op->args.size(); i++) { +// this->VisitExpr(op->args[i], call->args[i]); +// } + +// } else { +// equal = false; +// } +// } + +// void VisitExpr_(const LetNode *op, const Expr &e2) override { +// if (const LetNode *let = e2.as()) { +// eq_map.Set(op->id, let->id); +// this->VisitExpr(op->value, let->value); +// this->VisitExpr(op->body, let->body); +// } else { +// equal = false; +// } +// } +// }; + +// bool alpha_eq(const Expr &e1, const Expr &e2) { +// AlphaEq eq; +// eq.VisitExpr(e1, e2); +// return eq.equal; +// } + +// // TODO(@jroesch): move to correct namespace? +// TVM_REGISTER_API("relay._make._alpha_eq") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Expr e1 = args[0]; +// Expr e2 = args[1]; +// *ret = alpha_eq(e1, e2); +// }); + +TVM_REGISTER_API("relay._make._type_alpha_eq") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Type t1 = args[0]; + Type t2 = args[1]; + *ret = alpha_eq(t1, t2); + }); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_alpha_eq.py b/tests/python/relay/test_alpha_eq.py new file mode 100644 index 000000000000..f1dc81c3c483 --- /dev/null +++ b/tests/python/relay/test_alpha_eq.py @@ -0,0 +1,576 @@ +"""Test alpha-equivalence of expressions and types.""" +# pylint: disable=invalid-name, missing-docstring +# pylint: disable=wildcard-import, unused-wildcard-import +from relay.make import * +from relay.ir import alpha_eq, ShapeOp, Kind +from relay.typing import TYPE_DEFAULTS +from relay import ir + +INT_TYPE_WIDTH = TYPE_DEFAULTS["INT_WIDTH"] +INT_TYPE_LANES = TYPE_DEFAULTS["INT_LANES"] + +def int_type(width=32) -> ir.Type: + return TensorType(IntType(width), ShapeSeq([])) + +def float_type(width=32) -> ir.Type: + return TensorType(FloatType(width), ShapeSeq([])) + +def bool_type() -> ir.Type: + return TensorType(BoolType(), ShapeSeq([])) + +def nest_quantifiers(ids, body) -> ir.Type: + ret = body + for tid in reversed(ids): + ret = TypeQuantifier(tid, ret) + return ret + +def test_local_id_not_eq() -> None: + assert not alpha_eq(LocalId("x"), LocalId("y")) + +def test_local_id_eq() -> None: + x = LocalId("x") + assert alpha_eq(x, x) + +def test_global_id_not_eq() -> None: + left = GlobalId("xyz") + right = GlobalId("xyz") + assert not alpha_eq(left, right) + +def test_global_id_eq() -> None: + ident = GlobalId("xyz") + assert alpha_eq(ident, ident) + +def test_operator_id_not_eq() -> None: + left = OperatorId("xyz") + right = OperatorId("xyz") + # equality on operator id is pointer equality + assert not alpha_eq(left, right) + +def test_operator_id_eq() -> None: + x = OperatorId("xyz") + assert alpha_eq(x, x) + +def test_float_literal_eq() -> None: + x = FloatLit(1.0) + y = FloatLit(1.0) + assert alpha_eq(x, y) + +def test_float_literal_not_eq() -> None: + x = FloatLit(1.0) + y = FloatLit(2.0) + assert not alpha_eq(x, y) + +def test_int_literal_eq() -> None: + x = IntLit(1) + y = IntLit(1) + assert alpha_eq(x, y) + +def test_int_literal_not_eq() -> None: + x = IntLit(1) + y = IntLit(2) + assert not alpha_eq(x, y) + +def test_bool_literal_eq() -> None: + x = BoolLit(True) + y = BoolLit(True) + assert alpha_eq(x, y) + +def test_bool_literal_not_eq() -> None: + x = BoolLit(True) + y = BoolLit(False) + assert not alpha_eq(x, y) + +def test_tensor_literal_eq() -> None: + x = TensorLit([IntLit(1), IntLit(2)]) + y = TensorLit([IntLit(1), IntLit(2)]) + assert alpha_eq(x, y) + +def test_tensor_literal_not_eq() -> None: + x = TensorLit([IntLit(1), IntLit(2)]) + y = TensorLit([IntLit(1), IntLit(3)]) + z = TensorLit([IntLit(1)]) + assert not alpha_eq(x, y) + assert not alpha_eq(x, z) + +def test_product_literal_eq() -> None: + x = Tuple([IntLit(1), IntLit(2)]) + y = Tuple([IntLit(1), IntLit(2)]) + assert alpha_eq(x, y) + +def test_product_literal_not_eq() -> None: + x = Tuple([IntLit(1), IntLit(2)]) + y = Tuple([IntLit(2), IntLit(2)]) + z = Tuple([IntLit(1), IntLit(2), IntLit(3)]) + assert not alpha_eq(x, y) + assert not alpha_eq(x, z) + +def test_projection_eq() -> None: + prod = Tuple([IntLit(3), FloatLit(3.5)]) + + assert alpha_eq(Projection(prod, 0), Projection(prod, 0)) + assert alpha_eq(Projection(prod, 1), Projection(prod, 1)) + +def test_projection_not_eq() -> None: + prod1 = Tuple([IntLit(3), IntLit(4)]) + prod2 = Tuple([IntLit(3)]) + prod3 = Tuple([IntLit(3), IntLit(4), FloatLit(3.5)]) + + assert not alpha_eq(Projection(prod1, 0), Projection(prod1, 1)) + assert not alpha_eq(Projection(prod1, 0), Projection(prod2, 0)) + assert not alpha_eq(Projection(prod1, 0), Projection(prod3, 0)) + assert not alpha_eq(Projection(prod1, 1), Projection(prod3, 1)) + +def test_cast_not_eq() -> None: + left = Cast(IntType(1), IntLit(2)) + right = Cast(IntType(1), IntLit(1)) + assert not alpha_eq(left, right) + + # same literal, different type + left = Cast(IntType(1), IntLit(2)) + right = Cast(IntType(2), IntLit(2)) + assert not alpha_eq(left, right) + +def test_cast_eq() -> None: + left = Cast(IntType(1), IntLit(2)) + right = Cast(IntType(1), IntLit(2)) + assert alpha_eq(left, right) + +def test_param_not_eq() -> None: + left = Param(LocalId("foo"), int_type()) + right = Param(LocalId("foo"), bool_type()) + assert not alpha_eq(left, right) + +def test_param_eq() -> None: + left = Param(LocalId("foo"), int_type()) + right = Param(LocalId("bar"), int_type()) + assert alpha_eq(left, right) + +def test_function_not_eq() -> None: + params1 = [Param(LocalId("x"), int_type())] + fn1 = Function([], params1, int_type(), LocalId("x")) + params2 = [Param(LocalId("y"), bool_type())] + fn2 = Function([], params2, int_type(), LocalId("y")) + assert not alpha_eq(fn1, fn2) + + params3 = [Param(LocalId("x"), int_type()), Param(LocalId("y"), int_type())] + fn3 = Function([], params3, int_type(), LocalId("z")) + assert not alpha_eq(fn1, fn3) + +def test_function_eq() -> None: + x = LocalId("x") + y = LocalId("y") + params1 = [Param(x, int_type())] + fn1 = Function([], params1, int_type(), x) + params2 = [Param(y, int_type())] + fn2 = Function([], params2, int_type(), y) + assert alpha_eq(fn1, fn2) + +def test_call_not_eq() -> None: + x = LocalId("x") + y = LocalId("y") + params1 = [Param(x, int_type())] + fn1 = Function([], params1, int_type(), x) + args1 = [IntLit(1)] + call1 = Call(fn1, args1) + + args2 = [IntLit(2)] + call2 = Call(fn1, args2) + assert not alpha_eq(call1, call2) + + params2 = [Param(y, int_type())] + fn2 = Function([], params2, float_type(), FloatLit(0.0)) + call3 = Call(fn2, args1) + assert not alpha_eq(call1, call3) + assert not alpha_eq(call2, call3) + +def test_call_eq() -> None: + x = LocalId("x") + y = LocalId("y") + params1 = [Param(x, int_type())] + fn1 = Function([], params1, int_type(), x) + args = [IntLit(1)] + call1 = Call(fn1, args) + + params2 = [Param(y, int_type())] + fn2 = Function([], params2, int_type(), y) + call2 = Call(fn2, args) + assert alpha_eq(call1, call2) + +def test_debug_not_eq() -> None: + left = Debug(IntLit(1)) + right = Debug(IntLit(2)) + assert not alpha_eq(left, right) + +def test_debug_eq() -> None: + left = Debug(IntLit(1)) + right = Debug(IntLit(1)) + assert alpha_eq(left, right) + +def test_let_not_eq() -> None: + x = LocalId("x") + y = LocalId("y") + let1 = Let(x, int_type(), IntLit(10), IntLit(11)) + let2 = Let(y, int_type(), IntLit(10), IntLit(12)) + assert not alpha_eq(let1, let2) + + let3 = Let(x, int_type(), IntLit(10), x) + let4 = Let(y, int_type(), IntLit(12), y) + assert not alpha_eq(let3, let4) + +def test_let_eq() -> None: + x = LocalId("x") + y = LocalId("y") + let1 = Let(x, int_type(), IntLit(10), x) + let2 = Let(y, int_type(), IntLit(10), y) + assert alpha_eq(let1, let2) + +def test_ref_eq() -> None: + r1 = Ref(IntLit(5)) + r2 = Ref(IntLit(5)) + assert alpha_eq(r1, r2) + +def test_ref_not_eq() -> None: + r1 = Ref(IntLit(5)) + r2 = Ref(FloatLit(3.5)) + r3 = Ref(r1) + assert not alpha_eq(r1, r2) + assert not alpha_eq(r1, r3) + assert not alpha_eq(r2, r3) + +def test_val_ref_eq() -> None: + vr1 = ReadRef(Ref(IntLit(35))) + vr2 = ReadRef(Ref(Tuple([IntLit(12), FloatLit(2.5)]))) + assert alpha_eq(vr1, vr1) + assert alpha_eq(vr2, vr2) + +def test_val_ref_not_eq() -> None: + vr1 = ReadRef(Ref(IntLit(5))) + vr2 = ReadRef(Ref(vr1)) + vr3 = ReadRef(Ref(FloatLit(5.0))) + assert not alpha_eq(vr1, vr2) + assert not alpha_eq(vr1, vr3) + assert not alpha_eq(vr2, vr3) + +def test_set_ref_eq() -> None: + sr1 = WriteRef(Ref(FloatLit(5.0)), FloatLit(6.0)) + sr2 = WriteRef(Ref(Tuple([IntLit(3), BoolLit(False)])), + Tuple([IntLit(5), BoolLit(True)])) + assert alpha_eq(sr1, sr1) + assert alpha_eq(sr2, sr2) + +def test_set_ref_not_eq() -> None: + r1 = Ref(FloatLit(5.0)) + r2 = Ref(IntLit(5)) + r3 = Ref(IntLit(6)) + + assert not alpha_eq(WriteRef(r1, FloatLit(6.0)), + WriteRef(r2, IntLit(6))) + assert not alpha_eq(WriteRef(r2, IntLit(6)), WriteRef(r2, IntLit(7))) + assert not alpha_eq(WriteRef(r2, IntLit(7)), WriteRef(r3, IntLit(7))) + +# Type alpha-equality tests + +def test_base_type_eq() -> None: + assert alpha_eq(IntType(32), IntType(32)) + assert alpha_eq(BoolType(), BoolType()) + assert alpha_eq(FloatType(32), FloatType(32)) + +def test_tensor_type_eq() -> None: + tt1 = TensorType( + IntType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) + tt2 = TensorType( + FloatType(32), ShapeSeq([ShapeSingleton(3), ShapeSingleton(3)])) + assert alpha_eq(tt1, tt1) + assert alpha_eq(tt2, tt2) + +def test_tensor_type_not_eq() -> None: + tt1 = TensorType( + IntType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) + tt2 = TensorType( + FloatType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) + tt3 = TensorType( + IntType(32), ShapeSeq([ShapeSingleton(3), ShapeSingleton(3)])) + assert not alpha_eq(tt1, tt2) + assert not alpha_eq(tt1, tt3) + +def test_ref_type_eq() -> None: + rt1 = RefType(int_type()) + rt2 = RefType(float_type()) + assert alpha_eq(rt1, rt1) + assert alpha_eq(rt2, rt2) + +def test_ref_type_not_eq() -> None: + rt1 = RefType(int_type()) + rt2 = RefType(float_type()) + assert not alpha_eq(rt1, rt2) + +def test_product_type_eq() -> None: + pt1 = TupleType([int_type(), RefType(float_type())]) + pt2 = TupleType([float_type(), float_type(), int_type()]) + assert alpha_eq(pt1, pt1) + assert alpha_eq(pt2, pt2) + +def test_product_type_not_eq() -> None: + pt1 = TupleType([int_type(), int_type()]) + pt2 = TupleType([int_type(), int_type(), float_type()]) + pt3 = TupleType([bool_type(), float_type()]) + assert not alpha_eq(pt1, pt2) + assert not alpha_eq(pt1, pt3) + +def test_type_id_eq() -> None: + id1 = TypeParam("id1", Kind.Shape) + id2 = TypeParam("id2", Kind.BaseType) + id3 = TypeParam("id2", Kind.Type) + + assert alpha_eq(id1, id1) + assert alpha_eq(id2, id2) + assert alpha_eq(id3, id3) + +def test_type_id_not_eq() -> None: + # name is just a hint, we use pointer equality as the rule + # (unless there is a quantifier to give context) + id1 = TypeParam("id1", Kind.Shape) + id2 = TypeParam("id1", Kind.Shape) + id3 = TypeParam("id3", Kind.BaseType) + + assert not alpha_eq(id1, id2) + assert not alpha_eq(id1, id3) + +def test_arrow_type_eq() -> None: + ar1 = TypeArrow([int_type()], bool_type()) + ar2 = TypeArrow([int_type(), int_type()], TupleType([])) + assert alpha_eq(ar1, ar1) + assert alpha_eq(ar2, ar2) + +def test_arrow_type_not_eq() -> None: + t1 = int_type() + t2 = bool_type() + t3 = [int_type(), bool_type()] + + assert not alpha_eq(TypeArrow([t1], t2), TypeArrow([t1], t1)) + assert not alpha_eq(TypeArrow(t3, t1), TypeArrow([t2], t1)) + assert not alpha_eq(TypeArrow([t1], TypeArrow([t1], t1)), + TypeArrow([t1], t1)) + +def test_type_quantifier_eq() -> None: + id1 = TypeParam("id1", Kind.Shape) + id2 = TypeParam("id2", Kind.Shape) + tq1 = TypeQuantifier(id1, TensorType(IntType(32), id1)) + tq2 = TypeQuantifier(id2, TensorType(IntType(32), id2)) + + assert alpha_eq(tq1, tq1) + assert alpha_eq(tq1, tq2) + +def test_nested_type_quantifier_eq() -> None: + id1 = TypeParam("id1", Kind.BaseType) + id2 = TypeParam("id2", Kind.Shape) + id3 = TypeParam("id3", Kind.BaseType) + id4 = TypeParam("id4", Kind.Shape) + tq1 = TypeQuantifier(id1, TypeQuantifier(id2, TensorType(id1, id2))) + tq2 = TypeQuantifier(id3, TypeQuantifier(id4, TensorType(id3, id4))) + + assert alpha_eq(tq1, tq1) + assert alpha_eq(tq1, tq2) + +def test_type_quantifier_not_eq() -> None: + id1 = TypeParam("id1", Kind.Shape) + id2 = TypeParam("id2", Kind.BaseType) + id3 = TypeParam("id3", Kind.Shape) + + tq1 = TypeQuantifier(id1, TensorType(IntType(32), id1)) + tq2 = TypeQuantifier(id2, TensorType(id2, ShapeSeq([ShapeSingleton(3)]))) + tq3 = TypeQuantifier(id1, TensorType(IntType(32), id3)) + tq4 = TypeQuantifier(id1, TensorType(FloatType(32), id1)) + + assert not alpha_eq(tq1, tq2) + assert not alpha_eq(tq1, tq3) + assert not alpha_eq(tq1, tq4) + assert not alpha_eq(tq2, tq3) + assert not alpha_eq(tq2, tq4) + +def test_shape_singleton_eq() -> None: + single1 = ShapeSingleton(10) + single2 = ShapeSingleton(10) + + assert alpha_eq(single1, single1) + assert alpha_eq(single1, single2) + +def test_shape_singelton_not_eq() -> None: + single1 = ShapeSingleton(10) + single2 = ShapeSingleton(11) + + assert not alpha_eq(single1, single2) + +def test_shape_attr_eq() -> None: + attr1 = ShapeAttr("x") + attr2 = ShapeAttr("x") + + assert alpha_eq(attr1, attr1) + assert alpha_eq(attr1, attr2) + +def test_shape_attr_not_eq() -> None: + id1 = "x" + id2 = "y" + attr1 = ShapeAttr(id1) + attr2 = ShapeAttr(id2) + + assert not alpha_eq(attr1, attr2) + +def test_shape_seq_eq() -> None: + empty = ShapeSeq([]) + seq1 = ShapeSeq([ShapeSingleton(5)]) + seq2 = ShapeSeq([ShapeSingleton(5)]) + + assert alpha_eq(empty, empty) + assert alpha_eq(seq1, seq2) + +def test_shape_seq_not_eq() -> None: + empty = ShapeSeq([]) + seq = ShapeSeq([ShapeSingleton(5)]) + single = ShapeSingleton(5) + + assert not alpha_eq(empty, seq) + assert not alpha_eq(seq, single) + +def test_shape_projection_eq() -> None: + proj1 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) + proj2 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) + + assert alpha_eq(proj1, proj2) + +def test_shape_projection_not_eq() -> None: + proj1 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) + proj2 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 1) + proj3 = ShapeProjection(ShapeSeq([ShapeSingleton(2), ShapeSingleton(1)]), 0) + proj4 = ShapeProjection(ShapeSeq([ShapeSingleton(2), ShapeSingleton(1)]), 1) + + assert not alpha_eq(proj1, proj2) + assert not alpha_eq(proj1, proj3) + assert not alpha_eq(proj1, proj4) + assert not alpha_eq(proj2, proj3) + assert not alpha_eq(proj2, proj4) + assert not alpha_eq(proj3, proj4) + +def test_shape_binary_op_eq() -> None: + empty = ShapeSeq([]) + single = ShapeSingleton(5) + seq = ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]) + + op1 = ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty) + op2 = ShapeBinaryOp(ShapeOp.SHSUB, single, single) + op3 = ShapeBinaryOp(ShapeOp.SHMUL, seq, seq) + op4 = ShapeBinaryOp(ShapeOp.SHDIV, seq, seq) + + assert alpha_eq(op1, op1) + assert alpha_eq(op2, op2) + assert alpha_eq(op3, op3) + assert alpha_eq(op4, op4) + +def test_shape_binary_op_not_eq() -> None: + empty = ShapeSeq([]) + single = ShapeSingleton(5) + seq = ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]) + + assert not alpha_eq(ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty), empty) + assert not alpha_eq(ShapeBinaryOp(ShapeOp.SHMUL, seq, ShapeSingleton(1)), seq) + assert not alpha_eq( + ShapeBinaryOp(ShapeOp.SHPLUS, single, single), + ShapeBinaryOp(ShapeOp.SHPLUS, + ShapeSeq([single]), + ShapeSeq([single]))) + assert not alpha_eq( + ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty), + ShapeBinaryOp(ShapeOp.SHSUB, empty, empty)) + assert not alpha_eq( + ShapeBinaryOp(ShapeOp.SHMUL, empty, empty), + ShapeBinaryOp(ShapeOp.SHDIV, empty, empty)) + +def test_shape_nested_in_quantifier() -> None: + b1 = TypeParam("b", Kind.BaseType) + x1 = TypeParam("x", Kind.Shape) + y1 = TypeParam("y", Kind.Shape) + + b2 = TypeParam("b", Kind.BaseType) + x2 = TypeParam("x", Kind.Shape) + y2 = TypeParam("y", Kind.Shape) + + b3 = TypeParam("b", Kind.BaseType) + x3 = TypeParam("x", Kind.Shape) + y3 = TypeParam("y", Kind.Shape) + + tq1 = nest_quantifiers( + [b1, x1, y1], + TypeArrow( + [TensorType(b1, x1), TensorType(b1, y2)], + TensorType( + b1, + ShapeBinaryOp(ShapeOp.SHPLUS, + ShapeSeq([x1, ShapeProjection(y1, 1), + ShapeSingleton(5), ShapeAttr("att")]), + ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + + tq2 = nest_quantifiers( + [b2, x2, y2], + TypeArrow( + [TensorType(b2, x2), TensorType(b2, y2)], + TensorType( + b2, + ShapeBinaryOp(ShapeOp.SHPLUS, + ShapeSeq([x2, ShapeProjection(y2, 1), + ShapeSingleton(5), ShapeAttr("att")]), + ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + + # different attr, var order, position, and constant + tq3 = nest_quantifiers( + [b3, x3, y3], + TypeArrow( + [TensorType(b3, x3), TensorType(b3, y3)], + TensorType( + b3, + ShapeBinaryOp(ShapeOp.SHPLUS, + ShapeSeq([x3, ShapeProjection(y3, 1), + ShapeSingleton(4), ShapeAttr("att")]), + ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + + tq4 = nest_quantifiers( + [b3, x3, y3], + TypeArrow( + [TensorType(b3, x3), TensorType(b3, y3)], + TensorType( + b3, + ShapeBinaryOp(ShapeOp.SHPLUS, + ShapeSeq([x3, ShapeProjection(y3, 2), + ShapeSingleton(5), ShapeAttr("att2")]), + ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + + tq5 = nest_quantifiers( + [b3, x3, y3], + TypeArrow( + [TensorType(b3, x3), TensorType(b3, y3)], + TensorType( + b3, + ShapeBinaryOp(ShapeOp.SHMUL, + ShapeSeq([x3, ShapeProjection(y3, 1), + ShapeSingleton(5), ShapeAttr("att")]), + ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + + tq6 = nest_quantifiers( + [b3, y3, x3], + TypeArrow( + [TensorType(b3, x3), TensorType(b3, y3)], + TensorType( + b3, + ShapeBinaryOp(ShapeOp.SHPLUS, + ShapeSeq([x3, ShapeProjection(y3, 1), + ShapeSingleton(5), ShapeAttr("att")]), + ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + + assert alpha_eq(tq1, tq2) + assert not alpha_eq(tq1, tq3) + assert not alpha_eq(tq2, tq3) + assert not alpha_eq(tq1, tq4) + assert not alpha_eq(tq2, tq4) + assert not alpha_eq(tq1, tq5) + assert not alpha_eq(tq2, tq5) + assert not alpha_eq(tq1, tq6) + assert not alpha_eq(tq2, tq6) From 4a84f2595e7ce10a146550609a271ffd8e0f07a5 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 00:06:44 -0700 Subject: [PATCH 012/136] Add incomplete_type.h --- src/relay/compiler/incomplete_type.h | 36 ++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 src/relay/compiler/incomplete_type.h diff --git a/src/relay/compiler/incomplete_type.h b/src/relay/compiler/incomplete_type.h new file mode 100644 index 000000000000..8f360d1cd51c --- /dev/null +++ b/src/relay/compiler/incomplete_type.h @@ -0,0 +1,36 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file incomplete_type.h + * \brief A way to defined arbitrary function signature with dispatch on types. + */ + +#ifndef TVM_RELAY_COMPILER_INCOMPLETE_TYPE_H +#define TVM_RELAY_COMPILER_INCOMPLETE_TYPE_H + +#include "tvm/relay/ir.h" + +namespace tvm { +namespace relay { + +/*! + * \brief Represents a portion of an incomplete type. + */ +class IncompleteType; + +/*! \brief IncompleteType container node */ +class IncompleteTypeNode : public TypeNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) final {} + + TVM_DLL static IncompleteType make(); + + static constexpr const char* _type_key = "relay.IncompleteType"; + TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(IncompleteType, IncompleteTypeNode, Type); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_COMPILER_INCOMPLETE_TYPE_H From 2a2131cb174bf6b00f0e3f4f76cb071d9599aa02 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 00:07:05 -0700 Subject: [PATCH 013/136] Add type call --- include/tvm/relay/type.h | 33 ++++++++++++++++++++++++++++++++- src/relay/type.cc | 19 +++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 4c6995646114..dfe4309b7c77 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -228,7 +228,38 @@ class TypeFunctionNode : public RelayNode { TVM_DECLARE_NODE_TYPE_INFO(TypeFunctionNode, RelayNode); }; -RELAY_DEFINE_NODE_REF(TypeFunction, TypeFunctionNode, NodeRef); +RELAY_DEFINE_NODE_REF(TypeFunction, TypeFunctionNode, Type); + +/*! + * \brief Call a type function with some number of arguments. + */ +class TypeCall; +/*! + * \brief TypeCall container. + */ +class TypeCallNode : public TypeNode { + public: + /*! \brief The type function to be called. */ + Type func; + /*! \brief The type arguments to the type function. */ + tvm::Array args; + + TypeCallNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("func", &func); + v->Visit("args", &args); + } + + Type eval() const; + + TVM_DLL static TypeCall make(Type func, tvm::Array args); + + static constexpr const char* _type_key = "relay.TypeCall"; + TVM_DECLARE_NODE_TYPE_INFO(TypeCallNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(TypeCall, TypeCallNode, Type); // The following fields contains advanced typing // Only keep the class name and reserved for future usage. diff --git a/src/relay/type.cc b/src/relay/type.cc index 156207e1b73a..22d37ea05fda 100644 --- a/src/relay/type.cc +++ b/src/relay/type.cc @@ -96,5 +96,24 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "TypeFunctionNode(" << node->name << ", " << node->num_args << ")"; }); +TypeCall TypeCallNode::make(Type func, Array args) { + std::shared_ptr n = std::make_shared(); + n->func = std::move(func); + n->args = std::move(args); + return TypeCall(n); +} + +TVM_REGISTER_API("relay._make.TypeCall") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TypeCallNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TypeCallNode *node, + tvm::IRPrinter *p) { + p->stream << "TypeCallNode(" << node->func << ", " << node->args << ")"; + }); + + } // namespace relay } // namespace tvm From 9e7b18abf5dea70f9644faec83f7c1d700b3f7e4 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 13:23:13 -0700 Subject: [PATCH 014/136] Add test for let with IR builder --- python/tvm/relay/ir_builder.py | 104 ++++++++++++++++++++++++++ tests/python/relay/test_ir_builder.py | 23 ++++++ 2 files changed, 127 insertions(+) create mode 100644 python/tvm/relay/ir_builder.py create mode 100644 tests/python/relay/test_ir_builder.py diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py new file mode 100644 index 000000000000..497479140ec9 --- /dev/null +++ b/python/tvm/relay/ir_builder.py @@ -0,0 +1,104 @@ +from typing import Any +import numpy as np +import tvm +from . import type as ty +from . import expr +from . import make as mk + + +def convert(arg: Any, ctxt=tvm.cpu(0)) -> tvm.nd.NDArray: + """Convert Python values into the appropriate types + for the Relay evaluator. + """ + if isinstance(arg, int): + return tvm.nd.array(arg, ctxt) + elif isinstance(arg, float): + return tvm.nd.array(arg, ctxt) + elif isinstance(arg, bool): + return tvm.nd.array(arg, ctxt) + elif isinstance(arg, np.ndarray): + return tvm.nd.array(arg, ctxt) + elif isinstance(arg, tvm.ndarray.NDArray): + return arg + else: + # raise Exception(f"can't convert {type(arg)} to a Relay AST") + raise Exception(f"unsupported argument type {type(arg)}") + +def into_ast(arg: Any, ctxt=tvm.cpu(0)) -> expr.Expr: + if isinstance(arg, tuple): + raise Exception("..") + else: + value = convert(arg, ctxt) + return mk.Constant(value) + +class WithScope(object): + """Auxiliary scope with""" + + def __init__(self, enter_value, exit_cb): + self._enter_value = enter_value + self._exit_cb = exit_cb + + def __enter__(self): + return self._enter_value + + def __exit__(self, ptype, value, trace): + self._exit_cb() + +def _mk_let(bindings, ret_value): + let_expr = ret_value + for var, value in reversed(list(bindings.items())): + let_expr = mk.Let(var, value, let_expr, None) + + return let_expr + +class IRBuilder(): + def __init__(self): + self.bindings = [{}] + self.scopes = [{}] + self.ret_value = None + + def bind(self, name, type, value): + lv = mk.LocalVar(name) + self.scopes[-1][name] = lv + self.bindings[-1][lv] = value + return lv + + + def let(self, name, value): + if not isinstance(value, expr.Expr): + value = into_ast(value) + + return self.bind(name, None, value) + + def function(self, params): + def _on_exit(): + bindings = self.bindings.pop() + scope = self.scopes.pop() + import pdb + pdb.set_trace() + return WithScope(None, _on_exit) + + def ret(self, x): + if not self.ret_value: + self.ret_value = x + else: + raise Exception( + "return value already set, a function can only have one return value") + + def get(self): + """Get the full program""" + bindings = self.bindings.pop() + scope = self.scopes.pop() + + if self.bindings: + raise Exception("...") + if self.scopes: + raise Exception("...") + + if not self.ret_value: + raise Exception("...") + + return _mk_let(bindings, self.ret_value) + + + diff --git a/tests/python/relay/test_ir_builder.py b/tests/python/relay/test_ir_builder.py new file mode 100644 index 000000000000..666d7ff25659 --- /dev/null +++ b/tests/python/relay/test_ir_builder.py @@ -0,0 +1,23 @@ +import numpy as np +from tvm.relay.expr import Let, Constant +from tvm.relay.ir_builder import IRBuilder + +def test_let(): + b = IRBuilder() + x = b.let('x', 1) + b.ret(x) + prog = b.get() + assert isinstance(prog, Let) + var = prog.var + value = prog.value + assert var.name_hint == 'x' + assert var == prog.body + assert isinstance(value, Constant) + assert value.data.asnumpy() == np.array(1) + assert prog.value_type == None + +# def test_function(): +# b = IRBuilder() + +if __name__ == "__main__": + test_let() From 133b1dc53769d1b7c576cd0f6280b92b4a54afbd Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 13:24:43 -0700 Subject: [PATCH 015/136] Add initial version of unifier and old tests --- src/relay/compiler/unifier.cc | 477 ++++++++++++++++++++++++++++ src/relay/compiler/unifier.h | 129 ++++++++ tests/python/relay/test_unifier.py | 480 +++++++++++++++++++++++++++++ 3 files changed, 1086 insertions(+) create mode 100644 src/relay/compiler/unifier.cc create mode 100644 src/relay/compiler/unifier.h create mode 100644 tests/python/relay/test_unifier.py diff --git a/src/relay/compiler/unifier.cc b/src/relay/compiler/unifier.cc new file mode 100644 index 000000000000..bfd3e1a5ff32 --- /dev/null +++ b/src/relay/compiler/unifier.cc @@ -0,0 +1,477 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file unifier.cc + * \brief Data structures for type unification + */ + +#include "tvm/relay/ir.h" +#include "tvm/relay/logging.h" +#include "tvm/relay/compiler/alpha_eq.h" +#include "./unifier.h" +#include "./type_visitor.h" +// #include "tvm/relay/typeck/kindchecker.h" +// #include "tvm/relay/typeck/type_subst.h" + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace tvm::runtime; + +UnionFind UnionFindNode::make(tvm::Map uf_map) { + std::shared_ptr n = std::make_shared(); + n->uf_map = uf_map; + return UnionFind(n); +} + +void UnionFindNode::insert(const IncompleteType &v) { this->uf_map.Set(v, v); } + +void UnionFindNode::debug() { + for (auto entry : this->uf_map) { + std::cout << entry.first << " = " << entry.second << std::endl; + } +} + +void UnionFindNode::assertAlphaEq(const Type & l, const Type & r) { + if (!alpha_eq(l, r)) { + std::stringstream ss; + ss << "Incompatible parent types in UF:" << l << " and " << r; + throw UnionFindError(ss.str()); + } +} + +void UnionFindNode::unify(const IncompleteType &v1, const Type &t) { + RELAY_LOG(INFO) << "UnionFindNode::Unify v1=" << v1 << "t=" << t << std::endl; + auto parent1 = this->find(v1); + + // if t is a type var, then unify parents + const IncompleteTypeNode *tvn2 = t.as(); + if (tvn2) { + auto v2 = GetRef(tvn2); + auto parent2 = this->find(v2); + + // if parents are exactly equal, then we're done + if (parent1 == parent2) { + return; + } + + // if first parent is a type var, then can just set its union find map to + // second parent + if (const IncompleteTypeNode *pvn1 = parent1.as()) { + auto pv1 = GetRef(pvn1); + this->uf_map.Set(pv1, parent2); + // path compression: can also set v1 directly + this->uf_map.Set(v1, parent2); + return; + } + + // if second parent is a type var but first isn't, can set second type var + if (const IncompleteTypeNode *pvn2 = parent2.as()) { + auto pv2 = GetRef(pvn2); + this->uf_map.Set(pv2, parent1); + // path compression: can also set v2 directly + this->uf_map.Set(v2, parent1); + return; + } + + // if both parents are not type vars themselves, check alpha-equality + assertAlphaEq(parent1, parent2); + return; + } + + // if t is not a type var, then unify with v1's parent if parent is a type + // var; else, check alpha-equality for compatibility + if (const IncompleteTypeNode *pvn1 = parent1.as()) { + auto pv1 = GetRef(pvn1); + this->uf_map.Set(pv1, t); + // path compression: can also set v1 directly + this->uf_map.Set(v1, t); + return; + } + + assertAlphaEq(parent1, t); +} + +Type UnionFindNode::find(const IncompleteType &v) { + // The node has no mapping, so its representative is just itself. + if (this->uf_map.find(v) == this->uf_map.end()) { + return v; + } + + Type parent = this->uf_map.at(v); + + if (v == parent) { + return v; + } + + // if parent is not a type var, then it must be the representative type + const IncompleteTypeNode *rep = parent.as(); + if (!rep) { + return parent; + } + + // otherwise, recurse and perform path compression + IncompleteType pv = GetRef(rep); + Type higher_up = this->find(pv); + this->uf_map.Set(v, higher_up); + return higher_up; +} + +TVM_REGISTER_API("relay._make.UnionFind") + .set_body([](TVMArgs args, TVMRetValue *ret) { + if (args.size() == 0) { + *ret = UnionFindNode::make({}); + } else { + *ret = UnionFindNode::make(args[0]); + } + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const UnionFindNode *node, + tvm::IRPrinter *p) { + p->stream << "UnionFindNode(" << node->uf_map << ")"; + }); + +TypeUnifier TypeUnifierNode::make(UnionFind uf) { + std::shared_ptr n = std::make_shared(); + n->uf = uf; + return TypeUnifier(n); +} + +void TypeUnifierNode::insert(const IncompleteType &v) { this->uf->insert(v); } + +Type TypeUnifierNode::unify(const Type &t1, const Type &t2) { + RELAY_LOG(INFO) << "TypeUnifierNode::unify: t1=" << t1 << " t2=" << t2 + << std::endl; + + Type unified = this->VisitType(t1, t2); + // if (!check_kind(unified)) { + // throw UnificationError("Invalid kinds in unified type"); + // } + return unified; +} + +struct IncompleteTypeSubst : TypeFVisitor { + const TypeUnifierNode *unifier; + + IncompleteTypeSubst(const TypeUnifierNode *unifier) : unifier(unifier) {} + + // type var: look it up in the type map and recurse + Type VisitType_(const IncompleteTypeNode *op) override { + auto tv = GetRef(op); + auto parent = unifier->uf->find(tv); + if (parent == tv) { + return tv; + } + return this->VisitType(parent); + } +}; + +Type TypeUnifierNode::subst(const Type &t) { + IncompleteTypeSubst tvsubst(this); + // normalize first so substitutions in quantifiers will be correct + Type ret = tvsubst.VisitType(t); + // if (!check_kind(ret)) { + // std::stringstream ss; + // ss << "Invalid Kinds in substituted type!"; + // ss << t << std::endl; + // ss << ret << std::endl; + // throw SubstitutionError(ss.str()); + // } + return ret; +} + +Type TypeUnifierNode::VisitType_(const IncompleteTypeNode *t1, const Type rt2) { + IncompleteType tv1 = GetRef(t1); + RELAY_LOG(INFO) << "VisitType_: IncompleteTypeNode t1=" << t1 << " = " << rt2 + << std::endl; + this->uf->unify(tv1, rt2); + auto rep = this->uf->find(tv1); + RELAY_LOG(INFO) << "VisitType_: IncompleteTypeNode rep=" << rep << std::endl; + return rep; +} + +Type TypeUnifierNode::VisitType_(const TypeParamNode *t1, const Type rt2) { + TypeParam ti1 = GetRef(t1); + + // for typevars, remap and attempt to unify if already defined + if (const IncompleteTypeNode *tvn2 = rt2.as()) { + return this->unifyWithIncompleteType(ti1, GetRef(tvn2)); + } + + // for other type ids, only check equality + if (const TypeParamNode *tin2 = rt2.as()) { + TypeParam ti2 = GetRef(tin2); + + if (ti1 != ti2) { + throw UnificationError("Attempting to unify non-matching TypeParams"); + } + + return ti1; + } + + // cannot unify TypeParam with non-TypeParam + throw UnificationError("Unable to unify TypeParamNode"); +} + +Type TypeUnifierNode::VisitType_(const FuncTypeNode *t1, const Type rt2) { + return rt2; +// TypeArrow ta1 = GetRef(t1); + +// // for typevar, remap if necessary +// if (const IncompleteTypeNode *tvn2 = rt2.as()) { +// return this->unifyWithIncompleteType(ta1, GetRef(tvn2)); +// } + +// // for other arrow, unify arg and ret types +// if (const TypeArrowNode *tan2 = rt2.as()) { +// TypeArrow ta2 = GetRef(tan2); + +// if (ta1->arg_types.size() != ta2->arg_types.size()) { +// throw UnificationError("unable to unify functions of different arities"); +// } + +// tvm::Array unified_args; +// for (size_t i = 0; i < ta1->arg_types.size(); i++) { +// unified_args.push_back( +// this->VisitType(ta1->arg_types[i], ta2->arg_types[i])); +// } + +// Type unified_ret_type = this->VisitType(ta1->ret_type, ta2->ret_type); +// return TypeArrowNode::make(unified_args, unified_ret_type); +// } + +// throw UnificationError("Unable to unify TypeArrowNode"); +// } + +// Type TypeUnifierNode::VisitType_(const TypeQuantifierNode *t1, const Type rt2) { +// TypeQuantifier tq1 = GetRef(t1); + +// // for typevars, remap and attempt to unify if already defined +// if (const IncompleteTypeNode *tvn2 = rt2.as()) { +// return this->unifyWithIncompleteType(tq1, GetRef(tvn2)); +// } + +// // for other quantifiers, attempt to unify bound types after normalizing +// if (const TypeQuantifierNode *tqn2 = rt2.as()) { +// TypeQuantifier tq2 = GetRef(tqn2); +// TypeParam id1 = tq1->id; +// TypeParam id2 = tq2->id; + +// if (id1->kind != id2->kind) { +// throw UnificationError( +// "Cannot unify quantifiers over ids of different kinds"); +// } + +// TypeParam fresh = TypeParamNode::make(id1->name, id1->kind); + +// auto bt1 = type_subst(tq1->boundType, id1, fresh); +// auto bt2 = type_subst(tq2->boundType, id2, fresh); + +// Type unified_bound_type = this->VisitType(bt1, bt2); +// return TypeQuantifierNode::make(fresh, unified_bound_type); +// } + +// // anything else cannot be unified +// throw UnificationError("Cannot unify TypeQuantifierNode"); +} + +Type TypeUnifierNode::VisitType_(const TensorTypeNode *t1, const Type rt2) { + TensorType tt1 = GetRef(t1); + + // for typevars, remap and attempt to unify if already defined + if (const IncompleteTypeNode *tvn2 = rt2.as()) { + return this->unifyWithIncompleteType(tt1, GetRef(tvn2)); + } + + if (const TensorTypeNode *ttn2 = rt2.as()) { + TensorType tt2 = GetRef(ttn2); + + if (!alpha_eq(tt1, tt2)) { + throw UnificationError("dtypes do not match"); + } + + RELAY_LOG(INFO) << "Unify Tensor Shape s1=" << tt1->shape + << " s2= " << tt2->shape << std::endl; + try { + // Type unified_shape = this->VisitType(tt1->shape, tt2->shape); + return rt2; + } catch (const UnificationError & err) { + std::cout << "Need to check constraint " << tt1->shape << " = " << tt2->shape << std::endl; + } + + // fix me + return rt2; + // return TensorTypeNode::make(unified_bt, tt2->shape); + } + + // nothing else can unify + throw UnificationError("Cannot unify TensorTypeNode"); +} + +// Type TypeUnifierNode::VisitType_(const TupleTypeNode *t1, const Type rt2) { +// TupleType pt1 = GetRef(t1); + +// // for typevar, remap and attempt to unify if already defined +// if (const IncompleteTypeNode *tvn2 = rt2.as()) { +// return this->unifyWithIncompleteType(pt1, GetRef(tvn2)); +// } + +// // for other product types, unify item by item +// if (const TupleTypeNode *ptn2 = rt2.as()) { +// TupleType pt2 = GetRef(ptn2); + +// std::vector unified_fields; +// if (pt1->fields.size() != pt2->fields.size()) { +// throw UnificationError("Product types are of different dimensions"); +// } + +// for (size_t i = 0U; i < pt1->fields.size(); i++) { +// Type unified = this->VisitType(pt1->fields[i], pt2->fields[i]); +// unified_fields.push_back(unified); +// } + +// return TupleTypeNode::make(unified_fields); +// } + +// // otherwise cannot unify +// throw UnificationError("Cannot unify TupleTypeNode"); +// } + +Type TypeUnifierNode::VisitType_(const TypeFunctionNode *sen1, const Type t2) { +// ShapeExtension sh_ext1 = GetRef(sen1); + +// if (const IncompleteTypeNode *tvn2 = t2.as()) { +// return this->unifyWithIncompleteType(sh_ext1, GetRef(tvn2)); +// } + +// // will only attempt to unify with binary op with same op +// if (const ShapeExtensionNode *sen2 = t2.as()) { +// if (sh_ext1->name != sen2->name) { +// throw UnificationError( +// "Cannot unify shape projections of different index"); +// } +// } + +// return sh_ext1; + return t2; +} + +Type TypeUnifierNode::VisitType_(const TypeCallNode *tcn1, const Type t2) { + TypeCall ty_call1 = GetRef(tcn1); + + if (const IncompleteTypeNode *tvn2 = t2.as()) { + return this->unifyWithIncompleteType(ty_call1, GetRef(tvn2)); + } + + if (const TypeCallNode *tcn2 = t2.as()) { + Type unified_func = this->VisitType(ty_call1->func, tcn2->func); + + // For now, we will only unify if they are equal. + if (ty_call1->args.size() != tcn2->args.size()) { + throw UnificationError("Cannot unify calls of different number of arguments"); + } + + // Unify members, if possible + tvm::Array new_args; + for (size_t i = 0U; i < ty_call1->args.size(); i++) { + Type unified_member = this->VisitType(ty_call1->args[i], tcn2->args[i]); + new_args.push_back(unified_member); + } + + return TypeCallNode::make(unified_func, new_args); + } else { + throw UnificationError("Cannot unify call with non-call"); + } +} + +Type TypeUnifierNode::unifyWithIncompleteType(const Type &t1, const IncompleteType tv2) { + RELAY_LOG(INFO) << "unifyWithIncompleteType: t1=" << t1 << " t2=" << tv2 << std::endl; + // Fix unify to return new representative + this->uf->unify(tv2, t1); + auto rep = this->uf->find(tv2); + RELAY_LOG(INFO) << "unifyWithIncompleteType: rep =" << rep << std::endl; + return rep; +} + +TVM_REGISTER_API("relay._make.TypeUnifier") + .set_body([](TVMArgs args, TVMRetValue *ret) { + if (args.size() < 3) { + *ret = TypeUnifierNode::make(UnionFindNode::make({})); + } else { + *ret = TypeUnifierNode::make(args[0]); + } + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TypeUnifierNode *node, + tvm::IRPrinter *p) { + p->stream << "TypeUnifierNode(" << node->uf << ")"; + }); + +TVM_REGISTER_API("relay._unifier.UnionFind_insert") + .set_body([](TVMArgs args, TVMRetValue *ret) { + try { + UnionFind uf = args[0]; + uf->insert(args[1]); + } catch (std::exception &e) { + throw UnionFindError(e.what()); + } + }); + +TVM_REGISTER_API("relay._unifier.UnionFind_unify") + .set_body([](TVMArgs args, TVMRetValue *ret) { + try { + UnionFind uf = args[0]; + uf->unify(args[1], args[2]); + } catch (std::exception &e) { + throw UnionFindError(e.what()); + } + }); + +TVM_REGISTER_API("relay._unifier.UnionFind_find") + .set_body([](TVMArgs args, TVMRetValue *ret) { + try { + UnionFind uf = args[0]; + *ret = uf->find(args[1]); + } catch (std::exception &e) { + throw UnionFindError(e.what()); + } + }); + +TVM_REGISTER_API("relay._unifier.TypeUnifier_insert") + .set_body([](TVMArgs args, TVMRetValue *ret) { + try { + TypeUnifier unifier = args[0]; + IncompleteType var = args[1]; + unifier->insert(var); + } catch (std::exception &e) { + throw UnificationError(e.what()); + } + }); + +TVM_REGISTER_API("relay._unifier.TypeUnifier_unify") + .set_body([](TVMArgs args, TVMRetValue *ret) { + try { + TypeUnifier unifier = args[0]; + Type t1 = args[1]; + Type t2 = args[2]; + *ret = unifier->unify(t1, t2); + } catch (std::exception &e) { + throw UnificationError(e.what()); + } + }); + +TVM_REGISTER_API("relay._unifier.TypeUnifier_subst") + .set_body([](TVMArgs args, TVMRetValue *ret) { + try { + TypeUnifier unifier = args[0]; + Type t = args[1]; + *ret = unifier->subst(t); + } catch (std::exception &e) { + throw SubstitutionError(e.what()); + } + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/compiler/unifier.h b/src/relay/compiler/unifier.h new file mode 100644 index 000000000000..6788265c90f2 --- /dev/null +++ b/src/relay/compiler/unifier.h @@ -0,0 +1,129 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file unifier.h + * \brief The type unifier which solves a system of equations between + * incomplete types. + */ +#ifndef TVM_RELAY_COMPILER_UNIFIER_H_ +#define TVM_RELAY_COMPILER_UNIFIER_H_ + +#include +#include "./type_functor.h" +#include "tvm/relay/ir.h" + +namespace tvm { +namespace relay { + +struct UnionFindError : dmlc::Error { + explicit UnionFindError(const std::string& msg) : Error(msg) {} +}; + +struct UnificationError : dmlc::Error { + explicit UnificationError(const std::string& msg) : Error(msg) {} +}; + +struct SubstitutionError : dmlc::Error { + explicit SubstitutionError(const std::string& msg) : Error(msg) {} +}; + +/*! \brief a union-find data structure for the type-checker */ +class UnionFind; // forward declaration + +class UnionFindNode : public Node { + public: + tvm::Map uf_map; + + UnionFindNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("uf_map", &uf_map); } + + TVM_DLL static UnionFind make(tvm::Map uf_map); + + // insert v into UF + void insert(const IncompleteType& v); + + // infers that v1 and v2 must be of the smae type + void unify(const IncompleteType& v1, const Type& v2); + + // returns representative of v's UF-group + Type find(const IncompleteType& v); + + void debug(); + + void assertAlphaEq(const Type& l, const Type& r); + + static constexpr const char* _type_key = "relay.UnionFind"; + TVM_DECLARE_NODE_TYPE_INFO(UnionFindNode, Node); +}; + +class UnionFind : public NodeRef { + public: + UnionFind() {} + explicit UnionFind(std::shared_ptr p) : NodeRef(p) {} + + // no const so that union find can be mutable as a member of unifier + inline UnionFindNode* operator->() const { + return static_cast(node_.get()); + } + + using ContainerType = UnionFindNode; +}; + +class TypeUnifier; +class TypeUnifierNode : public Node, + private TypeFunctor { + public: + UnionFind uf; + + TypeUnifierNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("uf", &uf); } + + TVM_DLL static TypeUnifier make(UnionFind uf); + + /*! \brief Introduces a new type var into the unifier */ + void insert(const IncompleteType& v); + + /*! \brief Unifies two types if possible, throws a unification error if it + * cannot */ + Type unify(const Type& t1, const Type& t2); + + /*! \brief Attempts to substitute all type vars in t with concrete types, + * throws substitution error if it cannot concretize*/ + Type subst(const Type& t); + + // /*! \brief Checks the kinds in the given type */ + // Type CheckKinds(const Type& t); + + static constexpr const char* _type_key = "relay.TypeUnifier"; + TVM_DECLARE_NODE_TYPE_INFO(TypeUnifierNode, Node); + + private: + // unify non-typevar with typevar + Type unifyWithIncompleteType(const Type& t1, const IncompleteType tvn2); + + Type VisitType_(const IncompleteTypeNode* t1, const Type t2) override; + Type VisitType_(const TensorTypeNode* t1, const Type t2) override; + Type VisitType_(const TypeParamNode* t1, const Type t2) override; + Type VisitType_(const FuncTypeNode* t1, const Type t2) override; + // Type VisitType_(const TupleTypeNode* t1, const Type t2) override; + Type VisitType_(const TypeFunctionNode* s1, const Type t2) override; + Type VisitType_(const TypeCallNode* s1, const Type t2) override; +}; + +class TypeUnifier : public NodeRef { + public: + TypeUnifier() {} + explicit TypeUnifier(std::shared_ptr p) : NodeRef(p) {} + + // no const so that unifier can be mutable as a member of typechecker + inline TypeUnifierNode* operator->() const { + return static_cast(node_.get()); + } + + using ContainerType = TypeUnifierNode; +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_TYPECK_UNIFIER_H_ diff --git a/tests/python/relay/test_unifier.py b/tests/python/relay/test_unifier.py new file mode 100644 index 000000000000..7938a5a3ae5f --- /dev/null +++ b/tests/python/relay/test_unifier.py @@ -0,0 +1,480 @@ +"""Tests unification of types.""" +# pylint: disable=invalid-name, missing-docstring, bare-except +import relay.ir +# pylint: disable=unused-import +import relay.unifier # TODO (@jroesch) fix me +# pylint: disable=wildcard-import, unused-wildcard-import +from relay.make import * + +def unify_types(t1, t2): + unifier = TypeUnifier() + return unifier.unify(t1, t2) + +def int_type(): + return TensorType(IntType(32), ShapeSeq([])) + +def float_type(): + return TensorType(FloatType(32), ShapeSeq([])) + +def bool_type(): + return TensorType(BoolType(), ShapeSeq([])) + +def make_shape(dims): + return ShapeSeq([ShapeSingleton(dim) for dim in dims]) + +def test_insert_and_find(): + uf = UnionFind() + v1 = TypeVar(ir.Kind.Type) + v2 = TypeVar(ir.Kind.Type) + uf.insert(v1) + uf.insert(v2) + assert uf.find(v1) == v1 + assert uf.find(v2) == v2 + +def test_insert_error(): + uf = UnionFind() + v1 = TypeVar(ir.Kind.Type) + v2 = TypeVar(ir.Kind.Type) + uf.insert(v1) + try: + uf.find(v2) + assert False + except: + return + +def test_unify(): + uf = UnionFind() + v1 = TypeVar(ir.Kind.Type) + v2 = TypeVar(ir.Kind.Type) + v3 = TypeVar(ir.Kind.Type) + uf.insert(v1) + uf.insert(v2) + uf.insert(v3) + uf.unify(v1, v2) + rep = uf.find(v1) + assert (rep == v1 or rep == v2) + assert uf.find(v1) == rep + assert uf.find(v2) == rep + assert uf.find(v3) == v3 + assert v3 != rep + uf.unify(v1, v3) + new_rep = uf.find(v3) + assert (rep == v1 or rep == v2 or rep == v3) + assert uf.find(v1) == new_rep + assert uf.find(v2) == new_rep + assert uf.find(v3) == new_rep + +def test_unify_multiple_levels(): + uf = UnionFind() + v = [TypeVar(ir.Kind.Type) for _ in range(9)] + for var in v: + uf.insert(var) + uf.unify(v[0], v[1]) + uf.unify(v[0], v[2]) + uf.unify(v[3], v[4]) + uf.unify(v[4], v[5]) + uf.unify(v[6], v[7]) + uf.unify(v[6], v[8]) + rep1 = uf.find(v[0]) + rep2 = uf.find(v[3]) + rep3 = uf.find(v[6]) + assert (rep1 == v[0] or rep1 == v[1] or rep1 == v[2]) + assert (rep2 == v[3] or rep2 == v[4] or rep2 == v[5]) + assert (rep3 == v[6] or rep3 == v[7] or rep3 == v[8]) + for i in range(3): + assert uf.find(v[i]) == rep1 + assert uf.find(v[i + 3]) == rep2 + assert uf.find(v[i + 6]) == rep3 + # now unify two of the groups + uf.unify(v[1], v[4]) + new_rep1 = uf.find(v[0]) + new_rep2 = uf.find(v[6]) + assert (new_rep1 == v[0] or new_rep1 == v[1] or new_rep1 == v[2] + or new_rep1 == v[3] or new_rep1 == v[4] or new_rep1 == v[5]) + assert (new_rep2 == v[6] or new_rep2 == v[7] or new_rep2 == v[8]) + for i in range(6): + assert uf.find(v[i]) == new_rep1 + for i in range(3): + assert uf.find(v[i + 6]) == new_rep2 + +# TODO(sslyu, weberlo, joshpoll): put in isinstance asserts once those work +def test_unify_int(): + intty = IntType(1) + unified = unify_types(intty, intty) + assert intty == unified + +def test_unify_bool(): + boolty = BoolType() + unified = unify_types(boolty, boolty) + assert boolty == unified + +def test_unify_float(): + floatty = FloatType(4) + unified = unify_types(floatty, floatty) + assert floatty == unified + +def test_unify_incompatible_basetypes(): + bt = BoolType() + intty = IntType(32) + try: + unify_types(bt, intty) + assert False + except: + return + +def test_unify_concrete_type_arrow(): + arr1 = TypeArrow([int_type()], int_type()) + arr2 = TypeArrow([int_type()], int_type()) + unified = unify_types(arr1, arr2) + assert unified == arr1 + +def test_unify_type_arrow_with_holes(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.BaseType) + unifier.insert(v1) + unifier.unify(v1, bool_type()) + arr1 = TypeArrow([int_type()], bool_type()) + arr2 = TypeArrow([int_type()], v1) + unified = unifier.unify(arr1, arr2) + assert unified == arr1 + + v2 = TypeVar(ir.Kind.BaseType) + unifier.insert(v2) + unifier.unify(v2, int_type()) + arr3 = TypeArrow([v2], bool_type()) + unified = unifier.unify(arr1, arr3) + assert unified == arr1 + +def test_reject_incompatible_type_arrows(): + arr1 = TypeArrow([int_type()], bool_type()) + arr2 = TypeArrow([int_type(), bool_type()], bool_type()) + try: + unify_types(arr1, arr2) + assert False + except: + return + +def test_unify_concrete_type_quantifiers(): + tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), int_type()) + tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), int_type()) + unified = unify_types(tq1, tq2) + assert unified == tq1 + +def test_unify_basetype_with_quantifier_error(): + bt = bool_type() + tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt) + try: + unify_types(bt, tq) + assert False + except: + return + +def test_unify_typevars_with_each_other(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.Type) + v2 = TypeVar(ir.Kind.Type) + v3 = TypeVar(ir.Kind.Type) + unifier.insert(v1) + unifier.insert(v2) + unifier.insert(v3) + unified = unifier.unify(v1, v2) + assert (unified == v1 or unified == v2) + assert unified != v3 + new_unified = unifier.unify(v1, v3) + assert (new_unified == v1 or new_unified == v2 or new_unified == v3) + +def test_unify_typevars_with_basetype(): + unifier = TypeUnifier() + bt = BoolType() + v1 = TypeVar(ir.Kind.BaseType) + v2 = TypeVar(ir.Kind.BaseType) + unifier.insert(v1) + unifier.insert(v2) + unified1 = unifier.unify(v1, bt) + assert unified1 == bt + unified2 = unifier.unify(v1, v2) + assert unified2 == bt + +def test_unify_compatible_typevars(): + unifier = TypeUnifier() + bt = BoolType() + v1 = TypeVar(ir.Kind.BaseType) + v2 = TypeVar(ir.Kind.BaseType) + unifier.insert(v1) + unifier.insert(v2) + unifier.unify(v1, bt) + unifier.unify(v2, bt) + # because types to which v1 and v2 have been assigned are compatible, + # this should proceed without problems + unified = unifier.unify(v1, v2) + assert unified == bt + +def test_unify_incompatible_typevars(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.BaseType) + v2 = TypeVar(ir.Kind.BaseType) + bt = bool_type() + tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt) + unifier.insert(v1) + unifier.insert(v2) + unifier.unify(v1, bt) + unifier.unify(v2, tq) + # bt cannot be unified with tq, so unifying v1 and v2 should give an error + try: + unifier.unify(v1, v2) + assert False + except: + return + +def test_unify_typevar_with_quantifier(): + unifier = TypeUnifier() + tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bool_type()) + v1 = TypeVar(ir.Kind.BaseType) + unifier.insert(v1) + unified = unifier.unify(v1, tq) + assert unified == tq + +def test_unify_typevars_inside_concrete_quantifier(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.BaseType) + unifier.insert(v1) + tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v1) + tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), bool_type()) + unified = unifier.unify(tq1, tq2) + assert unified == tq2 + +def test_unify_concrete_tensors(): + bt = BoolType() + shape = make_shape([1, 2, 3]) + tt1 = TensorType(bt, shape) + tt2 = TensorType(bt, shape) + unified = unify_types(tt1, tt2) + assert unified == tt1 + +def test_unify_tensor_shape_reject(): + bt = BoolType() + shape1 = make_shape([1, 2, 3]) + shape2 = make_shape([2, 3, 4]) + tt1 = TensorType(bt, shape1) + tt2 = TensorType(bt, shape2) + try: + unify_types(tt1, tt2) + assert False + except: + return + +def test_unify_tensor_dtype_reject(): + bt1 = BoolType() + bt2 = IntType(32) + shape = make_shape([1, 2, 3]) + tt1 = TensorType(bt1, shape) + tt2 = TensorType(bt2, shape) + try: + unify_types(tt1, tt2) + assert False + except: + return + +def test_unify_quantified_tensors(): + x = TypeParam("x", ir.type.Kind.Shape) + y = TypeParam("y", ir.type.Kind.Shape) + tq1 = TypeQuantifier(x, TensorType(BoolType(), x)) + tq2 = TypeQuantifier(y, TensorType(BoolType(), y)) + unified = unify_types(tq1, tq2) + assert unified == tq1 + + a = TypeParam("a", ir.type.Kind.BaseType) + b = TypeParam("b", ir.type.Kind.BaseType) + tq3 = TypeQuantifier(a, TensorType(a, make_shape([1, 2, 3]))) + tq4 = TypeQuantifier(b, TensorType(b, make_shape([1, 2, 3]))) + unified = unify_types(tq3, tq4) + assert unified == tq3 + +def test_unify_concrete_products(): + bt = bool_type() + intty = int_type() + pt1 = TupleType([bt, intty]) + pt2 = TupleType([bt, intty]) + unified = unify_types(pt1, pt2) + assert unified == pt1 + +def test_unify_products_reject_size(): + bt = BoolType() + intty = IntType(32) + pt1 = TupleType([bt, bt, intty]) + pt2 = TupleType([bt, intty]) + try: + unify_types(pt1, pt2) + assert False + except: + return + +def test_unify_products_reject_member(): + bt = BoolType() + intty = IntType(32) + pt1 = TupleType([bt, bt]) + pt2 = TupleType([bt, intty]) + try: + unify_types(pt1, pt2) + assert False + except: + return + +def test_unify_products_typevar(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.BaseType) + bt = bool_type() + pt1 = TupleType([bt, bt]) + pt2 = TupleType([v1, bt]) + unifier.insert(v1) + unified = unifier.unify(pt1, pt2) + assert unified == pt1 + +def test_unify_quantified_products(): + x = TypeParam("x", ir.Kind.Type) + y = TypeParam("y", ir.Kind.Type) + p1 = TypeQuantifier(x, TupleType([int_type(), x])) + p2 = TypeQuantifier(y, TupleType([int_type(), y])) + unified = unify_types(p1, p2) + assert unified == p1 + +def test_unify_ref_types(): + r1 = RefType(bool_type()) + r2 = RefType(bool_type()) + assert unify_types(r1, r2) == r1 + +def test_unify_ref_reject_inner(): + r1 = RefType(BoolType()) + r2 = RefType(IntType(32)) + try: + unify_types(r1, r2) + assert False + except: + return + +def test_subst_basetype(): + unifier = TypeUnifier() + bt = BoolType() + assert bt == unifier.subst(bt) + +def test_subst_simple_hole(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.BaseType) + bt = BoolType() + unifier.insert(v1) + unifier.unify(v1, bt) + assert unifier.subst(v1) == bt + +def test_subst_typevar_for_typevar(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.Type) + v2 = TypeVar(ir.Kind.Type) + unifier.insert(v1) + unifier.insert(v2) + + unifier.unify(v1, v2) + assert unifier.subst(v1) == v2 + +def test_subst_concrete_arrow(): + unifier = TypeUnifier() + arr1 = TypeArrow([int_type()], int_type()) + assert unifier.subst(arr1) == arr1 + +def test_subst_arrow_with_holes(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.BaseType) + v2 = TypeVar(ir.Kind.BaseType) + unifier.insert(v1) + unifier.insert(v2) + unifier.unify(v1, int_type()) + unifier.unify(v2, bool_type()) + arr1 = TypeArrow([v1], v2) + arr2 = TypeArrow([int_type()], bool_type()) + assert unifier.subst(arr1) == arr2 + +def test_subst_concrete_quantifier(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.BaseType) + tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), int_type()) + unifier.insert(v1) + unifier.unify(v1, tq) + assert unifier.subst(v1) == tq + +def test_subst_quantifier_with_holes(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.Type) + v2 = TypeVar(ir.Kind.Type) + tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v2) + intty = int_type() + tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), intty) + + unifier.insert(v1) + unifier.insert(v2) + unifier.unify(v2, intty) + unifier.unify(v1, tq1) + assert unifier.subst(v1) == tq2 + +def test_subst_concrete_tensor(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.Type) + unifier.insert(v1) + tt = TensorType(BoolType(), make_shape([1, 2, 3])) + unifier.unify(v1, tt) + assert unifier.subst(v1) == tt + +def test_subst_concrete_product(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.Type) + unifier.insert(v1) + bt = bool_type() + pt = TupleType([bt, bt]) + unifier.unify(v1, pt) + assert unifier.subst(v1) == pt + +def test_subst_product_with_holes(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.Type) + v2 = TypeVar(ir.Kind.Type) + v3 = TypeVar(ir.Kind.Type) + unifier.insert(v1) + unifier.insert(v2) + unifier.insert(v3) + + tt1 = TensorType(IntType(32), ShapeSeq([])) + tt2 = TensorType(FloatType(32), ShapeSeq([])) + pt1 = TupleType([tt1, v2, v3]) + unifier.unify(v2, tt2) + unifier.unify(v3, v2) + unifier.unify(v1, pt1) + pt2 = TupleType([tt1, tt2, tt2]) + assert unifier.subst(v1) == pt2 + +def test_subst_concrete_ref(): + unifier = TypeUnifier() + rt = RefType(bool_type()) + assert unifier.subst(rt) == rt + +def test_subst_ref_with_hole(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.Type) + unifier.insert(v1) + + unifier.unify(v1, bool_type()) + rt1 = RefType(v1) + rt2 = RefType(bool_type()) + assert unifier.subst(rt1) == rt2 + +def test_typevar_on_lhs(): + unifier = TypeUnifier() + v1 = TypeVar(ir.Kind.BaseType) + v2 = TypeVar(ir.Kind.Type) + bt = bool_type() + tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt, bt) + unifier.insert(v1) + unifier.insert(v2) + unified1 = unifier.unify(bt, v1) + assert unified1 == bt + unified2 = unifier.unify(tq, v2) + assert unified2 == tq + assert unifier.subst(v1) == bt + assert unifier.subst(v2) == tq From bc9754f2506fd914ca2c9a83550f553443306f18 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 13:25:20 -0700 Subject: [PATCH 016/136] Update type_functor.h for incomplete type. --- include/tvm/relay/compiler/typechecker.h | 2 +- src/relay/compiler/type_functor.h | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/tvm/relay/compiler/typechecker.h b/include/tvm/relay/compiler/typechecker.h index c71f78c1a5b0..c69aba3c1e71 100644 --- a/include/tvm/relay/compiler/typechecker.h +++ b/include/tvm/relay/compiler/typechecker.h @@ -8,7 +8,7 @@ #define TVM_RELAY_COMPILER_TYPECHECKER_H_ #include "tvm/relay/ir.h" -#include "tvm/relay/environment.h" +#include "tvm/relay/compiler/environment.h" namespace tvm { namespace relay { diff --git a/src/relay/compiler/type_functor.h b/src/relay/compiler/type_functor.h index 66454725db48..3840c902bfe8 100644 --- a/src/relay/compiler/type_functor.h +++ b/src/relay/compiler/type_functor.h @@ -7,7 +7,8 @@ #define TVM_RELAY_COMPILER_TYPE_FUNCTOR_H_ #include -#include "ir.h" +#include "tvm/relay/ir.h" +#include "./incomplete_type.h" namespace tvm { namespace relay { @@ -61,12 +62,10 @@ class TypeFunctor { Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeParamNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; - Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; - Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitType_(const TypeFunction* op, Args... args) TYPE_FUNCTOR_DEFAULT; - Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeFunctionNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitTypeDefault_(const Node* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->type_key(); @@ -84,6 +83,7 @@ class TypeFunctor { RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode); RELAY_TYPE_FUNCTOR_DISPATCH(TypeFunctionNode); RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode); + RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode); return vtable; } }; From d9d11316ff339f0067572d240ce56c93dc541c14 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 13:40:34 -0700 Subject: [PATCH 017/136] Add Python side of unifier --- python/tvm/relay/_unifier.py | 5 +++ python/tvm/relay/_unifier.pyi | 12 ++++++ python/tvm/relay/ir.py | 18 +++++++++ python/tvm/relay/type.py | 5 +++ python/tvm/relay/unifier.py | 61 ++++++++++++++++++++++++++++++ tests/python/relay/test_unifier.py | 12 +++--- 6 files changed, 107 insertions(+), 6 deletions(-) create mode 100644 python/tvm/relay/_unifier.py create mode 100644 python/tvm/relay/_unifier.pyi create mode 100644 python/tvm/relay/ir.py create mode 100644 python/tvm/relay/unifier.py diff --git a/python/tvm/relay/_unifier.py b/python/tvm/relay/_unifier.py new file mode 100644 index 000000000000..41f5fe374b3e --- /dev/null +++ b/python/tvm/relay/_unifier.py @@ -0,0 +1,5 @@ +"""FFI functions for the Unifier.""" + +from tvm._ffi.function import _init_api + +_init_api("relay._unifier", __name__) diff --git a/python/tvm/relay/_unifier.pyi b/python/tvm/relay/_unifier.pyi new file mode 100644 index 000000000000..6ecd309250a6 --- /dev/null +++ b/python/tvm/relay/_unifier.pyi @@ -0,0 +1,12 @@ +from tvm.relay.ir import NodeBase + +class UnionFind(NodeBase): ... +class TypeUnifier(NodeBase): ... + +def UnionFind_insert(self: UnionFind, var: ir.IncompleteType) -> None: ... +def UnionFind_unify(self: UnionFind, var1: ir.IncompleteType, var2: ir.IncompleteType) -> None: ... +def UnionFind_find(self: UnionFind, var: ir.IncompleteType) -> ir.Type: ... + +def TypeUnifier_insert(self: TypeUnifier, var: ir.IncompleteType) -> None: ... +def TypeUnifier_unify(self, type1: ir.Type, type2: ir.Type) -> ir.Type: ... +def TypeUnifier_subst(self, type1: ir.Type) -> ir.Type: ... diff --git a/python/tvm/relay/ir.py b/python/tvm/relay/ir.py new file mode 100644 index 000000000000..a95f29abe6de --- /dev/null +++ b/python/tvm/relay/ir.py @@ -0,0 +1,18 @@ +from . import base +from . import type as ty +from . import expr + +# Base +register_relay_node = base.register_relay_node +NodeBase = base.NodeBase + +# Type +Type = ty.Type +TensorType = ty.Type +Kind = ty.Kind +TypeParam = ty.TypeParam +TypeConstraint = ty.TypeConstraint +FuncType = ty.FuncType +IncompleteType = ty.IncompleteType + +# Expr diff --git a/python/tvm/relay/type.py b/python/tvm/relay/type.py index c92f0d756587..4d53cf88a218 100644 --- a/python/tvm/relay/type.py +++ b/python/tvm/relay/type.py @@ -49,3 +49,8 @@ class FuncType(Type): arg_types: List[Type] ret_type: Type span: Span + +@register_relay_node +class IncompleteType(Type): + """An incomplete type.""" + pass diff --git a/python/tvm/relay/unifier.py b/python/tvm/relay/unifier.py new file mode 100644 index 000000000000..cb818de19c1d --- /dev/null +++ b/python/tvm/relay/unifier.py @@ -0,0 +1,61 @@ +"""The Python interface to Relay's UnionFind and TypeUnifier.""" + +from typing import Dict +from .ir import register_relay_node, NodeBase +from . import ir +from . import _unifier + +@register_relay_node +class UnionFind(NodeBase): + """Python API for UnionFind. + + The UnionFind maintains equality classes of type variables, the + representative of an equality class may be a type (which can) + contain type variables. The TypeUnifier uses this to build a + unification procedure between types. + """ + uf_map: Dict[ir.IncompleteType, ir.IncompleteType] + + def insert(self, var: ir.IncompleteType) -> None: + """Insert a type variable into the union find. + + :param: var: The variable to be inserted. + """ + return _unifier.UnionFind_insert(self, var) + + def unify(self, var: ir.IncompleteType, typ: ir.Type) -> None: + """Unify a type variable with an arbitrary type. + + :param: var: A type variable to be unified. + :param: typ: The type to be unified with. + """ + return _unifier.UnionFind_unify(self, var, typ) + + def find(self, var: ir.IncompleteType) -> ir.IncompleteType: + """Find the representative element of the type var. + + :param: var: The variable to lookup in the union find. + """ + return _unifier.UnionFind_find(self, var) + +@register_relay_node +class TypeUnifier(NodeBase): + """Python API for the TypeUnifier.""" + #pylint: disable=invalid-name + uf: UnionFind + eq_map: Dict[ir.TypeParam, ir.TypeParam] + + def insert(self, var: ir.IncompleteType) -> None: + return _unifier.TypeUnifier_insert(self, var) + + def unify(self, type1: ir.Type, type2: ir.Type) -> ir.Type: + """Unify two types producing the unified type as a result. + + :param: type1: The first type to be unified. + :param: type2: The second type to be unified. + :returns: The unified type. + """ + return _unifier.TypeUnifier_unify(self, type1, type2) + + def subst(self, type1: ir.Type) -> ir.Type: + return _unifier.TypeUnifier_subst(self, type1) diff --git a/tests/python/relay/test_unifier.py b/tests/python/relay/test_unifier.py index 7938a5a3ae5f..875502808563 100644 --- a/tests/python/relay/test_unifier.py +++ b/tests/python/relay/test_unifier.py @@ -1,10 +1,10 @@ -"""Tests unification of types.""" -# pylint: disable=invalid-name, missing-docstring, bare-except +""" +Test the type unifier, which solves systems of equations +between incomplete types. +""" import relay.ir -# pylint: disable=unused-import -import relay.unifier # TODO (@jroesch) fix me -# pylint: disable=wildcard-import, unused-wildcard-import -from relay.make import * +import relay.unifier + def unify_types(t1, t2): unifier = TypeUnifier() From 860d7e6841257ecacca1cc3459513b491a3b0aa9 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 14:06:09 -0700 Subject: [PATCH 018/136] Add to incomplete_type and add impl in typechecker.cc --- include/tvm/relay/type.h | 5 +- python/tvm/relay/type.py | 4 +- src/relay/compiler/incomplete_type.h | 8 +- src/relay/compiler/typechecker.cc | 771 +++++++++++++++++++++++++++ 4 files changed, 783 insertions(+), 5 deletions(-) create mode 100644 src/relay/compiler/typechecker.cc diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index dfe4309b7c77..4eeb42168d68 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -113,7 +113,10 @@ class TypeParamNode : public TypeNode { /*! \brief possible kinds of TypeParam */ enum Kind : int { /*! \brief template variable in shape expression */ - kShapeVar = 0 + kShapeVar = 0, + kShape = 1, + kBaseType = 2, + kType = 3, }; /*! * \brief The variable diff --git a/python/tvm/relay/type.py b/python/tvm/relay/type.py index 4d53cf88a218..2790b546cfe5 100644 --- a/python/tvm/relay/type.py +++ b/python/tvm/relay/type.py @@ -21,10 +21,10 @@ class Kind(IntEnum): """The kind of a type parameter, represents a variable shape, base type, type, or dimension. """ - Shape = 0 + ShapeVar = 0 + Shape = 1 BaseType = 1 Type = 2 - Elem = 3 @register_relay_node class TypeParam(Type): diff --git a/src/relay/compiler/incomplete_type.h b/src/relay/compiler/incomplete_type.h index 8f360d1cd51c..f31a2efdf78d 100644 --- a/src/relay/compiler/incomplete_type.h +++ b/src/relay/compiler/incomplete_type.h @@ -20,9 +20,13 @@ class IncompleteType; /*! \brief IncompleteType container node */ class IncompleteTypeNode : public TypeNode { public: - void VisitAttrs(tvm::AttrVisitor* v) final {} + TypeParamNode::Kind kind; - TVM_DLL static IncompleteType make(); + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("kind", &kind); + } + + TVM_DLL static IncompleteType make(TypeParamNode::Kind kind); static constexpr const char* _type_key = "relay.IncompleteType"; TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode); diff --git a/src/relay/compiler/typechecker.cc b/src/relay/compiler/typechecker.cc new file mode 100644 index 000000000000..c1f7b7f88765 --- /dev/null +++ b/src/relay/compiler/typechecker.cc @@ -0,0 +1,771 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file typechecker.cc + * \brief Relay typechecker + */ + +#include "tvm/relay/compiler/typechecker.h" +#include "./incomplete_type.h" +// #include "tvm/relay/alpha_eq.h" +// #include "tvm/relay/debug.h" +// #include "tvm/relay/first_order_reverse_ad.h" +// #include "tvm/relay/free_type_vars.h" +// #include "tvm/relay/gen_fresh.h" +// #include "tvm/relay/ir.h" +// #include "tvm/relay/logging.h" +// #include "tvm/relay/pretty_printer.h" +// #include "tvm/relay/reverse_ad.h" +// #include "tvm/relay/type_visitor.h" +// #include "tvm/relay/typeck/kindchecker.h" +// #include "tvm/relay/typeck/resolve.h" +// #include "tvm/relay/typeck/shape_evaluator.h" + +namespace tvm { +namespace relay { + +// using namespace tvm::runtime; + +// struct FatalTypeError : dmlc::Error { +// explicit FatalTypeError(const std::string & s) : dmlc::Error(s) {} +// }; + +// struct TypeContext { +// std::vector> stack; +// TypeContext() { +// stack.push_back({}); +// } +// void insert(const LocalId &id, const Type &t) { stack.back()[id] = t; } +// Type lookup(const LocalId &id) { +// for (auto frame = stack.rbegin(); frame != stack.rend(); ++frame) { +// if (frame->find(id) != frame->end()) { +// return frame->at(id); +// } +// } +// throw FatalTypeError("Could not resolve local id"); +// } +// struct LocalFrame { +// TypeContext & tc; +// explicit LocalFrame(TypeContext & tc) : tc(tc) { +// tc.stack.push_back({}); +// } +// ~LocalFrame() { +// tc.stack.pop_back(); +// } +// }; +// }; + +// class Typechecker : private ExprFunctor { +// private: +// TypeContext local_stack; +// public: +// Environment env; +// TypeUnifier unifier; + +// template +// T with_frame(const std::function & f) { +// TypeContext::LocalFrame fr(local_stack); +// return f(); +// } + +// Typechecker(); +// Typechecker(Environment env, TypeUnifier unifier) : env(env), unifier(unifier) {} +// explicit Typechecker(Environment env); +// Type Check(const Expr & expr); +// Type instantiate(Type t, tvm::Array & ty_args); + +// void report_error(const std::string & msg, Span sp); +// [[ noreturn ]] void fatal_error(const std::string & msg, Span sp); + +// Type unify(const Type &t1, const Type &t2, Span sp); +// Type resolve(const Type &t); +// Expr resolve(const Expr &e); +// Type VisitFunction(const Function & f, bool generalize); +// Operator CheckOp(Operator op); +// Defn CheckDefn(Defn def); +// private: +// Type VisitExpr_(const LocalIdNode* op) override; +// Type VisitExpr_(const GlobalIdNode* op) override; +// Type VisitExpr_(const OperatorIdNode* op) override; +// Type VisitExpr_(const FloatLitNode* op) override; +// Type VisitExpr_(const BoolLitNode* op) override; +// Type VisitExpr_(const IntLitNode* op) override; +// Type VisitExpr_(const TensorLitNode* op) override; +// Type VisitExpr_(const TupleNode* op) override; +// Type VisitExpr_(const CastNode* op) override; +// Type VisitExpr_(const ParamNode* op) override; +// Type VisitExpr_(const FunctionNode* op) override; +// Type VisitExpr_(const CallNode* op) override; +// Type VisitExpr_(const DebugNode* op) override; +// Type VisitExpr_(const LetNode* op) override; +// Type VisitExpr_(const ReverseNode* op) override; +// Type VisitExpr_(const GradientNode* op) override; +// Type VisitExpr_(const ProjectionNode* op) override; +// Type VisitExpr_(const IfNode* op) override; +// Type VisitExpr_(const RefNode* op) override; +// Type VisitExpr_(const ReadRefNode* op) override; +// Type VisitExpr_(const WriteRefNode* op) override; +// Type simple_eval_shape(const Type &shape); +// }; +// struct TypecheckerError : public dmlc::Error { +// explicit TypecheckerError(const std::string &msg) : Error(msg) {} +// }; + +// Typechecker::Typechecker() { +// this->env = EnvironmentNode::make({}); +// this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); +// } + +// Typechecker::Typechecker(Environment env) : env(env) { +// this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); +// } + +// Type Typechecker::Check(const Expr &expr) { +// RELAY_LOG(INFO) << "Typechecker::Check expr=" << expr << std::endl; +// Type ret = this->VisitExpr(expr); +// RELAY_LOG(INFO) << "Typechecker::Check type=" << expr << std::endl; +// ret = this->unifier->subst(ret); +// RELAY_LOG(INFO) << "Typechecker::Check type_after_subst=" << ret << std::endl; +// expr->checked_type_ = ret; +// return ret; +// } + +// Type Typechecker::VisitExpr_(const LocalIdNode *op) { +// LocalId id = GetRef(op); +// return this->local_stack.lookup(id); +// } + +// Type Typechecker::VisitExpr_(const GlobalIdNode *op) { +// GlobalId id = GetRef(op); +// Item item = this->env->lookup(id); + +// if (const OperatorNode *op = item.as()) { +// return op->type; +// } + +// if (const DefnNode *dn = item.as()) { +// Defn def = GetRef(dn); +// return def->type; +// } + +// this->fatal_error("Unhandled case in GlobalId", op->span); +// } + +// Type Typechecker::VisitExpr_(const OperatorIdNode *op) { +// OperatorId id = GetRef(op); +// Item item = this->env->lookup(id); + +// if (const OperatorNode *pn = item.as()) { +// Operator prim = GetRef(pn); +// return prim->type; +// } else { +// this->fatal_error("internal error in InstrinsicId case", op->span); +// } +// } + +// Type Typechecker::VisitExpr_(const FloatLitNode *op) { return FloatType(); } + +// Type Typechecker::VisitExpr_(const BoolLitNode *op) { return BoolType(); } + +// Type Typechecker::VisitExpr_(const IntLitNode *op) { return IntType(); } + +// Type Typechecker::VisitExpr_(const TensorLitNode *op) { +// TensorLit lit = GetRef(op); + +// if (lit->data.size() == 0) { +// this->fatal_error("Tensor literal must have at least one member", op->span); +// } + +// // unify types of all members to figure out shape, also ensure that +// // each member has compatible shape +// Type unified = this->Check(lit->data[0]); +// for (auto elt = lit->data.begin(); elt != lit->data.end(); elt++) { +// // evaluate all shape ASTs so they can be in standard form +// // TODO(sslyu): eventually we'd want this to be symbolic evaluation +// auto elt_el = *elt; +// Type elt_type = simple_eval_shape(this->Check(*elt)); +// if (!elt_type.as()) { +// this->fatal_error("All members in tensor literal must be tensors", +// elt_el->span); +// } +// unified = this->unify(unified, elt_type, lit->span); +// } + +// // types must unify into a tensor +// const TensorTypeNode *ttn = unified.as(); +// // shouldn't be possible due to check inside the loop +// if (!ttn) { +// this->fatal_error("Tensor literal contains non-tensor member", op->span); +// } + +// TensorType unified_tt = GetRef(ttn); + +// // new shape: add length of this tensor to front of existing shape +// // i.e., sequence and simplify +// // TODO(sslyu): should be symbolic evaluation eventually? +// Type new_shape = ShapeSeqNode::make( +// {ShapeSingletonNode::make(lit->data.size()), unified_tt->shape}); +// return TensorTypeNode::make(unified_tt->dtype, simple_eval_shape(new_shape)); +// } + +// Type Typechecker::VisitExpr_(const TupleNode *op) { +// Tuple pl = GetRef(op); + +// std::vector field_types; +// for (auto field = pl->fields.begin(); field != pl->fields.end(); field++) { +// field_types.push_back(this->Check(*field)); +// } + +// return TupleTypeNode::make(field_types); +// } + +// Type Typechecker::VisitExpr_(const CastNode *op) { +// // will take the cast at its word +// Cast cast = GetRef(op); +// return cast->target; +// } + +// Type Typechecker::VisitExpr_(const ParamNode *op) { +// Param param = GetRef(op); +// return resolve(param->type); +// } + +// // We should probably generalize the subst code. +// struct GeneralizeTypeType : TypeFVisitor { +// Map vars_to_id; +// const TypeUnifier &unifier; + +// GeneralizeTypeType(Map vars_to_id, +// const TypeUnifier &unifier) +// : vars_to_id(vars_to_id), unifier(unifier) {} + +// Type VisitType_(const TypeVarNode *op) override { +// auto repr = unifier->subst(GetRef(op)); +// if (auto tvn = repr.as()) { +// auto ty_var = GetRef(tvn); +// if (vars_to_id.find(ty_var) != vars_to_id.end()) { +// return vars_to_id[ty_var]; +// } else { +// return ty_var; +// } +// } else { +// return this->VisitType(repr); +// } +// } +// }; + +// struct GeneralizeTypeExpr : ExprFVisitor<> { +// Map vars_to_id; +// const TypeUnifier &unifier; + +// GeneralizeTypeExpr(const TypeUnifier &unifier, +// Map vars_to_id) +// : vars_to_id(vars_to_id), unifier(unifier) {} + +// Type VisitType(const Type &t) { +// return GeneralizeTypeType(vars_to_id, unifier).VisitType(t); +// } +// }; + +// Type Typechecker::VisitFunction(const Function &f, bool generalize) { +// // enter params into context +// auto fn_type = this->with_frame([&]() { +// std::vector arg_types; +// for (auto arg : f->params) { +// this->Check(arg); +// Type arg_type; +// // if arg type can be simply evaluated, try it +// // should be replaced with symbolic evaluation once it exists, +// // you will not have attr information at this point +// try { +// arg_type = simple_eval_shape(arg->type); +// } catch (const dmlc::Error &e) { +// this->report_error(e.what(), arg->span); +// arg_type = arg->type; +// } +// arg_types.push_back(arg_type); +// this->local_stack.insert(arg->id, arg_type); +// } + +// // typecheck body and ensure that it matches stated return type +// // TODO(sslyu): should the unified return type override the annotated one? +// Type checked_return = this->Check(f->body); +// Type ret_type = resolve(f->ret_type); +// Type unified = this->unify(simple_eval_shape(ret_type), +// simple_eval_shape(checked_return), f->span); +// return TypeArrowNode::make(arg_types, unified); +// }); +// if (generalize) { +// auto free_vars = free_type_vars(resolve(fn_type)); +// std::set dedup_free_vars; + +// for (auto free_var : free_vars) { +// auto repr = this->unifier->subst(free_var); +// if (auto new_free_var_node = repr.as()) { +// dedup_free_vars.insert(GetRef(new_free_var_node)); +// } else { +// // debug(repr); +// throw dmlc::Error( +// "internal error: this list should only contain type var nodes"); +// } +// } + +// Map vars_to_id; + +// GenFresh gf; +// for (auto free_var : dedup_free_vars) { +// vars_to_id.Set(free_var, gf.freshTV(free_var->kind)); +// } + +// fn_type = GeneralizeTypeType(vars_to_id, unifier).VisitType(fn_type); +// for (std::pair pair : vars_to_id) { +// // NB: In generalization we want to find type variables with +// // *no constraints* on them, and convert them to universally quantified +// // variables. +// // +// // i.e the program can be abstracted over the details of *that* type. + +// // For example a program that works irrespective of shape or datatype. + +// // In order to do this we find the set of free type variables in the +// // term, and then unify them with the fresh type ids we generate. +// // +// // Remember importantly these type variables still may appear in many +// // places in the program including both types and expressions. + +// // Our method for resolving these is to unify them with the variables +// // as we build the new quanitifer, changing from a program with "holes" +// // to one that is properly abstracted over. + +// // Finally later on we can iterate over the whole term and change from +// // type variables to these type ids. +// this->unify(pair.first, pair.second, pair.second->span); +// fn_type = TypeQuantifierNode::make(pair.second, fn_type); +// } +// } else { +// for (auto i = f->ty_params.size(); i > 0; i--) { +// auto ty_param = f->ty_params[i - 1]; +// auto ty_param_node = ty_param.as(); +// if (!ty_param_node) { +// throw dmlc::Error("internal error should be TypeParam"); +// } +// auto fresh_tid = +// TypeParamNode::make(ty_param_node->name, ty_param_node->kind); +// fn_type = +// type_subst(fn_type, GetRef(ty_param_node), fresh_tid); +// fn_type = TypeQuantifierNode::make(fresh_tid, fn_type); +// } +// } + +// return fn_type; +// } + +// Type Typechecker::VisitExpr_(const FunctionNode *op) { +// return this->VisitFunction(GetRef(op), false); +// } + +// Type Typechecker::instantiate(Type t, tvm::Array &ty_args) { +// const TypeQuantifierNode *ty_quant; +// while ((ty_quant = t.as())) { +// TypeParam id = ty_quant->id; +// TypeVar fresh = TypeVarNode::make(id->kind); +// this->unifier->insert(fresh); +// ty_args.push_back(fresh); +// t = type_subst(ty_quant->boundType, id, fresh); +// } + +// if (!check_kind(t)) { +// this->fatal_error("Kind rules broken when instantiating type variables", +// t->span); +// } + +// return t; +// } + +// Type Typechecker::VisitExpr_(const CallNode *op) { +// Call c = GetRef(op); +// Type fn_ty = this->Check(c->fn); + +// RELAY_LOG(INFO) << "Typechecker::VisitExpr_ op=" << c << std::endl +// << "fn_ty=" << fn_ty << std::endl; + +// // for each type id, insert a type variable and unify with the argument types +// // in order +// // to obtain the concrete instantiation +// tvm::Array ty_args; +// if (const TypeQuantifierNode *ty_quant = fn_ty.as()) { +// fn_ty = instantiate(GetRef(ty_quant), ty_args); +// } + +// if (!fn_ty.as()) { +// this->fatal_error("only expressions with function types can be called", +// c->fn->span); +// } + +// // evaluate all shapes up front (require that types be fully concrete) +// Type evaluated = evaluate_concrete_shape(fn_ty, op->attrs); +// std::vector arg_types; + +// TypeArrow arrow = GetRef(evaluated.as()); + +// // TODO(sslyu): figure out how to handle type ids +// // fn_ty = instantiate(fn_ty, ty_args); +// for (auto arg : c->args) { +// auto ty = this->Check(arg); +// arg_types.push_back(ty); +// } + +// auto type_arity = arrow->arg_types.size(); +// auto number_of_args = arg_types.size(); +// if (type_arity != number_of_args) { +// if (type_arity < number_of_args) { +// this->fatal_error("the function is provided too many arguments", c->span); +// } else { +// this->fatal_error("the function is provided too few arguments", c->span); +// } +// } + +// for (size_t i = 0; i < arrow->arg_types.size(); i++) { +// this->unify(arrow->arg_types[i], arg_types[i], c->args[i]->span); +// } + +// // After we unify the arguments we should know more about the type +// // arguments, let's run a quick pass over them to find new representatives. +// for (size_t i = 0; i < ty_args.size(); i++) { +// ty_args.Set(i, this->unifier->subst(ty_args[i])); +// } + +// // Write the type arguments into the call node, recording what inference +// // solves. This solution might need some work. +// c->ty_args = ty_args; + +// return arrow->ret_type; +// } + +// Type Typechecker::VisitExpr_(const DebugNode *op) { +// return this->Check(op->node); +// } + +// Type Typechecker::VisitExpr_(const LetNode *op) { +// Let let = GetRef(op); + +// Type checked_ty; +// Type annotated_ty = resolve(let->type); + +// // if we are let-defining a function, treat it as a let-rec and insert +// // the id with the annotated type in case there is recursion; +// // no such recursion permitted with anything that's not a function! +// if (let->value.as()) { +// with_frame([&]() { +// local_stack.insert(let->id, annotated_ty); +// checked_ty = Check(let->value); +// }); +// } else { +// checked_ty = Check(let->value); +// } + +// // ensure annotated type and checked type are compatible +// // TODO(sslyu): should the annotated type override the unified one? +// Type unified_ty = +// this->unify(checked_ty, simple_eval_shape(annotated_ty), let->span); + +// return with_frame([&]() { +// local_stack.insert(let->id, unified_ty); +// return Check(let->body); +// }); +// } + +// Type Typechecker::VisitExpr_(const ReverseNode *op) { +// // apply reverse mode to node and typecheck that instead +// std::shared_ptr gf = std::make_shared(); +// return this->Check(ReverseExpr(env, op->node, gf)); +// } + +// Type Typechecker::VisitExpr_(const GradientNode *op) { +// auto node = op->node; +// this->Check(node); +// auto gf = std::make_shared(); +// return FOWithGradientType(node->checked_type()); +// } + +// Type Typechecker::VisitExpr_(const ProjectionNode *op) { +// Projection proj = GetRef(op); + +// Type tup_type = this->Check(proj->tuple); + +// const TupleTypeNode *ptn = tup_type.as(); +// if (!ptn) { +// this->fatal_error("Cannot project into non-product type", op->span); +// } + +// TupleType pt = GetRef(ptn); +// size_t field = (size_t)proj->field; +// if (field >= pt->fields.size()) { +// this->fatal_error("Projecting past bounds of product", op->span); +// } + +// return pt->fields[field]; +// } + +// Type Typechecker::VisitExpr_(const IfNode *op) { +// If ifn = GetRef(op); + +// // Ensure the type of the guard is of Tensor[Bool, ()], +// // that is a rank-0 boolean tensor. +// Type guardType = this->Check(ifn->guard); +// bool is_bool = false; +// bool zero_rank = false; +// if (const TensorTypeNode *ttn = guardType.as()) { +// TensorType tt = GetRef(ttn); + +// if (const BaseTypeNode *btn = tt->dtype.as()) { +// is_bool = btn->type.is_bool(); +// } + +// Type shape = simple_eval_shape(tt->shape); + +// if (const ShapeSeqNode *sn = shape.as()) { +// zero_rank = (sn->shapes.size() == 0); +// } +// } + +// if (!(is_bool && zero_rank)) { +// this->fatal_error("IfNode guard must be a rank 0 bool tensor", +// ifn->guard->span); +// } + +// // unify types of different branches +// Type left = this->Check(ifn->true_b); +// Type right = this->Check(ifn->false_b); +// return this->unify(left, right, ifn->span); +// } + +// Type Typechecker::VisitExpr_(const RefNode *op) { +// Ref r = GetRef(op); +// Type inner = this->Check(r->expr); +// return RefTypeNode::make(inner); +// } + +// Type Typechecker::VisitExpr_(const ReadRefNode *op) { +// ReadRef vr = GetRef(op); +// Type ref_type = this->Check(vr->ref); + +// // reject if not a ref type +// const RefTypeNode *rtn = ref_type.as(); +// if (!rtn) { +// this->fatal_error( +// "the de-reference operation can only be used with references", +// op->span); +// } + +// RefType rt = GetRef(rtn); +// return rt->data_type; +// } + +// Type Typechecker::VisitExpr_(const WriteRefNode *op) { +// WriteRef sr = GetRef(op); +// Type ref_type = this->Check(sr->ref); + +// const RefTypeNode *rtn = ref_type.as(); +// if (!rtn) { +// this->fatal_error("Cannot mutate non-ref", op->span); +// } +// RefType rt = GetRef(rtn); + +// // ensure ref type's inner type and expr's type are compatible; return unit +// Type expr_type = this->Check(sr->val); +// this->unify(rt->data_type, expr_type, sr->span); +// return UnitType(); +// } + +// Type Typechecker::resolve(const Type &t) { +// return ::tvm::relay::resolve(this->unifier, t); +// } + +// Expr Typechecker::resolve(const Expr &e) { +// return ::tvm::relay::resolve(this->unifier, e); +// } + +// Type Typechecker::simple_eval_shape(const Type &shape) { +// // TODO(sslyu): Do we want to propagate attributes? +// Attributes empty = AttributesNode::make({}); +// return evaluate_concrete_shape(shape, empty); +// } + +// Operator Typechecker::CheckOp(Operator op) { +// if (!check_kind(op->type)) { +// report_error("the type of the operator is ill formed", op->type->span); +// } + +// // Fix me +// return op; +// } + +// Defn Typechecker::CheckDefn(Defn defn) { +// // This is to handle recursion, but we need to speculatively +// // put it in env, then remove it. +// env->items.insert({defn->id, defn}); + +// Type expected_ty = this->resolve(defn->type); + +// Expr body = defn->body; + +// auto checked_ty = Check(body); + +// try { +// Type uret_type = unify(expected_ty, checked_ty, defn->body->span); +// CHECK(is_fully_resolved(uret_type)); +// // Now let's clean up our work from earlier. +// env->items.erase(defn->id); +// return DefnNode::make(defn->id, uret_type, this->resolve(defn->body)); +// } catch (const UnificationError& err) { +// std::string msg = std::string("mismatch between `") + +// PrintType(env, expected_ty, WrapWidth(40)) + "` and `" + +// PrintType(env, checked_ty, WrapWidth(40)) + "`"; +// fatal_error(msg, defn->span); +// } +// } + +// Type check(const Environment &env, const Expr &e) { +// Typechecker tc(env); +// return tc.Check(e); +// } + +// Item check(const Environment &env, const Item &i) { +// Typechecker tc(env); + +// try { +// if (const DefnNode *defn = i.as()) { +// return tc.CheckDefn(GetRef(defn)); +// } else if (const OperatorNode *op_node = i.as()) { +// return tc.CheckOp(GetRef(op_node)); +// } else { +// throw dmlc::Error("internal error: unknown Item type"); +// } +// } catch (const FatalTypeError &err) { +// env->display_errors(); +// throw dmlc::Error( +// "We encountered a fatal error while type checking your program, please " +// "read above for more details."); +// } +// } + +// inline void Typechecker::report_error(const std::string &msg, Span sp) { +// this->env->report_error(msg, sp); +// } + +// void Typechecker::fatal_error(const std::string &msg, Span sp) { +// this->env->report_error(msg, sp); +// throw FatalTypeError( +// "internal error: this exception should" +// "be handled and errors reported with Environment::display_errors\n" + +// msg); +// } + +// Type Typechecker::unify(const Type &t1, const Type &t2, Span sp) { +// try { +// return this->unifier->unify(t1, t2); +// } catch (const dmlc::Error &e) { +// std::stringstream ss; +// ss << "Error unifying `"; +// ss << PrintType(env, t1, WrapWidth(40)); +// ss << "` and `"; +// ss << PrintType(env, t2, WrapWidth(40)); +// ss << "`: " << e.what(); +// this->fatal_error(ss.str(), sp); +// } +// } + +// // template + +// // Add safe dynamic Array downcast. +// // Add static upcast? + +// // Add to type utils. +// Array type_parameters(const Type &t) { +// Array params; +// auto type = t; +// const TypeQuantifierNode *ty_quant; +// while ((ty_quant = type.as())) { +// params.push_back(ty_quant->id); +// type = ty_quant->boundType; +// } + +// return params; +// } + +// template +// Array ArrayMap(const Array &data, F f) { +// // probably a way to use std::transform. +// Array output; +// for (const I &el : data) { +// output.push_back(f(el)); +// } +// return output; +// } + +// // There are some important questions around generalization +// // that we need to answer. +// Expr generalize(const Environment &env, const Expr &e) { +// if (auto fn_node = e.as()) { +// Typechecker tc(env); +// auto ty = tc.VisitFunction(GetRef(fn_node), true); +// auto ty_params = type_parameters(ty); +// auto params = ArrayMap(fn_node->params, [&](const Param &p) { +// return ParamNode::make(p->id, tc.resolve(p->type)); +// }); +// auto body = tc.resolve(fn_node->body); +// auto ret_type = tc.resolve(fn_node->ret_type); +// auto fn = FunctionNode::make(ty_params, params, ret_type, body); +// // we should check in empty context to ensure typing is preserved. +// // check(env, fn); +// return fn; +// } else { +// throw dmlc::Error("can only apply generalize to a function."); +// } +// } + +// TVM_REGISTER_API("relay._tyck.check_expr") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// Expr e = args[1]; +// *ret = check(env, e); +// }); + +// TVM_REGISTER_API("relay._tyck.check_item") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Environment env = args[0]; +// Item i = args[1]; +// *ret = check(env, i); +// }); + +// TVM_REGISTER_API("relay._tyck.get_checked_type") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// Expr e = args[0]; +// *ret = e->checked_type(); +// }); + +// TVM_REGISTER_API("relay._tyck.generalize") +// .set_body([](TVMArgs args, TVMRetValue *ret) { +// *ret = generalize(args[0], args[1]); +// }); + +IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) { + std::shared_ptr n = std::make_shared(); + n->kind = std::move(kind); + return IncompleteType(n); +} + +TVM_REGISTER_API("relay._make.IncompleteType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + int kind = args[0]; + *ret = IncompleteTypeNode::make(static_cast(kind)); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const IncompleteTypeNode *node, + tvm::IRPrinter *p) { + p->stream << "IncompleteTypeNode(" << node->kind << ", " << &node << ")"; + }); + +} // namespace relay +} // namespace tvm From 32bb3c2d6620b061ecfc3f971090b1af81b8595f Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 14:06:44 -0700 Subject: [PATCH 019/136] Add type_visitor.h --- src/relay/compiler/type_visitor.h | 107 ++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 src/relay/compiler/type_visitor.h diff --git a/src/relay/compiler/type_visitor.h b/src/relay/compiler/type_visitor.h new file mode 100644 index 000000000000..5ae100a8de6d --- /dev/null +++ b/src/relay/compiler/type_visitor.h @@ -0,0 +1,107 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file type_visitor.h + * \brief A wrapper around TypeFunctor for common use cases. + */ +#ifndef TVM_RELAY_TYPE_VISITOR_H_ +#define TVM_RELAY_TYPE_VISITOR_H_ + +#include +#include "./type_functor.h" + +namespace tvm { +namespace relay { + +/*! \brief A type visitor for vistiors which make use of internal + * mutable state. + * + * We recursively visit each type contained inside the visitor. + */ +template +struct TypeVisitor : TypeFunctor { + // void VisitType_(const TypeVarNode* op, Args... args) override {} + void VisitType_(const TypeParamNode* op, Args... args) override {} + + void VisitType_(const FuncTypeNode* op, Args... args) override { + // this->VisitType(op->id, args...); + // this->VisitType(op->boundType, args...); + // for (auto arg_type : op->arg_types) { + // this->VisitType(arg_type, args...); + // } + // this->VisitType(op->ret_type, args...); + } + + void VisitType_(const TensorTypeNode* op, Args... args) override { + // this->VisitType(op->dtype, args...); + // this->VisitType(op->shape, args...); + } + +// void VisitType_(const TupleTypeNode* op, Args... args) override { +// for (const Type& t : op->fields) { +// this->VisitType(t, args...); +// } +// } + +// void VisitType_(const TypeCallNode* op, Args... args) override { +// for (const Type& t : op->args) { +// this->VisitType(t, args...); +// } +// } + + void VisitType_(const TypeFunctionNode* op, Args... args) override {} + void VisitType_(const IncompleteTypeNode* op, Args... args) override {} +}; + +// A functional visitor for rebuilding an AST in place. +struct TypeFVisitor : TypeFunctor { + Type VisitType_(const TensorTypeNode* op) override { + // TODO (@jroesch): maybe we should recursively visit + return TensorTypeNode::make(op->shape, op->dtype); + } + + Type VisitType_(const TypeParamNode* op) override { + return GetRef(op); + } + +// Type VisitType_(const TypeArrowNode* op) override { +// std::vector args; +// for (auto arg_type : op->arg_types) { +// args.push_back(VisitType(arg_type)); +// } +// return TypeArrowNode::make(tvm::Array(args), VisitType(op->ret_type)); +// } + +// Type VisitType_(const TypeQuantifierNode* op) override { +// auto new_id = this->VisitType(op->id); +// if (const TypeParamNode* tin = new_id.as()) { +// return TypeQuantifierNode::make(GetRef(tin), +// this->VisitType(op->boundType)); +// } else { +// throw dmlc::Error("Cannot quantify something that is not a type ID"); +// } +// } + +// Type VisitType_(const TupleTypeNode* op) override { +// std::vector new_fields; +// for (const Type& t : op->fields) { +// new_fields.push_back(this->VisitType(t)); +// } +// return TupleTypeNode::make(new_fields); +// } + +// Type VisitType_(const TypeCallNode* op) override { +// auto func = this->VisitType(op->func); +// std::vector new_args; +// for (const Type& t : op->args) { +// new_args.push_back(this->VisitType(t)); +// } +// return TypeCallNode::make(func, new_args); +// } + Type VisitType_(const IncompleteTypeNode* op) override { + return GetRef(op); + } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_TYPE_VISITOR_H_ From e044e89373739b04b68e1d03e4e6a0e5d047e623 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 14:07:13 -0700 Subject: [PATCH 020/136] Add expr_visitor.h --- include/tvm/relay/expr_visitor.h | 166 +++++++++++++++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 include/tvm/relay/expr_visitor.h diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h new file mode 100644 index 000000000000..d7ac1465f70a --- /dev/null +++ b/include/tvm/relay/expr_visitor.h @@ -0,0 +1,166 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file expr_visitor.h + * \brief A simple visitor wrapper around ExprFunctor designed for visitors which + * maintain mutable state. + */ +#ifndef TVM_RELAY_EXPR_VISITOR_H_ +#define TVM_RELAY_EXPR_VISITOR_H_ + +#include "expr_functor.h" + +namespace tvm { +namespace relay { + +template +class ExprVisitor : public ExprFunctor { + public: + void VisitExpr_(const LocalVarNode* op, Args... args) override { return; } + + void VisitExpr_(const GlobalVarNode* op, Args... args) override { return; } + + void VisitExpr_(const ConstantNode* op, Args... args) override { return; } + + void VisitExpr_(const TupleNode* op, Args... args) override { + for (auto field : op->fields) { + this->VisitExpr(field, args...); + } + } + + void VisitExpr_(const ParamNode* op, Args... args) override { + this->VisitExpr(op->var, args...); + } + + void VisitExpr_(const FunctionNode* op, Args... args) override { + for (auto param : op->params) { + this->VisitExpr(param, args...); + } + + this->VisitExpr(op->body, args...); + } + + void VisitExpr_(const CallNode* op, Args... args) override { + this->VisitExpr(op->op, args...); + for (auto arg : op->args) { + this->VisitExpr(arg, args...); + } + } + + void VisitExpr_(const LetNode* op, Args... args) override { + this->VisitExpr(op->var, args...); + this->VisitExpr(op->value, args...); + this->VisitExpr(op->body, args...); + } + + void VisitExpr_(const IfNode* op, Args... args) override { + this->VisitExpr(op->cond, args...); + this->VisitExpr(op->true_value, args...); + this->VisitExpr(op->false_value, args...); + } + + void VisitExpr_(const OperatorNode* op, Args... args) override { return; } +}; + +template +class ExprFVisitor : public ExprFunctor { + public: + Expr VisitExpr_(const LocalVarNode* op, Args... args) override { + return GetRef(op); + } + + Expr VisitExpr_(const GlobalVarNode* op, Args... args) override { + return GetRef(op); + } + + Expr VisitExpr_(const OperatorNode* op, Args... args) override { + return GetRef(op); + } + + Expr VisitExpr_(const TupleNode* op, Args... args) override { + tvm::Array fields; + for (auto field : op->fields) { + fields.push_back(this->VisitExpr(field, args...)); + } + + return TupleNode::make(fields); + } + + Expr VisitExpr_(const ParamNode* op, Args... args) override { + Expr var_expr = this->VisitExpr(op->var, args...); + if (const LocalVarNode* var_node = var_expr.as()) { + auto var = GetRef(var_node); + auto type = this->VisitType(op->type, args...); + return ParamNode::make(var, type); + } else { + throw dmlc::Error("the default param visitor has bug"); + } + } + + Expr VisitExpr_(const FunctionNode* op, Args... args) override { + tvm::Array ty_params; + for (auto ty : op->type_params) { + ty_params.push_back(this->VisitType(ty, args...)); + } + + tvm::Array params; + for (auto param : op->params) { + Expr param_expr = this->VisitExpr(param, args...); + if (const ParamNode* param_node = param_expr.as()) { + auto param = GetRef(param_node); + params.push_back(param); + } else { + throw dmlc::Error("the default func visitor has bug"); + } + } + + auto ret_type = this->VisitType(op->ret_type, args...); + auto body = this->VisitExpr(op->body, args...); + return FunctionNode::make(ty_params, params, ret_type, body); + } + + Expr VisitExpr_(const CallNode* call_node, Args... args) override { + auto fn = this->VisitExpr(call_node->op, args...); + + tvm::Array ty_args; + for (auto ty_arg : call_node->type_args) { + auto new_ty_arg = this->VisitType(ty_arg, args...); + ty_args.push_back(new_ty_arg); + } + + tvm::Array call_args; + for (auto arg : call_node->args) { + call_args.push_back(this->VisitExpr(arg, args...)); + } + + auto call = CallNode::make(fn, call_args, call_node->attrs); + call->ty_args = ty_args; + + return call; + } + + Expr VisitExpr_(const LetNode* op, Args... args) override { + Expr var_expr = this->VisitExpr(op->var, args...); + if (const LocalVarNode* var_node = var_expr.as()) { + auto var = GetRef(var_node); + auto type = this->VisitType(op->value_type, args...); + auto value = this->VisitExpr(op->value, args...); + auto body = this->VisitExpr(op->body, args...); + return LetNode::make(var, type, value, body); + } else { + throw dmlc::Error("the default let visitor has error"); + } + } + + Expr VisitExpr_(const IfNode* op, Args... args) override { + auto guard = this->VisitExpr(op->cond, args...); + auto true_b = this->VisitExpr(op->true_value, args...); + auto false_b = this->VisitExpr(op->false_value, args...); + return IfNode::make(guard, true_b, false_b); + } + + virtual Type VisitType(const Type& t, Args... args) { return t; } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_EXPR_VISITOR_H_ From 0795d41f80400df40ce469a7bbf4e93afd86873a Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 14:07:25 -0700 Subject: [PATCH 021/136] Start reparing unifier and tests --- python/tvm/relay/ir_builder.py | 11 ++++++++ tests/python/relay/test_unifier.py | 43 ++++++++++++------------------ 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 497479140ec9..3c842e480c70 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -100,5 +100,16 @@ def get(self): return _mk_let(bindings, self.ret_value) +# def int_type(): +# return TensorType(IntType(32), ShapeSeq([])) + +# def float_type(): +# return TensorType(FloatType(32), ShapeSeq([])) + +# def bool_type(): +# return TensorType(BoolType(), ShapeSeq([])) + +# def make_shape(dims): +# return ShapeSeq([ShapeSingleton(dim) for dim in dims]) diff --git a/tests/python/relay/test_unifier.py b/tests/python/relay/test_unifier.py index 875502808563..b2ed075ca3de 100644 --- a/tests/python/relay/test_unifier.py +++ b/tests/python/relay/test_unifier.py @@ -2,30 +2,14 @@ Test the type unifier, which solves systems of equations between incomplete types. """ -import relay.ir -import relay.unifier - - -def unify_types(t1, t2): - unifier = TypeUnifier() - return unifier.unify(t1, t2) - -def int_type(): - return TensorType(IntType(32), ShapeSeq([])) - -def float_type(): - return TensorType(FloatType(32), ShapeSeq([])) - -def bool_type(): - return TensorType(BoolType(), ShapeSeq([])) - -def make_shape(dims): - return ShapeSeq([ShapeSingleton(dim) for dim in dims]) +import tvm.relay.ir +from tvm.relay.unifier import UnionFind, TypeUnifier +import tvm.relay.make as mk def test_insert_and_find(): uf = UnionFind() - v1 = TypeVar(ir.Kind.Type) - v2 = TypeVar(ir.Kind.Type) + v1 = mk.TypeVar(ir.Kind.Type) + v2 = mk.TypeVar(ir.Kind.Type) uf.insert(v1) uf.insert(v2) assert uf.find(v1) == v1 @@ -33,8 +17,8 @@ def test_insert_and_find(): def test_insert_error(): uf = UnionFind() - v1 = TypeVar(ir.Kind.Type) - v2 = TypeVar(ir.Kind.Type) + v1 = mk.TypeVar(ir.Kind.Type) + v2 = mk.TypeVar(ir.Kind.Type) uf.insert(v1) try: uf.find(v2) @@ -44,9 +28,9 @@ def test_insert_error(): def test_unify(): uf = UnionFind() - v1 = TypeVar(ir.Kind.Type) - v2 = TypeVar(ir.Kind.Type) - v3 = TypeVar(ir.Kind.Type) + v1 = mk.TypeVar(ir.Kind.Type) + v2 = mk.TypeVar(ir.Kind.Type) + v3 = mk.TypeVar(ir.Kind.Type) uf.insert(v1) uf.insert(v2) uf.insert(v3) @@ -97,6 +81,13 @@ def test_unify_multiple_levels(): for i in range(3): assert uf.find(v[i + 6]) == new_rep2 +# We have checked that the basic machinery in the UnionFind works +# and now we will test the type unifier which will fill in holes +# between type equalities by the process of unification. +def unify_types(t1, t2): + unifier = TypeUnifier() + return unifier.unify(t1, t2) + # TODO(sslyu, weberlo, joshpoll): put in isinstance asserts once those work def test_unify_int(): intty = IntType(1) From b6803b5477c5a37c335be821bade0c9e424aaa92 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Aug 2018 14:21:12 -0700 Subject: [PATCH 022/136] Fix test_unifier.py, now runs but all tests fail --- python/tvm/relay/make.py | 5 + tests/python/relay/test_alpha_eq.py | 1148 +++++++++++++-------------- tests/python/relay/test_unifier.py | 130 +-- 3 files changed, 643 insertions(+), 640 deletions(-) diff --git a/python/tvm/relay/make.py b/python/tvm/relay/make.py index 14d9ac040dc9..a2b87f2700af 100644 --- a/python/tvm/relay/make.py +++ b/python/tvm/relay/make.py @@ -18,3 +18,8 @@ Call = _make.Call Let = _make.Let If = _make.If +IncompleteType = _make.IncompleteType + +# Unifier +UnionFind = _make.UnionFind +TypeUnifier = _make.TypeUnifier diff --git a/tests/python/relay/test_alpha_eq.py b/tests/python/relay/test_alpha_eq.py index f1dc81c3c483..e4fbbcca93ce 100644 --- a/tests/python/relay/test_alpha_eq.py +++ b/tests/python/relay/test_alpha_eq.py @@ -1,576 +1,574 @@ """Test alpha-equivalence of expressions and types.""" -# pylint: disable=invalid-name, missing-docstring -# pylint: disable=wildcard-import, unused-wildcard-import -from relay.make import * -from relay.ir import alpha_eq, ShapeOp, Kind -from relay.typing import TYPE_DEFAULTS -from relay import ir - -INT_TYPE_WIDTH = TYPE_DEFAULTS["INT_WIDTH"] -INT_TYPE_LANES = TYPE_DEFAULTS["INT_LANES"] - -def int_type(width=32) -> ir.Type: - return TensorType(IntType(width), ShapeSeq([])) - -def float_type(width=32) -> ir.Type: - return TensorType(FloatType(width), ShapeSeq([])) - -def bool_type() -> ir.Type: - return TensorType(BoolType(), ShapeSeq([])) - -def nest_quantifiers(ids, body) -> ir.Type: - ret = body - for tid in reversed(ids): - ret = TypeQuantifier(tid, ret) - return ret - -def test_local_id_not_eq() -> None: - assert not alpha_eq(LocalId("x"), LocalId("y")) - -def test_local_id_eq() -> None: - x = LocalId("x") - assert alpha_eq(x, x) - -def test_global_id_not_eq() -> None: - left = GlobalId("xyz") - right = GlobalId("xyz") - assert not alpha_eq(left, right) - -def test_global_id_eq() -> None: - ident = GlobalId("xyz") - assert alpha_eq(ident, ident) - -def test_operator_id_not_eq() -> None: - left = OperatorId("xyz") - right = OperatorId("xyz") - # equality on operator id is pointer equality - assert not alpha_eq(left, right) - -def test_operator_id_eq() -> None: - x = OperatorId("xyz") - assert alpha_eq(x, x) - -def test_float_literal_eq() -> None: - x = FloatLit(1.0) - y = FloatLit(1.0) - assert alpha_eq(x, y) - -def test_float_literal_not_eq() -> None: - x = FloatLit(1.0) - y = FloatLit(2.0) - assert not alpha_eq(x, y) - -def test_int_literal_eq() -> None: - x = IntLit(1) - y = IntLit(1) - assert alpha_eq(x, y) - -def test_int_literal_not_eq() -> None: - x = IntLit(1) - y = IntLit(2) - assert not alpha_eq(x, y) - -def test_bool_literal_eq() -> None: - x = BoolLit(True) - y = BoolLit(True) - assert alpha_eq(x, y) - -def test_bool_literal_not_eq() -> None: - x = BoolLit(True) - y = BoolLit(False) - assert not alpha_eq(x, y) - -def test_tensor_literal_eq() -> None: - x = TensorLit([IntLit(1), IntLit(2)]) - y = TensorLit([IntLit(1), IntLit(2)]) - assert alpha_eq(x, y) - -def test_tensor_literal_not_eq() -> None: - x = TensorLit([IntLit(1), IntLit(2)]) - y = TensorLit([IntLit(1), IntLit(3)]) - z = TensorLit([IntLit(1)]) - assert not alpha_eq(x, y) - assert not alpha_eq(x, z) - -def test_product_literal_eq() -> None: - x = Tuple([IntLit(1), IntLit(2)]) - y = Tuple([IntLit(1), IntLit(2)]) - assert alpha_eq(x, y) - -def test_product_literal_not_eq() -> None: - x = Tuple([IntLit(1), IntLit(2)]) - y = Tuple([IntLit(2), IntLit(2)]) - z = Tuple([IntLit(1), IntLit(2), IntLit(3)]) - assert not alpha_eq(x, y) - assert not alpha_eq(x, z) - -def test_projection_eq() -> None: - prod = Tuple([IntLit(3), FloatLit(3.5)]) - - assert alpha_eq(Projection(prod, 0), Projection(prod, 0)) - assert alpha_eq(Projection(prod, 1), Projection(prod, 1)) - -def test_projection_not_eq() -> None: - prod1 = Tuple([IntLit(3), IntLit(4)]) - prod2 = Tuple([IntLit(3)]) - prod3 = Tuple([IntLit(3), IntLit(4), FloatLit(3.5)]) - - assert not alpha_eq(Projection(prod1, 0), Projection(prod1, 1)) - assert not alpha_eq(Projection(prod1, 0), Projection(prod2, 0)) - assert not alpha_eq(Projection(prod1, 0), Projection(prod3, 0)) - assert not alpha_eq(Projection(prod1, 1), Projection(prod3, 1)) - -def test_cast_not_eq() -> None: - left = Cast(IntType(1), IntLit(2)) - right = Cast(IntType(1), IntLit(1)) - assert not alpha_eq(left, right) - - # same literal, different type - left = Cast(IntType(1), IntLit(2)) - right = Cast(IntType(2), IntLit(2)) - assert not alpha_eq(left, right) - -def test_cast_eq() -> None: - left = Cast(IntType(1), IntLit(2)) - right = Cast(IntType(1), IntLit(2)) - assert alpha_eq(left, right) - -def test_param_not_eq() -> None: - left = Param(LocalId("foo"), int_type()) - right = Param(LocalId("foo"), bool_type()) - assert not alpha_eq(left, right) - -def test_param_eq() -> None: - left = Param(LocalId("foo"), int_type()) - right = Param(LocalId("bar"), int_type()) - assert alpha_eq(left, right) - -def test_function_not_eq() -> None: - params1 = [Param(LocalId("x"), int_type())] - fn1 = Function([], params1, int_type(), LocalId("x")) - params2 = [Param(LocalId("y"), bool_type())] - fn2 = Function([], params2, int_type(), LocalId("y")) - assert not alpha_eq(fn1, fn2) - - params3 = [Param(LocalId("x"), int_type()), Param(LocalId("y"), int_type())] - fn3 = Function([], params3, int_type(), LocalId("z")) - assert not alpha_eq(fn1, fn3) - -def test_function_eq() -> None: - x = LocalId("x") - y = LocalId("y") - params1 = [Param(x, int_type())] - fn1 = Function([], params1, int_type(), x) - params2 = [Param(y, int_type())] - fn2 = Function([], params2, int_type(), y) - assert alpha_eq(fn1, fn2) - -def test_call_not_eq() -> None: - x = LocalId("x") - y = LocalId("y") - params1 = [Param(x, int_type())] - fn1 = Function([], params1, int_type(), x) - args1 = [IntLit(1)] - call1 = Call(fn1, args1) - - args2 = [IntLit(2)] - call2 = Call(fn1, args2) - assert not alpha_eq(call1, call2) - - params2 = [Param(y, int_type())] - fn2 = Function([], params2, float_type(), FloatLit(0.0)) - call3 = Call(fn2, args1) - assert not alpha_eq(call1, call3) - assert not alpha_eq(call2, call3) - -def test_call_eq() -> None: - x = LocalId("x") - y = LocalId("y") - params1 = [Param(x, int_type())] - fn1 = Function([], params1, int_type(), x) - args = [IntLit(1)] - call1 = Call(fn1, args) - - params2 = [Param(y, int_type())] - fn2 = Function([], params2, int_type(), y) - call2 = Call(fn2, args) - assert alpha_eq(call1, call2) - -def test_debug_not_eq() -> None: - left = Debug(IntLit(1)) - right = Debug(IntLit(2)) - assert not alpha_eq(left, right) - -def test_debug_eq() -> None: - left = Debug(IntLit(1)) - right = Debug(IntLit(1)) - assert alpha_eq(left, right) - -def test_let_not_eq() -> None: - x = LocalId("x") - y = LocalId("y") - let1 = Let(x, int_type(), IntLit(10), IntLit(11)) - let2 = Let(y, int_type(), IntLit(10), IntLit(12)) - assert not alpha_eq(let1, let2) - - let3 = Let(x, int_type(), IntLit(10), x) - let4 = Let(y, int_type(), IntLit(12), y) - assert not alpha_eq(let3, let4) - -def test_let_eq() -> None: - x = LocalId("x") - y = LocalId("y") - let1 = Let(x, int_type(), IntLit(10), x) - let2 = Let(y, int_type(), IntLit(10), y) - assert alpha_eq(let1, let2) - -def test_ref_eq() -> None: - r1 = Ref(IntLit(5)) - r2 = Ref(IntLit(5)) - assert alpha_eq(r1, r2) - -def test_ref_not_eq() -> None: - r1 = Ref(IntLit(5)) - r2 = Ref(FloatLit(3.5)) - r3 = Ref(r1) - assert not alpha_eq(r1, r2) - assert not alpha_eq(r1, r3) - assert not alpha_eq(r2, r3) - -def test_val_ref_eq() -> None: - vr1 = ReadRef(Ref(IntLit(35))) - vr2 = ReadRef(Ref(Tuple([IntLit(12), FloatLit(2.5)]))) - assert alpha_eq(vr1, vr1) - assert alpha_eq(vr2, vr2) - -def test_val_ref_not_eq() -> None: - vr1 = ReadRef(Ref(IntLit(5))) - vr2 = ReadRef(Ref(vr1)) - vr3 = ReadRef(Ref(FloatLit(5.0))) - assert not alpha_eq(vr1, vr2) - assert not alpha_eq(vr1, vr3) - assert not alpha_eq(vr2, vr3) - -def test_set_ref_eq() -> None: - sr1 = WriteRef(Ref(FloatLit(5.0)), FloatLit(6.0)) - sr2 = WriteRef(Ref(Tuple([IntLit(3), BoolLit(False)])), - Tuple([IntLit(5), BoolLit(True)])) - assert alpha_eq(sr1, sr1) - assert alpha_eq(sr2, sr2) - -def test_set_ref_not_eq() -> None: - r1 = Ref(FloatLit(5.0)) - r2 = Ref(IntLit(5)) - r3 = Ref(IntLit(6)) - - assert not alpha_eq(WriteRef(r1, FloatLit(6.0)), - WriteRef(r2, IntLit(6))) - assert not alpha_eq(WriteRef(r2, IntLit(6)), WriteRef(r2, IntLit(7))) - assert not alpha_eq(WriteRef(r2, IntLit(7)), WriteRef(r3, IntLit(7))) - -# Type alpha-equality tests - -def test_base_type_eq() -> None: - assert alpha_eq(IntType(32), IntType(32)) - assert alpha_eq(BoolType(), BoolType()) - assert alpha_eq(FloatType(32), FloatType(32)) - -def test_tensor_type_eq() -> None: - tt1 = TensorType( - IntType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) - tt2 = TensorType( - FloatType(32), ShapeSeq([ShapeSingleton(3), ShapeSingleton(3)])) - assert alpha_eq(tt1, tt1) - assert alpha_eq(tt2, tt2) - -def test_tensor_type_not_eq() -> None: - tt1 = TensorType( - IntType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) - tt2 = TensorType( - FloatType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) - tt3 = TensorType( - IntType(32), ShapeSeq([ShapeSingleton(3), ShapeSingleton(3)])) - assert not alpha_eq(tt1, tt2) - assert not alpha_eq(tt1, tt3) - -def test_ref_type_eq() -> None: - rt1 = RefType(int_type()) - rt2 = RefType(float_type()) - assert alpha_eq(rt1, rt1) - assert alpha_eq(rt2, rt2) - -def test_ref_type_not_eq() -> None: - rt1 = RefType(int_type()) - rt2 = RefType(float_type()) - assert not alpha_eq(rt1, rt2) - -def test_product_type_eq() -> None: - pt1 = TupleType([int_type(), RefType(float_type())]) - pt2 = TupleType([float_type(), float_type(), int_type()]) - assert alpha_eq(pt1, pt1) - assert alpha_eq(pt2, pt2) - -def test_product_type_not_eq() -> None: - pt1 = TupleType([int_type(), int_type()]) - pt2 = TupleType([int_type(), int_type(), float_type()]) - pt3 = TupleType([bool_type(), float_type()]) - assert not alpha_eq(pt1, pt2) - assert not alpha_eq(pt1, pt3) - -def test_type_id_eq() -> None: - id1 = TypeParam("id1", Kind.Shape) - id2 = TypeParam("id2", Kind.BaseType) - id3 = TypeParam("id2", Kind.Type) - - assert alpha_eq(id1, id1) - assert alpha_eq(id2, id2) - assert alpha_eq(id3, id3) - -def test_type_id_not_eq() -> None: - # name is just a hint, we use pointer equality as the rule - # (unless there is a quantifier to give context) - id1 = TypeParam("id1", Kind.Shape) - id2 = TypeParam("id1", Kind.Shape) - id3 = TypeParam("id3", Kind.BaseType) - - assert not alpha_eq(id1, id2) - assert not alpha_eq(id1, id3) - -def test_arrow_type_eq() -> None: - ar1 = TypeArrow([int_type()], bool_type()) - ar2 = TypeArrow([int_type(), int_type()], TupleType([])) - assert alpha_eq(ar1, ar1) - assert alpha_eq(ar2, ar2) - -def test_arrow_type_not_eq() -> None: - t1 = int_type() - t2 = bool_type() - t3 = [int_type(), bool_type()] - - assert not alpha_eq(TypeArrow([t1], t2), TypeArrow([t1], t1)) - assert not alpha_eq(TypeArrow(t3, t1), TypeArrow([t2], t1)) - assert not alpha_eq(TypeArrow([t1], TypeArrow([t1], t1)), - TypeArrow([t1], t1)) - -def test_type_quantifier_eq() -> None: - id1 = TypeParam("id1", Kind.Shape) - id2 = TypeParam("id2", Kind.Shape) - tq1 = TypeQuantifier(id1, TensorType(IntType(32), id1)) - tq2 = TypeQuantifier(id2, TensorType(IntType(32), id2)) - - assert alpha_eq(tq1, tq1) - assert alpha_eq(tq1, tq2) - -def test_nested_type_quantifier_eq() -> None: - id1 = TypeParam("id1", Kind.BaseType) - id2 = TypeParam("id2", Kind.Shape) - id3 = TypeParam("id3", Kind.BaseType) - id4 = TypeParam("id4", Kind.Shape) - tq1 = TypeQuantifier(id1, TypeQuantifier(id2, TensorType(id1, id2))) - tq2 = TypeQuantifier(id3, TypeQuantifier(id4, TensorType(id3, id4))) - - assert alpha_eq(tq1, tq1) - assert alpha_eq(tq1, tq2) - -def test_type_quantifier_not_eq() -> None: - id1 = TypeParam("id1", Kind.Shape) - id2 = TypeParam("id2", Kind.BaseType) - id3 = TypeParam("id3", Kind.Shape) - - tq1 = TypeQuantifier(id1, TensorType(IntType(32), id1)) - tq2 = TypeQuantifier(id2, TensorType(id2, ShapeSeq([ShapeSingleton(3)]))) - tq3 = TypeQuantifier(id1, TensorType(IntType(32), id3)) - tq4 = TypeQuantifier(id1, TensorType(FloatType(32), id1)) - - assert not alpha_eq(tq1, tq2) - assert not alpha_eq(tq1, tq3) - assert not alpha_eq(tq1, tq4) - assert not alpha_eq(tq2, tq3) - assert not alpha_eq(tq2, tq4) - -def test_shape_singleton_eq() -> None: - single1 = ShapeSingleton(10) - single2 = ShapeSingleton(10) - - assert alpha_eq(single1, single1) - assert alpha_eq(single1, single2) - -def test_shape_singelton_not_eq() -> None: - single1 = ShapeSingleton(10) - single2 = ShapeSingleton(11) - - assert not alpha_eq(single1, single2) - -def test_shape_attr_eq() -> None: - attr1 = ShapeAttr("x") - attr2 = ShapeAttr("x") - - assert alpha_eq(attr1, attr1) - assert alpha_eq(attr1, attr2) - -def test_shape_attr_not_eq() -> None: - id1 = "x" - id2 = "y" - attr1 = ShapeAttr(id1) - attr2 = ShapeAttr(id2) - - assert not alpha_eq(attr1, attr2) - -def test_shape_seq_eq() -> None: - empty = ShapeSeq([]) - seq1 = ShapeSeq([ShapeSingleton(5)]) - seq2 = ShapeSeq([ShapeSingleton(5)]) - - assert alpha_eq(empty, empty) - assert alpha_eq(seq1, seq2) - -def test_shape_seq_not_eq() -> None: - empty = ShapeSeq([]) - seq = ShapeSeq([ShapeSingleton(5)]) - single = ShapeSingleton(5) - - assert not alpha_eq(empty, seq) - assert not alpha_eq(seq, single) - -def test_shape_projection_eq() -> None: - proj1 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) - proj2 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) - - assert alpha_eq(proj1, proj2) - -def test_shape_projection_not_eq() -> None: - proj1 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) - proj2 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 1) - proj3 = ShapeProjection(ShapeSeq([ShapeSingleton(2), ShapeSingleton(1)]), 0) - proj4 = ShapeProjection(ShapeSeq([ShapeSingleton(2), ShapeSingleton(1)]), 1) - - assert not alpha_eq(proj1, proj2) - assert not alpha_eq(proj1, proj3) - assert not alpha_eq(proj1, proj4) - assert not alpha_eq(proj2, proj3) - assert not alpha_eq(proj2, proj4) - assert not alpha_eq(proj3, proj4) - -def test_shape_binary_op_eq() -> None: - empty = ShapeSeq([]) - single = ShapeSingleton(5) - seq = ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]) - - op1 = ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty) - op2 = ShapeBinaryOp(ShapeOp.SHSUB, single, single) - op3 = ShapeBinaryOp(ShapeOp.SHMUL, seq, seq) - op4 = ShapeBinaryOp(ShapeOp.SHDIV, seq, seq) - - assert alpha_eq(op1, op1) - assert alpha_eq(op2, op2) - assert alpha_eq(op3, op3) - assert alpha_eq(op4, op4) - -def test_shape_binary_op_not_eq() -> None: - empty = ShapeSeq([]) - single = ShapeSingleton(5) - seq = ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]) - - assert not alpha_eq(ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty), empty) - assert not alpha_eq(ShapeBinaryOp(ShapeOp.SHMUL, seq, ShapeSingleton(1)), seq) - assert not alpha_eq( - ShapeBinaryOp(ShapeOp.SHPLUS, single, single), - ShapeBinaryOp(ShapeOp.SHPLUS, - ShapeSeq([single]), - ShapeSeq([single]))) - assert not alpha_eq( - ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty), - ShapeBinaryOp(ShapeOp.SHSUB, empty, empty)) - assert not alpha_eq( - ShapeBinaryOp(ShapeOp.SHMUL, empty, empty), - ShapeBinaryOp(ShapeOp.SHDIV, empty, empty)) - -def test_shape_nested_in_quantifier() -> None: - b1 = TypeParam("b", Kind.BaseType) - x1 = TypeParam("x", Kind.Shape) - y1 = TypeParam("y", Kind.Shape) - - b2 = TypeParam("b", Kind.BaseType) - x2 = TypeParam("x", Kind.Shape) - y2 = TypeParam("y", Kind.Shape) - - b3 = TypeParam("b", Kind.BaseType) - x3 = TypeParam("x", Kind.Shape) - y3 = TypeParam("y", Kind.Shape) - - tq1 = nest_quantifiers( - [b1, x1, y1], - TypeArrow( - [TensorType(b1, x1), TensorType(b1, y2)], - TensorType( - b1, - ShapeBinaryOp(ShapeOp.SHPLUS, - ShapeSeq([x1, ShapeProjection(y1, 1), - ShapeSingleton(5), ShapeAttr("att")]), - ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) - - tq2 = nest_quantifiers( - [b2, x2, y2], - TypeArrow( - [TensorType(b2, x2), TensorType(b2, y2)], - TensorType( - b2, - ShapeBinaryOp(ShapeOp.SHPLUS, - ShapeSeq([x2, ShapeProjection(y2, 1), - ShapeSingleton(5), ShapeAttr("att")]), - ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) - - # different attr, var order, position, and constant - tq3 = nest_quantifiers( - [b3, x3, y3], - TypeArrow( - [TensorType(b3, x3), TensorType(b3, y3)], - TensorType( - b3, - ShapeBinaryOp(ShapeOp.SHPLUS, - ShapeSeq([x3, ShapeProjection(y3, 1), - ShapeSingleton(4), ShapeAttr("att")]), - ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) - - tq4 = nest_quantifiers( - [b3, x3, y3], - TypeArrow( - [TensorType(b3, x3), TensorType(b3, y3)], - TensorType( - b3, - ShapeBinaryOp(ShapeOp.SHPLUS, - ShapeSeq([x3, ShapeProjection(y3, 2), - ShapeSingleton(5), ShapeAttr("att2")]), - ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) - - tq5 = nest_quantifiers( - [b3, x3, y3], - TypeArrow( - [TensorType(b3, x3), TensorType(b3, y3)], - TensorType( - b3, - ShapeBinaryOp(ShapeOp.SHMUL, - ShapeSeq([x3, ShapeProjection(y3, 1), - ShapeSingleton(5), ShapeAttr("att")]), - ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) - - tq6 = nest_quantifiers( - [b3, y3, x3], - TypeArrow( - [TensorType(b3, x3), TensorType(b3, y3)], - TensorType( - b3, - ShapeBinaryOp(ShapeOp.SHPLUS, - ShapeSeq([x3, ShapeProjection(y3, 1), - ShapeSingleton(5), ShapeAttr("att")]), - ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) - - assert alpha_eq(tq1, tq2) - assert not alpha_eq(tq1, tq3) - assert not alpha_eq(tq2, tq3) - assert not alpha_eq(tq1, tq4) - assert not alpha_eq(tq2, tq4) - assert not alpha_eq(tq1, tq5) - assert not alpha_eq(tq2, tq5) - assert not alpha_eq(tq1, tq6) - assert not alpha_eq(tq2, tq6) +from tvm.relay import make as mk +# from relay.ir import alpha_eq, ShapeOp, Kind +# from relay.typing import TYPE_DEFAULTS +# from relay import ir + +# INT_TYPE_WIDTH = TYPE_DEFAULTS["INT_WIDTH"] +# INT_TYPE_LANES = TYPE_DEFAULTS["INT_LANES"] + +# def int_type(width=32) -> ir.Type: +# return TensorType(IntType(width), ShapeSeq([])) + +# def float_type(width=32) -> ir.Type: +# return TensorType(FloatType(width), ShapeSeq([])) + +# def bool_type() -> ir.Type: +# return TensorType(BoolType(), ShapeSeq([])) + +# def nest_quantifiers(ids, body) -> ir.Type: +# ret = body +# for tid in reversed(ids): +# ret = TypeQuantifier(tid, ret) +# return ret + +# def test_local_id_not_eq() -> None: +# assert not alpha_eq(LocalId("x"), LocalId("y")) + +# def test_local_id_eq() -> None: +# x = LocalId("x") +# assert alpha_eq(x, x) + +# def test_global_id_not_eq() -> None: +# left = GlobalId("xyz") +# right = GlobalId("xyz") +# assert not alpha_eq(left, right) + +# def test_global_id_eq() -> None: +# ident = GlobalId("xyz") +# assert alpha_eq(ident, ident) + +# def test_operator_id_not_eq() -> None: +# left = OperatorId("xyz") +# right = OperatorId("xyz") +# # equality on operator id is pointer equality +# assert not alpha_eq(left, right) + +# def test_operator_id_eq() -> None: +# x = OperatorId("xyz") +# assert alpha_eq(x, x) + +# def test_float_literal_eq() -> None: +# x = FloatLit(1.0) +# y = FloatLit(1.0) +# assert alpha_eq(x, y) + +# def test_float_literal_not_eq() -> None: +# x = FloatLit(1.0) +# y = FloatLit(2.0) +# assert not alpha_eq(x, y) + +# def test_int_literal_eq() -> None: +# x = IntLit(1) +# y = IntLit(1) +# assert alpha_eq(x, y) + +# def test_int_literal_not_eq() -> None: +# x = IntLit(1) +# y = IntLit(2) +# assert not alpha_eq(x, y) + +# def test_bool_literal_eq() -> None: +# x = BoolLit(True) +# y = BoolLit(True) +# assert alpha_eq(x, y) + +# def test_bool_literal_not_eq() -> None: +# x = BoolLit(True) +# y = BoolLit(False) +# assert not alpha_eq(x, y) + +# def test_tensor_literal_eq() -> None: +# x = TensorLit([IntLit(1), IntLit(2)]) +# y = TensorLit([IntLit(1), IntLit(2)]) +# assert alpha_eq(x, y) + +# def test_tensor_literal_not_eq() -> None: +# x = TensorLit([IntLit(1), IntLit(2)]) +# y = TensorLit([IntLit(1), IntLit(3)]) +# z = TensorLit([IntLit(1)]) +# assert not alpha_eq(x, y) +# assert not alpha_eq(x, z) + +# def test_product_literal_eq() -> None: +# x = Tuple([IntLit(1), IntLit(2)]) +# y = Tuple([IntLit(1), IntLit(2)]) +# assert alpha_eq(x, y) + +# def test_product_literal_not_eq() -> None: +# x = Tuple([IntLit(1), IntLit(2)]) +# y = Tuple([IntLit(2), IntLit(2)]) +# z = Tuple([IntLit(1), IntLit(2), IntLit(3)]) +# assert not alpha_eq(x, y) +# assert not alpha_eq(x, z) + +# def test_projection_eq() -> None: +# prod = Tuple([IntLit(3), FloatLit(3.5)]) + +# assert alpha_eq(Projection(prod, 0), Projection(prod, 0)) +# assert alpha_eq(Projection(prod, 1), Projection(prod, 1)) + +# def test_projection_not_eq() -> None: +# prod1 = Tuple([IntLit(3), IntLit(4)]) +# prod2 = Tuple([IntLit(3)]) +# prod3 = Tuple([IntLit(3), IntLit(4), FloatLit(3.5)]) + +# assert not alpha_eq(Projection(prod1, 0), Projection(prod1, 1)) +# assert not alpha_eq(Projection(prod1, 0), Projection(prod2, 0)) +# assert not alpha_eq(Projection(prod1, 0), Projection(prod3, 0)) +# assert not alpha_eq(Projection(prod1, 1), Projection(prod3, 1)) + +# def test_cast_not_eq() -> None: +# left = Cast(IntType(1), IntLit(2)) +# right = Cast(IntType(1), IntLit(1)) +# assert not alpha_eq(left, right) + +# # same literal, different type +# left = Cast(IntType(1), IntLit(2)) +# right = Cast(IntType(2), IntLit(2)) +# assert not alpha_eq(left, right) + +# def test_cast_eq() -> None: +# left = Cast(IntType(1), IntLit(2)) +# right = Cast(IntType(1), IntLit(2)) +# assert alpha_eq(left, right) + +# def test_param_not_eq() -> None: +# left = Param(LocalId("foo"), int_type()) +# right = Param(LocalId("foo"), bool_type()) +# assert not alpha_eq(left, right) + +# def test_param_eq() -> None: +# left = Param(LocalId("foo"), int_type()) +# right = Param(LocalId("bar"), int_type()) +# assert alpha_eq(left, right) + +# def test_function_not_eq() -> None: +# params1 = [Param(LocalId("x"), int_type())] +# fn1 = Function([], params1, int_type(), LocalId("x")) +# params2 = [Param(LocalId("y"), bool_type())] +# fn2 = Function([], params2, int_type(), LocalId("y")) +# assert not alpha_eq(fn1, fn2) + +# params3 = [Param(LocalId("x"), int_type()), Param(LocalId("y"), int_type())] +# fn3 = Function([], params3, int_type(), LocalId("z")) +# assert not alpha_eq(fn1, fn3) + +# def test_function_eq() -> None: +# x = LocalId("x") +# y = LocalId("y") +# params1 = [Param(x, int_type())] +# fn1 = Function([], params1, int_type(), x) +# params2 = [Param(y, int_type())] +# fn2 = Function([], params2, int_type(), y) +# assert alpha_eq(fn1, fn2) + +# def test_call_not_eq() -> None: +# x = LocalId("x") +# y = LocalId("y") +# params1 = [Param(x, int_type())] +# fn1 = Function([], params1, int_type(), x) +# args1 = [IntLit(1)] +# call1 = Call(fn1, args1) + +# args2 = [IntLit(2)] +# call2 = Call(fn1, args2) +# assert not alpha_eq(call1, call2) + +# params2 = [Param(y, int_type())] +# fn2 = Function([], params2, float_type(), FloatLit(0.0)) +# call3 = Call(fn2, args1) +# assert not alpha_eq(call1, call3) +# assert not alpha_eq(call2, call3) + +# def test_call_eq() -> None: +# x = LocalId("x") +# y = LocalId("y") +# params1 = [Param(x, int_type())] +# fn1 = Function([], params1, int_type(), x) +# args = [IntLit(1)] +# call1 = Call(fn1, args) + +# params2 = [Param(y, int_type())] +# fn2 = Function([], params2, int_type(), y) +# call2 = Call(fn2, args) +# assert alpha_eq(call1, call2) + +# def test_debug_not_eq() -> None: +# left = Debug(IntLit(1)) +# right = Debug(IntLit(2)) +# assert not alpha_eq(left, right) + +# def test_debug_eq() -> None: +# left = Debug(IntLit(1)) +# right = Debug(IntLit(1)) +# assert alpha_eq(left, right) + +# def test_let_not_eq() -> None: +# x = LocalId("x") +# y = LocalId("y") +# let1 = Let(x, int_type(), IntLit(10), IntLit(11)) +# let2 = Let(y, int_type(), IntLit(10), IntLit(12)) +# assert not alpha_eq(let1, let2) + +# let3 = Let(x, int_type(), IntLit(10), x) +# let4 = Let(y, int_type(), IntLit(12), y) +# assert not alpha_eq(let3, let4) + +# def test_let_eq() -> None: +# x = LocalId("x") +# y = LocalId("y") +# let1 = Let(x, int_type(), IntLit(10), x) +# let2 = Let(y, int_type(), IntLit(10), y) +# assert alpha_eq(let1, let2) + +# def test_ref_eq() -> None: +# r1 = Ref(IntLit(5)) +# r2 = Ref(IntLit(5)) +# assert alpha_eq(r1, r2) + +# def test_ref_not_eq() -> None: +# r1 = Ref(IntLit(5)) +# r2 = Ref(FloatLit(3.5)) +# r3 = Ref(r1) +# assert not alpha_eq(r1, r2) +# assert not alpha_eq(r1, r3) +# assert not alpha_eq(r2, r3) + +# def test_val_ref_eq() -> None: +# vr1 = ReadRef(Ref(IntLit(35))) +# vr2 = ReadRef(Ref(Tuple([IntLit(12), FloatLit(2.5)]))) +# assert alpha_eq(vr1, vr1) +# assert alpha_eq(vr2, vr2) + +# def test_val_ref_not_eq() -> None: +# vr1 = ReadRef(Ref(IntLit(5))) +# vr2 = ReadRef(Ref(vr1)) +# vr3 = ReadRef(Ref(FloatLit(5.0))) +# assert not alpha_eq(vr1, vr2) +# assert not alpha_eq(vr1, vr3) +# assert not alpha_eq(vr2, vr3) + +# def test_set_ref_eq() -> None: +# sr1 = WriteRef(Ref(FloatLit(5.0)), FloatLit(6.0)) +# sr2 = WriteRef(Ref(Tuple([IntLit(3), BoolLit(False)])), +# Tuple([IntLit(5), BoolLit(True)])) +# assert alpha_eq(sr1, sr1) +# assert alpha_eq(sr2, sr2) + +# def test_set_ref_not_eq() -> None: +# r1 = Ref(FloatLit(5.0)) +# r2 = Ref(IntLit(5)) +# r3 = Ref(IntLit(6)) + +# assert not alpha_eq(WriteRef(r1, FloatLit(6.0)), +# WriteRef(r2, IntLit(6))) +# assert not alpha_eq(WriteRef(r2, IntLit(6)), WriteRef(r2, IntLit(7))) +# assert not alpha_eq(WriteRef(r2, IntLit(7)), WriteRef(r3, IntLit(7))) + +# # Type alpha-equality tests + +# def test_base_type_eq() -> None: +# assert alpha_eq(IntType(32), IntType(32)) +# assert alpha_eq(BoolType(), BoolType()) +# assert alpha_eq(FloatType(32), FloatType(32)) + +# def test_tensor_type_eq() -> None: +# tt1 = TensorType( +# IntType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) +# tt2 = TensorType( +# FloatType(32), ShapeSeq([ShapeSingleton(3), ShapeSingleton(3)])) +# assert alpha_eq(tt1, tt1) +# assert alpha_eq(tt2, tt2) + +# def test_tensor_type_not_eq() -> None: +# tt1 = TensorType( +# IntType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) +# tt2 = TensorType( +# FloatType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) +# tt3 = TensorType( +# IntType(32), ShapeSeq([ShapeSingleton(3), ShapeSingleton(3)])) +# assert not alpha_eq(tt1, tt2) +# assert not alpha_eq(tt1, tt3) + +# def test_ref_type_eq() -> None: +# rt1 = RefType(int_type()) +# rt2 = RefType(float_type()) +# assert alpha_eq(rt1, rt1) +# assert alpha_eq(rt2, rt2) + +# def test_ref_type_not_eq() -> None: +# rt1 = RefType(int_type()) +# rt2 = RefType(float_type()) +# assert not alpha_eq(rt1, rt2) + +# def test_product_type_eq() -> None: +# pt1 = TupleType([int_type(), RefType(float_type())]) +# pt2 = TupleType([float_type(), float_type(), int_type()]) +# assert alpha_eq(pt1, pt1) +# assert alpha_eq(pt2, pt2) + +# def test_product_type_not_eq() -> None: +# pt1 = TupleType([int_type(), int_type()]) +# pt2 = TupleType([int_type(), int_type(), float_type()]) +# pt3 = TupleType([bool_type(), float_type()]) +# assert not alpha_eq(pt1, pt2) +# assert not alpha_eq(pt1, pt3) + +# def test_type_id_eq() -> None: +# id1 = TypeParam("id1", Kind.Shape) +# id2 = TypeParam("id2", Kind.BaseType) +# id3 = TypeParam("id2", Kind.Type) + +# assert alpha_eq(id1, id1) +# assert alpha_eq(id2, id2) +# assert alpha_eq(id3, id3) + +# def test_type_id_not_eq() -> None: +# # name is just a hint, we use pointer equality as the rule +# # (unless there is a quantifier to give context) +# id1 = TypeParam("id1", Kind.Shape) +# id2 = TypeParam("id1", Kind.Shape) +# id3 = TypeParam("id3", Kind.BaseType) + +# assert not alpha_eq(id1, id2) +# assert not alpha_eq(id1, id3) + +# def test_arrow_type_eq() -> None: +# ar1 = TypeArrow([int_type()], bool_type()) +# ar2 = TypeArrow([int_type(), int_type()], TupleType([])) +# assert alpha_eq(ar1, ar1) +# assert alpha_eq(ar2, ar2) + +# def test_arrow_type_not_eq() -> None: +# t1 = int_type() +# t2 = bool_type() +# t3 = [int_type(), bool_type()] + +# assert not alpha_eq(TypeArrow([t1], t2), TypeArrow([t1], t1)) +# assert not alpha_eq(TypeArrow(t3, t1), TypeArrow([t2], t1)) +# assert not alpha_eq(TypeArrow([t1], TypeArrow([t1], t1)), +# TypeArrow([t1], t1)) + +# def test_type_quantifier_eq() -> None: +# id1 = TypeParam("id1", Kind.Shape) +# id2 = TypeParam("id2", Kind.Shape) +# tq1 = TypeQuantifier(id1, TensorType(IntType(32), id1)) +# tq2 = TypeQuantifier(id2, TensorType(IntType(32), id2)) + +# assert alpha_eq(tq1, tq1) +# assert alpha_eq(tq1, tq2) + +# def test_nested_type_quantifier_eq() -> None: +# id1 = TypeParam("id1", Kind.BaseType) +# id2 = TypeParam("id2", Kind.Shape) +# id3 = TypeParam("id3", Kind.BaseType) +# id4 = TypeParam("id4", Kind.Shape) +# tq1 = TypeQuantifier(id1, TypeQuantifier(id2, TensorType(id1, id2))) +# tq2 = TypeQuantifier(id3, TypeQuantifier(id4, TensorType(id3, id4))) + +# assert alpha_eq(tq1, tq1) +# assert alpha_eq(tq1, tq2) + +# def test_type_quantifier_not_eq() -> None: +# id1 = TypeParam("id1", Kind.Shape) +# id2 = TypeParam("id2", Kind.BaseType) +# id3 = TypeParam("id3", Kind.Shape) + +# tq1 = TypeQuantifier(id1, TensorType(IntType(32), id1)) +# tq2 = TypeQuantifier(id2, TensorType(id2, ShapeSeq([ShapeSingleton(3)]))) +# tq3 = TypeQuantifier(id1, TensorType(IntType(32), id3)) +# tq4 = TypeQuantifier(id1, TensorType(FloatType(32), id1)) + +# assert not alpha_eq(tq1, tq2) +# assert not alpha_eq(tq1, tq3) +# assert not alpha_eq(tq1, tq4) +# assert not alpha_eq(tq2, tq3) +# assert not alpha_eq(tq2, tq4) + +# def test_shape_singleton_eq() -> None: +# single1 = ShapeSingleton(10) +# single2 = ShapeSingleton(10) + +# assert alpha_eq(single1, single1) +# assert alpha_eq(single1, single2) + +# def test_shape_singelton_not_eq() -> None: +# single1 = ShapeSingleton(10) +# single2 = ShapeSingleton(11) + +# assert not alpha_eq(single1, single2) + +# def test_shape_attr_eq() -> None: +# attr1 = ShapeAttr("x") +# attr2 = ShapeAttr("x") + +# assert alpha_eq(attr1, attr1) +# assert alpha_eq(attr1, attr2) + +# def test_shape_attr_not_eq() -> None: +# id1 = "x" +# id2 = "y" +# attr1 = ShapeAttr(id1) +# attr2 = ShapeAttr(id2) + +# assert not alpha_eq(attr1, attr2) + +# def test_shape_seq_eq() -> None: +# empty = ShapeSeq([]) +# seq1 = ShapeSeq([ShapeSingleton(5)]) +# seq2 = ShapeSeq([ShapeSingleton(5)]) + +# assert alpha_eq(empty, empty) +# assert alpha_eq(seq1, seq2) + +# def test_shape_seq_not_eq() -> None: +# empty = ShapeSeq([]) +# seq = ShapeSeq([ShapeSingleton(5)]) +# single = ShapeSingleton(5) + +# assert not alpha_eq(empty, seq) +# assert not alpha_eq(seq, single) + +# def test_shape_projection_eq() -> None: +# proj1 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) +# proj2 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) + +# assert alpha_eq(proj1, proj2) + +# def test_shape_projection_not_eq() -> None: +# proj1 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) +# proj2 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 1) +# proj3 = ShapeProjection(ShapeSeq([ShapeSingleton(2), ShapeSingleton(1)]), 0) +# proj4 = ShapeProjection(ShapeSeq([ShapeSingleton(2), ShapeSingleton(1)]), 1) + +# assert not alpha_eq(proj1, proj2) +# assert not alpha_eq(proj1, proj3) +# assert not alpha_eq(proj1, proj4) +# assert not alpha_eq(proj2, proj3) +# assert not alpha_eq(proj2, proj4) +# assert not alpha_eq(proj3, proj4) + +# def test_shape_binary_op_eq() -> None: +# empty = ShapeSeq([]) +# single = ShapeSingleton(5) +# seq = ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]) + +# op1 = ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty) +# op2 = ShapeBinaryOp(ShapeOp.SHSUB, single, single) +# op3 = ShapeBinaryOp(ShapeOp.SHMUL, seq, seq) +# op4 = ShapeBinaryOp(ShapeOp.SHDIV, seq, seq) + +# assert alpha_eq(op1, op1) +# assert alpha_eq(op2, op2) +# assert alpha_eq(op3, op3) +# assert alpha_eq(op4, op4) + +# def test_shape_binary_op_not_eq() -> None: +# empty = ShapeSeq([]) +# single = ShapeSingleton(5) +# seq = ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]) + +# assert not alpha_eq(ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty), empty) +# assert not alpha_eq(ShapeBinaryOp(ShapeOp.SHMUL, seq, ShapeSingleton(1)), seq) +# assert not alpha_eq( +# ShapeBinaryOp(ShapeOp.SHPLUS, single, single), +# ShapeBinaryOp(ShapeOp.SHPLUS, +# ShapeSeq([single]), +# ShapeSeq([single]))) +# assert not alpha_eq( +# ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty), +# ShapeBinaryOp(ShapeOp.SHSUB, empty, empty)) +# assert not alpha_eq( +# ShapeBinaryOp(ShapeOp.SHMUL, empty, empty), +# ShapeBinaryOp(ShapeOp.SHDIV, empty, empty)) + +# def test_shape_nested_in_quantifier() -> None: +# b1 = TypeParam("b", Kind.BaseType) +# x1 = TypeParam("x", Kind.Shape) +# y1 = TypeParam("y", Kind.Shape) + +# b2 = TypeParam("b", Kind.BaseType) +# x2 = TypeParam("x", Kind.Shape) +# y2 = TypeParam("y", Kind.Shape) + +# b3 = TypeParam("b", Kind.BaseType) +# x3 = TypeParam("x", Kind.Shape) +# y3 = TypeParam("y", Kind.Shape) + +# tq1 = nest_quantifiers( +# [b1, x1, y1], +# TypeArrow( +# [TensorType(b1, x1), TensorType(b1, y2)], +# TensorType( +# b1, +# ShapeBinaryOp(ShapeOp.SHPLUS, +# ShapeSeq([x1, ShapeProjection(y1, 1), +# ShapeSingleton(5), ShapeAttr("att")]), +# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + +# tq2 = nest_quantifiers( +# [b2, x2, y2], +# TypeArrow( +# [TensorType(b2, x2), TensorType(b2, y2)], +# TensorType( +# b2, +# ShapeBinaryOp(ShapeOp.SHPLUS, +# ShapeSeq([x2, ShapeProjection(y2, 1), +# ShapeSingleton(5), ShapeAttr("att")]), +# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + +# # different attr, var order, position, and constant +# tq3 = nest_quantifiers( +# [b3, x3, y3], +# TypeArrow( +# [TensorType(b3, x3), TensorType(b3, y3)], +# TensorType( +# b3, +# ShapeBinaryOp(ShapeOp.SHPLUS, +# ShapeSeq([x3, ShapeProjection(y3, 1), +# ShapeSingleton(4), ShapeAttr("att")]), +# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + +# tq4 = nest_quantifiers( +# [b3, x3, y3], +# TypeArrow( +# [TensorType(b3, x3), TensorType(b3, y3)], +# TensorType( +# b3, +# ShapeBinaryOp(ShapeOp.SHPLUS, +# ShapeSeq([x3, ShapeProjection(y3, 2), +# ShapeSingleton(5), ShapeAttr("att2")]), +# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + +# tq5 = nest_quantifiers( +# [b3, x3, y3], +# TypeArrow( +# [TensorType(b3, x3), TensorType(b3, y3)], +# TensorType( +# b3, +# ShapeBinaryOp(ShapeOp.SHMUL, +# ShapeSeq([x3, ShapeProjection(y3, 1), +# ShapeSingleton(5), ShapeAttr("att")]), +# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + +# tq6 = nest_quantifiers( +# [b3, y3, x3], +# TypeArrow( +# [TensorType(b3, x3), TensorType(b3, y3)], +# TensorType( +# b3, +# ShapeBinaryOp(ShapeOp.SHPLUS, +# ShapeSeq([x3, ShapeProjection(y3, 1), +# ShapeSingleton(5), ShapeAttr("att")]), +# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) + +# assert alpha_eq(tq1, tq2) +# assert not alpha_eq(tq1, tq3) +# assert not alpha_eq(tq2, tq3) +# assert not alpha_eq(tq1, tq4) +# assert not alpha_eq(tq2, tq4) +# assert not alpha_eq(tq1, tq5) +# assert not alpha_eq(tq2, tq5) +# assert not alpha_eq(tq1, tq6) +# assert not alpha_eq(tq2, tq6) diff --git a/tests/python/relay/test_unifier.py b/tests/python/relay/test_unifier.py index b2ed075ca3de..065a91f0abbe 100644 --- a/tests/python/relay/test_unifier.py +++ b/tests/python/relay/test_unifier.py @@ -2,23 +2,23 @@ Test the type unifier, which solves systems of equations between incomplete types. """ -import tvm.relay.ir +from tvm.relay import ir from tvm.relay.unifier import UnionFind, TypeUnifier import tvm.relay.make as mk def test_insert_and_find(): - uf = UnionFind() - v1 = mk.TypeVar(ir.Kind.Type) - v2 = mk.TypeVar(ir.Kind.Type) + uf = mk.UnionFind()() + v1 = mk.IncompleteType(ir.Kind.Type) + v2 = mk.IncompleteType(ir.Kind.Type) uf.insert(v1) uf.insert(v2) assert uf.find(v1) == v1 assert uf.find(v2) == v2 def test_insert_error(): - uf = UnionFind() - v1 = mk.TypeVar(ir.Kind.Type) - v2 = mk.TypeVar(ir.Kind.Type) + uf = mk.UnionFind()() + v1 = mk.IncompleteType(ir.Kind.Type) + v2 = mk.IncompleteType(ir.Kind.Type) uf.insert(v1) try: uf.find(v2) @@ -27,10 +27,10 @@ def test_insert_error(): return def test_unify(): - uf = UnionFind() - v1 = mk.TypeVar(ir.Kind.Type) - v2 = mk.TypeVar(ir.Kind.Type) - v3 = mk.TypeVar(ir.Kind.Type) + uf = mk.UnionFind()() + v1 = mk.IncompleteType(ir.Kind.Type) + v2 = mk.IncompleteType(ir.Kind.Type) + v3 = mk.IncompleteType(ir.Kind.Type) uf.insert(v1) uf.insert(v2) uf.insert(v3) @@ -49,8 +49,8 @@ def test_unify(): assert uf.find(v3) == new_rep def test_unify_multiple_levels(): - uf = UnionFind() - v = [TypeVar(ir.Kind.Type) for _ in range(9)] + uf = mk.UnionFind()() + v = [mk.IncompleteType(ir.Kind.Type) for _ in range(9)] for var in v: uf.insert(var) uf.unify(v[0], v[1]) @@ -85,7 +85,7 @@ def test_unify_multiple_levels(): # and now we will test the type unifier which will fill in holes # between type equalities by the process of unification. def unify_types(t1, t2): - unifier = TypeUnifier() + unifier = mk.TypeUnifier() return unifier.unify(t1, t2) # TODO(sslyu, weberlo, joshpoll): put in isinstance asserts once those work @@ -120,8 +120,8 @@ def test_unify_concrete_type_arrow(): assert unified == arr1 def test_unify_type_arrow_with_holes(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.BaseType) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.BaseType) unifier.insert(v1) unifier.unify(v1, bool_type()) arr1 = TypeArrow([int_type()], bool_type()) @@ -129,7 +129,7 @@ def test_unify_type_arrow_with_holes(): unified = unifier.unify(arr1, arr2) assert unified == arr1 - v2 = TypeVar(ir.Kind.BaseType) + v2 = mk.IncompleteType(ir.Kind.BaseType) unifier.insert(v2) unifier.unify(v2, int_type()) arr3 = TypeArrow([v2], bool_type()) @@ -161,10 +161,10 @@ def test_unify_basetype_with_quantifier_error(): return def test_unify_typevars_with_each_other(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.Type) - v2 = TypeVar(ir.Kind.Type) - v3 = TypeVar(ir.Kind.Type) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.Type) + v2 = mk.IncompleteType(ir.Kind.Type) + v3 = mk.IncompleteType(ir.Kind.Type) unifier.insert(v1) unifier.insert(v2) unifier.insert(v3) @@ -175,10 +175,10 @@ def test_unify_typevars_with_each_other(): assert (new_unified == v1 or new_unified == v2 or new_unified == v3) def test_unify_typevars_with_basetype(): - unifier = TypeUnifier() + unifier = mk.TypeUnifier() bt = BoolType() - v1 = TypeVar(ir.Kind.BaseType) - v2 = TypeVar(ir.Kind.BaseType) + v1 = mk.IncompleteType(ir.Kind.BaseType) + v2 = mk.IncompleteType(ir.Kind.BaseType) unifier.insert(v1) unifier.insert(v2) unified1 = unifier.unify(v1, bt) @@ -187,10 +187,10 @@ def test_unify_typevars_with_basetype(): assert unified2 == bt def test_unify_compatible_typevars(): - unifier = TypeUnifier() + unifier = mk.TypeUnifier() bt = BoolType() - v1 = TypeVar(ir.Kind.BaseType) - v2 = TypeVar(ir.Kind.BaseType) + v1 = mk.IncompleteType(ir.Kind.BaseType) + v2 = mk.IncompleteType(ir.Kind.BaseType) unifier.insert(v1) unifier.insert(v2) unifier.unify(v1, bt) @@ -201,9 +201,9 @@ def test_unify_compatible_typevars(): assert unified == bt def test_unify_incompatible_typevars(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.BaseType) - v2 = TypeVar(ir.Kind.BaseType) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.BaseType) + v2 = mk.IncompleteType(ir.Kind.BaseType) bt = bool_type() tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt) unifier.insert(v1) @@ -218,16 +218,16 @@ def test_unify_incompatible_typevars(): return def test_unify_typevar_with_quantifier(): - unifier = TypeUnifier() + unifier = mk.TypeUnifier() tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bool_type()) - v1 = TypeVar(ir.Kind.BaseType) + v1 = mk.IncompleteType(ir.Kind.BaseType) unifier.insert(v1) unified = unifier.unify(v1, tq) assert unified == tq def test_unify_typevars_inside_concrete_quantifier(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.BaseType) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.BaseType) unifier.insert(v1) tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v1) tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), bool_type()) @@ -312,8 +312,8 @@ def test_unify_products_reject_member(): return def test_unify_products_typevar(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.BaseType) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.BaseType) bt = bool_type() pt1 = TupleType([bt, bt]) pt2 = TupleType([v1, bt]) @@ -344,22 +344,22 @@ def test_unify_ref_reject_inner(): return def test_subst_basetype(): - unifier = TypeUnifier() + unifier = mk.TypeUnifier() bt = BoolType() assert bt == unifier.subst(bt) def test_subst_simple_hole(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.BaseType) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.BaseType) bt = BoolType() unifier.insert(v1) unifier.unify(v1, bt) assert unifier.subst(v1) == bt def test_subst_typevar_for_typevar(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.Type) - v2 = TypeVar(ir.Kind.Type) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.Type) + v2 = mk.IncompleteType(ir.Kind.Type) unifier.insert(v1) unifier.insert(v2) @@ -367,14 +367,14 @@ def test_subst_typevar_for_typevar(): assert unifier.subst(v1) == v2 def test_subst_concrete_arrow(): - unifier = TypeUnifier() + unifier = mk.TypeUnifier() arr1 = TypeArrow([int_type()], int_type()) assert unifier.subst(arr1) == arr1 def test_subst_arrow_with_holes(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.BaseType) - v2 = TypeVar(ir.Kind.BaseType) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.BaseType) + v2 = mk.IncompleteType(ir.Kind.BaseType) unifier.insert(v1) unifier.insert(v2) unifier.unify(v1, int_type()) @@ -384,17 +384,17 @@ def test_subst_arrow_with_holes(): assert unifier.subst(arr1) == arr2 def test_subst_concrete_quantifier(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.BaseType) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.BaseType) tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), int_type()) unifier.insert(v1) unifier.unify(v1, tq) assert unifier.subst(v1) == tq def test_subst_quantifier_with_holes(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.Type) - v2 = TypeVar(ir.Kind.Type) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.Type) + v2 = mk.IncompleteType(ir.Kind.Type) tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v2) intty = int_type() tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), intty) @@ -406,16 +406,16 @@ def test_subst_quantifier_with_holes(): assert unifier.subst(v1) == tq2 def test_subst_concrete_tensor(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.Type) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.Type) unifier.insert(v1) tt = TensorType(BoolType(), make_shape([1, 2, 3])) unifier.unify(v1, tt) assert unifier.subst(v1) == tt def test_subst_concrete_product(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.Type) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.Type) unifier.insert(v1) bt = bool_type() pt = TupleType([bt, bt]) @@ -423,10 +423,10 @@ def test_subst_concrete_product(): assert unifier.subst(v1) == pt def test_subst_product_with_holes(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.Type) - v2 = TypeVar(ir.Kind.Type) - v3 = TypeVar(ir.Kind.Type) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.Type) + v2 = mk.IncompleteType(ir.Kind.Type) + v3 = mk.IncompleteType(ir.Kind.Type) unifier.insert(v1) unifier.insert(v2) unifier.insert(v3) @@ -441,13 +441,13 @@ def test_subst_product_with_holes(): assert unifier.subst(v1) == pt2 def test_subst_concrete_ref(): - unifier = TypeUnifier() + unifier = mk.TypeUnifier() rt = RefType(bool_type()) assert unifier.subst(rt) == rt def test_subst_ref_with_hole(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.Type) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.Type) unifier.insert(v1) unifier.unify(v1, bool_type()) @@ -456,9 +456,9 @@ def test_subst_ref_with_hole(): assert unifier.subst(rt1) == rt2 def test_typevar_on_lhs(): - unifier = TypeUnifier() - v1 = TypeVar(ir.Kind.BaseType) - v2 = TypeVar(ir.Kind.Type) + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.BaseType) + v2 = mk.IncompleteType(ir.Kind.Type) bt = bool_type() tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt, bt) unifier.insert(v1) From fee67dee3dd0952a1fa9a65edfb9d49a12eb8780 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 22 Aug 2018 00:46:19 -0700 Subject: [PATCH 023/136] Remove tests for ommitted features and fix Remove the tests for features we don't currently support, and fix the tests which were left. --- include/tvm/relay/expr_visitor.h | 6 +- include/tvm/relay/type.h | 12 + python/tvm/relay/ir_builder.py | 23 +- python/tvm/relay/make.py | 41 +++ python/tvm/relay/type.py | 16 +- src/relay/compiler/alpha_eq.cc | 16 +- src/relay/compiler/type_visitor.h | 115 ++++--- src/relay/compiler/typechecker.cc | 2 +- src/relay/compiler/unifier.cc | 126 +++---- src/relay/compiler/unifier.h | 2 +- src/relay/type.cc | 38 +++ tests/python/relay/test_unifier.py | 509 +++++++++++++++-------------- 12 files changed, 508 insertions(+), 398 deletions(-) diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h index d7ac1465f70a..721fa531a7e3 100644 --- a/include/tvm/relay/expr_visitor.h +++ b/include/tvm/relay/expr_visitor.h @@ -7,13 +7,13 @@ #ifndef TVM_RELAY_EXPR_VISITOR_H_ #define TVM_RELAY_EXPR_VISITOR_H_ -#include "expr_functor.h" +#include "tvm/relay/expr_functor.h" namespace tvm { namespace relay { template -class ExprVisitor : public ExprFunctor { +class ExprVisitor : public ::tvm::relay::ExprFunctor { public: void VisitExpr_(const LocalVarNode* op, Args... args) override { return; } @@ -62,7 +62,7 @@ class ExprVisitor : public ExprFunctor { }; template -class ExprFVisitor : public ExprFunctor { +class ExprFVisitor : public ::tvm::relay::ExprFunctor { public: Expr VisitExpr_(const LocalVarNode* op, Args... args) override { return GetRef(op); diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 4eeb42168d68..07b047471aba 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -83,6 +83,18 @@ class TensorTypeNode : public BaseTensorTypeNode { TVM_DLL static TensorType make(Array shape, DataType dtype); + /*! \brief Constructing an unsigned integer type */ + TVM_DLL static TensorType Int(int bits, int lanes = 1); + + /*! \brief Constructing an unsigned integer type */ + TVM_DLL static TensorType UInt(int bits, int lanes = 1); + + /*! \brief Construct a floating-point type */ + TVM_DLL static TensorType Float(int bits, int lanes = 1); + + /*1 \brief Construct a boolean type */ + TVM_DLL static TensorType Bool(int lanes = 1); + static constexpr const char* _type_key = "relay.TensorType"; TVM_DECLARE_NODE_TYPE_INFO(TensorTypeNode, BaseTensorTypeNode); }; diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 3c842e480c70..8fa9b789f53c 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -100,16 +100,23 @@ def get(self): return _mk_let(bindings, self.ret_value) -# def int_type(): -# return TensorType(IntType(32), ShapeSeq([])) +def bool_dtype(): + return 'uint1' -# def float_type(): -# return TensorType(FloatType(32), ShapeSeq([])) +def int_dtype(): + return 'uint1' -# def bool_type(): -# return TensorType(BoolType(), ShapeSeq([])) +def int_type(bits=32, lanes=1): + return mk.IntType(bits, lanes) -# def make_shape(dims): -# return ShapeSeq([ShapeSingleton(dim) for dim in dims]) +def uint_type(bits=32, lanes=1): + return mk.UIntType(bits, lanes) +def float_type(bits=32, lanes=1): + return mk.FloatType(bits, lanes) +def bool_type(lanes=1): + return mk.BoolType(lanes) + +def func_type(args, ret_type, type_params=[], type_constraints=[]): + return mk.FuncType(args, ret_type, type_params, type_constraints) diff --git a/python/tvm/relay/make.py b/python/tvm/relay/make.py index a2b87f2700af..236e2f6af596 100644 --- a/python/tvm/relay/make.py +++ b/python/tvm/relay/make.py @@ -1,4 +1,5 @@ from . import _make +from . import ir # Base Constructors Span = _make.Span @@ -8,6 +9,43 @@ TypeParam = _make.TypeParam FuncType = _make.FuncType +# Types +def IntType(bits: int, lanes: int=1) -> ir.Type: + """Constructs a integer base type. + + :param bits: The bit width of the integer type. + :param lanes: The number of vector elements for this datatype. + + """ + return _make.IntType(bits, lanes) + + +def UIntType(bits: int, lanes: int=1) -> ir.Type: + """Constructs a unsigned integer base type. + + :param bits: The bit width of the unsigned type. + :param lanes: The number of vector elements for this datatype. + """ + return _make.UIntType(bits, lanes) + + +def FloatType(bits: int, lanes: int=1) -> ir.Type: + """Constructs a floating point base type. + + :param bits: The bit width of the unsigned type. + :param lanes: The number of vector elements for this datatype. + """ + return _make.FloatType(bits, lanes) + + +def BoolType(lanes: int =1) -> ir.Type: + """Constructs a boolean base type. + + :param bits: The bit width of the unsigned type. + :param lanes: The number of vector elements for this datatype. + """ + return _make.BoolType(lanes) + # Expr Constructors Constant = _make.Constant Tuple = _make.Tuple @@ -23,3 +61,6 @@ # Unifier UnionFind = _make.UnionFind TypeUnifier = _make.TypeUnifier + +# Utility Functionality @TODO(jroesch): move to another location +_type_alpha_eq = _make._type_alpha_eq diff --git a/python/tvm/relay/type.py b/python/tvm/relay/type.py index 2790b546cfe5..a04089792282 100644 --- a/python/tvm/relay/type.py +++ b/python/tvm/relay/type.py @@ -4,10 +4,24 @@ from enum import IntEnum from .base import Span, NodeBase, register_relay_node from tvm import expr +# TODO(@jroesch): move me +from ._make import _type_alpha_eq class Type(NodeBase): """The base type for all Relay types.""" - pass + + def __eq__(self, other) -> bool: + """Compares two Relay types for structural equivalence using + alpha equivalence. + """ + return bool(_type_alpha_eq(self, other)) + + def __ne__(self, other) -> bool: + return not self.__eq__(other) + + def same_as(self, other) -> bool: + """Compares two Relay types by referential equality.""" + return super().__eq__(other) @register_relay_node class TensorType(Type): diff --git a/src/relay/compiler/alpha_eq.cc b/src/relay/compiler/alpha_eq.cc index 4b8e904bf29e..688a93ae73fc 100644 --- a/src/relay/compiler/alpha_eq.cc +++ b/src/relay/compiler/alpha_eq.cc @@ -33,14 +33,14 @@ struct TypeAlphaEq : TypeVisitor { } } -// void VisitType_(const TypeVarNode *bt1, const Type &t2) override { -// if (const TypeVarNode *bt2 = t2.as()) { -// equal = equal && bt1 == bt2; -// return; -// } else { -// equal = false; -// } -// } + void VisitType_(const IncompleteTypeNode *bt1, const Type &t2) override { + if (const IncompleteTypeNode *bt2 = t2.as()) { + equal = equal && bt1 == bt2; + return; + } else { + equal = false; + } + } void VisitType_(const TypeParamNode *ti1, const Type &t2) override { if (const TypeParamNode *ti2 = t2.as()) { diff --git a/src/relay/compiler/type_visitor.h b/src/relay/compiler/type_visitor.h index 5ae100a8de6d..60ae810a6b96 100644 --- a/src/relay/compiler/type_visitor.h +++ b/src/relay/compiler/type_visitor.h @@ -18,35 +18,34 @@ namespace relay { * We recursively visit each type contained inside the visitor. */ template -struct TypeVisitor : TypeFunctor { - // void VisitType_(const TypeVarNode* op, Args... args) override {} +struct TypeVisitor : ::tvm::relay::TypeFunctor { void VisitType_(const TypeParamNode* op, Args... args) override {} void VisitType_(const FuncTypeNode* op, Args... args) override { - // this->VisitType(op->id, args...); + // fix me handle poly + // this->VisitType(op->var, args...); // this->VisitType(op->boundType, args...); - // for (auto arg_type : op->arg_types) { - // this->VisitType(arg_type, args...); - // } - // this->VisitType(op->ret_type, args...); + for (auto arg_type : op->arg_types) { + this->VisitType(arg_type, args...); + } + this->VisitType(op->ret_type, args...); } - void VisitType_(const TensorTypeNode* op, Args... args) override { - // this->VisitType(op->dtype, args...); - // this->VisitType(op->shape, args...); - } + void VisitType_(const TensorTypeNode* op, Args... args) override {} -// void VisitType_(const TupleTypeNode* op, Args... args) override { -// for (const Type& t : op->fields) { -// this->VisitType(t, args...); -// } -// } + // void VisitType_(const TupleTypeNode* op, Args... args) override { + // for (const Type& t : op->fields) { + // this->VisitType(t, args...); + // } + // } -// void VisitType_(const TypeCallNode* op, Args... args) override { -// for (const Type& t : op->args) { -// this->VisitType(t, args...); -// } -// } + void VisitType_(const TypeCallNode* op, Args... args) override { + this->VisitType(op->func, args...); + + for (const Type& t : op->args) { + this->VisitType(t, args...); + } + } void VisitType_(const TypeFunctionNode* op, Args... args) override {} void VisitType_(const IncompleteTypeNode* op, Args... args) override {} @@ -60,48 +59,46 @@ struct TypeFVisitor : TypeFunctor { } Type VisitType_(const TypeParamNode* op) override { - return GetRef(op); + return GetRef(op); } -// Type VisitType_(const TypeArrowNode* op) override { -// std::vector args; -// for (auto arg_type : op->arg_types) { -// args.push_back(VisitType(arg_type)); -// } -// return TypeArrowNode::make(tvm::Array(args), VisitType(op->ret_type)); -// } - -// Type VisitType_(const TypeQuantifierNode* op) override { -// auto new_id = this->VisitType(op->id); -// if (const TypeParamNode* tin = new_id.as()) { -// return TypeQuantifierNode::make(GetRef(tin), -// this->VisitType(op->boundType)); -// } else { -// throw dmlc::Error("Cannot quantify something that is not a type ID"); -// } -// } - -// Type VisitType_(const TupleTypeNode* op) override { -// std::vector new_fields; -// for (const Type& t : op->fields) { -// new_fields.push_back(this->VisitType(t)); -// } -// return TupleTypeNode::make(new_fields); -// } - -// Type VisitType_(const TypeCallNode* op) override { -// auto func = this->VisitType(op->func); -// std::vector new_args; -// for (const Type& t : op->args) { -// new_args.push_back(this->VisitType(t)); -// } -// return TypeCallNode::make(func, new_args); -// } + Type VisitType_(const FuncTypeNode* op) override { + // auto new_id = this->VisitType(op->var); + // if (const TypeParamNode* tin = new_id.as()) { + // return TypeQuantifierNode::make(GetRef(tin), + // this->VisitType(op->boundType)); + + std::vector args; + for (auto arg_type : op->arg_types) { + args.push_back(VisitType(arg_type)); + } + + return FuncTypeNode::make(tvm::Array(args), + VisitType(op->ret_type), {}, {}); // fix me + } + + // Type VisitType_(const TupleTypeNode* op) override { + // std::vector new_fields; + // for (const Type& t : op->fields) { + // new_fields.push_back(this->VisitType(t)); + // } + // return TupleTypeNode::make(new_fields); + // } + + Type VisitType_(const TypeCallNode* op) override { + auto func = this->VisitType(op->func); + std::vector new_args; + for (const Type& t : op->args) { + new_args.push_back(this->VisitType(t)); + } + return TypeCallNode::make(func, new_args); + } + Type VisitType_(const IncompleteTypeNode* op) override { - return GetRef(op); + return GetRef(op); } -}; + }; } // namespace relay -} // namespace tvm +} // namespace relay #endif // TVM_RELAY_TYPE_VISITOR_H_ diff --git a/src/relay/compiler/typechecker.cc b/src/relay/compiler/typechecker.cc index c1f7b7f88765..e16481b7f9e0 100644 --- a/src/relay/compiler/typechecker.cc +++ b/src/relay/compiler/typechecker.cc @@ -764,7 +764,7 @@ TVM_REGISTER_API("relay._make.IncompleteType") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const IncompleteTypeNode *node, tvm::IRPrinter *p) { - p->stream << "IncompleteTypeNode(" << node->kind << ", " << &node << ")"; + p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; }); } // namespace relay diff --git a/src/relay/compiler/unifier.cc b/src/relay/compiler/unifier.cc index bfd3e1a5ff32..ff46e8e863d1 100644 --- a/src/relay/compiler/unifier.cc +++ b/src/relay/compiler/unifier.cc @@ -9,8 +9,8 @@ #include "tvm/relay/compiler/alpha_eq.h" #include "./unifier.h" #include "./type_visitor.h" +#include "./type_subst.h" // #include "tvm/relay/typeck/kindchecker.h" -// #include "tvm/relay/typeck/type_subst.h" namespace tvm { namespace relay { @@ -60,8 +60,6 @@ void UnionFindNode::unify(const IncompleteType &v1, const Type &t) { if (const IncompleteTypeNode *pvn1 = parent1.as()) { auto pv1 = GetRef(pvn1); this->uf_map.Set(pv1, parent2); - // path compression: can also set v1 directly - this->uf_map.Set(v1, parent2); return; } @@ -69,8 +67,6 @@ void UnionFindNode::unify(const IncompleteType &v1, const Type &t) { if (const IncompleteTypeNode *pvn2 = parent2.as()) { auto pv2 = GetRef(pvn2); this->uf_map.Set(pv2, parent1); - // path compression: can also set v2 directly - this->uf_map.Set(v2, parent1); return; } @@ -84,8 +80,6 @@ void UnionFindNode::unify(const IncompleteType &v1, const Type &t) { if (const IncompleteTypeNode *pvn1 = parent1.as()) { auto pv1 = GetRef(pvn1); this->uf_map.Set(pv1, t); - // path compression: can also set v1 directly - this->uf_map.Set(v1, t); return; } @@ -181,6 +175,24 @@ Type TypeUnifierNode::subst(const Type &t) { return ret; } +Type TypeUnifierNode::VisitType(const Type & t1, const Type t2) { + // When the right hand size is a type variable immediately unify. + if (const IncompleteTypeNode *tvn2 = t2.as()) { + return this->unifyWithIncompleteType(t1, GetRef(tvn2)); + } else { + return TypeFunctor::VisitType(t1, t2); + } +} + +Type TypeUnifierNode::unifyWithIncompleteType(const Type &t1, const IncompleteType tv2) { + RELAY_LOG(INFO) << "unifyWithIncompleteType: t1=" << t1 << " t2=" << tv2 << std::endl; + // Fix unify to return new representative + this->uf->unify(tv2, t1); + auto rep = this->uf->find(tv2); + RELAY_LOG(INFO) << "unifyWithIncompleteType: rep =" << rep << std::endl; + return rep; +} + Type TypeUnifierNode::VisitType_(const IncompleteTypeNode *t1, const Type rt2) { IncompleteType tv1 = GetRef(t1); RELAY_LOG(INFO) << "VisitType_: IncompleteTypeNode t1=" << t1 << " = " << rt2 @@ -194,11 +206,6 @@ Type TypeUnifierNode::VisitType_(const IncompleteTypeNode *t1, const Type rt2) { Type TypeUnifierNode::VisitType_(const TypeParamNode *t1, const Type rt2) { TypeParam ti1 = GetRef(t1); - // for typevars, remap and attempt to unify if already defined - if (const IncompleteTypeNode *tvn2 = rt2.as()) { - return this->unifyWithIncompleteType(ti1, GetRef(tvn2)); - } - // for other type ids, only check equality if (const TypeParamNode *tin2 = rt2.as()) { TypeParam ti2 = GetRef(tin2); @@ -215,75 +222,55 @@ Type TypeUnifierNode::VisitType_(const TypeParamNode *t1, const Type rt2) { } Type TypeUnifierNode::VisitType_(const FuncTypeNode *t1, const Type rt2) { - return rt2; -// TypeArrow ta1 = GetRef(t1); - -// // for typevar, remap if necessary -// if (const IncompleteTypeNode *tvn2 = rt2.as()) { -// return this->unifyWithIncompleteType(ta1, GetRef(tvn2)); -// } + FuncType ft1 = GetRef(t1); -// // for other arrow, unify arg and ret types -// if (const TypeArrowNode *tan2 = rt2.as()) { -// TypeArrow ta2 = GetRef(tan2); + if (const FuncTypeNode *tan2 = rt2.as()) { + FuncType ft2 = GetRef(tan2); -// if (ta1->arg_types.size() != ta2->arg_types.size()) { -// throw UnificationError("unable to unify functions of different arities"); -// } + if (ft1->type_params.size() != ft2->type_params.size()) { + throw UnificationError("unable to unify functions with differing number of type parameters"); + } -// tvm::Array unified_args; -// for (size_t i = 0; i < ta1->arg_types.size(); i++) { -// unified_args.push_back( -// this->VisitType(ta1->arg_types[i], ta2->arg_types[i])); -// } + if (ft1->type_params.size() != 0) { + throw dmlc::Error("NYI"); + } -// Type unified_ret_type = this->VisitType(ta1->ret_type, ta2->ret_type); -// return TypeArrowNode::make(unified_args, unified_ret_type); -// } + // TypeParam id1 = tq1->id; + // TypeParam id2 = tq2->id; -// throw UnificationError("Unable to unify TypeArrowNode"); -// } + // if (id1->kind != id2->kind) { + // throw UnificationError( + // "Cannot unify quantifiers over ids of different kinds"); + // } -// Type TypeUnifierNode::VisitType_(const TypeQuantifierNode *t1, const Type rt2) { -// TypeQuantifier tq1 = GetRef(t1); + // TypeParam fresh = TypeParamNode::make(id1->name, id1->kind); -// // for typevars, remap and attempt to unify if already defined -// if (const IncompleteTypeNode *tvn2 = rt2.as()) { -// return this->unifyWithIncompleteType(tq1, GetRef(tvn2)); -// } + // auto bt1 = type_subst(tq1->boundType, id1, fresh); + // auto bt2 = type_subst(tq2->boundType, id2, fresh); -// // for other quantifiers, attempt to unify bound types after normalizing -// if (const TypeQuantifierNode *tqn2 = rt2.as()) { -// TypeQuantifier tq2 = GetRef(tqn2); -// TypeParam id1 = tq1->id; -// TypeParam id2 = tq2->id; + // Type unified_bound_type = this->VisitType(bt1, bt2); -// if (id1->kind != id2->kind) { -// throw UnificationError( -// "Cannot unify quantifiers over ids of different kinds"); -// } + if (ft1->arg_types.size() != ft2->arg_types.size()) { + throw UnificationError("unable to unify functions of different arities"); + } -// TypeParam fresh = TypeParamNode::make(id1->name, id1->kind); + tvm::Array unified_args; + for (size_t i = 0; i < ft1->arg_types.size(); i++) { + unified_args.push_back( + this->VisitType(ft1->arg_types[i], ft2->arg_types[i])); + } -// auto bt1 = type_subst(tq1->boundType, id1, fresh); -// auto bt2 = type_subst(tq2->boundType, id2, fresh); + Type unified_ret_type = this->VisitType(ft1->ret_type, ft2->ret_type); -// Type unified_bound_type = this->VisitType(bt1, bt2); -// return TypeQuantifierNode::make(fresh, unified_bound_type); -// } + return FuncTypeNode::make(unified_args, unified_ret_type, {}, {}); + } -// // anything else cannot be unified -// throw UnificationError("Cannot unify TypeQuantifierNode"); + throw UnificationError("unable to unify function types"); } Type TypeUnifierNode::VisitType_(const TensorTypeNode *t1, const Type rt2) { TensorType tt1 = GetRef(t1); - // for typevars, remap and attempt to unify if already defined - if (const IncompleteTypeNode *tvn2 = rt2.as()) { - return this->unifyWithIncompleteType(tt1, GetRef(tvn2)); - } - if (const TensorTypeNode *ttn2 = rt2.as()) { TensorType tt2 = GetRef(ttn2); @@ -360,10 +347,6 @@ Type TypeUnifierNode::VisitType_(const TypeFunctionNode *sen1, const Type t2) { Type TypeUnifierNode::VisitType_(const TypeCallNode *tcn1, const Type t2) { TypeCall ty_call1 = GetRef(tcn1); - if (const IncompleteTypeNode *tvn2 = t2.as()) { - return this->unifyWithIncompleteType(ty_call1, GetRef(tvn2)); - } - if (const TypeCallNode *tcn2 = t2.as()) { Type unified_func = this->VisitType(ty_call1->func, tcn2->func); @@ -385,14 +368,7 @@ Type TypeUnifierNode::VisitType_(const TypeCallNode *tcn1, const Type t2) { } } -Type TypeUnifierNode::unifyWithIncompleteType(const Type &t1, const IncompleteType tv2) { - RELAY_LOG(INFO) << "unifyWithIncompleteType: t1=" << t1 << " t2=" << tv2 << std::endl; - // Fix unify to return new representative - this->uf->unify(tv2, t1); - auto rep = this->uf->find(tv2); - RELAY_LOG(INFO) << "unifyWithIncompleteType: rep =" << rep << std::endl; - return rep; -} + TVM_REGISTER_API("relay._make.TypeUnifier") .set_body([](TVMArgs args, TVMRetValue *ret) { diff --git a/src/relay/compiler/unifier.h b/src/relay/compiler/unifier.h index 6788265c90f2..cba96ff02451 100644 --- a/src/relay/compiler/unifier.h +++ b/src/relay/compiler/unifier.h @@ -101,7 +101,7 @@ class TypeUnifierNode : public Node, private: // unify non-typevar with typevar Type unifyWithIncompleteType(const Type& t1, const IncompleteType tvn2); - + Type VisitType(const Type & t1, const Type t2) override; Type VisitType_(const IncompleteTypeNode* t1, const Type t2) override; Type VisitType_(const TensorTypeNode* t1, const Type t2) override; Type VisitType_(const TypeParamNode* t1, const Type t2) override; diff --git a/src/relay/type.cc b/src/relay/type.cc index 22d37ea05fda..2b6647a5807e 100644 --- a/src/relay/type.cc +++ b/src/relay/type.cc @@ -6,6 +6,7 @@ #include "tvm/relay/type.h" #include "tvm/ir_functor.h" + namespace tvm { namespace relay { @@ -19,12 +20,49 @@ TensorType TensorTypeNode::make(Array shape, DataType dtype) { return TensorType(n); } +TensorType TensorTypeNode::Int(int bits, int lanes) { + return TensorTypeNode::make({}, HalideIR::Int(bits, lanes)); +} + +TensorType TensorTypeNode::UInt(int bits, int lanes) { + return TensorTypeNode::make({}, HalideIR::UInt(bits, lanes)); +} + +TensorType TensorTypeNode::Float(int bits, int lanes) { + return TensorTypeNode::make({}, HalideIR::Float(bits, lanes)); +} + +TensorType TensorTypeNode::Bool(int lanes) { + return TensorTypeNode::make({}, HalideIR::Bool(lanes)); +} + TVM_REGISTER_API("relay._make.TensorType") .set_body([](TVMArgs args, TVMRetValue *ret) { Array shape = args[0]; *ret = TensorTypeNode::make(shape, args[1]); }); + +TVM_REGISTER_API("relay._make.IntType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TensorTypeNode::Int(args[0], args[1]); + }); + +TVM_REGISTER_API("relay._make.UIntType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TensorTypeNode::UInt(args[0], args[1]); + }); + +TVM_REGISTER_API("relay._make.BoolType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TensorTypeNode::Bool(args[0]); + }); + +TVM_REGISTER_API("relay._make.FloatType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TensorTypeNode::Float(args[0], args[1]); + }); + TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const TensorTypeNode *node, tvm::IRPrinter *p) { diff --git a/tests/python/relay/test_unifier.py b/tests/python/relay/test_unifier.py index 065a91f0abbe..21889faa51ee 100644 --- a/tests/python/relay/test_unifier.py +++ b/tests/python/relay/test_unifier.py @@ -2,12 +2,16 @@ Test the type unifier, which solves systems of equations between incomplete types. """ +import tvm from tvm.relay import ir from tvm.relay.unifier import UnionFind, TypeUnifier +from tvm.relay.ir_builder import bool_type, uint_type, int_type, float_type, func_type +from tvm.relay import ir_builder as build import tvm.relay.make as mk + def test_insert_and_find(): - uf = mk.UnionFind()() + uf = mk.UnionFind() v1 = mk.IncompleteType(ir.Kind.Type) v2 = mk.IncompleteType(ir.Kind.Type) uf.insert(v1) @@ -15,8 +19,9 @@ def test_insert_and_find(): assert uf.find(v1) == v1 assert uf.find(v2) == v2 + def test_insert_error(): - uf = mk.UnionFind()() + uf = mk.UnionFind() v1 = mk.IncompleteType(ir.Kind.Type) v2 = mk.IncompleteType(ir.Kind.Type) uf.insert(v1) @@ -26,8 +31,9 @@ def test_insert_error(): except: return + def test_unify(): - uf = mk.UnionFind()() + uf = mk.UnionFind() v1 = mk.IncompleteType(ir.Kind.Type) v2 = mk.IncompleteType(ir.Kind.Type) v3 = mk.IncompleteType(ir.Kind.Type) @@ -48,8 +54,9 @@ def test_unify(): assert uf.find(v2) == new_rep assert uf.find(v3) == new_rep + def test_unify_multiple_levels(): - uf = mk.UnionFind()() + uf = mk.UnionFind() v = [mk.IncompleteType(ir.Kind.Type) for _ in range(9)] for var in v: uf.insert(var) @@ -84,81 +91,92 @@ def test_unify_multiple_levels(): # We have checked that the basic machinery in the UnionFind works # and now we will test the type unifier which will fill in holes # between type equalities by the process of unification. + + def unify_types(t1, t2): unifier = mk.TypeUnifier() return unifier.unify(t1, t2) # TODO(sslyu, weberlo, joshpoll): put in isinstance asserts once those work + + def test_unify_int(): - intty = IntType(1) + intty = int_type(1) unified = unify_types(intty, intty) assert intty == unified + def test_unify_bool(): - boolty = BoolType() + boolty = bool_type() unified = unify_types(boolty, boolty) assert boolty == unified + def test_unify_float(): - floatty = FloatType(4) + floatty = float_type(4) unified = unify_types(floatty, floatty) assert floatty == unified + def test_unify_incompatible_basetypes(): - bt = BoolType() - intty = IntType(32) + bt = bool_type() + intty = int_type(32) try: unify_types(bt, intty) assert False except: return -def test_unify_concrete_type_arrow(): - arr1 = TypeArrow([int_type()], int_type()) - arr2 = TypeArrow([int_type()], int_type()) + +def test_unify_concrete_func_type(): + arr1 = func_type([int_type()], int_type()) + arr2 = func_type([int_type()], int_type()) unified = unify_types(arr1, arr2) assert unified == arr1 -def test_unify_type_arrow_with_holes(): + +def test_unify_func_type_with_holes(): unifier = mk.TypeUnifier() v1 = mk.IncompleteType(ir.Kind.BaseType) unifier.insert(v1) unifier.unify(v1, bool_type()) - arr1 = TypeArrow([int_type()], bool_type()) - arr2 = TypeArrow([int_type()], v1) + arr1 = func_type([int_type()], bool_type()) + arr2 = func_type([int_type()], v1) unified = unifier.unify(arr1, arr2) assert unified == arr1 v2 = mk.IncompleteType(ir.Kind.BaseType) unifier.insert(v2) unifier.unify(v2, int_type()) - arr3 = TypeArrow([v2], bool_type()) + arr3 = func_type([v2], bool_type()) unified = unifier.unify(arr1, arr3) assert unified == arr1 -def test_reject_incompatible_type_arrows(): - arr1 = TypeArrow([int_type()], bool_type()) - arr2 = TypeArrow([int_type(), bool_type()], bool_type()) + +def test_reject_incompatible_func_types(): + arr1 = func_type([int_type()], bool_type()) + arr2 = func_type([int_type(), bool_type()], bool_type()) try: unify_types(arr1, arr2) assert False except: return -def test_unify_concrete_type_quantifiers(): - tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), int_type()) - tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), int_type()) - unified = unify_types(tq1, tq2) - assert unified == tq1 +# def test_unify_concrete_type_quantifiers(): +# tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), int_type()) +# tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), int_type()) +# unified = unify_types(tq1, tq2) +# assert unified == tq1 + +# def test_unify_basetype_with_quantifier_error(): +# bt = bool_type() +# tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt) +# try: +# unify_types(bt, tq) +# assert False +# except: +# return -def test_unify_basetype_with_quantifier_error(): - bt = bool_type() - tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt) - try: - unify_types(bt, tq) - assert False - except: - return def test_unify_typevars_with_each_other(): unifier = mk.TypeUnifier() @@ -174,11 +192,12 @@ def test_unify_typevars_with_each_other(): new_unified = unifier.unify(v1, v3) assert (new_unified == v1 or new_unified == v2 or new_unified == v3) + def test_unify_typevars_with_basetype(): unifier = mk.TypeUnifier() - bt = BoolType() - v1 = mk.IncompleteType(ir.Kind.BaseType) - v2 = mk.IncompleteType(ir.Kind.BaseType) + bt = bool_type() + v1 = mk.IncompleteType(ir.Kind.Type) + v2 = mk.IncompleteType(ir.Kind.Type) unifier.insert(v1) unifier.insert(v2) unified1 = unifier.unify(v1, bt) @@ -186,11 +205,12 @@ def test_unify_typevars_with_basetype(): unified2 = unifier.unify(v1, v2) assert unified2 == bt + def test_unify_compatible_typevars(): unifier = mk.TypeUnifier() - bt = BoolType() - v1 = mk.IncompleteType(ir.Kind.BaseType) - v2 = mk.IncompleteType(ir.Kind.BaseType) + bt = bool_type() + v1 = mk.IncompleteType(ir.Kind.Type) + v2 = mk.IncompleteType(ir.Kind.Type) unifier.insert(v1) unifier.insert(v2) unifier.unify(v1, bt) @@ -200,162 +220,154 @@ def test_unify_compatible_typevars(): unified = unifier.unify(v1, v2) assert unified == bt -def test_unify_incompatible_typevars(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.BaseType) - v2 = mk.IncompleteType(ir.Kind.BaseType) - bt = bool_type() - tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt) - unifier.insert(v1) - unifier.insert(v2) - unifier.unify(v1, bt) - unifier.unify(v2, tq) - # bt cannot be unified with tq, so unifying v1 and v2 should give an error - try: - unifier.unify(v1, v2) - assert False - except: - return +# def test_unify_incompatible_typevars(): +# unifier = mk.TypeUnifier() +# v1 = mk.IncompleteType(ir.Kind.Type) +# v2 = mk.IncompleteType(ir.Kind.Type) +# bt = bool_type() +# tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt) +# unifier.insert(v1) +# unifier.insert(v2) +# unifier.unify(v1, bt) +# unifier.unify(v2, tq) +# # bt cannot be unified with tq, so unifying v1 and v2 should give an error +# try: +# unifier.unify(v1, v2) +# assert False +# except: +# return + +# def test_unify_typevar_with_quantifier(): +# unifier = mk.TypeUnifier() +# tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bool_type()) +# v1 = mk.IncompleteType(ir.Kind.BaseType) +# unifier.insert(v1) +# unified = unifier.unify(v1, tq) +# assert unified == tq + +# def test_unify_typevars_inside_concrete_quantifier(): +# unifier = mk.TypeUnifier() +# v1 = mk.IncompleteType(ir.Kind.BaseType) +# unifier.insert(v1) +# tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v1) +# tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), bool_type()) +# unified = unifier.unify(tq1, tq2) +# assert unified == tq2 -def test_unify_typevar_with_quantifier(): - unifier = mk.TypeUnifier() - tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bool_type()) - v1 = mk.IncompleteType(ir.Kind.BaseType) - unifier.insert(v1) - unified = unifier.unify(v1, tq) - assert unified == tq - -def test_unify_typevars_inside_concrete_quantifier(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.BaseType) - unifier.insert(v1) - tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v1) - tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), bool_type()) - unified = unifier.unify(tq1, tq2) - assert unified == tq2 def test_unify_concrete_tensors(): - bt = BoolType() - shape = make_shape([1, 2, 3]) - tt1 = TensorType(bt, shape) - tt2 = TensorType(bt, shape) + bt = build.bool_dtype() + shape = tvm.convert([1, 2, 3]) + tt1 = mk.TensorType(shape, bt) + tt2 = mk.TensorType(shape, bt) unified = unify_types(tt1, tt2) assert unified == tt1 + def test_unify_tensor_shape_reject(): - bt = BoolType() - shape1 = make_shape([1, 2, 3]) - shape2 = make_shape([2, 3, 4]) - tt1 = TensorType(bt, shape1) - tt2 = TensorType(bt, shape2) + bt = build.bool_dtype() + shape1 = tvm.convert([1, 2, 3]) + shape2 = tvm.convert([2, 3, 4]) + tt1 = mk.TensorType(shape1, bt) + tt2 = mk.TensorType(shape2, bt) try: unify_types(tt1, tt2) assert False except: return + def test_unify_tensor_dtype_reject(): - bt1 = BoolType() - bt2 = IntType(32) - shape = make_shape([1, 2, 3]) - tt1 = TensorType(bt1, shape) - tt2 = TensorType(bt2, shape) + bt1 = build.bool_dtype() + bt2 = build.int_dtype() + shape = tvm.convert([1, 2, 3]) + tt1 = mk.TensorType(shape, bt1) + tt2 = mk.TensorType(shape, bt2) try: unify_types(tt1, tt2) assert False except: return -def test_unify_quantified_tensors(): - x = TypeParam("x", ir.type.Kind.Shape) - y = TypeParam("y", ir.type.Kind.Shape) - tq1 = TypeQuantifier(x, TensorType(BoolType(), x)) - tq2 = TypeQuantifier(y, TensorType(BoolType(), y)) - unified = unify_types(tq1, tq2) - assert unified == tq1 - - a = TypeParam("a", ir.type.Kind.BaseType) - b = TypeParam("b", ir.type.Kind.BaseType) - tq3 = TypeQuantifier(a, TensorType(a, make_shape([1, 2, 3]))) - tq4 = TypeQuantifier(b, TensorType(b, make_shape([1, 2, 3]))) - unified = unify_types(tq3, tq4) - assert unified == tq3 - -def test_unify_concrete_products(): - bt = bool_type() - intty = int_type() - pt1 = TupleType([bt, intty]) - pt2 = TupleType([bt, intty]) - unified = unify_types(pt1, pt2) - assert unified == pt1 - -def test_unify_products_reject_size(): - bt = BoolType() - intty = IntType(32) - pt1 = TupleType([bt, bt, intty]) - pt2 = TupleType([bt, intty]) - try: - unify_types(pt1, pt2) - assert False - except: - return - -def test_unify_products_reject_member(): - bt = BoolType() - intty = IntType(32) - pt1 = TupleType([bt, bt]) - pt2 = TupleType([bt, intty]) - try: - unify_types(pt1, pt2) - assert False - except: - return +# def test_unify_quantified_tensors(): +# x = TypeParam("x", ir.type.Kind.Shape) +# y = TypeParam("y", ir.type.Kind.Shape) +# tq1 = TypeQuantifier(x, mk.TensorType(bool_type(), x)) +# tq2 = TypeQuantifier(y, mk.TensorType(bool_type(), y)) +# unified = unify_types(tq1, tq2) +# assert unified == tq1 + +# a = TypeParam("a", ir.type.Kind.BaseType) +# b = TypeParam("b", ir.type.Kind.BaseType) +# tq3 = TypeQuantifier(a, mk.TensorType(a, make_shape([1, 2, 3]))) +# tq4 = TypeQuantifier(b, mk.TensorType(b, make_shape([1, 2, 3]))) +# unified = unify_types(tq3, tq4) +# assert unified == tq3 + +# def test_unify_concrete_products(): +# bt = bool_type() +# intty = int_type() +# pt1 = TupleType([bt, intty]) +# pt2 = TupleType([bt, intty]) +# unified = unify_types(pt1, pt2) +# assert unified == pt1 + +# def test_unify_products_reject_size(): +# bt = bool_type() +# intty = IntType(32) +# pt1 = TupleType([bt, bt, intty]) +# pt2 = TupleType([bt, intty]) +# try: +# unify_types(pt1, pt2) +# assert False +# except: +# return + +# def test_unify_products_reject_member(): +# bt = bool_type() +# intty = int_type() +# pt1 = TupleType([bt, bt]) +# pt2 = TupleType([bt, intty]) +# try: +# unify_types(pt1, pt2) +# assert False +# except: +# return + +# def test_unify_products_typevar(): +# unifier = mk.TypeUnifier() +# v1 = mk.IncompleteType(ir.Kind.BaseType) +# bt = bool_type() +# pt1 = TupleType([bt, bt]) +# pt2 = TupleType([v1, bt]) +# unifier.insert(v1) +# unified = unifier.unify(pt1, pt2) +# assert unified == pt1 + +# def test_unify_quantified_products(): +# x = TypeParam("x", ir.Kind.Type) +# y = TypeParam("y", ir.Kind.Type) +# p1 = TypeQuantifier(x, TupleType([int_type(), x])) +# p2 = TypeQuantifier(y, TupleType([int_type(), y])) +# unified = unify_types(p1, p2) +# assert unified == p1 -def test_unify_products_typevar(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.BaseType) - bt = bool_type() - pt1 = TupleType([bt, bt]) - pt2 = TupleType([v1, bt]) - unifier.insert(v1) - unified = unifier.unify(pt1, pt2) - assert unified == pt1 - -def test_unify_quantified_products(): - x = TypeParam("x", ir.Kind.Type) - y = TypeParam("y", ir.Kind.Type) - p1 = TypeQuantifier(x, TupleType([int_type(), x])) - p2 = TypeQuantifier(y, TupleType([int_type(), y])) - unified = unify_types(p1, p2) - assert unified == p1 - -def test_unify_ref_types(): - r1 = RefType(bool_type()) - r2 = RefType(bool_type()) - assert unify_types(r1, r2) == r1 - -def test_unify_ref_reject_inner(): - r1 = RefType(BoolType()) - r2 = RefType(IntType(32)) - try: - unify_types(r1, r2) - assert False - except: - return def test_subst_basetype(): unifier = mk.TypeUnifier() - bt = BoolType() + bt = bool_type() assert bt == unifier.subst(bt) + def test_subst_simple_hole(): unifier = mk.TypeUnifier() v1 = mk.IncompleteType(ir.Kind.BaseType) - bt = BoolType() + bt = bool_type() unifier.insert(v1) unifier.unify(v1, bt) assert unifier.subst(v1) == bt + def test_subst_typevar_for_typevar(): unifier = mk.TypeUnifier() v1 = mk.IncompleteType(ir.Kind.Type) @@ -364,13 +376,26 @@ def test_subst_typevar_for_typevar(): unifier.insert(v2) unifier.unify(v1, v2) - assert unifier.subst(v1) == v2 + assert unifier.subst(v1) == unifier.subst(v2) + + +def test_subst_typevar_for_typevar_comm(): + unifier = mk.TypeUnifier() + v1 = mk.IncompleteType(ir.Kind.Type) + v2 = mk.IncompleteType(ir.Kind.Type) + unifier.insert(v1) + unifier.insert(v2) + + unifier.unify(v2, v1) + assert unifier.subst(v1) == unifier.subst(v2) + def test_subst_concrete_arrow(): unifier = mk.TypeUnifier() - arr1 = TypeArrow([int_type()], int_type()) + arr1 = func_type([int_type()], int_type()) assert unifier.subst(arr1) == arr1 + def test_subst_arrow_with_holes(): unifier = mk.TypeUnifier() v1 = mk.IncompleteType(ir.Kind.BaseType) @@ -379,93 +404,93 @@ def test_subst_arrow_with_holes(): unifier.insert(v2) unifier.unify(v1, int_type()) unifier.unify(v2, bool_type()) - arr1 = TypeArrow([v1], v2) - arr2 = TypeArrow([int_type()], bool_type()) + arr1 = func_type([v1], v2) + arr2 = func_type([int_type()], bool_type()) assert unifier.subst(arr1) == arr2 -def test_subst_concrete_quantifier(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.BaseType) - tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), int_type()) - unifier.insert(v1) - unifier.unify(v1, tq) - assert unifier.subst(v1) == tq +# def test_subst_concrete_quantifier(): +# unifier = mk.TypeUnifier() +# v1 = mk.IncompleteType(ir.Kind.BaseType) +# tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), int_type()) +# unifier.insert(v1) +# unifier.unify(v1, tq) +# assert unifier.subst(v1) == tq + +# def test_subst_quantifier_with_holes(): +# unifier = mk.TypeUnifier() +# v1 = mk.IncompleteType(ir.Kind.Type) +# v2 = mk.IncompleteType(ir.Kind.Type) +# tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v2) +# intty = int_type() +# tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), intty) + # unifier.insert(v1) + # unifier.insert(v2) + # unifier.unify(v2, intty) + # unifier.unify(v1, tq1) + # assert unifier.subst(v1) == tq2 -def test_subst_quantifier_with_holes(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.Type) - v2 = mk.IncompleteType(ir.Kind.Type) - tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v2) - intty = int_type() - tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), intty) - - unifier.insert(v1) - unifier.insert(v2) - unifier.unify(v2, intty) - unifier.unify(v1, tq1) - assert unifier.subst(v1) == tq2 def test_subst_concrete_tensor(): unifier = mk.TypeUnifier() v1 = mk.IncompleteType(ir.Kind.Type) unifier.insert(v1) - tt = TensorType(BoolType(), make_shape([1, 2, 3])) + tt = mk.TensorType(tvm.convert([1, 2, 3]), 'uint1') unifier.unify(v1, tt) assert unifier.subst(v1) == tt -def test_subst_concrete_product(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.Type) - unifier.insert(v1) - bt = bool_type() - pt = TupleType([bt, bt]) - unifier.unify(v1, pt) - assert unifier.subst(v1) == pt - -def test_subst_product_with_holes(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.Type) - v2 = mk.IncompleteType(ir.Kind.Type) - v3 = mk.IncompleteType(ir.Kind.Type) - unifier.insert(v1) - unifier.insert(v2) - unifier.insert(v3) - - tt1 = TensorType(IntType(32), ShapeSeq([])) - tt2 = TensorType(FloatType(32), ShapeSeq([])) - pt1 = TupleType([tt1, v2, v3]) - unifier.unify(v2, tt2) - unifier.unify(v3, v2) - unifier.unify(v1, pt1) - pt2 = TupleType([tt1, tt2, tt2]) - assert unifier.subst(v1) == pt2 - -def test_subst_concrete_ref(): - unifier = mk.TypeUnifier() - rt = RefType(bool_type()) - assert unifier.subst(rt) == rt - -def test_subst_ref_with_hole(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.Type) - unifier.insert(v1) - - unifier.unify(v1, bool_type()) - rt1 = RefType(v1) - rt2 = RefType(bool_type()) - assert unifier.subst(rt1) == rt2 - -def test_typevar_on_lhs(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.BaseType) - v2 = mk.IncompleteType(ir.Kind.Type) - bt = bool_type() - tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt, bt) - unifier.insert(v1) - unifier.insert(v2) - unified1 = unifier.unify(bt, v1) - assert unified1 == bt - unified2 = unifier.unify(tq, v2) - assert unified2 == tq - assert unifier.subst(v1) == bt - assert unifier.subst(v2) == tq +# def test_subst_concrete_product(): +# unifier = mk.TypeUnifier() +# v1 = mk.IncompleteType(ir.Kind.Type) +# unifier.insert(v1) +# bt = bool_type() +# pt = TupleType([bt, bt]) +# unifier.unify(v1, pt) +# assert unifier.subst(v1) == pt + +# def test_subst_product_with_holes(): +# unifier = mk.TypeUnifier() +# v1 = mk.IncompleteType(ir.Kind.Type) +# v2 = mk.IncompleteType(ir.Kind.Type) +# v3 = mk.IncompleteType(ir.Kind.Type) +# unifier.insert(v1) +# unifier.insert(v2) +# unifier.insert(v3) + +# tt1 = mk.TensorType(int_type(), tvm.convert([])) +# tt2 = mk.TensorType(FloatType(32), tvm.convert([])) +# pt1 = TupleType([tt1, v2, v3]) +# unifier.unify(v2, tt2) +# unifier.unify(v3, v2) +# unifier.unify(v1, pt1) +# pt2 = TupleType([tt1, tt2, tt2]) +# assert unifier.subst(v1) == pt2 + +# def test_subst_concrete_ref(): +# unifier = mk.TypeUnifier() +# rt = RefType(bool_type()) +# assert unifier.subst(rt) == rt + +# def test_subst_ref_with_hole(): +# unifier = mk.TypeUnifier() +# v1 = mk.IncompleteType(ir.Kind.Type) +# unifier.insert(v1) + +# unifier.unify(v1, bool_type()) +# rt1 = RefType(v1) +# rt2 = RefType(bool_type()) +# assert unifier.subst(rt1) == rt2 + +# def test_typevar_on_lhs(): +# unifier = mk.TypeUnifier() +# v1 = mk.IncompleteType(ir.Kind.BaseType) +# v2 = mk.IncompleteType(ir.Kind.Type) +# bt = bool_type() +# tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt, bt) +# unifier.insert(v1) +# unifier.insert(v2) +# unified1 = unifier.unify(bt, v1) +# assert unified1 == bt +# unified2 = unifier.unify(tq, v2) +# assert unified2 == tq +# assert unifier.subst(v1) == bt +# assert unifier.subst(v2) == tq From 4cdf3f9b89cb085a5358b71bd435639e55475493 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 22 Aug 2018 01:01:31 -0700 Subject: [PATCH 024/136] Start refactoring type checker Introduce both Environment and type inference Python interfaces for testing. --- .../compiler/{typechecker.h => type_infer.h} | 10 +- python/tvm/relay/_env.py | 5 + python/tvm/relay/_env.pyi | 18 ++++ python/tvm/relay/_type_infer.py | 5 + python/tvm/relay/_type_infer.pyi | 6 ++ python/tvm/relay/env.py | 98 +++++++++++++++++++ python/tvm/relay/ir_builder.py | 4 +- python/tvm/relay/type_infer.py | 6 ++ .../{typechecker.cc => type_infer.cc} | 28 +++--- tests/python/relay/test_typechecker.py | 17 ++++ 10 files changed, 177 insertions(+), 20 deletions(-) rename include/tvm/relay/compiler/{typechecker.h => type_infer.h} (70%) create mode 100644 python/tvm/relay/_env.py create mode 100644 python/tvm/relay/_env.pyi create mode 100644 python/tvm/relay/_type_infer.py create mode 100644 python/tvm/relay/_type_infer.pyi create mode 100644 python/tvm/relay/env.py create mode 100644 python/tvm/relay/type_infer.py rename src/relay/compiler/{typechecker.cc => type_infer.cc} (98%) create mode 100644 tests/python/relay/test_typechecker.py diff --git a/include/tvm/relay/compiler/typechecker.h b/include/tvm/relay/compiler/type_infer.h similarity index 70% rename from include/tvm/relay/compiler/typechecker.h rename to include/tvm/relay/compiler/type_infer.h index c69aba3c1e71..4c16defe977f 100644 --- a/include/tvm/relay/compiler/typechecker.h +++ b/include/tvm/relay/compiler/type_infer.h @@ -1,8 +1,10 @@ /*! - * Copyright (c) 2017 by Contributors - * \file tvm/relay/typechecker.h - * \brief Type check a Relay program producing a type checked program - * with its checked_type field populated and incomplete types resolved. + * Copyright (c) 2018 by Contributors + * \file tvm/relay/type_infer.h + * \brief Perform type inference and checking on Relay programs. + * + * The pass produces a new expression with its checked_type + * field populated and incomplete types resolved. */ #ifndef TVM_RELAY_COMPILER_TYPECHECKER_H_ #define TVM_RELAY_COMPILER_TYPECHECKER_H_ diff --git a/python/tvm/relay/_env.py b/python/tvm/relay/_env.py new file mode 100644 index 000000000000..25b8715a7816 --- /dev/null +++ b/python/tvm/relay/_env.py @@ -0,0 +1,5 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable +"""The interface to the Environment exposed from C++.""" +from tvm._ffi.function import _init_api + +_init_api("relay._env", __name__) diff --git a/python/tvm/relay/_env.pyi b/python/tvm/relay/_env.pyi new file mode 100644 index 000000000000..d14e726e5443 --- /dev/null +++ b/python/tvm/relay/_env.pyi @@ -0,0 +1,18 @@ +from typing import Union, Tuple, Dict, List +from relay.ir import GlobalId, OperatorId, Item, NodeBase, Span, FileId +from relay.ir import ShapeExtension, Operator, Defn + +class Environment(NodeBase): ... + +def Environment_add(self: Environment, func: GlobalId) -> None: ... +def Environment_global_id(self: Environment, name: str) -> GlobalId: ... +def Environment_operator_id(self: Environment, name: str) -> OperatorId: ... +def Environment_lookup_global(self: Environment, id: GlobalId) -> Item: ... +def Environment_lookup_operator(self: Environment, id: OperatorId) -> Item: ... +def Environment_remove_global(self: Environment, id: GlobalId) -> Item: ... +def Environment_add_source(self: Environment, file_name: str, source: str) -> FileId: ... +def Environment_report_error(self: Environment, message: str, span: Span) -> None: ... +def Environment_display_errors(self: Environment) -> None: ... +def Environment_register_shape_ext(self: Environment, shape_ext: ShapeExtension) -> None: ... +def Environment_get_operators(self: Environment) -> List[Operator]: ... +def Environment_get_defns(self: Environment) -> List[Defn]: ... diff --git a/python/tvm/relay/_type_infer.py b/python/tvm/relay/_type_infer.py new file mode 100644 index 000000000000..7213769a4164 --- /dev/null +++ b/python/tvm/relay/_type_infer.py @@ -0,0 +1,5 @@ +"""FFI exposing the Relay type inference and checking.""" + +from tvm._ffi.function import _init_api + +_init_api("relay._type_infer", __name__) diff --git a/python/tvm/relay/_type_infer.pyi b/python/tvm/relay/_type_infer.pyi new file mode 100644 index 000000000000..1bb42ab854c2 --- /dev/null +++ b/python/tvm/relay/_type_infer.pyi @@ -0,0 +1,6 @@ +from .env import Environment +from . import ir + +def check_expr(env: Environment, expr: ir.Expr) -> ir.Type: ... +def generalize(env: Environment, expr: ir.Expr) -> ir.Expr: ... +def _get_checked_type(expr: ir.Expr) -> ir.Type: ... diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py new file mode 100644 index 000000000000..9bd63476f1fb --- /dev/null +++ b/python/tvm/relay/env.py @@ -0,0 +1,98 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import +"""A global environment storing everything needed to interpret or compile a Realy program.""" +from typing import Union, List +from relay.ir import register_relay_node, NodeBase +from relay.ir import GlobalId, OperatorId, Item, FileId, Span, ShapeExtension +from relay.ir import Operator, Defn +from relay._env import * +import tvm + +# Move me to C++ if possible. +__tgt_host__ = __tgt__ = "llvm" +__relay_tvm_context__ = tvm.cpu() + +ADD_ID = "__add__" +SUB_ID = "__sub__" +MUL_ID = "__mul__" +DIV_ID = "__div__" +NEG_ID = "__neg__" +LT_ID = "__lt__" +LE_ID = "__le__" +GT_ID = "__gt__" +GE_ID = "__ge__" +EQ_ID = "__eq__" +NE_ID = "__ne__" + +@register_relay_node +class Environment(NodeBase): + """The global Relay environment containing definitions, + primitives, options, and more. + """ + def add(self, item: Item) -> None: + return Environment_add(self, item) + + def global_id(self, name: str) -> GlobalId: + return Environment_global_id(self, name) + + def operator_id(self, name: str) -> OperatorId: + return Environment_operator_id(self, name) + + def lookup(self, ident: Union[GlobalId, OperatorId]) -> Item: + if isinstance(ident, OperatorId): + return Environment_lookup_operator(self, ident) + else: + return Environment_lookup_global(self, ident) + + def add_source(self, file_name: str, source: str) -> FileId: + return Environment_add_source(self, file_name, source) + + def report_error(self, message: str, span: Span) -> None: + return Environment_report_error(self, message, span) + + def register_shape_ext(self, ext: ShapeExtension) -> None: + return Environment_register_shape_ext(self, ext) + + def display_errors(self) -> None: + return Environment_display_errors(self) + + def operators(self) -> List[Operator]: + return Environment_get_operators(self) + + def defns(self) -> List[Defn]: + return Environment_get_defns(self) + + def tvm_context(self): + return __relay_tvm_context__ + + def add_id(self) -> OperatorId: + return self.operator_id(ADD_ID) + + def sub_id(self) -> OperatorId: + return self.operator_id(SUB_ID) + + def mul_id(self) -> OperatorId: + return self.operator_id(MUL_ID) + + def div_id(self) -> OperatorId: + return self.operator_id(DIV_ID) + + def neg_id(self) -> OperatorId: + return self.operator_id(NEG_ID) + + def lt_id(self) -> OperatorId: + return self.operator_id(LT_ID) + + def le_id(self) -> OperatorId: + return self.operator_id(LE_ID) + + def gt_id(self) -> OperatorId: + return self.operator_id(GT_ID) + + def ge_id(self) -> OperatorId: + return self.operator_id(GE_ID) + + def eq_id(self) -> OperatorId: + return self.operator_id(EQ_ID) + + def ne_id(self) -> OperatorId: + return self.operator_id(NE_ID) diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 8fa9b789f53c..2b2cdb432b43 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -64,11 +64,11 @@ def bind(self, name, type, value): return lv - def let(self, name, value): + def let(self, name, value, value_type=None): if not isinstance(value, expr.Expr): value = into_ast(value) - return self.bind(name, None, value) + return self.bind(name, value_type, value) def function(self, params): def _on_exit(): diff --git a/python/tvm/relay/type_infer.py b/python/tvm/relay/type_infer.py new file mode 100644 index 000000000000..17938dfdcbc4 --- /dev/null +++ b/python/tvm/relay/type_infer.py @@ -0,0 +1,6 @@ +#pylint: disable-all + +from . import _type_infer + +check_expr = _type_infer.check_expr +# generalize = _type_infer.generalize diff --git a/src/relay/compiler/typechecker.cc b/src/relay/compiler/type_infer.cc similarity index 98% rename from src/relay/compiler/typechecker.cc rename to src/relay/compiler/type_infer.cc index e16481b7f9e0..0b7435598d6d 100644 --- a/src/relay/compiler/typechecker.cc +++ b/src/relay/compiler/type_infer.cc @@ -1,10 +1,10 @@ /*! * Copyright (c) 2018 by Contributors - * \file typechecker.cc - * \brief Relay typechecker + * \file type_infer.cc + * \brief Relay type inference and checking. */ -#include "tvm/relay/compiler/typechecker.h" +#include "tvm/relay/compiler/type_infer.h" #include "./incomplete_type.h" // #include "tvm/relay/alpha_eq.h" // #include "tvm/relay/debug.h" @@ -724,12 +724,12 @@ namespace relay { // } // } -// TVM_REGISTER_API("relay._tyck.check_expr") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// Expr e = args[1]; -// *ret = check(env, e); -// }); +TVM_REGISTER_API("relay._type_infer.check_expr") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + Expr e = args[1]; + *ret = check(env, e); + }); // TVM_REGISTER_API("relay._tyck.check_item") // .set_body([](TVMArgs args, TVMRetValue *ret) { @@ -738,11 +738,11 @@ namespace relay { // *ret = check(env, i); // }); -// TVM_REGISTER_API("relay._tyck.get_checked_type") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Expr e = args[0]; -// *ret = e->checked_type(); -// }); +TVM_REGISTER_API("relay._type_infer._get_checked_type") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Expr e = args[0]; + *ret = e->checked_type(); + }); // TVM_REGISTER_API("relay._tyck.generalize") // .set_body([](TVMArgs args, TVMRetValue *ret) { diff --git a/tests/python/relay/test_typechecker.py b/tests/python/relay/test_typechecker.py new file mode 100644 index 000000000000..e5466d4439b9 --- /dev/null +++ b/tests/python/relay/test_typechecker.py @@ -0,0 +1,17 @@ +"""Test that type checker correcly computes types + for expressions. +""" +import tvm.relay.make as mk +from tvm.relay.ir_builder import IRBuilder, float_type + +def test_monomorphic_let(): + b = IRBuilder() + # Program: let x = 1; x + x = b.let('x', 1, value_type=float_type()) + b.ret(x) + + prog = b.get() + e = check_expr(prog) + e.get_type() + + From a1027bf040d7b38f1b7fa74b740bcac08eb78c42 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 22 Aug 2018 01:13:44 -0700 Subject: [PATCH 025/136] Get a failing test for the type checker --- include/tvm/relay/compiler/type_infer.h | 8 ++++++-- python/tvm/relay/expr.py | 4 +++- python/tvm/relay/make.py | 3 +++ src/relay/compiler/environment.cc | 18 +++++++++--------- src/relay/compiler/type_infer.cc | 13 +++++++------ tests/python/relay/test_typechecker.py | 9 +++++++-- 6 files changed, 35 insertions(+), 20 deletions(-) diff --git a/include/tvm/relay/compiler/type_infer.h b/include/tvm/relay/compiler/type_infer.h index 4c16defe977f..6d07de1c29e8 100644 --- a/include/tvm/relay/compiler/type_infer.h +++ b/include/tvm/relay/compiler/type_infer.h @@ -19,8 +19,12 @@ namespace relay { * with unambigous type information filled in, as well as it's * checked type field populated with the result type. */ -Expr check(const Environment & env, const Expr & e); -Operator check(const Environment & env, const Operator & op); +Expr Infer(const Environment & env, const Expr & e); + +/*! \brief Ensures that an operator is well-formed with respect + * to Relay's type system. + */ +Operator CheckOperator(const Environment & env, const Operator & op); } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index dea3a99f5f09..c17a69dd0dc9 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -6,10 +6,12 @@ from .base import Span, NodeBase, register_relay_node from .type import Type, TypeParam from tvm import expr +from ._type_infer import _get_checked_type class Expr(NodeBase): """The base type for all Relay exprressions.""" - pass + def checked_type(self): + return _get_checked_type(self) @register_relay_node class Constant(Expr): diff --git a/python/tvm/relay/make.py b/python/tvm/relay/make.py index 236e2f6af596..bf9ec0e48f64 100644 --- a/python/tvm/relay/make.py +++ b/python/tvm/relay/make.py @@ -4,6 +4,9 @@ # Base Constructors Span = _make.Span +# Environment +Environment = _make.Environment + # Type Constructors TensorType = _make.TensorType TypeParam = _make.TypeParam diff --git a/src/relay/compiler/environment.cc b/src/relay/compiler/environment.cc index 125ceae834b3..af8f5eeefab7 100644 --- a/src/relay/compiler/environment.cc +++ b/src/relay/compiler/environment.cc @@ -187,10 +187,10 @@ Environment EnvironmentNode::make( // this->shape_exts_.Insert(ext->name, ext); // } -// TVM_REGISTER_API("relay._make.Environment") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// *ret = EnvironmentNode::make({}); -// }); +TVM_REGISTER_API("relay._make.Environment") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = EnvironmentNode::make({}); + }); // TVM_REGISTER_API("relay._env.Environment_add") // .set_body([](TVMArgs args, TVMRetValue *ret) { @@ -282,11 +282,11 @@ Environment EnvironmentNode::make( // *ret = env->get_defns(); // }); -// TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -// .set_dispatch([](const EnvironmentNode *node, -// tvm::IRPrinter *p) { -// p->stream << "EnvironmentNode(todo)"; // << node->items << ")"; -// }); +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const EnvironmentNode *node, + tvm::IRPrinter *p) { + p->stream << "EnvironmentNode(todo)"; // << node->items << ")"; + }); } // namespace relay } // namespace tvm diff --git a/src/relay/compiler/type_infer.cc b/src/relay/compiler/type_infer.cc index 0b7435598d6d..96d9dc92d97e 100644 --- a/src/relay/compiler/type_infer.cc +++ b/src/relay/compiler/type_infer.cc @@ -625,12 +625,13 @@ namespace relay { // } // } -// Type check(const Environment &env, const Expr &e) { -// Typechecker tc(env); -// return tc.Check(e); -// } +Expr Infer(const Environment &env, const Expr &e) { + //Typechecker tc(env); + // return tc.Check(e); + return e; +} -// Item check(const Environment &env, const Item &i) { +// Item Check(const Environment &env, const Item &i) { // Typechecker tc(env); // try { @@ -728,7 +729,7 @@ TVM_REGISTER_API("relay._type_infer.check_expr") .set_body([](TVMArgs args, TVMRetValue *ret) { Environment env = args[0]; Expr e = args[1]; - *ret = check(env, e); + *ret = Infer(env, e); }); // TVM_REGISTER_API("relay._tyck.check_item") diff --git a/tests/python/relay/test_typechecker.py b/tests/python/relay/test_typechecker.py index e5466d4439b9..5626fd8ce0bc 100644 --- a/tests/python/relay/test_typechecker.py +++ b/tests/python/relay/test_typechecker.py @@ -2,8 +2,14 @@ for expressions. """ import tvm.relay.make as mk +from tvm.relay.type_infer import check_expr from tvm.relay.ir_builder import IRBuilder, float_type +def has_type(expr, typ): + env = mk.Environment({}) + checked_expr = check_expr(env, expr) + return checked_expr.checked_type() == typ + def test_monomorphic_let(): b = IRBuilder() # Program: let x = 1; x @@ -11,7 +17,6 @@ def test_monomorphic_let(): b.ret(x) prog = b.get() - e = check_expr(prog) - e.get_type() + assert has_type(prog, float_type()) From fe6b31a48b022a55480fcbef393c2289fa1a20cb Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 22 Aug 2018 02:24:00 -0700 Subject: [PATCH 026/136] Iterate on first test case --- include/tvm/relay/error.h | 10 + include/tvm/relay/expr.h | 8 +- include/tvm/relay/expr_visitor.h | 18 +- src/relay/compiler/resolve.cc | 99 ++ src/relay/compiler/resolve.h | 23 + src/relay/compiler/type_infer.cc | 1461 +++++++++++++++--------------- src/relay/compiler/unifier.h | 5 +- 7 files changed, 870 insertions(+), 754 deletions(-) create mode 100644 src/relay/compiler/resolve.cc create mode 100644 src/relay/compiler/resolve.h diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index d2698f8e380b..4f6a27d209c8 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -22,6 +22,16 @@ struct SpannedError { SpannedError(std::string msg, Span sp) : msg(msg), sp(sp) {} }; +// FIX, we should change spanned errors to have a method which allow them to report on the Environment, +// inverting control to error definition. +struct FatalTypeError : dmlc::Error { + explicit FatalTypeError(const std::string & s) : dmlc::Error(s) {} +}; + +struct TypecheckerError : public dmlc::Error { + explicit TypecheckerError(const std::string &msg) : Error(msg) {} +}; + } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index c1dd557717af..a29c8486ffb6 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -15,6 +15,10 @@ namespace tvm { namespace relay { + +// TOD0(@jroesch): best way to define? +class TypeInferencer; + /*! * \brief Relay expression. */ @@ -24,13 +28,14 @@ class Expr; */ class ExprNode : public RelayNode { public: + // private: /*! * \brief Stores the result of type inference(type checking). * * \note This can be undefined before type inference. * this value is discarded during serialization. */ - Type checked_type_ = Type(nullptr); + mutable Type checked_type_ = Type(nullptr); /*! * \return The checked_type */ @@ -43,6 +48,7 @@ class ExprNode : public RelayNode { static constexpr const char* _type_key = "relay.Expr"; TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode); + friend class TypeInferencer; }; RELAY_DEFINE_NODE_REF(Expr, ExprNode, NodeRef); diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h index 721fa531a7e3..2039414b4238 100644 --- a/include/tvm/relay/expr_visitor.h +++ b/include/tvm/relay/expr_visitor.h @@ -97,9 +97,16 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor ty_params; + tvm::Array ty_params; + for (auto ty : op->type_params) { - ty_params.push_back(this->VisitType(ty, args...)); + Type ty_param_type = VisitType(ty, args...); + if (auto ty_param = ty_param_type.as()) { + auto ty_param_ref = GetRef(ty_param); + ty_params.push_back(ty_param_ref); + } else { + throw dmlc::Error("the default func visitor has bug"); + } } tvm::Array params; @@ -115,7 +122,7 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctorVisitType(op->ret_type, args...); auto body = this->VisitExpr(op->body, args...); - return FunctionNode::make(ty_params, params, ret_type, body); + return FunctionNode::make(params, ret_type, body, ty_params); } Expr VisitExpr_(const CallNode* call_node, Args... args) override { @@ -132,8 +139,7 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctorVisitExpr(arg, args...)); } - auto call = CallNode::make(fn, call_args, call_node->attrs); - call->ty_args = ty_args; + auto call = CallNode::make(fn, call_args, call_node->attrs, ty_args); return call; } @@ -145,7 +151,7 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctorVisitType(op->value_type, args...); auto value = this->VisitExpr(op->value, args...); auto body = this->VisitExpr(op->body, args...); - return LetNode::make(var, type, value, body); + return LetNode::make(var, value, body, type); } else { throw dmlc::Error("the default let visitor has error"); } diff --git a/src/relay/compiler/resolve.cc b/src/relay/compiler/resolve.cc new file mode 100644 index 000000000000..2d3e84dc2160 --- /dev/null +++ b/src/relay/compiler/resolve.cc @@ -0,0 +1,99 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file unifier.cc + * \brief Data structures for type unification + */ + +#include "./resolve.h" +#include "./type_visitor.h" +#include "tvm/relay/expr_visitor.h" +#include "tvm/relay/ir.h" + +namespace tvm { +namespace relay { + +// We should probably generalize the subst code. +struct ResolveTypeType : TypeFVisitor { + const TypeUnifier &unifier; + + explicit ResolveTypeType(const TypeUnifier &unifier) : unifier(unifier) {} + + Type VisitType(const Type &t) override { + if (!t.defined()) { + auto inc_ty = IncompleteTypeNode::make(TypeParamNode::Kind::kType); + unifier->insert(inc_ty); + return inc_ty; + } else { + return TypeFVisitor::VisitType(t); + } + } + + Type VisitType_(const IncompleteTypeNode *op) override { + return unifier->subst(GetRef(op)); + } +}; + +struct ResolveTypeExpr : ExprFVisitor<> { + const TypeUnifier &unifier; + + explicit ResolveTypeExpr(const TypeUnifier &unifier) : unifier(unifier) {} + + Expr VisitExpr(const Expr &e) { + // NB: a bit tricky here. + // + // We want to store resolved type without having + // to re-typecheck the entire term. + // + // Since we know that e : T[...] under some holes + // then it is the case that if we resolve types + // present in e, then we can type it under T + // with the wholes filled in. + // + // We will visit e like normal building a new + // term, then resolve e's old type and write + // it back into the new node. + auto new_e = ExprFVisitor::VisitExpr(e); + auto resolved_cty = VisitType(e->checked_type_); + new_e->checked_type_ = resolved_cty; + return new_e; + } + + Type VisitType(const Type &t) { + return ResolveTypeType(unifier).VisitType(t); + } +}; + +Type resolve(const TypeUnifier &unifier, const Type &ty) { + return ResolveTypeType(unifier).VisitType(ty); +} + +Expr resolve(const TypeUnifier &unifier, const Expr &expr) { + return ResolveTypeExpr(unifier).VisitExpr(expr); +} + +struct FullyResolved : TypeVisitor<> { + bool incomplete; + + FullyResolved() : incomplete(true) {} + + void VisitType(const Type &t) override { + if (!t.defined()) { + incomplete = true; + } else { + return TypeVisitor<>::VisitType(t); + } + } + + void VisitType_(const IncompleteTypeNode *ty_var) override { + incomplete = false; + } +}; + +bool is_fully_resolved(const Type &t) { + auto fr = FullyResolved(); + fr.VisitType(t); + return fr.incomplete; +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/compiler/resolve.h b/src/relay/compiler/resolve.h new file mode 100644 index 000000000000..b4e164df6287 --- /dev/null +++ b/src/relay/compiler/resolve.h @@ -0,0 +1,23 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/options.h + * \brief Global options for the Relay IR. + */ +#ifndef TVM_RELAY_TYPECK_RESOLVE_H_ +#define TVM_RELAY_TYPECK_RESOLVE_H_ + +#include +#include "tvm/relay/ir.h" +#include "./unifier.h" + +namespace tvm { +namespace relay { + +Type resolve(const TypeUnifier & unifier, const Type & ty); +Expr resolve(const TypeUnifier & unifier, const Expr & expr); +bool is_fully_resolved(const Type & t); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_TYPECK_RESOLVE_H_ diff --git a/src/relay/compiler/type_infer.cc b/src/relay/compiler/type_infer.cc index 96d9dc92d97e..49c8bbf9627f 100644 --- a/src/relay/compiler/type_infer.cc +++ b/src/relay/compiler/type_infer.cc @@ -2,771 +2,740 @@ * Copyright (c) 2018 by Contributors * \file type_infer.cc * \brief Relay type inference and checking. + * + * This file implements one of the most important passes to the + * Relay IR. In order to do many transformations and generate the + * most efficient code we need to obtain type information for the + * IR. + * + * Like computation graphs the IR leaves most type information + * implicit and relies performing analysis of the program to + * generate this information. + * + * This pass given an expression `e` will infer a type `t` for + * the expression simultaneous checking the property `e : t` + * (i.e we can show e has type t). + * + * If we can not infer a type or there are conflicting typing + * constraints we will trigger an error. */ +#include "tvm/relay/logging.h" #include "tvm/relay/compiler/type_infer.h" +#include "tvm/relay/error.h" +#include "tvm/relay/expr_functor.h" #include "./incomplete_type.h" +#include "./unifier.h" +#include "./resolve.h" // #include "tvm/relay/alpha_eq.h" // #include "tvm/relay/debug.h" // #include "tvm/relay/first_order_reverse_ad.h" // #include "tvm/relay/free_type_vars.h" // #include "tvm/relay/gen_fresh.h" // #include "tvm/relay/ir.h" -// #include "tvm/relay/logging.h" // #include "tvm/relay/pretty_printer.h" // #include "tvm/relay/reverse_ad.h" // #include "tvm/relay/type_visitor.h" // #include "tvm/relay/typeck/kindchecker.h" -// #include "tvm/relay/typeck/resolve.h" // #include "tvm/relay/typeck/shape_evaluator.h" namespace tvm { namespace relay { -// using namespace tvm::runtime; - -// struct FatalTypeError : dmlc::Error { -// explicit FatalTypeError(const std::string & s) : dmlc::Error(s) {} -// }; - -// struct TypeContext { -// std::vector> stack; -// TypeContext() { -// stack.push_back({}); -// } -// void insert(const LocalId &id, const Type &t) { stack.back()[id] = t; } -// Type lookup(const LocalId &id) { -// for (auto frame = stack.rbegin(); frame != stack.rend(); ++frame) { -// if (frame->find(id) != frame->end()) { -// return frame->at(id); -// } -// } -// throw FatalTypeError("Could not resolve local id"); -// } -// struct LocalFrame { -// TypeContext & tc; -// explicit LocalFrame(TypeContext & tc) : tc(tc) { -// tc.stack.push_back({}); -// } -// ~LocalFrame() { -// tc.stack.pop_back(); -// } -// }; -// }; - -// class Typechecker : private ExprFunctor { -// private: -// TypeContext local_stack; -// public: -// Environment env; -// TypeUnifier unifier; - -// template -// T with_frame(const std::function & f) { -// TypeContext::LocalFrame fr(local_stack); -// return f(); -// } - -// Typechecker(); -// Typechecker(Environment env, TypeUnifier unifier) : env(env), unifier(unifier) {} -// explicit Typechecker(Environment env); -// Type Check(const Expr & expr); -// Type instantiate(Type t, tvm::Array & ty_args); - -// void report_error(const std::string & msg, Span sp); -// [[ noreturn ]] void fatal_error(const std::string & msg, Span sp); - -// Type unify(const Type &t1, const Type &t2, Span sp); -// Type resolve(const Type &t); -// Expr resolve(const Expr &e); -// Type VisitFunction(const Function & f, bool generalize); -// Operator CheckOp(Operator op); -// Defn CheckDefn(Defn def); -// private: -// Type VisitExpr_(const LocalIdNode* op) override; -// Type VisitExpr_(const GlobalIdNode* op) override; -// Type VisitExpr_(const OperatorIdNode* op) override; -// Type VisitExpr_(const FloatLitNode* op) override; -// Type VisitExpr_(const BoolLitNode* op) override; -// Type VisitExpr_(const IntLitNode* op) override; -// Type VisitExpr_(const TensorLitNode* op) override; -// Type VisitExpr_(const TupleNode* op) override; -// Type VisitExpr_(const CastNode* op) override; -// Type VisitExpr_(const ParamNode* op) override; -// Type VisitExpr_(const FunctionNode* op) override; -// Type VisitExpr_(const CallNode* op) override; -// Type VisitExpr_(const DebugNode* op) override; -// Type VisitExpr_(const LetNode* op) override; -// Type VisitExpr_(const ReverseNode* op) override; -// Type VisitExpr_(const GradientNode* op) override; -// Type VisitExpr_(const ProjectionNode* op) override; -// Type VisitExpr_(const IfNode* op) override; -// Type VisitExpr_(const RefNode* op) override; -// Type VisitExpr_(const ReadRefNode* op) override; -// Type VisitExpr_(const WriteRefNode* op) override; -// Type simple_eval_shape(const Type &shape); -// }; -// struct TypecheckerError : public dmlc::Error { -// explicit TypecheckerError(const std::string &msg) : Error(msg) {} -// }; - -// Typechecker::Typechecker() { -// this->env = EnvironmentNode::make({}); -// this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); -// } - -// Typechecker::Typechecker(Environment env) : env(env) { -// this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); -// } - -// Type Typechecker::Check(const Expr &expr) { -// RELAY_LOG(INFO) << "Typechecker::Check expr=" << expr << std::endl; -// Type ret = this->VisitExpr(expr); -// RELAY_LOG(INFO) << "Typechecker::Check type=" << expr << std::endl; -// ret = this->unifier->subst(ret); -// RELAY_LOG(INFO) << "Typechecker::Check type_after_subst=" << ret << std::endl; -// expr->checked_type_ = ret; -// return ret; -// } - -// Type Typechecker::VisitExpr_(const LocalIdNode *op) { -// LocalId id = GetRef(op); -// return this->local_stack.lookup(id); -// } - -// Type Typechecker::VisitExpr_(const GlobalIdNode *op) { -// GlobalId id = GetRef(op); -// Item item = this->env->lookup(id); - -// if (const OperatorNode *op = item.as()) { -// return op->type; -// } - -// if (const DefnNode *dn = item.as()) { -// Defn def = GetRef(dn); -// return def->type; -// } - -// this->fatal_error("Unhandled case in GlobalId", op->span); -// } - -// Type Typechecker::VisitExpr_(const OperatorIdNode *op) { -// OperatorId id = GetRef(op); -// Item item = this->env->lookup(id); - -// if (const OperatorNode *pn = item.as()) { -// Operator prim = GetRef(pn); -// return prim->type; -// } else { -// this->fatal_error("internal error in InstrinsicId case", op->span); -// } -// } - -// Type Typechecker::VisitExpr_(const FloatLitNode *op) { return FloatType(); } - -// Type Typechecker::VisitExpr_(const BoolLitNode *op) { return BoolType(); } - -// Type Typechecker::VisitExpr_(const IntLitNode *op) { return IntType(); } - -// Type Typechecker::VisitExpr_(const TensorLitNode *op) { -// TensorLit lit = GetRef(op); - -// if (lit->data.size() == 0) { -// this->fatal_error("Tensor literal must have at least one member", op->span); -// } - -// // unify types of all members to figure out shape, also ensure that -// // each member has compatible shape -// Type unified = this->Check(lit->data[0]); -// for (auto elt = lit->data.begin(); elt != lit->data.end(); elt++) { -// // evaluate all shape ASTs so they can be in standard form -// // TODO(sslyu): eventually we'd want this to be symbolic evaluation -// auto elt_el = *elt; -// Type elt_type = simple_eval_shape(this->Check(*elt)); -// if (!elt_type.as()) { -// this->fatal_error("All members in tensor literal must be tensors", -// elt_el->span); -// } -// unified = this->unify(unified, elt_type, lit->span); -// } - -// // types must unify into a tensor -// const TensorTypeNode *ttn = unified.as(); -// // shouldn't be possible due to check inside the loop -// if (!ttn) { -// this->fatal_error("Tensor literal contains non-tensor member", op->span); -// } - -// TensorType unified_tt = GetRef(ttn); - -// // new shape: add length of this tensor to front of existing shape -// // i.e., sequence and simplify -// // TODO(sslyu): should be symbolic evaluation eventually? -// Type new_shape = ShapeSeqNode::make( -// {ShapeSingletonNode::make(lit->data.size()), unified_tt->shape}); -// return TensorTypeNode::make(unified_tt->dtype, simple_eval_shape(new_shape)); -// } - -// Type Typechecker::VisitExpr_(const TupleNode *op) { -// Tuple pl = GetRef(op); - -// std::vector field_types; -// for (auto field = pl->fields.begin(); field != pl->fields.end(); field++) { -// field_types.push_back(this->Check(*field)); -// } - -// return TupleTypeNode::make(field_types); -// } - -// Type Typechecker::VisitExpr_(const CastNode *op) { -// // will take the cast at its word -// Cast cast = GetRef(op); -// return cast->target; -// } - -// Type Typechecker::VisitExpr_(const ParamNode *op) { -// Param param = GetRef(op); -// return resolve(param->type); -// } - -// // We should probably generalize the subst code. -// struct GeneralizeTypeType : TypeFVisitor { -// Map vars_to_id; -// const TypeUnifier &unifier; - -// GeneralizeTypeType(Map vars_to_id, -// const TypeUnifier &unifier) -// : vars_to_id(vars_to_id), unifier(unifier) {} - -// Type VisitType_(const TypeVarNode *op) override { -// auto repr = unifier->subst(GetRef(op)); -// if (auto tvn = repr.as()) { -// auto ty_var = GetRef(tvn); -// if (vars_to_id.find(ty_var) != vars_to_id.end()) { -// return vars_to_id[ty_var]; -// } else { -// return ty_var; -// } -// } else { -// return this->VisitType(repr); -// } -// } -// }; - -// struct GeneralizeTypeExpr : ExprFVisitor<> { -// Map vars_to_id; -// const TypeUnifier &unifier; - -// GeneralizeTypeExpr(const TypeUnifier &unifier, -// Map vars_to_id) -// : vars_to_id(vars_to_id), unifier(unifier) {} - -// Type VisitType(const Type &t) { -// return GeneralizeTypeType(vars_to_id, unifier).VisitType(t); -// } -// }; - -// Type Typechecker::VisitFunction(const Function &f, bool generalize) { -// // enter params into context -// auto fn_type = this->with_frame([&]() { -// std::vector arg_types; -// for (auto arg : f->params) { -// this->Check(arg); -// Type arg_type; -// // if arg type can be simply evaluated, try it -// // should be replaced with symbolic evaluation once it exists, -// // you will not have attr information at this point -// try { -// arg_type = simple_eval_shape(arg->type); -// } catch (const dmlc::Error &e) { -// this->report_error(e.what(), arg->span); -// arg_type = arg->type; -// } -// arg_types.push_back(arg_type); -// this->local_stack.insert(arg->id, arg_type); -// } - -// // typecheck body and ensure that it matches stated return type -// // TODO(sslyu): should the unified return type override the annotated one? -// Type checked_return = this->Check(f->body); -// Type ret_type = resolve(f->ret_type); -// Type unified = this->unify(simple_eval_shape(ret_type), -// simple_eval_shape(checked_return), f->span); -// return TypeArrowNode::make(arg_types, unified); -// }); -// if (generalize) { -// auto free_vars = free_type_vars(resolve(fn_type)); -// std::set dedup_free_vars; - -// for (auto free_var : free_vars) { -// auto repr = this->unifier->subst(free_var); -// if (auto new_free_var_node = repr.as()) { -// dedup_free_vars.insert(GetRef(new_free_var_node)); -// } else { -// // debug(repr); -// throw dmlc::Error( -// "internal error: this list should only contain type var nodes"); -// } -// } - -// Map vars_to_id; - -// GenFresh gf; -// for (auto free_var : dedup_free_vars) { -// vars_to_id.Set(free_var, gf.freshTV(free_var->kind)); -// } - -// fn_type = GeneralizeTypeType(vars_to_id, unifier).VisitType(fn_type); -// for (std::pair pair : vars_to_id) { -// // NB: In generalization we want to find type variables with -// // *no constraints* on them, and convert them to universally quantified -// // variables. -// // -// // i.e the program can be abstracted over the details of *that* type. - -// // For example a program that works irrespective of shape or datatype. - -// // In order to do this we find the set of free type variables in the -// // term, and then unify them with the fresh type ids we generate. -// // -// // Remember importantly these type variables still may appear in many -// // places in the program including both types and expressions. - -// // Our method for resolving these is to unify them with the variables -// // as we build the new quanitifer, changing from a program with "holes" -// // to one that is properly abstracted over. - -// // Finally later on we can iterate over the whole term and change from -// // type variables to these type ids. -// this->unify(pair.first, pair.second, pair.second->span); -// fn_type = TypeQuantifierNode::make(pair.second, fn_type); -// } -// } else { -// for (auto i = f->ty_params.size(); i > 0; i--) { -// auto ty_param = f->ty_params[i - 1]; -// auto ty_param_node = ty_param.as(); -// if (!ty_param_node) { -// throw dmlc::Error("internal error should be TypeParam"); -// } -// auto fresh_tid = -// TypeParamNode::make(ty_param_node->name, ty_param_node->kind); -// fn_type = -// type_subst(fn_type, GetRef(ty_param_node), fresh_tid); -// fn_type = TypeQuantifierNode::make(fresh_tid, fn_type); -// } -// } - -// return fn_type; -// } - -// Type Typechecker::VisitExpr_(const FunctionNode *op) { -// return this->VisitFunction(GetRef(op), false); -// } - -// Type Typechecker::instantiate(Type t, tvm::Array &ty_args) { -// const TypeQuantifierNode *ty_quant; -// while ((ty_quant = t.as())) { -// TypeParam id = ty_quant->id; -// TypeVar fresh = TypeVarNode::make(id->kind); -// this->unifier->insert(fresh); -// ty_args.push_back(fresh); -// t = type_subst(ty_quant->boundType, id, fresh); -// } - -// if (!check_kind(t)) { -// this->fatal_error("Kind rules broken when instantiating type variables", -// t->span); -// } - -// return t; -// } - -// Type Typechecker::VisitExpr_(const CallNode *op) { -// Call c = GetRef(op); -// Type fn_ty = this->Check(c->fn); - -// RELAY_LOG(INFO) << "Typechecker::VisitExpr_ op=" << c << std::endl -// << "fn_ty=" << fn_ty << std::endl; - -// // for each type id, insert a type variable and unify with the argument types -// // in order -// // to obtain the concrete instantiation -// tvm::Array ty_args; -// if (const TypeQuantifierNode *ty_quant = fn_ty.as()) { -// fn_ty = instantiate(GetRef(ty_quant), ty_args); -// } - -// if (!fn_ty.as()) { -// this->fatal_error("only expressions with function types can be called", -// c->fn->span); -// } - -// // evaluate all shapes up front (require that types be fully concrete) -// Type evaluated = evaluate_concrete_shape(fn_ty, op->attrs); -// std::vector arg_types; - -// TypeArrow arrow = GetRef(evaluated.as()); - -// // TODO(sslyu): figure out how to handle type ids -// // fn_ty = instantiate(fn_ty, ty_args); -// for (auto arg : c->args) { -// auto ty = this->Check(arg); -// arg_types.push_back(ty); -// } - -// auto type_arity = arrow->arg_types.size(); -// auto number_of_args = arg_types.size(); -// if (type_arity != number_of_args) { -// if (type_arity < number_of_args) { -// this->fatal_error("the function is provided too many arguments", c->span); -// } else { -// this->fatal_error("the function is provided too few arguments", c->span); -// } -// } - -// for (size_t i = 0; i < arrow->arg_types.size(); i++) { -// this->unify(arrow->arg_types[i], arg_types[i], c->args[i]->span); -// } - -// // After we unify the arguments we should know more about the type -// // arguments, let's run a quick pass over them to find new representatives. -// for (size_t i = 0; i < ty_args.size(); i++) { -// ty_args.Set(i, this->unifier->subst(ty_args[i])); -// } - -// // Write the type arguments into the call node, recording what inference -// // solves. This solution might need some work. -// c->ty_args = ty_args; - -// return arrow->ret_type; -// } - -// Type Typechecker::VisitExpr_(const DebugNode *op) { -// return this->Check(op->node); -// } - -// Type Typechecker::VisitExpr_(const LetNode *op) { -// Let let = GetRef(op); - -// Type checked_ty; -// Type annotated_ty = resolve(let->type); - -// // if we are let-defining a function, treat it as a let-rec and insert -// // the id with the annotated type in case there is recursion; -// // no such recursion permitted with anything that's not a function! -// if (let->value.as()) { -// with_frame([&]() { -// local_stack.insert(let->id, annotated_ty); -// checked_ty = Check(let->value); -// }); -// } else { -// checked_ty = Check(let->value); -// } - -// // ensure annotated type and checked type are compatible -// // TODO(sslyu): should the annotated type override the unified one? -// Type unified_ty = -// this->unify(checked_ty, simple_eval_shape(annotated_ty), let->span); - -// return with_frame([&]() { -// local_stack.insert(let->id, unified_ty); -// return Check(let->body); -// }); -// } - -// Type Typechecker::VisitExpr_(const ReverseNode *op) { -// // apply reverse mode to node and typecheck that instead -// std::shared_ptr gf = std::make_shared(); -// return this->Check(ReverseExpr(env, op->node, gf)); -// } - -// Type Typechecker::VisitExpr_(const GradientNode *op) { -// auto node = op->node; -// this->Check(node); -// auto gf = std::make_shared(); -// return FOWithGradientType(node->checked_type()); -// } - -// Type Typechecker::VisitExpr_(const ProjectionNode *op) { -// Projection proj = GetRef(op); - -// Type tup_type = this->Check(proj->tuple); - -// const TupleTypeNode *ptn = tup_type.as(); -// if (!ptn) { -// this->fatal_error("Cannot project into non-product type", op->span); -// } - -// TupleType pt = GetRef(ptn); -// size_t field = (size_t)proj->field; -// if (field >= pt->fields.size()) { -// this->fatal_error("Projecting past bounds of product", op->span); -// } - -// return pt->fields[field]; -// } - -// Type Typechecker::VisitExpr_(const IfNode *op) { -// If ifn = GetRef(op); - -// // Ensure the type of the guard is of Tensor[Bool, ()], -// // that is a rank-0 boolean tensor. -// Type guardType = this->Check(ifn->guard); -// bool is_bool = false; -// bool zero_rank = false; -// if (const TensorTypeNode *ttn = guardType.as()) { -// TensorType tt = GetRef(ttn); - -// if (const BaseTypeNode *btn = tt->dtype.as()) { -// is_bool = btn->type.is_bool(); -// } - -// Type shape = simple_eval_shape(tt->shape); - -// if (const ShapeSeqNode *sn = shape.as()) { -// zero_rank = (sn->shapes.size() == 0); -// } -// } - -// if (!(is_bool && zero_rank)) { -// this->fatal_error("IfNode guard must be a rank 0 bool tensor", -// ifn->guard->span); -// } - -// // unify types of different branches -// Type left = this->Check(ifn->true_b); -// Type right = this->Check(ifn->false_b); -// return this->unify(left, right, ifn->span); -// } - -// Type Typechecker::VisitExpr_(const RefNode *op) { -// Ref r = GetRef(op); -// Type inner = this->Check(r->expr); -// return RefTypeNode::make(inner); -// } - -// Type Typechecker::VisitExpr_(const ReadRefNode *op) { -// ReadRef vr = GetRef(op); -// Type ref_type = this->Check(vr->ref); - -// // reject if not a ref type -// const RefTypeNode *rtn = ref_type.as(); -// if (!rtn) { -// this->fatal_error( -// "the de-reference operation can only be used with references", -// op->span); -// } - -// RefType rt = GetRef(rtn); -// return rt->data_type; -// } - -// Type Typechecker::VisitExpr_(const WriteRefNode *op) { -// WriteRef sr = GetRef(op); -// Type ref_type = this->Check(sr->ref); - -// const RefTypeNode *rtn = ref_type.as(); -// if (!rtn) { -// this->fatal_error("Cannot mutate non-ref", op->span); -// } -// RefType rt = GetRef(rtn); - -// // ensure ref type's inner type and expr's type are compatible; return unit -// Type expr_type = this->Check(sr->val); -// this->unify(rt->data_type, expr_type, sr->span); -// return UnitType(); -// } - -// Type Typechecker::resolve(const Type &t) { -// return ::tvm::relay::resolve(this->unifier, t); -// } - -// Expr Typechecker::resolve(const Expr &e) { -// return ::tvm::relay::resolve(this->unifier, e); -// } - -// Type Typechecker::simple_eval_shape(const Type &shape) { -// // TODO(sslyu): Do we want to propagate attributes? -// Attributes empty = AttributesNode::make({}); -// return evaluate_concrete_shape(shape, empty); -// } - -// Operator Typechecker::CheckOp(Operator op) { -// if (!check_kind(op->type)) { -// report_error("the type of the operator is ill formed", op->type->span); -// } - -// // Fix me -// return op; -// } - -// Defn Typechecker::CheckDefn(Defn defn) { -// // This is to handle recursion, but we need to speculatively -// // put it in env, then remove it. -// env->items.insert({defn->id, defn}); - -// Type expected_ty = this->resolve(defn->type); - -// Expr body = defn->body; - -// auto checked_ty = Check(body); - -// try { -// Type uret_type = unify(expected_ty, checked_ty, defn->body->span); -// CHECK(is_fully_resolved(uret_type)); -// // Now let's clean up our work from earlier. -// env->items.erase(defn->id); -// return DefnNode::make(defn->id, uret_type, this->resolve(defn->body)); -// } catch (const UnificationError& err) { -// std::string msg = std::string("mismatch between `") + -// PrintType(env, expected_ty, WrapWidth(40)) + "` and `" + -// PrintType(env, checked_ty, WrapWidth(40)) + "`"; -// fatal_error(msg, defn->span); -// } -// } - -Expr Infer(const Environment &env, const Expr &e) { - //Typechecker tc(env); - // return tc.Check(e); - return e; -} - -// Item Check(const Environment &env, const Item &i) { -// Typechecker tc(env); - -// try { -// if (const DefnNode *defn = i.as()) { -// return tc.CheckDefn(GetRef(defn)); -// } else if (const OperatorNode *op_node = i.as()) { -// return tc.CheckOp(GetRef(op_node)); -// } else { -// throw dmlc::Error("internal error: unknown Item type"); -// } -// } catch (const FatalTypeError &err) { -// env->display_errors(); -// throw dmlc::Error( -// "We encountered a fatal error while type checking your program, please " -// "read above for more details."); -// } -// } - -// inline void Typechecker::report_error(const std::string &msg, Span sp) { -// this->env->report_error(msg, sp); -// } - -// void Typechecker::fatal_error(const std::string &msg, Span sp) { -// this->env->report_error(msg, sp); -// throw FatalTypeError( -// "internal error: this exception should" -// "be handled and errors reported with Environment::display_errors\n" + -// msg); -// } - -// Type Typechecker::unify(const Type &t1, const Type &t2, Span sp) { -// try { -// return this->unifier->unify(t1, t2); -// } catch (const dmlc::Error &e) { -// std::stringstream ss; -// ss << "Error unifying `"; -// ss << PrintType(env, t1, WrapWidth(40)); -// ss << "` and `"; -// ss << PrintType(env, t2, WrapWidth(40)); -// ss << "`: " << e.what(); -// this->fatal_error(ss.str(), sp); -// } -// } - -// // template - -// // Add safe dynamic Array downcast. -// // Add static upcast? - -// // Add to type utils. -// Array type_parameters(const Type &t) { -// Array params; -// auto type = t; -// const TypeQuantifierNode *ty_quant; -// while ((ty_quant = type.as())) { -// params.push_back(ty_quant->id); -// type = ty_quant->boundType; -// } - -// return params; -// } - -// template -// Array ArrayMap(const Array &data, F f) { -// // probably a way to use std::transform. -// Array output; -// for (const I &el : data) { -// output.push_back(f(el)); -// } -// return output; -// } - -// // There are some important questions around generalization -// // that we need to answer. -// Expr generalize(const Environment &env, const Expr &e) { -// if (auto fn_node = e.as()) { -// Typechecker tc(env); -// auto ty = tc.VisitFunction(GetRef(fn_node), true); -// auto ty_params = type_parameters(ty); -// auto params = ArrayMap(fn_node->params, [&](const Param &p) { -// return ParamNode::make(p->id, tc.resolve(p->type)); -// }); -// auto body = tc.resolve(fn_node->body); -// auto ret_type = tc.resolve(fn_node->ret_type); -// auto fn = FunctionNode::make(ty_params, params, ret_type, body); -// // we should check in empty context to ensure typing is preserved. -// // check(env, fn); -// return fn; -// } else { -// throw dmlc::Error("can only apply generalize to a function."); -// } -// } - -TVM_REGISTER_API("relay._type_infer.check_expr") - .set_body([](TVMArgs args, TVMRetValue *ret) { - Environment env = args[0]; - Expr e = args[1]; - *ret = Infer(env, e); - }); - -// TVM_REGISTER_API("relay._tyck.check_item") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// Item i = args[1]; -// *ret = check(env, i); -// }); - -TVM_REGISTER_API("relay._type_infer._get_checked_type") - .set_body([](TVMArgs args, TVMRetValue *ret) { - Expr e = args[0]; - *ret = e->checked_type(); - }); - -// TVM_REGISTER_API("relay._tyck.generalize") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// *ret = generalize(args[0], args[1]); -// }); - -IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) { - std::shared_ptr n = std::make_shared(); - n->kind = std::move(kind); - return IncompleteType(n); -} - -TVM_REGISTER_API("relay._make.IncompleteType") - .set_body([](TVMArgs args, TVMRetValue *ret) { - int kind = args[0]; - *ret = IncompleteTypeNode::make(static_cast(kind)); - }); - -TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const IncompleteTypeNode *node, - tvm::IRPrinter *p) { - p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; +using namespace tvm::runtime; + +struct TypeContext { + std::vector> stack; + + TypeContext() { stack.push_back({}); } + + void insert(const LocalVar &id, const Type &t) { stack.back()[id] = t; } + + Type lookup(const LocalVar &id) { + for (auto frame = stack.rbegin(); frame != stack.rend(); ++frame) { + if (frame->find(id) != frame->end()) { + return frame->at(id); + } + } + throw FatalTypeError("Could not resolve local id"); + } + + struct LocalFrame { + TypeContext &tc; + explicit LocalFrame(TypeContext &tc) : tc(tc) { tc.stack.push_back({}); } + ~LocalFrame() { tc.stack.pop_back(); } + }; +}; + +struct CheckedExpr { + Expr expr; + Type type; + CheckedExpr(Expr e, Type t) : expr(e), type(t) {} +}; + +class TypeInferencer : private ExprFunctor { + private: + TypeContext local_stack; + + public: + Environment env; + TypeUnifier unifier; + + // Should be in header? + template + T with_frame(const std::function & f) { + TypeContext::LocalFrame fr(local_stack); + return f(); + } + + TypeInferencer(); + TypeInferencer(Environment env, TypeUnifier unifier) : env(env), + unifier(unifier) {} explicit TypeInferencer(Environment env); + + CheckedExpr Infer(const Expr & expr); + + Type instantiate(Type t, tvm::Array &ty_args); + + void report_error(const std::string & msg, Span sp); + [[ noreturn ]] void fatal_error(const std::string & msg, Span sp); + + Type unify(const Type &t1, const Type &t2, Span sp); + Type resolve(const Type &t); + Expr resolve(const Expr &e); + CheckedExpr VisitFunction(const Function & f, bool generalize); + // Operator CheckOp(Operator op); + // Defn CheckDefn(Defn def); + private: + CheckedExpr VisitExpr_(const LocalVarNode* op) override; + CheckedExpr VisitExpr_(const GlobalVarNode* op) override; + CheckedExpr VisitExpr_(const TupleNode* op) override; + CheckedExpr VisitExpr_(const ParamNode* op) override; + CheckedExpr VisitExpr_(const FunctionNode* op) override; + CheckedExpr VisitExpr_(const CallNode* op) override; + CheckedExpr VisitExpr_(const LetNode* op) override; + CheckedExpr VisitExpr_(const IfNode* op) override; +}; + + TypeInferencer::TypeInferencer() { + this->env = EnvironmentNode::make({}); + this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); + } + + TypeInferencer::TypeInferencer(Environment env) : env(env) { + this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); + } + + CheckedExpr TypeInferencer::Infer(const Expr &expr) { + RELAY_LOG(INFO) << "TypeInferencer::Check expr=" << expr << std::endl; + CheckedExpr checked_expr = this->VisitExpr(expr); + RELAY_LOG(INFO) << "TypeInferencer::Check type=" << checked_expr.type << std::endl; + Type final_type = this->unifier->subst(checked_expr.type); + RELAY_LOG(INFO) << "TypeInferencer::Check type_after_subst=" << final_type << std::endl; + checked_expr.expr->checked_type_ = final_type; + return checked_expr; + } + + CheckedExpr TypeInferencer::VisitExpr_(const LocalVarNode *op) { + auto var = GetRef(op); + return { var, this->local_stack.lookup(var) }; + } + + CheckedExpr TypeInferencer::VisitExpr_(const GlobalVarNode *op) { + // GlobalVar id = GetRef(op); + // Item item = this->env->lookup(id); + + // if (const OperatorNode *op = item.as()) { + // return op->type; + // } + + // if (const DefnNode *dn = item.as()) { + // Defn def = GetRef(dn); + // return def->type; + // } + + // this->fatal_error("Unhandled case in GlobalId", op->span); + throw Error("hereeee"); + } + + // Type TypeInferencer::VisitExpr_(const OperatorIdNode *op) { + // OperatorId id = GetRef(op); + // Item item = this->env->lookup(id); + + // if (const OperatorNode *pn = item.as()) { + // Operator prim = GetRef(pn); + // return prim->type; + // } else { + // this->fatal_error("internal error in InstrinsicId case", op->span); + // } + // } + + CheckedExpr TypeInferencer::VisitExpr_(const TupleNode *op) { + // Tuple pl = GetRef(op); + + // std::vector field_types; + // for (auto field = pl->fields.begin(); field != pl->fields.end(); field++) + // { + // field_types.push_back(this->Check(*field)); + // } + + // return TupleTypeNode::make(field_types); + throw Error("TupleNode NYI"); + } + + CheckedExpr TypeInferencer::VisitExpr_(const ParamNode *op) { + // Param param = GetRef(op); + // return { resolve(param->type); + throw Error("ParamNode NYI"); + } + + // // We should probably generalize the subst code. + // struct GeneralizeTypeType : TypeFVisitor { + // Map vars_to_id; + // const TypeUnifier &unifier; + + // GeneralizeTypeType(Map vars_to_id, + // const TypeUnifier &unifier) + // : vars_to_id(vars_to_id), unifier(unifier) {} + + // Type VisitType_(const TypeVarNode *op) override { + // auto repr = unifier->subst(GetRef(op)); + // if (auto tvn = repr.as()) { + // auto ty_var = GetRef(tvn); + // if (vars_to_id.find(ty_var) != vars_to_id.end()) { + // return vars_to_id[ty_var]; + // } else { + // return ty_var; + // } + // } else { + // return this->VisitType(repr); + // } + // } + // }; + + // struct GeneralizeTypeExpr : ExprFVisitor<> { + // Map vars_to_id; + // const TypeUnifier &unifier; + + // GeneralizeTypeExpr(const TypeUnifier &unifier, + // Map vars_to_id) + // : vars_to_id(vars_to_id), unifier(unifier) {} + + // Type VisitType(const Type &t) { + // return GeneralizeTypeType(vars_to_id, unifier).VisitType(t); + // } + // }; + + CheckedExpr TypeInferencer::VisitFunction(const Function &f, bool generalize) { + throw Error("FunctionNode NYI"); + // // enter params into context + // auto fn_type = this->with_frame([&]() { + // std::vector arg_types; + // for (auto arg : f->params) { + // this->Check(arg); + // Type arg_type; + // // if arg type can be simply evaluated, try it + // // should be replaced with symbolic evaluation once it exists, + // // you will not have attr information at this point + // try { + // arg_type = simple_eval_shape(arg->type); + // } catch (const dmlc::Error &e) { + // this->report_error(e.what(), arg->span); + // arg_type = arg->type; + // } + // arg_types.push_back(arg_type); + // this->local_stack.insert(arg->id, arg_type); + // } + + // // typecheck body and ensure that it matches stated return type + // // TODO(sslyu): should the unified return type override the annotated + // one? Type checked_return = this->Check(f->body); Type ret_type = + // resolve(f->ret_type); Type unified = + // this->unify(simple_eval_shape(ret_type), + // simple_eval_shape(checked_return), f->span); + // return TypeArrowNode::make(arg_types, unified); + // }); + // if (generalize) { + // auto free_vars = free_type_vars(resolve(fn_type)); + // std::set dedup_free_vars; + + // for (auto free_var : free_vars) { + // auto repr = this->unifier->subst(free_var); + // if (auto new_free_var_node = repr.as()) { + // dedup_free_vars.insert(GetRef(new_free_var_node)); + // } else { + // // debug(repr); + // throw dmlc::Error( + // "internal error: this list should only contain type var + // nodes"); + // } + // } + + // Map vars_to_id; + + // GenFresh gf; + // for (auto free_var : dedup_free_vars) { + // vars_to_id.Set(free_var, gf.freshTV(free_var->kind)); + // } + + // fn_type = GeneralizeTypeType(vars_to_id, unifier).VisitType(fn_type); + // for (std::pair pair : vars_to_id) { + // // NB: In generalization we want to find type variables with + // // *no constraints* on them, and convert them to universally + // quantified + // // variables. + // // + // // i.e the program can be abstracted over the details of *that* type. + + // // For example a program that works irrespective of shape or + // datatype. + + // // In order to do this we find the set of free type variables in the + // // term, and then unify them with the fresh type ids we generate. + // // + // // Remember importantly these type variables still may appear in many + // // places in the program including both types and expressions. + + // // Our method for resolving these is to unify them with the variables + // // as we build the new quanitifer, changing from a program with + // "holes" + // // to one that is properly abstracted over. + + // // Finally later on we can iterate over the whole term and change + // from + // // type variables to these type ids. + // this->unify(pair.first, pair.second, pair.second->span); + // fn_type = TypeQuantifierNode::make(pair.second, fn_type); + // } + // } else { + // for (auto i = f->ty_params.size(); i > 0; i--) { + // auto ty_param = f->ty_params[i - 1]; + // auto ty_param_node = ty_param.as(); + // if (!ty_param_node) { + // throw dmlc::Error("internal error should be TypeParam"); + // } + // auto fresh_tid = + // TypeParamNode::make(ty_param_node->name, ty_param_node->kind); + // fn_type = + // type_subst(fn_type, GetRef(ty_param_node), fresh_tid); + // fn_type = TypeQuantifierNode::make(fresh_tid, fn_type); + // } + // } + + // return fn_type; + + } + + CheckedExpr TypeInferencer::VisitExpr_(const FunctionNode *op) { + return this->VisitFunction(GetRef(op), false); + } + + // Type TypeInferencer::instantiate(Type t, tvm::Array &ty_args) { + // const TypeQuantifierNode *ty_quant; + // while ((ty_quant = t.as())) { + // TypeParam id = ty_quant->id; + // TypeVar fresh = TypeVarNode::make(id->kind); + // this->unifier->insert(fresh); + // ty_args.push_back(fresh); + // t = type_subst(ty_quant->boundType, id, fresh); + // } + + // if (!check_kind(t)) { + // this->fatal_error("Kind rules broken when instantiating type + // variables", + // t->span); + // } + + // return t; + // } + + CheckedExpr TypeInferencer::VisitExpr_(const CallNode *op) { + throw Error("CallNode"); + // Call c = GetRef(op); + // Type fn_ty = this->Check(c->fn); + + // RELAY_LOG(INFO) << "TypeInferencer::VisitExpr_ op=" << c << std::endl + // << "fn_ty=" << fn_ty << std::endl; + + // // for each type id, insert a type variable and unify with the argument + // types + // // in order + // // to obtain the concrete instantiation + // tvm::Array ty_args; + // if (const TypeQuantifierNode *ty_quant = fn_ty.as()) + // { + // fn_ty = instantiate(GetRef(ty_quant), ty_args); + // } + + // if (!fn_ty.as()) { + // this->fatal_error("only expressions with function types can be called", + // c->fn->span); + // } + + // // evaluate all shapes up front (require that types be fully concrete) + // Type evaluated = evaluate_concrete_shape(fn_ty, op->attrs); + // std::vector arg_types; + + // TypeArrow arrow = GetRef(evaluated.as()); + + // // TODO(sslyu): figure out how to handle type ids + // // fn_ty = instantiate(fn_ty, ty_args); + // for (auto arg : c->args) { + // auto ty = this->Check(arg); + // arg_types.push_back(ty); + // } + + // auto type_arity = arrow->arg_types.size(); + // auto number_of_args = arg_types.size(); + // if (type_arity != number_of_args) { + // if (type_arity < number_of_args) { + // this->fatal_error("the function is provided too many arguments", + // c->span); + // } else { + // this->fatal_error("the function is provided too few arguments", + // c->span); + // } + // } + + // for (size_t i = 0; i < arrow->arg_types.size(); i++) { + // this->unify(arrow->arg_types[i], arg_types[i], c->args[i]->span); + // } + + // // After we unify the arguments we should know more about the type + // // arguments, let's run a quick pass over them to find new + // representatives. for (size_t i = 0; i < ty_args.size(); i++) { + // ty_args.Set(i, this->unifier->subst(ty_args[i])); + // } + + // // Write the type arguments into the call node, recording what inference + // // solves. This solution might need some work. + // c->ty_args = ty_args; + + // return arrow->ret_type; + } + + // Type TypeInferencer::VisitExpr_(const DebugNode *op) { + // return this->Check(op->node); + // } + + CheckedExpr TypeInferencer::VisitExpr_(const LetNode *op) { + Let let = GetRef(op); + + Type checked_ty; + Type annotated_ty = resolve(let->value_type); + + // // if we are let-defining a function, treat it as a let-rec and insert + // // the id with the annotated type in case there is recursion; + // // no such recursion permitted with anything that's not a function! + // if (let->value.as()) { + // with_frame([&]() { + // local_stack.insert(let->id, annotated_ty); + // checked_ty = Check(let->value); + // }); + // } else { + // checked_ty = Check(let->value); + // } + + // ensure annotated type and checked type are compatible + // TODO(sslyu): should the annotated type override the unified one? + Type unified_ty = + this->unify(checked_ty, annotated_ty, let->span); + + return with_frame([&]() { + local_stack.insert(let->var, unified_ty); + return Infer(let->body); }); + } + + // Type TypeInferencer::VisitExpr_(const ReverseNode *op) { + // // apply reverse mode to node and typecheck that instead + // std::shared_ptr gf = std::make_shared(); + // return this->Check(ReverseExpr(env, op->node, gf)); + // } + + // Type TypeInferencer::VisitExpr_(const GradientNode *op) { + // auto node = op->node; + // this->Check(node); + // auto gf = std::make_shared(); + // return FOWithGradientType(node->checked_type()); + // } + + // Type TypeInferencer::VisitExpr_(const ProjectionNode *op) { + // Projection proj = GetRef(op); + + // Type tup_type = this->Check(proj->tuple); + + // const TupleTypeNode *ptn = tup_type.as(); + // if (!ptn) { + // this->fatal_error("Cannot project into non-product type", op->span); + // } + + // TupleType pt = GetRef(ptn); + // size_t field = (size_t)proj->field; + // if (field >= pt->fields.size()) { + // this->fatal_error("Projecting past bounds of product", op->span); + // } + + // return pt->fields[field]; + // } + + CheckedExpr TypeInferencer::VisitExpr_(const IfNode *op) { + // If ifn = GetRef(op); + + // // Ensure the type of the guard is of Tensor[Bool, ()], + // // that is a rank-0 boolean tensor. + // Type guardType = this->Check(ifn->guard); + // bool is_bool = false; + // bool zero_rank = false; + // if (const TensorTypeNode *ttn = guardType.as()) { + // TensorType tt = GetRef(ttn); + + // if (const BaseTypeNode *btn = tt->dtype.as()) { + // is_bool = btn->type.is_bool(); + // } + + // Type shape = simple_eval_shape(tt->shape); + + // if (const ShapeSeqNode *sn = shape.as()) { + // zero_rank = (sn->shapes.size() == 0); + // } + // } + + // if (!(is_bool && zero_rank)) { + // this->fatal_error("IfNode guard must be a rank 0 bool tensor", + // ifn->guard->span); + // } + + // // unify types of different branches + // Type left = this->Check(ifn->true_b); + // Type right = this->Check(ifn->false_b); + // return this->unify(left, right, ifn->span); + } + + // Type TypeInferencer::VisitExpr_(const RefNode *op) { + // Ref r = GetRef(op); + // Type inner = this->Check(r->expr); + // return RefTypeNode::make(inner); + // } + + // Type TypeInferencer::VisitExpr_(const ReadRefNode *op) { + // ReadRef vr = GetRef(op); + // Type ref_type = this->Check(vr->ref); + + // // reject if not a ref type + // const RefTypeNode *rtn = ref_type.as(); + // if (!rtn) { + // this->fatal_error( + // "the de-reference operation can only be used with references", + // op->span); + // } + + // RefType rt = GetRef(rtn); + // return rt->data_type; + // } + + // Type TypeInferencer::VisitExpr_(const WriteRefNode *op) { + // WriteRef sr = GetRef(op); + // Type ref_type = this->Check(sr->ref); + + // const RefTypeNode *rtn = ref_type.as(); + // if (!rtn) { + // this->fatal_error("Cannot mutate non-ref", op->span); + // } + // RefType rt = GetRef(rtn); + + // // ensure ref type's inner type and expr's type are compatible; return + // unit Type expr_type = this->Check(sr->val); this->unify(rt->data_type, + // expr_type, sr->span); return UnitType(); + // } + + Type TypeInferencer::resolve(const Type &t) { + return ::tvm::relay::resolve(this->unifier, t); + } + + Expr TypeInferencer::resolve(const Expr &e) { + return ::tvm::relay::resolve(this->unifier, e); + } + + // Operator TypeInferencer::CheckOp(Operator op) { + // if (!check_kind(op->type)) { + // report_error("the type of the operator is ill formed", op->type->span); + // } + + // // Fix me + // return op; + // } + + // Defn TypeInferencer::CheckDefn(Defn defn) { + // // This is to handle recursion, but we need to speculatively + // // put it in env, then remove it. + // env->items.insert({defn->id, defn}); + + // Type expected_ty = this->resolve(defn->type); + + // Expr body = defn->body; + + // auto checked_ty = Check(body); + + // try { + // Type uret_type = unify(expected_ty, checked_ty, defn->body->span); + // CHECK(is_fully_resolved(uret_type)); + // // Now let's clean up our work from earlier. + // env->items.erase(defn->id); + // return DefnNode::make(defn->id, uret_type, this->resolve(defn->body)); + // } catch (const UnificationError& err) { + // std::string msg = std::string("mismatch between `") + + // PrintType(env, expected_ty, WrapWidth(40)) + "` and + // `" + PrintType(env, checked_ty, WrapWidth(40)) + + // "`"; + // fatal_error(msg, defn->span); + // } + // } + + Expr Infer(const Environment &env, const Expr &e) { + TypeInferencer ti(env); + auto checked_expr = ti.Infer(e); + return checked_expr.expr; + } + + // Item Check(const Environment &env, const Item &i) { + // TypeInferencer tc(env); + + // try { + // if (const DefnNode *defn = i.as()) { + // return tc.CheckDefn(GetRef(defn)); + // } else if (const OperatorNode *op_node = i.as()) { + // return tc.CheckOp(GetRef(op_node)); + // } else { + // throw dmlc::Error("internal error: unknown Item type"); + // } + // } catch (const FatalTypeError &err) { + // env->display_errors(); + // throw dmlc::Error( + // "We encountered a fatal error while type checking your program, + // please " "read above for more details."); + // } + // } + + inline void TypeInferencer::report_error(const std::string &msg, Span sp) { + // this->env->report_error(msg, sp); + } + + void TypeInferencer::fatal_error(const std::string &msg, Span sp) { + // this->env->report_error(msg, sp); + throw FatalTypeError( + "internal error: this exception should" + "be handled and errors reported with Environment::display_errors\n" + + msg); + } + + Type TypeInferencer::unify(const Type &t1, const Type &t2, Span sp) { + try { + return this->unifier->unify(t1, t2); + } catch (const dmlc::Error &e) { + std::stringstream ss; + ss << "Error unifying `"; + ss << t1; + // ss << PrintType(env, t1, WrapWidth(40)); + ss << "` and `"; + ss << t2; + // ss << PrintType(env, t2, WrapWidth(40)); + ss << "`: " << e.what(); + this->fatal_error(ss.str(), sp); + } + } + + // // template + + // // Add safe dynamic Array downcast. + // // Add static upcast? + + // // Add to type utils. + // Array type_parameters(const Type &t) { + // Array params; + // auto type = t; + // const TypeQuantifierNode *ty_quant; + // while ((ty_quant = type.as())) { + // params.push_back(ty_quant->id); + // type = ty_quant->boundType; + // } + + // return params; + // } + + // template + // Array ArrayMap(const Array &data, F f) { + // // probably a way to use std::transform. + // Array output; + // for (const I &el : data) { + // output.push_back(f(el)); + // } + // return output; + // } + + // // There are some important questions around generalization + // // that we need to answer. + // Expr generalize(const Environment &env, const Expr &e) { + // if (auto fn_node = e.as()) { + // TypeInferencer tc(env); + // auto ty = tc.VisitFunction(GetRef(fn_node), true); + // auto ty_params = type_parameters(ty); + // auto params = ArrayMap(fn_node->params, [&](const Param &p) { + // return ParamNode::make(p->id, tc.resolve(p->type)); + // }); + // auto body = tc.resolve(fn_node->body); + // auto ret_type = tc.resolve(fn_node->ret_type); + // auto fn = FunctionNode::make(ty_params, params, ret_type, body); + // // we should check in empty context to ensure typing is preserved. + // // check(env, fn); + // return fn; + // } else { + // throw dmlc::Error("can only apply generalize to a function."); + // } + // } + + TVM_REGISTER_API("relay._type_infer.check_expr") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + Expr e = args[1]; + *ret = Infer(env, e); + }); + + // TVM_REGISTER_API("relay._tyck.check_item") + // .set_body([](TVMArgs args, TVMRetValue *ret) { + // Environment env = args[0]; + // Item i = args[1]; + // *ret = check(env, i); + // }); + + TVM_REGISTER_API("relay._type_infer._get_checked_type") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Expr e = args[0]; + *ret = e->checked_type(); + }); + + // TVM_REGISTER_API("relay._tyck.generalize") + // .set_body([](TVMArgs args, TVMRetValue *ret) { + // *ret = generalize(args[0], args[1]); + // }); + + IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) { + std::shared_ptr n = + std::make_shared(); + n->kind = std::move(kind); + return IncompleteType(n); + } + + TVM_REGISTER_API("relay._make.IncompleteType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + int kind = args[0]; + *ret = IncompleteTypeNode::make(static_cast(kind)); + }); + + TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const IncompleteTypeNode *node, + tvm::IRPrinter *p) { + p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; + }); } // namespace relay -} // namespace tvm +} // namespace relay diff --git a/src/relay/compiler/unifier.h b/src/relay/compiler/unifier.h index cba96ff02451..86ffd664a161 100644 --- a/src/relay/compiler/unifier.h +++ b/src/relay/compiler/unifier.h @@ -99,9 +99,12 @@ class TypeUnifierNode : public Node, TVM_DECLARE_NODE_TYPE_INFO(TypeUnifierNode, Node); private: - // unify non-typevar with typevar + /*! \brief Unify incomplete type with another type. */ Type unifyWithIncompleteType(const Type& t1, const IncompleteType tvn2); + /*! \brief Implements unification between two types with incomplete portions. */ Type VisitType(const Type & t1, const Type t2) override; + + // Visitor Cases Type VisitType_(const IncompleteTypeNode* t1, const Type t2) override; Type VisitType_(const TensorTypeNode* t1, const Type t2) override; Type VisitType_(const TypeParamNode* t1, const Type t2) override; From 02f6b1e6cace4951649f2cf2b2218b269f8d5f60 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 22 Aug 2018 02:30:53 -0700 Subject: [PATCH 027/136] First simple test passes --- src/relay/compiler/type_infer.cc | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/relay/compiler/type_infer.cc b/src/relay/compiler/type_infer.cc index 49c8bbf9627f..7304bdabe486 100644 --- a/src/relay/compiler/type_infer.cc +++ b/src/relay/compiler/type_infer.cc @@ -108,6 +108,7 @@ class TypeInferencer : private ExprFunctor { private: CheckedExpr VisitExpr_(const LocalVarNode* op) override; CheckedExpr VisitExpr_(const GlobalVarNode* op) override; + CheckedExpr VisitExpr_(const ConstantNode* op) override; CheckedExpr VisitExpr_(const TupleNode* op) override; CheckedExpr VisitExpr_(const ParamNode* op) override; CheckedExpr VisitExpr_(const FunctionNode* op) override; @@ -157,6 +158,15 @@ class TypeInferencer : private ExprFunctor { throw Error("hereeee"); } + CheckedExpr TypeInferencer::VisitExpr_(const ConstantNode *const_node) { + auto array = const_node->data; + // array->t + // first pass + return { + GetRef(const_node), + TensorTypeNode::make({}, HalideIR::Float(32, 1)) }; + } + // Type TypeInferencer::VisitExpr_(const OperatorIdNode *op) { // OperatorId id = GetRef(op); // Item item = this->env->lookup(id); @@ -423,16 +433,16 @@ class TypeInferencer : private ExprFunctor { Type checked_ty; Type annotated_ty = resolve(let->value_type); - // // if we are let-defining a function, treat it as a let-rec and insert - // // the id with the annotated type in case there is recursion; - // // no such recursion permitted with anything that's not a function! + // if we are let-defining a function, treat it as a let-rec and insert + // the id with the annotated type in case there is recursion; + // no such recursion permitted with anything that's not a function! // if (let->value.as()) { - // with_frame([&]() { - // local_stack.insert(let->id, annotated_ty); - // checked_ty = Check(let->value); - // }); + // with_frame([&]() { + // local_stack.insert(let->id, annotated_ty); + // checked_ty = Check(let->value); + // }); // } else { - // checked_ty = Check(let->value); + checked_ty = Infer(let->value).type; // } // ensure annotated type and checked type are compatible From 7d2624a604ab707ef690c90521e2f2c87df8fdba Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 22 Aug 2018 10:37:34 -0700 Subject: [PATCH 028/136] Iterate towards second test --- include/tvm/relay/compiler/environment.h | 13 +++--- include/tvm/relay/op.h | 1 + python/tvm/relay/ir_builder.py | 54 +++++++++++++++++++++--- src/relay/compiler/environment.cc | 20 ++++----- tests/python/relay/test_typechecker.py | 11 ++++- 5 files changed, 72 insertions(+), 27 deletions(-) diff --git a/include/tvm/relay/compiler/environment.h b/include/tvm/relay/compiler/environment.h index ddb7f0dca192..3e108cd8b390 100644 --- a/include/tvm/relay/compiler/environment.h +++ b/include/tvm/relay/compiler/environment.h @@ -42,9 +42,9 @@ class EnvironmentNode : public RelayNode { /*! A map from string names to GlobalIds, ensures global uniqueness. */ InternTable global_map_; /*! A map from string names to Operators, ensures global uniqueness. */ - InternTable operator_map_; + InternTable operators; // /*! \brief A map from file names to source fragments. */ - // SourceMap source_map_; + // SourceMap source_map_ // /*! \brief A list of the errors reported during the current run. */ // std::vector errors_; @@ -64,8 +64,8 @@ class EnvironmentNode : public RelayNode { TVM_DLL static Environment make( std::unordered_map global_funcs); - // Add an item to the Enviroment. - // void add(const Operator& op, bool update = false); + /*! Add an operator to the Enviroment. */ + void register_op(const Operator& op); // void add(const Operator& op, bool update = false); // void try_add(const Item& item, bool update=false); @@ -73,13 +73,10 @@ class EnvironmentNode : public RelayNode { // void remove(const GlobalId& id); // GlobalId global_id(const std::string& str); - // OperatorId operator_id(const std::string& str); + Operator op(const std::string& str); // We can lookup a GlobalId, OperatorId. // Defn lookup(const GlobalId& id); - // Operator lookup(const OperatorId& id); - // Defn lookup_global(const std::string& str); - // Item lookup_operator(const std::string& str); // FileId add_source(std::string file_name, std::string source); // tvm::Array get_operators(); diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index fa152945d38c..2e631df55fd0 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -25,6 +25,7 @@ class Operator; /*! \brief Container for Operator */ class OperatorNode : public ExprNode { public: + std::string name; /*! \brief A type which specifies the relationship between the inputs and outputs * of the operator. */ diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 2b2cdb432b43..35bc31265987 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -5,6 +5,12 @@ from . import expr from . import make as mk +class ExprBuilder(): + def __init__(self, expr): + self.expr = expr + + def __call__(self, *args): + return ExprBuilder(mk.Call(self.expr, list(args), None, None)) def convert(arg: Any, ctxt=tvm.cpu(0)) -> tvm.nd.NDArray: """Convert Python values into the appropriate types @@ -29,7 +35,7 @@ def into_ast(arg: Any, ctxt=tvm.cpu(0)) -> expr.Expr: raise Exception("..") else: value = convert(arg, ctxt) - return mk.Constant(value) + return ExprBuilder(mk.Constant(value)) class WithScope(object): """Auxiliary scope with""" @@ -44,6 +50,18 @@ def __enter__(self): def __exit__(self, ptype, value, trace): self._exit_cb() + +class PartialFunc(): + def __init__(self, params, ret_type, body, type_params): + self.params = params + self.ret_type = ret_type + self.body = body + self.type_params = type_params + + def param_ids(self): + return [p.var for p in self.params] + + def _mk_let(bindings, ret_value): let_expr = ret_value for var, value in reversed(list(bindings.items())): @@ -51,12 +69,15 @@ def _mk_let(bindings, ret_value): return let_expr + class IRBuilder(): def __init__(self): self.bindings = [{}] self.scopes = [{}] + self.params = [] self.ret_value = None + def bind(self, name, type, value): lv = mk.LocalVar(name) self.scopes[-1][name] = lv @@ -65,18 +86,33 @@ def bind(self, name, type, value): def let(self, name, value, value_type=None): - if not isinstance(value, expr.Expr): + if not (isinstance(value, expr.Expr) or isinstance(value, ExprBuilder)): value = into_ast(value) + if isinstance(value, ExprBuilder): + value = value.expr + return self.bind(name, value_type, value) - def function(self, params): + def function(self, *params): + relay_params = [] + for name, ty in params: + lv = mk.LocalVar(name) + self.scopes[-1][name] = lv + relay_params.append(mk.Param(lv, ty)) + + # self.params.append(relay_params) + + pfunc = PartialFunc(relay_params, None, None, []) + def _on_exit(): bindings = self.bindings.pop() scope = self.scopes.pop() - import pdb - pdb.set_trace() - return WithScope(None, _on_exit) + # params = self.params.pop() + + + return WithScope(pfunc, _on_exit) + def ret(self, x): if not self.ret_value: @@ -85,6 +121,12 @@ def ret(self, x): raise Exception( "return value already set, a function can only have one return value") + def fn_params(self): + pass + + def op(self, name): + pass + def get(self): """Get the full program""" bindings = self.bindings.pop() diff --git a/src/relay/compiler/environment.cc b/src/relay/compiler/environment.cc index af8f5eeefab7..735ef79ceb3a 100644 --- a/src/relay/compiler/environment.cc +++ b/src/relay/compiler/environment.cc @@ -24,6 +24,10 @@ Environment EnvironmentNode::make( return Environment(n); } +void EnvironmentNode::register_op(const Operator& op) { + this->operators.Insert(op->name, op); +} + // tvm::PackedFunc EnvironmentNode::jit_for(OperatorId id) { // return this->lookup(id)->compiler; // } @@ -111,18 +115,10 @@ Environment EnvironmentNode::make( // } // } -// Operator EnvironmentNode::lookup(const OperatorId &id) { -// if (operators.find(id) != operators.end()) { -// return operators.at(id); -// } else { -// throw EnvError(std::string("there is no definition of ") + id->name); -// } -// } - -// Item EnvironmentNode::lookup_operator(const std::string &str) { -// OperatorId id = this->operator_id(str); -// return lookup(id); -// } +Operator EnvironmentNode::op(const std::string & op_name) { + // FIX ME + return operators.Lookup(op_name); +} // Defn EnvironmentNode::lookup_global(const std::string &str) { // GlobalId id = this->global_id(str); diff --git a/tests/python/relay/test_typechecker.py b/tests/python/relay/test_typechecker.py index 5626fd8ce0bc..c6bc1a05ebe9 100644 --- a/tests/python/relay/test_typechecker.py +++ b/tests/python/relay/test_typechecker.py @@ -11,8 +11,8 @@ def has_type(expr, typ): return checked_expr.checked_type() == typ def test_monomorphic_let(): + "Program: let x = 1; x" b = IRBuilder() - # Program: let x = 1; x x = b.let('x', 1, value_type=float_type()) b.ret(x) @@ -20,3 +20,12 @@ def test_monomorphic_let(): assert has_type(prog, float_type()) +def test_single_op(): + "Program: fn (x : int32) { let t1 = f(x); t1 }" + b = IRBuilder() + f = b.op('f') + with b.function(('x', float_type())) as func: + x, = func.param_ids() + t1 = b.let('t1', f(x)) + b.ret(t1) + import pdb; pdb.set_trace() From 1ca8c4a31f717bd08e42ef43210cccd8070983dc Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 22 Aug 2018 16:55:31 -0700 Subject: [PATCH 029/136] Remove placeholder op.h --- include/tvm/relay/op.h | 48 ------------------------------------------ 1 file changed, 48 deletions(-) delete mode 100644 include/tvm/relay/op.h diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h deleted file mode 100644 index 2e631df55fd0..000000000000 --- a/include/tvm/relay/op.h +++ /dev/null @@ -1,48 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file tvm/relay/op.h - * \brief Relay's representation of operators. - */ -#ifndef TVM_RELAY_OP_H_ -#define TVM_RELAY_OP_H_ - -#include "./expr.h" - -namespace tvm { -namespace relay { - - -/*! - * \brief A primitive Relay operator defined externally to Relay. - * - * \note Currently these are expected to be backed by a TVM's operator, - * such as the ones defined in TOPI. - * - * For developers who are familar with the computational graph this - * directly maps to the concept of operators in NNVM. - */ -class Operator; -/*! \brief Container for Operator */ -class OperatorNode : public ExprNode { - public: - std::string name; - /*! \brief A type which specifies the relationship between the inputs and outputs - * of the operator. - */ - Type op_type; - - void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("op_type", &op_type); - } - - TVM_DLL static Operator make(Type op_type); - - static constexpr const char* _type_key = "relay.Operator"; - TVM_DECLARE_NODE_TYPE_INFO(OperatorNode, OperatorNode); -}; - -RELAY_DEFINE_NODE_REF(Operator, OperatorNode, Expr); - -} // namespace relay -} // namespace tvm -#endif // TVM_RELAY_EXPR_H_ From 2c6b5219a8b36cc4004039c216797c60ef5d7d7d Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 22 Aug 2018 09:44:43 -0700 Subject: [PATCH 030/136] [OP] Current op system --- include/tvm/base.h | 5 + include/tvm/relay/op.h | 402 ++++++++++++++++++++++++++++ python/tvm/relay/__init__.py | 3 +- python/tvm/relay/make.py | 6 + python/tvm/relay/op.py | 37 +++ src/relay/op.cc | 105 ++++++-- src/relay/op/tensor/elemwise.cc | 23 ++ tests/python/relay/test_relay_op.py | 8 + 8 files changed, 568 insertions(+), 21 deletions(-) create mode 100644 include/tvm/relay/op.h create mode 100644 python/tvm/relay/op.py create mode 100644 src/relay/op/tensor/elemwise.cc create mode 100644 tests/python/relay/test_relay_op.py diff --git a/include/tvm/base.h b/include/tvm/base.h index c2d796b6002c..be848b34cd43 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -134,6 +134,11 @@ struct NodeFactoryReg { */ #define TVM_ADD_FILELINE "\n\nDefined in " __FILE__ ":L" TVM_STRINGIZE(__LINE__) +#define TVM_REGISTER_NODE_TYPE(TypeName) \ + static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \ + ::dmlc::Registry<::tvm::NodeFactoryReg>::Get()->__REGISTER__(TypeName::_type_key) \ + .set_body([]() { return std::make_shared(); }) + } // namespace tvm #endif // TVM_BASE_H_ diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h new file mode 100644 index 000000000000..f7e1cfbbc8c2 --- /dev/null +++ b/include/tvm/relay/op.h @@ -0,0 +1,402 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/op.h + * \brief Primitive operator definition. + */ +#ifndef TVM_RELAY_OP_H_ +#define TVM_RELAY_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "./base.h" +#include "./expr.h" +#include "../attrs.h" + +namespace tvm { +namespace relay { + +// forward declare name. +template +class OpMap; +class GenericOpMap; +class OpRegistry; + +/*! + * \brief Node container of operator structure. + */ +class OpNode : public relay::ExprNode { + public: + /*! \brief name of the operator */ + std::string name; + /*! + * \brief detailed description of the operator + * This can be used to generate docstring automatically for the operator. + */ + std::string description; + /* \brief Information of input arguments to the operator */ + Array arguments; + /*! + * \brief The type key of the attribute field + * This can be empty, in which case it defaults to + */ + std::string attrs_type_key; + /*! + * \brief number of input arguments to the operator, + * -1 means it is variable length + */ + int32_t num_inputs = -1; + /*! + * \brief support level of the operator, + * The lower the more priority it contains. + * This is in analogies to BLAS levels. + */ + int32_t support_level = 10; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("name", &name); + v->Visit("description", &description); + v->Visit("arguments", &arguments); + v->Visit("attrs_type_key", &attrs_type_key); + v->Visit("num_inputs", &num_inputs); + v->Visit("support_level", &support_level); + } + + static constexpr const char* _type_key = "relay.Op"; + TVM_DECLARE_NODE_TYPE_INFO(OpNode, Node); + + private: + // friend class + friend class GenericOpMap; + friend class OpRegistry; + // Program internal unique index of operator. + // Used to help index the program. + uint32_t index_{0}; +}; + +/*! + * \brief Operator reference class. + */ +class Op : public relay::Expr { + public: + /*! \brief default constructor */ + Op() {} + /*! \brief constructor from node pointer */ + explicit Op(std::shared_ptr n) : Expr(n) {} + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const OpNode* operator->() const; + /*! + * \brief Get additional registered attribute about operators. + * If nothing has been registered, an empty OpMap will be returned. + * \param attr_name The name of the attribute. + * \return An OpMap of specified attr_name. + * \tparam ValueType The type of the attribute. + */ + template + inline static const OpMap& GetAttr(const std::string& attr_name); + /*! + * \brief Get an Op for a given operator name. + * Will raise an error if the op has not been registered. + * \param op_name Name of the operator. + * \return Pointer to a Op, valid throughout program lifetime. + */ + TVM_DLL static const Op& Get(const std::string& op_name); + + /*! \brief specify container node */ + using ContainerType = OpNode; + + private: + /*! + * \brief Get generic attrmap given attr name + * \param key The attribute key + * \return reference to GenericOpMap + */ + TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key); +}; + +/*! \brief Helper structure to register operators */ +class OpRegistry { + public: + /*! \return the operator */ + const Op& op() const { + return op_; + } + /*! + * \brief setter function during registration + * Set the description of operator + * \param descr the description string. + * \return reference to self. + */ + inline OpRegistry& describe(const std::string& descr); // NOLINT(*) + /*! + * \brief Add argument information to the function. + * \param name Name of the argument. + * \param type Type of the argument. + * \param description Description of the argument. + * \return reference to self. + */ + inline OpRegistry& add_argument(const std::string &name, + const std::string &type, + const std::string &description); + /*! + * \brief Set the type key of attributes. + * \param type_key The type of of the attrs field.x + * \return reference to self. + */ + inline OpRegistry& set_attrs_type_key(const std::string& type_key); + /*! + * \brief Set the num_inputs + * \param n The number of inputs to be set. + * \return reference to self. + */ + inline OpRegistry& set_num_inputs(int32_t n); // NOLINT(*) + /*! + * \brief Set the support level of op. + * \param level The support level. + * \return reference to self. + */ + inline OpRegistry& set_support_level(int32_t level); // NOLINT(*) + /*! + * \brief Register additional attributes to operator. + * \param attr_name The name of the attribute. + * \param value The value to be set. + * \param plevel The priority level of this set, + * an higher priority level attribute + * will replace lower priority level attribute. + * Must be bigger than 0. + * + * Cannot set with same plevel twice in the code. + * + * \tparam ValueType The type of the value to be set. + */ + template + inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*) + const ValueType& value, + int plevel = 10); + + // set the name of the op to be the same as registry + inline OpRegistry& set_name() { // NOLINT(*) + get()->name = name; + return *this; + } + + private: + friend class ::dmlc::Registry; + // the name + std::string name; + /*! \brief The operator */ + Op op_; + // private constructor + OpRegistry(); + // return internal pointer to op. + inline OpNode* get(); + // update the attribute OpMap + TVM_DLL void UpdateAttr(const std::string& key, + TVMRetValue value, + int plevel); +}; + +/*! + * \brief Generic map to store additional information of Op. + */ +class GenericOpMap { + public: + /*! + * \brief Check if the map has op as key. + * \param op The key to the map + * \return 1 if op is contained in map, 0 otherwise. + */ + inline int count(const Op& op) const; + /*! + * \brief get the corresponding value element at op + * \param op The key to the map + * \return the const reference to the content value. + */ + inline const TVMRetValue& operator[](const Op& op) const; + /*! + * \brief get the corresponding value element at op with default value. + * \param op The key to the map + * \param def_value The default value when the key does not exist. + * \return the const reference to the content value. + * \tparam ValueType The content value type. + */ + template + inline ValueType get(const Op& op, ValueType def_value) const; + + private: + friend class OpRegistry; + // the attribute field. + std::string attr_name_; + // internal data + std::vector > data_; + // The value + GenericOpMap() = default; +}; + +/*! + * \brief Map used to store meta-information about Op. + * \tparam ValueType The type of the value stored in map. + */ +template +class OpMap { + public: + /*! + * \brief Check if the map has op as key. + * \param op The key to the map + * \return 1 if op is contained in map, 0 otherwise. + */ + inline int count(const Op& op) const; + /*! + * \brief get the corresponding value element at op + * \param op The key to the map + * \return the const reference to the content value. + */ + inline ValueType operator[](const Op& op) const; + /*! + * \brief get the corresponding value element at op with default value. + * \param op The key to the map + * \param def_value The default value when the key does not exist. + * \return the const reference to the content value. + */ + inline ValueType get(const Op& op, ValueType def_value) const; + + private: + friend class Op; + // constructor + explicit OpMap(const GenericOpMap& map) + : map_(map) {} + /*! \brief The internal map field */ + const GenericOpMap& map_; +}; + + +// internal macros to make +#define RELAY_REGISTER_VAR_DEF \ + static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::OpRegistry & __make_ ## RelayOp + +/*! + * \def RELAY_REGISTER_OP + * \brief Register a new operator, or set attribute of the corresponding op. + * + * \param OpName The name of registry + * + * \code + * + * RELAY_REGISTER_OP("add") + * .describe("add two inputs together") + * .set_num_inputs(2) + * .set_attr("gpu_kernel", AddKernel); + * + * \endcode + */ +#define RELAY_REGISTER_OP(OpName) \ + DMLC_STR_CONCAT(RELAY_REGISTER_VAR_DEF, __COUNTER__) = \ + ::dmlc::Registry<::tvm::relay::OpRegistry>::Get()->__REGISTER_OR_GET__(OpName).set_name() + +// implementations +inline const OpNode* Op::operator->() const { + return static_cast(node_.get()); +} + +template +inline const OpMap& Op::GetAttr(const std::string& key) { + return OpMap(Op::GetGenericAttr(key)); +} + +inline OpNode* OpRegistry::get() { + return const_cast(op_.operator->()); +} + +inline OpRegistry& OpRegistry::describe(const std::string& descr) { // NOLINT(*) + get()->description = descr; + return *this; +} + +inline OpRegistry& OpRegistry::add_argument(const std::string &name, + const std::string &type, + const std::string &description) { + std::shared_ptr n = std::make_shared(); + n->name = name; + n->type_info = type; + n->description = description; + get()->arguments.push_back(AttrFieldInfo(n)); + return *this; +} + +inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) + get()->num_inputs = n; + return *this; +} + +inline OpRegistry& OpRegistry::set_support_level(int32_t n) { // NOLINT(*) + get()->support_level = n; + return *this; +} + +template +inline OpRegistry& OpRegistry::set_attr( // NOLINT(*) + const std::string& attr_name, + const ValueType& value, + int plevel) { + CHECK_GT(plevel, 0) + << "plevel in set_attr must be greater than 0"; + TVMRetValue rv; + rv = value; + UpdateAttr(attr_name, rv, plevel); + return *this; +} + +// member functions of OpMap +inline int GenericOpMap::count(const Op& op) const { + if (op.defined()) { + const uint32_t idx = op->index_; + return idx < data_.size() ? (data_[idx].second != 0) : 0; + } else { + return 0; + } +} + +inline const TVMRetValue& GenericOpMap::operator[](const Op& op) const { + CHECK(op.defined()); + const uint32_t idx = op->index_; + CHECK(idx < data_.size() && data_[idx].second != 0) + << "Attribute " << attr_name_ + << " has not been registered for Operator " << op->name; + return data_[idx].first; +} + +template +inline ValueType GenericOpMap::get(const Op& op, ValueType value) const { + CHECK(op.defined()); + const uint32_t idx = op->index_; + if (idx < data_.size() && data_[idx].second != 0) { + return data_[idx].first; + } else { + return value; + } +} + +template +inline int OpMap::count(const Op& op) const { + return map_.count(op); +} + +template +inline ValueType OpMap::operator[](const Op& op) const { + return map_[op]; +} +template +inline ValueType OpMap::get(const Op& op, ValueType def_value) const { + return map_.get(op, def_value); +} + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_OP_H_ diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index c90875db4178..a9446ebed979 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -1,7 +1,8 @@ -"""Relay namespace.""" +"""The Relay IR namespace containing the IR definition and compiler.""" from . import base from . import type as tpe from . import make +from . import op # Type Type = tpe.Type diff --git a/python/tvm/relay/make.py b/python/tvm/relay/make.py index bf9ec0e48f64..8cad18c11c46 100644 --- a/python/tvm/relay/make.py +++ b/python/tvm/relay/make.py @@ -1,6 +1,12 @@ from . import _make from . import ir +This module includes MyPy type signatures for all of the +exposed modules. +""" +from __future__ import absolute_import as _abs +from .._ffi.function import _init_api + # Base Constructors Span = _make.Span diff --git a/python/tvm/relay/op.py b/python/tvm/relay/op.py new file mode 100644 index 000000000000..373d04f984a3 --- /dev/null +++ b/python/tvm/relay/op.py @@ -0,0 +1,37 @@ +"""Relay operators""" +from __future__ import absolute_import as _abs + +import sys +from .._ffi.function import _init_api +from .._ffi.node import convert_to_node +from . import make as _make +from ..make import node as _make_node + +def _create_op(op_name): + op = _GetOp(op_name) + attrs_type_key = op.attrs_type_key + attrs_type_key = attrs_type_key if attrs_type_key else "DictAttrs" + # TODO(tqchen): improve the code build to fix the restriction. + # + # current restriction: + # - pass in args as positional arguments + # - pass in kwargs as keyword argument + def _op_func(*args, **kwargs): + args = convert_to_node(args) + # Need work to make sure constructor matches + return _make.Call(op, args, + attrs = _make.node(attrs_type_key, **kwargs)) + _op_func.__name__ = op.name + return _op_func + + +def _init_ops(): + """Helper function to initialize the operators + """ + module = sys.modules[__name__] + for name in _ListOpNames(): + f = _create_op(name.value) + setattr(module, f.__name__, f) + +_init_api("relay.op", __name__) +_init_ops() diff --git a/src/relay/op.cc b/src/relay/op.cc index 07ad5f0ae4ed..5a4241a182b1 100644 --- a/src/relay/op.cc +++ b/src/relay/op.cc @@ -1,31 +1,96 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file op.cc - * \brief Relay's representation of operators. - */ -#include "tvm/relay/op.h" -#include "tvm/ir_functor.h" +#include +#include +#include + +namespace dmlc { +// enable registry +DMLC_REGISTRY_ENABLE(::tvm::relay::OpRegistry); +} // namespace dmlc namespace tvm { namespace relay { -using tvm::IRPrinter; -using namespace runtime; +// single manager of operator information. +struct OpManager { + // mutex to avoid registration from multiple threads. + std::mutex mutex; + // global operator counter + std::atomic op_counter{0}; + // storage of additional attribute table. + std::unordered_map > attr; + // get singleton of the + static OpManager* Global() { + static OpManager inst; + return &inst; + } +}; + +// find operator by name +const Op& Op::Get(const std::string& name) { + const OpRegistry* reg = dmlc::Registry::Find(name); + CHECK(reg != nullptr) + << "Operator " << name << " is not registered"; + return reg->op(); +} + +OpRegistry::OpRegistry() { + OpManager* mgr = OpManager::Global(); + std::shared_ptr n = std::make_shared(); + n->index_ = mgr->op_counter++; + op_ = Op(n); +} + +// Get attribute map by key +const GenericOpMap& Op::GetGenericAttr(const std::string& key) { + OpManager* mgr = OpManager::Global(); + std::lock_guard lock(mgr->mutex); + auto it = mgr->attr.find(key); + if (it == mgr->attr.end()) { + LOG(FATAL) << "Operator attribute \'" << key << "\' is not registered"; + } + return *it->second.get(); +} -Operator OperatorNode::make(Type op_type) { - std::shared_ptr n = std::make_shared(); - n->op_type = std::move(op_type); - return Operator(n); +void OpRegistry::UpdateAttr( + const std::string& key, TVMRetValue value, int plevel) { + OpManager* mgr = OpManager::Global(); + std::lock_guard lock(mgr->mutex); + std::unique_ptr& op_map = mgr->attr[key]; + if (op_map == nullptr) { + op_map.reset(new GenericOpMap()); + } + uint32_t index = op_->index_; + if (op_map->data_.size() <= index) { + op_map->data_.resize(index + 1, + std::make_pair(TVMRetValue(), 0)); + } + std::pair & p = op_map->data_[index]; + CHECK(p.second != plevel) + << "Attribute " << key + << " of operator " << this->name + << " is already registered with same plevel=" << plevel; + if (p.second < plevel) { + op_map->data_[index] = std::make_pair(value, plevel); + } } -TVM_REGISTER_API("relay._make.Operator").set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = OperatorNode::make(args[0]); -}); +// Frontend APIs +using runtime::TypedPackedFunc; + +TVM_REGISTER_API("relay.op._ListOpNames") +.set_body(TypedPackedFunc()>([]() { + Array ret; + for (const std::string& name : + dmlc::Registry::ListAllNames()) { + ret.push_back(tvm::Expr(name)); + } + return ret; + })); + +TVM_REGISTER_API("relay.op._GetOp") +.set_body(TypedPackedFunc(Op::Get)); + -TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const OperatorNode *node, tvm::IRPrinter *p) { - p->stream << "OperatorNode(" << node->op_type << ")"; - }); } // namespace relay } // namespace tvm diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc new file mode 100644 index 000000000000..8b759bfbc07c --- /dev/null +++ b/src/relay/op/tensor/elemwise.cc @@ -0,0 +1,23 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file elemwise.cc + * \brief Elementwise operators. + */ +#include + +namespace tvm { +namespace relay { + +RELAY_REGISTER_OP("log") +.describe(R"code(Returns the log input array, computed element-wise. + +.. math:: + log(x) + +)code" TVM_ADD_FILELINE) +.set_support_level(1) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor."); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_relay_op.py b/tests/python/relay/test_relay_op.py new file mode 100644 index 000000000000..93316da8ec41 --- /dev/null +++ b/tests/python/relay/test_relay_op.py @@ -0,0 +1,8 @@ +from tvm import relay + +def test_op_level1(): + assert relay.op.log + + +if __name__ == "__main__": + test_op_level1() From e1976812ad818b77d7e73901feb43a09d220033f Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 22 Aug 2018 17:04:04 -0700 Subject: [PATCH 031/136] WIP --- include/tvm/relay/compiler/environment.h | 24 ++++------ include/tvm/relay/compiler/intern_table.h | 55 ----------------------- include/tvm/relay/type.h | 2 +- 3 files changed, 10 insertions(+), 71 deletions(-) delete mode 100644 include/tvm/relay/compiler/intern_table.h diff --git a/include/tvm/relay/compiler/environment.h b/include/tvm/relay/compiler/environment.h index 3e108cd8b390..536302c31dc6 100644 --- a/include/tvm/relay/compiler/environment.h +++ b/include/tvm/relay/compiler/environment.h @@ -41,22 +41,18 @@ class EnvironmentNode : public RelayNode { private: /*! A map from string names to GlobalIds, ensures global uniqueness. */ InternTable global_map_; - /*! A map from string names to Operators, ensures global uniqueness. */ - InternTable operators; + // /*! \brief A map from file names to source fragments. */ // SourceMap source_map_ // /*! \brief A list of the errors reported during the current run. */ // std::vector errors_; public: - // This map contains all items *except* operators. - std::unordered_map items; + /*! \brief A map from ids to all global functions. */ + tvm::Map items; // Options options; - tvm::PackedFunc jit_for(Operator op); - tvm::PackedFunc reverse(Operator op); - EnvironmentNode() {} void VisitAttrs(tvm::AttrVisitor* v) final {} @@ -75,16 +71,14 @@ class EnvironmentNode : public RelayNode { // GlobalId global_id(const std::string& str); Operator op(const std::string& str); - // We can lookup a GlobalId, OperatorId. - // Defn lookup(const GlobalId& id); - // FileId add_source(std::string file_name, std::string source); + /*! \brief Lookup a global function by its name. */ + Function lookup(const GlobalVar& id); - // tvm::Array get_operators(); - // tvm::Array get_defns(); + /*! \brief Add a source fragment to the environment. */ + // FileId add_source(std::string file_name, std::string source); - // void report_error(std::string msg, Span sp); - // void display_errors(); - // void register_shape_ext(ShapeExtension ext); + void report_error(std::string msg, Span sp); + void display_errors(); static constexpr const char* _type_key = "relay.Environment"; TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node); diff --git a/include/tvm/relay/compiler/intern_table.h b/include/tvm/relay/compiler/intern_table.h deleted file mode 100644 index 1850e513e5e5..000000000000 --- a/include/tvm/relay/compiler/intern_table.h +++ /dev/null @@ -1,55 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file tvm/relay/compiler/intern_table.h - * \brief A table which maps string keys to data. - * - * These are useful for mapping user-readable names - * to globally unique allocations which use pointer - * equality for comparsion. - */ -#ifndef TVM_RELAY_COMPILER_INTERN_TABLE_H_ -#define TVM_RELAY_COMPILER_INTERN_TABLE_H_ - -#include -#include -#include "dmlc/logging.h" - -namespace tvm { -namespace relay { - -struct KeyNotFound : dmlc::Error { - explicit KeyNotFound(std::string msg) : dmlc::Error(msg) {} -}; - -template -class InternTable { -private: - /*! \brief The internal table mapping from strings to T. */ - std::unordered_map table_; - - public: - /*! \brief Insert a new key into the table. - * \note Attempting to reinsert a key triggers an error. - */ - void Insert(const std::string& key, const T& value) { - if (table_.find(key) == table_.end()) { - table_.insert({key, value}); - } else { - throw dmlc::Error( - std::string("you have previously interred a value for: ") + key); - } - } - - /*! \brief Lookup the data in the table. */ - const T& Lookup(std::string key) { - if (table_.find(key) != table_.end()) { - return table_.at(key); - } else { - throw KeyNotFound(std::string("could not find match") + key); - } - } -}; - -} // namespace relay -} // namespace tvm -#endif // TVM_RELAY_COMPILER_INTERN_TABLE_H_ diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 07b047471aba..ef8c4c71f5b7 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -92,7 +92,7 @@ class TensorTypeNode : public BaseTensorTypeNode { /*! \brief Construct a floating-point type */ TVM_DLL static TensorType Float(int bits, int lanes = 1); - /*1 \brief Construct a boolean type */ + /*! \brief Construct a boolean type */ TVM_DLL static TensorType Bool(int lanes = 1); static constexpr const char* _type_key = "relay.TensorType"; From b2fa0237f72061e38be32b5379b7014e96f9834c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 23 Aug 2018 13:04:01 -0700 Subject: [PATCH 032/136] WIP --- cmake/config.cmake | 3 + include/tvm/relay/compiler/environment.h | 20 +++--- include/tvm/relay/compiler/type_infer.h | 2 +- include/tvm/relay/expr_functor.h | 4 +- include/tvm/relay/expr_visitor.h | 6 +- src/relay/compiler/environment.cc | 88 +++++++----------------- 6 files changed, 41 insertions(+), 82 deletions(-) diff --git a/cmake/config.cmake b/cmake/config.cmake index c364a88cce11..e09fdb241bf1 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -19,6 +19,9 @@ # $ make -j8 #-------------------------------------------------------------------- +SET(CMAKE_C_COMPLIER clang) +SET(CMAKE_CXX_COMPILER clang++) + #--------------------------------------------- # Backend runtimes. #--------------------------------------------- diff --git a/include/tvm/relay/compiler/environment.h b/include/tvm/relay/compiler/environment.h index 536302c31dc6..5b33e781b399 100644 --- a/include/tvm/relay/compiler/environment.h +++ b/include/tvm/relay/compiler/environment.h @@ -8,7 +8,6 @@ #include #include -#include "tvm/relay/compiler/intern_table.h" #include "../expr.h" #include "../type.h" #include "../op.h" @@ -40,7 +39,7 @@ struct Environment; class EnvironmentNode : public RelayNode { private: /*! A map from string names to GlobalIds, ensures global uniqueness. */ - InternTable global_map_; + tvm::Map global_map_; // /*! \brief A map from file names to source fragments. */ // SourceMap source_map_ @@ -61,18 +60,17 @@ class EnvironmentNode : public RelayNode { std::unordered_map global_funcs); /*! Add an operator to the Enviroment. */ - void register_op(const Operator& op); - // void add(const Operator& op, bool update = false); + void register_op(const Op& op); + void add(const GlobalVar& var, const Function & func, bool update = false); + void try_add(const GlobalVar& var, const Function & func, bool update=false); + void update(const GlobalVar& var, const Function & func); + void remove(const GlobalVar& var); - // void try_add(const Item& item, bool update=false); - // void update(const Item& item); - // void remove(const GlobalId& id); - - // GlobalId global_id(const std::string& str); - Operator op(const std::string& str); + GlobalVar GetGlobalVar(const std::string& str); /*! \brief Lookup a global function by its name. */ - Function lookup(const GlobalVar& id); + Function Lookup(const GlobalVar& id); + Function Lookup(const std::string & s); /*! \brief Add a source fragment to the environment. */ // FileId add_source(std::string file_name, std::string source); diff --git a/include/tvm/relay/compiler/type_infer.h b/include/tvm/relay/compiler/type_infer.h index 6d07de1c29e8..c084fb7a109e 100644 --- a/include/tvm/relay/compiler/type_infer.h +++ b/include/tvm/relay/compiler/type_infer.h @@ -24,7 +24,7 @@ Expr Infer(const Environment & env, const Expr & e); /*! \brief Ensures that an operator is well-formed with respect * to Relay's type system. */ -Operator CheckOperator(const Environment & env, const Operator & op); +Op CheckOp(const Environment & env, const Op & op); } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 922892e8a7a5..2067b90bd364 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -113,7 +113,7 @@ class ExprFunctor { virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const IfNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const OperatorNode* op, + virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args...) { throw dmlc::Error(std::string("Do not have a default for ") + op->type_key()); @@ -133,7 +133,7 @@ class ExprFunctor { RELAY_EXPR_FUNCTOR_DISPATCH(CallNode); RELAY_EXPR_FUNCTOR_DISPATCH(LetNode); RELAY_EXPR_FUNCTOR_DISPATCH(IfNode); - RELAY_EXPR_FUNCTOR_DISPATCH(OperatorNode); + RELAY_EXPR_FUNCTOR_DISPATCH(OpNode); return vtable; } }; diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h index 2039414b4238..d1e8a99dc374 100644 --- a/include/tvm/relay/expr_visitor.h +++ b/include/tvm/relay/expr_visitor.h @@ -58,7 +58,7 @@ class ExprVisitor : public ::tvm::relay::ExprFunctorVisitExpr(op->false_value, args...); } - void VisitExpr_(const OperatorNode* op, Args... args) override { return; } + void VisitExpr_(const OpNode* op, Args... args) override { return; } }; template @@ -72,8 +72,8 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor(op); } - Expr VisitExpr_(const OperatorNode* op, Args... args) override { - return GetRef(op); + Expr VisitExpr_(const OpNode* op, Args... args) override { + return GetRef(op); } Expr VisitExpr_(const TupleNode* op, Args... args) override { diff --git a/src/relay/compiler/environment.cc b/src/relay/compiler/environment.cc index 735ef79ceb3a..7ce0785f4f8f 100644 --- a/src/relay/compiler/environment.cc +++ b/src/relay/compiler/environment.cc @@ -1,7 +1,7 @@ /*! * Copyright (c) 2018 by Contributors * \file environment.cc - * \brief Relay global environment. + * \brief The global environment in Relay. */ #include #include "tvm/relay/compiler/environment.h" @@ -24,34 +24,17 @@ Environment EnvironmentNode::make( return Environment(n); } -void EnvironmentNode::register_op(const Operator& op) { - this->operators.Insert(op->name, op); +GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) { + auto global_id = global_map_.find(str); + if (global_id != global_map_.end()) { + return (*global_id).second; + } else { + auto id = GlobalVarNode::make(str); + this->global_map_.Set(str, id); + return id; + } } -// tvm::PackedFunc EnvironmentNode::jit_for(OperatorId id) { -// return this->lookup(id)->compiler; -// } - -// GlobalId EnvironmentNode::global_id(const std::string &str) { -// try { -// return global_map_.Lookup(str); -// } catch (const KeyNotFound &err) { -// GlobalId id = GlobalIdNode::make(str); -// global_map_.Insert(str, id); -// return id; -// } -// } - -// OperatorId EnvironmentNode::operator_id(const std::string &str) { -// try { -// return operator_map_.Lookup(str); -// } catch (const KeyNotFound &err) { -// OperatorId id = OperatorIdNode::make(str); -// operator_map_.Insert(str, id); -// return id; -// } -// } - // // Add a new item to the global environment // // throws an exception if the item already // // exists. @@ -79,15 +62,15 @@ void EnvironmentNode::register_op(const Operator& op) { // } else { // operators.insert({op->id, op}); // } -// } else if (const DefnNode *d = item.as()) { -// auto def = GetRef(d); +// } else if (const FunctionNode *d = item.as()) { +// auto def = GetRef(d); // auto type = def->type; // if (items.find(def->id) != items.end()) { // if (!update) { // throw dmlc::Error("already have definition for XXXX."); // } -// auto old_type = items[def->id].as()->type; +// auto old_type = items[def->id].as()->type; // if (!alpha_eq(type, old_type)) { // throw dmlc::Error( @@ -107,23 +90,18 @@ void EnvironmentNode::register_op(const Operator& op) { // void EnvironmentNode::remove(const GlobalId &id) { this->items.erase(id); } -// Defn EnvironmentNode::lookup(const GlobalId &id) { -// if (items.find(id) != items.end()) { -// return items.at(id); -// } else { -// throw EnvError(std::string("there is no definition of ") + id->name); -// } -// } - -Operator EnvironmentNode::op(const std::string & op_name) { - // FIX ME - return operators.Lookup(op_name); +Function EnvironmentNode::Lookup(const GlobalVar &var) { + if (items.find(var) != items.end()) { + return items.at(var); + } else { + throw Error(std::string("there is no definition of ") + var->name_hint); + } } -// Defn EnvironmentNode::lookup_global(const std::string &str) { -// GlobalId id = this->global_id(str); -// return this->lookup(id); -// } +Function EnvironmentNode::Lookup(const std::string &str) { + GlobalVar id = this->GetGlobalVar(str); + return this->Lookup(id); +} // inline FileId EnvironmentNode::add_source(std::string file_name, // std::string source) { @@ -163,26 +141,6 @@ Operator EnvironmentNode::op(const std::string & op_name) { // } // } -// Array EnvironmentNode::get_operators() { -// std::vector ops; -// for (auto pair : this->operators) { -// ops.push_back(pair.second); -// } -// return Array(ops); -// } - -// Array EnvironmentNode::get_defns() { -// std::vector defns; -// for (auto pair : this->items) { -// defns.push_back(pair.second); -// } -// return Array(defns); -// } - -// void EnvironmentNode::register_shape_ext(ShapeExtension ext) { -// this->shape_exts_.Insert(ext->name, ext); -// } - TVM_REGISTER_API("relay._make.Environment") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = EnvironmentNode::make({}); From bf5e16da8407d6d99b0a21c28889b2ebac33ed24 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 23 Aug 2018 17:14:05 -0700 Subject: [PATCH 033/136] Change over to new Node construction --- python/tvm/relay/__init__.py | 16 ++++- python/tvm/relay/_make.pyi | 91 -------------------------- python/tvm/relay/base.py | 4 ++ python/tvm/relay/expr.py | 29 ++++++++ python/tvm/relay/ir_builder.py | 4 ++ python/tvm/relay/make.py | 75 --------------------- python/tvm/relay/op.py | 2 +- python/tvm/relay/type.py | 59 ++++++++++++++++- tests/python/relay/test_ir_nodes.py | 59 +++++++++-------- tests/python/relay/test_typechecker.py | 4 +- 10 files changed, 143 insertions(+), 200 deletions(-) delete mode 100644 python/tvm/relay/_make.pyi delete mode 100644 python/tvm/relay/make.py diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index a9446ebed979..037d71854689 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -1,9 +1,12 @@ """The Relay IR namespace containing the IR definition and compiler.""" from . import base from . import type as tpe -from . import make +from . import expr from . import op +# Span +Span = base.Span + # Type Type = tpe.Type TensorType = tpe.TensorType @@ -11,3 +14,14 @@ TypeParam = tpe.TypeParam TypeConstraint = tpe.TypeConstraint FuncType = tpe.FuncType + +# Expr +Constant = expr.Constant +Tuple = expr.Tuple +LocalVar = expr.LocalVar +GlobalVar = expr.GlobalVar +Param = expr.Param +Function = expr.Function +Call = expr.Call +Let = expr.Let +If = expr.If diff --git a/python/tvm/relay/_make.pyi b/python/tvm/relay/_make.pyi deleted file mode 100644 index d94857916319..000000000000 --- a/python/tvm/relay/_make.pyi +++ /dev/null @@ -1,91 +0,0 @@ -# from typing import Dict, List, Any, Callable, TypeVar as PyTypeVar -# import nnvm.relay.ir as ir -# import nnvm.relay.env as env -# import ctypes - -# # Environment -# def Environment(items: Dict[ir.GlobalId, ir.Item]) -> env.Environment: ... - -# # Items TODO(@jroesch) Correct Anys to the right type. -# def Operator(id: ir.OperatorId, tvm_name: str, ty: ir.Type, compiler: Any, fwd_mode: Any, rev_mode: Any) -> ir.Operator: ... -# def Defn(id: ir.GlobalId, ty: ir.Type, body: ir.Function) -> ir.Defn: ... - -# # Types -# def IntType(bits: int, lanes: int) -> ir.Type: ... -# def UIntType(bits: int, lanes: int) -> ir.Type: ... -# def FloatType(bits: int, lanes: int) -> ir.Type: ... -# def BoolType(lanes: int) -> ir.Type: ... -# def TupleType(fields: List[ir.Type]) -> ir.Type: ... -# def TensorType(dtype: ir.Type, shape: ir.Type) -> ir.Type: ... -# def TypeParam(name: str, kind: ir.Kind) -> ir.Type: ... -# def TypeQuantifier(id: ir.TypeId, body: ir.Type) -> ir.Type: ... -# def TypeArrow(left: ir.Type, right: ir.Type) -> ir.Type: ... -# def TypeVar(kind: ir.Kind) -> ir.Type: ... -# def PlaceholderType() -> ir.Type: ... -# def ShapeSeq(shapes: List[ir.Type]) -> ir.ShapeSeq: ... -# def ShapeSingleton(value: int) -> ir.ShapeSingleton: ... -# def ShapeAttr(id: ir.StringLit) -> ir.ShapeAttr: ... -# def ShapeProjection(shape: ir.Type, value: int) -> ir.ShapeProjection: ... -# def ShapeBinaryOp(op: ir.ShapeOp, left: ir.Type, right: ir.Type) -> ir.ShapeBinaryOp: ... -# def ShapeBroadcast(left: ir.Type, right: ir.Type) -> ir.ShapeBroadcast: ... -# def ShapeExtension(name: str, eval: Any) -> ir.ShapeExtension: ... -# def TypeCall(func: ir.Type, args: List[ir.Type]) -> ir.TypeCall: ... -# def RefType(data_type: ir.Type) -> ir.RefType: ... - -# # Expressions -# def Param(id: ir.LocalId, type: ir.Type) -> ir.Param: ... -# def Function(ty_params: List[ir.TypeId], params: List[ir.Param], ret_type: ir.Type, body: ir.Expr) -> ir.Function: ... -# def LocalId(name: str) -> ir.Expr: ... -# def GlobalId(name: str) -> ir.Expr: ... -# def OperatorId(name: str) -> ir.Expr: ... -# def Let(id: ir.LocalId, ty: ir.Type, value: ir.Expr, body: ir.Expr) -> ir.Expr: ... -# def IntLit(value: int) -> ir.IntLit: ... -# def FloatLit(value: float) -> ir.FloatLit: ... -# def TensorLit(value: List[ir.Expr]) -> ir.TensorLit: ... -# def Tuple(fields: List[ir.Expr]) -> ir.Expr: ... -# def BoolLit(value: bool) -> ir.BoolLit: ... -# def StringLit(value: str) -> ir.StringLit: ... -# def Attributes(attrs: Dict[str, ir.Expr]) -> ir.Attributes: ... -# def Call(func: ir.Expr, args: List[ir.Expr], attrs: ir.Attributes) -> ir.Call: ... -# def UnaryOp(op: ir.UOp, arg: ir.Expr) -> ir.Expr: ... -# def BinaryOp(op: ir.BOp, left: ir.Expr, right: ir.Expr) -> ir.Expr: ... -# def Projection(tuple: ir.Expr, field : int) -> ir.Expr: ... -# def Gradient(node: ir.Expr) -> ir.Expr: ... -# def Cast(target: ir.Type, node: ir.Expr) -> ir.Expr: ... -# def Debug(node: ir.Expr) -> ir.Expr: ... -# def Zero(type: ir.Type) -> ir.Expr: ... -# def If(guard: ir.Expr, true_branch: ir.Expr, false_branch: ir.Expr) -> ir.Expr: ... -# def Ref(value: ir.Expr) -> ir.Expr: ... -# def ReadRef(ref: ir.Expr) -> ir.Expr: ... -# def WriteRef(ref: ir.Expr, value: ir.Expr) -> ir.Expr: ... - -# # Values -# def IntValue(value: int) -> ir.TensorValue: ... -# def FloatValue(value: float) -> ir.TensorValue: ... -# def BoolValue(value: bool) -> ir.TensorValue: ... -# def TensorValue(handle: ctypes.c_void_p) -> ir.TensorValue: ... -# def Closure(env: Dict[ir.LocalId, ir.Value], fn: ir.Function) -> ir.Closure: ... - -# # Error Reporting -# def Span(file_id: ir.FileId, lineno: int, col_offset: int) -> ir.NodeBase: ... -# def FileId(file_id: int) -> ir.FileId: ... - -# # Utils -# def _alpha_eq(e1: ir.Expr, e2: ir.Expr) -> bool: ... -# def _type_alpha_eq(e1: ir.Type, e2: ir.Type) -> bool: ... -# def _expr_set_span(e: ir.Expr, sp: ir.Span) -> None: ... -# def _type_set_span(t: ir.Type, sp: ir.Span) -> None: ... -# def _item_set_span(t: ir.Item, sp: ir.Span) -> None: ... -# def Node_hash(n: ir.Node) -> int: ... -# def Operator_is_generic(op: ir.Operator) -> bool: ... - -# # FIXME -# def UnionFind() -> Any: ... -# def TypeUnifier() -> Any: ... - -# T = PyTypeVar('T') -# U = PyTypeVar('U') -# PassFunc = Callable[[env.Environment], Callable[[T], U]] - -# # Passes -# def ItemPass(name: str, pass_func: PassFunc[ir.Item, ir.Item]) -> ir.ItemPass: ... diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index 687ba53ac005..ee818617f629 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -3,6 +3,7 @@ from __future__ import absolute_import as _abs from typing import Union from .._ffi.node import NodeBase, register_node as _register_tvm_node +from . import _make NodeBase = NodeBase @@ -25,3 +26,6 @@ class Span(NodeBase): source: "FileSource" lineno: int col_offset: int + + def __init__(self, source, lineno, col_offset): + self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index c17a69dd0dc9..7f5dcbd0beb5 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -7,6 +7,7 @@ from .type import Type, TypeParam from tvm import expr from ._type_infer import _get_checked_type +from . import _make class Expr(NodeBase): """The base type for all Relay exprressions.""" @@ -19,6 +20,9 @@ class Constant(Expr): """ data: tvm.nd.NDArray + def __init__(self, data: tvm.nd.NDArray) -> None: + self.__init_handle_by_constructor__(_make.Constant, data) + @register_relay_node class Tuple(Expr): """A hetereogenous sequence of values. @@ -26,16 +30,26 @@ class Tuple(Expr): """ fields: List[Expr] + def __init__(self, fields: List[Expr]) -> None: + self.__init_handle_by_constructor__(_make.Tuple, fields) + + @register_relay_node class LocalVar(Expr): """A local variable in Relay.""" name_hint: str + def __init__(self, name_hint: str) -> None: + self.__init_handle_by_constructor__(_make.LocalVar, name_hint) + @register_relay_node class GlobalVar(Expr): """A global variable in Relay.""" name_hint: str + def __init__(self, name_hint: str) -> None: + self.__init_handle_by_constructor__(_make.GlobalVar, name_hint) + @register_relay_node class Param(Expr): """A function type in Relay, see tvm/relay/type.h for more details. @@ -43,6 +57,10 @@ class Param(Expr): var: LocalVar type: Type + def __init__(self, var: LocalVar, type: Type) -> None: + self.__init_handle_by_constructor__(_make.Param, var, type) + + @register_relay_node class Function(Expr): type_params: List[TypeParam] @@ -50,11 +68,17 @@ class Function(Expr): ret_type: Type body: Expr + def __init__(self, params: List[Param], ret_type: Type, body: Expr, type_params: List[TypeParam]=[]) -> None: + self.__init_handle_by_constructor__(_make.Function, params, ret_type, body, type_params) + class Call(Expr): op: Expr args: List[Expr] # todo(@jroesch): add attrs + def __init__(self, op: Expr, args: List[Expr], attrs, ty_args) -> None: + self.__init_handle_by_constructor__(_make.Call, op, args, attrs, ty_args) + @register_relay_node class Let(Expr): var: LocalVar @@ -62,6 +86,9 @@ class Let(Expr): body: Expr value_type: Type # should be type nanotation + def __init__(self, var: LocalVar, value: Expr, body: Expr, value_type: Type) -> None: + self.__init_handle_by_constructor__(_make.Let, var, value, body, value_type) + @register_relay_node class If(Expr): cond: Expr @@ -69,3 +96,5 @@ class If(Expr): false_value: Expr span: Span + def __init__(self, cond: Expr, true_value: Expr, false_value: Expr) -> None: + self.__init_handle_by_constructor__(_make.If, cond, true_value, false_value) diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 35bc31265987..f24e7baa1483 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -4,6 +4,7 @@ from . import type as ty from . import expr from . import make as mk +from . import op class ExprBuilder(): def __init__(self, expr): @@ -142,6 +143,9 @@ def get(self): return _mk_let(bindings, self.ret_value) +def op(name): + return op._create_op(name) + def bool_dtype(): return 'uint1' diff --git a/python/tvm/relay/make.py b/python/tvm/relay/make.py deleted file mode 100644 index 8cad18c11c46..000000000000 --- a/python/tvm/relay/make.py +++ /dev/null @@ -1,75 +0,0 @@ -from . import _make -from . import ir - -This module includes MyPy type signatures for all of the -exposed modules. -""" -from __future__ import absolute_import as _abs -from .._ffi.function import _init_api - -# Base Constructors -Span = _make.Span - -# Environment -Environment = _make.Environment - -# Type Constructors -TensorType = _make.TensorType -TypeParam = _make.TypeParam -FuncType = _make.FuncType - -# Types -def IntType(bits: int, lanes: int=1) -> ir.Type: - """Constructs a integer base type. - - :param bits: The bit width of the integer type. - :param lanes: The number of vector elements for this datatype. - - """ - return _make.IntType(bits, lanes) - - -def UIntType(bits: int, lanes: int=1) -> ir.Type: - """Constructs a unsigned integer base type. - - :param bits: The bit width of the unsigned type. - :param lanes: The number of vector elements for this datatype. - """ - return _make.UIntType(bits, lanes) - - -def FloatType(bits: int, lanes: int=1) -> ir.Type: - """Constructs a floating point base type. - - :param bits: The bit width of the unsigned type. - :param lanes: The number of vector elements for this datatype. - """ - return _make.FloatType(bits, lanes) - - -def BoolType(lanes: int =1) -> ir.Type: - """Constructs a boolean base type. - - :param bits: The bit width of the unsigned type. - :param lanes: The number of vector elements for this datatype. - """ - return _make.BoolType(lanes) - -# Expr Constructors -Constant = _make.Constant -Tuple = _make.Tuple -LocalVar = _make.LocalVar -GlobalVar = _make.GlobalVar -Param = _make.Param -Function = _make.Function -Call = _make.Call -Let = _make.Let -If = _make.If -IncompleteType = _make.IncompleteType - -# Unifier -UnionFind = _make.UnionFind -TypeUnifier = _make.TypeUnifier - -# Utility Functionality @TODO(jroesch): move to another location -_type_alpha_eq = _make._type_alpha_eq diff --git a/python/tvm/relay/op.py b/python/tvm/relay/op.py index 373d04f984a3..dae498b66c12 100644 --- a/python/tvm/relay/op.py +++ b/python/tvm/relay/op.py @@ -4,7 +4,7 @@ import sys from .._ffi.function import _init_api from .._ffi.node import convert_to_node -from . import make as _make +from . import _make from ..make import node as _make_node def _create_op(op_name): diff --git a/python/tvm/relay/type.py b/python/tvm/relay/type.py index a04089792282..c7b8964c20e8 100644 --- a/python/tvm/relay/type.py +++ b/python/tvm/relay/type.py @@ -5,7 +5,7 @@ from .base import Span, NodeBase, register_relay_node from tvm import expr # TODO(@jroesch): move me -from ._make import _type_alpha_eq +from . import _make class Type(NodeBase): """The base type for all Relay types.""" @@ -14,7 +14,7 @@ def __eq__(self, other) -> bool: """Compares two Relay types for structural equivalence using alpha equivalence. """ - return bool(_type_alpha_eq(self, other)) + return bool(_make._type_alpha_eq(self, other)) def __ne__(self, other) -> bool: return not self.__eq__(other) @@ -31,6 +31,9 @@ class TensorType(Type): shape: List[expr.Expr] span: Span + def __init__(self, dtype: str, shape: List[expr.Expr]) -> None: + self.__init_handle_by_constructor__(_make.TensorType,dtype, shape) + class Kind(IntEnum): """The kind of a type parameter, represents a variable shape, base type, type, or dimension. @@ -49,6 +52,9 @@ class TypeParam(Type): kind: Kind span: Span + def __init__(self, var: expr.Var, kind: Kind) -> None: + self.__init_handle_by_constructor__(_make.TypeParam, var, kind) + @register_relay_node class TypeConstraint(Type): """Abstract class representing a type constraint.""" @@ -64,7 +70,54 @@ class FuncType(Type): ret_type: Type span: Span + def __init__(self, arg_types: List[Type], ret_type: Type, type_params: List[TypeParam], type_constraints: List[TypeConstraint]) -> None: + self.__init_handle_by_constructor__(_make.FuncType, arg_types, ret_type, type_params, type_constraints) + +@register_relay_node +class TypeCall(Type): + def __init__() -> None: + pass + + @register_relay_node class IncompleteType(Type): """An incomplete type.""" - pass + + def __init__(self, kind: Kind) -> None: + self.__init_handle_by_constructor__(_make.IncompleteType, kind) + +def IntType(bits: int, lanes: int=1) -> Type: + """Constructs a integer base type. + + :param bits: The bit width of the integer type. + :param lanes: The number of vector elements for this datatype. + + """ + return _make.IntType(bits, lanes) + + +def UIntType(bits: int, lanes: int=1) -> Type: + """Constructs a unsigned integer base type. + + :param bits: The bit width of the unsigned type. + :param lanes: The number of vector elements for this datatype. + """ + return _make.UIntType(bits, lanes) + + +def FloatType(bits: int, lanes: int=1) -> Type: + """Constructs a floating point base type. + + :param bits: The bit width of the unsigned type. + :param lanes: The number of vector elements for this datatype. + """ + return _make.FloatType(bits, lanes) + + +def BoolType(lanes: int =1) -> Type: + """Constructs a boolean base type. + + :param bits: The bit width of the unsigned type. + :param lanes: The number of vector elements for this datatype. + """ + return _make.BoolType(lanes) diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index 26fe06109513..676aa347950b 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -1,14 +1,11 @@ """ test ir""" import tvm from tvm import relay -import tvm.relay.make as mk -from tvm import expr +from tvm.expr import * # Span - - def test_span() -> None: - span = mk.Span(None, 1, 1) + span = relay.Span(None, 1, 1) assert span.source == None assert span.lineno == 1 assert span.col_offset == 1 @@ -19,11 +16,10 @@ def test_span() -> None: # Types - def test_tensor_type() -> None: shape = tvm.convert([1, 2, 3]) dtype = 'float32' - tt = mk.TensorType(shape, dtype) + tt = relay.TensorType(shape, dtype) assert tt.dtype == dtype assert tt.shape == shape assert tt.span == None @@ -31,7 +27,7 @@ def test_tensor_type() -> None: def test_type_param() -> None: - tp = mk.TypeParam('name', relay.Kind.Shape) + tp = relay.TypeParam('name', relay.Kind.Shape) tp.kind == relay.Kind.Shape tp.span # TODO allow us to set span str(tp) @@ -42,7 +38,7 @@ def test_func_type() -> None: type_constraints = tvm.convert([]) # TODO: fill me in arg_types = tvm.convert([]) ret_type = None - tf = mk.FuncType(arg_types, ret_type, type_params, type_constraints) + tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints) assert tf.type_params == type_params assert tf.type_constraints == type_constraints assert tf.arg_types == arg_types @@ -54,7 +50,7 @@ def test_func_type() -> None: def test_constant() -> None: arr = tvm.nd.array(10) - const = mk.Constant(arr) + const = relay.Constant(arr) assert const.data == arr assert const.span == None str(const) @@ -62,7 +58,7 @@ def test_constant() -> None: def test_tuple() -> None: fields = tvm.convert([]) - tup = mk.Tuple(fields) + tup = relay.Tuple(fields) assert tup.fields == fields assert tup.span == None str(tup) @@ -70,7 +66,7 @@ def test_tuple() -> None: def test_local_var() -> None: name_hint = 's' - lv = mk.LocalVar(name_hint) + lv = relay.LocalVar(name_hint) lv.name_hint == name_hint # assert lv.span == None todo(@jroesch): what do we do about spans str(lv) @@ -78,16 +74,16 @@ def test_local_var() -> None: def test_global_var() -> None: name_hint = 'g' - gv = mk.GlobalVar(name_hint) + gv = relay.GlobalVar(name_hint) gv.name_hint == name_hint # assert lv.span == None todo(@jroesch): what do we do about spans str(gv) def test_param() -> None: - lv = mk.LocalVar('x') + lv = relay.LocalVar('x') ty = None - param = mk.Param(lv, ty) + param = relay.Param(lv, ty) assert param.var == lv assert param.type == ty assert param.span == None @@ -96,11 +92,11 @@ def test_param() -> None: def test_function() -> None: param_names = ['a', 'b', 'c', 'd'] - params = tvm.convert([mk.Param(mk.LocalVar(n), None) for n in param_names]) + params = tvm.convert([relay.Param(relay.LocalVar(n), None) for n in param_names]) ret_type = None body = None type_params = tvm.convert([]) - fn = mk.Function(params, ret_type, body, type_params) + fn = relay.Function(params, ret_type, body, type_params) assert fn.params == params assert fn.body == body assert fn.type_params == type_params @@ -109,10 +105,10 @@ def test_function() -> None: def test_call() -> None: - op = mk.LocalVar('f') + op = relay.LocalVar('f') arg_names = ['a', 'b', 'c', 'd'] - args = tvm.convert([mk.LocalVar(n) for n in arg_names]) - call = mk.Call(op, args, None, None) + args = tvm.convert([relay.LocalVar(n) for n in arg_names]) + call = relay.Call(op, args, None, None) assert call.op == op assert call.args == args assert call.span == None @@ -120,13 +116,13 @@ def test_call() -> None: def test_let() -> None: - lv = mk.LocalVar('x') + lv = relay.LocalVar('x') ty = None arr = tvm.nd.array(10) - value = mk.Constant(arr) + value = relay.Constant(arr) # I would prefer that the order of arguments # matches syntax let x : t = v in b - let = mk.Let(lv, value, lv, ty) + let = relay.Let(lv, value, lv, ty) assert let.var == lv assert let.value == value assert let.value_type == ty @@ -136,10 +132,10 @@ def test_let() -> None: def test_if() -> None: - cond = mk.LocalVar('cond') - left = mk.LocalVar('left') - right = mk.LocalVar('right') - ife = mk.If(cond, left, right) + cond = relay.LocalVar('cond') + left = relay.LocalVar('left') + right = relay.LocalVar('right') + ife = relay.If(cond, left, right) assert ife.cond == cond assert ife.true_value == left assert ife.false_value == right @@ -152,3 +148,12 @@ def test_if() -> None: test_tensor_type() test_type_param() test_func_type() + test_constant() + test_tuple() + test_local_var() + test_global_var() + test_param() + test_function() + test_call() + test_let() + test_if() diff --git a/tests/python/relay/test_typechecker.py b/tests/python/relay/test_typechecker.py index c6bc1a05ebe9..bf172eb07935 100644 --- a/tests/python/relay/test_typechecker.py +++ b/tests/python/relay/test_typechecker.py @@ -3,7 +3,7 @@ """ import tvm.relay.make as mk from tvm.relay.type_infer import check_expr -from tvm.relay.ir_builder import IRBuilder, float_type +from tvm.relay.ir_builder import IRBuilder, float_type, op def has_type(expr, typ): env = mk.Environment({}) @@ -23,7 +23,7 @@ def test_monomorphic_let(): def test_single_op(): "Program: fn (x : int32) { let t1 = f(x); t1 }" b = IRBuilder() - f = b.op('f') + f = op('log') with b.function(('x', float_type())) as func: x, = func.param_ids() t1 = b.let('t1', f(x)) From b73e3c5fd51a606186d28511a158e62222f50fa9 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 23 Aug 2018 17:18:37 -0700 Subject: [PATCH 034/136] Basic tests working --- python/tvm/relay/ir_builder.py | 5 +- tests/python/relay/test_alpha_eq.py | 1 - tests/python/relay/test_typechecker.py | 1 - tests/python/relay/test_unifier.py | 163 ++++++++++++------------- 4 files changed, 83 insertions(+), 87 deletions(-) diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index f24e7baa1483..af83c9948be2 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -3,8 +3,7 @@ import tvm from . import type as ty from . import expr -from . import make as mk -from . import op +from . import op as _op class ExprBuilder(): def __init__(self, expr): @@ -144,7 +143,7 @@ def get(self): return _mk_let(bindings, self.ret_value) def op(name): - return op._create_op(name) + return _op._create_op(name) def bool_dtype(): return 'uint1' diff --git a/tests/python/relay/test_alpha_eq.py b/tests/python/relay/test_alpha_eq.py index e4fbbcca93ce..6c0e7779eae6 100644 --- a/tests/python/relay/test_alpha_eq.py +++ b/tests/python/relay/test_alpha_eq.py @@ -1,5 +1,4 @@ """Test alpha-equivalence of expressions and types.""" -from tvm.relay import make as mk # from relay.ir import alpha_eq, ShapeOp, Kind # from relay.typing import TYPE_DEFAULTS # from relay import ir diff --git a/tests/python/relay/test_typechecker.py b/tests/python/relay/test_typechecker.py index bf172eb07935..6a16aadcb002 100644 --- a/tests/python/relay/test_typechecker.py +++ b/tests/python/relay/test_typechecker.py @@ -1,7 +1,6 @@ """Test that type checker correcly computes types for expressions. """ -import tvm.relay.make as mk from tvm.relay.type_infer import check_expr from tvm.relay.ir_builder import IRBuilder, float_type, op diff --git a/tests/python/relay/test_unifier.py b/tests/python/relay/test_unifier.py index 21889faa51ee..c45e6ac4f732 100644 --- a/tests/python/relay/test_unifier.py +++ b/tests/python/relay/test_unifier.py @@ -3,17 +3,16 @@ between incomplete types. """ import tvm -from tvm.relay import ir +from tvm import relay from tvm.relay.unifier import UnionFind, TypeUnifier from tvm.relay.ir_builder import bool_type, uint_type, int_type, float_type, func_type from tvm.relay import ir_builder as build -import tvm.relay.make as mk def test_insert_and_find(): - uf = mk.UnionFind() - v1 = mk.IncompleteType(ir.Kind.Type) - v2 = mk.IncompleteType(ir.Kind.Type) + uf = relay.UnionFind() + v1 = relay.IncompleteType(ir.Kind.Type) + v2 = relay.IncompleteType(ir.Kind.Type) uf.insert(v1) uf.insert(v2) assert uf.find(v1) == v1 @@ -21,9 +20,9 @@ def test_insert_and_find(): def test_insert_error(): - uf = mk.UnionFind() - v1 = mk.IncompleteType(ir.Kind.Type) - v2 = mk.IncompleteType(ir.Kind.Type) + uf = relay.UnionFind() + v1 = relay.IncompleteType(ir.Kind.Type) + v2 = relay.IncompleteType(ir.Kind.Type) uf.insert(v1) try: uf.find(v2) @@ -33,10 +32,10 @@ def test_insert_error(): def test_unify(): - uf = mk.UnionFind() - v1 = mk.IncompleteType(ir.Kind.Type) - v2 = mk.IncompleteType(ir.Kind.Type) - v3 = mk.IncompleteType(ir.Kind.Type) + uf = relay.UnionFind() + v1 = relay.IncompleteType(ir.Kind.Type) + v2 = relay.IncompleteType(ir.Kind.Type) + v3 = relay.IncompleteType(ir.Kind.Type) uf.insert(v1) uf.insert(v2) uf.insert(v3) @@ -56,8 +55,8 @@ def test_unify(): def test_unify_multiple_levels(): - uf = mk.UnionFind() - v = [mk.IncompleteType(ir.Kind.Type) for _ in range(9)] + uf = relay.UnionFind() + v = [relay.IncompleteType(ir.Kind.Type) for _ in range(9)] for var in v: uf.insert(var) uf.unify(v[0], v[1]) @@ -94,7 +93,7 @@ def test_unify_multiple_levels(): def unify_types(t1, t2): - unifier = mk.TypeUnifier() + unifier = relay.TypeUnifier() return unifier.unify(t1, t2) # TODO(sslyu, weberlo, joshpoll): put in isinstance asserts once those work @@ -136,8 +135,8 @@ def test_unify_concrete_func_type(): def test_unify_func_type_with_holes(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.BaseType) + unifier = relay.TypeUnifier() + v1 = relay.IncompleteType(ir.Kind.BaseType) unifier.insert(v1) unifier.unify(v1, bool_type()) arr1 = func_type([int_type()], bool_type()) @@ -145,7 +144,7 @@ def test_unify_func_type_with_holes(): unified = unifier.unify(arr1, arr2) assert unified == arr1 - v2 = mk.IncompleteType(ir.Kind.BaseType) + v2 = relay.IncompleteType(ir.Kind.BaseType) unifier.insert(v2) unifier.unify(v2, int_type()) arr3 = func_type([v2], bool_type()) @@ -179,10 +178,10 @@ def test_reject_incompatible_func_types(): def test_unify_typevars_with_each_other(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.Type) - v2 = mk.IncompleteType(ir.Kind.Type) - v3 = mk.IncompleteType(ir.Kind.Type) + unifier = relay.TypeUnifier() + v1 = relay.IncompleteType(ir.Kind.Type) + v2 = relay.IncompleteType(ir.Kind.Type) + v3 = relay.IncompleteType(ir.Kind.Type) unifier.insert(v1) unifier.insert(v2) unifier.insert(v3) @@ -194,10 +193,10 @@ def test_unify_typevars_with_each_other(): def test_unify_typevars_with_basetype(): - unifier = mk.TypeUnifier() + unifier = relay.TypeUnifier() bt = bool_type() - v1 = mk.IncompleteType(ir.Kind.Type) - v2 = mk.IncompleteType(ir.Kind.Type) + v1 = relay.IncompleteType(ir.Kind.Type) + v2 = relay.IncompleteType(ir.Kind.Type) unifier.insert(v1) unifier.insert(v2) unified1 = unifier.unify(v1, bt) @@ -207,10 +206,10 @@ def test_unify_typevars_with_basetype(): def test_unify_compatible_typevars(): - unifier = mk.TypeUnifier() + unifier = relay.TypeUnifier() bt = bool_type() - v1 = mk.IncompleteType(ir.Kind.Type) - v2 = mk.IncompleteType(ir.Kind.Type) + v1 = relay.IncompleteType(ir.Kind.Type) + v2 = relay.IncompleteType(ir.Kind.Type) unifier.insert(v1) unifier.insert(v2) unifier.unify(v1, bt) @@ -221,9 +220,9 @@ def test_unify_compatible_typevars(): assert unified == bt # def test_unify_incompatible_typevars(): -# unifier = mk.TypeUnifier() -# v1 = mk.IncompleteType(ir.Kind.Type) -# v2 = mk.IncompleteType(ir.Kind.Type) +# unifier = relay.TypeUnifier() +# v1 = relay.IncompleteType(ir.Kind.Type) +# v2 = relay.IncompleteType(ir.Kind.Type) # bt = bool_type() # tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt) # unifier.insert(v1) @@ -238,16 +237,16 @@ def test_unify_compatible_typevars(): # return # def test_unify_typevar_with_quantifier(): -# unifier = mk.TypeUnifier() +# unifier = relay.TypeUnifier() # tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bool_type()) -# v1 = mk.IncompleteType(ir.Kind.BaseType) +# v1 = relay.IncompleteType(ir.Kind.BaseType) # unifier.insert(v1) # unified = unifier.unify(v1, tq) # assert unified == tq # def test_unify_typevars_inside_concrete_quantifier(): -# unifier = mk.TypeUnifier() -# v1 = mk.IncompleteType(ir.Kind.BaseType) +# unifier = relay.TypeUnifier() +# v1 = relay.IncompleteType(ir.Kind.BaseType) # unifier.insert(v1) # tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v1) # tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), bool_type()) @@ -258,8 +257,8 @@ def test_unify_compatible_typevars(): def test_unify_concrete_tensors(): bt = build.bool_dtype() shape = tvm.convert([1, 2, 3]) - tt1 = mk.TensorType(shape, bt) - tt2 = mk.TensorType(shape, bt) + tt1 = relay.TensorType(shape, bt) + tt2 = relay.TensorType(shape, bt) unified = unify_types(tt1, tt2) assert unified == tt1 @@ -268,8 +267,8 @@ def test_unify_tensor_shape_reject(): bt = build.bool_dtype() shape1 = tvm.convert([1, 2, 3]) shape2 = tvm.convert([2, 3, 4]) - tt1 = mk.TensorType(shape1, bt) - tt2 = mk.TensorType(shape2, bt) + tt1 = relay.TensorType(shape1, bt) + tt2 = relay.TensorType(shape2, bt) try: unify_types(tt1, tt2) assert False @@ -281,8 +280,8 @@ def test_unify_tensor_dtype_reject(): bt1 = build.bool_dtype() bt2 = build.int_dtype() shape = tvm.convert([1, 2, 3]) - tt1 = mk.TensorType(shape, bt1) - tt2 = mk.TensorType(shape, bt2) + tt1 = relay.TensorType(shape, bt1) + tt2 = relay.TensorType(shape, bt2) try: unify_types(tt1, tt2) assert False @@ -292,15 +291,15 @@ def test_unify_tensor_dtype_reject(): # def test_unify_quantified_tensors(): # x = TypeParam("x", ir.type.Kind.Shape) # y = TypeParam("y", ir.type.Kind.Shape) -# tq1 = TypeQuantifier(x, mk.TensorType(bool_type(), x)) -# tq2 = TypeQuantifier(y, mk.TensorType(bool_type(), y)) +# tq1 = TypeQuantifier(x, relay.TensorType(bool_type(), x)) +# tq2 = TypeQuantifier(y, relay.TensorType(bool_type(), y)) # unified = unify_types(tq1, tq2) # assert unified == tq1 # a = TypeParam("a", ir.type.Kind.BaseType) # b = TypeParam("b", ir.type.Kind.BaseType) -# tq3 = TypeQuantifier(a, mk.TensorType(a, make_shape([1, 2, 3]))) -# tq4 = TypeQuantifier(b, mk.TensorType(b, make_shape([1, 2, 3]))) +# tq3 = TypeQuantifier(a, relay.TensorType(a, make_shape([1, 2, 3]))) +# tq4 = TypeQuantifier(b, relay.TensorType(b, make_shape([1, 2, 3]))) # unified = unify_types(tq3, tq4) # assert unified == tq3 @@ -335,8 +334,8 @@ def test_unify_tensor_dtype_reject(): # return # def test_unify_products_typevar(): -# unifier = mk.TypeUnifier() -# v1 = mk.IncompleteType(ir.Kind.BaseType) +# unifier = relay.TypeUnifier() +# v1 = relay.IncompleteType(ir.Kind.BaseType) # bt = bool_type() # pt1 = TupleType([bt, bt]) # pt2 = TupleType([v1, bt]) @@ -354,14 +353,14 @@ def test_unify_tensor_dtype_reject(): def test_subst_basetype(): - unifier = mk.TypeUnifier() + unifier = relay.TypeUnifier() bt = bool_type() assert bt == unifier.subst(bt) def test_subst_simple_hole(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.BaseType) + unifier = relay.TypeUnifier() + v1 = relay.IncompleteType(ir.Kind.BaseType) bt = bool_type() unifier.insert(v1) unifier.unify(v1, bt) @@ -369,9 +368,9 @@ def test_subst_simple_hole(): def test_subst_typevar_for_typevar(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.Type) - v2 = mk.IncompleteType(ir.Kind.Type) + unifier = relay.TypeUnifier() + v1 = relay.IncompleteType(ir.Kind.Type) + v2 = relay.IncompleteType(ir.Kind.Type) unifier.insert(v1) unifier.insert(v2) @@ -380,9 +379,9 @@ def test_subst_typevar_for_typevar(): def test_subst_typevar_for_typevar_comm(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.Type) - v2 = mk.IncompleteType(ir.Kind.Type) + unifier = relay.TypeUnifier() + v1 = relay.IncompleteType(ir.Kind.Type) + v2 = relay.IncompleteType(ir.Kind.Type) unifier.insert(v1) unifier.insert(v2) @@ -391,15 +390,15 @@ def test_subst_typevar_for_typevar_comm(): def test_subst_concrete_arrow(): - unifier = mk.TypeUnifier() + unifier = relay.TypeUnifier() arr1 = func_type([int_type()], int_type()) assert unifier.subst(arr1) == arr1 def test_subst_arrow_with_holes(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.BaseType) - v2 = mk.IncompleteType(ir.Kind.BaseType) + unifier = relay.TypeUnifier() + v1 = relay.IncompleteType(ir.Kind.BaseType) + v2 = relay.IncompleteType(ir.Kind.BaseType) unifier.insert(v1) unifier.insert(v2) unifier.unify(v1, int_type()) @@ -409,17 +408,17 @@ def test_subst_arrow_with_holes(): assert unifier.subst(arr1) == arr2 # def test_subst_concrete_quantifier(): -# unifier = mk.TypeUnifier() -# v1 = mk.IncompleteType(ir.Kind.BaseType) +# unifier = relay.TypeUnifier() +# v1 = relay.IncompleteType(ir.Kind.BaseType) # tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), int_type()) # unifier.insert(v1) # unifier.unify(v1, tq) # assert unifier.subst(v1) == tq # def test_subst_quantifier_with_holes(): -# unifier = mk.TypeUnifier() -# v1 = mk.IncompleteType(ir.Kind.Type) -# v2 = mk.IncompleteType(ir.Kind.Type) +# unifier = relay.TypeUnifier() +# v1 = relay.IncompleteType(ir.Kind.Type) +# v2 = relay.IncompleteType(ir.Kind.Type) # tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v2) # intty = int_type() # tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), intty) @@ -431,16 +430,16 @@ def test_subst_arrow_with_holes(): def test_subst_concrete_tensor(): - unifier = mk.TypeUnifier() - v1 = mk.IncompleteType(ir.Kind.Type) + unifier = relay.TypeUnifier() + v1 = relay.IncompleteType(ir.Kind.Type) unifier.insert(v1) - tt = mk.TensorType(tvm.convert([1, 2, 3]), 'uint1') + tt = relay.TensorType(tvm.convert([1, 2, 3]), 'uint1') unifier.unify(v1, tt) assert unifier.subst(v1) == tt # def test_subst_concrete_product(): -# unifier = mk.TypeUnifier() -# v1 = mk.IncompleteType(ir.Kind.Type) +# unifier = relay.TypeUnifier() +# v1 = relay.IncompleteType(ir.Kind.Type) # unifier.insert(v1) # bt = bool_type() # pt = TupleType([bt, bt]) @@ -448,16 +447,16 @@ def test_subst_concrete_tensor(): # assert unifier.subst(v1) == pt # def test_subst_product_with_holes(): -# unifier = mk.TypeUnifier() -# v1 = mk.IncompleteType(ir.Kind.Type) -# v2 = mk.IncompleteType(ir.Kind.Type) -# v3 = mk.IncompleteType(ir.Kind.Type) +# unifier = relay.TypeUnifier() +# v1 = relay.IncompleteType(ir.Kind.Type) +# v2 = relay.IncompleteType(ir.Kind.Type) +# v3 = relay.IncompleteType(ir.Kind.Type) # unifier.insert(v1) # unifier.insert(v2) # unifier.insert(v3) -# tt1 = mk.TensorType(int_type(), tvm.convert([])) -# tt2 = mk.TensorType(FloatType(32), tvm.convert([])) +# tt1 = relay.TensorType(int_type(), tvm.convert([])) +# tt2 = relay.TensorType(FloatType(32), tvm.convert([])) # pt1 = TupleType([tt1, v2, v3]) # unifier.unify(v2, tt2) # unifier.unify(v3, v2) @@ -466,13 +465,13 @@ def test_subst_concrete_tensor(): # assert unifier.subst(v1) == pt2 # def test_subst_concrete_ref(): -# unifier = mk.TypeUnifier() +# unifier = relay.TypeUnifier() # rt = RefType(bool_type()) # assert unifier.subst(rt) == rt # def test_subst_ref_with_hole(): -# unifier = mk.TypeUnifier() -# v1 = mk.IncompleteType(ir.Kind.Type) +# unifier = relay.TypeUnifier() +# v1 = relay.IncompleteType(ir.Kind.Type) # unifier.insert(v1) # unifier.unify(v1, bool_type()) @@ -481,9 +480,9 @@ def test_subst_concrete_tensor(): # assert unifier.subst(rt1) == rt2 # def test_typevar_on_lhs(): -# unifier = mk.TypeUnifier() -# v1 = mk.IncompleteType(ir.Kind.BaseType) -# v2 = mk.IncompleteType(ir.Kind.Type) +# unifier = relay.TypeUnifier() +# v1 = relay.IncompleteType(ir.Kind.BaseType) +# v2 = relay.IncompleteType(ir.Kind.Type) # bt = bool_type() # tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt, bt) # unifier.insert(v1) From e9f8bd57df2283405481d158be211ec36a57b101 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 23 Aug 2018 17:28:36 -0700 Subject: [PATCH 035/136] Remove unifier from Python interface --- python/tvm/relay/_unifier.py | 5 - python/tvm/relay/_unifier.pyi | 12 - python/tvm/relay/unifier.py | 61 ---- src/relay/compiler/unifier.cc | 80 ----- tests/python/relay/test_unifier.py | 495 ----------------------------- 5 files changed, 653 deletions(-) delete mode 100644 python/tvm/relay/_unifier.py delete mode 100644 python/tvm/relay/_unifier.pyi delete mode 100644 python/tvm/relay/unifier.py delete mode 100644 tests/python/relay/test_unifier.py diff --git a/python/tvm/relay/_unifier.py b/python/tvm/relay/_unifier.py deleted file mode 100644 index 41f5fe374b3e..000000000000 --- a/python/tvm/relay/_unifier.py +++ /dev/null @@ -1,5 +0,0 @@ -"""FFI functions for the Unifier.""" - -from tvm._ffi.function import _init_api - -_init_api("relay._unifier", __name__) diff --git a/python/tvm/relay/_unifier.pyi b/python/tvm/relay/_unifier.pyi deleted file mode 100644 index 6ecd309250a6..000000000000 --- a/python/tvm/relay/_unifier.pyi +++ /dev/null @@ -1,12 +0,0 @@ -from tvm.relay.ir import NodeBase - -class UnionFind(NodeBase): ... -class TypeUnifier(NodeBase): ... - -def UnionFind_insert(self: UnionFind, var: ir.IncompleteType) -> None: ... -def UnionFind_unify(self: UnionFind, var1: ir.IncompleteType, var2: ir.IncompleteType) -> None: ... -def UnionFind_find(self: UnionFind, var: ir.IncompleteType) -> ir.Type: ... - -def TypeUnifier_insert(self: TypeUnifier, var: ir.IncompleteType) -> None: ... -def TypeUnifier_unify(self, type1: ir.Type, type2: ir.Type) -> ir.Type: ... -def TypeUnifier_subst(self, type1: ir.Type) -> ir.Type: ... diff --git a/python/tvm/relay/unifier.py b/python/tvm/relay/unifier.py deleted file mode 100644 index cb818de19c1d..000000000000 --- a/python/tvm/relay/unifier.py +++ /dev/null @@ -1,61 +0,0 @@ -"""The Python interface to Relay's UnionFind and TypeUnifier.""" - -from typing import Dict -from .ir import register_relay_node, NodeBase -from . import ir -from . import _unifier - -@register_relay_node -class UnionFind(NodeBase): - """Python API for UnionFind. - - The UnionFind maintains equality classes of type variables, the - representative of an equality class may be a type (which can) - contain type variables. The TypeUnifier uses this to build a - unification procedure between types. - """ - uf_map: Dict[ir.IncompleteType, ir.IncompleteType] - - def insert(self, var: ir.IncompleteType) -> None: - """Insert a type variable into the union find. - - :param: var: The variable to be inserted. - """ - return _unifier.UnionFind_insert(self, var) - - def unify(self, var: ir.IncompleteType, typ: ir.Type) -> None: - """Unify a type variable with an arbitrary type. - - :param: var: A type variable to be unified. - :param: typ: The type to be unified with. - """ - return _unifier.UnionFind_unify(self, var, typ) - - def find(self, var: ir.IncompleteType) -> ir.IncompleteType: - """Find the representative element of the type var. - - :param: var: The variable to lookup in the union find. - """ - return _unifier.UnionFind_find(self, var) - -@register_relay_node -class TypeUnifier(NodeBase): - """Python API for the TypeUnifier.""" - #pylint: disable=invalid-name - uf: UnionFind - eq_map: Dict[ir.TypeParam, ir.TypeParam] - - def insert(self, var: ir.IncompleteType) -> None: - return _unifier.TypeUnifier_insert(self, var) - - def unify(self, type1: ir.Type, type2: ir.Type) -> ir.Type: - """Unify two types producing the unified type as a result. - - :param: type1: The first type to be unified. - :param: type2: The second type to be unified. - :returns: The unified type. - """ - return _unifier.TypeUnifier_unify(self, type1, type2) - - def subst(self, type1: ir.Type) -> ir.Type: - return _unifier.TypeUnifier_subst(self, type1) diff --git a/src/relay/compiler/unifier.cc b/src/relay/compiler/unifier.cc index ff46e8e863d1..5c0fbcf3ec71 100644 --- a/src/relay/compiler/unifier.cc +++ b/src/relay/compiler/unifier.cc @@ -369,85 +369,5 @@ Type TypeUnifierNode::VisitType_(const TypeCallNode *tcn1, const Type t2) { } - -TVM_REGISTER_API("relay._make.TypeUnifier") - .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args.size() < 3) { - *ret = TypeUnifierNode::make(UnionFindNode::make({})); - } else { - *ret = TypeUnifierNode::make(args[0]); - } - }); - -TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const TypeUnifierNode *node, - tvm::IRPrinter *p) { - p->stream << "TypeUnifierNode(" << node->uf << ")"; - }); - -TVM_REGISTER_API("relay._unifier.UnionFind_insert") - .set_body([](TVMArgs args, TVMRetValue *ret) { - try { - UnionFind uf = args[0]; - uf->insert(args[1]); - } catch (std::exception &e) { - throw UnionFindError(e.what()); - } - }); - -TVM_REGISTER_API("relay._unifier.UnionFind_unify") - .set_body([](TVMArgs args, TVMRetValue *ret) { - try { - UnionFind uf = args[0]; - uf->unify(args[1], args[2]); - } catch (std::exception &e) { - throw UnionFindError(e.what()); - } - }); - -TVM_REGISTER_API("relay._unifier.UnionFind_find") - .set_body([](TVMArgs args, TVMRetValue *ret) { - try { - UnionFind uf = args[0]; - *ret = uf->find(args[1]); - } catch (std::exception &e) { - throw UnionFindError(e.what()); - } - }); - -TVM_REGISTER_API("relay._unifier.TypeUnifier_insert") - .set_body([](TVMArgs args, TVMRetValue *ret) { - try { - TypeUnifier unifier = args[0]; - IncompleteType var = args[1]; - unifier->insert(var); - } catch (std::exception &e) { - throw UnificationError(e.what()); - } - }); - -TVM_REGISTER_API("relay._unifier.TypeUnifier_unify") - .set_body([](TVMArgs args, TVMRetValue *ret) { - try { - TypeUnifier unifier = args[0]; - Type t1 = args[1]; - Type t2 = args[2]; - *ret = unifier->unify(t1, t2); - } catch (std::exception &e) { - throw UnificationError(e.what()); - } - }); - -TVM_REGISTER_API("relay._unifier.TypeUnifier_subst") - .set_body([](TVMArgs args, TVMRetValue *ret) { - try { - TypeUnifier unifier = args[0]; - Type t = args[1]; - *ret = unifier->subst(t); - } catch (std::exception &e) { - throw SubstitutionError(e.what()); - } - }); - } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_unifier.py b/tests/python/relay/test_unifier.py deleted file mode 100644 index c45e6ac4f732..000000000000 --- a/tests/python/relay/test_unifier.py +++ /dev/null @@ -1,495 +0,0 @@ -""" -Test the type unifier, which solves systems of equations -between incomplete types. -""" -import tvm -from tvm import relay -from tvm.relay.unifier import UnionFind, TypeUnifier -from tvm.relay.ir_builder import bool_type, uint_type, int_type, float_type, func_type -from tvm.relay import ir_builder as build - - -def test_insert_and_find(): - uf = relay.UnionFind() - v1 = relay.IncompleteType(ir.Kind.Type) - v2 = relay.IncompleteType(ir.Kind.Type) - uf.insert(v1) - uf.insert(v2) - assert uf.find(v1) == v1 - assert uf.find(v2) == v2 - - -def test_insert_error(): - uf = relay.UnionFind() - v1 = relay.IncompleteType(ir.Kind.Type) - v2 = relay.IncompleteType(ir.Kind.Type) - uf.insert(v1) - try: - uf.find(v2) - assert False - except: - return - - -def test_unify(): - uf = relay.UnionFind() - v1 = relay.IncompleteType(ir.Kind.Type) - v2 = relay.IncompleteType(ir.Kind.Type) - v3 = relay.IncompleteType(ir.Kind.Type) - uf.insert(v1) - uf.insert(v2) - uf.insert(v3) - uf.unify(v1, v2) - rep = uf.find(v1) - assert (rep == v1 or rep == v2) - assert uf.find(v1) == rep - assert uf.find(v2) == rep - assert uf.find(v3) == v3 - assert v3 != rep - uf.unify(v1, v3) - new_rep = uf.find(v3) - assert (rep == v1 or rep == v2 or rep == v3) - assert uf.find(v1) == new_rep - assert uf.find(v2) == new_rep - assert uf.find(v3) == new_rep - - -def test_unify_multiple_levels(): - uf = relay.UnionFind() - v = [relay.IncompleteType(ir.Kind.Type) for _ in range(9)] - for var in v: - uf.insert(var) - uf.unify(v[0], v[1]) - uf.unify(v[0], v[2]) - uf.unify(v[3], v[4]) - uf.unify(v[4], v[5]) - uf.unify(v[6], v[7]) - uf.unify(v[6], v[8]) - rep1 = uf.find(v[0]) - rep2 = uf.find(v[3]) - rep3 = uf.find(v[6]) - assert (rep1 == v[0] or rep1 == v[1] or rep1 == v[2]) - assert (rep2 == v[3] or rep2 == v[4] or rep2 == v[5]) - assert (rep3 == v[6] or rep3 == v[7] or rep3 == v[8]) - for i in range(3): - assert uf.find(v[i]) == rep1 - assert uf.find(v[i + 3]) == rep2 - assert uf.find(v[i + 6]) == rep3 - # now unify two of the groups - uf.unify(v[1], v[4]) - new_rep1 = uf.find(v[0]) - new_rep2 = uf.find(v[6]) - assert (new_rep1 == v[0] or new_rep1 == v[1] or new_rep1 == v[2] - or new_rep1 == v[3] or new_rep1 == v[4] or new_rep1 == v[5]) - assert (new_rep2 == v[6] or new_rep2 == v[7] or new_rep2 == v[8]) - for i in range(6): - assert uf.find(v[i]) == new_rep1 - for i in range(3): - assert uf.find(v[i + 6]) == new_rep2 - -# We have checked that the basic machinery in the UnionFind works -# and now we will test the type unifier which will fill in holes -# between type equalities by the process of unification. - - -def unify_types(t1, t2): - unifier = relay.TypeUnifier() - return unifier.unify(t1, t2) - -# TODO(sslyu, weberlo, joshpoll): put in isinstance asserts once those work - - -def test_unify_int(): - intty = int_type(1) - unified = unify_types(intty, intty) - assert intty == unified - - -def test_unify_bool(): - boolty = bool_type() - unified = unify_types(boolty, boolty) - assert boolty == unified - - -def test_unify_float(): - floatty = float_type(4) - unified = unify_types(floatty, floatty) - assert floatty == unified - - -def test_unify_incompatible_basetypes(): - bt = bool_type() - intty = int_type(32) - try: - unify_types(bt, intty) - assert False - except: - return - - -def test_unify_concrete_func_type(): - arr1 = func_type([int_type()], int_type()) - arr2 = func_type([int_type()], int_type()) - unified = unify_types(arr1, arr2) - assert unified == arr1 - - -def test_unify_func_type_with_holes(): - unifier = relay.TypeUnifier() - v1 = relay.IncompleteType(ir.Kind.BaseType) - unifier.insert(v1) - unifier.unify(v1, bool_type()) - arr1 = func_type([int_type()], bool_type()) - arr2 = func_type([int_type()], v1) - unified = unifier.unify(arr1, arr2) - assert unified == arr1 - - v2 = relay.IncompleteType(ir.Kind.BaseType) - unifier.insert(v2) - unifier.unify(v2, int_type()) - arr3 = func_type([v2], bool_type()) - unified = unifier.unify(arr1, arr3) - assert unified == arr1 - - -def test_reject_incompatible_func_types(): - arr1 = func_type([int_type()], bool_type()) - arr2 = func_type([int_type(), bool_type()], bool_type()) - try: - unify_types(arr1, arr2) - assert False - except: - return - -# def test_unify_concrete_type_quantifiers(): -# tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), int_type()) -# tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), int_type()) -# unified = unify_types(tq1, tq2) -# assert unified == tq1 - -# def test_unify_basetype_with_quantifier_error(): -# bt = bool_type() -# tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt) -# try: -# unify_types(bt, tq) -# assert False -# except: -# return - - -def test_unify_typevars_with_each_other(): - unifier = relay.TypeUnifier() - v1 = relay.IncompleteType(ir.Kind.Type) - v2 = relay.IncompleteType(ir.Kind.Type) - v3 = relay.IncompleteType(ir.Kind.Type) - unifier.insert(v1) - unifier.insert(v2) - unifier.insert(v3) - unified = unifier.unify(v1, v2) - assert (unified == v1 or unified == v2) - assert unified != v3 - new_unified = unifier.unify(v1, v3) - assert (new_unified == v1 or new_unified == v2 or new_unified == v3) - - -def test_unify_typevars_with_basetype(): - unifier = relay.TypeUnifier() - bt = bool_type() - v1 = relay.IncompleteType(ir.Kind.Type) - v2 = relay.IncompleteType(ir.Kind.Type) - unifier.insert(v1) - unifier.insert(v2) - unified1 = unifier.unify(v1, bt) - assert unified1 == bt - unified2 = unifier.unify(v1, v2) - assert unified2 == bt - - -def test_unify_compatible_typevars(): - unifier = relay.TypeUnifier() - bt = bool_type() - v1 = relay.IncompleteType(ir.Kind.Type) - v2 = relay.IncompleteType(ir.Kind.Type) - unifier.insert(v1) - unifier.insert(v2) - unifier.unify(v1, bt) - unifier.unify(v2, bt) - # because types to which v1 and v2 have been assigned are compatible, - # this should proceed without problems - unified = unifier.unify(v1, v2) - assert unified == bt - -# def test_unify_incompatible_typevars(): -# unifier = relay.TypeUnifier() -# v1 = relay.IncompleteType(ir.Kind.Type) -# v2 = relay.IncompleteType(ir.Kind.Type) -# bt = bool_type() -# tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt) -# unifier.insert(v1) -# unifier.insert(v2) -# unifier.unify(v1, bt) -# unifier.unify(v2, tq) -# # bt cannot be unified with tq, so unifying v1 and v2 should give an error -# try: -# unifier.unify(v1, v2) -# assert False -# except: -# return - -# def test_unify_typevar_with_quantifier(): -# unifier = relay.TypeUnifier() -# tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bool_type()) -# v1 = relay.IncompleteType(ir.Kind.BaseType) -# unifier.insert(v1) -# unified = unifier.unify(v1, tq) -# assert unified == tq - -# def test_unify_typevars_inside_concrete_quantifier(): -# unifier = relay.TypeUnifier() -# v1 = relay.IncompleteType(ir.Kind.BaseType) -# unifier.insert(v1) -# tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v1) -# tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), bool_type()) -# unified = unifier.unify(tq1, tq2) -# assert unified == tq2 - - -def test_unify_concrete_tensors(): - bt = build.bool_dtype() - shape = tvm.convert([1, 2, 3]) - tt1 = relay.TensorType(shape, bt) - tt2 = relay.TensorType(shape, bt) - unified = unify_types(tt1, tt2) - assert unified == tt1 - - -def test_unify_tensor_shape_reject(): - bt = build.bool_dtype() - shape1 = tvm.convert([1, 2, 3]) - shape2 = tvm.convert([2, 3, 4]) - tt1 = relay.TensorType(shape1, bt) - tt2 = relay.TensorType(shape2, bt) - try: - unify_types(tt1, tt2) - assert False - except: - return - - -def test_unify_tensor_dtype_reject(): - bt1 = build.bool_dtype() - bt2 = build.int_dtype() - shape = tvm.convert([1, 2, 3]) - tt1 = relay.TensorType(shape, bt1) - tt2 = relay.TensorType(shape, bt2) - try: - unify_types(tt1, tt2) - assert False - except: - return - -# def test_unify_quantified_tensors(): -# x = TypeParam("x", ir.type.Kind.Shape) -# y = TypeParam("y", ir.type.Kind.Shape) -# tq1 = TypeQuantifier(x, relay.TensorType(bool_type(), x)) -# tq2 = TypeQuantifier(y, relay.TensorType(bool_type(), y)) -# unified = unify_types(tq1, tq2) -# assert unified == tq1 - -# a = TypeParam("a", ir.type.Kind.BaseType) -# b = TypeParam("b", ir.type.Kind.BaseType) -# tq3 = TypeQuantifier(a, relay.TensorType(a, make_shape([1, 2, 3]))) -# tq4 = TypeQuantifier(b, relay.TensorType(b, make_shape([1, 2, 3]))) -# unified = unify_types(tq3, tq4) -# assert unified == tq3 - -# def test_unify_concrete_products(): -# bt = bool_type() -# intty = int_type() -# pt1 = TupleType([bt, intty]) -# pt2 = TupleType([bt, intty]) -# unified = unify_types(pt1, pt2) -# assert unified == pt1 - -# def test_unify_products_reject_size(): -# bt = bool_type() -# intty = IntType(32) -# pt1 = TupleType([bt, bt, intty]) -# pt2 = TupleType([bt, intty]) -# try: -# unify_types(pt1, pt2) -# assert False -# except: -# return - -# def test_unify_products_reject_member(): -# bt = bool_type() -# intty = int_type() -# pt1 = TupleType([bt, bt]) -# pt2 = TupleType([bt, intty]) -# try: -# unify_types(pt1, pt2) -# assert False -# except: -# return - -# def test_unify_products_typevar(): -# unifier = relay.TypeUnifier() -# v1 = relay.IncompleteType(ir.Kind.BaseType) -# bt = bool_type() -# pt1 = TupleType([bt, bt]) -# pt2 = TupleType([v1, bt]) -# unifier.insert(v1) -# unified = unifier.unify(pt1, pt2) -# assert unified == pt1 - -# def test_unify_quantified_products(): -# x = TypeParam("x", ir.Kind.Type) -# y = TypeParam("y", ir.Kind.Type) -# p1 = TypeQuantifier(x, TupleType([int_type(), x])) -# p2 = TypeQuantifier(y, TupleType([int_type(), y])) -# unified = unify_types(p1, p2) -# assert unified == p1 - - -def test_subst_basetype(): - unifier = relay.TypeUnifier() - bt = bool_type() - assert bt == unifier.subst(bt) - - -def test_subst_simple_hole(): - unifier = relay.TypeUnifier() - v1 = relay.IncompleteType(ir.Kind.BaseType) - bt = bool_type() - unifier.insert(v1) - unifier.unify(v1, bt) - assert unifier.subst(v1) == bt - - -def test_subst_typevar_for_typevar(): - unifier = relay.TypeUnifier() - v1 = relay.IncompleteType(ir.Kind.Type) - v2 = relay.IncompleteType(ir.Kind.Type) - unifier.insert(v1) - unifier.insert(v2) - - unifier.unify(v1, v2) - assert unifier.subst(v1) == unifier.subst(v2) - - -def test_subst_typevar_for_typevar_comm(): - unifier = relay.TypeUnifier() - v1 = relay.IncompleteType(ir.Kind.Type) - v2 = relay.IncompleteType(ir.Kind.Type) - unifier.insert(v1) - unifier.insert(v2) - - unifier.unify(v2, v1) - assert unifier.subst(v1) == unifier.subst(v2) - - -def test_subst_concrete_arrow(): - unifier = relay.TypeUnifier() - arr1 = func_type([int_type()], int_type()) - assert unifier.subst(arr1) == arr1 - - -def test_subst_arrow_with_holes(): - unifier = relay.TypeUnifier() - v1 = relay.IncompleteType(ir.Kind.BaseType) - v2 = relay.IncompleteType(ir.Kind.BaseType) - unifier.insert(v1) - unifier.insert(v2) - unifier.unify(v1, int_type()) - unifier.unify(v2, bool_type()) - arr1 = func_type([v1], v2) - arr2 = func_type([int_type()], bool_type()) - assert unifier.subst(arr1) == arr2 - -# def test_subst_concrete_quantifier(): -# unifier = relay.TypeUnifier() -# v1 = relay.IncompleteType(ir.Kind.BaseType) -# tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), int_type()) -# unifier.insert(v1) -# unifier.unify(v1, tq) -# assert unifier.subst(v1) == tq - -# def test_subst_quantifier_with_holes(): -# unifier = relay.TypeUnifier() -# v1 = relay.IncompleteType(ir.Kind.Type) -# v2 = relay.IncompleteType(ir.Kind.Type) -# tq1 = TypeQuantifier(TypeParam("id1", ir.Kind.Type), v2) -# intty = int_type() -# tq2 = TypeQuantifier(TypeParam("id2", ir.Kind.Type), intty) - # unifier.insert(v1) - # unifier.insert(v2) - # unifier.unify(v2, intty) - # unifier.unify(v1, tq1) - # assert unifier.subst(v1) == tq2 - - -def test_subst_concrete_tensor(): - unifier = relay.TypeUnifier() - v1 = relay.IncompleteType(ir.Kind.Type) - unifier.insert(v1) - tt = relay.TensorType(tvm.convert([1, 2, 3]), 'uint1') - unifier.unify(v1, tt) - assert unifier.subst(v1) == tt - -# def test_subst_concrete_product(): -# unifier = relay.TypeUnifier() -# v1 = relay.IncompleteType(ir.Kind.Type) -# unifier.insert(v1) -# bt = bool_type() -# pt = TupleType([bt, bt]) -# unifier.unify(v1, pt) -# assert unifier.subst(v1) == pt - -# def test_subst_product_with_holes(): -# unifier = relay.TypeUnifier() -# v1 = relay.IncompleteType(ir.Kind.Type) -# v2 = relay.IncompleteType(ir.Kind.Type) -# v3 = relay.IncompleteType(ir.Kind.Type) -# unifier.insert(v1) -# unifier.insert(v2) -# unifier.insert(v3) - -# tt1 = relay.TensorType(int_type(), tvm.convert([])) -# tt2 = relay.TensorType(FloatType(32), tvm.convert([])) -# pt1 = TupleType([tt1, v2, v3]) -# unifier.unify(v2, tt2) -# unifier.unify(v3, v2) -# unifier.unify(v1, pt1) -# pt2 = TupleType([tt1, tt2, tt2]) -# assert unifier.subst(v1) == pt2 - -# def test_subst_concrete_ref(): -# unifier = relay.TypeUnifier() -# rt = RefType(bool_type()) -# assert unifier.subst(rt) == rt - -# def test_subst_ref_with_hole(): -# unifier = relay.TypeUnifier() -# v1 = relay.IncompleteType(ir.Kind.Type) -# unifier.insert(v1) - -# unifier.unify(v1, bool_type()) -# rt1 = RefType(v1) -# rt2 = RefType(bool_type()) -# assert unifier.subst(rt1) == rt2 - -# def test_typevar_on_lhs(): -# unifier = relay.TypeUnifier() -# v1 = relay.IncompleteType(ir.Kind.BaseType) -# v2 = relay.IncompleteType(ir.Kind.Type) -# bt = bool_type() -# tq = TypeQuantifier(TypeParam("id1", ir.Kind.Type), bt, bt) -# unifier.insert(v1) -# unifier.insert(v2) -# unified1 = unifier.unify(bt, v1) -# assert unified1 == bt -# unified2 = unifier.unify(tq, v2) -# assert unified2 == tq -# assert unifier.subst(v1) == bt -# assert unifier.subst(v2) == tq From 920768f882412bcf1d2cc3aaf021d61876035720 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 24 Aug 2018 10:51:02 -0700 Subject: [PATCH 036/136] [OP] Structral refactor --- include/tvm/base.h | 6 --- include/tvm/relay/expr.h | 6 ++- include/tvm/relay/op.h | 1 - python/tvm/relay/__init__.py | 6 +++ python/tvm/relay/op.py | 4 +- python/tvm/relay/op/__init__.py | 6 +++ python/tvm/relay/op/_make.py | 4 ++ python/tvm/relay/op/_tensor.py | 4 ++ python/tvm/relay/op/registry.py | 1 + python/tvm/relay/op/tensor.py | 60 +++++++++++++++++++++++++++++ src/relay/compiler/unifier.cc | 2 +- src/relay/{ => ir}/base.cc | 0 src/relay/{ => ir}/expr.cc | 0 src/relay/{ => ir}/op.cc | 0 src/relay/{ => ir}/type.cc | 0 src/relay/op/tensor/elemwise.cc | 46 ++++++++++++++++++++-- tests/python/relay/test_relay_op.py | 8 +++- 17 files changed, 137 insertions(+), 17 deletions(-) create mode 100644 python/tvm/relay/op/__init__.py create mode 100644 python/tvm/relay/op/_make.py create mode 100644 python/tvm/relay/op/_tensor.py create mode 100644 python/tvm/relay/op/registry.py create mode 100644 python/tvm/relay/op/tensor.py rename src/relay/{ => ir}/base.cc (100%) rename src/relay/{ => ir}/expr.cc (100%) rename src/relay/{ => ir}/op.cc (100%) rename src/relay/{ => ir}/type.cc (100%) diff --git a/include/tvm/base.h b/include/tvm/base.h index be848b34cd43..464259bc0527 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -134,11 +134,5 @@ struct NodeFactoryReg { */ #define TVM_ADD_FILELINE "\n\nDefined in " __FILE__ ":L" TVM_STRINGIZE(__LINE__) -#define TVM_REGISTER_NODE_TYPE(TypeName) \ - static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \ - ::dmlc::Registry<::tvm::NodeFactoryReg>::Get()->__REGISTER__(TypeName::_type_key) \ - .set_body([]() { return std::make_shared(); }) - - } // namespace tvm #endif // TVM_BASE_H_ diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index a29c8486ffb6..a4a683297ea5 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -282,8 +282,10 @@ class CallNode : public ExprNode { v->Visit("span", &span); } - TVM_DLL static Call make(Expr op, Array args, Attrs attrs, - Array ty_args); + TVM_DLL static Call make(Expr op, + Array args, + Attrs attrs = Attrs(), + Array ty_args = Array()); static constexpr const char* _type_key = "relay.Call"; TVM_DECLARE_NODE_TYPE_INFO(CallNode, ExprNode); diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index f7e1cfbbc8c2..cae3d9db6920 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -276,7 +276,6 @@ class OpMap { const GenericOpMap& map_; }; - // internal macros to make #define RELAY_REGISTER_VAR_DEF \ static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::OpRegistry & __make_ ## RelayOp diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 037d71854689..f94b572f6b44 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -4,6 +4,10 @@ from . import expr from . import op +# import all operators in the loop namespace +from .op import * + + # Span Span = base.Span @@ -18,6 +22,7 @@ # Expr Constant = expr.Constant Tuple = expr.Tuple +# TODO: GlobalVar, LocalVar-> var LocalVar = expr.LocalVar GlobalVar = expr.GlobalVar Param = expr.Param @@ -25,3 +30,4 @@ Call = expr.Call Let = expr.Let If = expr.If +Var = LocalVar diff --git a/python/tvm/relay/op.py b/python/tvm/relay/op.py index dae498b66c12..d54edf47c5ee 100644 --- a/python/tvm/relay/op.py +++ b/python/tvm/relay/op.py @@ -2,7 +2,7 @@ from __future__ import absolute_import as _abs import sys -from .._ffi.function import _init_api + from .._ffi.node import convert_to_node from . import _make from ..make import node as _make_node @@ -33,5 +33,5 @@ def _init_ops(): f = _create_op(name.value) setattr(module, f.__name__, f) -_init_api("relay.op", __name__) + _init_ops() diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py new file mode 100644 index 000000000000..02e49ec40ff8 --- /dev/null +++ b/python/tvm/relay/op/__init__.py @@ -0,0 +1,6 @@ +"""Relay core operators.""" +# operator defs +from .tensor import * + +# operator registry +from . import _tensor diff --git a/python/tvm/relay/op/_make.py b/python/tvm/relay/op/_make.py new file mode 100644 index 000000000000..79c86cbb0254 --- /dev/null +++ b/python/tvm/relay/op/_make.py @@ -0,0 +1,4 @@ +"""Constructor APIs""" +from ..._ffi.function import _init_api + +_init_api("relay.op._make", __name__) diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py new file mode 100644 index 000000000000..08dedee0923c --- /dev/null +++ b/python/tvm/relay/op/_tensor.py @@ -0,0 +1,4 @@ +"""Backend compiler related feature regsitration""" + + + diff --git a/python/tvm/relay/op/registry.py b/python/tvm/relay/op/registry.py new file mode 100644 index 000000000000..d7426429ef6f --- /dev/null +++ b/python/tvm/relay/op/registry.py @@ -0,0 +1 @@ +"""Mechanism to work with operator registry.""" diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py new file mode 100644 index 000000000000..7155db3a4cd5 --- /dev/null +++ b/python/tvm/relay/op/tensor.py @@ -0,0 +1,60 @@ +"""Basic tensor operations.""" +from __future__ import absolute_import as _abs +from . import _make + +# We create a wrapper function for each operator in the +# python side to call into the positional _make.OpName function. +# +# We make this decision so that we can: +# - Have declare python docstring for each function +# - Enable keyword arguments easily +# - Not put too much burden on FFI to support complicated features +# like default value and keyword arguments + + +def log(data): + """Take log of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.log(data) + + +def exp(data): + """Take exp of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.exp(data) + + +def sqrt(data): + """Take sqrt of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.sqrt(data) diff --git a/src/relay/compiler/unifier.cc b/src/relay/compiler/unifier.cc index 5c0fbcf3ec71..b7cc296cc5db 100644 --- a/src/relay/compiler/unifier.cc +++ b/src/relay/compiler/unifier.cc @@ -9,7 +9,7 @@ #include "tvm/relay/compiler/alpha_eq.h" #include "./unifier.h" #include "./type_visitor.h" -#include "./type_subst.h" +//#include "./type_subst.h" // #include "tvm/relay/typeck/kindchecker.h" namespace tvm { diff --git a/src/relay/base.cc b/src/relay/ir/base.cc similarity index 100% rename from src/relay/base.cc rename to src/relay/ir/base.cc diff --git a/src/relay/expr.cc b/src/relay/ir/expr.cc similarity index 100% rename from src/relay/expr.cc rename to src/relay/ir/expr.cc diff --git a/src/relay/op.cc b/src/relay/ir/op.cc similarity index 100% rename from src/relay/op.cc rename to src/relay/ir/op.cc diff --git a/src/relay/type.cc b/src/relay/ir/type.cc similarity index 100% rename from src/relay/type.cc rename to src/relay/ir/type.cc diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc index 8b759bfbc07c..79301a7fac24 100644 --- a/src/relay/op/tensor/elemwise.cc +++ b/src/relay/op/tensor/elemwise.cc @@ -3,21 +3,59 @@ * \file elemwise.cc * \brief Elementwise operators. */ +#include #include namespace tvm { namespace relay { -RELAY_REGISTER_OP("log") +// Quick helper macro +// - Expose a positional make function to construct the node. +// - Register op to the registry. +// +// We make the decision to always only expose positional argument. +// We will do rewrapping in the frontend to support language +// sugars such as keyword arguments and default value. +// +#define RELAY_REGISTER_UNARY_OP(OpName) \ + TVM_REGISTER_API("relay.op._make." OpName) \ + .set_body_typed([](Expr data) { \ + static const Op& op = Op::Get(OpName); \ + return CallNode::make(op, {data}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP(OpName) \ + .set_num_inputs(1) \ + .add_argument("data", "Tensor", "The input tensor.") + + +RELAY_REGISTER_UNARY_OP("log") .describe(R"code(Returns the log input array, computed element-wise. .. math:: log(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor."); +.set_support_level(1); + + +RELAY_REGISTER_UNARY_OP("exp") +.describe(R"code(Returns the exp input array, computed element-wise. + +.. math:: + \exp(x) + +)code" TVM_ADD_FILELINE) +.set_support_level(1); + + +RELAY_REGISTER_UNARY_OP("sqrt") +.describe(R"code(Returns the sqrt input array, computed element-wise. + +.. math:: + sqrt(x) + +)code" TVM_ADD_FILELINE) +.set_support_level(1); } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_relay_op.py b/tests/python/relay/test_relay_op.py index 93316da8ec41..4235dd918d93 100644 --- a/tests/python/relay/test_relay_op.py +++ b/tests/python/relay/test_relay_op.py @@ -1,7 +1,13 @@ from tvm import relay def test_op_level1(): - assert relay.op.log + x = relay.Var("x") + + for op_name in ["log", "exp", "sqrt"]: + y = getattr(relay, op_name)(x) + assert y.op.name == op_name + assert y.op.support_level == 1 + assert y.args[0] == x if __name__ == "__main__": From 4cec4f119cd355ebcf1b1641ff0560e7360cbfc0 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 24 Aug 2018 14:34:28 -0700 Subject: [PATCH 037/136] Add type_subst back --- src/relay/compiler/type_subst.cc | 39 ++++++++++++++++++++++++++++++++ src/relay/compiler/type_subst.h | 19 ++++++++++++++++ 2 files changed, 58 insertions(+) create mode 100644 src/relay/compiler/type_subst.cc create mode 100644 src/relay/compiler/type_subst.h diff --git a/src/relay/compiler/type_subst.cc b/src/relay/compiler/type_subst.cc new file mode 100644 index 000000000000..6650f59bad51 --- /dev/null +++ b/src/relay/compiler/type_subst.cc @@ -0,0 +1,39 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file type_subst.cc + * \brief Function for substituting a concrete type in place of a type ID + */ +#include "./type_subst.h" +#include "./type_visitor.h" + +namespace tvm { +namespace relay { + +struct TypeSubst : TypeFVisitor { + tvm::Map subst_map; + + explicit TypeSubst(tvm::Map subst_map) + : subst_map(subst_map) {} + + Type VisitType_(const TypeParamNode *op) override { + auto id = GetRef(op); + if (subst_map.find(id) != subst_map.end()) { + return this->subst_map[id]; + } else { + return id; + } + } +}; + +Type type_subst(const Type &type, const TypeParam &target, const Type &subst) { + TypeSubst ty_sub({ {target, subst} }); + return ty_sub.VisitType(type); +} + +Type type_subst(const Type &type, tvm::Map subst_map) { + TypeSubst ty_sub(subst_map); + return ty_sub.VisitType(type); +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/compiler/type_subst.h b/src/relay/compiler/type_subst.h new file mode 100644 index 000000000000..0bf0de5a4b85 --- /dev/null +++ b/src/relay/compiler/type_subst.h @@ -0,0 +1,19 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file typeck/type_subst.h + * \brief Utility function for substituting types + */ +#ifndef TVM_RELAY_TYPECK_TYPE_SUBST_H_ +#define TVM_RELAY_TYPECK_TYPE_SUBST_H_ + +#include "tvm/relay/ir.h" + +namespace tvm { +namespace relay { + +Type type_subst(const Type & type, const TypeParam & target, const Type & subst); +Type type_subst(const Type &type, tvm::Map subst_map); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_TYPECK_TYPE_SUBST_H_ From b1ba347464bb6c1455d50abca650bb8ab2427f25 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 24 Aug 2018 14:35:06 -0700 Subject: [PATCH 038/136] Clean up code while refactoring inference --- include/tvm/relay/compiler/environment.h | 15 +- include/tvm/relay/expr.h | 5 + include/tvm/relay/op.h | 20 ++- python/tvm/relay/env.py | 105 ++++-------- python/tvm/relay/ir_builder.py | 45 +++--- python/tvm/relay/op.py | 11 +- src/relay/compiler/environment.cc | 2 +- src/relay/compiler/type_infer.cc | 195 ++++++++++++----------- src/relay/ir/expr.cc | 6 +- src/relay/op/tensor/elemwise.cc | 5 +- tests/python/relay/test_typechecker.py | 9 +- 11 files changed, 214 insertions(+), 204 deletions(-) diff --git a/include/tvm/relay/compiler/environment.h b/include/tvm/relay/compiler/environment.h index 5b33e781b399..d5a3ddd73f77 100644 --- a/include/tvm/relay/compiler/environment.h +++ b/include/tvm/relay/compiler/environment.h @@ -40,6 +40,7 @@ class EnvironmentNode : public RelayNode { private: /*! A map from string names to GlobalIds, ensures global uniqueness. */ tvm::Map global_map_; + tvm::Map type_func_map_; // /*! \brief A map from file names to source fragments. */ // SourceMap source_map_ @@ -57,14 +58,12 @@ class EnvironmentNode : public RelayNode { void VisitAttrs(tvm::AttrVisitor* v) final {} TVM_DLL static Environment make( - std::unordered_map global_funcs); - - /*! Add an operator to the Enviroment. */ - void register_op(const Op& op); - void add(const GlobalVar& var, const Function & func, bool update = false); - void try_add(const GlobalVar& var, const Function & func, bool update=false); - void update(const GlobalVar& var, const Function & func); - void remove(const GlobalVar& var); + tvm::Map global_funcs); + + void Add(const GlobalVar& var, const Function & func, bool update = false); + void TryAdd(const GlobalVar& var, const Function & func, bool update=false); + void Update(const GlobalVar& var, const Function & func); + void Remove(const GlobalVar& var); GlobalVar GetGlobalVar(const std::string& str); diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index a4a683297ea5..ff11a41a6e5f 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -364,6 +364,11 @@ class IfNode : public ExprNode { RELAY_DEFINE_NODE_REF(If, IfNode, Expr); +// template +// T Downcast(U u) { + +// } + } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_H_ diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index cae3d9db6920..be81f54ecd69 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -14,6 +14,7 @@ #include #include "./base.h" +#include "./type.h" #include "./expr.h" #include "../attrs.h" @@ -33,6 +34,8 @@ class OpNode : public relay::ExprNode { public: /*! \brief name of the operator */ std::string name; + + Type op_type; /*! * \brief detailed description of the operator * This can be used to generate docstring automatically for the operator. @@ -67,7 +70,7 @@ class OpNode : public relay::ExprNode { } static constexpr const char* _type_key = "relay.Op"; - TVM_DECLARE_NODE_TYPE_INFO(OpNode, Node); + TVM_DECLARE_NODE_TYPE_INFO(OpNode, ExprNode); private: // friend class @@ -145,6 +148,13 @@ class OpRegistry { inline OpRegistry& add_argument(const std::string &name, const std::string &type, const std::string &description); + /*! + * \brief Attach the type function corresponding to the return type. + * \param ty_func The type function to register for the return type. + * \return reference to self. + */ + inline OpRegistry& add_type_func(const std::string & type_func_name); + /*! * \brief Set the type key of attributes. * \param type_key The type of of the attrs field.x @@ -329,6 +339,14 @@ inline OpRegistry& OpRegistry::add_argument(const std::string &name, return *this; } + inline OpRegistry& OpRegistry::add_type_func(const std::string & type_func_name) { + auto type_func = TypeFunctionNode::make(type_func_name, 0); + for (auto arg : get()->arguments) { + std::cout << arg << std::endl; + } + return *this; + } + inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) get()->num_inputs = n; return *this; diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py index 9bd63476f1fb..c63197fa8509 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/env.py @@ -1,98 +1,57 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import """A global environment storing everything needed to interpret or compile a Realy program.""" from typing import Union, List -from relay.ir import register_relay_node, NodeBase -from relay.ir import GlobalId, OperatorId, Item, FileId, Span, ShapeExtension -from relay.ir import Operator, Defn -from relay._env import * +from .base import register_relay_node, NodeBase +from . import _make +# from relay.ir import GlobalId, OperatorId, Item, FileId, Span, ShapeExtension +# from relay.ir import Operator, Defn +# from relay._env import * import tvm # Move me to C++ if possible. __tgt_host__ = __tgt__ = "llvm" __relay_tvm_context__ = tvm.cpu() -ADD_ID = "__add__" -SUB_ID = "__sub__" -MUL_ID = "__mul__" -DIV_ID = "__div__" -NEG_ID = "__neg__" -LT_ID = "__lt__" -LE_ID = "__le__" -GT_ID = "__gt__" -GE_ID = "__ge__" -EQ_ID = "__eq__" -NE_ID = "__ne__" - @register_relay_node class Environment(NodeBase): """The global Relay environment containing definitions, primitives, options, and more. """ - def add(self, item: Item) -> None: - return Environment_add(self, item) - - def global_id(self, name: str) -> GlobalId: - return Environment_global_id(self, name) - - def operator_id(self, name: str) -> OperatorId: - return Environment_operator_id(self, name) - - def lookup(self, ident: Union[GlobalId, OperatorId]) -> Item: - if isinstance(ident, OperatorId): - return Environment_lookup_operator(self, ident) - else: - return Environment_lookup_global(self, ident) - - def add_source(self, file_name: str, source: str) -> FileId: - return Environment_add_source(self, file_name, source) - - def report_error(self, message: str, span: Span) -> None: - return Environment_report_error(self, message, span) - - def register_shape_ext(self, ext: ShapeExtension) -> None: - return Environment_register_shape_ext(self, ext) - - def display_errors(self) -> None: - return Environment_display_errors(self) - - def operators(self) -> List[Operator]: - return Environment_get_operators(self) - - def defns(self) -> List[Defn]: - return Environment_get_defns(self) - - def tvm_context(self): - return __relay_tvm_context__ + def __init__(self, funcs) -> None: + self.__init_handle_by_constructor__(_make.Environment, funcs) - def add_id(self) -> OperatorId: - return self.operator_id(ADD_ID) + # def add(self, item: Item) -> None: + # return Environment_add(self, item) - def sub_id(self) -> OperatorId: - return self.operator_id(SUB_ID) + # def global_id(self, name: str) -> GlobalId: + # return Environment_global_id(self, name) - def mul_id(self) -> OperatorId: - return self.operator_id(MUL_ID) + # def operator_id(self, name: str) -> OperatorId: + # return Environment_operator_id(self, name) - def div_id(self) -> OperatorId: - return self.operator_id(DIV_ID) + # def lookup(self, ident: Union[GlobalId, OperatorId]) -> Item: + # if isinstance(ident, OperatorId): + # return Environment_lookup_operator(self, ident) + # else: + # return Environment_lookup_global(self, ident) - def neg_id(self) -> OperatorId: - return self.operator_id(NEG_ID) + # def add_source(self, file_name: str, source: str) -> FileId: + # return Environment_add_source(self, file_name, source) - def lt_id(self) -> OperatorId: - return self.operator_id(LT_ID) + # def report_error(self, message: str, span: Span) -> None: + # return Environment_report_error(self, message, span) - def le_id(self) -> OperatorId: - return self.operator_id(LE_ID) + # def register_shape_ext(self, ext: ShapeExtension) -> None: + # return Environment_register_shape_ext(self, ext) - def gt_id(self) -> OperatorId: - return self.operator_id(GT_ID) + # def display_errors(self) -> None: + # return Environment_display_errors(self) - def ge_id(self) -> OperatorId: - return self.operator_id(GE_ID) + # def operators(self) -> List[Operator]: + # return Environment_get_operators(self) - def eq_id(self) -> OperatorId: - return self.operator_id(EQ_ID) + # def defns(self) -> List[Defn]: + # return Environment_get_defns(self) - def ne_id(self) -> OperatorId: - return self.operator_id(NE_ID) + # def tvm_context(self): + # return __relay_tvm_context__ diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index af83c9948be2..07927aef7d24 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -1,8 +1,8 @@ from typing import Any import numpy as np import tvm -from . import type as ty -from . import expr +from .type import FloatType, IntType, BoolType, UIntType, FuncType +from .expr import Expr, Call, Constant, Let, LocalVar, Param, Function from . import op as _op class ExprBuilder(): @@ -10,7 +10,7 @@ def __init__(self, expr): self.expr = expr def __call__(self, *args): - return ExprBuilder(mk.Call(self.expr, list(args), None, None)) + return ExprBuilder(Call(self.expr, list(args), None, None)) def convert(arg: Any, ctxt=tvm.cpu(0)) -> tvm.nd.NDArray: """Convert Python values into the appropriate types @@ -30,12 +30,12 @@ def convert(arg: Any, ctxt=tvm.cpu(0)) -> tvm.nd.NDArray: # raise Exception(f"can't convert {type(arg)} to a Relay AST") raise Exception(f"unsupported argument type {type(arg)}") -def into_ast(arg: Any, ctxt=tvm.cpu(0)) -> expr.Expr: +def into_ast(arg: Any, ctxt=tvm.cpu(0)) -> Expr: if isinstance(arg, tuple): raise Exception("..") else: value = convert(arg, ctxt) - return ExprBuilder(mk.Constant(value)) + return ExprBuilder(Constant(value)) class WithScope(object): """Auxiliary scope with""" @@ -61,11 +61,18 @@ def __init__(self, params, ret_type, body, type_params): def param_ids(self): return [p.var for p in self.params] + def to_func(self): + return Function( + self.params, + self.ret_type, + self.body, + self.type_params) + def _mk_let(bindings, ret_value): let_expr = ret_value for var, value in reversed(list(bindings.items())): - let_expr = mk.Let(var, value, let_expr, None) + let_expr = Let(var, value, let_expr, None) return let_expr @@ -79,14 +86,14 @@ def __init__(self): def bind(self, name, type, value): - lv = mk.LocalVar(name) + lv = LocalVar(name) self.scopes[-1][name] = lv self.bindings[-1][lv] = value return lv def let(self, name, value, value_type=None): - if not (isinstance(value, expr.Expr) or isinstance(value, ExprBuilder)): + if not (isinstance(value, Expr) or isinstance(value, ExprBuilder)): value = into_ast(value) if isinstance(value, ExprBuilder): @@ -97,9 +104,9 @@ def let(self, name, value, value_type=None): def function(self, *params): relay_params = [] for name, ty in params: - lv = mk.LocalVar(name) + lv = LocalVar(name) self.scopes[-1][name] = lv - relay_params.append(mk.Param(lv, ty)) + relay_params.append(Param(lv, ty)) # self.params.append(relay_params) @@ -108,7 +115,10 @@ def function(self, *params): def _on_exit(): bindings = self.bindings.pop() scope = self.scopes.pop() - # params = self.params.pop() + ret_value = self.ret_value + body = _mk_let(bindings, ret_value) + self.ret_value = None + pfunc.body = body return WithScope(pfunc, _on_exit) @@ -124,9 +134,6 @@ def ret(self, x): def fn_params(self): pass - def op(self, name): - pass - def get(self): """Get the full program""" bindings = self.bindings.pop() @@ -152,16 +159,16 @@ def int_dtype(): return 'uint1' def int_type(bits=32, lanes=1): - return mk.IntType(bits, lanes) + return IntType(bits, lanes) def uint_type(bits=32, lanes=1): - return mk.UIntType(bits, lanes) + return UIntType(bits, lanes) def float_type(bits=32, lanes=1): - return mk.FloatType(bits, lanes) + return FloatType(bits, lanes) def bool_type(lanes=1): - return mk.BoolType(lanes) + return BoolType(lanes) def func_type(args, ret_type, type_params=[], type_constraints=[]): - return mk.FuncType(args, ret_type, type_params, type_constraints) + return FuncType(args, ret_type, type_params, type_constraints) diff --git a/python/tvm/relay/op.py b/python/tvm/relay/op.py index d54edf47c5ee..d36a433e1e85 100644 --- a/python/tvm/relay/op.py +++ b/python/tvm/relay/op.py @@ -6,6 +6,12 @@ from .._ffi.node import convert_to_node from . import _make from ..make import node as _make_node +from .expr import Expr, Call +from .base import register_relay_node + +@register_relay_node +class Op(Expr): + pass def _create_op(op_name): op = _GetOp(op_name) @@ -19,8 +25,9 @@ def _create_op(op_name): def _op_func(*args, **kwargs): args = convert_to_node(args) # Need work to make sure constructor matches - return _make.Call(op, args, - attrs = _make.node(attrs_type_key, **kwargs)) + # can support kwargs later + attrs = _make_node(attrs_type_key, **kwargs) + return Call(op, args, None, []) _op_func.__name__ = op.name return _op_func diff --git a/src/relay/compiler/environment.cc b/src/relay/compiler/environment.cc index 7ce0785f4f8f..a1c6b31076e3 100644 --- a/src/relay/compiler/environment.cc +++ b/src/relay/compiler/environment.cc @@ -18,7 +18,7 @@ using tvm::IRPrinter; using namespace tvm::runtime; Environment EnvironmentNode::make( - std::unordered_map global_funcs) { + tvm::Map global_funcs) { std::shared_ptr n = std::make_shared(); n->items = std::move(global_funcs); return Environment(n); diff --git a/src/relay/compiler/type_infer.cc b/src/relay/compiler/type_infer.cc index 7304bdabe486..40f14b517951 100644 --- a/src/relay/compiler/type_infer.cc +++ b/src/relay/compiler/type_infer.cc @@ -94,7 +94,7 @@ class TypeInferencer : private ExprFunctor { CheckedExpr Infer(const Expr & expr); - Type instantiate(Type t, tvm::Array &ty_args); + Type instantiate(FuncType fn_ty, tvm::Array &ty_args); void report_error(const std::string & msg, Span sp); [[ noreturn ]] void fatal_error(const std::string & msg, Span sp); @@ -192,10 +192,9 @@ class TypeInferencer : private ExprFunctor { throw Error("TupleNode NYI"); } - CheckedExpr TypeInferencer::VisitExpr_(const ParamNode *op) { - // Param param = GetRef(op); - // return { resolve(param->type); - throw Error("ParamNode NYI"); + CheckedExpr TypeInferencer::VisitExpr_(const ParamNode *param) { + auto rtype = resolve(param->type); + return { ParamNode::make(param->var, rtype), rtype }; } // // We should probably generalize the subst code. @@ -236,25 +235,32 @@ class TypeInferencer : private ExprFunctor { // }; CheckedExpr TypeInferencer::VisitFunction(const Function &f, bool generalize) { - throw Error("FunctionNode NYI"); - // // enter params into context - // auto fn_type = this->with_frame([&]() { - // std::vector arg_types; - // for (auto arg : f->params) { - // this->Check(arg); - // Type arg_type; - // // if arg type can be simply evaluated, try it - // // should be replaced with symbolic evaluation once it exists, - // // you will not have attr information at this point - // try { - // arg_type = simple_eval_shape(arg->type); - // } catch (const dmlc::Error &e) { - // this->report_error(e.what(), arg->span); - // arg_type = arg->type; - // } - // arg_types.push_back(arg_type); - // this->local_stack.insert(arg->id, arg_type); - // } + // First we add the parameters to the context allowing us to check their + // types. + + // TODO(@jroesch): support polymorphism + + std::vector param_types; + std::vector params; + + return this->with_frame([&]() -> CheckedExpr { + for (auto param : f->params) { + CheckedExpr checked_param = this->Infer(param); + Type arg_type; + param_types.push_back(checked_param.type); + params.push_back(GetRef(checked_param.expr.as())); + this->local_stack.insert(param->var, checked_param.type); + } + + auto checked_body = this->Infer(f->body); + auto inferred_rtype = checked_body.type; + auto annotated_rtype = resolve(f->ret_type); + + auto unified_rtype = this->unify(inferred_rtype, annotated_rtype, f->span); + + return { FunctionNode::make(params, unified_rtype, checked_body.expr, {}), + FuncTypeNode::make(param_types, unified_rtype, {}, {}) }; + }); // // typecheck body and ensure that it matches stated return type // // TODO(sslyu): should the unified return type override the annotated @@ -332,95 +338,96 @@ class TypeInferencer : private ExprFunctor { // } // return fn_type; - } CheckedExpr TypeInferencer::VisitExpr_(const FunctionNode *op) { return this->VisitFunction(GetRef(op), false); } - // Type TypeInferencer::instantiate(Type t, tvm::Array &ty_args) { - // const TypeQuantifierNode *ty_quant; - // while ((ty_quant = t.as())) { - // TypeParam id = ty_quant->id; - // TypeVar fresh = TypeVarNode::make(id->kind); - // this->unifier->insert(fresh); - // ty_args.push_back(fresh); - // t = type_subst(ty_quant->boundType, id, fresh); - // } + Type TypeInferencer::instantiate(FuncType fn_ty, tvm::Array &ty_args) { + // const TypeQuantifierNode *ty_quant; + // while ((ty_quant = t.as())) { + // TypeParam id = ty_quant->id; + // TypeVar fresh = TypeVarNode::make(id->kind); + // this->unifier->insert(fresh); + // ty_args.push_back(fresh); + // t = type_subst(ty_quant->boundType, id, fresh); + // } - // if (!check_kind(t)) { - // this->fatal_error("Kind rules broken when instantiating type - // variables", - // t->span); - // } + // if (!check_kind(t)) { + // this->fatal_error("Kind rules broken when instantiating type + // variables", + // t->span); + // } - // return t; - // } + // return t; + } CheckedExpr TypeInferencer::VisitExpr_(const CallNode *op) { - throw Error("CallNode"); - // Call c = GetRef(op); - // Type fn_ty = this->Check(c->fn); - - // RELAY_LOG(INFO) << "TypeInferencer::VisitExpr_ op=" << c << std::endl - // << "fn_ty=" << fn_ty << std::endl; - - // // for each type id, insert a type variable and unify with the argument - // types - // // in order - // // to obtain the concrete instantiation - // tvm::Array ty_args; - // if (const TypeQuantifierNode *ty_quant = fn_ty.as()) - // { - // fn_ty = instantiate(GetRef(ty_quant), ty_args); - // } + Call c = GetRef(op); - // if (!fn_ty.as()) { - // this->fatal_error("only expressions with function types can be called", - // c->fn->span); - // } + auto checked_op = this->Infer(c->op); - // // evaluate all shapes up front (require that types be fully concrete) - // Type evaluated = evaluate_concrete_shape(fn_ty, op->attrs); - // std::vector arg_types; + RELAY_LOG(INFO) << "TypeInferencer::VisitExpr_ op=" << c << std::endl + << "fn_ty=" << fn_ty << std::endl; - // TypeArrow arrow = GetRef(evaluated.as()); - // // TODO(sslyu): figure out how to handle type ids - // // fn_ty = instantiate(fn_ty, ty_args); - // for (auto arg : c->args) { - // auto ty = this->Check(arg); - // arg_types.push_back(ty); - // } + auto fn_ty_node = checked_op.expr.as(); - // auto type_arity = arrow->arg_types.size(); - // auto number_of_args = arg_types.size(); - // if (type_arity != number_of_args) { - // if (type_arity < number_of_args) { - // this->fatal_error("the function is provided too many arguments", - // c->span); - // } else { - // this->fatal_error("the function is provided too few arguments", - // c->span); - // } - // } + if (!fn_ty_node) { + this->fatal_error("only expressions with function types can be called", c->fn->span); + } - // for (size_t i = 0; i < arrow->arg_types.size(); i++) { - // this->unify(arrow->arg_types[i], arg_types[i], c->args[i]->span); - // } + // We now have a function type. + FuncType fn_ty = GetRef(fn_ty_node); - // // After we unify the arguments we should know more about the type - // // arguments, let's run a quick pass over them to find new - // representatives. for (size_t i = 0; i < ty_args.size(); i++) { - // ty_args.Set(i, this->unifier->subst(ty_args[i])); - // } + tvm::Array ty_args; + if (ty_args.size() != 0) { + throw Error("found manually suplied type args, not supported"); + } + + fn_ty = instantiate(fn_ty, ty_args); + + std::vector arg_types; + + + // TODO(sslyu): figure out how to handle type ids + // fn_ty = instantiate(fn_ty, ty_args); + for (auto arg : c->args) { + auto checked_arg = this->Infer(arg); + arg_types.push_back(checked_arg.type); + } + + auto type_arity = fn_ty->arg_types.size(); + auto number_of_args = arg_types.size(); + + if (type_arity != number_of_args) { + if (type_arity < number_of_args) { + this->fatal_error("the function is provided too many arguments", + c->span); + } else { + this->fatal_error("the function is provided too few arguments", + c->span); + } + } + + for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { + this->unify(fn_ty->arg_types[i], arg_types[i], c->args[i]->span); + } + + // After we unify the arguments we should know more about the type + // arguments, let's run a quick pass over them to find new + // representatives. + + for (size_t i = 0; i < ty_args.size(); i++) { + ty_args.Set(i, this->unifier->subst(ty_args[i])); + } - // // Write the type arguments into the call node, recording what inference - // // solves. This solution might need some work. - // c->ty_args = ty_args; + // Write the type arguments into the call node, recording what inference + // solves. This solution might need some work. + c->ty_args = ty_args; - // return arrow->ret_type; + return { new_call, call_type } } // Type TypeInferencer::VisitExpr_(const DebugNode *op) { diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 38df81940e48..3a3ef1b52604 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -114,7 +114,11 @@ TVM_REGISTER_API("relay._make.Function") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const FunctionNode *node, tvm::IRPrinter *p) { - p->stream << "FunctionNode(TODO)"; + p->stream << "FunctionNode(" << + node->params << ", " << + node->ret_type << ", " << + node->body << ", " << + node->type_params << ")"; }); Call CallNode::make(Expr op, Array args, Attrs attrs, diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc index 79301a7fac24..50c864650ff4 100644 --- a/src/relay/op/tensor/elemwise.cc +++ b/src/relay/op/tensor/elemwise.cc @@ -36,6 +36,9 @@ RELAY_REGISTER_UNARY_OP("log") )code" TVM_ADD_FILELINE) .set_support_level(1); +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.add_type_func("Broadcast"); RELAY_REGISTER_UNARY_OP("exp") @@ -57,5 +60,5 @@ RELAY_REGISTER_UNARY_OP("sqrt") )code" TVM_ADD_FILELINE) .set_support_level(1); -} // namespace relay +} // namespace relayv } // namespace tvm diff --git a/tests/python/relay/test_typechecker.py b/tests/python/relay/test_typechecker.py index 6a16aadcb002..9c050ecd62d0 100644 --- a/tests/python/relay/test_typechecker.py +++ b/tests/python/relay/test_typechecker.py @@ -2,10 +2,11 @@ for expressions. """ from tvm.relay.type_infer import check_expr -from tvm.relay.ir_builder import IRBuilder, float_type, op +from tvm.relay.ir_builder import IRBuilder, float_type, op, func_type +from tvm.relay.env import Environment def has_type(expr, typ): - env = mk.Environment({}) + env = Environment({}) checked_expr = check_expr(env, expr) return checked_expr.checked_type() == typ @@ -20,11 +21,11 @@ def test_monomorphic_let(): def test_single_op(): - "Program: fn (x : int32) { let t1 = f(x); t1 }" + "Program: fn (x : float32) { let t1 = f(x); t1 }" b = IRBuilder() f = op('log') with b.function(('x', float_type())) as func: x, = func.param_ids() t1 = b.let('t1', f(x)) b.ret(t1) - import pdb; pdb.set_trace() + assert has_type(func.to_func(), func_type([float_type()], float_type())) From 3a4ff266273c594ac4bb64df0b9a8d22fdf3bfb8 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 24 Aug 2018 14:35:19 -0700 Subject: [PATCH 039/136] Restore old TVM backend code --- python/tvm/relay/tvm_rts_backend.py | 239 ++++++++++++++++++++++++++++ 1 file changed, 239 insertions(+) create mode 100644 python/tvm/relay/tvm_rts_backend.py diff --git a/python/tvm/relay/tvm_rts_backend.py b/python/tvm/relay/tvm_rts_backend.py new file mode 100644 index 000000000000..137230ace63a --- /dev/null +++ b/python/tvm/relay/tvm_rts_backend.py @@ -0,0 +1,239 @@ +"""A compiler from Relay programs to TVM's graph runtime. +""" +import json +from typing import Dict, Any, List, Tuple + +import attr + +from relay.frontend import get_env +from . import ir +from .tyck import get_checked_type +from .opt import AbstractExprVisitor, compile_ops_to_module +from ._make import Operator_is_generic + + +@attr.s(auto_attribs=True) +class NodeRef: + ident: int + index: int = 0 + version: int = 0 + + def to_json(self) -> Any: + return [self.ident, self.index, self.version] + + +@attr.s(auto_attribs=True) +class Node(): + name: str + attrs: Dict[str, Any] + is_output: bool + + def to_json(self) -> Any: + raise Exception("Abstract method, please implement me.") + + +@attr.s(auto_attribs=True) +class InputNode(Node): + """An input node in the graph representation we lower to before NNVM's graph.""" + is_output: bool = False + + def to_json(self): + return { + "op": "null", + "name": self.name, + "inputs": [] + } + + +@attr.s(auto_attribs=True) +class OpNode(Node): + """An operator node in the graph representation we lower to before NNVM's graph.""" + op_name: str + inputs: List[NodeRef] + op_attrs: Dict[str, Any] + is_output: bool = False + + def to_json(self) -> Any: + attrs = dict.copy(self.op_attrs) + # Extend ops with extra info. + attrs['func_name'] = self.op_name + # When do we flatten? + attrs['flatten_data'] = "0" + # Fix me! + attrs['num_inputs'] = str(len(self.inputs)) + attrs['num_outputs'] = "1" + + return { + "op": "tvm_op", + "name": self.name, + "attrs": attrs, + "inputs": self.inputs + } + + +def from_tensor(typ: ir.TensorType) -> Tuple[str, List[int]]: + dtype = typ.dtype.dtype + shape = typ.shape + dims = [] + for dim in shape.shapes: + dims.append(dim.value) + return dtype, dims + + +class TVMRTSCompiler(AbstractExprVisitor[NodeRef]): + """The compiler from Relay to the TVM runtime system.""" + nodes: List[Node] + id_map: Dict[ir.LocalId, NodeRef] + + def __init__(self) -> None: + self.nodes = [] + self.id_map = {} + + def add_node(self, node: Node) -> NodeRef: + self.nodes.append(node) + ident = len(self.nodes) - 1 + return NodeRef(ident) + + def add_binding(self, ident: ir.LocalId, ref: NodeRef) -> None: + self.id_map[ident] = ref + + def let_bind(self, ident: ir.LocalId, node: Node) -> NodeRef: + ref = self.add_node(node) + self.add_binding(ident, ref) + return ref + + def get_node(self, ref: NodeRef) -> Node: + return self.nodes[ref.ident] + + def lookup(self, ident: ir.LocalId) -> NodeRef: + return self.id_map[ident] + + def compile(self, func: ir.Function) -> None: + """Compile a single function into a graph.""" + # TODO: (@jroesch) Restore me + # assert len(fn.ty_params) == 0 + + # First we convert all the parameters into input nodes. + params = func.params + + for param in params: + dtype, shape = from_tensor(param.type) + node = InputNode(f"{param.id.name}", { + "shape": shape, + "dtype": dtype, + }) + self.let_bind(param.id, node) + + # Then we compile the body into a graph which can depend + # on input variables. + output_ref = self.visit(func.body) + + # Finally we retreive return value of program, which will + # become our output node. + self.get_node(output_ref).is_output = True + + def visit_let(self, let: ir.Let) -> NodeRef: + """Visit the Let binding, by first traversing its value, + then setting the metadata on the returned NodeRef. + + Finally visit the body, and return the NodeRef corresponding + to it. + """ + ident = let.id + val = let.value + body = let.body + + # Need to add type info? + val_ref = self.visit(val) + dtype, shape = from_tensor(get_checked_type(val)) + val_node = self.get_node(val_ref) + val_node.attrs["dtype"] = dtype + val_node.attrs["shape"] = shape + self.add_binding(ident, val_ref) + return self.visit(body) + + def visit_local_id(self, ident: ir.LocalId) -> NodeRef: + return self.lookup(ident) + + def visit_call(self, call: ir.Call) -> NodeRef: + inputs = [] + for arg in call.args: + inputs.append(self.visit(arg).to_json()) + + # need to deal with name mangle + op_name = call.fn.name + op_node = OpNode("call_name", {}, op_name, inputs, {}) + return self.add_node(op_node) + + def to_json(self) -> str: + """Convert the sequence of nodes stored by the compiler into the + JSON format defined in: https://docs.tvm.ai/dev/nnvm_json_spec.html. + """ + nodes = [] + # First we compute "nodes" field. + for node in self.nodes: + nodes.append(node.to_json()) + + arg_nodes = [] + heads = [] + # Compute "arg_nodes" and "heads" fields. + for i, node in enumerate(self.nodes): + if isinstance(node, InputNode): + arg_nodes.append(i) + + if node.is_output: + # Need to fix this. + heads.append(NodeRef(i).to_json()) + + # Compute "node_row_ptr". + # TODO + + # Compute "attrs" field. + attrs = {} + + # A + shapes = [] + storage_ids = [] + dtype = [] + dltype = [] + + for i, node in enumerate(self.nodes): + storage_ids.append(i) + shapes.append(node.attrs['shape']) + if node.attrs['dtype'] == 'float32': + dtype.append(0) + dltype.append('float32') + + attrs["shape"] = ["list_shape", shapes] + attrs["storage_id"] = ["list_int", storage_ids] + attrs["dtype"] = ["list_int", dtype] + attrs["dltype"] = ["list_str", dltype] + + json_dict = { + "nodes": nodes, + "arg_nodes": arg_nodes, + "heads": heads, + "attrs": attrs + } + + return json.dumps(json_dict) + + +def compile_to_tvm(func): + """Compile a single function to the components needed by the + TVM RTS. + """ + env = get_env() + iids = [] + + # Why do I need to call items? + for op in env.operators(): + if not Operator_is_generic(op): + iids.append(op.id) + + # TODO(@jroesch): Need to write test case for this + mod = compile_ops_to_module(env, iids) + comp = TVMRTSCompiler() + comp.compile(func) + graph_json = comp.to_json() + return graph_json, mod, None # params currently isn't supported by API From b91ffb5acd1eef2aaf834484f9a0c46f210e8ec2 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 24 Aug 2018 15:33:44 -0700 Subject: [PATCH 040/136] Type checker is working for one op case --- python/tvm/relay/expr.py | 1 + python/tvm/relay/ir_builder.py | 3 - python/tvm/relay/op.py | 44 ----- python/tvm/relay/op/__init__.py | 6 + src/relay/compiler/resolve.cc | 2 + src/relay/compiler/type_infer.cc | 256 ++++++++++--------------- src/relay/op/tensor/elemwise.cc | 4 +- tests/python/relay/test_typechecker.py | 10 +- 8 files changed, 113 insertions(+), 213 deletions(-) delete mode 100644 python/tvm/relay/op.py diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 7f5dcbd0beb5..e98d74f3da88 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -71,6 +71,7 @@ class Function(Expr): def __init__(self, params: List[Param], ret_type: Type, body: Expr, type_params: List[TypeParam]=[]) -> None: self.__init_handle_by_constructor__(_make.Function, params, ret_type, body, type_params) +@register_relay_node class Call(Expr): op: Expr args: List[Expr] diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 07927aef7d24..8bd225bd4de1 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -149,9 +149,6 @@ def get(self): return _mk_let(bindings, self.ret_value) -def op(name): - return _op._create_op(name) - def bool_dtype(): return 'uint1' diff --git a/python/tvm/relay/op.py b/python/tvm/relay/op.py deleted file mode 100644 index d36a433e1e85..000000000000 --- a/python/tvm/relay/op.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Relay operators""" -from __future__ import absolute_import as _abs - -import sys - -from .._ffi.node import convert_to_node -from . import _make -from ..make import node as _make_node -from .expr import Expr, Call -from .base import register_relay_node - -@register_relay_node -class Op(Expr): - pass - -def _create_op(op_name): - op = _GetOp(op_name) - attrs_type_key = op.attrs_type_key - attrs_type_key = attrs_type_key if attrs_type_key else "DictAttrs" - # TODO(tqchen): improve the code build to fix the restriction. - # - # current restriction: - # - pass in args as positional arguments - # - pass in kwargs as keyword argument - def _op_func(*args, **kwargs): - args = convert_to_node(args) - # Need work to make sure constructor matches - # can support kwargs later - attrs = _make_node(attrs_type_key, **kwargs) - return Call(op, args, None, []) - _op_func.__name__ = op.name - return _op_func - - -def _init_ops(): - """Helper function to initialize the operators - """ - module = sys.modules[__name__] - for name in _ListOpNames(): - f = _create_op(name.value) - setattr(module, f.__name__, f) - - -_init_ops() diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 02e49ec40ff8..ad2f54929aed 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -4,3 +4,9 @@ # operator registry from . import _tensor +from ..expr import Expr +from ..base import register_relay_node + +@register_relay_node +class Op(Expr): + pass diff --git a/src/relay/compiler/resolve.cc b/src/relay/compiler/resolve.cc index 2d3e84dc2160..236722b23387 100644 --- a/src/relay/compiler/resolve.cc +++ b/src/relay/compiler/resolve.cc @@ -53,6 +53,7 @@ struct ResolveTypeExpr : ExprFVisitor<> { // term, then resolve e's old type and write // it back into the new node. auto new_e = ExprFVisitor::VisitExpr(e); + CHECK(e->checked_type_.defined()); auto resolved_cty = VisitType(e->checked_type_); new_e->checked_type_ = resolved_cty; return new_e; @@ -64,6 +65,7 @@ struct ResolveTypeExpr : ExprFVisitor<> { }; Type resolve(const TypeUnifier &unifier, const Type &ty) { + CHECK(ty.defined()); return ResolveTypeType(unifier).VisitType(ty); } diff --git a/src/relay/compiler/type_infer.cc b/src/relay/compiler/type_infer.cc index 40f14b517951..e2e5999e7341 100644 --- a/src/relay/compiler/type_infer.cc +++ b/src/relay/compiler/type_infer.cc @@ -27,6 +27,7 @@ #include "./incomplete_type.h" #include "./unifier.h" #include "./resolve.h" +#include "./type_subst.h" // #include "tvm/relay/alpha_eq.h" // #include "tvm/relay/debug.h" // #include "tvm/relay/first_order_reverse_ad.h" @@ -71,6 +72,7 @@ struct CheckedExpr { Expr expr; Type type; CheckedExpr(Expr e, Type t) : expr(e), type(t) {} + CheckedExpr() {} }; class TypeInferencer : private ExprFunctor { @@ -94,7 +96,7 @@ class TypeInferencer : private ExprFunctor { CheckedExpr Infer(const Expr & expr); - Type instantiate(FuncType fn_ty, tvm::Array &ty_args); + FuncType instantiate(FuncType fn_ty, tvm::Array &ty_args); void report_error(const std::string & msg, Span sp); [[ noreturn ]] void fatal_error(const std::string & msg, Span sp); @@ -103,7 +105,7 @@ class TypeInferencer : private ExprFunctor { Type resolve(const Type &t); Expr resolve(const Expr &e); CheckedExpr VisitFunction(const Function & f, bool generalize); - // Operator CheckOp(Operator op); + void CheckOp(Op op); // Defn CheckDefn(Defn def); private: CheckedExpr VisitExpr_(const LocalVarNode* op) override; @@ -115,6 +117,7 @@ class TypeInferencer : private ExprFunctor { CheckedExpr VisitExpr_(const CallNode* op) override; CheckedExpr VisitExpr_(const LetNode* op) override; CheckedExpr VisitExpr_(const IfNode* op) override; + CheckedExpr VisitExpr_(const OpNode* op) override; }; TypeInferencer::TypeInferencer() { @@ -145,7 +148,7 @@ class TypeInferencer : private ExprFunctor { // GlobalVar id = GetRef(op); // Item item = this->env->lookup(id); - // if (const OperatorNode *op = item.as()) { + // if (const OpNode *op = item.as()) { // return op->type; // } @@ -167,12 +170,12 @@ class TypeInferencer : private ExprFunctor { TensorTypeNode::make({}, HalideIR::Float(32, 1)) }; } - // Type TypeInferencer::VisitExpr_(const OperatorIdNode *op) { - // OperatorId id = GetRef(op); + // Type TypeInferencer::VisitExpr_(const OpIdNode *op) { + // OpId id = GetRef(op); // Item item = this->env->lookup(id); - // if (const OperatorNode *pn = item.as()) { - // Operator prim = GetRef(pn); + // if (const OpNode *pn = item.as()) { + // Op prim = GetRef(pn); // return prim->type; // } else { // this->fatal_error("internal error in InstrinsicId case", op->span); @@ -344,15 +347,20 @@ class TypeInferencer : private ExprFunctor { return this->VisitFunction(GetRef(op), false); } - Type TypeInferencer::instantiate(FuncType fn_ty, tvm::Array &ty_args) { - // const TypeQuantifierNode *ty_quant; - // while ((ty_quant = t.as())) { - // TypeParam id = ty_quant->id; - // TypeVar fresh = TypeVarNode::make(id->kind); - // this->unifier->insert(fresh); - // ty_args.push_back(fresh); - // t = type_subst(ty_quant->boundType, id, fresh); - // } + FuncType TypeInferencer::instantiate(FuncType fn_ty, tvm::Array &ty_args) { + tvm::Map subst_map; + + // Build a subsitituion map up from the function type and type arguments. + // Eventually allow the type vars to be passed in. + for (auto ty_param : fn_ty->type_params) { + IncompleteType fresh = IncompleteTypeNode::make(ty_param->kind); + this->unifier->insert(fresh); + ty_args.push_back(fresh); + subst_map.Set(ty_param, fresh); + } + + Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, {}); + inst_ty = type_subst(fn_ty, subst_map); // if (!check_kind(t)) { // this->fatal_error("Kind rules broken when instantiating type @@ -360,7 +368,7 @@ class TypeInferencer : private ExprFunctor { // t->span); // } - // return t; + return GetRef(inst_ty.as()); } CheckedExpr TypeInferencer::VisitExpr_(const CallNode *op) { @@ -369,13 +377,13 @@ class TypeInferencer : private ExprFunctor { auto checked_op = this->Infer(c->op); RELAY_LOG(INFO) << "TypeInferencer::VisitExpr_ op=" << c << std::endl - << "fn_ty=" << fn_ty << std::endl; + << "fn_ty=" << checked_op.type << std::endl; - auto fn_ty_node = checked_op.expr.as(); + auto fn_ty_node = checked_op.type.as(); if (!fn_ty_node) { - this->fatal_error("only expressions with function types can be called", c->fn->span); + this->fatal_error("only expressions with function types can be called", c->op->span); } // We now have a function type. @@ -389,13 +397,12 @@ class TypeInferencer : private ExprFunctor { fn_ty = instantiate(fn_ty, ty_args); std::vector arg_types; + std::vector checked_args; - - // TODO(sslyu): figure out how to handle type ids - // fn_ty = instantiate(fn_ty, ty_args); for (auto arg : c->args) { auto checked_arg = this->Infer(arg); arg_types.push_back(checked_arg.type); + checked_args.push_back(checked_arg.expr); } auto type_arity = fn_ty->arg_types.size(); @@ -423,164 +430,100 @@ class TypeInferencer : private ExprFunctor { ty_args.Set(i, this->unifier->subst(ty_args[i])); } - // Write the type arguments into the call node, recording what inference - // solves. This solution might need some work. - c->ty_args = ty_args; + auto new_call = CallNode::make(checked_op.expr, checked_args, c->attrs, ty_args); - return { new_call, call_type } + return { new_call, fn_ty->ret_type }; } - // Type TypeInferencer::VisitExpr_(const DebugNode *op) { - // return this->Check(op->node); - // } - CheckedExpr TypeInferencer::VisitExpr_(const LetNode *op) { Let let = GetRef(op); - Type checked_ty; + CheckedExpr checked_value; Type annotated_ty = resolve(let->value_type); - // if we are let-defining a function, treat it as a let-rec and insert - // the id with the annotated type in case there is recursion; - // no such recursion permitted with anything that's not a function! - // if (let->value.as()) { - // with_frame([&]() { - // local_stack.insert(let->id, annotated_ty); - // checked_ty = Check(let->value); - // }); - // } else { - checked_ty = Infer(let->value).type; - // } - // ensure annotated type and checked type are compatible - // TODO(sslyu): should the annotated type override the unified one? + // If we are let-defining a function, we want to be able to + // recursively name the function in order to support recursive + // local definitions. + if (let->value.as()) { + with_frame([&]() { + local_stack.insert(let->var, annotated_ty); + checked_value = Infer(let->value); + }); + } else { + checked_value = Infer(let->value); + } + Type unified_ty = - this->unify(checked_ty, annotated_ty, let->span); + this->unify(checked_value.type, annotated_ty, let->span); + + // Update type context with unified type now that we have + // solved this equation. + local_stack.insert(let->var, unified_ty); - return with_frame([&]() { + auto checked_body = with_frame([&]() { local_stack.insert(let->var, unified_ty); return Infer(let->body); }); - } - // Type TypeInferencer::VisitExpr_(const ReverseNode *op) { - // // apply reverse mode to node and typecheck that instead - // std::shared_ptr gf = std::make_shared(); - // return this->Check(ReverseExpr(env, op->node, gf)); - // } - - // Type TypeInferencer::VisitExpr_(const GradientNode *op) { - // auto node = op->node; - // this->Check(node); - // auto gf = std::make_shared(); - // return FOWithGradientType(node->checked_type()); - // } - - // Type TypeInferencer::VisitExpr_(const ProjectionNode *op) { - // Projection proj = GetRef(op); - - // Type tup_type = this->Check(proj->tuple); - - // const TupleTypeNode *ptn = tup_type.as(); - // if (!ptn) { - // this->fatal_error("Cannot project into non-product type", op->span); - // } + auto checked_let = LetNode::make( + let->var, + checked_value.expr, + checked_body.expr, + let->value_type); - // TupleType pt = GetRef(ptn); - // size_t field = (size_t)proj->field; - // if (field >= pt->fields.size()) { - // this->fatal_error("Projecting past bounds of product", op->span); - // } - - // return pt->fields[field]; - // } + return { checked_let, checked_body.type }; + } CheckedExpr TypeInferencer::VisitExpr_(const IfNode *op) { - // If ifn = GetRef(op); - - // // Ensure the type of the guard is of Tensor[Bool, ()], - // // that is a rank-0 boolean tensor. - // Type guardType = this->Check(ifn->guard); - // bool is_bool = false; - // bool zero_rank = false; - // if (const TensorTypeNode *ttn = guardType.as()) { - // TensorType tt = GetRef(ttn); - - // if (const BaseTypeNode *btn = tt->dtype.as()) { - // is_bool = btn->type.is_bool(); - // } - - // Type shape = simple_eval_shape(tt->shape); - - // if (const ShapeSeqNode *sn = shape.as()) { - // zero_rank = (sn->shapes.size() == 0); - // } - // } - - // if (!(is_bool && zero_rank)) { - // this->fatal_error("IfNode guard must be a rank 0 bool tensor", - // ifn->guard->span); - // } + If ifn = GetRef(op); + + // Ensure the type of the guard is of Tensor[Bool, ()], + // that is a rank-0 boolean tensor. + auto checked_cond = this->Infer(ifn->cond); + auto cond_type = checked_cond.type; + + if (const TensorTypeNode *tt_node = cond_type.as()) { + TensorType tt = GetRef(tt_node); + if (tt->dtype.is_bool() && tt->shape.size() == 0) { + auto checked_true = this->Infer(ifn->true_value); + auto checked_false = this->Infer(ifn->false_value); + auto unified_type = this->unify(checked_true.type, checked_false.type, ifn->span); + auto checked_if = IfNode::make(checked_cond.expr, checked_true.expr, checked_false.expr); + return { checked_if, unified_type }; + } + } - // // unify types of different branches - // Type left = this->Check(ifn->true_b); - // Type right = this->Check(ifn->false_b); - // return this->unify(left, right, ifn->span); + this->fatal_error("if-then-else guard must be a rank-0 boolean tensor", + ifn->cond->span); } - // Type TypeInferencer::VisitExpr_(const RefNode *op) { - // Ref r = GetRef(op); - // Type inner = this->Check(r->expr); - // return RefTypeNode::make(inner); - // } - - // Type TypeInferencer::VisitExpr_(const ReadRefNode *op) { - // ReadRef vr = GetRef(op); - // Type ref_type = this->Check(vr->ref); - - // // reject if not a ref type - // const RefTypeNode *rtn = ref_type.as(); - // if (!rtn) { - // this->fatal_error( - // "the de-reference operation can only be used with references", - // op->span); - // } - - // RefType rt = GetRef(rtn); - // return rt->data_type; - // } - - // Type TypeInferencer::VisitExpr_(const WriteRefNode *op) { - // WriteRef sr = GetRef(op); - // Type ref_type = this->Check(sr->ref); - - // const RefTypeNode *rtn = ref_type.as(); - // if (!rtn) { - // this->fatal_error("Cannot mutate non-ref", op->span); - // } - // RefType rt = GetRef(rtn); - - // // ensure ref type's inner type and expr's type are compatible; return - // unit Type expr_type = this->Check(sr->val); this->unify(rt->data_type, - // expr_type, sr->span); return UnitType(); - // } + CheckedExpr TypeInferencer::VisitExpr_(const OpNode *op) { + return { GetRef(op), FuncTypeNode::make({}, TensorTypeNode::Int(32), {}, {} )}; + } Type TypeInferencer::resolve(const Type &t) { - return ::tvm::relay::resolve(this->unifier, t); + if (t.defined()) { + return ::tvm::relay::resolve(this->unifier, t); + } else { + return IncompleteTypeNode::make(TypeParamNode::Kind::kType); + } } Expr TypeInferencer::resolve(const Expr &e) { + CHECK(e.defined()); return ::tvm::relay::resolve(this->unifier, e); } - // Operator TypeInferencer::CheckOp(Operator op) { - // if (!check_kind(op->type)) { - // report_error("the type of the operator is ill formed", op->type->span); - // } + void TypeInferencer::CheckOp(Op op) { + throw Error("NYI"); + // if (!check_kind(op->type)) { + // report_error("the type of the operator is ill formed", op->type->span); + // } - // // Fix me - // return op; - // } + // // Fix me + // return op; + } // Defn TypeInferencer::CheckDefn(Defn defn) { // // This is to handle recursion, but we need to speculatively @@ -620,8 +563,8 @@ class TypeInferencer : private ExprFunctor { // try { // if (const DefnNode *defn = i.as()) { // return tc.CheckDefn(GetRef(defn)); - // } else if (const OperatorNode *op_node = i.as()) { - // return tc.CheckOp(GetRef(op_node)); + // } else if (const OpNode *op_node = i.as()) { + // return tc.CheckOp(GetRef(op_node)); // } else { // throw dmlc::Error("internal error: unknown Item type"); // } @@ -717,13 +660,6 @@ class TypeInferencer : private ExprFunctor { *ret = Infer(env, e); }); - // TVM_REGISTER_API("relay._tyck.check_item") - // .set_body([](TVMArgs args, TVMRetValue *ret) { - // Environment env = args[0]; - // Item i = args[1]; - // *ret = check(env, i); - // }); - TVM_REGISTER_API("relay._type_infer._get_checked_type") .set_body([](TVMArgs args, TVMRetValue *ret) { Expr e = args[0]; diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc index 50c864650ff4..d1d3e01ed9a6 100644 --- a/src/relay/op/tensor/elemwise.cc +++ b/src/relay/op/tensor/elemwise.cc @@ -35,9 +35,7 @@ RELAY_REGISTER_UNARY_OP("log") log(x) )code" TVM_ADD_FILELINE) -.set_support_level(1); -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(1) .add_type_func("Broadcast"); diff --git a/tests/python/relay/test_typechecker.py b/tests/python/relay/test_typechecker.py index 9c050ecd62d0..d111bba9dfbf 100644 --- a/tests/python/relay/test_typechecker.py +++ b/tests/python/relay/test_typechecker.py @@ -2,8 +2,9 @@ for expressions. """ from tvm.relay.type_infer import check_expr -from tvm.relay.ir_builder import IRBuilder, float_type, op, func_type +from tvm.relay.ir_builder import IRBuilder, float_type, func_type from tvm.relay.env import Environment +from tvm.relay.op import log def has_type(expr, typ): env = Environment({}) @@ -23,9 +24,12 @@ def test_monomorphic_let(): def test_single_op(): "Program: fn (x : float32) { let t1 = f(x); t1 }" b = IRBuilder() - f = op('log') with b.function(('x', float_type())) as func: x, = func.param_ids() - t1 = b.let('t1', f(x)) + t1 = b.let('t1', log(x)) b.ret(t1) assert has_type(func.to_func(), func_type([float_type()], float_type())) + +if __name__ == "__main__": + test_monomorphic_let() + test_single_op() From 26755ec41877ecb747ef8fc64972535fd39f8062 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 24 Aug 2018 16:03:53 -0700 Subject: [PATCH 041/136] op step 2 --- include/tvm/relay/op.h | 18 +++++-- python/tvm/relay/__init__.py | 8 +-- python/tvm/relay/expr.py | 2 + python/tvm/relay/op/__init__.py | 3 ++ python/tvm/relay/op/op.py | 77 +++++++++++++++++++++++++++++ python/tvm/relay/op/registry.py | 1 - src/relay/ir/op.cc | 62 ++++++++++++++++++----- tests/python/relay/test_relay_op.py | 13 +++++ 8 files changed, 164 insertions(+), 20 deletions(-) create mode 100644 python/tvm/relay/op/op.py delete mode 100644 python/tvm/relay/op/registry.py diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index be81f54ecd69..15c55dee52c0 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -103,7 +103,7 @@ class Op : public relay::Expr { * \tparam ValueType The type of the attribute. */ template - inline static const OpMap& GetAttr(const std::string& attr_name); + inline static OpMap GetAttr(const std::string& attr_name); /*! * \brief Get an Op for a given operator name. * Will raise an error if the op has not been registered. @@ -193,9 +193,13 @@ class OpRegistry { // set the name of the op to be the same as registry inline OpRegistry& set_name() { // NOLINT(*) - get()->name = name; + if (get()->name.length() == 0) { + get()->name = name; + } return *this; } + /*! \return The global single retistry */ + TVM_DLL static ::dmlc::Registry* Registry(); private: friend class ::dmlc::Registry; @@ -307,7 +311,7 @@ class OpMap { */ #define RELAY_REGISTER_OP(OpName) \ DMLC_STR_CONCAT(RELAY_REGISTER_VAR_DEF, __COUNTER__) = \ - ::dmlc::Registry<::tvm::relay::OpRegistry>::Get()->__REGISTER_OR_GET__(OpName).set_name() + ::tvm::relay::OpRegistry::Registry()->__REGISTER_OR_GET__(OpName).set_name() // implementations inline const OpNode* Op::operator->() const { @@ -315,7 +319,7 @@ inline const OpNode* Op::operator->() const { } template -inline const OpMap& Op::GetAttr(const std::string& key) { +inline OpMap Op::GetAttr(const std::string& key) { return OpMap(Op::GetGenericAttr(key)); } @@ -352,6 +356,12 @@ inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) return *this; } +inline OpRegistry& OpRegistry::set_attrs_type_key( // NOLINT(*) + const std::string& type_key) { + get()->attrs_type_key = type_key; + return *this; +} + inline OpRegistry& OpRegistry::set_support_level(int32_t n) { // NOLINT(*) get()->support_level = n; return *this; diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index f94b572f6b44..019d7c19a865 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -4,10 +4,6 @@ from . import expr from . import op -# import all operators in the loop namespace -from .op import * - - # Span Span = base.Span @@ -31,3 +27,7 @@ Let = expr.Let If = expr.If Var = LocalVar + +# Operators +from .op import Op +from .op.tensor import * diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index e98d74f3da88..41066829e2f3 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -99,3 +99,5 @@ class If(Expr): def __init__(self, cond: Expr, true_value: Expr, false_value: Expr) -> None: self.__init_handle_by_constructor__(_make.If, cond, true_value, false_value) + + diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index ad2f54929aed..3d87d78fe633 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -1,5 +1,8 @@ """Relay core operators.""" # operator defs +from .op import get, register, Op + +# Operators from .tensor import * # operator registry diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py new file mode 100644 index 000000000000..4540b19f5ccf --- /dev/null +++ b/python/tvm/relay/op/op.py @@ -0,0 +1,77 @@ +"""The base node types for the Relay language.""" +from ..._ffi.function import _init_api + +from ..base import register_relay_node +from ..expr import Expr +from ..._ffi.function import Function +from ...api import convert + +@register_relay_node +class Op(Expr): + def __init__(self): + raise RuntimeError("Cannot create op, use get instead") + + def get_attr(self, attr_name): + """Get additional attribute about the operator. + + Parameters + ---------- + attr_name : str + The attribute name. + + Returns + ------- + value : object + The attribute value + """ + return _OpGetAttr(self, attr_name) + + +def get(op_name): + """Get the Op for a given name + + Parameters + ---------- + op_name : str + The operator name + + Returns + ------- + op : Op + The op of the corresponding name + """ + return _GetOp(op_name) + + +def register(op_name, attr_key, value=None, level=10): + """Register an operator property of an operator. + + + Parameters + ---------- + op_name : str + The name of operator + + attr_key : str + The attribute name. + + value : object, optional + The value to set + + level : int, optional + The priority level + + Returns + ------- + fregister : function + Register function if value is not specified. + """ + def _register(v): + """internal register function""" + _Register(op_name, attr_key, v, level) + return v + return _register(value) if value else _register + + +_init_api("relay.op", __name__) + diff --git a/python/tvm/relay/op/registry.py b/python/tvm/relay/op/registry.py deleted file mode 100644 index d7426429ef6f..000000000000 --- a/python/tvm/relay/op/registry.py +++ /dev/null @@ -1 +0,0 @@ -"""Mechanism to work with operator registry.""" diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 5a4241a182b1..664947425b53 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -10,6 +10,10 @@ DMLC_REGISTRY_ENABLE(::tvm::relay::OpRegistry); namespace tvm { namespace relay { +::dmlc::Registry* OpRegistry::Registry() { + return ::dmlc::Registry::Get(); +} + // single manager of operator information. struct OpManager { // mutex to avoid registration from multiple threads. @@ -18,6 +22,8 @@ struct OpManager { std::atomic op_counter{0}; // storage of additional attribute table. std::unordered_map > attr; + // frontend functions + std::vector frontend_funcs; // get singleton of the static OpManager* Global() { static OpManager inst; @@ -75,22 +81,56 @@ void OpRegistry::UpdateAttr( } // Frontend APIs -using runtime::TypedPackedFunc; - TVM_REGISTER_API("relay.op._ListOpNames") -.set_body(TypedPackedFunc()>([]() { - Array ret; - for (const std::string& name : - dmlc::Registry::ListAllNames()) { - ret.push_back(tvm::Expr(name)); - } - return ret; - })); +.set_body_typed()>([]() { + Array ret; + for (const std::string& name : + dmlc::Registry::ListAllNames()) { + ret.push_back(tvm::Expr(name)); + } + return ret; + }); TVM_REGISTER_API("relay.op._GetOp") -.set_body(TypedPackedFunc(Op::Get)); +.set_body_typed(Op::Get); +TVM_REGISTER_API("relay.op._OpGetAttr") +.set_body([](TVMArgs args, TVMRetValue* rv) { + Op op = args[0]; + std::string attr_name = args[1]; + auto op_map = Op::GetAttr(attr_name); + if (op_map.count(op)) { + *rv = op_map[op]; + } + }); + + +TVM_REGISTER_API("relay.op._Register") +.set_body([](TVMArgs args, TVMRetValue* rv) { + std::string op_name = args[0]; + std::string attr_key = args[1]; + runtime::TVMArgValue value = args[2]; + int plevel = args[3]; + auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name(); + // enable resgiteration and override of certain properties + if (attr_key == "num_inputs" && plevel > 128) { + reg.set_num_inputs(value); + } else if (attr_key == "attrs_type_key" && plevel > 128) { + reg.set_attrs_type_key(value); + } else { + // normal attr table override. + if (args[2].type_code() == kFuncHandle) { + // do an eager copy of the PackedFunc + PackedFunc f = args[2]; + // If we get a function from frontend, avoid deleting it. + OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f)); + reg.set_attr(attr_key, f, plevel); + } else { + reg.set_attr(attr_key, args[2], plevel); + } + } + }); } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_relay_op.py b/tests/python/relay/test_relay_op.py index 4235dd918d93..1f95a3f72c15 100644 --- a/tests/python/relay/test_relay_op.py +++ b/tests/python/relay/test_relay_op.py @@ -1,5 +1,16 @@ from tvm import relay +def test_op_attr(): + log_op = relay.op.get("log") + + @relay.op.register("exp", "ftest") + def test(x): + return x + 1 + + assert log_op.num_inputs == 1 + assert log_op.get_attr("ftest") is None + assert relay.op.get("exp").get_attr("ftest")(1) == 2 + def test_op_level1(): x = relay.Var("x") @@ -11,4 +22,6 @@ def test_op_level1(): if __name__ == "__main__": + test_op_attr() test_op_level1() + From f8540ec40adecd0212ffed6e0d97c7acba8e6752 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 24 Aug 2018 16:11:07 -0700 Subject: [PATCH 042/136] WIP --- include/tvm/relay/compiler/environment.h | 2 +- include/tvm/relay/op.h | 6 +++--- include/tvm/relay/type.h | 22 ++++++++++++---------- src/codegen/spirv/ir_builder.cc | 2 +- src/relay/compiler/alpha_eq.cc | 2 +- src/relay/compiler/type_functor.h | 4 ++-- src/relay/compiler/type_visitor.h | 2 +- src/relay/compiler/unifier.cc | 2 +- src/relay/compiler/unifier.h | 2 +- src/relay/ir/type.cc | 14 +++++++------- src/relay/op/tensor/elemwise.cc | 4 +++- 11 files changed, 33 insertions(+), 29 deletions(-) diff --git a/include/tvm/relay/compiler/environment.h b/include/tvm/relay/compiler/environment.h index d5a3ddd73f77..2ec8ca8af933 100644 --- a/include/tvm/relay/compiler/environment.h +++ b/include/tvm/relay/compiler/environment.h @@ -40,7 +40,7 @@ class EnvironmentNode : public RelayNode { private: /*! A map from string names to GlobalIds, ensures global uniqueness. */ tvm::Map global_map_; - tvm::Map type_func_map_; + tvm::Map type_func_map_; // /*! \brief A map from file names to source fragments. */ // SourceMap source_map_ diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 15c55dee52c0..630a231ebb54 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -153,7 +153,7 @@ class OpRegistry { * \param ty_func The type function to register for the return type. * \return reference to self. */ - inline OpRegistry& add_type_func(const std::string & type_func_name); + inline OpRegistry& add_type_func(const std::string & type_func_name, TypeRelationFn type_fn); /*! * \brief Set the type key of attributes. @@ -343,8 +343,8 @@ inline OpRegistry& OpRegistry::add_argument(const std::string &name, return *this; } - inline OpRegistry& OpRegistry::add_type_func(const std::string & type_func_name) { - auto type_func = TypeFunctionNode::make(type_func_name, 0); + inline OpRegistry& OpRegistry::add_type_func(const std::string & type_func_name, TypeRelationFn type_fn) { + auto type_func = TypeRelationNode::make(type_func_name, 0); for (auto arg : get()->arguments) { std::cout << arg << std::endl; } diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index ef8c4c71f5b7..68ed411a23ed 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -210,40 +210,42 @@ class FuncTypeNode : public TypeNode { RELAY_DEFINE_NODE_REF(FuncType, FuncTypeNode, Type); +using TypeRelationFn = std::function(const Array&, int)>; + /*! - * \brief Opaque type inference function. + * \brief Opaque type relation, is an input-output relation on types. */ -class TypeFunction; +class TypeRelation; /*! - * \brief TypeFunction container. + * \brief TypeRelation container. * \note This node is not directly serializable. * The type function need to be lookedup in the environment. */ -class TypeFunctionNode : public RelayNode { +class TypeRelationNode : public RelayNode { public: /*! \brief The name of the function */ std::string name; /*! \brief Number of input type arguments, can be -1, which means VarArgs */ int num_args; /*! - * \brief The type function, + * \brief The function on input and output variables which * this is not directly serializable, * need to be looked-up in the environment. */ - mutable std::function& arg_types)> func_; + TypeRelationFn func_; void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("name", &name); v->Visit("num_args", &num_args); } - TVM_DLL static TypeFunction make(std::string name, int num_args); + TVM_DLL static TypeRelation make(std::string name, int num_args); - static constexpr const char* _type_key = "relay.TypeFunction"; - TVM_DECLARE_NODE_TYPE_INFO(TypeFunctionNode, RelayNode); + static constexpr const char* _type_key = "relay.TypeRelation"; + TVM_DECLARE_NODE_TYPE_INFO(TypeRelationNode, RelayNode); }; -RELAY_DEFINE_NODE_REF(TypeFunction, TypeFunctionNode, Type); +RELAY_DEFINE_NODE_REF(TypeRelation, TypeRelationNode, Type); /*! * \brief Call a type function with some number of arguments. diff --git a/src/codegen/spirv/ir_builder.cc b/src/codegen/spirv/ir_builder.cc index 41cb48c5854b..87987dbf08e9 100644 --- a/src/codegen/spirv/ir_builder.cc +++ b/src/codegen/spirv/ir_builder.cc @@ -41,7 +41,7 @@ void IRBuilder::InitPreDefs() { t_void_.id = id_counter_++; ib_.Begin(spv::OpTypeVoid).Add(t_void_).Commit(&global_); t_void_func_.id = id_counter_++; - ib_.Begin(spv::OpTypeFunction) + ib_.Begin(spv::OpTypeRelation) .AddSeq(t_void_func_, t_void_).Commit(&global_); } diff --git a/src/relay/compiler/alpha_eq.cc b/src/relay/compiler/alpha_eq.cc index 688a93ae73fc..d4f1d888fb69 100644 --- a/src/relay/compiler/alpha_eq.cc +++ b/src/relay/compiler/alpha_eq.cc @@ -92,7 +92,7 @@ struct TypeAlphaEq : TypeVisitor { } } - void VisitType_(const TypeFunctionNode *op, const Type &t2) override { + void VisitType_(const TypeRelationNode *op, const Type &t2) override { } // void VisitType_(const TupleTypeNode *op, const Type &t2) override { // if (const TupleTypeNode *pt = t2.as()) { diff --git a/src/relay/compiler/type_functor.h b/src/relay/compiler/type_functor.h index 3840c902bfe8..5de56837ca10 100644 --- a/src/relay/compiler/type_functor.h +++ b/src/relay/compiler/type_functor.h @@ -63,7 +63,7 @@ class TypeFunctor { virtual R VisitType_(const TypeParamNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitType_(const TypeFunctionNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; @@ -81,7 +81,7 @@ class TypeFunctor { RELAY_TYPE_FUNCTOR_DISPATCH(TypeParamNode); RELAY_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode); RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode); - RELAY_TYPE_FUNCTOR_DISPATCH(TypeFunctionNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode); RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode); RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode); return vtable; diff --git a/src/relay/compiler/type_visitor.h b/src/relay/compiler/type_visitor.h index 60ae810a6b96..c98ff3ab8958 100644 --- a/src/relay/compiler/type_visitor.h +++ b/src/relay/compiler/type_visitor.h @@ -47,7 +47,7 @@ struct TypeVisitor : ::tvm::relay::TypeFunctor { } } - void VisitType_(const TypeFunctionNode* op, Args... args) override {} + void VisitType_(const TypeRelationNode* op, Args... args) override {} void VisitType_(const IncompleteTypeNode* op, Args... args) override {} }; diff --git a/src/relay/compiler/unifier.cc b/src/relay/compiler/unifier.cc index b7cc296cc5db..2f728a104530 100644 --- a/src/relay/compiler/unifier.cc +++ b/src/relay/compiler/unifier.cc @@ -325,7 +325,7 @@ Type TypeUnifierNode::VisitType_(const TensorTypeNode *t1, const Type rt2) { // throw UnificationError("Cannot unify TupleTypeNode"); // } -Type TypeUnifierNode::VisitType_(const TypeFunctionNode *sen1, const Type t2) { +Type TypeUnifierNode::VisitType_(const TypeRelationNode *sen1, const Type t2) { // ShapeExtension sh_ext1 = GetRef(sen1); // if (const IncompleteTypeNode *tvn2 = t2.as()) { diff --git a/src/relay/compiler/unifier.h b/src/relay/compiler/unifier.h index 86ffd664a161..40583b16a55a 100644 --- a/src/relay/compiler/unifier.h +++ b/src/relay/compiler/unifier.h @@ -110,7 +110,7 @@ class TypeUnifierNode : public Node, Type VisitType_(const TypeParamNode* t1, const Type t2) override; Type VisitType_(const FuncTypeNode* t1, const Type t2) override; // Type VisitType_(const TupleTypeNode* t1, const Type t2) override; - Type VisitType_(const TypeFunctionNode* s1, const Type t2) override; + Type VisitType_(const TypeRelationNode* s1, const Type t2) override; Type VisitType_(const TypeCallNode* s1, const Type t2) override; }; diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index 2b6647a5807e..d9e2737225ec 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -116,22 +116,22 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << node->type_constraints << ")"; }); -TypeFunction TypeFunctionNode::make(std::string name, int num_args) { - std::shared_ptr n = std::make_shared(); +TypeRelation TypeRelationNode::make(std::string name, int num_args) { + std::shared_ptr n = std::make_shared(); n->name = std::move(name); n->num_args = std::move(num_args); - return TypeFunction(n); + return TypeRelation(n); } -TVM_REGISTER_API("relay._make.TypeFunction") +TVM_REGISTER_API("relay._make.TypeRelation") .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = TypeFunctionNode::make(args[0], args[1]); + *ret = TypeRelationNode::make(args[0], args[1]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const TypeFunctionNode *node, + .set_dispatch([](const TypeRelationNode *node, tvm::IRPrinter *p) { - p->stream << "TypeFunctionNode(" << node->name << ", " << node->num_args << ")"; + p->stream << "TypeRelationNode(" << node->name << ", " << node->num_args << ")"; }); TypeCall TypeCallNode::make(Type func, Array args) { diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc index d1d3e01ed9a6..05e1cbd57b13 100644 --- a/src/relay/op/tensor/elemwise.cc +++ b/src/relay/op/tensor/elemwise.cc @@ -36,7 +36,9 @@ RELAY_REGISTER_UNARY_OP("log") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_func("Broadcast"); +.add_type_func("Log", [](const Array & t, int num_args) { + return t; +}); RELAY_REGISTER_UNARY_OP("exp") From 0325fa63c9e4125a195eb95ad4e1c69e89d31066 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 26 Aug 2018 21:54:04 -0700 Subject: [PATCH 043/136] Address comments from Friday --- .../tvm/relay/{compiler => }/environment.h | 27 +++++++++---------- include/tvm/relay/expr_functor.h | 3 ++- include/tvm/relay/ir.h | 20 -------------- .../tvm/relay/{compiler => pass}/alpha_eq.h | 9 ++++--- .../tvm/relay/{compiler => pass}/type_infer.h | 14 +++++----- python/tvm/relay/ir.py | 18 ------------- src/relay/{compiler => ir}/environment.cc | 2 +- src/relay/{compiler => pass}/alpha_eq.cc | 8 +++--- .../{compiler => pass}/incomplete_type.h | 8 +++--- src/relay/{compiler => pass}/resolve.cc | 10 +++---- src/relay/{compiler => pass}/resolve.h | 6 ++--- src/relay/{compiler => pass}/type_functor.h | 8 +++--- src/relay/{compiler => pass}/type_infer.cc | 12 ++++----- src/relay/{compiler => pass}/type_subst.cc | 12 ++++----- src/relay/{compiler => pass}/type_subst.h | 8 +++--- src/relay/{compiler => pass}/type_visitor.h | 0 src/relay/{compiler => pass}/unifier.cc | 22 ++++++++------- src/relay/{compiler => pass}/unifier.h | 10 +++---- 18 files changed, 80 insertions(+), 117 deletions(-) rename include/tvm/relay/{compiler => }/environment.h (83%) delete mode 100644 include/tvm/relay/ir.h rename include/tvm/relay/{compiler => pass}/alpha_eq.h (52%) rename include/tvm/relay/{compiler => pass}/type_infer.h (69%) delete mode 100644 python/tvm/relay/ir.py rename src/relay/{compiler => ir}/environment.cc (99%) rename src/relay/{compiler => pass}/alpha_eq.cc (97%) rename src/relay/{compiler => pass}/incomplete_type.h (82%) rename src/relay/{compiler => pass}/resolve.cc (92%) rename src/relay/{compiler => pass}/resolve.h (79%) rename src/relay/{compiler => pass}/type_functor.h (95%) rename src/relay/{compiler => pass}/type_infer.cc (98%) rename src/relay/{compiler => pass}/type_subst.cc (66%) rename src/relay/{compiler => pass}/type_subst.h (54%) rename src/relay/{compiler => pass}/type_visitor.h (100%) rename src/relay/{compiler => pass}/unifier.cc (96%) rename src/relay/{compiler => pass}/unifier.h (95%) diff --git a/include/tvm/relay/compiler/environment.h b/include/tvm/relay/environment.h similarity index 83% rename from include/tvm/relay/compiler/environment.h rename to include/tvm/relay/environment.h index 2ec8ca8af933..ff8e596059b5 100644 --- a/include/tvm/relay/compiler/environment.h +++ b/include/tvm/relay/environment.h @@ -1,18 +1,17 @@ /*! * Copyright (c) 2018 by Contributors - * \file environment.h - * \brief The global environment containing + * \file tvm/relay/environment.h + * \brief The global environment, contains global state of Relay program. */ #ifndef TVM_RELAY_ENVIRONMENT_H_ #define TVM_RELAY_ENVIRONMENT_H_ #include #include -#include "../expr.h" -#include "../type.h" -#include "../op.h" -#include "../error.h" -// #include "tvm/relay/options.h" +#include "./expr.h" +#include "./type.h" +#include "./op.h" +#include "./error.h" // #include "tvm/relay/source_map.h" namespace tvm { @@ -38,10 +37,8 @@ struct Environment; class EnvironmentNode : public RelayNode { private: - /*! A map from string names to GlobalIds, ensures global uniqueness. */ + /*! \brief A map from string names to global variables ensures global uniqueness. */ tvm::Map global_map_; - tvm::Map type_func_map_; - // /*! \brief A map from file names to source fragments. */ // SourceMap source_map_ // /*! \brief A list of the errors reported during the current run. */ @@ -51,8 +48,6 @@ class EnvironmentNode : public RelayNode { /*! \brief A map from ids to all global functions. */ tvm::Map items; - // Options options; - EnvironmentNode() {} void VisitAttrs(tvm::AttrVisitor* v) final {} @@ -67,15 +62,17 @@ class EnvironmentNode : public RelayNode { GlobalVar GetGlobalVar(const std::string& str); - /*! \brief Lookup a global function by its name. */ + /*! \brief Lookup a global function by its variable. */ Function Lookup(const GlobalVar& id); + + /*! \brief Lookup a global function by its string name */ Function Lookup(const std::string & s); /*! \brief Add a source fragment to the environment. */ // FileId add_source(std::string file_name, std::string source); - void report_error(std::string msg, Span sp); - void display_errors(); + void ReportError(std::string msg, Span sp); + void DisplayErrors(); static constexpr const char* _type_key = "relay.Environment"; TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node); diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 2067b90bd364..e37a454eee41 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -9,7 +9,8 @@ #include #include -#include "ir.h" +#include "./expr.h" +#include "./op.h" namespace tvm { namespace relay { diff --git a/include/tvm/relay/ir.h b/include/tvm/relay/ir.h deleted file mode 100644 index 73c275cf1c98..000000000000 --- a/include/tvm/relay/ir.h +++ /dev/null @@ -1,20 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file tvm/relay/ir.h - * \brief The Relay intermediate representation's core data structures. - */ -#ifndef TVM_RELAY_IR_H_ -#define TVM_RELAY_IR_H_ - -#include "./base.h" -#include "./type.h" -#include "./expr.h" -#include "./op.h" - -// namespace tvm { -// namespace relay { - -// } // namespace relay -// } // namespace tvm - -#endif // TVM_RELAY_IR_H_ diff --git a/include/tvm/relay/compiler/alpha_eq.h b/include/tvm/relay/pass/alpha_eq.h similarity index 52% rename from include/tvm/relay/compiler/alpha_eq.h rename to include/tvm/relay/pass/alpha_eq.h index ba91afc21015..caa2f93c31a7 100644 --- a/include/tvm/relay/compiler/alpha_eq.h +++ b/include/tvm/relay/pass/alpha_eq.h @@ -1,18 +1,19 @@ /*! * Copyright (c) 2018 by Contributors * \file tvm/relay/alpha_eq.h - * \brief Check expressions & types for structural equivalence. + * \brief Check expressions and types for structural equivalence. */ #ifndef TVM_RELAY_ALPHA_EQ_H_ #define TVM_RELAY_ALPHA_EQ_H_ -#include "tvm/relay/ir.h" +#include "tvm/relay/type.h" +#include "tvm/relay/expr.h" namespace tvm { namespace relay { -bool alpha_eq(const Expr & e1, const Expr & e2); -bool alpha_eq(const Type & t1, const Type & t2); +bool AlphaEqual(const Expr & e1, const Expr & e2); +bool AlphaEqual(const Type & t1, const Type & t2); } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/compiler/type_infer.h b/include/tvm/relay/pass/type_infer.h similarity index 69% rename from include/tvm/relay/compiler/type_infer.h rename to include/tvm/relay/pass/type_infer.h index c084fb7a109e..9a8ab2bc6a8b 100644 --- a/include/tvm/relay/compiler/type_infer.h +++ b/include/tvm/relay/pass/type_infer.h @@ -1,16 +1,16 @@ /*! * Copyright (c) 2018 by Contributors - * \file tvm/relay/type_infer.h + * \file tvm/relay/pass/type_infer.h * \brief Perform type inference and checking on Relay programs. * * The pass produces a new expression with its checked_type * field populated and incomplete types resolved. */ -#ifndef TVM_RELAY_COMPILER_TYPECHECKER_H_ -#define TVM_RELAY_COMPILER_TYPECHECKER_H_ +#ifndef TVM_RELAY_PASS__TYPECHECKER_H_ +#define TVM_RELAY_PASS__TYPECHECKER_H_ -#include "tvm/relay/ir.h" -#include "tvm/relay/compiler/environment.h" +#include "tvm/relay/expr.h" +#include "tvm/relay/environment.h" namespace tvm { namespace relay { @@ -19,7 +19,7 @@ namespace relay { * with unambigous type information filled in, as well as it's * checked type field populated with the result type. */ -Expr Infer(const Environment & env, const Expr & e); +Expr InferType(const Environment & env, const Expr & e); /*! \brief Ensures that an operator is well-formed with respect * to Relay's type system. @@ -28,4 +28,4 @@ Op CheckOp(const Environment & env, const Op & op); } // namespace relay } // namespace tvm -#endif // TVM_RELAY_COMPILER_TYPECHECKER_H_ +#endif // TVM_RELAY_PASS_TYPECHECKER_H_ diff --git a/python/tvm/relay/ir.py b/python/tvm/relay/ir.py deleted file mode 100644 index a95f29abe6de..000000000000 --- a/python/tvm/relay/ir.py +++ /dev/null @@ -1,18 +0,0 @@ -from . import base -from . import type as ty -from . import expr - -# Base -register_relay_node = base.register_relay_node -NodeBase = base.NodeBase - -# Type -Type = ty.Type -TensorType = ty.Type -Kind = ty.Kind -TypeParam = ty.TypeParam -TypeConstraint = ty.TypeConstraint -FuncType = ty.FuncType -IncompleteType = ty.IncompleteType - -# Expr diff --git a/src/relay/compiler/environment.cc b/src/relay/ir/environment.cc similarity index 99% rename from src/relay/compiler/environment.cc rename to src/relay/ir/environment.cc index a1c6b31076e3..8c155e3bc1bd 100644 --- a/src/relay/compiler/environment.cc +++ b/src/relay/ir/environment.cc @@ -4,7 +4,7 @@ * \brief The global environment in Relay. */ #include -#include "tvm/relay/compiler/environment.h" +#include "tvm/relay/environment.h" // #include "tvm/relay/alpha_eq.h" // #include "tvm/relay/debug.h" // #include "tvm/relay/typeck/typechecker.h" diff --git a/src/relay/compiler/alpha_eq.cc b/src/relay/pass/alpha_eq.cc similarity index 97% rename from src/relay/compiler/alpha_eq.cc rename to src/relay/pass/alpha_eq.cc index d4f1d888fb69..5247bb5beaef 100644 --- a/src/relay/compiler/alpha_eq.cc +++ b/src/relay/pass/alpha_eq.cc @@ -1,9 +1,9 @@ /*! * Copyright (c) 2018 by Contributors - * \file alpha_eq.cc + * \file src/tvm/relay/pass/alpha_eq.cc * \brief Compute the set of variables not bound in the expression. */ -#include "tvm/relay/compiler/alpha_eq.h" +#include "tvm/relay/pass/alpha_eq.h" #include "tvm/relay/expr_visitor.h" #include "./type_visitor.h" @@ -134,7 +134,7 @@ struct TypeAlphaEq : TypeVisitor { // } }; -bool alpha_eq(const Type &t1, const Type &t2) { +bool AlphaEqual(const Type &t1, const Type &t2) { TypeAlphaEq aeq; aeq.VisitType(t1, t2); return aeq.equal; @@ -277,7 +277,7 @@ TVM_REGISTER_API("relay._make._type_alpha_eq") .set_body([](TVMArgs args, TVMRetValue *ret) { Type t1 = args[0]; Type t2 = args[1]; - *ret = alpha_eq(t1, t2); + *ret = AlphaEqual(t1, t2); }); } // namespace relay diff --git a/src/relay/compiler/incomplete_type.h b/src/relay/pass/incomplete_type.h similarity index 82% rename from src/relay/compiler/incomplete_type.h rename to src/relay/pass/incomplete_type.h index f31a2efdf78d..3967b4e58657 100644 --- a/src/relay/compiler/incomplete_type.h +++ b/src/relay/pass/incomplete_type.h @@ -4,10 +4,10 @@ * \brief A way to defined arbitrary function signature with dispatch on types. */ -#ifndef TVM_RELAY_COMPILER_INCOMPLETE_TYPE_H -#define TVM_RELAY_COMPILER_INCOMPLETE_TYPE_H +#ifndef TVM_RELAY_PASS_INCOMPLETE_TYPE_H +#define TVM_RELAY_PASS_INCOMPLETE_TYPE_H -#include "tvm/relay/ir.h" +#include namespace tvm { namespace relay { @@ -37,4 +37,4 @@ RELAY_DEFINE_NODE_REF(IncompleteType, IncompleteTypeNode, Type); } // namespace relay } // namespace tvm -#endif // TVM_RELAY_COMPILER_INCOMPLETE_TYPE_H +#endif // TVM_RELAY_PASS_INCOMPLETE_TYPE_H diff --git a/src/relay/compiler/resolve.cc b/src/relay/pass/resolve.cc similarity index 92% rename from src/relay/compiler/resolve.cc rename to src/relay/pass/resolve.cc index 236722b23387..e86368854060 100644 --- a/src/relay/compiler/resolve.cc +++ b/src/relay/pass/resolve.cc @@ -1,18 +1,18 @@ /*! * Copyright (c) 2018 by Contributors - * \file unifier.cc - * \brief Data structures for type unification + * \file resolve.cc + * \brief Resolve incomplete types to complete types. */ +#include +#include #include "./resolve.h" #include "./type_visitor.h" -#include "tvm/relay/expr_visitor.h" -#include "tvm/relay/ir.h" namespace tvm { namespace relay { -// We should probably generalize the subst code. +// TODO(@jroesch): We should probably generalize the subst code. struct ResolveTypeType : TypeFVisitor { const TypeUnifier &unifier; diff --git a/src/relay/compiler/resolve.h b/src/relay/pass/resolve.h similarity index 79% rename from src/relay/compiler/resolve.h rename to src/relay/pass/resolve.h index b4e164df6287..5f6cc328a239 100644 --- a/src/relay/compiler/resolve.h +++ b/src/relay/pass/resolve.h @@ -1,13 +1,13 @@ /*! * Copyright (c) 2018 by Contributors - * \file tvm/relay/options.h - * \brief Global options for the Relay IR. + * \file tvm/relay/resolve.h + * \brief Resolve incomplete types to complete types. */ #ifndef TVM_RELAY_TYPECK_RESOLVE_H_ #define TVM_RELAY_TYPECK_RESOLVE_H_ #include -#include "tvm/relay/ir.h" +#include #include "./unifier.h" namespace tvm { diff --git a/src/relay/compiler/type_functor.h b/src/relay/pass/type_functor.h similarity index 95% rename from src/relay/compiler/type_functor.h rename to src/relay/pass/type_functor.h index 5de56837ca10..9adc1a08860e 100644 --- a/src/relay/compiler/type_functor.h +++ b/src/relay/pass/type_functor.h @@ -3,11 +3,11 @@ * \file type_functor.h * \brief A way to defined arbitrary function signature with dispatch on types. */ -#ifndef TVM_RELAY_COMPILER_TYPE_FUNCTOR_H_ -#define TVM_RELAY_COMPILER_TYPE_FUNCTOR_H_ +#ifndef TVM_RELAY_PASS_TYPE_FUNCTOR_H_ +#define TVM_RELAY_PASS_TYPE_FUNCTOR_H_ #include -#include "tvm/relay/ir.h" +#include #include "./incomplete_type.h" namespace tvm { @@ -90,4 +90,4 @@ class TypeFunctor { } // namespace relay } // namespace tvm -#endif // TVM_RELAY_COMPILER_TYPE_FUNCTOR_H_ +#endif // TVM_RELAY_PASS_TYPE_FUNCTOR_H_ diff --git a/src/relay/compiler/type_infer.cc b/src/relay/pass/type_infer.cc similarity index 98% rename from src/relay/compiler/type_infer.cc rename to src/relay/pass/type_infer.cc index e2e5999e7341..d84b3f96d426 100644 --- a/src/relay/compiler/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -20,10 +20,10 @@ * constraints we will trigger an error. */ -#include "tvm/relay/logging.h" -#include "tvm/relay/compiler/type_infer.h" -#include "tvm/relay/error.h" -#include "tvm/relay/expr_functor.h" +#include +#include +#include +#include #include "./incomplete_type.h" #include "./unifier.h" #include "./resolve.h" @@ -335,7 +335,7 @@ class TypeInferencer : private ExprFunctor { // auto fresh_tid = // TypeParamNode::make(ty_param_node->name, ty_param_node->kind); // fn_type = - // type_subst(fn_type, GetRef(ty_param_node), fresh_tid); + // TypeSubst(fn_type, GetRef(ty_param_node), fresh_tid); // fn_type = TypeQuantifierNode::make(fresh_tid, fn_type); // } // } @@ -360,7 +360,7 @@ class TypeInferencer : private ExprFunctor { } Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, {}); - inst_ty = type_subst(fn_ty, subst_map); + inst_ty = TypeSubst(fn_ty, subst_map); // if (!check_kind(t)) { // this->fatal_error("Kind rules broken when instantiating type diff --git a/src/relay/compiler/type_subst.cc b/src/relay/pass/type_subst.cc similarity index 66% rename from src/relay/compiler/type_subst.cc rename to src/relay/pass/type_subst.cc index 6650f59bad51..91713976bcaa 100644 --- a/src/relay/compiler/type_subst.cc +++ b/src/relay/pass/type_subst.cc @@ -9,10 +9,10 @@ namespace tvm { namespace relay { -struct TypeSubst : TypeFVisitor { +struct TypeSubstV : TypeFVisitor { tvm::Map subst_map; - explicit TypeSubst(tvm::Map subst_map) + explicit TypeSubstV(tvm::Map subst_map) : subst_map(subst_map) {} Type VisitType_(const TypeParamNode *op) override { @@ -25,13 +25,13 @@ struct TypeSubst : TypeFVisitor { } }; -Type type_subst(const Type &type, const TypeParam &target, const Type &subst) { - TypeSubst ty_sub({ {target, subst} }); +Type TypeSubst(const Type &type, const TypeParam &target, const Type &subst) { + TypeSubstV ty_sub({ {target, subst} }); return ty_sub.VisitType(type); } -Type type_subst(const Type &type, tvm::Map subst_map) { - TypeSubst ty_sub(subst_map); +Type TypeSubst(const Type &type, tvm::Map subst_map) { + TypeSubstV ty_sub(subst_map); return ty_sub.VisitType(type); } diff --git a/src/relay/compiler/type_subst.h b/src/relay/pass/type_subst.h similarity index 54% rename from src/relay/compiler/type_subst.h rename to src/relay/pass/type_subst.h index 0bf0de5a4b85..3c248fdce3b7 100644 --- a/src/relay/compiler/type_subst.h +++ b/src/relay/pass/type_subst.h @@ -1,18 +1,18 @@ /*! * Copyright (c) 2018 by Contributors * \file typeck/type_subst.h - * \brief Utility function for substituting types + * \brief Utility functions for substituting types. */ #ifndef TVM_RELAY_TYPECK_TYPE_SUBST_H_ #define TVM_RELAY_TYPECK_TYPE_SUBST_H_ -#include "tvm/relay/ir.h" +#include namespace tvm { namespace relay { -Type type_subst(const Type & type, const TypeParam & target, const Type & subst); -Type type_subst(const Type &type, tvm::Map subst_map); +Type TypeSubst(const Type & type, const TypeParam & target, const Type & subst); +Type TypeSubst(const Type &type, tvm::Map subst_map); } // namespace relay } // namespace tvm diff --git a/src/relay/compiler/type_visitor.h b/src/relay/pass/type_visitor.h similarity index 100% rename from src/relay/compiler/type_visitor.h rename to src/relay/pass/type_visitor.h diff --git a/src/relay/compiler/unifier.cc b/src/relay/pass/unifier.cc similarity index 96% rename from src/relay/compiler/unifier.cc rename to src/relay/pass/unifier.cc index 2f728a104530..c6a4e7dfba6d 100644 --- a/src/relay/compiler/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -1,12 +1,14 @@ /*! * Copyright (c) 2018 by Contributors - * \file unifier.cc - * \brief Data structures for type unification + * \file tvm/src/relay/pass/unifier.cc + * \brief The type unifier which solves a system of equations between + * incomplete types. */ -#include "tvm/relay/ir.h" -#include "tvm/relay/logging.h" -#include "tvm/relay/compiler/alpha_eq.h" +#include +#include +#include +#include #include "./unifier.h" #include "./type_visitor.h" //#include "./type_subst.h" @@ -32,8 +34,8 @@ void UnionFindNode::debug() { } } -void UnionFindNode::assertAlphaEq(const Type & l, const Type & r) { - if (!alpha_eq(l, r)) { +void UnionFindNode::AssertAlphaEqual(const Type & l, const Type & r) { + if (!AlphaEqual(l, r)) { std::stringstream ss; ss << "Incompatible parent types in UF:" << l << " and " << r; throw UnionFindError(ss.str()); @@ -71,7 +73,7 @@ void UnionFindNode::unify(const IncompleteType &v1, const Type &t) { } // if both parents are not type vars themselves, check alpha-equality - assertAlphaEq(parent1, parent2); + AssertAlphaEqual(parent1, parent2); return; } @@ -83,7 +85,7 @@ void UnionFindNode::unify(const IncompleteType &v1, const Type &t) { return; } - assertAlphaEq(parent1, t); + AssertAlphaEqual(parent1, t); } Type UnionFindNode::find(const IncompleteType &v) { @@ -274,7 +276,7 @@ Type TypeUnifierNode::VisitType_(const TensorTypeNode *t1, const Type rt2) { if (const TensorTypeNode *ttn2 = rt2.as()) { TensorType tt2 = GetRef(ttn2); - if (!alpha_eq(tt1, tt2)) { + if (!AlphaEqual(tt1, tt2)) { throw UnificationError("dtypes do not match"); } diff --git a/src/relay/compiler/unifier.h b/src/relay/pass/unifier.h similarity index 95% rename from src/relay/compiler/unifier.h rename to src/relay/pass/unifier.h index 40583b16a55a..aecc428cb6a9 100644 --- a/src/relay/compiler/unifier.h +++ b/src/relay/pass/unifier.h @@ -1,15 +1,15 @@ /*! * Copyright (c) 2018 by Contributors - * \file unifier.h + * \file include/tvm/relay/pass/unifier.h * \brief The type unifier which solves a system of equations between * incomplete types. */ -#ifndef TVM_RELAY_COMPILER_UNIFIER_H_ -#define TVM_RELAY_COMPILER_UNIFIER_H_ +#ifndef TVM_RELAY_PASS_UNIFIER_H_ +#define TVM_RELAY_PASS_UNIFIER_H_ #include +#include #include "./type_functor.h" -#include "tvm/relay/ir.h" namespace tvm { namespace relay { @@ -50,7 +50,7 @@ class UnionFindNode : public Node { void debug(); - void assertAlphaEq(const Type& l, const Type& r); + void AssertAlphaEqual(const Type& l, const Type& r); static constexpr const char* _type_key = "relay.UnionFind"; TVM_DECLARE_NODE_TYPE_INFO(UnionFindNode, Node); From 7ad860fa0f609916f7c776c8c75af6210bacbcff Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 26 Aug 2018 21:55:56 -0700 Subject: [PATCH 044/136] Repair tests --- python/tvm/relay/op/__init__.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 3d87d78fe633..d54f47e25197 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -9,7 +9,3 @@ from . import _tensor from ..expr import Expr from ..base import register_relay_node - -@register_relay_node -class Op(Expr): - pass From 53536633dc3ab687600e8423c4e5f22a10620eb7 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 27 Aug 2018 16:34:55 -0700 Subject: [PATCH 045/136] Work on type relation --- include/tvm/relay/op.h | 19 +++++++- python/tvm/relay/op/tensor.py | 17 +++++++ src/relay/op/tensor/elemwise.cc | 25 ++++++++--- src/relay/op/type_relations.cc | 45 +++++++++++++++++++ src/relay/op/type_relations.h | 22 +++++++++ src/relay/pass/alpha_eq.cc | 44 ++++++++++-------- src/relay/pass/type_infer.cc | 32 +++++++------ src/relay/pass/type_visitor.h | 4 ++ ...ecker.py => test_tyck_eval_integration.py} | 16 ++++++- 9 files changed, 184 insertions(+), 40 deletions(-) create mode 100644 src/relay/op/type_relations.cc create mode 100644 src/relay/op/type_relations.h rename tests/python/relay/{test_typechecker.py => test_tyck_eval_integration.py} (65%) diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 630a231ebb54..c91955460f82 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -345,9 +345,26 @@ inline OpRegistry& OpRegistry::add_argument(const std::string &name, inline OpRegistry& OpRegistry::add_type_func(const std::string & type_func_name, TypeRelationFn type_fn) { auto type_func = TypeRelationNode::make(type_func_name, 0); + + std::vector type_params; + std::vector arg_types; + // TODO (@jroesch: revise type generation strategy + int i = 0; for (auto arg : get()->arguments) { - std::cout << arg << std::endl; + std::string name = "t"; + name += std::to_string(i++); + auto param = TypeParamNode::make(name, TypeParamNode::Kind::kType); + type_params.push_back(param); + arg_types.push_back(param); } + + + auto type_result = TypeCallNode::make(type_func, arg_types); + + auto func_type = FuncTypeNode::make(arg_types, type_result, type_params, {}); + + get()->op_type = func_type; + return *this; } diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index 7155db3a4cd5..aa9ce6bf42e9 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -58,3 +58,20 @@ def sqrt(data): The computed result. """ return _make.sqrt(data) + +def add(lhs, rhs): + """Take sqrt of data. + + Parameters + ---------- + lhs : relay.Expr + The left hand side input data + rhs : relay.Expr + The right hand side input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.add(lhs, rhs) diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc index 05e1cbd57b13..700e9185ccba 100644 --- a/src/relay/op/tensor/elemwise.cc +++ b/src/relay/op/tensor/elemwise.cc @@ -5,6 +5,7 @@ */ #include #include +#include "../type_relations.h" namespace tvm { namespace relay { @@ -36,9 +37,7 @@ RELAY_REGISTER_UNARY_OP("log") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_func("Log", [](const Array & t, int num_args) { - return t; -}); +.add_type_func("Log", IdentityRel); RELAY_REGISTER_UNARY_OP("exp") @@ -48,7 +47,8 @@ RELAY_REGISTER_UNARY_OP("exp") \exp(x) )code" TVM_ADD_FILELINE) -.set_support_level(1); +.set_support_level(1) +.add_type_func("Exp", IdentityRel); RELAY_REGISTER_UNARY_OP("sqrt") @@ -58,7 +58,22 @@ RELAY_REGISTER_UNARY_OP("sqrt") sqrt(x) )code" TVM_ADD_FILELINE) -.set_support_level(1); +.set_support_level(1) +.add_type_func("Sqrt", IdentityRel); + +// Addition +TVM_REGISTER_API("relay.op._make.add") + .set_body_typed([](Expr lhs, Expr rhs) { + static const Op& op = Op::Get("add"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); + }); + +RELAY_REGISTER_OP("add") + .set_num_inputs(2) + .add_argument("lhs", "Tensor", "The left hand side tensor.") + .add_argument("rhs", "Tensor", "The right hand side tensor.") + .set_support_level(1) + .add_type_func("Broadcast", BroadcastRel); } // namespace relayv } // namespace tvm diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc new file mode 100644 index 000000000000..a5ba1dc14b5f --- /dev/null +++ b/src/relay/op/type_relations.cc @@ -0,0 +1,45 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file type_relations.cc + * \brief A set of utilities and common functionality + * for type relations. + */ +#include +#include +#include "../pass/incomplete_type.h" + +namespace tvm { +namespace relay { + +TensorType as_ttype(const Type & t) { + if (auto tt_node = t.as()) { + return GetRef(tt_node); + } else { + return TensorType(nullptr); + } +} + +Array IdentityRel(const Array & types, int num_args) { + CHECK(types.size() == 1); + auto t1 = as_ttype(types[0]); + if (t1 && types[1].as()) { + return {t1, t1}; + } else { + return types; + } +} + +Array BroadcastRel(const Array & types, int num_args) { + std::cout << "Inside of Broadcast" << std::endl; + CHECK(types.size() == 0); + if (auto t1 = as_ttype(types[0])) { + if (auto t2 = as_ttype(types[1])) { + return types; + } + } + return types; +} + + +} // namespace relayv +} // namespace tvm diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h new file mode 100644 index 000000000000..f2c4876705b6 --- /dev/null +++ b/src/relay/op/type_relations.h @@ -0,0 +1,22 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/op/type_relations.h + * \brief A set of utilities and common functionality + * for type relations. + */ +#ifndef TVM_RELAY_TYPECK_RESOLVE_H_ +#define TVM_RELAY_TYPECK_RESOLVE_H_ + +#include +#include + +namespace tvm { +namespace relay { + +Array IdentityRel(const Array & types, int num_args); +Array BroadcastRel(const Array & types, int num_args); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_TYPECK_RESOLVE_H_ diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc index 5247bb5beaef..555d4f2db99d 100644 --- a/src/relay/pass/alpha_eq.cc +++ b/src/relay/pass/alpha_eq.cc @@ -92,8 +92,14 @@ struct TypeAlphaEq : TypeVisitor { } } - void VisitType_(const TypeRelationNode *op, const Type &t2) override { + void VisitType_(const TypeRelationNode *tr1, const Type &t2) override { + if (const TypeRelationNode *tr2 = t2.as()) { + equal = tr1 == tr2; + } else { + equal = false; + } } + // void VisitType_(const TupleTypeNode *op, const Type &t2) override { // if (const TupleTypeNode *pt = t2.as()) { // if (op->fields.size() != pt->fields.size()) { @@ -112,26 +118,26 @@ struct TypeAlphaEq : TypeVisitor { // } // } -// void VisitType_(const TypeCallNode *tyn1, const Type &t2) override { -// TypeCall tycall = GetRef(tyn1); -// if (const TypeCallNode *tyn2 = t2.as()) { -// if (tycall->func != tyn2->func) { -// equal = false; -// return; -// } + void VisitType_(const TypeCallNode *tyn1, const Type &t2) override { + TypeCall tycall = GetRef(tyn1); + if (const TypeCallNode *tyn2 = t2.as()) { + if (tycall->func != tyn2->func) { + equal = false; + return; + } -// if (tycall->args.size() != tyn2->args.size()) { -// equal = false; -// return; -// } + if (tycall->args.size() != tyn2->args.size()) { + equal = false; + return; + } -// for (size_t i = 0U; i < tycall->args.size(); i++) { -// this->VisitType(tycall->args[i], tyn2->args[i]); -// } -// } else { -// equal = false; -// } -// } + for (size_t i = 0U; i < tycall->args.size(); i++) { + this->VisitType(tycall->args[i], tyn2->args[i]); + } + } else { + equal = false; + } + } }; bool AlphaEqual(const Type &t1, const Type &t2) { diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index d84b3f96d426..b9cfd5837c4b 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -28,17 +28,8 @@ #include "./unifier.h" #include "./resolve.h" #include "./type_subst.h" -// #include "tvm/relay/alpha_eq.h" -// #include "tvm/relay/debug.h" -// #include "tvm/relay/first_order_reverse_ad.h" -// #include "tvm/relay/free_type_vars.h" -// #include "tvm/relay/gen_fresh.h" -// #include "tvm/relay/ir.h" -// #include "tvm/relay/pretty_printer.h" -// #include "tvm/relay/reverse_ad.h" -// #include "tvm/relay/type_visitor.h" +#include "./type_visitor.h" // #include "tvm/relay/typeck/kindchecker.h" -// #include "tvm/relay/typeck/shape_evaluator.h" namespace tvm { namespace relay { @@ -68,6 +59,12 @@ struct TypeContext { }; }; +struct TypeNormalizer : TypeFVisitor { + TypeUnifier unifier; + TypeNormalizer(const TypeUnifier & unifier) : unifier(unifier) {} + // Type VisitType_( +}; + struct CheckedExpr { Expr expr; Type type; @@ -98,6 +95,8 @@ class TypeInferencer : private ExprFunctor { FuncType instantiate(FuncType fn_ty, tvm::Array &ty_args); + Type Normalize(const Type & t); + void report_error(const std::string & msg, Span sp); [[ noreturn ]] void fatal_error(const std::string & msg, Span sp); @@ -129,11 +128,17 @@ class TypeInferencer : private ExprFunctor { this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); } + Type TypeInferencer::Normalize(const Type & t) { + auto nt = this->resolve(t); + auto normalizer = TypeNormalizer(this->unifier); + return normalizer.VisitType(nt); + } + CheckedExpr TypeInferencer::Infer(const Expr &expr) { RELAY_LOG(INFO) << "TypeInferencer::Check expr=" << expr << std::endl; CheckedExpr checked_expr = this->VisitExpr(expr); RELAY_LOG(INFO) << "TypeInferencer::Check type=" << checked_expr.type << std::endl; - Type final_type = this->unifier->subst(checked_expr.type); + Type final_type = Normalize(checked_expr.type); RELAY_LOG(INFO) << "TypeInferencer::Check type_after_subst=" << final_type << std::endl; checked_expr.expr->checked_type_ = final_type; return checked_expr; @@ -498,8 +503,9 @@ class TypeInferencer : private ExprFunctor { ifn->cond->span); } - CheckedExpr TypeInferencer::VisitExpr_(const OpNode *op) { - return { GetRef(op), FuncTypeNode::make({}, TensorTypeNode::Int(32), {}, {} )}; + CheckedExpr TypeInferencer::VisitExpr_(const OpNode *op_node) { + auto op = GetRef(op_node); + return { op, op->op_type }; } Type TypeInferencer::resolve(const Type &t) { diff --git a/src/relay/pass/type_visitor.h b/src/relay/pass/type_visitor.h index c98ff3ab8958..68dba76644c3 100644 --- a/src/relay/pass/type_visitor.h +++ b/src/relay/pass/type_visitor.h @@ -85,6 +85,10 @@ struct TypeFVisitor : TypeFunctor { // return TupleTypeNode::make(new_fields); // } + Type VisitType_(const TypeRelationNode* op) override { + return GetRef(op); + } + Type VisitType_(const TypeCallNode* op) override { auto func = this->VisitType(op->func); std::vector new_args; diff --git a/tests/python/relay/test_typechecker.py b/tests/python/relay/test_tyck_eval_integration.py similarity index 65% rename from tests/python/relay/test_typechecker.py rename to tests/python/relay/test_tyck_eval_integration.py index d111bba9dfbf..d96682fbfda4 100644 --- a/tests/python/relay/test_typechecker.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -4,11 +4,12 @@ from tvm.relay.type_infer import check_expr from tvm.relay.ir_builder import IRBuilder, float_type, func_type from tvm.relay.env import Environment -from tvm.relay.op import log +from tvm.relay.op import log, add def has_type(expr, typ): env = Environment({}) checked_expr = check_expr(env, expr) + import pdb; pdb.set_trace() return checked_expr.checked_type() == typ def test_monomorphic_let(): @@ -30,6 +31,17 @@ def test_single_op(): b.ret(t1) assert has_type(func.to_func(), func_type([float_type()], float_type())) +def test_dual_op(): + "Program: fn (x : float32) { let t1 = f(x); let t2 = g(t1, x); t1 }" + b = IRBuilder() + with b.function(('x', float_type())) as func: + x, = func.param_ids() + t1 = b.let('t1', log(x)) + t2 = b.let('t2', add(t1, x)) + b.ret(t2) + assert has_type(func.to_func(), func_type([float_type()], float_type())) + if __name__ == "__main__": - test_monomorphic_let() + # test_monomorphic_let() test_single_op() + test_dual_op() From ac254c392326034adce40ee8e940fc7b8117db48 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 27 Aug 2018 21:22:32 -0700 Subject: [PATCH 046/136] Fix find and replace bug --- src/codegen/spirv/ir_builder.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/codegen/spirv/ir_builder.cc b/src/codegen/spirv/ir_builder.cc index 87987dbf08e9..41cb48c5854b 100644 --- a/src/codegen/spirv/ir_builder.cc +++ b/src/codegen/spirv/ir_builder.cc @@ -41,7 +41,7 @@ void IRBuilder::InitPreDefs() { t_void_.id = id_counter_++; ib_.Begin(spv::OpTypeVoid).Add(t_void_).Commit(&global_); t_void_func_.id = id_counter_++; - ib_.Begin(spv::OpTypeRelation) + ib_.Begin(spv::OpTypeFunction) .AddSeq(t_void_func_, t_void_).Commit(&global_); } From b79f63adb1992360b599364ab9150ae2f9daded1 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 11:54:35 -0700 Subject: [PATCH 047/136] Add normalization for type relations --- include/tvm/relay/base.h | 17 +++--- include/tvm/relay/expr_visitor.h | 2 +- include/tvm/relay/op.h | 59 +++++++++++++------ include/tvm/relay/pass/alpha_eq.h | 4 +- include/tvm/relay/type.h | 5 +- src/relay/ir/expr.cc | 35 +++++++---- src/relay/ir/type.cc | 24 ++++---- src/relay/op/tensor/elemwise.cc | 10 ++++ src/relay/op/type_relations.cc | 2 +- src/relay/pass/type_infer.cc | 48 +++++++++------ .../relay/test_tyck_eval_integration.py | 2 +- 11 files changed, 130 insertions(+), 78 deletions(-) diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 3b31aae52617..f25d6e6532df 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -49,15 +49,16 @@ using NodeEqual = ::tvm::NodeEqual; * \param NodeName The internal contrainer name. * \param NodeRefBase The base type. */ -#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \ - class TypeName : public NodeRefBase { \ - public: \ - TypeName() {} \ +#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \ + class TypeName : public NodeRefBase { \ + public: \ + TypeName() {} \ explicit TypeName(std::shared_ptr<::tvm::Node> n) : NodeRefBase(n) {} \ - const NodeName* operator->() const { \ - return static_cast(node_.get()); \ - } \ - using ContainerType = NodeName; \ + const NodeName* operator->() const { \ + return static_cast(node_.get()); \ + } \ + operator bool() { return this->defined(); } \ + using ContainerType = NodeName; \ }; diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h index d1e8a99dc374..8803aa5ae48f 100644 --- a/include/tvm/relay/expr_visitor.h +++ b/include/tvm/relay/expr_visitor.h @@ -7,7 +7,7 @@ #ifndef TVM_RELAY_EXPR_VISITOR_H_ #define TVM_RELAY_EXPR_VISITOR_H_ -#include "tvm/relay/expr_functor.h" +#include namespace tvm { namespace relay { diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index c91955460f82..0e5483174c53 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -155,6 +155,15 @@ class OpRegistry { */ inline OpRegistry& add_type_func(const std::string & type_func_name, TypeRelationFn type_fn); + /*! + * \brief Attach the type function corresponding to the return type. + * \param ty_func The type function to register for the return type. + * \return reference to self. + */ + inline OpRegistry& add_type_func( + const std::string & type_func_name, + std::function(const Array &, int)> type_fn); + /*! * \brief Set the type key of attributes. * \param type_key The type of of the attrs field.x @@ -343,30 +352,44 @@ inline OpRegistry& OpRegistry::add_argument(const std::string &name, return *this; } - inline OpRegistry& OpRegistry::add_type_func(const std::string & type_func_name, TypeRelationFn type_fn) { - auto type_func = TypeRelationNode::make(type_func_name, 0); +inline OpRegistry& OpRegistry::add_type_func( + const std::string & type_func_name, + std::function(const Array &, int)> type_fn) { + auto pfunc = runtime::TypedPackedFunc(const Array &, int)>(type_fn); + return add_type_func(type_func_name, pfunc); +} - std::vector type_params; - std::vector arg_types; - // TODO (@jroesch: revise type generation strategy - int i = 0; - for (auto arg : get()->arguments) { - std::string name = "t"; - name += std::to_string(i++); - auto param = TypeParamNode::make(name, TypeParamNode::Kind::kType); - type_params.push_back(param); - arg_types.push_back(param); - } +inline OpRegistry& OpRegistry::add_type_func(const std::string & type_func_name, TypeRelationFn type_fn) { + auto type_func = TypeRelationNode::make(type_func_name, 0, type_fn); + std::vector type_params; + std::vector arg_types; - auto type_result = TypeCallNode::make(type_func, arg_types); + // Add inputs. + int i = 0; + for (auto arg : get()->arguments) { + std::string name = "in"; + name += std::to_string(i++); + auto param = TypeParamNode::make(name, TypeParamNode::Kind::kType); + type_params.push_back(param); + arg_types.push_back(param); + } + + auto ty_call_args = Array(arg_types); + + // Add output type. + auto out_param = TypeParamNode::make("out", TypeParamNode::Kind::kType); + type_params.push_back(out_param); + ty_call_args.push_back(out_param); - auto func_type = FuncTypeNode::make(arg_types, type_result, type_params, {}); + auto type_result = TypeCallNode::make(type_func, ty_call_args); - get()->op_type = func_type; + auto func_type = FuncTypeNode::make(arg_types, type_result, type_params, {}); - return *this; - } + get()->op_type = func_type; + + return *this; +} inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) get()->num_inputs = n; diff --git a/include/tvm/relay/pass/alpha_eq.h b/include/tvm/relay/pass/alpha_eq.h index caa2f93c31a7..9f3c2138a440 100644 --- a/include/tvm/relay/pass/alpha_eq.h +++ b/include/tvm/relay/pass/alpha_eq.h @@ -6,8 +6,8 @@ #ifndef TVM_RELAY_ALPHA_EQ_H_ #define TVM_RELAY_ALPHA_EQ_H_ -#include "tvm/relay/type.h" -#include "tvm/relay/expr.h" +#include "../type.h" +#include "../expr.h" namespace tvm { namespace relay { diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 68ed411a23ed..498053f4f9bb 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -210,7 +210,7 @@ class FuncTypeNode : public TypeNode { RELAY_DEFINE_NODE_REF(FuncType, FuncTypeNode, Type); -using TypeRelationFn = std::function(const Array&, int)>; +using TypeRelationFn = runtime::TypedPackedFunc(const Array&, int)>; /*! * \brief Opaque type relation, is an input-output relation on types. @@ -239,7 +239,7 @@ class TypeRelationNode : public RelayNode { v->Visit("num_args", &num_args); } - TVM_DLL static TypeRelation make(std::string name, int num_args); + TVM_DLL static TypeRelation make(std::string name, int num_args, TypeRelationFn func_); static constexpr const char* _type_key = "relay.TypeRelation"; TVM_DECLARE_NODE_TYPE_INFO(TypeRelationNode, RelayNode); @@ -258,6 +258,7 @@ class TypeCallNode : public TypeNode { public: /*! \brief The type function to be called. */ Type func; + /*! \brief The type arguments to the type function. */ tvm::Array args; diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 3a3ef1b52604..2b235e8b01ad 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -1,10 +1,10 @@ /*! * Copyright (c) 2018 by Contributors - * \file expr.cc + * \file src/tvm/ir/expr.cc * \brief The expression AST nodes of Relay. */ -#include "tvm/relay/expr.h" -#include "tvm/ir_functor.h" +#include +#include namespace tvm { namespace relay { @@ -29,6 +29,19 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "ConstantNode(TODO)"; }); +TensorType ConstantNode::tensor_type() const { + auto dl_dtype = data->dtype; + auto dtype = HalideIR::Type(static_cast(dl_dtype.code), + dl_dtype.bits, dl_dtype.lanes); + + Array shape; + for (int i = 0; i < data->ndim; i++) { + shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), data->shape[i])); + } + + return TensorTypeNode::make(shape, dtype); +} + Tuple TupleNode::make(tvm::Array fields) { std::shared_ptr n = std::make_shared(); n->fields = std::move(fields); @@ -114,11 +127,8 @@ TVM_REGISTER_API("relay._make.Function") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const FunctionNode *node, tvm::IRPrinter *p) { - p->stream << "FunctionNode(" << - node->params << ", " << - node->ret_type << ", " << - node->body << ", " << - node->type_params << ")"; + p->stream << "FunctionNode(" << node->params << ", " << node->ret_type + << ", " << node->body << ", " << node->type_params << ")"; }); Call CallNode::make(Expr op, Array args, Attrs attrs, @@ -158,7 +168,8 @@ TVM_REGISTER_API("relay._make.Let") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const LetNode *node, tvm::IRPrinter *p) { - p->stream << "LetNode(" << node->var << node->value << node->body << node->value_type << ")"; + p->stream << "LetNode(" << node->var << node->value << node->body + << node->value_type << ")"; }); If IfNode::make(Expr cond, Expr true_value, Expr false_value) { @@ -175,10 +186,8 @@ TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) { TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const IfNode *node, tvm::IRPrinter *p) { - p->stream << "IfNode(" << - node->cond << ", " << - node->true_value << - node->false_value << ")"; + p->stream << "IfNode(" << node->cond << ", " << node->true_value + << node->false_value << ")"; }); } // namespace relay diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index d9e2737225ec..abed09a69d7b 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -1,11 +1,10 @@ /*! * Copyright (c) 2018 by Contributors - * \file type.cc + * \file src/tvm/ir/type.cc * \brief The type system AST nodes of Relay. */ -#include "tvm/relay/type.h" -#include "tvm/ir_functor.h" - +#include +#include namespace tvm { namespace relay { @@ -42,7 +41,6 @@ TVM_REGISTER_API("relay._make.TensorType") *ret = TensorTypeNode::make(shape, args[1]); }); - TVM_REGISTER_API("relay._make.IntType") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = TensorTypeNode::Int(args[0], args[1]); @@ -91,10 +89,9 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << node->kind << ")"; }); - FuncType FuncTypeNode::make(tvm::Array arg_types, Type ret_type, - tvm::Array type_params, - tvm::Array type_constraints) { + tvm::Array type_params, + tvm::Array type_constraints) { std::shared_ptr n = std::make_shared(); n->arg_types = std::move(arg_types); n->ret_type = std::move(ret_type); @@ -116,22 +113,24 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << node->type_constraints << ")"; }); -TypeRelation TypeRelationNode::make(std::string name, int num_args) { +TypeRelation TypeRelationNode::make(std::string name, int num_args, TypeRelationFn func) { std::shared_ptr n = std::make_shared(); n->name = std::move(name); n->num_args = std::move(num_args); + n->func_ = std::move(func); return TypeRelation(n); } TVM_REGISTER_API("relay._make.TypeRelation") .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = TypeRelationNode::make(args[0], args[1]); + *ret = TypeRelationNode::make(args[0], args[1], args[2]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const TypeRelationNode *node, - tvm::IRPrinter *p) { - p->stream << "TypeRelationNode(" << node->name << ", " << node->num_args << ")"; + tvm::IRPrinter *p) { + p->stream << "TypeRelationNode(" << node->name << ", " << node->num_args + << ")"; }); TypeCall TypeCallNode::make(Type func, Array args) { @@ -152,6 +151,5 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "TypeCallNode(" << node->func << ", " << node->args << ")"; }); - } // namespace relay } // namespace tvm diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc index 700e9185ccba..cd90705c6476 100644 --- a/src/relay/op/tensor/elemwise.cc +++ b/src/relay/op/tensor/elemwise.cc @@ -39,6 +39,9 @@ RELAY_REGISTER_UNARY_OP("log") .set_support_level(1) .add_type_func("Log", IdentityRel); +// data : Tensor[shape, dtype] +// result: Tensor[shape, dtype] + RELAY_REGISTER_UNARY_OP("exp") .describe(R"code(Returns the exp input array, computed element-wise. @@ -75,5 +78,12 @@ RELAY_REGISTER_OP("add") .set_support_level(1) .add_type_func("Broadcast", BroadcastRel); + // def broadcast(s1, s2): + // ... + // + // input1: Tensor[dtype, s1] + // input2: Tensor[dtype, s2] + // output: Tensor[dtype, broadcast(s1, s2)] + } // namespace relayv } // namespace tvm diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index a5ba1dc14b5f..68fe2c51a365 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -20,7 +20,7 @@ TensorType as_ttype(const Type & t) { } Array IdentityRel(const Array & types, int num_args) { - CHECK(types.size() == 1); + CHECK(types.size() == 2); auto t1 = as_ttype(types[0]); if (t1 && types[1].as()) { return {t1, t1}; diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index b9cfd5837c4b..72fa9cb14bdc 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -62,7 +62,34 @@ struct TypeContext { struct TypeNormalizer : TypeFVisitor { TypeUnifier unifier; TypeNormalizer(const TypeUnifier & unifier) : unifier(unifier) {} - // Type VisitType_( + + Type VisitType_(const TypeCallNode * ty_call_node) { + auto ty_call = GetRef(ty_call_node); + + auto all_concrete = true; + for (auto arg : ty_call->args) { + all_concrete = all_concrete && !arg.as(); + } + + if (all_concrete) { + return ty_call->args[ty_call->args.size() - 1]; + } else { + if (auto ty_rel_node = ty_call->func.as()) { + // NB: we substract 1 for the output argument. + auto new_args = ty_rel_node->func_(ty_call->args, ty_call->args.size() - 1); + CHECK(new_args.size() == ty_call->args.size()); + tvm::Array final_args; + + for (int i = 0; i < new_args.size(); i++) { + final_args.push_back(unifier->unify(ty_call->args[i], new_args[i])); + } + + return TypeCallNode::make(ty_call->func, final_args); + } else { + CHECK(false); + } + } + } }; struct CheckedExpr { @@ -167,26 +194,9 @@ class TypeInferencer : private ExprFunctor { } CheckedExpr TypeInferencer::VisitExpr_(const ConstantNode *const_node) { - auto array = const_node->data; - // array->t - // first pass - return { - GetRef(const_node), - TensorTypeNode::make({}, HalideIR::Float(32, 1)) }; + return { GetRef(const_node), const_node->tensor_type() }; } - // Type TypeInferencer::VisitExpr_(const OpIdNode *op) { - // OpId id = GetRef(op); - // Item item = this->env->lookup(id); - - // if (const OpNode *pn = item.as()) { - // Op prim = GetRef(pn); - // return prim->type; - // } else { - // this->fatal_error("internal error in InstrinsicId case", op->span); - // } - // } - CheckedExpr TypeInferencer::VisitExpr_(const TupleNode *op) { // Tuple pl = GetRef(op); diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index d96682fbfda4..e94158cd44e2 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -13,7 +13,7 @@ def has_type(expr, typ): return checked_expr.checked_type() == typ def test_monomorphic_let(): - "Program: let x = 1; x" + "Program: let x = 1; return x" b = IRBuilder() x = b.let('x', 1, value_type=float_type()) b.ret(x) From 23e64bc8dcf58a3e9a23f1a424a663c53f4eb516 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 15:09:39 -0700 Subject: [PATCH 048/136] Iterating on Broadcast --- include/tvm/relay/type.h | 28 ++++++++- python/tvm/relay/ir_builder.py | 5 +- python/tvm/relay/type.py | 6 +- src/relay/ir/type.cc | 18 ++++++ src/relay/op/type_relations.cc | 58 ++++++++++++++++++- src/relay/pass/type_infer.cc | 16 +++-- .../relay/test_tyck_eval_integration.py | 19 +++--- 7 files changed, 129 insertions(+), 21 deletions(-) diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 498053f4f9bb..a6c801c382de 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -269,8 +269,6 @@ class TypeCallNode : public TypeNode { v->Visit("args", &args); } - Type eval() const; - TVM_DLL static TypeCall make(Type func, tvm::Array args); static constexpr const char* _type_key = "relay.TypeCall"; @@ -279,6 +277,32 @@ class TypeCallNode : public TypeNode { RELAY_DEFINE_NODE_REF(TypeCall, TypeCallNode, Type); +/*! + * \brief The type of tuple values. + */ +class TupleType; +/*! + * \brief TupleType container. + */ +class TupleTypeNode : public TypeNode { + public: + /*! \brief The type of each field in the tuple. */ + tvm::Array fields; + + TupleTypeNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("fields", &fields); + } + + TVM_DLL static TupleType make(tvm::Array fields); + + static constexpr const char* _type_key = "relay.TypeTuple"; + TVM_DECLARE_NODE_TYPE_INFO(TypeTupleNode, TypeNode); +}; + +RELAY_DEFINE_NODE_REF(TypeCall, TypeCallNode, Type); + // The following fields contains advanced typing // Only keep the class name and reserved for future usage. class GenericTensorType; diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 8bd225bd4de1..a9cb02a19025 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -1,7 +1,7 @@ from typing import Any import numpy as np import tvm -from .type import FloatType, IntType, BoolType, UIntType, FuncType +from .type import FloatType, IntType, BoolType, UIntType, FuncType, TensorType from .expr import Expr, Call, Constant, Let, LocalVar, Param, Function from . import op as _op @@ -167,5 +167,8 @@ def float_type(bits=32, lanes=1): def bool_type(lanes=1): return BoolType(lanes) +def tensor_type(*shape, dtype='float32'): + return TensorType(tvm.convert(shape), dtype) + def func_type(args, ret_type, type_params=[], type_constraints=[]): return FuncType(args, ret_type, type_params, type_constraints) diff --git a/python/tvm/relay/type.py b/python/tvm/relay/type.py index c7b8964c20e8..c9a96de4889d 100644 --- a/python/tvm/relay/type.py +++ b/python/tvm/relay/type.py @@ -27,12 +27,12 @@ def same_as(self, other) -> bool: class TensorType(Type): """A concrete TensorType in Relay, see tvm/relay/type.h for more details. """ - dtype: str shape: List[expr.Expr] + dtype: str span: Span - def __init__(self, dtype: str, shape: List[expr.Expr]) -> None: - self.__init_handle_by_constructor__(_make.TensorType,dtype, shape) + def __init__(self, shape: List[expr.Expr], dtype: str) -> None: + self.__init_handle_by_constructor__(_make.TensorType, shape, dtype) class Kind(IntEnum): """The kind of a type parameter, represents a variable shape, diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index abed09a69d7b..1faa9ede8638 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -151,5 +151,23 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "TypeCallNode(" << node->func << ", " << node->args << ")"; }); +TypeCall TupleTypeNode::make(Array fields) { + std::shared_ptr n = std::make_shared(); + n->fields = std::move(fields); + return TupleType(n); +} + +TVM_REGISTER_API("relay._make.TupleType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TupleTypeNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TupleTypeNode *node, + tvm::IRPrinter *p) { + p->stream << "TupleTypeNode(" << node->fields << ")"; + }); + + } // namespace relay } // namespace tvm diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 68fe2c51a365..883b8ecc946d 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -4,6 +4,7 @@ * \brief A set of utilities and common functionality * for type relations. */ +#include #include #include #include "../pass/incomplete_type.h" @@ -19,6 +20,13 @@ TensorType as_ttype(const Type & t) { } } +// TODO(@jroesch) what size value do we extract? +int to_int(const tvm::Expr & e) { + auto imm = e.as(); + CHECK(imm); + return imm->value; +} + Array IdentityRel(const Array & types, int num_args) { CHECK(types.size() == 2); auto t1 = as_ttype(types[0]); @@ -29,12 +37,56 @@ Array IdentityRel(const Array & types, int num_args) { } } +static Type ConcreteBroadcast(const TensorType & t1, const TensorType & t2) { + RELAY_LOG(INFO) << "ConcreteBroadcast: t1=" << t1 << " t2=" << t2 << std::endl; + auto sh1 = t1->shape; + auto sh2 = t2->shape; + RELAY_LOG(INFO) << "ConcreteBroadcast: sh1=" << sh1 << " sh2=" << sh2 << std::endl; + CHECK(sh1.size() > 0); + CHECK(sh2.size() > 0); + + auto suffix_len = static_cast(std::min(sh1.size(), sh2.size())); + auto full_len = static_cast(std::max(sh1.size(), sh2.size())); + + std::cout << "Length" << suffix_len << full_len << (full_len - suffix_len - 1) << std::endl; + auto lower_bound = full_len - suffix_len - 1; + + for (int64_t i = full_len - 1; i > lower_bound; i--) { + std::cout << "Index i=" << i << std::endl; + auto dim1 = to_int(sh1[i]); + auto dim2 = to_int(sh2[i]); + if (dim1 != dim2) { + CHECK(false); + } + } + + Array larger; + Array smaller; + + for (int i = 0; i < (full_len - suffix_len); i++) { + smaller.push_back(tvm::ir::IntImm::make(1)); + } + + if (sh1.size() < sh2.size()) { + + } else if (sh1.size() > sh2.size()) { + + } else { + + } + + for (int i = 0; i < suffix_len - full_len; i++) { + + } + + return t1; +} + Array BroadcastRel(const Array & types, int num_args) { - std::cout << "Inside of Broadcast" << std::endl; - CHECK(types.size() == 0); + CHECK(types.size() == 3); if (auto t1 = as_ttype(types[0])) { if (auto t2 = as_ttype(types[1])) { - return types; + return { t1, t2, ConcreteBroadcast(t1, t2) }; } } return types; diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 72fa9cb14bdc..cc91176feb62 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -66,22 +66,28 @@ struct TypeNormalizer : TypeFVisitor { Type VisitType_(const TypeCallNode * ty_call_node) { auto ty_call = GetRef(ty_call_node); - auto all_concrete = true; + Array normalized_args; + for (auto arg : ty_call->args) { + normalized_args.push_back(VisitType(arg)); + } + + auto all_concrete = true; + for (auto arg : normalized_args) { all_concrete = all_concrete && !arg.as(); } if (all_concrete) { - return ty_call->args[ty_call->args.size() - 1]; + return normalized_args[normalized_args.size() - 1]; } else { if (auto ty_rel_node = ty_call->func.as()) { // NB: we substract 1 for the output argument. auto new_args = ty_rel_node->func_(ty_call->args, ty_call->args.size() - 1); - CHECK(new_args.size() == ty_call->args.size()); + CHECK(new_args.size() == normalized_args.size()); tvm::Array final_args; for (int i = 0; i < new_args.size(); i++) { - final_args.push_back(unifier->unify(ty_call->args[i], new_args[i])); + final_args.push_back(unifier->unify(normalized_args[i], new_args[i])); } return TypeCallNode::make(ty_call->func, final_args); @@ -606,7 +612,7 @@ class TypeInferencer : private ExprFunctor { Type TypeInferencer::unify(const Type &t1, const Type &t2, Span sp) { try { - return this->unifier->unify(t1, t2); + return Normalize(this->unifier->unify(t1, t2)); } catch (const dmlc::Error &e) { std::stringstream ss; ss << "Error unifying `"; diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index e94158cd44e2..6e5b64ee846e 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -2,7 +2,7 @@ for expressions. """ from tvm.relay.type_infer import check_expr -from tvm.relay.ir_builder import IRBuilder, float_type, func_type +from tvm.relay.ir_builder import IRBuilder, float_type, func_type, tensor_type from tvm.relay.env import Environment from tvm.relay.op import log, add @@ -15,12 +15,11 @@ def has_type(expr, typ): def test_monomorphic_let(): "Program: let x = 1; return x" b = IRBuilder() - x = b.let('x', 1, value_type=float_type()) + x = b.let('x', 1.0, value_type=float_type(64)) b.ret(x) prog = b.get() - assert has_type(prog, float_type()) - + assert has_type(prog, float_type(64)) def test_single_op(): "Program: fn (x : float32) { let t1 = f(x); t1 }" @@ -32,9 +31,15 @@ def test_single_op(): assert has_type(func.to_func(), func_type([float_type()], float_type())) def test_dual_op(): - "Program: fn (x : float32) { let t1 = f(x); let t2 = g(t1, x); t1 }" + """Program: + fn (x : Tensor[f32, (10, 10)]) { + let t1 = log(x); + let t2 = add(t1, x); + return t1; + } + """ b = IRBuilder() - with b.function(('x', float_type())) as func: + with b.function(('x', tensor_type(10, 10))) as func: x, = func.param_ids() t1 = b.let('t1', log(x)) t2 = b.let('t2', add(t1, x)) @@ -43,5 +48,5 @@ def test_dual_op(): if __name__ == "__main__": # test_monomorphic_let() - test_single_op() + # test_single_op() test_dual_op() From 918d081c28a05eff2818edf31c766717fe925d47 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 15:36:16 -0700 Subject: [PATCH 049/136] Address CR feedback --- include/tvm/relay/base.h | 9 +++- include/tvm/relay/environment.h | 24 ++++++----- include/tvm/relay/expr.h | 21 +++++----- include/tvm/relay/expr_functor.h | 38 +++-------------- include/tvm/relay/type.h | 4 +- src/relay/ir/type.cc | 2 +- src/relay/op/type_relations.cc | 2 +- src/relay/pass/resolve.cc | 6 +-- src/relay/pass/resolve.h | 6 +-- src/relay/pass/type_functor.h | 2 + src/relay/pass/type_infer.cc | 4 +- src/relay/pass/type_visitor.h | 70 ++++++++++++++++---------------- 12 files changed, 86 insertions(+), 102 deletions(-) diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index f25d6e6532df..092f5ceb8fc3 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -1,7 +1,7 @@ /*! * Copyright (c) 2018 by Contributors * \file tvm/relay/base.h - * \brief Base data structure for relay. + * \brief Base classes for the Relay IR. */ #ifndef TVM_RELAY_BASE_H_ #define TVM_RELAY_BASE_H_ @@ -13,7 +13,12 @@ namespace tvm { /*! - * \brief Relay: high level functional IR + * \brief Relay: a high level functional IR for TVM. + * + * This namespace contains the abstract syntax tree, and other + * essential data structures for the Relay IR. + * + * You can find more about Relay by reading the language reference. */ namespace relay { /*! diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h index ff8e596059b5..ce874103a0a1 100644 --- a/include/tvm/relay/environment.h +++ b/include/tvm/relay/environment.h @@ -1,7 +1,8 @@ /*! * Copyright (c) 2018 by Contributors * \file tvm/relay/environment.h - * \brief The global environment, contains global state of Relay program. + * \brief The global environment: contains information needed to + * compile & optimize Relay programs. */ #ifndef TVM_RELAY_ENVIRONMENT_H_ #define TVM_RELAY_ENVIRONMENT_H_ @@ -21,18 +22,21 @@ struct Environment; /*! \brief The global environment of Relay programs. * - * The global environment contains all the global - * information needed to compile a Relay program, - * including the set of operators, the set of - * global functions, and configuration options. + * The global environment contains the global + * information needed to compile a Relay program. + * + * It contains all global functions, and configuration + * options. * * Many operations require acess to the global - * Environment. We mostly pass the argument by value - * in a functional style as an explicit argument. + * Environment. We pass the Environment by value + * in a functional style as an explicit argument, + * but we will mutate the Environment while optimizing + * Relay programs. * - * This means users can construct custom environments - * easily, for example a fresh environment for each - * thread while auto-tuning. + * The functional style allows users to construct custom + * environments easily, for example each thread can store + * an Environment while auto-tuning. * */ class EnvironmentNode : public RelayNode { diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index ff11a41a6e5f..5fe91702a29f 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -1,7 +1,7 @@ /*! * Copyright (c) 2018 by Contributors * \file tvm/relay/expr.h - * \brief The Relay IR expression nodes. + * \brief Relay expression language. */ #ifndef TVM_RELAY_EXPR_H_ #define TVM_RELAY_EXPR_H_ @@ -16,11 +16,8 @@ namespace tvm { namespace relay { -// TOD0(@jroesch): best way to define? -class TypeInferencer; - /*! - * \brief Relay expression. + * \brief A Relay expression. */ class Expr; /*! @@ -28,7 +25,6 @@ class Expr; */ class ExprNode : public RelayNode { public: - // private: /*! * \brief Stores the result of type inference(type checking). * @@ -48,7 +44,6 @@ class ExprNode : public RelayNode { static constexpr const char* _type_key = "relay.Expr"; TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode); - friend class TypeInferencer; }; RELAY_DEFINE_NODE_REF(Expr, ExprNode, NodeRef); @@ -68,8 +63,6 @@ class ConstantNode : public ExprNode { /*! \brief The data of the tensor */ runtime::NDArray data; - // TODO(tqchen) add the function after we get TensorType constructor - // TODO(tqchen) create simple TensorType constructor for concrete types. /*! \return The corresponding tensor type of the data */ TensorType tensor_type() const; @@ -335,6 +328,12 @@ RELAY_DEFINE_NODE_REF(Let, LetNode, Expr); /*! * \brief Condition expression + * + * Unlike traditional statement `if`s, the if evalutes + * to the result of the branch taken. + * + * let x = if (true) { 1 } else { 0 }; // x is 1 + * let y = if (false) { 1 } else { 0 }; // y is 0 */ class If; /*! \brief container of If */ @@ -342,9 +341,9 @@ class IfNode : public ExprNode { public: /*! \brief The condition */ Expr cond; - /*! \brief The value to take when condition is true */ + /*! \brief The expression evaluated when condition is true. */ Expr true_value; - /*! \brief The value to take when condition is false */ + /*! \brief The expression evaluated when condition is false */ Expr false_value; IfNode() {} diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index e37a454eee41..4632733cbcfc 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -1,8 +1,8 @@ /*! * Copyright (c) 2018 by Contributors - * \file expr_functor.h - * \brief A more powerful Visitor that enables defining arbitrary function - * signatures with dispatch on first argument. + * \file tvm/relay/expr_functor.h + * \brief A more powerful visitor which enables defining arbitrary function + * signatures with type based dispatch on first argument. */ #ifndef TVM_RELAY_EXPR_FUNCTOR_H_ #define TVM_RELAY_EXPR_FUNCTOR_H_ @@ -19,36 +19,8 @@ namespace relay { * \brief A dynamical functor that dispatches on in the first Expr argument. * You can use this as a more powerful Visitor, since it allows you to * define function signatures of Visit Function. - * - * This helps you to avoid to book-keep return value of Visitor via state, - * which can cause bugs easily when state is incorrectly maintained. - * - * \code - * // A functor that set variable to b. and calculate results. - * class MyExprFunctor - * : public ir::ExprFunctor { - * public: - * int VisitExpr_(const Variable* op, int b) final { - * return b; - * } - * int VisitExpr_(const IntImm* op, int b) final { - * return op->value; - * } - * int VisitExpr_(const Add* op, int b) final { - * return Visit(op->a, b) + Visit(op->b, b); - * } - * }; - * MyExprFunctor f; - * Var x("x"); - * CHECK_EQ(f(x + 1, 2), 3); - * \endcode - * - * \note Why do we need this more powerful Functor: - * - * We often need to implement a transformer tasks. - * Say we want to take Expr and transform it to some analysis result, - * This easily be done incorrectly using plain Visitor. See IRVisitor's - * document for possible error cases. + * + * \sa tvm/ir_functor.h * * \tparam FType function signiture * This type if only defined for FType with function signiture R(const Expr&, diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index a6c801c382de..5d579b661280 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -298,10 +298,10 @@ class TupleTypeNode : public TypeNode { TVM_DLL static TupleType make(tvm::Array fields); static constexpr const char* _type_key = "relay.TypeTuple"; - TVM_DECLARE_NODE_TYPE_INFO(TypeTupleNode, TypeNode); + TVM_DECLARE_NODE_TYPE_INFO(TupleTypeNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(TypeCall, TypeCallNode, Type); +RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type); // The following fields contains advanced typing // Only keep the class name and reserved for future usage. diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index 1faa9ede8638..e29f3cbde4c1 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -151,7 +151,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "TypeCallNode(" << node->func << ", " << node->args << ")"; }); -TypeCall TupleTypeNode::make(Array fields) { +TupleType TupleTypeNode::make(Array fields) { std::shared_ptr n = std::make_shared(); n->fields = std::move(fields); return TupleType(n); diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 883b8ecc946d..56b139731178 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -64,7 +64,7 @@ static Type ConcreteBroadcast(const TensorType & t1, const TensorType & t2) { Array smaller; for (int i = 0; i < (full_len - suffix_len); i++) { - smaller.push_back(tvm::ir::IntImm::make(1)); + // smaller.push_back(tvm::ir::IntImm::make(1)); } if (sh1.size() < sh2.size()) { diff --git a/src/relay/pass/resolve.cc b/src/relay/pass/resolve.cc index e86368854060..f18a67bcffc9 100644 --- a/src/relay/pass/resolve.cc +++ b/src/relay/pass/resolve.cc @@ -64,12 +64,12 @@ struct ResolveTypeExpr : ExprFVisitor<> { } }; -Type resolve(const TypeUnifier &unifier, const Type &ty) { +Type Resolve(const TypeUnifier &unifier, const Type &ty) { CHECK(ty.defined()); return ResolveTypeType(unifier).VisitType(ty); } -Expr resolve(const TypeUnifier &unifier, const Expr &expr) { +Expr Resolve(const TypeUnifier &unifier, const Expr &expr) { return ResolveTypeExpr(unifier).VisitExpr(expr); } @@ -91,7 +91,7 @@ struct FullyResolved : TypeVisitor<> { } }; -bool is_fully_resolved(const Type &t) { +bool IsFullyResolved(const Type &t) { auto fr = FullyResolved(); fr.VisitType(t); return fr.incomplete; diff --git a/src/relay/pass/resolve.h b/src/relay/pass/resolve.h index 5f6cc328a239..495c9658238a 100644 --- a/src/relay/pass/resolve.h +++ b/src/relay/pass/resolve.h @@ -13,9 +13,9 @@ namespace tvm { namespace relay { -Type resolve(const TypeUnifier & unifier, const Type & ty); -Expr resolve(const TypeUnifier & unifier, const Expr & expr); -bool is_fully_resolved(const Type & t); +Type Resolve(const TypeUnifier & unifier, const Type & ty); +Expr Resolve(const TypeUnifier & unifier, const Expr & expr); +bool IsFullyResolved(const Type & t); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/type_functor.h b/src/relay/pass/type_functor.h index 9adc1a08860e..9180703b49e8 100644 --- a/src/relay/pass/type_functor.h +++ b/src/relay/pass/type_functor.h @@ -65,6 +65,7 @@ class TypeFunctor { virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitTypeDefault_(const Node* op, Args...) { @@ -83,6 +84,7 @@ class TypeFunctor { RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode); RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode); RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode); RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode); return vtable; } diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index cc91176feb62..5e9e784dbe83 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -526,7 +526,7 @@ class TypeInferencer : private ExprFunctor { Type TypeInferencer::resolve(const Type &t) { if (t.defined()) { - return ::tvm::relay::resolve(this->unifier, t); + return ::tvm::relay::Resolve(this->unifier, t); } else { return IncompleteTypeNode::make(TypeParamNode::Kind::kType); } @@ -534,7 +534,7 @@ class TypeInferencer : private ExprFunctor { Expr TypeInferencer::resolve(const Expr &e) { CHECK(e.defined()); - return ::tvm::relay::resolve(this->unifier, e); + return ::tvm::relay::Resolve(this->unifier, e); } void TypeInferencer::CheckOp(Op op) { diff --git a/src/relay/pass/type_visitor.h b/src/relay/pass/type_visitor.h index 68dba76644c3..f3c0f9a74fb7 100644 --- a/src/relay/pass/type_visitor.h +++ b/src/relay/pass/type_visitor.h @@ -22,7 +22,7 @@ struct TypeVisitor : ::tvm::relay::TypeFunctor { void VisitType_(const TypeParamNode* op, Args... args) override {} void VisitType_(const FuncTypeNode* op, Args... args) override { - // fix me handle poly + // TODO(@jroesch): handle poly // this->VisitType(op->var, args...); // this->VisitType(op->boundType, args...); for (auto arg_type : op->arg_types) { @@ -33,11 +33,11 @@ struct TypeVisitor : ::tvm::relay::TypeFunctor { void VisitType_(const TensorTypeNode* op, Args... args) override {} - // void VisitType_(const TupleTypeNode* op, Args... args) override { - // for (const Type& t : op->fields) { - // this->VisitType(t, args...); - // } - // } + void VisitType_(const TupleTypeNode* op, Args... args) override { + for (const Type& t : op->fields) { + this->VisitType(t, args...); + } + } void VisitType_(const TypeCallNode* op, Args... args) override { this->VisitType(op->func, args...); @@ -63,46 +63,48 @@ struct TypeFVisitor : TypeFunctor { } Type VisitType_(const FuncTypeNode* op) override { + // TODO (@jroesch): handle poly + // auto new_id = this->VisitType(op->var); // if (const TypeParamNode* tin = new_id.as()) { // return TypeQuantifierNode::make(GetRef(tin), // this->VisitType(op->boundType)); - std::vector args; - for (auto arg_type : op->arg_types) { - args.push_back(VisitType(arg_type)); - } - - return FuncTypeNode::make(tvm::Array(args), - VisitType(op->ret_type), {}, {}); // fix me + std::vector args; + for (auto arg_type : op->arg_types) { + args.push_back(VisitType(arg_type)); } - // Type VisitType_(const TupleTypeNode* op) override { - // std::vector new_fields; - // for (const Type& t : op->fields) { - // new_fields.push_back(this->VisitType(t)); - // } - // return TupleTypeNode::make(new_fields); - // } - - Type VisitType_(const TypeRelationNode* op) override { - return GetRef(op); - } + return FuncTypeNode::make(tvm::Array(args), VisitType(op->ret_type), + {}, {}); // fix me + } - Type VisitType_(const TypeCallNode* op) override { - auto func = this->VisitType(op->func); - std::vector new_args; - for (const Type& t : op->args) { - new_args.push_back(this->VisitType(t)); + Type VisitType_(const TupleTypeNode* op) override { + std::vector new_fields; + for (const Type& t : op->fields) { + new_fields.push_back(this->VisitType(t)); } - return TypeCallNode::make(func, new_args); + return TupleTypeNode::make(new_fields); } - Type VisitType_(const IncompleteTypeNode* op) override { - return GetRef(op); + Type VisitType_(const TypeRelationNode* op) override { + return GetRef(op); + } + + Type VisitType_(const TypeCallNode* op) override { + auto func = this->VisitType(op->func); + std::vector new_args; + for (const Type& t : op->args) { + new_args.push_back(this->VisitType(t)); } - }; + return TypeCallNode::make(func, new_args); + } + + Type VisitType_(const IncompleteTypeNode* op) override { + return GetRef(op); + } +}; } // namespace relay -} // namespace relay +} // namespace tvm #endif // TVM_RELAY_TYPE_VISITOR_H_ From 5f9529f27662f5447f371338cf63994ae0467d4a Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 15:46:45 -0700 Subject: [PATCH 050/136] Address more CR feedback --- include/tvm/relay/expr_visitor.h | 9 ++++++--- include/tvm/relay/logging.h | 2 +- python/tvm/relay/__init__.py | 1 - 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h index 8803aa5ae48f..e15f25a39eb3 100644 --- a/include/tvm/relay/expr_visitor.h +++ b/include/tvm/relay/expr_visitor.h @@ -1,8 +1,11 @@ /*! * Copyright (c) 2018 by Contributors - * \file expr_visitor.h - * \brief A simple visitor wrapper around ExprFunctor designed for visitors which - * maintain mutable state. + * \file tvm/relay/expr_visitor.h + * \brief A simple visitor wrapper around ExprFunctor. + * + * Exposes two visitors with default traversal strategies, one + * which doesn't compute a result but can mutate internal state, + * and another which functionally builds a new Expr. */ #ifndef TVM_RELAY_EXPR_VISITOR_H_ #define TVM_RELAY_EXPR_VISITOR_H_ diff --git a/include/tvm/relay/logging.h b/include/tvm/relay/logging.h index 99cfc44de6cb..c53cd15ee72e 100644 --- a/include/tvm/relay/logging.h +++ b/include/tvm/relay/logging.h @@ -8,10 +8,10 @@ #ifndef TVM_RELAY_LOGGING_H_ #define TVM_RELAY_LOGGING_H_ +#include #include #include #include -#include "dmlc/logging.h" namespace tvm { namespace relay { diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 019d7c19a865..c36b9bcf8357 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -18,7 +18,6 @@ # Expr Constant = expr.Constant Tuple = expr.Tuple -# TODO: GlobalVar, LocalVar-> var LocalVar = expr.LocalVar GlobalVar = expr.GlobalVar Param = expr.Param From 66f01ad33bbbf053efbadd70a674bf7839481f8a Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 16:24:47 -0700 Subject: [PATCH 051/136] Add SourceMap and clean up environment.h --- include/tvm/relay/environment.h | 12 ++--- include/tvm/relay/source_map.h | 44 +++++++++++++++++ src/relay/source_map.cc | 88 +++++++++++++++++++++++++++++++++ 3 files changed, 138 insertions(+), 6 deletions(-) create mode 100644 include/tvm/relay/source_map.h create mode 100644 src/relay/source_map.cc diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h index ce874103a0a1..43be0ab8c912 100644 --- a/include/tvm/relay/environment.h +++ b/include/tvm/relay/environment.h @@ -13,7 +13,7 @@ #include "./type.h" #include "./op.h" #include "./error.h" -// #include "tvm/relay/source_map.h" +#include "tvm/relay/source_map.h" namespace tvm { namespace relay { @@ -43,10 +43,10 @@ class EnvironmentNode : public RelayNode { private: /*! \brief A map from string names to global variables ensures global uniqueness. */ tvm::Map global_map_; - // /*! \brief A map from file names to source fragments. */ - // SourceMap source_map_ - // /*! \brief A list of the errors reported during the current run. */ - // std::vector errors_; + /*! \brief A map from file names to source fragments. */ + SourceMap source_map_; + /*! \brief A list of the errors reported during the current run. */ + std::vector errors_; public: /*! \brief A map from ids to all global functions. */ @@ -73,7 +73,7 @@ class EnvironmentNode : public RelayNode { Function Lookup(const std::string & s); /*! \brief Add a source fragment to the environment. */ - // FileId add_source(std::string file_name, std::string source); + SourceName AddSource(std::string file_name, std::string source); void ReportError(std::string msg, Span sp); void DisplayErrors(); diff --git a/include/tvm/relay/source_map.h b/include/tvm/relay/source_map.h new file mode 100644 index 000000000000..71bf93aa1ed9 --- /dev/null +++ b/include/tvm/relay/source_map.h @@ -0,0 +1,44 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file source_map.h + * \brief A representation of source files and a data structure for + * storing them. + */ +#ifndef TVM_RELAY_SOURCE_MAP_H_ +#define TVM_RELAY_SOURCE_MAP_H_ + +#include +#include +#include + +namespace tvm { +namespace relay { + +struct SourceFragment { + std::string file_name; + std::vector source_lines; + + SourceFragment(std::string file_name, std::string source); + + SourceFragment(const SourceFragment& sf) { + this->file_name = sf.file_name; + this->source_lines = sf.source_lines; + } + + std::string SourceAt(Span sp, int lines); +}; + +/*! \brief Maps from FileId's to a SourceFragment. + */ +class SourceMap { + /*! \brief Map from unique token to a fragment of a source file. */ + std::unordered_map map_; + public: + SourceMap() : map_() {} + SourceName AddSource(std::string file_name, std::string source); + const SourceFragment & GetSource(SourceName id) const; +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_SOURCE_MAP_H_ \ No newline at end of file diff --git a/src/relay/source_map.cc b/src/relay/source_map.cc new file mode 100644 index 000000000000..0db80fd30339 --- /dev/null +++ b/src/relay/source_map.cc @@ -0,0 +1,88 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file source_map.cc + * \brief Source maps for Relay. + */ + +#include +#include +#include + +namespace tvm { +namespace relay { + +using tvm::IRPrinter; +using namespace tvm::runtime; + +SourceFragment::SourceFragment(std::string file_name, std::string source) + : file_name(file_name), source_lines({}) { + RELAY_LOG(INFO)<< "SourceFragment::SourceFragment source=" << source << std::endl; + std::stringstream source_stream; + source_stream.str(source.c_str()); + std::string line; + + while (std::getline(source_stream, line)) { + RELAY_LOG(INFO) << "SourceFragment::SourceFragment: line=" << line << std::endl; + std::string copy(line); + source_lines.push_back(copy); + } +} + +std::string SourceFragment::SourceAt(Span sp, int max_lines) { + std::stringstream out; + + // We need to move from 1 based indexing to zero based indexing. + int starting_line = sp->lineno; + + if (starting_line >= static_cast(this->source_lines.size())) { + throw dmlc::Error("SourceFragment: index out of bounds"); + } + + auto lines = std::max(static_cast(max_lines), source_lines.size() - starting_line); + + for (size_t i = 0; i < lines; i++) { + out << std::endl << this->source_lines.at(starting_line + i); + } + + auto source_slice = out.str(); + + RELAY_LOG(INFO) << "SourceFragment::SourceAt: source_slice=" << source_slice << std::endl; + return source_slice; +} + +SourceName SourceMap::AddSource(std::string file_name, std::string source) { + auto new_id = SourceNameNode::make(file_name); + SourceFragment sfile(file_name, source); + this->map_.insert({new_id, sfile}); + return new_id; +} + +SourceName SourceNameNode::make(std::string name) { + std::shared_ptr n = std::make_shared(); + n->name = std::move(name); + return SourceName(n); +} + +static SourceFragment DUMMY_SOURCE = SourceFragment("DUMMY_FILE", "DUMMY_SOURCE"); + +SourceFragment const &SourceMap::GetSource(SourceName id) const { + auto item = map_.find(id); + if (item != map_.end()) { + return (*item).second; + } else { + return DUMMY_SOURCE; + } +} + +TVM_REGISTER_API("relay._make.SourceName") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue *ret) { + *ret = SourceNameNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const SourceNameNode *node, tvm::IRPrinter *p) { + p->stream << "SourceNameNode(" << node->name << ", " << node << ")"; + }); + +} // namespace relay +} // namespace tvm \ No newline at end of file From e0b9ed7b0323cbf0cd66801fbc1c06dcf55415b8 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 16:35:43 -0700 Subject: [PATCH 052/136] Reogranize a bit --- include/tvm/relay/expr.h | 4 ++-- src/relay/ir/base.cc | 18 ++++++++++++++++++ src/relay/source_map.cc | 22 ++-------------------- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 5fe91702a29f..ddac633f9d09 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include "./base.h" #include "./type.h" @@ -223,8 +224,7 @@ class FunctionNode : public ExprNode { RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr); -// TODO(tqchen) change Expr to Attr after we introduce Attr system. -using Attrs = tvm::Map; +using Attrs = tvm::Attrs; /*! * \brief Call corresponds to operator invocation. diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 5fdf96ded224..d48b9a4c3e0c 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -12,6 +12,24 @@ namespace relay { using tvm::IRPrinter; using namespace tvm::runtime; +SourceName SourceNameNode::make(std::string name) { + std::shared_ptr n = std::make_shared(); + n->name = std::move(name); + return SourceName(n); +} + +// TVM_REGISTER_API("relay._make.SourceName") +// .set_body([](tvm::TVMArgs args, tvm::TVMRetValue *ret) { +// *ret = SourceNameNode::make(args[0]); +// }); + +// This causes a crash? + +// TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +// .set_dispatch([](const SourceNameNode *node, tvm::IRPrinter *p) { +// p->stream << "SourceNameNode(" << node->name << ", " << node << ")"; +// }); + Span SpanNode::make(SourceName source, int lineno, int col_offset) { std::shared_ptr n = std::make_shared(); n->source = std::move(source); diff --git a/src/relay/source_map.cc b/src/relay/source_map.cc index 0db80fd30339..a1b3627bccc8 100644 --- a/src/relay/source_map.cc +++ b/src/relay/source_map.cc @@ -57,32 +57,14 @@ SourceName SourceMap::AddSource(std::string file_name, std::string source) { return new_id; } -SourceName SourceNameNode::make(std::string name) { - std::shared_ptr n = std::make_shared(); - n->name = std::move(name); - return SourceName(n); -} - -static SourceFragment DUMMY_SOURCE = SourceFragment("DUMMY_FILE", "DUMMY_SOURCE"); - -SourceFragment const &SourceMap::GetSource(SourceName id) const { +const SourceFragment& SourceMap::GetSource(SourceName id) const { auto item = map_.find(id); if (item != map_.end()) { return (*item).second; } else { - return DUMMY_SOURCE; + throw dmlc::Error("could not find requested source fragment"); } } -TVM_REGISTER_API("relay._make.SourceName") - .set_body([](tvm::TVMArgs args, tvm::TVMRetValue *ret) { - *ret = SourceNameNode::make(args[0]); - }); - -TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const SourceNameNode *node, tvm::IRPrinter *p) { - p->stream << "SourceNameNode(" << node->name << ", " << node << ")"; - }); - } // namespace relay } // namespace tvm \ No newline at end of file From b29e17adf3502e973899252eb2941bb836488686 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 16:45:02 -0700 Subject: [PATCH 053/136] Kill dead code in env.py --- python/tvm/relay/_env.pyi | 15 +------------- python/tvm/relay/env.py | 43 --------------------------------------- 2 files changed, 1 insertion(+), 57 deletions(-) diff --git a/python/tvm/relay/_env.pyi b/python/tvm/relay/_env.pyi index d14e726e5443..c6b5d0f6c4bd 100644 --- a/python/tvm/relay/_env.pyi +++ b/python/tvm/relay/_env.pyi @@ -2,17 +2,4 @@ from typing import Union, Tuple, Dict, List from relay.ir import GlobalId, OperatorId, Item, NodeBase, Span, FileId from relay.ir import ShapeExtension, Operator, Defn -class Environment(NodeBase): ... - -def Environment_add(self: Environment, func: GlobalId) -> None: ... -def Environment_global_id(self: Environment, name: str) -> GlobalId: ... -def Environment_operator_id(self: Environment, name: str) -> OperatorId: ... -def Environment_lookup_global(self: Environment, id: GlobalId) -> Item: ... -def Environment_lookup_operator(self: Environment, id: OperatorId) -> Item: ... -def Environment_remove_global(self: Environment, id: GlobalId) -> Item: ... -def Environment_add_source(self: Environment, file_name: str, source: str) -> FileId: ... -def Environment_report_error(self: Environment, message: str, span: Span) -> None: ... -def Environment_display_errors(self: Environment) -> None: ... -def Environment_register_shape_ext(self: Environment, shape_ext: ShapeExtension) -> None: ... -def Environment_get_operators(self: Environment) -> List[Operator]: ... -def Environment_get_defns(self: Environment) -> List[Defn]: ... +class Environment(NodeBase): ... \ No newline at end of file diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py index c63197fa8509..4de5a0c02772 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/env.py @@ -3,15 +3,8 @@ from typing import Union, List from .base import register_relay_node, NodeBase from . import _make -# from relay.ir import GlobalId, OperatorId, Item, FileId, Span, ShapeExtension -# from relay.ir import Operator, Defn -# from relay._env import * import tvm -# Move me to C++ if possible. -__tgt_host__ = __tgt__ = "llvm" -__relay_tvm_context__ = tvm.cpu() - @register_relay_node class Environment(NodeBase): """The global Relay environment containing definitions, @@ -19,39 +12,3 @@ class Environment(NodeBase): """ def __init__(self, funcs) -> None: self.__init_handle_by_constructor__(_make.Environment, funcs) - - # def add(self, item: Item) -> None: - # return Environment_add(self, item) - - # def global_id(self, name: str) -> GlobalId: - # return Environment_global_id(self, name) - - # def operator_id(self, name: str) -> OperatorId: - # return Environment_operator_id(self, name) - - # def lookup(self, ident: Union[GlobalId, OperatorId]) -> Item: - # if isinstance(ident, OperatorId): - # return Environment_lookup_operator(self, ident) - # else: - # return Environment_lookup_global(self, ident) - - # def add_source(self, file_name: str, source: str) -> FileId: - # return Environment_add_source(self, file_name, source) - - # def report_error(self, message: str, span: Span) -> None: - # return Environment_report_error(self, message, span) - - # def register_shape_ext(self, ext: ShapeExtension) -> None: - # return Environment_register_shape_ext(self, ext) - - # def display_errors(self) -> None: - # return Environment_display_errors(self) - - # def operators(self) -> List[Operator]: - # return Environment_get_operators(self) - - # def defns(self) -> List[Defn]: - # return Environment_get_defns(self) - - # def tvm_context(self): - # return __relay_tvm_context__ From 04496801364afa1dd51cd4a1716fd8f0e3f4ee64 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 16:45:31 -0700 Subject: [PATCH 054/136] Fix commit mistake --- cmake/config.cmake | 3 --- 1 file changed, 3 deletions(-) diff --git a/cmake/config.cmake b/cmake/config.cmake index e09fdb241bf1..c364a88cce11 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -19,9 +19,6 @@ # $ make -j8 #-------------------------------------------------------------------- -SET(CMAKE_C_COMPLIER clang) -SET(CMAKE_CXX_COMPILER clang++) - #--------------------------------------------- # Backend runtimes. #--------------------------------------------- From f151ea910b1c6ce2e3a7218de66ec12ecf8fdfe3 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 16:48:05 -0700 Subject: [PATCH 055/136] Move type_infer into pass.py --- python/tvm/relay/{_type_infer.py => _pass.py} | 0 python/tvm/relay/{_type_infer.pyi => _pass.pyi} | 0 python/tvm/relay/{type_infer.py => pass.py} | 2 +- 3 files changed, 1 insertion(+), 1 deletion(-) rename python/tvm/relay/{_type_infer.py => _pass.py} (100%) rename python/tvm/relay/{_type_infer.pyi => _pass.pyi} (100%) rename python/tvm/relay/{type_infer.py => pass.py} (78%) diff --git a/python/tvm/relay/_type_infer.py b/python/tvm/relay/_pass.py similarity index 100% rename from python/tvm/relay/_type_infer.py rename to python/tvm/relay/_pass.py diff --git a/python/tvm/relay/_type_infer.pyi b/python/tvm/relay/_pass.pyi similarity index 100% rename from python/tvm/relay/_type_infer.pyi rename to python/tvm/relay/_pass.pyi diff --git a/python/tvm/relay/type_infer.py b/python/tvm/relay/pass.py similarity index 78% rename from python/tvm/relay/type_infer.py rename to python/tvm/relay/pass.py index 17938dfdcbc4..9d7902686928 100644 --- a/python/tvm/relay/type_infer.py +++ b/python/tvm/relay/pass.py @@ -1,6 +1,6 @@ #pylint: disable-all -from . import _type_infer +from . import _pass check_expr = _type_infer.check_expr # generalize = _type_infer.generalize From 78a8d64c67ae48a595063758753a14acbb757ed5 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 16:54:18 -0700 Subject: [PATCH 056/136] Reorganize passes a bit --- include/tvm/relay/pass.h | 23 +++++++++++ include/tvm/relay/pass/type_infer.h | 10 +---- python/tvm/relay/_pass.py | 2 +- src/relay/pass/type_infer.cc | 59 ++--------------------------- 4 files changed, 29 insertions(+), 65 deletions(-) create mode 100644 include/tvm/relay/pass.h diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h new file mode 100644 index 000000000000..89f3dd48fc31 --- /dev/null +++ b/include/tvm/relay/pass.h @@ -0,0 +1,23 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/pass.h + * \brief The set of Relay passes written in C++. + */ +#ifndef TVM_RELAY_PASS_H_ +#define TVM_RELAY_PASS_H_ + +#include "tvm/relay/expr.h" +#include "tvm/relay/environment.h" + +namespace tvm { +namespace relay { + +/*! The result of type checking an expression is a new expression + * with unambigous type information filled in, as well as it's + * checked type field populated with the result type. + */ +Expr InferType(const Environment & env, const Expr & e); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PASS_TYPECHECKER_H_ \ No newline at end of file diff --git a/include/tvm/relay/pass/type_infer.h b/include/tvm/relay/pass/type_infer.h index 9a8ab2bc6a8b..2b860a5e89ef 100644 --- a/include/tvm/relay/pass/type_infer.h +++ b/include/tvm/relay/pass/type_infer.h @@ -6,8 +6,8 @@ * The pass produces a new expression with its checked_type * field populated and incomplete types resolved. */ -#ifndef TVM_RELAY_PASS__TYPECHECKER_H_ -#define TVM_RELAY_PASS__TYPECHECKER_H_ +#ifndef TVM_RELAY_PASS_TYPECHECKER_H_ +#define TVM_RELAY_PASS_TYPECHECKER_H_ #include "tvm/relay/expr.h" #include "tvm/relay/environment.h" @@ -15,12 +15,6 @@ namespace tvm { namespace relay { -/*! The result of type checking an expression is a new expression - * with unambigous type information filled in, as well as it's - * checked type field populated with the result type. - */ -Expr InferType(const Environment & env, const Expr & e); - /*! \brief Ensures that an operator is well-formed with respect * to Relay's type system. */ diff --git a/python/tvm/relay/_pass.py b/python/tvm/relay/_pass.py index 7213769a4164..052ba6d4a0fb 100644 --- a/python/tvm/relay/_pass.py +++ b/python/tvm/relay/_pass.py @@ -2,4 +2,4 @@ from tvm._ffi.function import _init_api -_init_api("relay._type_infer", __name__) +_init_api("relay._pass", __name__) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 5e9e784dbe83..746323fc6d56 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -626,73 +626,20 @@ class TypeInferencer : private ExprFunctor { } } - // // template - - // // Add safe dynamic Array downcast. - // // Add static upcast? - - // // Add to type utils. - // Array type_parameters(const Type &t) { - // Array params; - // auto type = t; - // const TypeQuantifierNode *ty_quant; - // while ((ty_quant = type.as())) { - // params.push_back(ty_quant->id); - // type = ty_quant->boundType; - // } - - // return params; - // } - - // template - // Array ArrayMap(const Array &data, F f) { - // // probably a way to use std::transform. - // Array output; - // for (const I &el : data) { - // output.push_back(f(el)); - // } - // return output; - // } - - // // There are some important questions around generalization - // // that we need to answer. - // Expr generalize(const Environment &env, const Expr &e) { - // if (auto fn_node = e.as()) { - // TypeInferencer tc(env); - // auto ty = tc.VisitFunction(GetRef(fn_node), true); - // auto ty_params = type_parameters(ty); - // auto params = ArrayMap(fn_node->params, [&](const Param &p) { - // return ParamNode::make(p->id, tc.resolve(p->type)); - // }); - // auto body = tc.resolve(fn_node->body); - // auto ret_type = tc.resolve(fn_node->ret_type); - // auto fn = FunctionNode::make(ty_params, params, ret_type, body); - // // we should check in empty context to ensure typing is preserved. - // // check(env, fn); - // return fn; - // } else { - // throw dmlc::Error("can only apply generalize to a function."); - // } - // } - - TVM_REGISTER_API("relay._type_infer.check_expr") + TVM_REGISTER_API("relay._pass.check_expr") .set_body([](TVMArgs args, TVMRetValue *ret) { Environment env = args[0]; Expr e = args[1]; *ret = Infer(env, e); }); - TVM_REGISTER_API("relay._type_infer._get_checked_type") + // TODO(@jroesch): put in a better namespace. + TVM_REGISTER_API("relay._pass._get_checked_type") .set_body([](TVMArgs args, TVMRetValue *ret) { Expr e = args[0]; *ret = e->checked_type(); }); - // TVM_REGISTER_API("relay._tyck.generalize") - // .set_body([](TVMArgs args, TVMRetValue *ret) { - // *ret = generalize(args[0], args[1]); - // }); - IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) { std::shared_ptr n = std::make_shared(); From 1dbf6368559b228ff071a26037c1159fb5c8e723 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 16:57:12 -0700 Subject: [PATCH 057/136] More cleaning --- python/tvm/relay/type.py | 37 ------------------------------------- src/relay/ir/type.cc | 26 +++----------------------- 2 files changed, 3 insertions(+), 60 deletions(-) diff --git a/python/tvm/relay/type.py b/python/tvm/relay/type.py index c9a96de4889d..70e4666e96f9 100644 --- a/python/tvm/relay/type.py +++ b/python/tvm/relay/type.py @@ -4,7 +4,6 @@ from enum import IntEnum from .base import Span, NodeBase, register_relay_node from tvm import expr -# TODO(@jroesch): move me from . import _make class Type(NodeBase): @@ -85,39 +84,3 @@ class IncompleteType(Type): def __init__(self, kind: Kind) -> None: self.__init_handle_by_constructor__(_make.IncompleteType, kind) - -def IntType(bits: int, lanes: int=1) -> Type: - """Constructs a integer base type. - - :param bits: The bit width of the integer type. - :param lanes: The number of vector elements for this datatype. - - """ - return _make.IntType(bits, lanes) - - -def UIntType(bits: int, lanes: int=1) -> Type: - """Constructs a unsigned integer base type. - - :param bits: The bit width of the unsigned type. - :param lanes: The number of vector elements for this datatype. - """ - return _make.UIntType(bits, lanes) - - -def FloatType(bits: int, lanes: int=1) -> Type: - """Constructs a floating point base type. - - :param bits: The bit width of the unsigned type. - :param lanes: The number of vector elements for this datatype. - """ - return _make.FloatType(bits, lanes) - - -def BoolType(lanes: int =1) -> Type: - """Constructs a boolean base type. - - :param bits: The bit width of the unsigned type. - :param lanes: The number of vector elements for this datatype. - """ - return _make.BoolType(lanes) diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index e29f3cbde4c1..2975c60cc0c1 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -41,26 +41,6 @@ TVM_REGISTER_API("relay._make.TensorType") *ret = TensorTypeNode::make(shape, args[1]); }); -TVM_REGISTER_API("relay._make.IntType") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = TensorTypeNode::Int(args[0], args[1]); - }); - -TVM_REGISTER_API("relay._make.UIntType") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = TensorTypeNode::UInt(args[0], args[1]); - }); - -TVM_REGISTER_API("relay._make.BoolType") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = TensorTypeNode::Bool(args[0]); - }); - -TVM_REGISTER_API("relay._make.FloatType") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = TensorTypeNode::Float(args[0], args[1]); - }); - TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const TensorTypeNode *node, tvm::IRPrinter *p) { @@ -113,7 +93,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << node->type_constraints << ")"; }); -TypeRelation TypeRelationNode::make(std::string name, int num_args, TypeRelationFn func) { +TypeRelation TypeRelationNode::make(std::string name, int num_args, + TypeRelationFn func) { std::shared_ptr n = std::make_shared(); n->name = std::move(name); n->num_args = std::move(num_args); @@ -164,10 +145,9 @@ TVM_REGISTER_API("relay._make.TupleType") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const TupleTypeNode *node, - tvm::IRPrinter *p) { + tvm::IRPrinter *p) { p->stream << "TupleTypeNode(" << node->fields << ")"; }); - } // namespace relay } // namespace tvm From 3445b9ffa47d7544c745f5c859993f10f5aa9f61 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 17:00:55 -0700 Subject: [PATCH 058/136] Remove dead code --- python/tvm/relay/pass.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/pass.py b/python/tvm/relay/pass.py index 9d7902686928..8c352e58843d 100644 --- a/python/tvm/relay/pass.py +++ b/python/tvm/relay/pass.py @@ -3,4 +3,3 @@ from . import _pass check_expr = _type_infer.check_expr -# generalize = _type_infer.generalize From d716e6ced593409b13715b01f0424caf1710302d Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 17:06:56 -0700 Subject: [PATCH 059/136] Clean up code in Unifier --- src/relay/pass/unifier.cc | 69 ++++++++++++++++----------------------- src/relay/pass/unifier.h | 2 +- 2 files changed, 29 insertions(+), 42 deletions(-) diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index c6a4e7dfba6d..4d986ad79ab1 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -11,7 +11,6 @@ #include #include "./unifier.h" #include "./type_visitor.h" -//#include "./type_subst.h" // #include "tvm/relay/typeck/kindchecker.h" namespace tvm { @@ -298,52 +297,40 @@ Type TypeUnifierNode::VisitType_(const TensorTypeNode *t1, const Type rt2) { throw UnificationError("Cannot unify TensorTypeNode"); } -// Type TypeUnifierNode::VisitType_(const TupleTypeNode *t1, const Type rt2) { -// TupleType pt1 = GetRef(t1); +Type TypeUnifierNode::VisitType_(const TupleTypeNode *t1, const Type rt2) { + TupleType pt1 = GetRef(t1); -// // for typevar, remap and attempt to unify if already defined -// if (const IncompleteTypeNode *tvn2 = rt2.as()) { -// return this->unifyWithIncompleteType(pt1, GetRef(tvn2)); -// } + // When unifying tuple types we just solve each field in order. + if (const TupleTypeNode *ptn2 = rt2.as()) { + TupleType pt2 = GetRef(ptn2); -// // for other product types, unify item by item -// if (const TupleTypeNode *ptn2 = rt2.as()) { -// TupleType pt2 = GetRef(ptn2); - -// std::vector unified_fields; -// if (pt1->fields.size() != pt2->fields.size()) { -// throw UnificationError("Product types are of different dimensions"); -// } - -// for (size_t i = 0U; i < pt1->fields.size(); i++) { -// Type unified = this->VisitType(pt1->fields[i], pt2->fields[i]); -// unified_fields.push_back(unified); -// } - -// return TupleTypeNode::make(unified_fields); -// } - -// // otherwise cannot unify -// throw UnificationError("Cannot unify TupleTypeNode"); -// } + std::vector unified_fields; + if (pt1->fields.size() != pt2->fields.size()) { + throw UnificationError("Product types are of different dimensions"); + } -Type TypeUnifierNode::VisitType_(const TypeRelationNode *sen1, const Type t2) { -// ShapeExtension sh_ext1 = GetRef(sen1); + for (size_t i = 0U; i < pt1->fields.size(); i++) { + Type unified = this->VisitType(pt1->fields[i], pt2->fields[i]); + unified_fields.push_back(unified); + } -// if (const IncompleteTypeNode *tvn2 = t2.as()) { -// return this->unifyWithIncompleteType(sh_ext1, GetRef(tvn2)); -// } + return TupleTypeNode::make(unified_fields); + } -// // will only attempt to unify with binary op with same op -// if (const ShapeExtensionNode *sen2 = t2.as()) { -// if (sh_ext1->name != sen2->name) { -// throw UnificationError( -// "Cannot unify shape projections of different index"); -// } -// } + // otherwise cannot unify + throw UnificationError("Cannot unify TupleTypeNode"); +} -// return sh_ext1; - return t2; +Type TypeUnifierNode::VisitType_(const TypeRelationNode *tr1, const Type t2) { + if (const TypeRelationNode *tr2 = t2.as()) { + if (tr1 == tr2) { + return GetRef(tr1); + } else { + throw UnificationError("Cannot unify different type relations"); + } + } else { + throw UnificationError("Cannot unify type relation with another type of type"); + } } Type TypeUnifierNode::VisitType_(const TypeCallNode *tcn1, const Type t2) { diff --git a/src/relay/pass/unifier.h b/src/relay/pass/unifier.h index aecc428cb6a9..5a4adea5c44e 100644 --- a/src/relay/pass/unifier.h +++ b/src/relay/pass/unifier.h @@ -109,7 +109,7 @@ class TypeUnifierNode : public Node, Type VisitType_(const TensorTypeNode* t1, const Type t2) override; Type VisitType_(const TypeParamNode* t1, const Type t2) override; Type VisitType_(const FuncTypeNode* t1, const Type t2) override; - // Type VisitType_(const TupleTypeNode* t1, const Type t2) override; + Type VisitType_(const TupleTypeNode* t1, const Type t2) override; Type VisitType_(const TypeRelationNode* s1, const Type t2) override; Type VisitType_(const TypeCallNode* s1, const Type t2) override; }; From 33ef935d9cba8ff559b086b2d3f8690fb4a1eefd Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 17:21:37 -0700 Subject: [PATCH 060/136] Clean up environment.h --- include/tvm/relay/environment.h | 3 +- src/relay/ir/environment.cc | 203 ++++++++++---------------------- 2 files changed, 60 insertions(+), 146 deletions(-) diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h index 43be0ab8c912..aa41882db46e 100644 --- a/include/tvm/relay/environment.h +++ b/include/tvm/relay/environment.h @@ -50,7 +50,7 @@ class EnvironmentNode : public RelayNode { public: /*! \brief A map from ids to all global functions. */ - tvm::Map items; + tvm::Map functions; EnvironmentNode() {} @@ -60,7 +60,6 @@ class EnvironmentNode : public RelayNode { tvm::Map global_funcs); void Add(const GlobalVar& var, const Function & func, bool update = false); - void TryAdd(const GlobalVar& var, const Function & func, bool update=false); void Update(const GlobalVar& var, const Function & func); void Remove(const GlobalVar& var); diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc index 8c155e3bc1bd..63a42b9d0e3e 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/environment.cc @@ -5,11 +5,7 @@ */ #include #include "tvm/relay/environment.h" -// #include "tvm/relay/alpha_eq.h" -// #include "tvm/relay/debug.h" -// #include "tvm/relay/typeck/typechecker.h" // #include "tvm/relay/util/rang.h" -// #include "tvm/runtime/packed_func_ext.h" namespace tvm { namespace relay { @@ -20,7 +16,7 @@ using namespace tvm::runtime; Environment EnvironmentNode::make( tvm::Map global_funcs) { std::shared_ptr n = std::make_shared(); - n->items = std::move(global_funcs); + n->functions = std::move(global_funcs); return Environment(n); } @@ -35,10 +31,13 @@ GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) { } } -// // Add a new item to the global environment -// // throws an exception if the item already -// // exists. -// void EnvironmentNode::add(const Item &unchecked_item, bool update) { +/*! \brief Add a new item to the global environment + * \note if the update flag is not set adding a duplicate + * definition will trigger an exception, otherwise we will + * update the definition if and only if it is type compatible. + */ +void EnvironmentNode::Add(const GlobalVar& var, const Function & func, bool update) { + throw Error("NYI"); // // Type check the item before we add it to the environment. // auto env = GetRef(this); // Item item = check(env, unchecked_item); @@ -85,14 +84,22 @@ GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) { // throw EnvError("internal error: unknown item type, unreachable code"); // } // } +} -// void EnvironmentNode::update(const Item &item) { return this->add(item, true); } +void EnvironmentNode::Update(const GlobalVar& var, const Function & func) { + this->Add(var, func, true); +} -// void EnvironmentNode::remove(const GlobalId &id) { this->items.erase(id); } +void EnvironmentNode::Remove(const GlobalVar&) { + // Clarify with @tqchen about how to use COW to do this. + throw Error("NYI"); + // this->items.erase(id); +} Function EnvironmentNode::Lookup(const GlobalVar &var) { - if (items.find(var) != items.end()) { - return items.at(var); + auto func = functions.find(var); + if (func != functions.end()) { + return (*func).second; } else { throw Error(std::string("there is no definition of ") + var->name_hint); } @@ -103,143 +110,51 @@ Function EnvironmentNode::Lookup(const std::string &str) { return this->Lookup(id); } -// inline FileId EnvironmentNode::add_source(std::string file_name, -// std::string source) { -// return this->source_map_.add_source(file_name, source); -// } - -// void EnvironmentNode::report_error(std::string msg, Span sp) { -// this->errors_.push_back(Error(msg, sp)); -// } - -// void EnvironmentNode::display_errors() { -// for (auto err : this->errors_) { -// auto sp = err.sp; -// auto source_file = this->source_map_.GetSource(err.sp->file_id); -// auto file_name = source_file.file_name; -// auto source_at_span = source_file.SourceAt(err.sp, 1); -// std::string error_marker = "error:"; -// auto line_info = -// std::to_string(sp->lineno) + ":" + std::to_string(sp->col_offset); - -// std::cout << rang::style::bold << rang::fg::red << error_marker -// << rang::fg::reset << file_name << ":" << line_info -// << rang::style::reset << " " << source_at_span << std::endl; - -// // Build the cursor. - -// // Fix this code, hardwired to compute alignment of pointer. -// size_t spaces = error_marker.size() + line_info.size() + file_name.size() + -// sp->col_offset - 3; - -// std::string cursor = "~~~~^~~~~"; -// for (size_t i = 0; i < spaces; i++) { -// std::cout << " "; -// } -// std::cout << rang::fg::red << cursor << " " << err.msg << rang::style::reset -// << std::endl; -// } -// } +inline SourceName EnvironmentNode::AddSource(std::string file_name, + std::string source) { + throw Error("need to restore error handling"); + // return this->source_map_.add_source(file_name, source); +} -TVM_REGISTER_API("relay._make.Environment") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = EnvironmentNode::make({}); - }); +void EnvironmentNode::ReportError(std::string msg, Span sp) { + throw Error("need to restore error handling"); + // this->errors_.push_back(Error(msg, sp)); +} -// TVM_REGISTER_API("relay._env.Environment_add") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// Item item = args[1]; -// env->add(item, true); // REMOVE ME -// }); - -// TVM_REGISTER_API("relay._env.Environment_lookup_global") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// GlobalId id = args[1]; -// *ret = env->lookup(id); -// }); - -// TVM_REGISTER_API("relay._env.Environment_lookup_operator") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// OperatorId id = args[1]; -// *ret = env->lookup(id); -// }); - -// // TVM_REGISTER_API("relay._env.Environment_remove_global") -// // .set_body([](TVMArgs args, TVMRetValue *ret) { -// // Environment env = args[0]; -// // GlobalId id = args[1]; -// // env->remove(id); -// // }); - -// TVM_REGISTER_API("relay._env.Environment_global_id") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// std::string str = args[1]; -// *ret = env->global_id(str); -// }); - -// TVM_REGISTER_API("relay._env.Environment_operator_id") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// std::string str = args[1]; -// *ret = env->operator_id(str); -// }); - -// TVM_REGISTER_API("relay._env.Environment_register_shape_ext") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// ShapeExtension ext = args[1]; -// env->register_shape_ext(ext); -// }); - -// TVM_REGISTER_API("relay._env.Environment_register_primitive") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// std::string str = args[1]; -// *ret = env->global_id(str); -// }); - -// TVM_REGISTER_API("relay._env.Environment_add_source") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// std::string file_name = args[1]; -// std::string source_name = args[2]; -// *ret = env->add_source(file_name, source_name); -// }); - -// TVM_REGISTER_API("relay._env.Environment_report_error") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// std::string msg = args[1]; -// Span sp = args[2]; -// env->report_error(msg, sp); -// }); - -// TVM_REGISTER_API("relay._env.Environment_display_errors") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// return env->display_errors(); -// }); - -// TVM_REGISTER_API("relay._env.Environment_get_operators") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// *ret = env->get_operators(); -// }); - -// TVM_REGISTER_API("relay._env.Environment_get_defns") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Environment env = args[0]; -// *ret = env->get_defns(); -// }); +void EnvironmentNode::DisplayErrors() { + throw Error("need to restore error printing"); + // for (auto err : this->errors_) { + // auto sp = err.sp; + // auto source_file = this->source_map_.GetSource(err.sp->file_id); + // auto file_name = source_file.file_name; + // auto source_at_span = source_file.SourceAt(err.sp, 1); + // std::string error_marker = "error:"; + // auto line_info = + // std::to_string(sp->lineno) + ":" + std::to_string(sp->col_offset); + + // std::cout << rang::style::bold << rang::fg::red << error_marker + // << rang::fg::reset << file_name << ":" << line_info + // << rang::style::reset << " " << source_at_span << std::endl; + + // // Build the cursor. + + // // Fix this code, hardwired to compute alignment of pointer. + // size_t spaces = error_marker.size() + line_info.size() + file_name.size() + + // sp->col_offset - 3; + + // std::string cursor = "~~~~^~~~~"; + // for (size_t i = 0; i < spaces; i++) { + // std::cout << " "; + // } + // std::cout << rang::fg::red << cursor << " " << err.msg << rang::style::reset + // << std::endl; + // } +} TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const EnvironmentNode *node, tvm::IRPrinter *p) { - p->stream << "EnvironmentNode(todo)"; // << node->items << ")"; + p->stream << "EnvironmentNode( " << node->functions << ")"; }); } // namespace relay From 3c3d719a5a81cb771a12025a0e262527679e4ae1 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 28 Aug 2018 17:29:30 -0700 Subject: [PATCH 061/136] Fix up Python imports --- python/tvm/relay/{_pass.py => _ir_pass.py} | 2 +- python/tvm/relay/{_pass.pyi => _ir_pass.pyi} | 0 python/tvm/relay/expr.py | 2 +- python/tvm/relay/ir_builder.py | 21 ++++++++++++------- python/tvm/relay/ir_pass.py | 5 +++++ python/tvm/relay/pass.py | 5 ----- src/relay/ir/environment.cc | 5 +++++ src/relay/pass/type_infer.cc | 4 ++-- .../relay/test_tyck_eval_integration.py | 2 +- 9 files changed, 29 insertions(+), 17 deletions(-) rename python/tvm/relay/{_pass.py => _ir_pass.py} (72%) rename python/tvm/relay/{_pass.pyi => _ir_pass.pyi} (100%) create mode 100644 python/tvm/relay/ir_pass.py delete mode 100644 python/tvm/relay/pass.py diff --git a/python/tvm/relay/_pass.py b/python/tvm/relay/_ir_pass.py similarity index 72% rename from python/tvm/relay/_pass.py rename to python/tvm/relay/_ir_pass.py index 052ba6d4a0fb..61fdcfa38c2f 100644 --- a/python/tvm/relay/_pass.py +++ b/python/tvm/relay/_ir_pass.py @@ -2,4 +2,4 @@ from tvm._ffi.function import _init_api -_init_api("relay._pass", __name__) +_init_api("relay._ir_pass", __name__) diff --git a/python/tvm/relay/_pass.pyi b/python/tvm/relay/_ir_pass.pyi similarity index 100% rename from python/tvm/relay/_pass.pyi rename to python/tvm/relay/_ir_pass.pyi diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 41066829e2f3..4f558210fb11 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -6,7 +6,7 @@ from .base import Span, NodeBase, register_relay_node from .type import Type, TypeParam from tvm import expr -from ._type_infer import _get_checked_type +from ._ir_pass import _get_checked_type from . import _make class Expr(NodeBase): diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index a9cb02a19025..b5ca6428c897 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -1,7 +1,7 @@ from typing import Any import numpy as np import tvm -from .type import FloatType, IntType, BoolType, UIntType, FuncType, TensorType +from .type import FuncType, TensorType from .expr import Expr, Call, Constant, Let, LocalVar, Param, Function from . import op as _op @@ -152,20 +152,27 @@ def get(self): def bool_dtype(): return 'uint1' -def int_dtype(): - return 'uint1' +def int_dtype(bits=32): + return f'int1{bits}' + +def float_dtype(bits=32): + return f'float{bits}' +def uint_dtype(bits=32): + return f'fuint{bits}' + def int_type(bits=32, lanes=1): - return IntType(bits, lanes) + # TODO(@jroesch, @tqchen) How do we set lanes? + return TensorType(tvm.convert([]), int_dtype(bits)) def uint_type(bits=32, lanes=1): - return UIntType(bits, lanes) + return TensorType(tvm.convert([]), uint_dtype(bits)) def float_type(bits=32, lanes=1): - return FloatType(bits, lanes) + return TensorType(tvm.convert([]), float_dtype(bits)) def bool_type(lanes=1): - return BoolType(lanes) + return TensorType(tvm.convert([]), bool_dtype(bits)) def tensor_type(*shape, dtype='float32'): return TensorType(tvm.convert(shape), dtype) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py new file mode 100644 index 000000000000..ad7a68eac392 --- /dev/null +++ b/python/tvm/relay/ir_pass.py @@ -0,0 +1,5 @@ +#pylint: disable-all + +from . import _ir_pass + +check_expr = _ir_pass.check_expr diff --git a/python/tvm/relay/pass.py b/python/tvm/relay/pass.py deleted file mode 100644 index 8c352e58843d..000000000000 --- a/python/tvm/relay/pass.py +++ /dev/null @@ -1,5 +0,0 @@ -#pylint: disable-all - -from . import _pass - -check_expr = _type_infer.check_expr diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc index 63a42b9d0e3e..cb8afd002c51 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/environment.cc @@ -151,6 +151,11 @@ void EnvironmentNode::DisplayErrors() { // } } +TVM_REGISTER_API("relay._make.Environment") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = EnvironmentNode::make(args[0]); + }); + TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const EnvironmentNode *node, tvm::IRPrinter *p) { diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 746323fc6d56..383196f49be9 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -626,7 +626,7 @@ class TypeInferencer : private ExprFunctor { } } - TVM_REGISTER_API("relay._pass.check_expr") + TVM_REGISTER_API("relay._ir_pass.check_expr") .set_body([](TVMArgs args, TVMRetValue *ret) { Environment env = args[0]; Expr e = args[1]; @@ -634,7 +634,7 @@ class TypeInferencer : private ExprFunctor { }); // TODO(@jroesch): put in a better namespace. - TVM_REGISTER_API("relay._pass._get_checked_type") + TVM_REGISTER_API("relay._ir_pass._get_checked_type") .set_body([](TVMArgs args, TVMRetValue *ret) { Expr e = args[0]; *ret = e->checked_type(); diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index 6e5b64ee846e..72fd995fd22e 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -1,7 +1,7 @@ """Test that type checker correcly computes types for expressions. """ -from tvm.relay.type_infer import check_expr +from tvm.relay.ir_pass import check_expr from tvm.relay.ir_builder import IRBuilder, float_type, func_type, tensor_type from tvm.relay.env import Environment from tvm.relay.op import log, add From c075bd32dae0a2bbb6b7dd519edc9533ebc734e6 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 30 Aug 2018 22:50:00 -0700 Subject: [PATCH 062/136] Add first pass add broadcast inference --- src/relay/op/type_relations.cc | 33 ++++++++++++++----- .../relay/test_tyck_eval_integration.py | 4 +-- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 56b139731178..d97b8f96e85c 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -24,6 +24,7 @@ TensorType as_ttype(const Type & t) { int to_int(const tvm::Expr & e) { auto imm = e.as(); CHECK(imm); + std::cout << "TYPE: " << imm << imm->type << std::endl; return imm->value; } @@ -60,26 +61,41 @@ static Type ConcreteBroadcast(const TensorType & t1, const TensorType & t2) { } } - Array larger; - Array smaller; + Array larger; + Array smaller; for (int i = 0; i < (full_len - suffix_len); i++) { - // smaller.push_back(tvm::ir::IntImm::make(1)); + smaller.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), 1)); } if (sh1.size() < sh2.size()) { - + for (auto sh : sh1) { + smaller.push_back(sh); + } + larger = sh2; } else if (sh1.size() > sh2.size()) { - + for (auto sh : sh1) { + larger.push_back(sh); + } + smaller = sh2; } else { - + larger = sh1; + smaller = sh2; } - for (int i = 0; i < suffix_len - full_len; i++) { + CHECK(larger.size() == smaller.size()); + Array out_shape; + for (int i = 0; i < smaller.size(); i++) { + auto left = smaller[i].as(); + auto right = larger[i].as(); + CHECK(left); + CHECK(right); + int64_t dim = std::max(left->value, right->value); + out_shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), dim)); } - return t1; + return TensorTypeNode::make(out_shape, t1->dtype); } Array BroadcastRel(const Array & types, int num_args) { @@ -89,6 +105,7 @@ Array BroadcastRel(const Array & types, int num_args) { return { t1, t2, ConcreteBroadcast(t1, t2) }; } } + return types; } diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index 72fd995fd22e..e928cd5cb76a 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -47,6 +47,6 @@ def test_dual_op(): assert has_type(func.to_func(), func_type([float_type()], float_type())) if __name__ == "__main__": - # test_monomorphic_let() - # test_single_op() + test_monomorphic_let() + test_single_op() test_dual_op() From c080d7b90f33fbb3382ef4abeadcb98615157f45 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 31 Aug 2018 16:05:34 -0700 Subject: [PATCH 063/136] Add ability to build and check a global --- include/tvm/relay/environment.h | 4 + python/tvm/relay/env.py | 16 + python/tvm/relay/ir_builder.py | 41 +- src/relay/ir/environment.cc | 125 ++- src/relay/ir/expr.cc | 4 +- src/relay/pass/type_infer.cc | 961 +++++++++--------- .../relay/test_tyck_eval_integration.py | 25 +- 7 files changed, 630 insertions(+), 546 deletions(-) diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h index aa41882db46e..5ad7ba8e0010 100644 --- a/include/tvm/relay/environment.h +++ b/include/tvm/relay/environment.h @@ -63,6 +63,7 @@ class EnvironmentNode : public RelayNode { void Update(const GlobalVar& var, const Function & func); void Remove(const GlobalVar& var); + /*! \brief Lookup a global function by its variable. */ GlobalVar GetGlobalVar(const std::string& str); /*! \brief Lookup a global function by its variable. */ @@ -70,6 +71,9 @@ class EnvironmentNode : public RelayNode { /*! \brief Lookup a global function by its string name */ Function Lookup(const std::string & s); + + // TODO(@jroesch, @tqchen): what are the semantics here + void Merge(const Environment & env); /*! \brief Add a source fragment to the environment. */ SourceName AddSource(std::string file_name, std::string source); diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py index 4de5a0c02772..186ee8854c35 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/env.py @@ -3,6 +3,7 @@ from typing import Union, List from .base import register_relay_node, NodeBase from . import _make +from . import _env import tvm @register_relay_node @@ -12,3 +13,18 @@ class Environment(NodeBase): """ def __init__(self, funcs) -> None: self.__init_handle_by_constructor__(_make.Environment, funcs) + + def add(self, var, func) -> None: + if isinstance(var, str): + var = _env.Environment_GetGlobalVar(self, var) + + _env.Environment_Add(self, var, func) + + def merge(self, other): + return _env.Environment_Merge(self, other) + + def lookup(self, var): + if isinstance(var, str): + return _env.Environment_Lookup_str(self, var) + else: + return _env.Environment_Lookup(self, var) diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index b5ca6428c897..50ebeb1bb12d 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -3,6 +3,7 @@ import tvm from .type import FuncType, TensorType from .expr import Expr, Call, Constant, Let, LocalVar, Param, Function +from .env import Environment from . import op as _op class ExprBuilder(): @@ -83,6 +84,7 @@ def __init__(self): self.scopes = [{}] self.params = [] self.ret_value = None + self.env = Environment({}) def bind(self, name, type, value): @@ -93,6 +95,9 @@ def bind(self, name, type, value): def let(self, name, value, value_type=None): + if isinstance(value, Param): + value = value.var + if not (isinstance(value, Expr) or isinstance(value, ExprBuilder)): value = into_ast(value) @@ -131,8 +136,29 @@ def ret(self, x): raise Exception( "return value already set, a function can only have one return value") - def fn_params(self): - pass + def param(self, name, ty=None): + if not ty: + ty = float_type() + + return Param(LocalVar(name), ty) + + # def params(*args): + # i = 0 + # while i < args.size(): + # arg = args[i] + # if isinstance(arg, str): + + + def decl(self, name: str, *params): + decl_builder = IRBuilder() + + def _on_exit(): + exp, sub_env = decl_builder.get() + self.env.add(name, Function(params, None, exp)) + self.env.merge(sub_env) + + return WithScope(decl_builder, _on_exit) + def get(self): """Get the full program""" @@ -140,14 +166,15 @@ def get(self): scope = self.scopes.pop() if self.bindings: - raise Exception("...") + raise Exception("IRBuilder: binding error") + if self.scopes: - raise Exception("...") + raise Exception("IRBuilder: scoping error") - if not self.ret_value: - raise Exception("...") + if bindings and scope and not self.ret_value: + raise Exception("IRBuilder: no return value set") - return _mk_let(bindings, self.ret_value) + return _mk_let(bindings, self.ret_value), self.env def bool_dtype(): return 'uint1' diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc index cb8afd002c51..7861fb58820b 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/environment.cc @@ -4,14 +4,18 @@ * \brief The global environment in Relay. */ #include -#include "tvm/relay/environment.h" +#include +#include +#include +#include +#include "./../pass/resolve.h" // #include "tvm/relay/util/rang.h" namespace tvm { namespace relay { using tvm::IRPrinter; -using namespace tvm::runtime; +using namespace runtime; Environment EnvironmentNode::make( tvm::Map global_funcs) { @@ -37,53 +41,35 @@ GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) { * update the definition if and only if it is type compatible. */ void EnvironmentNode::Add(const GlobalVar& var, const Function & func, bool update) { - throw Error("NYI"); -// // Type check the item before we add it to the environment. -// auto env = GetRef(this); -// Item item = check(env, unchecked_item); - -// if (const OperatorNode *op_node = item.as()) { -// Operator op = GetRef(op_node); -// auto type = op->type; -// if (operators.find(op->id) != operators.end()) { -// if (!update) { -// throw dmlc::Error("already have definition for XXXX."); -// } - -// auto old_type = operators[op->id]->type; - -// if (!alpha_eq(type, old_type)) { -// throw dmlc::Error( -// "Environment#update changes type, not possible in this mode."); -// } - -// operators.insert({op->id, op}); -// } else { -// operators.insert({op->id, op}); -// } -// } else if (const FunctionNode *d = item.as()) { -// auto def = GetRef(d); -// auto type = def->type; -// if (items.find(def->id) != items.end()) { -// if (!update) { -// throw dmlc::Error("already have definition for XXXX."); -// } - -// auto old_type = items[def->id].as()->type; - -// if (!alpha_eq(type, old_type)) { -// throw dmlc::Error( -// "Environment#update changes type, not possible in this mode."); -// } - -// this->items.insert({def->id, def}); -// } else { -// this->items.insert({def->id, def}); -// } -// } else { -// throw EnvError("internal error: unknown item type, unreachable code"); -// } -// } + // Type check the item before we add it to the environment. + auto env = GetRef(this); + Expr checked_expr = InferType(env, func); + + if (const FunctionNode *func_node = checked_expr.as()) { + auto checked_func = GetRef(func_node); + auto type = checked_func->checked_type(); + + CHECK(IsFullyResolved(type)); + + if (functions.find(var) != functions.end()) { + if (!update) { + throw dmlc::Error("already have definition for XXXX."); + } + + auto old_type = functions[var].as()->checked_type(); + + if (!AlphaEqual(type, old_type)) { + throw dmlc::Error( + "Environment#update changes type, not possible in this mode."); + } + + this->functions.Set(var, checked_func); + } else { + this->functions.Set(var, checked_func); + } + } else { + throw Error("internal error: unknown item type, unreachable code"); + } } void EnvironmentNode::Update(const GlobalVar& var, const Function & func) { @@ -110,6 +96,13 @@ Function EnvironmentNode::Lookup(const std::string &str) { return this->Lookup(id); } +void EnvironmentNode::Merge(const Environment & env) { + for (auto pair : env->functions) { + this->functions.Set(pair.first, pair.second); + } +} + + inline SourceName EnvironmentNode::AddSource(std::string file_name, std::string source) { throw Error("need to restore error handling"); @@ -156,6 +149,40 @@ TVM_REGISTER_API("relay._make.Environment") *ret = EnvironmentNode::make(args[0]); }); +TVM_REGISTER_API("relay._env.Environment_Add") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + env->Add(args[1], args[2], false); + }); + +TVM_REGISTER_API("relay._env.Environment_GetGlobalVar") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + *ret = env->GetGlobalVar(args[1]); + }); + +TVM_REGISTER_API("relay._env.Environment_Lookup") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + GlobalVar var = args[1]; + *ret = env->Lookup(var); + }); + +TVM_REGISTER_API("relay._env.Environment_Lookup_str") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + std::string var_name = args[1]; + auto var = env->GetGlobalVar(var_name); + *ret = env->Lookup(var); + }); + +TVM_REGISTER_API("relay._env.Environment_Merge") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + env->Merge(args[1]); + }); + + TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const EnvironmentNode *node, tvm::IRPrinter *p) { diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 2b235e8b01ad..47d253e91c21 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -168,8 +168,8 @@ TVM_REGISTER_API("relay._make.Let") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const LetNode *node, tvm::IRPrinter *p) { - p->stream << "LetNode(" << node->var << node->value << node->body - << node->value_type << ")"; + p->stream << "LetNode(" << node->var << ", " << node->value + << ", " << node->body << ", " << node->value_type << ")"; }); If IfNode::make(Expr cond, Expr true_value, Expr false_value) { diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 383196f49be9..514df129503a 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -21,14 +21,14 @@ */ #include +#include #include #include -#include #include "./incomplete_type.h" -#include "./unifier.h" #include "./resolve.h" #include "./type_subst.h" #include "./type_visitor.h" +#include "./unifier.h" // #include "tvm/relay/typeck/kindchecker.h" namespace tvm { @@ -61,9 +61,9 @@ struct TypeContext { struct TypeNormalizer : TypeFVisitor { TypeUnifier unifier; - TypeNormalizer(const TypeUnifier & unifier) : unifier(unifier) {} + TypeNormalizer(const TypeUnifier &unifier) : unifier(unifier) {} - Type VisitType_(const TypeCallNode * ty_call_node) { + Type VisitType_(const TypeCallNode *ty_call_node) { auto ty_call = GetRef(ty_call_node); Array normalized_args; @@ -71,7 +71,7 @@ struct TypeNormalizer : TypeFVisitor { for (auto arg : ty_call->args) { normalized_args.push_back(VisitType(arg)); } - + auto all_concrete = true; for (auto arg : normalized_args) { all_concrete = all_concrete && !arg.as(); @@ -82,7 +82,8 @@ struct TypeNormalizer : TypeFVisitor { } else { if (auto ty_rel_node = ty_call->func.as()) { // NB: we substract 1 for the output argument. - auto new_args = ty_rel_node->func_(ty_call->args, ty_call->args.size() - 1); + auto new_args = + ty_rel_node->func_(ty_call->args, ty_call->args.size() - 1); CHECK(new_args.size() == normalized_args.size()); tvm::Array final_args; @@ -110,554 +111,544 @@ class TypeInferencer : private ExprFunctor { TypeContext local_stack; public: - Environment env; - TypeUnifier unifier; - - // Should be in header? - template - T with_frame(const std::function & f) { - TypeContext::LocalFrame fr(local_stack); - return f(); - } - - TypeInferencer(); - TypeInferencer(Environment env, TypeUnifier unifier) : env(env), - unifier(unifier) {} explicit TypeInferencer(Environment env); - - CheckedExpr Infer(const Expr & expr); - - FuncType instantiate(FuncType fn_ty, tvm::Array &ty_args); - - Type Normalize(const Type & t); - - void report_error(const std::string & msg, Span sp); - [[ noreturn ]] void fatal_error(const std::string & msg, Span sp); - - Type unify(const Type &t1, const Type &t2, Span sp); - Type resolve(const Type &t); - Expr resolve(const Expr &e); - CheckedExpr VisitFunction(const Function & f, bool generalize); - void CheckOp(Op op); - // Defn CheckDefn(Defn def); - private: - CheckedExpr VisitExpr_(const LocalVarNode* op) override; - CheckedExpr VisitExpr_(const GlobalVarNode* op) override; - CheckedExpr VisitExpr_(const ConstantNode* op) override; - CheckedExpr VisitExpr_(const TupleNode* op) override; - CheckedExpr VisitExpr_(const ParamNode* op) override; - CheckedExpr VisitExpr_(const FunctionNode* op) override; - CheckedExpr VisitExpr_(const CallNode* op) override; - CheckedExpr VisitExpr_(const LetNode* op) override; - CheckedExpr VisitExpr_(const IfNode* op) override; - CheckedExpr VisitExpr_(const OpNode* op) override; -}; - - TypeInferencer::TypeInferencer() { - this->env = EnvironmentNode::make({}); - this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); - } - - TypeInferencer::TypeInferencer(Environment env) : env(env) { - this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); - } - - Type TypeInferencer::Normalize(const Type & t) { - auto nt = this->resolve(t); - auto normalizer = TypeNormalizer(this->unifier); - return normalizer.VisitType(nt); - } - - CheckedExpr TypeInferencer::Infer(const Expr &expr) { - RELAY_LOG(INFO) << "TypeInferencer::Check expr=" << expr << std::endl; - CheckedExpr checked_expr = this->VisitExpr(expr); - RELAY_LOG(INFO) << "TypeInferencer::Check type=" << checked_expr.type << std::endl; - Type final_type = Normalize(checked_expr.type); - RELAY_LOG(INFO) << "TypeInferencer::Check type_after_subst=" << final_type << std::endl; - checked_expr.expr->checked_type_ = final_type; - return checked_expr; - } + Environment env; + TypeUnifier unifier; - CheckedExpr TypeInferencer::VisitExpr_(const LocalVarNode *op) { - auto var = GetRef(op); - return { var, this->local_stack.lookup(var) }; + // Should be in header? + template + T with_frame(const std::function &f) { + TypeContext::LocalFrame fr(local_stack); + return f(); } - CheckedExpr TypeInferencer::VisitExpr_(const GlobalVarNode *op) { - // GlobalVar id = GetRef(op); - // Item item = this->env->lookup(id); + TypeInferencer(); + TypeInferencer(Environment env, TypeUnifier unifier) + : env(env), unifier(unifier) {} + explicit TypeInferencer(Environment env); - // if (const OpNode *op = item.as()) { - // return op->type; - // } + CheckedExpr Infer(const Expr &expr); - // if (const DefnNode *dn = item.as()) { - // Defn def = GetRef(dn); - // return def->type; - // } + FuncType instantiate(FuncType fn_ty, tvm::Array &ty_args); - // this->fatal_error("Unhandled case in GlobalId", op->span); - throw Error("hereeee"); - } - - CheckedExpr TypeInferencer::VisitExpr_(const ConstantNode *const_node) { - return { GetRef(const_node), const_node->tensor_type() }; - } + Type Normalize(const Type &t); - CheckedExpr TypeInferencer::VisitExpr_(const TupleNode *op) { - // Tuple pl = GetRef(op); + void report_error(const std::string &msg, Span sp); + [[noreturn]] void fatal_error(const std::string &msg, Span sp); - // std::vector field_types; - // for (auto field = pl->fields.begin(); field != pl->fields.end(); field++) - // { - // field_types.push_back(this->Check(*field)); - // } + Type unify(const Type &t1, const Type &t2, Span sp); + Type resolve(const Type &t); + Expr resolve(const Expr &e); + CheckedExpr VisitFunction(const Function &f, bool generalize); + void CheckOp(Op op); + // Defn CheckDefn(Defn def); + private: + CheckedExpr VisitExpr_(const LocalVarNode *op) override; + CheckedExpr VisitExpr_(const GlobalVarNode *op) override; + CheckedExpr VisitExpr_(const ConstantNode *op) override; + CheckedExpr VisitExpr_(const TupleNode *op) override; + CheckedExpr VisitExpr_(const ParamNode *op) override; + CheckedExpr VisitExpr_(const FunctionNode *op) override; + CheckedExpr VisitExpr_(const CallNode *op) override; + CheckedExpr VisitExpr_(const LetNode *op) override; + CheckedExpr VisitExpr_(const IfNode *op) override; + CheckedExpr VisitExpr_(const OpNode *op) override; +}; - // return TupleTypeNode::make(field_types); - throw Error("TupleNode NYI"); +TypeInferencer::TypeInferencer() { + this->env = EnvironmentNode::make({}); + this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); +} + +TypeInferencer::TypeInferencer(Environment env) : env(env) { + this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); +} + +Type TypeInferencer::Normalize(const Type &t) { + auto nt = this->resolve(t); + auto normalizer = TypeNormalizer(this->unifier); + return normalizer.VisitType(nt); +} + +CheckedExpr TypeInferencer::Infer(const Expr &expr) { + RELAY_LOG(INFO) << "TypeInferencer::Check expr=" << expr << std::endl; + CheckedExpr checked_expr = this->VisitExpr(expr); + RELAY_LOG(INFO) << "TypeInferencer::Check type=" << checked_expr.type + << std::endl; + Type final_type = Normalize(checked_expr.type); + RELAY_LOG(INFO) << "TypeInferencer::Check type_after_subst=" << final_type + << std::endl; + checked_expr.expr->checked_type_ = final_type; + return checked_expr; +} + +CheckedExpr TypeInferencer::VisitExpr_(const LocalVarNode *op) { + auto var = GetRef(op); + return {var, this->local_stack.lookup(var)}; +} + +CheckedExpr TypeInferencer::VisitExpr_(const GlobalVarNode *op) { + GlobalVar var = GetRef(op); + Expr e = this->env->Lookup(var); + return { var, e->checked_type() }; +} + +CheckedExpr TypeInferencer::VisitExpr_(const ConstantNode *const_node) { + return {GetRef(const_node), const_node->tensor_type()}; +} + +CheckedExpr TypeInferencer::VisitExpr_(const TupleNode *op) { + Tuple pl = GetRef(op); + + std::vector field_exprs; + std::vector field_types; + for (auto field = pl->fields.begin(); field != pl->fields.end(); field++) { + auto checked_field = Infer(*field); + field_exprs.push_back(checked_field.expr); + field_types.push_back(checked_field.type); } - CheckedExpr TypeInferencer::VisitExpr_(const ParamNode *param) { - auto rtype = resolve(param->type); - return { ParamNode::make(param->var, rtype), rtype }; - } + return { TupleNode::make(field_exprs), TupleTypeNode::make(field_types) }; +} + +CheckedExpr TypeInferencer::VisitExpr_(const ParamNode *param) { + auto rtype = resolve(param->type); + return {ParamNode::make(param->var, rtype), rtype}; +} + +// // We should probably generalize the subst code. +// struct GeneralizeTypeType : TypeFVisitor { +// Map vars_to_id; +// const TypeUnifier &unifier; + +// GeneralizeTypeType(Map vars_to_id, +// const TypeUnifier &unifier) +// : vars_to_id(vars_to_id), unifier(unifier) {} + +// Type VisitType_(const TypeVarNode *op) override { +// auto repr = unifier->subst(GetRef(op)); +// if (auto tvn = repr.as()) { +// auto ty_var = GetRef(tvn); +// if (vars_to_id.find(ty_var) != vars_to_id.end()) { +// return vars_to_id[ty_var]; +// } else { +// return ty_var; +// } +// } else { +// return this->VisitType(repr); +// } +// } +// }; + +// struct GeneralizeTypeExpr : ExprFVisitor<> { +// Map vars_to_id; +// const TypeUnifier &unifier; + +// GeneralizeTypeExpr(const TypeUnifier &unifier, +// Map vars_to_id) +// : vars_to_id(vars_to_id), unifier(unifier) {} + +// Type VisitType(const Type &t) { +// return GeneralizeTypeType(vars_to_id, unifier).VisitType(t); +// } +// }; + +CheckedExpr TypeInferencer::VisitFunction(const Function &f, bool generalize) { + // First we add the parameters to the context allowing us to check their + // types. + + // TODO(@jroesch): support polymorphism + + std::vector param_types; + std::vector params; + + return this->with_frame([&]() -> CheckedExpr { + for (auto param : f->params) { + CheckedExpr checked_param = this->Infer(param); + Type arg_type; + param_types.push_back(checked_param.type); + params.push_back(GetRef(checked_param.expr.as())); + this->local_stack.insert(param->var, checked_param.type); + } - // // We should probably generalize the subst code. - // struct GeneralizeTypeType : TypeFVisitor { - // Map vars_to_id; - // const TypeUnifier &unifier; - - // GeneralizeTypeType(Map vars_to_id, - // const TypeUnifier &unifier) - // : vars_to_id(vars_to_id), unifier(unifier) {} - - // Type VisitType_(const TypeVarNode *op) override { - // auto repr = unifier->subst(GetRef(op)); - // if (auto tvn = repr.as()) { - // auto ty_var = GetRef(tvn); - // if (vars_to_id.find(ty_var) != vars_to_id.end()) { - // return vars_to_id[ty_var]; - // } else { - // return ty_var; - // } + auto checked_body = this->Infer(f->body); + auto inferred_rtype = checked_body.type; + auto annotated_rtype = resolve(f->ret_type); + + auto unified_rtype = this->unify(inferred_rtype, annotated_rtype, f->span); + + return {FunctionNode::make(params, unified_rtype, checked_body.expr, {}), + FuncTypeNode::make(param_types, unified_rtype, {}, {})}; + }); + + // // typecheck body and ensure that it matches stated return type + // // TODO(sslyu): should the unified return type override the annotated + // one? Type checked_return = this->Check(f->body); Type ret_type = + // resolve(f->ret_type); Type unified = + // this->unify(simple_eval_shape(ret_type), + // simple_eval_shape(checked_return), f->span); + // return TypeArrowNode::make(arg_types, unified); + // }); + // if (generalize) { + // auto free_vars = free_type_vars(resolve(fn_type)); + // std::set dedup_free_vars; + + // for (auto free_var : free_vars) { + // auto repr = this->unifier->subst(free_var); + // if (auto new_free_var_node = repr.as()) { + // dedup_free_vars.insert(GetRef(new_free_var_node)); // } else { - // return this->VisitType(repr); + // // debug(repr); + // throw dmlc::Error( + // "internal error: this list should only contain type var + // nodes"); // } // } - // }; - // struct GeneralizeTypeExpr : ExprFVisitor<> { // Map vars_to_id; - // const TypeUnifier &unifier; - // GeneralizeTypeExpr(const TypeUnifier &unifier, - // Map vars_to_id) - // : vars_to_id(vars_to_id), unifier(unifier) {} - - // Type VisitType(const Type &t) { - // return GeneralizeTypeType(vars_to_id, unifier).VisitType(t); + // GenFresh gf; + // for (auto free_var : dedup_free_vars) { + // vars_to_id.Set(free_var, gf.freshTV(free_var->kind)); // } - // }; - - CheckedExpr TypeInferencer::VisitFunction(const Function &f, bool generalize) { - // First we add the parameters to the context allowing us to check their - // types. - - // TODO(@jroesch): support polymorphism - - std::vector param_types; - std::vector params; - - return this->with_frame([&]() -> CheckedExpr { - for (auto param : f->params) { - CheckedExpr checked_param = this->Infer(param); - Type arg_type; - param_types.push_back(checked_param.type); - params.push_back(GetRef(checked_param.expr.as())); - this->local_stack.insert(param->var, checked_param.type); - } - - auto checked_body = this->Infer(f->body); - auto inferred_rtype = checked_body.type; - auto annotated_rtype = resolve(f->ret_type); - - auto unified_rtype = this->unify(inferred_rtype, annotated_rtype, f->span); - - return { FunctionNode::make(params, unified_rtype, checked_body.expr, {}), - FuncTypeNode::make(param_types, unified_rtype, {}, {}) }; - }); - - // // typecheck body and ensure that it matches stated return type - // // TODO(sslyu): should the unified return type override the annotated - // one? Type checked_return = this->Check(f->body); Type ret_type = - // resolve(f->ret_type); Type unified = - // this->unify(simple_eval_shape(ret_type), - // simple_eval_shape(checked_return), f->span); - // return TypeArrowNode::make(arg_types, unified); - // }); - // if (generalize) { - // auto free_vars = free_type_vars(resolve(fn_type)); - // std::set dedup_free_vars; - - // for (auto free_var : free_vars) { - // auto repr = this->unifier->subst(free_var); - // if (auto new_free_var_node = repr.as()) { - // dedup_free_vars.insert(GetRef(new_free_var_node)); - // } else { - // // debug(repr); - // throw dmlc::Error( - // "internal error: this list should only contain type var - // nodes"); - // } - // } - - // Map vars_to_id; - - // GenFresh gf; - // for (auto free_var : dedup_free_vars) { - // vars_to_id.Set(free_var, gf.freshTV(free_var->kind)); - // } - - // fn_type = GeneralizeTypeType(vars_to_id, unifier).VisitType(fn_type); - // for (std::pair pair : vars_to_id) { - // // NB: In generalization we want to find type variables with - // // *no constraints* on them, and convert them to universally - // quantified - // // variables. - // // - // // i.e the program can be abstracted over the details of *that* type. - - // // For example a program that works irrespective of shape or - // datatype. - - // // In order to do this we find the set of free type variables in the - // // term, and then unify them with the fresh type ids we generate. - // // - // // Remember importantly these type variables still may appear in many - // // places in the program including both types and expressions. - - // // Our method for resolving these is to unify them with the variables - // // as we build the new quanitifer, changing from a program with - // "holes" - // // to one that is properly abstracted over. - - // // Finally later on we can iterate over the whole term and change - // from - // // type variables to these type ids. - // this->unify(pair.first, pair.second, pair.second->span); - // fn_type = TypeQuantifierNode::make(pair.second, fn_type); - // } - // } else { - // for (auto i = f->ty_params.size(); i > 0; i--) { - // auto ty_param = f->ty_params[i - 1]; - // auto ty_param_node = ty_param.as(); - // if (!ty_param_node) { - // throw dmlc::Error("internal error should be TypeParam"); - // } - // auto fresh_tid = - // TypeParamNode::make(ty_param_node->name, ty_param_node->kind); - // fn_type = - // TypeSubst(fn_type, GetRef(ty_param_node), fresh_tid); - // fn_type = TypeQuantifierNode::make(fresh_tid, fn_type); - // } - // } - - // return fn_type; - } - CheckedExpr TypeInferencer::VisitExpr_(const FunctionNode *op) { - return this->VisitFunction(GetRef(op), false); - } + // fn_type = GeneralizeTypeType(vars_to_id, unifier).VisitType(fn_type); + // for (std::pair pair : vars_to_id) { + // // NB: In generalization we want to find type variables with + // // *no constraints* on them, and convert them to universally + // quantified + // // variables. + // // + // // i.e the program can be abstracted over the details of *that* type. + + // // For example a program that works irrespective of shape or + // datatype. + + // // In order to do this we find the set of free type variables in the + // // term, and then unify them with the fresh type ids we generate. + // // + // // Remember importantly these type variables still may appear in many + // // places in the program including both types and expressions. + + // // Our method for resolving these is to unify them with the variables + // // as we build the new quanitifer, changing from a program with + // "holes" + // // to one that is properly abstracted over. + + // // Finally later on we can iterate over the whole term and change + // from + // // type variables to these type ids. + // this->unify(pair.first, pair.second, pair.second->span); + // fn_type = TypeQuantifierNode::make(pair.second, fn_type); + // } + // } else { + // for (auto i = f->ty_params.size(); i > 0; i--) { + // auto ty_param = f->ty_params[i - 1]; + // auto ty_param_node = ty_param.as(); + // if (!ty_param_node) { + // throw dmlc::Error("internal error should be TypeParam"); + // } + // auto fresh_tid = + // TypeParamNode::make(ty_param_node->name, ty_param_node->kind); + // fn_type = + // TypeSubst(fn_type, GetRef(ty_param_node), fresh_tid); + // fn_type = TypeQuantifierNode::make(fresh_tid, fn_type); + // } + // } - FuncType TypeInferencer::instantiate(FuncType fn_ty, tvm::Array &ty_args) { - tvm::Map subst_map; + // return fn_type; +} - // Build a subsitituion map up from the function type and type arguments. - // Eventually allow the type vars to be passed in. - for (auto ty_param : fn_ty->type_params) { - IncompleteType fresh = IncompleteTypeNode::make(ty_param->kind); - this->unifier->insert(fresh); - ty_args.push_back(fresh); - subst_map.Set(ty_param, fresh); - } +CheckedExpr TypeInferencer::VisitExpr_(const FunctionNode *op) { + return this->VisitFunction(GetRef(op), false); +} - Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, {}); - inst_ty = TypeSubst(fn_ty, subst_map); +FuncType TypeInferencer::instantiate(FuncType fn_ty, + tvm::Array &ty_args) { + tvm::Map subst_map; - // if (!check_kind(t)) { - // this->fatal_error("Kind rules broken when instantiating type - // variables", - // t->span); - // } - - return GetRef(inst_ty.as()); + // Build a subsitituion map up from the function type and type arguments. + // Eventually allow the type vars to be passed in. + for (auto ty_param : fn_ty->type_params) { + IncompleteType fresh = IncompleteTypeNode::make(ty_param->kind); + this->unifier->insert(fresh); + ty_args.push_back(fresh); + subst_map.Set(ty_param, fresh); } - CheckedExpr TypeInferencer::VisitExpr_(const CallNode *op) { - Call c = GetRef(op); + Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, {}); + inst_ty = TypeSubst(fn_ty, subst_map); - auto checked_op = this->Infer(c->op); + // if (!check_kind(t)) { + // this->fatal_error("Kind rules broken when instantiating type + // variables", + // t->span); + // } - RELAY_LOG(INFO) << "TypeInferencer::VisitExpr_ op=" << c << std::endl - << "fn_ty=" << checked_op.type << std::endl; + return GetRef(inst_ty.as()); +} +CheckedExpr TypeInferencer::VisitExpr_(const CallNode *op) { + Call c = GetRef(op); - auto fn_ty_node = checked_op.type.as(); + auto checked_op = this->Infer(c->op); - if (!fn_ty_node) { - this->fatal_error("only expressions with function types can be called", c->op->span); - } + RELAY_LOG(INFO) << "TypeInferencer::VisitExpr_ op=" << c << std::endl + << "fn_ty=" << checked_op.type << std::endl; - // We now have a function type. - FuncType fn_ty = GetRef(fn_ty_node); - - tvm::Array ty_args; - if (ty_args.size() != 0) { - throw Error("found manually suplied type args, not supported"); - } + auto fn_ty_node = checked_op.type.as(); - fn_ty = instantiate(fn_ty, ty_args); + if (!fn_ty_node) { + this->fatal_error("only expressions with function types can be called", + c->op->span); + } - std::vector arg_types; - std::vector checked_args; + // We now have a function type. + FuncType fn_ty = GetRef(fn_ty_node); - for (auto arg : c->args) { - auto checked_arg = this->Infer(arg); - arg_types.push_back(checked_arg.type); - checked_args.push_back(checked_arg.expr); - } + tvm::Array ty_args; + if (ty_args.size() != 0) { + throw Error("found manually suplied type args, not supported"); + } - auto type_arity = fn_ty->arg_types.size(); - auto number_of_args = arg_types.size(); + fn_ty = instantiate(fn_ty, ty_args); - if (type_arity != number_of_args) { - if (type_arity < number_of_args) { - this->fatal_error("the function is provided too many arguments", - c->span); - } else { - this->fatal_error("the function is provided too few arguments", - c->span); - } - } + std::vector arg_types; + std::vector checked_args; - for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { - this->unify(fn_ty->arg_types[i], arg_types[i], c->args[i]->span); - } + for (auto arg : c->args) { + auto checked_arg = this->Infer(arg); + arg_types.push_back(checked_arg.type); + checked_args.push_back(checked_arg.expr); + } - // After we unify the arguments we should know more about the type - // arguments, let's run a quick pass over them to find new - // representatives. + auto type_arity = fn_ty->arg_types.size(); + auto number_of_args = arg_types.size(); - for (size_t i = 0; i < ty_args.size(); i++) { - ty_args.Set(i, this->unifier->subst(ty_args[i])); + if (type_arity != number_of_args) { + if (type_arity < number_of_args) { + this->fatal_error("the function is provided too many arguments", c->span); + } else { + this->fatal_error("the function is provided too few arguments", c->span); } + } - auto new_call = CallNode::make(checked_op.expr, checked_args, c->attrs, ty_args); - - return { new_call, fn_ty->ret_type }; + for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { + this->unify(fn_ty->arg_types[i], arg_types[i], c->args[i]->span); } - CheckedExpr TypeInferencer::VisitExpr_(const LetNode *op) { - Let let = GetRef(op); + // After we unify the arguments we should know more about the type + // arguments, let's run a quick pass over them to find new + // representatives. - CheckedExpr checked_value; - Type annotated_ty = resolve(let->value_type); + for (size_t i = 0; i < ty_args.size(); i++) { + ty_args.Set(i, this->unifier->subst(ty_args[i])); + } + auto new_call = + CallNode::make(checked_op.expr, checked_args, c->attrs, ty_args); - // If we are let-defining a function, we want to be able to - // recursively name the function in order to support recursive - // local definitions. - if (let->value.as()) { - with_frame([&]() { - local_stack.insert(let->var, annotated_ty); - checked_value = Infer(let->value); - }); - } else { - checked_value = Infer(let->value); - } + return {new_call, fn_ty->ret_type}; +} - Type unified_ty = - this->unify(checked_value.type, annotated_ty, let->span); +CheckedExpr TypeInferencer::VisitExpr_(const LetNode *op) { + Let let = GetRef(op); - // Update type context with unified type now that we have - // solved this equation. - local_stack.insert(let->var, unified_ty); + CheckedExpr checked_value; + Type annotated_ty = resolve(let->value_type); - auto checked_body = with_frame([&]() { - local_stack.insert(let->var, unified_ty); - return Infer(let->body); + // If we are let-defining a function, we want to be able to + // recursively name the function in order to support recursive + // local definitions. + if (let->value.as()) { + with_frame([&]() { + local_stack.insert(let->var, annotated_ty); + checked_value = Infer(let->value); }); - - auto checked_let = LetNode::make( - let->var, - checked_value.expr, - checked_body.expr, - let->value_type); - - return { checked_let, checked_body.type }; + } else { + checked_value = Infer(let->value); } - CheckedExpr TypeInferencer::VisitExpr_(const IfNode *op) { - If ifn = GetRef(op); - - // Ensure the type of the guard is of Tensor[Bool, ()], - // that is a rank-0 boolean tensor. - auto checked_cond = this->Infer(ifn->cond); - auto cond_type = checked_cond.type; - - if (const TensorTypeNode *tt_node = cond_type.as()) { - TensorType tt = GetRef(tt_node); - if (tt->dtype.is_bool() && tt->shape.size() == 0) { - auto checked_true = this->Infer(ifn->true_value); - auto checked_false = this->Infer(ifn->false_value); - auto unified_type = this->unify(checked_true.type, checked_false.type, ifn->span); - auto checked_if = IfNode::make(checked_cond.expr, checked_true.expr, checked_false.expr); - return { checked_if, unified_type }; - } - } + Type unified_ty = this->unify(checked_value.type, annotated_ty, let->span); - this->fatal_error("if-then-else guard must be a rank-0 boolean tensor", - ifn->cond->span); - } + // Update type context with unified type now that we have + // solved this equation. + local_stack.insert(let->var, unified_ty); - CheckedExpr TypeInferencer::VisitExpr_(const OpNode *op_node) { - auto op = GetRef(op_node); - return { op, op->op_type }; - } - - Type TypeInferencer::resolve(const Type &t) { - if (t.defined()) { - return ::tvm::relay::Resolve(this->unifier, t); - } else { - return IncompleteTypeNode::make(TypeParamNode::Kind::kType); + auto checked_body = with_frame([&]() { + local_stack.insert(let->var, unified_ty); + return Infer(let->body); + }); + + auto checked_let = LetNode::make(let->var, checked_value.expr, + checked_body.expr, let->value_type); + + return {checked_let, checked_body.type}; +} + +CheckedExpr TypeInferencer::VisitExpr_(const IfNode *op) { + If ifn = GetRef(op); + + // Ensure the type of the guard is of Tensor[Bool, ()], + // that is a rank-0 boolean tensor. + auto checked_cond = this->Infer(ifn->cond); + auto cond_type = checked_cond.type; + + if (const TensorTypeNode *tt_node = cond_type.as()) { + TensorType tt = GetRef(tt_node); + if (tt->dtype.is_bool() && tt->shape.size() == 0) { + auto checked_true = this->Infer(ifn->true_value); + auto checked_false = this->Infer(ifn->false_value); + auto unified_type = + this->unify(checked_true.type, checked_false.type, ifn->span); + auto checked_if = IfNode::make(checked_cond.expr, checked_true.expr, + checked_false.expr); + return {checked_if, unified_type}; } } - Expr TypeInferencer::resolve(const Expr &e) { - CHECK(e.defined()); - return ::tvm::relay::Resolve(this->unifier, e); - } + this->fatal_error("if-then-else guard must be a rank-0 boolean tensor", + ifn->cond->span); +} - void TypeInferencer::CheckOp(Op op) { - throw Error("NYI"); - // if (!check_kind(op->type)) { - // report_error("the type of the operator is ill formed", op->type->span); - // } +CheckedExpr TypeInferencer::VisitExpr_(const OpNode *op_node) { + auto op = GetRef(op_node); + return {op, op->op_type}; +} - // // Fix me - // return op; +Type TypeInferencer::resolve(const Type &t) { + if (t.defined()) { + return ::tvm::relay::Resolve(this->unifier, t); + } else { + return IncompleteTypeNode::make(TypeParamNode::Kind::kType); } +} - // Defn TypeInferencer::CheckDefn(Defn defn) { - // // This is to handle recursion, but we need to speculatively - // // put it in env, then remove it. - // env->items.insert({defn->id, defn}); - - // Type expected_ty = this->resolve(defn->type); - - // Expr body = defn->body; - - // auto checked_ty = Check(body); - - // try { - // Type uret_type = unify(expected_ty, checked_ty, defn->body->span); - // CHECK(is_fully_resolved(uret_type)); - // // Now let's clean up our work from earlier. - // env->items.erase(defn->id); - // return DefnNode::make(defn->id, uret_type, this->resolve(defn->body)); - // } catch (const UnificationError& err) { - // std::string msg = std::string("mismatch between `") + - // PrintType(env, expected_ty, WrapWidth(40)) + "` and - // `" + PrintType(env, checked_ty, WrapWidth(40)) + - // "`"; - // fatal_error(msg, defn->span); - // } - // } +Expr TypeInferencer::resolve(const Expr &e) { + CHECK(e.defined()); + return ::tvm::relay::Resolve(this->unifier, e); +} - Expr Infer(const Environment &env, const Expr &e) { - TypeInferencer ti(env); - auto checked_expr = ti.Infer(e); - return checked_expr.expr; - } - - // Item Check(const Environment &env, const Item &i) { - // TypeInferencer tc(env); - - // try { - // if (const DefnNode *defn = i.as()) { - // return tc.CheckDefn(GetRef(defn)); - // } else if (const OpNode *op_node = i.as()) { - // return tc.CheckOp(GetRef(op_node)); - // } else { - // throw dmlc::Error("internal error: unknown Item type"); - // } - // } catch (const FatalTypeError &err) { - // env->display_errors(); - // throw dmlc::Error( - // "We encountered a fatal error while type checking your program, - // please " "read above for more details."); - // } +void TypeInferencer::CheckOp(Op op) { + throw Error("NYI"); + // if (!check_kind(op->type)) { + // report_error("the type of the operator is ill formed", op->type->span); // } - inline void TypeInferencer::report_error(const std::string &msg, Span sp) { - // this->env->report_error(msg, sp); - } - - void TypeInferencer::fatal_error(const std::string &msg, Span sp) { - // this->env->report_error(msg, sp); - throw FatalTypeError( - "internal error: this exception should" - "be handled and errors reported with Environment::display_errors\n" + - msg); + // // Fix me + // return op; +} + +// Defn TypeInferencer::CheckDefn(Defn defn) { +// // This is to handle recursion, but we need to speculatively +// // put it in env, then remove it. +// env->items.insert({defn->id, defn}); + +// Type expected_ty = this->resolve(defn->type); + +// Expr body = defn->body; + +// auto checked_ty = Check(body); + +// try { +// Type uret_type = unify(expected_ty, checked_ty, defn->body->span); +// CHECK(is_fully_resolved(uret_type)); +// // Now let's clean up our work from earlier. +// env->items.erase(defn->id); +// return DefnNode::make(defn->id, uret_type, this->resolve(defn->body)); +// } catch (const UnificationError& err) { +// std::string msg = std::string("mismatch between `") + +// PrintType(env, expected_ty, WrapWidth(40)) + "` and +// `" + PrintType(env, checked_ty, WrapWidth(40)) + +// "`"; +// fatal_error(msg, defn->span); +// } +// } + +Expr InferType(const Environment &env, const Expr &e) { + TypeInferencer ti(env); + auto checked_expr = ti.Infer(e); + return checked_expr.expr; +} + +// Item Check(const Environment &env, const Item &i) { +// TypeInferencer tc(env); + +// try { +// if (const DefnNode *defn = i.as()) { +// return tc.CheckDefn(GetRef(defn)); +// } else if (const OpNode *op_node = i.as()) { +// return tc.CheckOp(GetRef(op_node)); +// } else { +// throw dmlc::Error("internal error: unknown Item type"); +// } +// } catch (const FatalTypeError &err) { +// env->display_errors(); +// throw dmlc::Error( +// "We encountered a fatal error while type checking your program, +// please " "read above for more details."); +// } +// } + +inline void TypeInferencer::report_error(const std::string &msg, Span sp) { + // this->env->report_error(msg, sp); +} + +void TypeInferencer::fatal_error(const std::string &msg, Span sp) { + // this->env->report_error(msg, sp); + throw FatalTypeError( + "internal error: this exception should" + "be handled and errors reported with Environment::display_errors\n" + + msg); +} + +Type TypeInferencer::unify(const Type &t1, const Type &t2, Span sp) { + try { + return Normalize(this->unifier->unify(t1, t2)); + } catch (const dmlc::Error &e) { + std::stringstream ss; + ss << "Error unifying `"; + ss << t1; + // ss << PrintType(env, t1, WrapWidth(40)); + ss << "` and `"; + ss << t2; + // ss << PrintType(env, t2, WrapWidth(40)); + ss << "`: " << e.what(); + this->fatal_error(ss.str(), sp); } +} - Type TypeInferencer::unify(const Type &t1, const Type &t2, Span sp) { - try { - return Normalize(this->unifier->unify(t1, t2)); - } catch (const dmlc::Error &e) { - std::stringstream ss; - ss << "Error unifying `"; - ss << t1; - // ss << PrintType(env, t1, WrapWidth(40)); - ss << "` and `"; - ss << t2; - // ss << PrintType(env, t2, WrapWidth(40)); - ss << "`: " << e.what(); - this->fatal_error(ss.str(), sp); - } - } +TVM_REGISTER_API("relay._ir_pass.check_expr") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + Expr e = args[1]; + *ret = InferType(env, e); + }); - TVM_REGISTER_API("relay._ir_pass.check_expr") - .set_body([](TVMArgs args, TVMRetValue *ret) { - Environment env = args[0]; - Expr e = args[1]; - *ret = Infer(env, e); - }); - - // TODO(@jroesch): put in a better namespace. - TVM_REGISTER_API("relay._ir_pass._get_checked_type") - .set_body([](TVMArgs args, TVMRetValue *ret) { - Expr e = args[0]; - *ret = e->checked_type(); - }); - - IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) { - std::shared_ptr n = - std::make_shared(); - n->kind = std::move(kind); - return IncompleteType(n); - } +// TODO(@jroesch): put in a better namespace. +TVM_REGISTER_API("relay._ir_pass._get_checked_type") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Expr e = args[0]; + *ret = e->checked_type(); + }); - TVM_REGISTER_API("relay._make.IncompleteType") - .set_body([](TVMArgs args, TVMRetValue *ret) { - int kind = args[0]; - *ret = IncompleteTypeNode::make(static_cast(kind)); - }); +IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) { + std::shared_ptr n = + std::make_shared(); + n->kind = std::move(kind); + return IncompleteType(n); +} + +TVM_REGISTER_API("relay._make.IncompleteType") + .set_body([](TVMArgs args, TVMRetValue *ret) { + int kind = args[0]; + *ret = IncompleteTypeNode::make(static_cast(kind)); + }); - TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const IncompleteTypeNode *node, - tvm::IRPrinter *p) { - p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; - }); +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const IncompleteTypeNode *node, + tvm::IRPrinter *p) { + p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; + }); } // namespace relay -} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index e928cd5cb76a..bf95992d952f 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -6,12 +6,14 @@ from tvm.relay.env import Environment from tvm.relay.op import log, add -def has_type(expr, typ): - env = Environment({}) +def has_type(expr, typ, env=Environment({})): checked_expr = check_expr(env, expr) - import pdb; pdb.set_trace() return checked_expr.checked_type() == typ +def decl_has_type(env, name, typ): + func = env.lookup(name) + return func.checked_type() == typ + def test_monomorphic_let(): "Program: let x = 1; return x" b = IRBuilder() @@ -46,7 +48,24 @@ def test_dual_op(): b.ret(t2) assert has_type(func.to_func(), func_type([float_type()], float_type())) + +def test_decl(): + """Program: + def f(x : Tensor[f32, (10, 10)]) { + let lx = log(x); + return lx; + } + """ + b = IRBuilder() + x = b.param('x') + with b.decl('f', x) as d: + lx = d.let('lx', log(x)) + d.ret(lx) + _, env = b.get() + assert decl_has_type(env, 'f', func_type([float_type()], float_type())) + if __name__ == "__main__": test_monomorphic_let() test_single_op() test_dual_op() + test_decl() From 124381f60e825fe7579aba6818a80a37c11e7af8 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 31 Aug 2018 17:01:11 -0700 Subject: [PATCH 064/136] Address first round of CR comments --- include/tvm/relay/base.h | 2 +- include/tvm/relay/environment.h | 12 +++--- include/tvm/relay/error.h | 4 ++ include/tvm/relay/expr.h | 2 - include/tvm/relay/pass/alpha_eq.h | 42 +++++++++++++++++-- include/tvm/relay/pass/type_infer.h | 6 +-- include/tvm/relay/source_map.h | 2 +- src/relay/pass/unifier.h | 5 ++- src/relay/source_map.cc | 4 +- .../relay/test_tyck_eval_integration.py | 2 +- 10 files changed, 59 insertions(+), 22 deletions(-) diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 092f5ceb8fc3..e78c4b28e9ca 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -24,7 +24,7 @@ namespace relay { /*! * \brief we always used NodeRef for referencing nodes. * - * By default, NodePtr is a std::shared_ptr of node + * By default, NodeRef is a std::shared_ptr of node */ using NodeRef = tvm::NodeRef; diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h index 5ad7ba8e0010..ca5b8ac90df4 100644 --- a/include/tvm/relay/environment.h +++ b/include/tvm/relay/environment.h @@ -7,13 +7,13 @@ #ifndef TVM_RELAY_ENVIRONMENT_H_ #define TVM_RELAY_ENVIRONMENT_H_ +#include +#include +#include +#include +#include #include #include -#include "./expr.h" -#include "./type.h" -#include "./op.h" -#include "./error.h" -#include "tvm/relay/source_map.h" namespace tvm { namespace relay { @@ -28,7 +28,7 @@ struct Environment; * It contains all global functions, and configuration * options. * - * Many operations require acess to the global + * Many operations require access to the global * Environment. We pass the Environment by value * in a functional style as an explicit argument, * but we will mutate the Environment while optimizing diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index 4f6a27d209c8..433c08abfd58 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -16,6 +16,10 @@ struct Error : dmlc::Error { Error(std::string msg) : dmlc::Error(msg) {} }; +struct InternalError : Error { + InternalError(std::string msg) : Error(msg) {} +}; + struct SpannedError { std::string msg; Span sp; diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index ddac633f9d09..8ea3980dad46 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -224,8 +224,6 @@ class FunctionNode : public ExprNode { RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr); -using Attrs = tvm::Attrs; - /*! * \brief Call corresponds to operator invocation. * Corresponds to the operator in computational graph terminology. diff --git a/include/tvm/relay/pass/alpha_eq.h b/include/tvm/relay/pass/alpha_eq.h index 9f3c2138a440..51b5b4dd8b70 100644 --- a/include/tvm/relay/pass/alpha_eq.h +++ b/include/tvm/relay/pass/alpha_eq.h @@ -6,14 +6,48 @@ #ifndef TVM_RELAY_ALPHA_EQ_H_ #define TVM_RELAY_ALPHA_EQ_H_ -#include "../type.h" -#include "../expr.h" +#include +#include namespace tvm { namespace relay { -bool AlphaEqual(const Expr & e1, const Expr & e2); -bool AlphaEqual(const Type & t1, const Type & t2); +/*! \brief Compare two expressions for structural equivalence. + + This comparsion operator respects scoping and compares + expressions without regard to variable choice. + + For example: `let x = 1 in x` is equal to `let y = 1 in y`. + + See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence + for more details. + + \param e1 The left hand expression. + \param e2 The right hand expression. + + \return true if equal, otherwise false + +*/ +bool AlphaEqual(const Expr& e1, const Expr& e2); + +/*! \brief Compare two types for structural equivalence. + + This comparsion operator respects scoping and compares + expressions without regard to variable choice. + + For example: `forall s, Tensor[f32, s]` is equal to + `forall w, Tensor[f32, w]`. + + See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence + for more details. + + \param t1 The left hand type. + \param t2 The right hand type. + + \return true if equal, otherwise false + +*/ +bool AlphaEqual(const Type& t1, const Type& t2); } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/pass/type_infer.h b/include/tvm/relay/pass/type_infer.h index 2b860a5e89ef..a75eac6cc0da 100644 --- a/include/tvm/relay/pass/type_infer.h +++ b/include/tvm/relay/pass/type_infer.h @@ -6,8 +6,8 @@ * The pass produces a new expression with its checked_type * field populated and incomplete types resolved. */ -#ifndef TVM_RELAY_PASS_TYPECHECKER_H_ -#define TVM_RELAY_PASS_TYPECHECKER_H_ +#ifndef TVM_RELAY_PASS_TYPE_INFER_H_ +#define TVM_RELAY_PASS_TYPE_INFER_H_ #include "tvm/relay/expr.h" #include "tvm/relay/environment.h" @@ -22,4 +22,4 @@ Op CheckOp(const Environment & env, const Op & op); } // namespace relay } // namespace tvm -#endif // TVM_RELAY_PASS_TYPECHECKER_H_ +#endif // TVM_RELAY_PASS_TYPE_INFER_H_ diff --git a/include/tvm/relay/source_map.h b/include/tvm/relay/source_map.h index 71bf93aa1ed9..a4dbc20b30ff 100644 --- a/include/tvm/relay/source_map.h +++ b/include/tvm/relay/source_map.h @@ -18,7 +18,7 @@ struct SourceFragment { std::string file_name; std::vector source_lines; - SourceFragment(std::string file_name, std::string source); + SourceFragment(const std::string& file_name, const std::string& source); SourceFragment(const SourceFragment& sf) { this->file_name = sf.file_name; diff --git a/src/relay/pass/unifier.h b/src/relay/pass/unifier.h index 5a4adea5c44e..64485768c2f0 100644 --- a/src/relay/pass/unifier.h +++ b/src/relay/pass/unifier.h @@ -61,8 +61,9 @@ class UnionFind : public NodeRef { UnionFind() {} explicit UnionFind(std::shared_ptr p) : NodeRef(p) {} - // no const so that union find can be mutable as a member of unifier - inline UnionFindNode* operator->() const { + // The union find structure is mutable so we do not use the standard macros + // and expose the pointer via `->`. + UnionFindNode* operator->() const { return static_cast(node_.get()); } diff --git a/src/relay/source_map.cc b/src/relay/source_map.cc index a1b3627bccc8..d784c7946954 100644 --- a/src/relay/source_map.cc +++ b/src/relay/source_map.cc @@ -14,7 +14,7 @@ namespace relay { using tvm::IRPrinter; using namespace tvm::runtime; -SourceFragment::SourceFragment(std::string file_name, std::string source) +SourceFragment::SourceFragment(const std::string& file_name, const std::string& source) : file_name(file_name), source_lines({}) { RELAY_LOG(INFO)<< "SourceFragment::SourceFragment source=" << source << std::endl; std::stringstream source_stream; @@ -28,7 +28,7 @@ SourceFragment::SourceFragment(std::string file_name, std::string source) } } -std::string SourceFragment::SourceAt(Span sp, int max_lines) { +std::string SourceFragment::SourceAt(Span sp, int max_lines = 1) { std::stringstream out; // We need to move from 1 based indexing to zero based indexing. diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index bf95992d952f..7d42448a175b 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -20,7 +20,7 @@ def test_monomorphic_let(): x = b.let('x', 1.0, value_type=float_type(64)) b.ret(x) - prog = b.get() + prog, _ = b.get() assert has_type(prog, float_type(64)) def test_single_op(): From 8db9d577940302fa30102b3a3ce7c9db0bdb8e81 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 31 Aug 2018 17:01:45 -0700 Subject: [PATCH 065/136] Add skeleton for kind checker --- include/tvm/relay/pass.h | 20 +++++++++++++++-- src/relay/pass/kind_check.cc | 42 ++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 2 deletions(-) create mode 100644 src/relay/pass/kind_check.cc diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 89f3dd48fc31..0d73ea2ce976 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -6,8 +6,8 @@ #ifndef TVM_RELAY_PASS_H_ #define TVM_RELAY_PASS_H_ -#include "tvm/relay/expr.h" -#include "tvm/relay/environment.h" +#include +#include namespace tvm { namespace relay { @@ -18,6 +18,22 @@ namespace relay { */ Expr InferType(const Environment & env, const Expr & e); +/*! + * \brief Check that types are well formed by applying "kinding rules". + * + * This pass ensures we do not do things that violate the design of the + * type system when writing down types. + * + * For example tensors are not allowed to contain functions in Relay. + * + * We check this by ensuring the `dtype` field of a Tensor always contains + * a data type such as `int`, `float`, `uint`. + * + * \param env The global environment. + * \param t The type to check. + */ +void KindCheck(const Environment& env, const Type& t); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_PASS_TYPECHECKER_H_ \ No newline at end of file diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc new file mode 100644 index 000000000000..65b0b087131c --- /dev/null +++ b/src/relay/pass/kind_check.cc @@ -0,0 +1,42 @@ +/*! + * Copyright (c) 2018 by Contributors + * + * \file kindchecker.cc + * + * \brief Check that types are well formed by applying "kinding rules". + * + * This pass ensures we do not do things that violate the design of the + * type system when writing down types. + * + * For example tensors are not allowed to contain functions in Relay. + * + * We check this by ensuring the `dtype` field of a Tensor always + * contains a data type such as `int`, `float`, `uint`. + */ +#include +#include +#include "./type_visitor.h" + +namespace tvm { +namespace relay { + +using namespace tvm::runtime; + +struct KindChecker : TypeVisitor<> { + bool valid; + + KindChecker() : valid(true) {} + + bool Check(const Type &t) { + this->VisitType(t); + return valid; + } +}; + +bool KindCheck(const Type &t) { + KindChecker kc; + return kc.Check(t); +} + +} // namespace relay +} // namespace tvm \ No newline at end of file From cd32bb9a3fa3a9a7195e01746542e6921ca519e3 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 31 Aug 2018 17:08:58 -0700 Subject: [PATCH 066/136] Tweak docs in pass.h --- include/tvm/relay/pass.h | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 0d73ea2ce976..6d9761daa925 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -6,17 +6,26 @@ #ifndef TVM_RELAY_PASS_H_ #define TVM_RELAY_PASS_H_ -#include #include +#include namespace tvm { namespace relay { -/*! The result of type checking an expression is a new expression - * with unambigous type information filled in, as well as it's - * checked type field populated with the result type. +/*! \brief Infer the type of an expression with the provided environment. + * + * The result of type checking is a new expression with unambigous + * type information filled in, as well as it's checked type field + * populated with the result type. + * + * \param env The environment used for global settings and referencing + * global functions. + * + * \param e The expression to type check. + * + * \return A type checked expression with its checked_type field populated. */ -Expr InferType(const Environment & env, const Expr & e); +Expr InferType(const Environment& env, const Expr& e); /*! * \brief Check that types are well formed by applying "kinding rules". @@ -28,7 +37,7 @@ Expr InferType(const Environment & env, const Expr & e); * * We check this by ensuring the `dtype` field of a Tensor always contains * a data type such as `int`, `float`, `uint`. - * + * * \param env The global environment. * \param t The type to check. */ From 4feda67c7836fba7a978272c5728951506d56f3d Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 31 Aug 2018 17:17:39 -0700 Subject: [PATCH 067/136] Refactor kind_check.{h, cc} --- include/tvm/relay/pass.h | 3 ++- include/tvm/relay/pass/type_infer.h | 25 ------------------------- src/relay/ir/environment.cc | 1 - src/relay/pass/kind_check.cc | 2 +- src/relay/pass/type_infer.cc | 9 ++------- 5 files changed, 5 insertions(+), 35 deletions(-) delete mode 100644 include/tvm/relay/pass/type_infer.h diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 6d9761daa925..738c6033147c 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -40,8 +40,9 @@ Expr InferType(const Environment& env, const Expr& e); * * \param env The global environment. * \param t The type to check. + * \return true if the rules are satisified otherwise false */ -void KindCheck(const Environment& env, const Type& t); +bool KindCheck(const Environment& env, const Type& t); } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/pass/type_infer.h b/include/tvm/relay/pass/type_infer.h deleted file mode 100644 index a75eac6cc0da..000000000000 --- a/include/tvm/relay/pass/type_infer.h +++ /dev/null @@ -1,25 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file tvm/relay/pass/type_infer.h - * \brief Perform type inference and checking on Relay programs. - * - * The pass produces a new expression with its checked_type - * field populated and incomplete types resolved. - */ -#ifndef TVM_RELAY_PASS_TYPE_INFER_H_ -#define TVM_RELAY_PASS_TYPE_INFER_H_ - -#include "tvm/relay/expr.h" -#include "tvm/relay/environment.h" - -namespace tvm { -namespace relay { - -/*! \brief Ensures that an operator is well-formed with respect - * to Relay's type system. - */ -Op CheckOp(const Environment & env, const Op & op); - -} // namespace relay -} // namespace tvm -#endif // TVM_RELAY_PASS_TYPE_INFER_H_ diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc index 7861fb58820b..4c17f7cdbc89 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/environment.cc @@ -7,7 +7,6 @@ #include #include #include -#include #include "./../pass/resolve.h" // #include "tvm/relay/util/rang.h" diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index 65b0b087131c..c3823c8c3a35 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -33,7 +33,7 @@ struct KindChecker : TypeVisitor<> { } }; -bool KindCheck(const Type &t) { +bool KindCheck(const Environment& env, const Type &t) { KindChecker kc; return kc.Check(t); } diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 514df129503a..894139c10b53 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -23,13 +23,12 @@ #include #include #include -#include +#include #include "./incomplete_type.h" #include "./resolve.h" #include "./type_subst.h" #include "./type_visitor.h" #include "./unifier.h" -// #include "tvm/relay/typeck/kindchecker.h" namespace tvm { namespace relay { @@ -378,11 +377,7 @@ FuncType TypeInferencer::instantiate(FuncType fn_ty, Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, {}); inst_ty = TypeSubst(fn_ty, subst_map); - // if (!check_kind(t)) { - // this->fatal_error("Kind rules broken when instantiating type - // variables", - // t->span); - // } + CHECK(KindCheck(this->env, inst_ty)); return GetRef(inst_ty.as()); } From 4e747c2e20a9416cfd04f6f42baa2cc48228773e Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 2 Sep 2018 18:48:08 -0700 Subject: [PATCH 068/136] Improve type checking, can check control-flow-y program. --- include/tvm/relay/environment.h | 2 +- include/tvm/relay/expr.h | 2 + include/tvm/relay/pass.h | 1 + python/tvm/relay/env.py | 3 + python/tvm/relay/expr.py | 6 +- python/tvm/relay/ir_builder.py | 106 ++++++++----- python/tvm/relay/op/tensor.py | 20 +++ src/relay/ir/environment.cc | 11 +- src/relay/ir/expr.cc | 9 ++ src/relay/op/tensor/elemwise.cc | 35 +++++ src/relay/op/type_relations.cc | 143 ++++++++++-------- src/relay/op/type_relations.h | 1 + src/relay/pass/type_infer.cc | 72 +++------ src/relay/pass/unifier.cc | 9 +- .../relay/test_tyck_eval_integration.py | 38 ++++- 15 files changed, 299 insertions(+), 159 deletions(-) diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h index ca5b8ac90df4..da782900fac5 100644 --- a/include/tvm/relay/environment.h +++ b/include/tvm/relay/environment.h @@ -78,7 +78,7 @@ class EnvironmentNode : public RelayNode { /*! \brief Add a source fragment to the environment. */ SourceName AddSource(std::string file_name, std::string source); - void ReportError(std::string msg, Span sp); + void AddDiagnostic(SpannedError); void DisplayErrors(); static constexpr const char* _type_key = "relay.Environment"; diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 8ea3980dad46..7fd81ee0481b 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -215,6 +215,8 @@ class FunctionNode : public ExprNode { v->Visit("span", &span); } + Type fn_type() const; + TVM_DLL static Function make(tvm::Array params, Type ret_type, Expr body, tvm::Array ty_params); diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 738c6033147c..f92596c41179 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -26,6 +26,7 @@ namespace relay { * \return A type checked expression with its checked_type field populated. */ Expr InferType(const Environment& env, const Expr& e); +Expr InferType(const Environment& env, const GlobalVar & v, const Function & e); /*! * \brief Check that types are well formed by applying "kinding rules". diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py index 186ee8854c35..ee64ef6ce814 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/env.py @@ -23,6 +23,9 @@ def add(self, var, func) -> None: def merge(self, other): return _env.Environment_Merge(self, other) + def global_var(self, var): + return _env.Environment_GetGlobalVar(self, var) + def lookup(self, var): if isinstance(var, str): return _env.Environment_Lookup_str(self, var) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 4f558210fb11..ec0cfd55ad62 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -9,7 +9,11 @@ from ._ir_pass import _get_checked_type from . import _make -class Expr(NodeBase): +class ExprBuilder(): + def __call__(self, *args, **kwargs): + return Call(self, args, None, None) + +class Expr(NodeBase, ExprBuilder): """The base type for all Relay exprressions.""" def checked_type(self): return _get_checked_type(self) diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 50ebeb1bb12d..563c512639bc 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -2,27 +2,20 @@ import numpy as np import tvm from .type import FuncType, TensorType -from .expr import Expr, Call, Constant, Let, LocalVar, Param, Function +from .expr import Expr, Call, Constant, Let, LocalVar, Param, Function, If from .env import Environment from . import op as _op -class ExprBuilder(): - def __init__(self, expr): - self.expr = expr - - def __call__(self, *args): - return ExprBuilder(Call(self.expr, list(args), None, None)) - def convert(arg: Any, ctxt=tvm.cpu(0)) -> tvm.nd.NDArray: """Convert Python values into the appropriate types for the Relay evaluator. """ if isinstance(arg, int): - return tvm.nd.array(arg, ctxt) + return tvm.nd.array(np.array(arg, dtype='int32'), ctxt) elif isinstance(arg, float): return tvm.nd.array(arg, ctxt) elif isinstance(arg, bool): - return tvm.nd.array(arg, ctxt) + return tvm.nd.array(np.array(arg, dtype='float32'), ctxt) elif isinstance(arg, np.ndarray): return tvm.nd.array(arg, ctxt) elif isinstance(arg, tvm.ndarray.NDArray): @@ -36,10 +29,10 @@ def into_ast(arg: Any, ctxt=tvm.cpu(0)) -> Expr: raise Exception("..") else: value = convert(arg, ctxt) - return ExprBuilder(Constant(value)) + return Constant(value) class WithScope(object): - """Auxiliary scope with""" + """A wrapper for builder methods which introduce scoping.""" def __init__(self, enter_value, exit_cb): self._enter_value = enter_value @@ -49,7 +42,10 @@ def __enter__(self): return self._enter_value def __exit__(self, ptype, value, trace): - self._exit_cb() + if value: + raise value + else: + self._exit_cb() class PartialFunc(): @@ -77,15 +73,28 @@ def _mk_let(bindings, ret_value): return let_expr - class IRBuilder(): def __init__(self): self.bindings = [{}] self.scopes = [{}] self.params = [] - self.ret_value = None + self.ret_values = [None] self.env = Environment({}) + def enter_scope(self, params=[]): + self.bindings.append({}) + self.scopes.append({}) + self.params.append(params) + self.ret_values.append(None) + + + def exit_scope(self): + bindings = self.bindings.pop() + scopes = self.scopes.pop() + params = self.params.pop() + ret_value = self.ret_values.pop() + return bindings, scopes, params, ret_value + def bind(self, name, type, value): lv = LocalVar(name) @@ -98,12 +107,9 @@ def let(self, name, value, value_type=None): if isinstance(value, Param): value = value.var - if not (isinstance(value, Expr) or isinstance(value, ExprBuilder)): + if not isinstance(value, Expr): value = into_ast(value) - if isinstance(value, ExprBuilder): - value = value.expr - return self.bind(name, value_type, value) def function(self, *params): @@ -115,27 +121,52 @@ def function(self, *params): # self.params.append(relay_params) + self.enter_scope() + pfunc = PartialFunc(relay_params, None, None, []) def _on_exit(): - bindings = self.bindings.pop() - scope = self.scopes.pop() - ret_value = self.ret_value + bindings, scope, params, ret_value = self.exit_scope() body = _mk_let(bindings, ret_value) - self.ret_value = None pfunc.body = body - return WithScope(pfunc, _on_exit) def ret(self, x): - if not self.ret_value: - self.ret_value = x + if not self.ret_values[-1]: + self.ret_values[-1] = x else: raise Exception( "return value already set, a function can only have one return value") + def if_scope(self, cond): + self.enter_scope() + + def _on_exit(): + bindings, _, _, ret_value = self.exit_scope() + assert self.ret_values[-1] is None + true_branch = _mk_let(bindings, ret_value) + self.ret_values[-1] = If(cond, true_branch, None) + + return WithScope(10, _on_exit) + + + def else_scope(self): + self.enter_scope() + + def _on_exit(): + bindings, _, _, ret_value = self.exit_scope() + partial_if = self.ret_values[-1] + assert isinstance(partial_if, If) and partial_if.false_value is None + false_branch = _mk_let(bindings, ret_value) + self.ret_values[-1] = If( + partial_if.cond, + partial_if.true_value, + false_branch) + + return WithScope(10, _on_exit) + def param(self, name, ty=None): if not ty: ty = float_type() @@ -148,18 +179,21 @@ def param(self, name, ty=None): # arg = args[i] # if isinstance(arg, str): + def global_var(self, name: str): + return self.env.global_var(name) - def decl(self, name: str, *params): - decl_builder = IRBuilder() + def decl(self, name: str, *params, ret_type=None): + self.enter_scope() def _on_exit(): - exp, sub_env = decl_builder.get() - self.env.add(name, Function(params, None, exp)) - self.env.merge(sub_env) - - return WithScope(decl_builder, _on_exit) + bindings, _, _, ret_value = self.exit_scope() + exp = _mk_let(bindings, ret_value) + self.env.add(name, Function(params, ret_type, exp)) + return WithScope(10, _on_exit) + + # def while_loop(cond) def get(self): """Get the full program""" bindings = self.bindings.pop() @@ -171,16 +205,16 @@ def get(self): if self.scopes: raise Exception("IRBuilder: scoping error") - if bindings and scope and not self.ret_value: + if bindings and scope and not self.ret_values: raise Exception("IRBuilder: no return value set") - return _mk_let(bindings, self.ret_value), self.env + return _mk_let(bindings, self.ret_values[-1]), self.env def bool_dtype(): return 'uint1' def int_dtype(bits=32): - return f'int1{bits}' + return f'int{bits}' def float_dtype(bits=32): return f'float{bits}' diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index aa9ce6bf42e9..d0c1b88eb240 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -75,3 +75,23 @@ def add(lhs, rhs): The computed result. """ return _make.add(lhs, rhs) + +def subtract(lhs, rhs): + """Take sqrt of data. + + Parameters + ---------- + lhs : relay.Expr + The left hand side input data + rhs : relay.Expr + The right hand side input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.add(lhs, rhs) + +def equal(lhs, rhs): + return _make.equal(lhs, rhs) \ No newline at end of file diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc index 4c17f7cdbc89..a1a754615350 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/environment.cc @@ -42,7 +42,8 @@ GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) { void EnvironmentNode::Add(const GlobalVar& var, const Function & func, bool update) { // Type check the item before we add it to the environment. auto env = GetRef(this); - Expr checked_expr = InferType(env, func); + + Expr checked_expr = InferType(env, var, func); if (const FunctionNode *func_node = checked_expr.as()) { auto checked_func = GetRef(func_node); @@ -104,13 +105,11 @@ void EnvironmentNode::Merge(const Environment & env) { inline SourceName EnvironmentNode::AddSource(std::string file_name, std::string source) { - throw Error("need to restore error handling"); - // return this->source_map_.add_source(file_name, source); + return this->source_map_.AddSource(file_name, source); } -void EnvironmentNode::ReportError(std::string msg, Span sp) { - throw Error("need to restore error handling"); - // this->errors_.push_back(Error(msg, sp)); +void EnvironmentNode::AddDiagnostic(SpannedError error) { + this->errors_.push_back(error); } void EnvironmentNode::DisplayErrors() { diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 47d253e91c21..8dce7c054c8e 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -119,6 +119,15 @@ Function FunctionNode::make(tvm::Array params, Type ret_type, Expr body, return Function(n); } +Type FunctionNode::fn_type() const { + Array param_types; + for (auto param : this->params) { + param_types.push_back(param->type); + } + + return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {}); +} + TVM_REGISTER_API("relay._make.Function") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = FunctionNode::make(args[0], args[1], args[2], args[3]); diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc index cd90705c6476..76adfbbfb968 100644 --- a/src/relay/op/tensor/elemwise.cc +++ b/src/relay/op/tensor/elemwise.cc @@ -85,5 +85,40 @@ RELAY_REGISTER_OP("add") // input2: Tensor[dtype, s2] // output: Tensor[dtype, broadcast(s1, s2)] +// Addition +TVM_REGISTER_API("relay.op._make.subtract") + .set_body_typed([](Expr lhs, Expr rhs) { + static const Op& op = Op::Get("subtract"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); + }); + +RELAY_REGISTER_OP("subtract") + .set_num_inputs(2) + .add_argument("lhs", "Tensor", "The left hand side tensor.") + .add_argument("rhs", "Tensor", "The right hand side tensor.") + .set_support_level(1) + .add_type_func("BroadcastComp", BroadcastCompRel); + + // def broadcast(s1, s2): + // ... + // + // input1: Tensor[dtype, s1] + // input2: Tensor[dtype, s2] + // output: Tensor[dtype, broadcast(s1, s2)] + +// Addition +TVM_REGISTER_API("relay.op._make.equal") + .set_body_typed([](Expr lhs, Expr rhs) { + static const Op& op = Op::Get("equal"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); + }); + +RELAY_REGISTER_OP("equal") + .set_num_inputs(2) + .add_argument("lhs", "Tensor", "The left hand side tensor.") + .add_argument("rhs", "Tensor", "The right hand side tensor.") + .set_support_level(1) + .add_type_func("BroadcastComp", BroadcastCompRel); + } // namespace relayv } // namespace tvm diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index d97b8f96e85c..32d81a1d445e 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -4,15 +4,15 @@ * \brief A set of utilities and common functionality * for type relations. */ -#include #include +#include #include #include "../pass/incomplete_type.h" namespace tvm { namespace relay { -TensorType as_ttype(const Type & t) { +TensorType as_ttype(const Type& t) { if (auto tt_node = t.as()) { return GetRef(tt_node); } else { @@ -21,94 +21,115 @@ TensorType as_ttype(const Type & t) { } // TODO(@jroesch) what size value do we extract? -int to_int(const tvm::Expr & e) { +int to_int(const tvm::Expr& e) { auto imm = e.as(); CHECK(imm); std::cout << "TYPE: " << imm << imm->type << std::endl; return imm->value; } -Array IdentityRel(const Array & types, int num_args) { - CHECK(types.size() == 2); - auto t1 = as_ttype(types[0]); - if (t1 && types[1].as()) { - return {t1, t1}; - } else { - return types; - } +Array IdentityRel(const Array& types, int num_args) { + CHECK(types.size() == 2); + auto t1 = as_ttype(types[0]); + if (t1 && types[1].as()) { + return {t1, t1}; + } else { + return types; + } } -static Type ConcreteBroadcast(const TensorType & t1, const TensorType & t2) { - RELAY_LOG(INFO) << "ConcreteBroadcast: t1=" << t1 << " t2=" << t2 << std::endl; +static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, + DataType output_dtype) { + RELAY_LOG(INFO) << "ConcreteBroadcast: t1=" << t1 << " t2=" << t2 + << std::endl; auto sh1 = t1->shape; auto sh2 = t2->shape; - RELAY_LOG(INFO) << "ConcreteBroadcast: sh1=" << sh1 << " sh2=" << sh2 << std::endl; - CHECK(sh1.size() > 0); - CHECK(sh2.size() > 0); - - auto suffix_len = static_cast(std::min(sh1.size(), sh2.size())); - auto full_len = static_cast(std::max(sh1.size(), sh2.size())); - - std::cout << "Length" << suffix_len << full_len << (full_len - suffix_len - 1) << std::endl; - auto lower_bound = full_len - suffix_len - 1; - - for (int64_t i = full_len - 1; i > lower_bound; i--) { - std::cout << "Index i=" << i << std::endl; - auto dim1 = to_int(sh1[i]); - auto dim2 = to_int(sh2[i]); - if (dim1 != dim2) { - CHECK(false); + RELAY_LOG(INFO) << "ConcreteBroadcast: sh1=" << sh1 << " sh2=" << sh2 + << std::endl; + if (sh1.size() == 0 && sh2.size() == 0) { + return TensorTypeNode::make({}, output_dtype); + // We have non-zero shapes so broadcast rules apply. + } else { + auto suffix_len = static_cast(std::min(sh1.size(), sh2.size())); + auto full_len = static_cast(std::max(sh1.size(), sh2.size())); + + std::cout << "Length" << suffix_len << full_len + << (full_len - suffix_len - 1) << std::endl; + auto lower_bound = full_len - suffix_len - 1; + + for (int64_t i = full_len - 1; i > lower_bound; i--) { + std::cout << "Index i=" << i << std::endl; + auto dim1 = to_int(sh1[i]); + auto dim2 = to_int(sh2[i]); + if (dim1 != dim2) { + CHECK(false); + } } - } - Array larger; - Array smaller; + Array larger; + Array smaller; - for (int i = 0; i < (full_len - suffix_len); i++) { - smaller.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), 1)); - } + for (int i = 0; i < (full_len - suffix_len); i++) { + smaller.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), 1)); + } - if (sh1.size() < sh2.size()) { - for (auto sh : sh1) { - smaller.push_back(sh); + if (sh1.size() < sh2.size()) { + for (auto sh : sh1) { + smaller.push_back(sh); + } + larger = sh2; + } else if (sh1.size() > sh2.size()) { + for (auto sh : sh1) { + larger.push_back(sh); + } + smaller = sh2; + } else { + larger = sh1; + smaller = sh2; } - larger = sh2; - } else if (sh1.size() > sh2.size()) { - for (auto sh : sh1) { - larger.push_back(sh); + + CHECK(larger.size() == smaller.size()); + + Array out_shape; + for (int i = 0; i < smaller.size(); i++) { + auto left = smaller[i].as(); + auto right = larger[i].as(); + CHECK(left); + CHECK(right); + int64_t dim = std::max(left->value, right->value); + out_shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), dim)); } - smaller = sh2; - } else { - larger = sh1; - smaller = sh2; - } - CHECK(larger.size() == smaller.size()); + return TensorTypeNode::make(out_shape, output_dtype); + } +} - Array out_shape; - for (int i = 0; i < smaller.size(); i++) { - auto left = smaller[i].as(); - auto right = larger[i].as(); - CHECK(left); - CHECK(right); - int64_t dim = std::max(left->value, right->value); - out_shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), dim)); +Array BroadcastRel(const Array& types, int num_args) { + CHECK(types.size() == 3); + if (auto t1 = as_ttype(types[0])) { + if (auto t2 = as_ttype(types[1])) { + std::cout << t1->dtype << t2->dtype << std::endl; + CHECK(t1->dtype == t2->dtype); + return {t1, t2, ConcreteBroadcast(t1, t2, t1->dtype)}; + } } - return TensorTypeNode::make(out_shape, t1->dtype); + return types; } -Array BroadcastRel(const Array & types, int num_args) { +/* A relation which specifies broadcasting rules for operations which + compute boolean results. +*/ +Array BroadcastCompRel(const Array& types, int num_args) { CHECK(types.size() == 3); if (auto t1 = as_ttype(types[0])) { if (auto t2 = as_ttype(types[1])) { - return { t1, t2, ConcreteBroadcast(t1, t2) }; + return {t1, t2, ConcreteBroadcast(t1, t2, HalideIR::Bool())}; } } return types; } - -} // namespace relayv +} // namespace relay } // namespace tvm diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h index f2c4876705b6..71c98fef7da1 100644 --- a/src/relay/op/type_relations.h +++ b/src/relay/op/type_relations.h @@ -15,6 +15,7 @@ namespace relay { Array IdentityRel(const Array & types, int num_args); Array BroadcastRel(const Array & types, int num_args); +Array BroadcastCompRel(const Array & types, int num_args); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 894139c10b53..1adfb95d1e15 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -188,7 +188,7 @@ CheckedExpr TypeInferencer::VisitExpr_(const LocalVarNode *op) { CheckedExpr TypeInferencer::VisitExpr_(const GlobalVarNode *op) { GlobalVar var = GetRef(op); Expr e = this->env->Lookup(var); - return { var, e->checked_type() }; + return {var, e->checked_type()}; } CheckedExpr TypeInferencer::VisitExpr_(const ConstantNode *const_node) { @@ -206,7 +206,7 @@ CheckedExpr TypeInferencer::VisitExpr_(const TupleNode *op) { field_types.push_back(checked_field.type); } - return { TupleNode::make(field_exprs), TupleTypeNode::make(field_types) }; + return {TupleNode::make(field_exprs), TupleTypeNode::make(field_types)}; } CheckedExpr TypeInferencer::VisitExpr_(const ParamNode *param) { @@ -488,21 +488,14 @@ CheckedExpr TypeInferencer::VisitExpr_(const IfNode *op) { auto checked_cond = this->Infer(ifn->cond); auto cond_type = checked_cond.type; - if (const TensorTypeNode *tt_node = cond_type.as()) { - TensorType tt = GetRef(tt_node); - if (tt->dtype.is_bool() && tt->shape.size() == 0) { - auto checked_true = this->Infer(ifn->true_value); - auto checked_false = this->Infer(ifn->false_value); - auto unified_type = - this->unify(checked_true.type, checked_false.type, ifn->span); - auto checked_if = IfNode::make(checked_cond.expr, checked_true.expr, - checked_false.expr); - return {checked_if, unified_type}; - } - } - - this->fatal_error("if-then-else guard must be a rank-0 boolean tensor", - ifn->cond->span); + this->unify(cond_type, TensorTypeNode::make({}, HalideIR::Bool()), ifn->cond->span); + auto checked_true = this->Infer(ifn->true_value); + auto checked_false = this->Infer(ifn->false_value); + auto unified_type = + this->unify(checked_true.type, checked_false.type, ifn->span); + auto checked_if = IfNode::make(checked_cond.expr, checked_true.expr, + checked_false.expr); + return {checked_if, unified_type}; } CheckedExpr TypeInferencer::VisitExpr_(const OpNode *op_node) { @@ -510,7 +503,7 @@ CheckedExpr TypeInferencer::VisitExpr_(const OpNode *op_node) { return {op, op->op_type}; } -Type TypeInferencer::resolve(const Type &t) { +Type TypeInferencer::resolve(const Type& t) { if (t.defined()) { return ::tvm::relay::Resolve(this->unifier, t); } else { @@ -518,21 +511,11 @@ Type TypeInferencer::resolve(const Type &t) { } } -Expr TypeInferencer::resolve(const Expr &e) { +Expr TypeInferencer::resolve(const Expr& e) { CHECK(e.defined()); return ::tvm::relay::Resolve(this->unifier, e); } -void TypeInferencer::CheckOp(Op op) { - throw Error("NYI"); - // if (!check_kind(op->type)) { - // report_error("the type of the operator is ill formed", op->type->span); - // } - - // // Fix me - // return op; -} - // Defn TypeInferencer::CheckDefn(Defn defn) { // // This is to handle recursion, but we need to speculatively // // put it in env, then remove it. @@ -565,31 +548,24 @@ Expr InferType(const Environment &env, const Expr &e) { return checked_expr.expr; } -// Item Check(const Environment &env, const Item &i) { -// TypeInferencer tc(env); +Expr InferType(const Environment &env, const GlobalVar & var, const Function & func) { + TypeInferencer ti(env); + auto func_copy = FunctionNode::make(func->params, func->ret_type, func->body, func->type_params); + func_copy->checked_type_ = ti.resolve(func_copy->fn_type()); + env->functions.Set(var, func_copy); + auto checked_expr = ti.Infer(func); + auto map_node = env->functions.CopyOnWrite(); + map_node->data.erase(var.node_); + return checked_expr.expr; +} -// try { -// if (const DefnNode *defn = i.as()) { -// return tc.CheckDefn(GetRef(defn)); -// } else if (const OpNode *op_node = i.as()) { -// return tc.CheckOp(GetRef(op_node)); -// } else { -// throw dmlc::Error("internal error: unknown Item type"); -// } -// } catch (const FatalTypeError &err) { -// env->display_errors(); -// throw dmlc::Error( -// "We encountered a fatal error while type checking your program, -// please " "read above for more details."); -// } -// } inline void TypeInferencer::report_error(const std::string &msg, Span sp) { - // this->env->report_error(msg, sp); + this->env->AddDiagnostic({msg, sp}); } void TypeInferencer::fatal_error(const std::string &msg, Span sp) { - // this->env->report_error(msg, sp); + this->env->AddDiagnostic({msg, sp}); throw FatalTypeError( "internal error: this exception should" "be handled and errors reported with Environment::display_errors\n" + diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index 4d986ad79ab1..4558f6a24919 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -180,6 +180,12 @@ Type TypeUnifierNode::VisitType(const Type & t1, const Type t2) { // When the right hand size is a type variable immediately unify. if (const IncompleteTypeNode *tvn2 = t2.as()) { return this->unifyWithIncompleteType(t1, GetRef(tvn2)); + // The TypeCallNode case is special and not symmetric. + // + // We flip the arguments so we hit the TypeCall and other case in there is + // ever a type call. + } else if (const TypeCallNode *tvn2 = t2.as()) { + return TypeFunctor::VisitType(t2, t1); } else { return TypeFunctor::VisitType(t1, t2); } @@ -353,7 +359,8 @@ Type TypeUnifierNode::VisitType_(const TypeCallNode *tcn1, const Type t2) { return TypeCallNode::make(unified_func, new_args); } else { - throw UnificationError("Cannot unify call with non-call"); + auto args = ty_call1->args; + return this->VisitType(args[args.size() - 1], t2); } } diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index 7d42448a175b..1ae78441e166 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -2,9 +2,10 @@ for expressions. """ from tvm.relay.ir_pass import check_expr -from tvm.relay.ir_builder import IRBuilder, float_type, func_type, tensor_type +from tvm.relay.ir_builder import IRBuilder, float_type, int_type +from tvm.relay.ir_builder import func_type, tensor_type, into_ast from tvm.relay.env import Environment -from tvm.relay.op import log, add +from tvm.relay.op import log, add, equal, subtract def has_type(expr, typ, env=Environment({})): checked_expr = check_expr(env, expr) @@ -40,6 +41,7 @@ def test_dual_op(): return t1; } """ + pass b = IRBuilder() with b.function(('x', tensor_type(10, 10))) as func: x, = func.param_ids() @@ -56,16 +58,42 @@ def f(x : Tensor[f32, (10, 10)]) { return lx; } """ + pass b = IRBuilder() x = b.param('x') - with b.decl('f', x) as d: - lx = d.let('lx', log(x)) - d.ret(lx) + with b.decl('f', x): + lx = b.let('lx', log(x)) + b.ret(lx) _, env = b.get() assert decl_has_type(env, 'f', func_type([float_type()], float_type())) +def test_recursion(): + """ + Program: + def f(n: i32, data: f32) -> f32 { + if (n == 0) { + return f(n - 1, log(data)); + } else { + return data; + } + } + f(2, 10000); + """ + b = IRBuilder() + f = b.global_var('f') + n = b.param('n', ty=int_type()) + data = b.param('data', ty=float_type()) + with b.decl(f, n, data): + with b.if_scope(equal(n, into_ast(0.0))): + b.ret(f(subtract(n, into_ast(1)), log(data))) + with b.else_scope(): + b.ret(data) + b.ret(f(into_ast(2.0), into_ast(10000.0))) + assert decl_has_type(b.env, 'f', func_type([int_type(), float_type()], float_type())) + if __name__ == "__main__": test_monomorphic_let() test_single_op() test_dual_op() test_decl() + test_recursion() From f25865eb1acd97ea1880b7892f16bad4313c3bc1 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 2 Sep 2018 18:52:14 -0700 Subject: [PATCH 069/136] Rename TVM compiler to to_tvm.py --- python/tvm/relay/ir_builder.py | 2 +- python/tvm/relay/{tvm_rts_backend.py => to_tvm.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename python/tvm/relay/{tvm_rts_backend.py => to_tvm.py} (100%) diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 563c512639bc..a0a8c2e008da 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -220,7 +220,7 @@ def float_dtype(bits=32): return f'float{bits}' def uint_dtype(bits=32): - return f'fuint{bits}' + return f'uint{bits}' def int_type(bits=32, lanes=1): # TODO(@jroesch, @tqchen) How do we set lanes? diff --git a/python/tvm/relay/tvm_rts_backend.py b/python/tvm/relay/to_tvm.py similarity index 100% rename from python/tvm/relay/tvm_rts_backend.py rename to python/tvm/relay/to_tvm.py From 8941d419cfd2a9a2cfafa20143b4bbcfb75c6088 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 4 Sep 2018 13:49:44 -0700 Subject: [PATCH 070/136] Begin work on lowering Relay to TVM --- include/tvm/relay/base.h | 12 + include/tvm/relay/environment.h | 32 ++- include/tvm/relay/expr.h | 1 + include/tvm/relay/expr_visitor.h | 110 +++++---- include/tvm/relay/op.h | 3 +- python/tvm/relay/env.py | 3 + python/tvm/relay/ir_builder.py | 17 +- python/tvm/relay/ir_pass.py | 227 +++++++++++++++++- python/tvm/relay/op/__init__.py | 2 +- python/tvm/relay/op/op.py | 42 +++- python/tvm/relay/to_tvm.py | 49 ++-- src/relay/ir/environment.cc | 55 +++-- src/relay/ir/op.cc | 54 +++++ src/relay/pass/resolve.cc | 3 +- src/relay/pass/type_infer.cc | 7 +- .../relay/test_tyck_eval_integration.py | 53 +++- 16 files changed, 539 insertions(+), 131 deletions(-) diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index e78c4b28e9ca..09f3a94e1edb 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -154,6 +154,18 @@ RefType GetRef(const NodeType* ptr) { return RefType(const_cast(ptr)->shared_from_this()); } +/*! + * \brief Get PackedFunction from global registry and + * report error if it does not exist + * \param name The name of the function. + * \return The created PackedFunc. + */ +inline const PackedFunc& GetPackedFunc(const std::string& name) { + const PackedFunc* pf = tvm::runtime::Registry::Get(name); + CHECK(pf != nullptr) << "Cannot find function " << name << " in registry"; + return *pf; +} + } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h index da782900fac5..29cde295398d 100644 --- a/include/tvm/relay/environment.h +++ b/include/tvm/relay/environment.h @@ -7,11 +7,11 @@ #ifndef TVM_RELAY_ENVIRONMENT_H_ #define TVM_RELAY_ENVIRONMENT_H_ +#include #include -#include #include -#include #include +#include #include #include @@ -24,8 +24,8 @@ struct Environment; * * The global environment contains the global * information needed to compile a Relay program. - * - * It contains all global functions, and configuration + * + * It contains all global functions, and configuration * options. * * Many operations require access to the global @@ -34,14 +34,15 @@ struct Environment; * but we will mutate the Environment while optimizing * Relay programs. * - * The functional style allows users to construct custom + * The functional style allows users to construct custom * environments easily, for example each thread can store * an Environment while auto-tuning. * */ class EnvironmentNode : public RelayNode { private: - /*! \brief A map from string names to global variables ensures global uniqueness. */ + /*! \brief A map from string names to global variables ensures global + * uniqueness. */ tvm::Map global_map_; /*! \brief A map from file names to source fragments. */ SourceMap source_map_; @@ -56,11 +57,10 @@ class EnvironmentNode : public RelayNode { void VisitAttrs(tvm::AttrVisitor* v) final {} - TVM_DLL static Environment make( - tvm::Map global_funcs); + TVM_DLL static Environment make(tvm::Map global_funcs); - void Add(const GlobalVar& var, const Function & func, bool update = false); - void Update(const GlobalVar& var, const Function & func); + void Add(const GlobalVar& var, const Function& func, bool update = false); + void Update(const GlobalVar& var, const Function& func); void Remove(const GlobalVar& var); /*! \brief Lookup a global function by its variable. */ @@ -70,14 +70,20 @@ class EnvironmentNode : public RelayNode { Function Lookup(const GlobalVar& id); /*! \brief Lookup a global function by its string name */ - Function Lookup(const std::string & s); - + Function Lookup(const std::string& s); + // TODO(@jroesch, @tqchen): what are the semantics here - void Merge(const Environment & env); + void Merge(const Environment& env); /*! \brief Add a source fragment to the environment. */ SourceName AddSource(std::string file_name, std::string source); + using Transformer = runtime::TypedPackedFunc< + runtime::TypedPackedFunc(const Environment&)>; + + /*! \brief Apply a function over every function in the global environment. */ + void Transform(Transformer tranformer); + void AddDiagnostic(SpannedError); void DisplayErrors(); diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 7fd81ee0481b..a882b7cc1ea7 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -271,6 +271,7 @@ class CallNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("op", &op); v->Visit("args", &args); + v->Visit("attrs", &attrs); v->Visit("type_args", &type_args); v->Visit("span", &span); } diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h index e15f25a39eb3..6f2a7f98542a 100644 --- a/include/tvm/relay/expr_visitor.h +++ b/include/tvm/relay/expr_visitor.h @@ -15,95 +15,99 @@ namespace tvm { namespace relay { -template -class ExprVisitor : public ::tvm::relay::ExprFunctor { +class ExprVisitor : public ::tvm::relay::ExprFunctor { public: - void VisitExpr_(const LocalVarNode* op, Args... args) override { return; } + void VisitExpr_(const LocalVarNode* op) override { return; } - void VisitExpr_(const GlobalVarNode* op, Args... args) override { return; } + void VisitExpr_(const GlobalVarNode* op) override { return; } - void VisitExpr_(const ConstantNode* op, Args... args) override { return; } + void VisitExpr_(const ConstantNode* op) override { return; } - void VisitExpr_(const TupleNode* op, Args... args) override { + void VisitExpr_(const TupleNode* op) override { for (auto field : op->fields) { - this->VisitExpr(field, args...); + this->VisitExpr(field); } } - void VisitExpr_(const ParamNode* op, Args... args) override { - this->VisitExpr(op->var, args...); + void VisitExpr_(const ParamNode* op) override { + this->VisitExpr(op->var); } - void VisitExpr_(const FunctionNode* op, Args... args) override { + void VisitExpr_(const FunctionNode* op) override { for (auto param : op->params) { - this->VisitExpr(param, args...); + this->VisitExpr(param); } - this->VisitExpr(op->body, args...); + this->VisitExpr(op->body); } - void VisitExpr_(const CallNode* op, Args... args) override { - this->VisitExpr(op->op, args...); + void VisitExpr_(const CallNode* op) override { + this->VisitExpr(op->op); + for (auto ty_arg : op->type_args) { + this->VisitType(ty_arg); + } + for (auto arg : op->args) { - this->VisitExpr(arg, args...); + this->VisitExpr(arg); } } - void VisitExpr_(const LetNode* op, Args... args) override { - this->VisitExpr(op->var, args...); - this->VisitExpr(op->value, args...); - this->VisitExpr(op->body, args...); + void VisitExpr_(const LetNode* op) override { + this->VisitExpr(op->var); + this->VisitExpr(op->value); + this->VisitExpr(op->body); } - void VisitExpr_(const IfNode* op, Args... args) override { - this->VisitExpr(op->cond, args...); - this->VisitExpr(op->true_value, args...); - this->VisitExpr(op->false_value, args...); + void VisitExpr_(const IfNode* op) override { + this->VisitExpr(op->cond); + this->VisitExpr(op->true_value); + this->VisitExpr(op->false_value); } - void VisitExpr_(const OpNode* op, Args... args) override { return; } + void VisitExpr_(const OpNode* op) override { return; } + + virtual void VisitType(const Type& t) {} }; -template -class ExprFVisitor : public ::tvm::relay::ExprFunctor { +class ExprFVisitor : public ::tvm::relay::ExprFunctor { public: - Expr VisitExpr_(const LocalVarNode* op, Args... args) override { + Expr VisitExpr_(const LocalVarNode* op) override { return GetRef(op); } - Expr VisitExpr_(const GlobalVarNode* op, Args... args) override { + Expr VisitExpr_(const GlobalVarNode* op) override { return GetRef(op); } - Expr VisitExpr_(const OpNode* op, Args... args) override { + Expr VisitExpr_(const OpNode* op) override { return GetRef(op); } - Expr VisitExpr_(const TupleNode* op, Args... args) override { + Expr VisitExpr_(const TupleNode* op) override { tvm::Array fields; for (auto field : op->fields) { - fields.push_back(this->VisitExpr(field, args...)); + fields.push_back(this->VisitExpr(field)); } return TupleNode::make(fields); } - Expr VisitExpr_(const ParamNode* op, Args... args) override { - Expr var_expr = this->VisitExpr(op->var, args...); + Expr VisitExpr_(const ParamNode* op) override { + Expr var_expr = this->VisitExpr(op->var); if (const LocalVarNode* var_node = var_expr.as()) { auto var = GetRef(var_node); - auto type = this->VisitType(op->type, args...); + auto type = this->VisitType(op->type); return ParamNode::make(var, type); } else { throw dmlc::Error("the default param visitor has bug"); } } - Expr VisitExpr_(const FunctionNode* op, Args... args) override { + Expr VisitExpr_(const FunctionNode* op) override { tvm::Array ty_params; for (auto ty : op->type_params) { - Type ty_param_type = VisitType(ty, args...); + Type ty_param_type = VisitType(ty); if (auto ty_param = ty_param_type.as()) { auto ty_param_ref = GetRef(ty_param); ty_params.push_back(ty_param_ref); @@ -114,7 +118,7 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor params; for (auto param : op->params) { - Expr param_expr = this->VisitExpr(param, args...); + Expr param_expr = this->VisitExpr(param); if (const ParamNode* param_node = param_expr.as()) { auto param = GetRef(param_node); params.push_back(param); @@ -123,23 +127,23 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctorVisitType(op->ret_type, args...); - auto body = this->VisitExpr(op->body, args...); + auto ret_type = this->VisitType(op->ret_type); + auto body = this->VisitExpr(op->body); return FunctionNode::make(params, ret_type, body, ty_params); } - Expr VisitExpr_(const CallNode* call_node, Args... args) override { - auto fn = this->VisitExpr(call_node->op, args...); + Expr VisitExpr_(const CallNode* call_node) override { + auto fn = this->VisitExpr(call_node->op); tvm::Array ty_args; for (auto ty_arg : call_node->type_args) { - auto new_ty_arg = this->VisitType(ty_arg, args...); + auto new_ty_arg = this->VisitType(ty_arg); ty_args.push_back(new_ty_arg); } tvm::Array call_args; for (auto arg : call_node->args) { - call_args.push_back(this->VisitExpr(arg, args...)); + call_args.push_back(this->VisitExpr(arg)); } auto call = CallNode::make(fn, call_args, call_node->attrs, ty_args); @@ -147,27 +151,27 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctorVisitExpr(op->var, args...); + Expr VisitExpr_(const LetNode* op) override { + Expr var_expr = this->VisitExpr(op->var); if (const LocalVarNode* var_node = var_expr.as()) { auto var = GetRef(var_node); - auto type = this->VisitType(op->value_type, args...); - auto value = this->VisitExpr(op->value, args...); - auto body = this->VisitExpr(op->body, args...); + auto type = this->VisitType(op->value_type); + auto value = this->VisitExpr(op->value); + auto body = this->VisitExpr(op->body); return LetNode::make(var, value, body, type); } else { throw dmlc::Error("the default let visitor has error"); } } - Expr VisitExpr_(const IfNode* op, Args... args) override { - auto guard = this->VisitExpr(op->cond, args...); - auto true_b = this->VisitExpr(op->true_value, args...); - auto false_b = this->VisitExpr(op->false_value, args...); + Expr VisitExpr_(const IfNode* op) override { + auto guard = this->VisitExpr(op->cond); + auto true_b = this->VisitExpr(op->true_value); + auto false_b = this->VisitExpr(op->false_value); return IfNode::make(guard, true_b, false_b); } - virtual Type VisitType(const Type& t, Args... args) { return t; } + virtual Type VisitType(const Type& t) { return t; } }; } // namespace relay diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 0e5483174c53..2e8d090f6625 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -34,7 +34,7 @@ class OpNode : public relay::ExprNode { public: /*! \brief name of the operator */ std::string name; - + /*! \brief the type of the operator */ Type op_type; /*! * \brief detailed description of the operator @@ -62,6 +62,7 @@ class OpNode : public relay::ExprNode { void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("name", &name); + v->Visit("op_type", &op_type); v->Visit("description", &description); v->Visit("arguments", &arguments); v->Visit("attrs_type_key", &attrs_type_key); diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py index ee64ef6ce814..86c9ac794b4e 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/env.py @@ -31,3 +31,6 @@ def lookup(self, var): return _env.Environment_Lookup_str(self, var) else: return _env.Environment_Lookup(self, var) + + def transform(self, transformer): + _env.Environment_Transform(self, transformer) diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index a0a8c2e008da..098eb474c6ee 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -25,8 +25,12 @@ def convert(arg: Any, ctxt=tvm.cpu(0)) -> tvm.nd.NDArray: raise Exception(f"unsupported argument type {type(arg)}") def into_ast(arg: Any, ctxt=tvm.cpu(0)) -> Expr: - if isinstance(arg, tuple): + if isinstance(arg, Expr): + return arg + elif isinstance(arg, tuple): raise Exception("..") + elif isinstance(arg, PartialFunc): + return arg.to_func() else: value = convert(arg, ctxt) return Constant(value) @@ -114,10 +118,11 @@ def let(self, name, value, value_type=None): def function(self, *params): relay_params = [] - for name, ty in params: - lv = LocalVar(name) - self.scopes[-1][name] = lv - relay_params.append(Param(lv, ty)) + for param in params: + name = param.var + ty = param.type + self.scopes[-1][name.name_hint] = name + relay_params.append(Param(name, ty)) # self.params.append(relay_params) @@ -135,7 +140,7 @@ def _on_exit(): def ret(self, x): if not self.ret_values[-1]: - self.ret_values[-1] = x + self.ret_values[-1] = into_ast(x) else: raise Exception( "return value already set, a function can only have one return value") diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index ad7a68eac392..70d5f09237d8 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -1,5 +1,230 @@ -#pylint: disable-all +# pylint: disable=no-else-return +# pylint: disable=unidiomatic-typecheck +""" +The optimizer for Relay. +Exposes an interface for configuring the optimizer and scripting +it directly in Python. +""" +from typing import TypeVar, Generic, Union +from typing import Dict, Tuple, List, Callable +import tvm + +from .expr import Expr +from .expr import Function, Let, Call, LocalVar +from .expr import GlobalVar, If, Constant +from .type import Type +from .env import Environment +from .op import Op +# import relay.make as relay_mk +# from relay import ir +# from relay.env import Environment +# from relay.tyck import check_expr +# from relay.first_order_reverse_ad import fo_with_gradient +# from relay.anf import to_anf from . import _ir_pass +# Expose checking expression, should rename to infer_type. check_expr = _ir_pass.check_expr + +# # pylint: disable=invalid-name +# concretize = _opt.concretize + +# # pylint: disable=invalid-name +# optimize = _opt.optimize + +# # pylint: disable=invalid-name +# type_specialize = _opt.type_specialize + +# # pylint: disable=invalid-name +# compile_ops_to_module = _opt.compile_ops_to_module + + +@tvm.register_func("relay.mangle") +def mangle(name: str, types: List[Type]) -> str: + for typ in types: + name += str(typ) + "_" + return name + +T = TypeVar('T') +class AbstractExprVisitor(Generic[T]): + """A functional visitor over Expr in Python.""" + + # pylint: disable=no-else-return + def visit(self, expr: Expr) -> T: + """Apply the visitor to an expression.""" + if isinstance(expr, Function): + return self.visit_function(expr) + elif isinstance(expr, Call): + return self.visit_call(expr) + elif isinstance(expr, Let): + return self.visit_let(expr) + elif isinstance(expr, LocalVar): + return self.visit_local_var(expr) + elif isinstance(expr, GlobalVar): + return self.visit_global_var(expr) + elif isinstance(expr, If): + return self.visit_if(expr) + elif isinstance(expr, Tuple): + return self.visit_tuple(expr) + elif isinstance(expr, Constant): + return self.visit_constant(expr) + else: + raise Exception(f"warning unhandled case: {type(expr)}") + + def visit_function(self, _: Function) -> T: + raise Exception("Abstract method please implement me.") + + def visit_let(self, _: Let) -> T: + raise Exception("Abstract method please implement me.") + + def visit_call(self, _: Call) -> T: + raise Exception("Abstract method please implement me.") + + def visit_local_id(self, _: LocalVar) -> T: + raise Exception("Abstract method please implement me.") + + def visit_type(self, typ: Type) -> Type: + return typ + + def visit_if(self, _: If) -> T: + raise Exception("Abstract method please implement me.") + + def visit_tuple(self, _: Tuple) -> T: + raise Exception("Abstract method please implement me.") + + def visit_constant(self, _: Constant) -> T: + raise Exception("Abstract method please implement me.") + + def visit_global_var(self, _: GlobalVar) -> T: + raise Exception("Abstract method please implement me.") + + @classmethod + def to_pass(cls) -> Callable[[Environment], Callable[[GlobalVar, Function], Function]]: + def _outer_wrapper(env): + visitor = cls(env) + def _inner_wrapper(var, func): + return visitor.visit(func) + return _inner_wrapper + return _outer_wrapper + +class ExprVisitor(AbstractExprVisitor[Expr]): + """A functional visitor over Expr in Python.""" + + def visit_function(self, fn: Function) -> Expr: + new_body = self.visit(fn.body) + return Function( + list(fn.params), + fn.ret_type, new_body, + fn.type_params) + + def visit_let(self, let: Let) -> Expr: + new_var = self.visit(let.var) + new_value_type = self.visit_type(let.value_type) + new_val = self.visit(let.value) + new_body = self.visit(let.body) + return Let(new_var, new_val, new_body, new_value_type) + + def visit_call(self, call: Call) -> Expr: + new_fn = self.visit(call.fn) + new_args = [self.visit(arg) for arg in call.args] + return Call(new_fn, new_args, call.attrs) + + def visit_local_var(self, local_var: LocalVar) -> Expr: + return local_var + + def visit_global_id(self, global_var: GlobalVar) -> Expr: + return global_var + + def visit_if(self, ite: If) -> Expr: + return If( + self.visit(ite.guard), + self.visit(ite.true_b), + self.visit(ite.false_b)) + + def visit_tuple(self, tup: Tuple) -> Expr: + return Tuple([self.visit(field) for field in tup.fields]) + + def visit_constant(self, const: Constant) -> Expr: + return const + +MMCacheKey = Tuple[GlobalVar, List[Type]] + +class Monomorphize(ExprVisitor): + """A monomorphization pass. + + Implements what is known as "monomorphization" in + classic compiler literature. This pass removes + polymorphism replacing calls to functions and + operators with type specialized versions. + """ + monomorph_map: Dict[MMCacheKey, Union[Op, Function]] + + # pylint: disable=super-init-not-called + def __init__(self, env: Environment) -> None: + self.env = env + # Stores (GlobalVar, Type), should eventually store attributes. + self.monomorph_map = {} + + # pylint: disable=no-else-return + def visit_call(self, call: Call) -> Expr: + import pdb; pdb.set_trace() + # cache_key = (call.fn, call.ty_args) + # if isinstance(call.fn, OperatorId): + # if cache_key in self.monomorph_map: + # op = self.monomorph_map[cache_key] + # new_args = [self.visit(arg) for arg in call.args] + # return Call(op, new_args, call.attrs) + # else: + # new_name = mangle(call.fn.name, call.ty_args) + # new_id = self.env.operator_id(new_name) + # self.monomorph_map[cache_key] = new_id + # op = self.env.lookup(call.fn) + # for arg in call.ty_args: + # if isinstance(arg, TypeParam): + # return call # raise Exception("...") # Fix me in the morning!!! + # new_op = concretize(new_id, op, call.ty_args, call.attrs) + # self.monomorph_map[cache_key] = new_op.id + # self.env.add(new_op) + # new_args = [self.visit(arg) for arg in call.args] + # return Call(new_op.id, new_args, call.attrs) + # elif isinstance(call.fn, GlobalVar): + # if cache_key in self.monomorph_map: + # op_name = self.monomorph_map[cache_key] + # new_args = [self.visit(arg) for arg in call.args] + # return Call(op_name, new_args, call.attrs) + # else: + # defn = self.env.lookup(call.fn) + # new_id = self.env.global_id(defn.id.name + str(1)) + # cache_key = (call.fn, call.ty_args) + # self.monomorph_map[cache_key] = new_id + # new_body = self.visit(type_specialize(call.ty_args, defn.body)) + # new_body = Function( + # [], new_body.params, new_body.ret_type, new_body.body) + # new_ty = check_expr(self.env, new_body) + # # TODO(@jroesch): move into C++ + # # TODO(@joresch): implement and call name mangler + # defn = Defn(new_id, new_ty, new_body) + # self.env.add(defn) + # self.visit_item(defn) + # return Call(new_id, call.args, call.attrs) + # elif isinstance(call.fn, Function): + # new_args = [self.visit(arg) for arg in call.args] + # new_func = type_specialize(call.ty_args, call.fn) + # new_func = self.visit(new_func) + # new_func = Function([], + # new_func.params, + # new_func.ret_type, + # new_func.body) + # check_expr(self.env, new_func) + # return Call(new_func, call.args, call.attrs) + # else: + # new_fn = self.visit(call.fn) + # new_args = [self.visit(arg) for arg in call.args] + # return Call(new_fn, new_args, call.attrs) + + +# TODO(@jroesch): Fix up my type +__tgt_host__ = __tgt__ = "llvm" +__relay_tvm_context__ = tvm.cpu() + diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index d54f47e25197..47ebc5501cab 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -1,6 +1,6 @@ """Relay core operators.""" # operator defs -from .op import get, register, Op +from .op import get, register, Op, compile_ops # Operators from .tensor import * diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 4540b19f5ccf..d351e6cdc88d 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -3,7 +3,7 @@ from ..base import register_relay_node from ..expr import Expr -from ..._ffi.function import Function +from ..._ffi.function import Function, register_func from ...api import convert @register_relay_node @@ -72,6 +72,46 @@ def _register(v): return v return _register(value) if value else _register +def compile_ops(op_names): + """Register an operator property of an operator. + + + Parameters + ---------- + op_name : str + The name of operator + + attr_key : str + The attribute name. + + value : object, optional + The value to set + + level : int, optional + The priority level + + Returns + ------- + fregister : function + Register function if value is not specified. + """ + fake_map = {} + for name in op_names: + fake_map[name] = LocalVar(name) + if isinstance({}, dict): + fake_map = None + return [] # _CompileOpsToModule(fake_map) + +# TODO(@jroesch): We should port to C++, just need to figure out how to write this code. +@register_func("relay.opt.compile_ops") +def _compile_ops(op_impls): + lowered = [] + for local, sch, inputs in op_impls: + lfn = tvm.lower(sch, inputs, name=local.name_hint) + lowered.append(lfn) + + # TOOD(@jroesch): Where should we read these settings from + return tvm.build(lowered, target='llvm', target_host=tvm.cpu(0)) _init_api("relay.op", __name__) diff --git a/python/tvm/relay/to_tvm.py b/python/tvm/relay/to_tvm.py index 137230ace63a..d191e078dffe 100644 --- a/python/tvm/relay/to_tvm.py +++ b/python/tvm/relay/to_tvm.py @@ -4,12 +4,10 @@ from typing import Dict, Any, List, Tuple import attr - -from relay.frontend import get_env -from . import ir -from .tyck import get_checked_type -from .opt import AbstractExprVisitor, compile_ops_to_module -from ._make import Operator_is_generic +from .ir_pass import AbstractExprVisitor +from .op import compile_ops +from .type import TensorType +from .expr import LocalVar, Function, Let, Call @attr.s(auto_attribs=True) @@ -71,7 +69,7 @@ def to_json(self) -> Any: } -def from_tensor(typ: ir.TensorType) -> Tuple[str, List[int]]: +def from_tensor(typ: TensorType) -> Tuple[str, List[int]]: dtype = typ.dtype.dtype shape = typ.shape dims = [] @@ -83,7 +81,7 @@ def from_tensor(typ: ir.TensorType) -> Tuple[str, List[int]]: class TVMRTSCompiler(AbstractExprVisitor[NodeRef]): """The compiler from Relay to the TVM runtime system.""" nodes: List[Node] - id_map: Dict[ir.LocalId, NodeRef] + id_map: Dict[LocalVar, NodeRef] def __init__(self) -> None: self.nodes = [] @@ -94,10 +92,10 @@ def add_node(self, node: Node) -> NodeRef: ident = len(self.nodes) - 1 return NodeRef(ident) - def add_binding(self, ident: ir.LocalId, ref: NodeRef) -> None: + def add_binding(self, ident: LocalVar, ref: NodeRef) -> None: self.id_map[ident] = ref - def let_bind(self, ident: ir.LocalId, node: Node) -> NodeRef: + def let_bind(self, ident: LocalVar, node: Node) -> NodeRef: ref = self.add_node(node) self.add_binding(ident, ref) return ref @@ -105,10 +103,10 @@ def let_bind(self, ident: ir.LocalId, node: Node) -> NodeRef: def get_node(self, ref: NodeRef) -> Node: return self.nodes[ref.ident] - def lookup(self, ident: ir.LocalId) -> NodeRef: + def lookup(self, ident: LocalVar) -> NodeRef: return self.id_map[ident] - def compile(self, func: ir.Function) -> None: + def compile(self, func: Function) -> None: """Compile a single function into a graph.""" # TODO: (@jroesch) Restore me # assert len(fn.ty_params) == 0 @@ -132,30 +130,30 @@ def compile(self, func: ir.Function) -> None: # become our output node. self.get_node(output_ref).is_output = True - def visit_let(self, let: ir.Let) -> NodeRef: + def visit_let(self, let: Let) -> NodeRef: """Visit the Let binding, by first traversing its value, then setting the metadata on the returned NodeRef. Finally visit the body, and return the NodeRef corresponding to it. """ - ident = let.id + ident = let.var val = let.value body = let.body # Need to add type info? val_ref = self.visit(val) - dtype, shape = from_tensor(get_checked_type(val)) + dtype, shape = from_tensor(val.checked_type()) val_node = self.get_node(val_ref) val_node.attrs["dtype"] = dtype val_node.attrs["shape"] = shape self.add_binding(ident, val_ref) return self.visit(body) - def visit_local_id(self, ident: ir.LocalId) -> NodeRef: + def visit_local_id(self, ident: LocalVar) -> NodeRef: return self.lookup(ident) - def visit_call(self, call: ir.Call) -> NodeRef: + def visit_call(self, call: Call) -> NodeRef: inputs = [] for arg in call.args: inputs.append(self.visit(arg).to_json()) @@ -219,20 +217,21 @@ def to_json(self) -> str: return json.dumps(json_dict) -def compile_to_tvm(func): +def compile(func): """Compile a single function to the components needed by the TVM RTS. """ - env = get_env() - iids = [] + op_names = [] - # Why do I need to call items? - for op in env.operators(): - if not Operator_is_generic(op): - iids.append(op.id) + # # Why do I need to call items? + # for op in env.operators(): + # if not Operator_is_generic(op): + # iids.append(op.id) # TODO(@jroesch): Need to write test case for this - mod = compile_ops_to_module(env, iids) + print("above") + mod = compile_ops(op_names) + print("below") comp = TVMRTSCompiler() comp.compile(func) graph_json = comp.to_json() diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc index a1a754615350..db7f11fb9e2b 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/environment.cc @@ -3,10 +3,10 @@ * \file environment.cc * \brief The global environment in Relay. */ -#include #include -#include #include +#include +#include #include "./../pass/resolve.h" // #include "tvm/relay/util/rang.h" @@ -16,8 +16,7 @@ namespace relay { using tvm::IRPrinter; using namespace runtime; -Environment EnvironmentNode::make( - tvm::Map global_funcs) { +Environment EnvironmentNode::make(tvm::Map global_funcs) { std::shared_ptr n = std::make_shared(); n->functions = std::move(global_funcs); return Environment(n); @@ -26,11 +25,11 @@ Environment EnvironmentNode::make( GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) { auto global_id = global_map_.find(str); if (global_id != global_map_.end()) { - return (*global_id).second; + return (*global_id).second; } else { - auto id = GlobalVarNode::make(str); - this->global_map_.Set(str, id); - return id; + auto id = GlobalVarNode::make(str); + this->global_map_.Set(str, id); + return id; } } @@ -39,7 +38,8 @@ GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) { * definition will trigger an exception, otherwise we will * update the definition if and only if it is type compatible. */ -void EnvironmentNode::Add(const GlobalVar& var, const Function & func, bool update) { +void EnvironmentNode::Add(const GlobalVar &var, const Function &func, + bool update) { // Type check the item before we add it to the environment. auto env = GetRef(this); @@ -72,14 +72,14 @@ void EnvironmentNode::Add(const GlobalVar& var, const Function & func, bool upda } } -void EnvironmentNode::Update(const GlobalVar& var, const Function & func) { +void EnvironmentNode::Update(const GlobalVar &var, const Function &func) { this->Add(var, func, true); } -void EnvironmentNode::Remove(const GlobalVar&) { +void EnvironmentNode::Remove(const GlobalVar &) { // Clarify with @tqchen about how to use COW to do this. throw Error("NYI"); - // this->items.erase(id); + // this->items.erase(id); } Function EnvironmentNode::Lookup(const GlobalVar &var) { @@ -96,15 +96,14 @@ Function EnvironmentNode::Lookup(const std::string &str) { return this->Lookup(id); } -void EnvironmentNode::Merge(const Environment & env) { +void EnvironmentNode::Merge(const Environment &env) { for (auto pair : env->functions) { this->functions.Set(pair.first, pair.second); } } - inline SourceName EnvironmentNode::AddSource(std::string file_name, - std::string source) { + std::string source) { return this->source_map_.AddSource(file_name, source); } @@ -130,18 +129,35 @@ void EnvironmentNode::DisplayErrors() { // // Build the cursor. // // Fix this code, hardwired to compute alignment of pointer. - // size_t spaces = error_marker.size() + line_info.size() + file_name.size() + + // size_t spaces = error_marker.size() + line_info.size() + file_name.size() + // + // sp->col_offset - 3; // std::string cursor = "~~~~^~~~~"; // for (size_t i = 0; i < spaces; i++) { // std::cout << " "; // } - // std::cout << rang::fg::red << cursor << " " << err.msg << rang::style::reset + // std::cout << rang::fg::red << cursor << " " << err.msg << + // rang::style::reset // << std::endl; // } } +void EnvironmentNode::Transform(EnvironmentNode::Transformer transformer) { + Array to_process; + for (auto var_and_func : this->functions) { + to_process.push_back(var_and_func.first); + } + + auto for_each = transformer(GetRef(this)); + for (auto var : to_process) { + auto func = this->functions[var]; + auto transformed = for_each(var, func); + this->Add(var, transformed, true); + } +} + + TVM_REGISTER_API("relay._make.Environment") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = EnvironmentNode::make(args[0]); @@ -180,6 +196,11 @@ TVM_REGISTER_API("relay._env.Environment_Merge") env->Merge(args[1]); }); +TVM_REGISTER_API("relay._env.Environment_Transform") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + env->Transform(args[1]); + }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const EnvironmentNode *node, diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 664947425b53..e02a3163e8e7 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -1,4 +1,8 @@ +#include #include +#include +#include + #include #include @@ -132,5 +136,55 @@ TVM_REGISTER_API("relay.op._Register") } }); +bool IsGeneric(const Op& op) { + if (auto ty_func = op.as()) { + return ty_func->type_params.size() == 0; + } else { + return false; + } +} + +using namespace runtime; + +Module CompileOpsToModule(const std::vector & op_names) { + PackedFunc compile_ops = GetPackedFunc("relay.op.compile_ops"); + tvm::Array> args; + + auto compiler_map = Op::GetAttr("FRelayOpCompiler"); + + for (auto op_name : op_names) { + Op op = Op::Get(op_name); + + if (IsGeneric(op)) { + auto compiler = compiler_map[op]; + tvm::Array pair = + compiler(op->name, op->op_type); + //TODO(@jroesch): I can't pass strings across what should be the interface here. + tvm::Array triple = {LocalVarNode::make(op->name), pair[0], pair[1]}; + args.push_back(triple); + } else { + throw dmlc::Error("it is impossible to compile generic operators."); + } + } + + // Nothing to do, bail out earlier. + // TVM will complain if we try to generate a module of size 0. + if (args.size() == 0) { + return Module(nullptr); + } + + return compile_ops(args); +} + +TVM_REGISTER_API("relay.op._CompileOpsToModule") +.set_body([](TVMArgs args, TVMRetValue* ret) { + tvm::Map map = args[0]; + std::vector names; + for (auto pair : map) { + names.push_back(pair.first); + } + *ret = CompileOpsToModule(names); +}); + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/resolve.cc b/src/relay/pass/resolve.cc index f18a67bcffc9..f513e36c9a30 100644 --- a/src/relay/pass/resolve.cc +++ b/src/relay/pass/resolve.cc @@ -33,7 +33,7 @@ struct ResolveTypeType : TypeFVisitor { } }; -struct ResolveTypeExpr : ExprFVisitor<> { +struct ResolveTypeExpr : ExprFVisitor { const TypeUnifier &unifier; explicit ResolveTypeExpr(const TypeUnifier &unifier) : unifier(unifier) {} @@ -53,6 +53,7 @@ struct ResolveTypeExpr : ExprFVisitor<> { // term, then resolve e's old type and write // it back into the new node. auto new_e = ExprFVisitor::VisitExpr(e); + std::cout << e << std::endl; CHECK(e->checked_type_.defined()); auto resolved_cty = VisitType(e->checked_type_); new_e->checked_type_ = resolved_cty; diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 1adfb95d1e15..2ea205b511b1 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -211,6 +211,9 @@ CheckedExpr TypeInferencer::VisitExpr_(const TupleNode *op) { CheckedExpr TypeInferencer::VisitExpr_(const ParamNode *param) { auto rtype = resolve(param->type); + // This is a special case ... not sure if there is a better way + // to handle this. + param->var->checked_type_ = rtype; return {ParamNode::make(param->var, rtype), rtype}; } @@ -545,7 +548,7 @@ Expr TypeInferencer::resolve(const Expr& e) { Expr InferType(const Environment &env, const Expr &e) { TypeInferencer ti(env); auto checked_expr = ti.Infer(e); - return checked_expr.expr; + return ti.resolve(checked_expr.expr); } Expr InferType(const Environment &env, const GlobalVar & var, const Function & func) { @@ -556,7 +559,7 @@ Expr InferType(const Environment &env, const GlobalVar & var, const Function & f auto checked_expr = ti.Infer(func); auto map_node = env->functions.CopyOnWrite(); map_node->data.erase(var.node_); - return checked_expr.expr; + return ti.resolve(checked_expr.expr); } diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index 1ae78441e166..51833e13e475 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -5,7 +5,10 @@ from tvm.relay.ir_builder import IRBuilder, float_type, int_type from tvm.relay.ir_builder import func_type, tensor_type, into_ast from tvm.relay.env import Environment +from tvm.relay.ir_pass import Monomorphize from tvm.relay.op import log, add, equal, subtract +from tvm.relay.expr import Function +from tvm.relay import to_tvm def has_type(expr, typ, env=Environment({})): checked_expr = check_expr(env, expr) @@ -15,14 +18,26 @@ def decl_has_type(env, name, typ): func = env.lookup(name) return func.checked_type() == typ + +def run(env, expr): + if not isinstance(expr, Function): + expr = Function([], None, expr, []) + + env.add("main", expr) + env.transform(Monomorphize.to_pass()) + main = env.lookup("main") + graph_json, mod, _ = to_tvm.compile(main) + import pdb; pdb.set_trace() + def test_monomorphic_let(): "Program: let x = 1; return x" b = IRBuilder() x = b.let('x', 1.0, value_type=float_type(64)) b.ret(x) - prog, _ = b.get() + prog, env = b.get() assert has_type(prog, float_type(64)) + run(env, prog) def test_single_op(): "Program: fn (x : float32) { let t1 = f(x); t1 }" @@ -33,6 +48,25 @@ def test_single_op(): b.ret(t1) assert has_type(func.to_func(), func_type([float_type()], float_type())) +def test_binary_op(): + """ + Program: + fn (x, y) { + return x + y; + } + """ + b = IRBuilder() + x = b.param('x', tensor_type(5, 5, 5)) + y = b.param('y', tensor_type(5, 5, 5)) + with b.function(x, y) as func: + b.ret(add(x, y)) + b.ret(func) + prog, env = b.get() + ttype = tensor_type(5, 5, 5) + expected_ty = func_type([ttype, ttype], ttype) + assert has_type(func.to_func(), expected_ty) + run(env, prog) + def test_dual_op(): """Program: fn (x : Tensor[f32, (10, 10)]) { @@ -40,8 +74,7 @@ def test_dual_op(): let t2 = add(t1, x); return t1; } - """ - pass + """ b = IRBuilder() with b.function(('x', tensor_type(10, 10))) as func: x, = func.param_ids() @@ -58,7 +91,6 @@ def f(x : Tensor[f32, (10, 10)]) { return lx; } """ - pass b = IRBuilder() x = b.param('x') with b.decl('f', x): @@ -90,10 +122,11 @@ def f(n: i32, data: f32) -> f32 { b.ret(data) b.ret(f(into_ast(2.0), into_ast(10000.0))) assert decl_has_type(b.env, 'f', func_type([int_type(), float_type()], float_type())) - + if __name__ == "__main__": - test_monomorphic_let() - test_single_op() - test_dual_op() - test_decl() - test_recursion() + # test_monomorphic_let() + # test_single_op() + test_binary_op() + # test_dual_op() + # test_decl() + # test_recursion() From fb21edbc420c31286fede01fb27b0ce2c680d6d9 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 5 Sep 2018 15:50:08 -0700 Subject: [PATCH 071/136] WIP debugging --- include/tvm/relay/op.h | 2 +- nnvm/python/nnvm/_base.py | 7 +- python/tvm/relay/expr.py | 9 + python/tvm/relay/ir_pass.py | 109 ++++++----- python/tvm/relay/op/_tensor.py | 52 ++++++ python/tvm/relay/op/op.py | 70 ++++--- python/tvm/relay/to_tvm.py | 54 +++--- src/relay/ir/op.cc | 172 +++++++++++------- src/relay/pass/type_infer.cc | 1 + .../relay/test_tyck_eval_integration.py | 18 +- 10 files changed, 314 insertions(+), 180 deletions(-) diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 2e8d090f6625..756451e66768 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -35,7 +35,7 @@ class OpNode : public relay::ExprNode { /*! \brief name of the operator */ std::string name; /*! \brief the type of the operator */ - Type op_type; + mutable FuncType op_type; /*! * \brief detailed description of the operator * This can be used to generate docstring automatically for the operator. diff --git a/nnvm/python/nnvm/_base.py b/nnvm/python/nnvm/_base.py index 29390a2201bf..63b2f815ad9b 100644 --- a/nnvm/python/nnvm/_base.py +++ b/nnvm/python/nnvm/_base.py @@ -22,7 +22,12 @@ numeric_types = (float, int, np.float32, np.int32) # this function is needed for python3 # to convert ctypes.char_p .value back to python str - py_str = lambda x: x.decode('utf-8') + def py_str(x): + try: + return x.decode('utf-8') + except: + print(x) + # py_str = lambda x: x.decode('utf-8') else: string_types = basestring numeric_types = (float, int, long, np.float32, np.int32) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index ec0cfd55ad62..1558853c2820 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -10,7 +10,16 @@ from . import _make class ExprBuilder(): + # def convert_args(self, def __call__(self, *args, **kwargs): + converted_args = [] + for arg in args: + import pdb; pdb.set_trace() + if isinstance(arg, Param): + converted_args.append(arg.var) + else: + converted_args.append(arg) + return Call(self, args, None, None) class Expr(NodeBase, ExprBuilder): diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 70d5f09237d8..8b49710f70ec 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -13,9 +13,10 @@ from .expr import Expr from .expr import Function, Let, Call, LocalVar from .expr import GlobalVar, If, Constant -from .type import Type +from .type import Type, TypeParam from .env import Environment from .op import Op +from .op.op import specialize_op # import relay.make as relay_mk # from relay import ir # from relay.env import Environment @@ -126,7 +127,7 @@ def visit_let(self, let: Let) -> Expr: return Let(new_var, new_val, new_body, new_value_type) def visit_call(self, call: Call) -> Expr: - new_fn = self.visit(call.fn) + new_fn = self.visit(call.op) new_args = [self.visit(arg) for arg in call.args] return Call(new_fn, new_args, call.attrs) @@ -148,7 +149,7 @@ def visit_tuple(self, tup: Tuple) -> Expr: def visit_constant(self, const: Constant) -> Expr: return const -MMCacheKey = Tuple[GlobalVar, List[Type]] +MMCacheKey = Tuple[Union[GlobalVar, str], List[Type]] class Monomorphize(ExprVisitor): """A monomorphization pass. @@ -168,60 +169,54 @@ def __init__(self, env: Environment) -> None: # pylint: disable=no-else-return def visit_call(self, call: Call) -> Expr: - import pdb; pdb.set_trace() - # cache_key = (call.fn, call.ty_args) - # if isinstance(call.fn, OperatorId): - # if cache_key in self.monomorph_map: - # op = self.monomorph_map[cache_key] - # new_args = [self.visit(arg) for arg in call.args] - # return Call(op, new_args, call.attrs) - # else: - # new_name = mangle(call.fn.name, call.ty_args) - # new_id = self.env.operator_id(new_name) - # self.monomorph_map[cache_key] = new_id - # op = self.env.lookup(call.fn) - # for arg in call.ty_args: - # if isinstance(arg, TypeParam): - # return call # raise Exception("...") # Fix me in the morning!!! - # new_op = concretize(new_id, op, call.ty_args, call.attrs) - # self.monomorph_map[cache_key] = new_op.id - # self.env.add(new_op) - # new_args = [self.visit(arg) for arg in call.args] - # return Call(new_op.id, new_args, call.attrs) - # elif isinstance(call.fn, GlobalVar): - # if cache_key in self.monomorph_map: - # op_name = self.monomorph_map[cache_key] - # new_args = [self.visit(arg) for arg in call.args] - # return Call(op_name, new_args, call.attrs) - # else: - # defn = self.env.lookup(call.fn) - # new_id = self.env.global_id(defn.id.name + str(1)) - # cache_key = (call.fn, call.ty_args) - # self.monomorph_map[cache_key] = new_id - # new_body = self.visit(type_specialize(call.ty_args, defn.body)) - # new_body = Function( - # [], new_body.params, new_body.ret_type, new_body.body) - # new_ty = check_expr(self.env, new_body) - # # TODO(@jroesch): move into C++ - # # TODO(@joresch): implement and call name mangler - # defn = Defn(new_id, new_ty, new_body) - # self.env.add(defn) - # self.visit_item(defn) - # return Call(new_id, call.args, call.attrs) - # elif isinstance(call.fn, Function): - # new_args = [self.visit(arg) for arg in call.args] - # new_func = type_specialize(call.ty_args, call.fn) - # new_func = self.visit(new_func) - # new_func = Function([], - # new_func.params, - # new_func.ret_type, - # new_func.body) - # check_expr(self.env, new_func) - # return Call(new_func, call.args, call.attrs) - # else: - # new_fn = self.visit(call.fn) - # new_args = [self.visit(arg) for arg in call.args] - # return Call(new_fn, new_args, call.attrs) + cache_key = (call.op, call.type_args) + new_args = [self.visit(arg) for arg in call.args] + + if cache_key in self.monomorph_map: + op = self.monomorph_map[cache_key] + new_args = [self.visit(arg) for arg in call.args] + return Call(op, new_args, call.attrs) + else: + if isinstance(call.op, Op): + poly_name = call.op.name + mono_name = mangle(poly_name, call.type_args) + for arg in call.type_args: + if isinstance(arg, TypeParam): + return call # raise Exception("...") # Fix me in the morning!!! + + mono_op = specialize_op(poly_name, mono_name, call.type_args) + self.monomorph_map[cache_key] = mono_op + return Call(mono_op, new_args,call.attrs, []) + elif isinstance(call.op, GlobalVar): + return call + # defn = self.env.lookup(call.op) + # new_id = self.env.global_id(defn.id.name + str(1)) + # cache_key = (call.op, call.type_args) + # self.monomorph_map[cache_key] = new_id + # new_body = self.visit(type_specialize(call.type_args, defn.body)) + # new_body = Function( + # [], new_body.params, new_body.ret_type, new_body.body) + # new_ty = check_expr(self.env, new_body) + # # TODO(@jroesch): move into C++ + # # TODO(@joresch): implement and call name mangler + # defn = Defn(new_id, new_ty, new_body) + # self.env.add(defn) + # self.visit_item(defn) + # return Call(new_id, call.args, call.attrs) + + elif isinstance(call.op, Function): + return call + # new_func = type_specialize(call.type_args, call.op) + # new_func = self.visit(new_func) + # new_func = Function([], + # new_func.params, + # new_func.ret_type, + # new_func.body) + # check_expr(self.env, new_func) + # return Call(new_func, call.args, call.attrs) + else: + new_fn = self.visit(call.op) + return Call(new_fn, new_args, call.attrs) # TODO(@jroesch): Fix up my type diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 08dedee0923c..da94ec89b380 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -1,4 +1,56 @@ """Backend compiler related feature regsitration""" +from .op import register +from ..type import FuncType, TensorType +from ...schedule import create_schedule +from ...api import placeholder +from topi import add +def type_to_placeholder(name, ty): + if isinstance(ty, TensorType): + return placeholder(ty.shape, name=name, dtype=ty.dtype) + else: + raise Exception("can only pass Tensor values to TVM operators") +def func_ty_to_placeholders(func_ty): + if isinstance(func_ty, FuncType): + arg_types = func_ty.arg_types + ret_type = func_ty.ret_type + args = [] + var = 0 + for arg in arg_types: + var += 1 + args.append(type_to_placeholder(f"Input{var}", arg)) + return args, ret_type + else: + raise Exception("error") +# def lookup_in_topi(name): +# try: +# f = eval(f"topi.{name}") +# except: +# f = eval(f"topi.nn.{name}") + +# return f + +# @tvm.register_func("nnvm.relay._default_op_compiler") +# def _default_op_compile(op_name: str, func_ty: ir.Type, attrs: ir.Attributes=None) -> Any: +# Inputs, ret_ty = func_ty_to_placeholders(func_ty) +# op = lookup_in_topi(op_name) +# Output = op(*Inputs) + +# if Output.dtype == 'uint1': +# import pdb; pdb.set_trace() +# Output = Output.astype('uint8') + +# schedule = tvm.create_schedule(Output.op) +# return [schedule, Inputs + [Output]] + + +def add_compiler(op_name, func_type, *args): + Inputs, ret_ty = func_ty_to_placeholders(func_type) + # op = lookup_in_topi(op_name) + Output = add(*Inputs) + schedule = create_schedule(Output.op) + return [schedule, Inputs + [Output]] + +register("add", "FRelayOpCompiler", add_compiler) \ No newline at end of file diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index d351e6cdc88d..bb589f40f138 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -5,6 +5,8 @@ from ..expr import Expr from ..._ffi.function import Function, register_func from ...api import convert +from ...container import Map +from ... import lower, build, cpu @register_relay_node class Op(Expr): @@ -78,40 +80,64 @@ def compile_ops(op_names): Parameters ---------- - op_name : str - The name of operator - - attr_key : str - The attribute name. - - value : object, optional - The value to set - - level : int, optional - The priority level + op_names : List[str] + A list of operator names to compile to machine code. Returns ------- - fregister : function - Register function if value is not specified. + A module containing the compiled TVM operators. """ - fake_map = {} - for name in op_names: - fake_map[name] = LocalVar(name) - if isinstance({}, dict): - fake_map = None - return [] # _CompileOpsToModule(fake_map) + return _CompileOpsToModule(*op_names) # TODO(@jroesch): We should port to C++, just need to figure out how to write this code. -@register_func("relay.opt.compile_ops") +@register_func("relay.op._compile_ops") def _compile_ops(op_impls): lowered = [] for local, sch, inputs in op_impls: - lfn = tvm.lower(sch, inputs, name=local.name_hint) + lfn = lower(sch, inputs, name=local.name_hint) lowered.append(lfn) # TOOD(@jroesch): Where should we read these settings from - return tvm.build(lowered, target='llvm', target_host=tvm.cpu(0)) + return build(lowered, target='llvm', target_host='llvm') _init_api("relay.op", __name__) +def specialize_op(op_name, new_op_name, type_args): + """Specializes an operator to a set of types and assigns it new_op_name. + + The idea is to take operators with generic types such as broadcasting + addition: + + add : forall (T : Type) (U : Type), (U, T) -> Broadcast(U, T) + + This is a function which is polymorphic over two types `T` and `U` and + takes a value of type `T` and one of `U` and returns `Broadcast` of U + and T. + + Broadcast is a type relation which relates U and T to an output type. + + The idea is that the above type is shorthand for: + + add : forall (T : Type) (U : Type) (O : Type), Broadcast(U, T, O) => (U, T) -> O + + That is a function from U and T to O where the typing relation between the values + is specified by Broadcast. + + We implement a basic Broadcasting rule in `type_relations.h` but users can specify + their own. + + If we know T=Tensor[(10, 10), dtype], U=Tensor[(10, 10), dtype] then the result + should be Tensor[(10, 10), dtype]. + + We can use SpecializeOp to implement this change of operator. + + Parameters + ---------- + op_name : str + The operator to be specialized. + + Returns + ------- + The specialized operator. + """ + return _SpecializeOp(op_name, new_op_name, type_args) \ No newline at end of file diff --git a/python/tvm/relay/to_tvm.py b/python/tvm/relay/to_tvm.py index d191e078dffe..f2c2a9ba5463 100644 --- a/python/tvm/relay/to_tvm.py +++ b/python/tvm/relay/to_tvm.py @@ -1,11 +1,11 @@ """A compiler from Relay programs to TVM's graph runtime. """ import json -from typing import Dict, Any, List, Tuple +from typing import Dict, Any, List, Tuple, Set import attr from .ir_pass import AbstractExprVisitor -from .op import compile_ops +from .op import compile_ops, Op from .type import TensorType from .expr import LocalVar, Function, Let, Call @@ -69,23 +69,23 @@ def to_json(self) -> Any: } +def shape_to_json(shape): + return [str(sh.value) for sh in shape] + def from_tensor(typ: TensorType) -> Tuple[str, List[int]]: - dtype = typ.dtype.dtype - shape = typ.shape - dims = [] - for dim in shape.shapes: - dims.append(dim.value) - return dtype, dims + return (typ.dtype, shape_to_json(typ.shape)) class TVMRTSCompiler(AbstractExprVisitor[NodeRef]): """The compiler from Relay to the TVM runtime system.""" nodes: List[Node] id_map: Dict[LocalVar, NodeRef] + all_ops: Set[Op] def __init__(self) -> None: self.nodes = [] self.id_map = {} + self.all_ops = set() def add_node(self, node: Node) -> NodeRef: self.nodes.append(node) @@ -116,11 +116,11 @@ def compile(self, func: Function) -> None: for param in params: dtype, shape = from_tensor(param.type) - node = InputNode(f"{param.id.name}", { + node = InputNode(f"{param.var.name_hint}", { "shape": shape, "dtype": dtype, }) - self.let_bind(param.id, node) + self.let_bind(param.var, node) # Then we compile the body into a graph which can depend # on input variables. @@ -150,7 +150,7 @@ def visit_let(self, let: Let) -> NodeRef: self.add_binding(ident, val_ref) return self.visit(body) - def visit_local_id(self, ident: LocalVar) -> NodeRef: + def visit_local_var(self, ident: LocalVar) -> NodeRef: return self.lookup(ident) def visit_call(self, call: Call) -> NodeRef: @@ -158,9 +158,13 @@ def visit_call(self, call: Call) -> NodeRef: for arg in call.args: inputs.append(self.visit(arg).to_json()) - # need to deal with name mangle - op_name = call.fn.name - op_node = OpNode("call_name", {}, op_name, inputs, {}) + assert isinstance(call.op, Op) + self.all_ops.add(call.op.name) + + op_name = call.op.name + attrs = { 'shape': shape_to_json(call.checked_type().shape), + 'dtype': call.checked_type().dtype } + op_node = OpNode("call_name", attrs, op_name, inputs, {}) return self.add_node(op_node) def to_json(self) -> str: @@ -221,18 +225,16 @@ def compile(func): """Compile a single function to the components needed by the TVM RTS. """ - op_names = [] - - # # Why do I need to call items? - # for op in env.operators(): - # if not Operator_is_generic(op): - # iids.append(op.id) - - # TODO(@jroesch): Need to write test case for this - print("above") - mod = compile_ops(op_names) - print("below") comp = TVMRTSCompiler() comp.compile(func) + op_names = list(comp.all_ops) + mod = compile_ops(op_names) graph_json = comp.to_json() - return graph_json, mod, None # params currently isn't supported by API + try: + import nnvm + graph = nnvm.graph.load_json(graph_json) + except Exception as e: + import traceback + traceback.print_tb(e.__traceback__) + import pdb; pdb.set_trace() + return graph, mod, None # params currently isn't supported by API diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index e02a3163e8e7..64467004a973 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -1,10 +1,11 @@ -#include #include +#include #include #include +#include "./../pass/type_subst.h" -#include #include +#include namespace dmlc { // enable registry @@ -25,7 +26,7 @@ struct OpManager { // global operator counter std::atomic op_counter{0}; // storage of additional attribute table. - std::unordered_map > attr; + std::unordered_map> attr; // frontend functions std::vector frontend_funcs; // get singleton of the @@ -38,8 +39,7 @@ struct OpManager { // find operator by name const Op& Op::Get(const std::string& name) { const OpRegistry* reg = dmlc::Registry::Find(name); - CHECK(reg != nullptr) - << "Operator " << name << " is not registered"; + CHECK(reg != nullptr) << "Operator " << name << " is not registered"; return reg->op(); } @@ -61,8 +61,8 @@ const GenericOpMap& Op::GetGenericAttr(const std::string& key) { return *it->second.get(); } -void OpRegistry::UpdateAttr( - const std::string& key, TVMRetValue value, int plevel) { +void OpRegistry::UpdateAttr(const std::string& key, TVMRetValue value, + int plevel) { OpManager* mgr = OpManager::Global(); std::lock_guard lock(mgr->mutex); std::unique_ptr& op_map = mgr->attr[key]; @@ -71,13 +71,11 @@ void OpRegistry::UpdateAttr( } uint32_t index = op_->index_; if (op_map->data_.size() <= index) { - op_map->data_.resize(index + 1, - std::make_pair(TVMRetValue(), 0)); + op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0)); } - std::pair & p = op_map->data_[index]; + std::pair& p = op_map->data_[index]; CHECK(p.second != plevel) - << "Attribute " << key - << " of operator " << this->name + << "Attribute " << key << " of operator " << this->name << " is already registered with same plevel=" << plevel; if (p.second < plevel) { op_map->data_[index] = std::make_pair(value, plevel); @@ -86,59 +84,57 @@ void OpRegistry::UpdateAttr( // Frontend APIs TVM_REGISTER_API("relay.op._ListOpNames") -.set_body_typed()>([]() { - Array ret; - for (const std::string& name : - dmlc::Registry::ListAllNames()) { - ret.push_back(tvm::Expr(name)); - } - return ret; - }); - -TVM_REGISTER_API("relay.op._GetOp") -.set_body_typed(Op::Get); + .set_body_typed()>([]() { + Array ret; + for (const std::string& name : + dmlc::Registry::ListAllNames()) { + ret.push_back(tvm::Expr(name)); + } + return ret; + }); +TVM_REGISTER_API("relay.op._GetOp").set_body_typed(Op::Get); TVM_REGISTER_API("relay.op._OpGetAttr") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Op op = args[0]; - std::string attr_name = args[1]; - auto op_map = Op::GetAttr(attr_name); - if (op_map.count(op)) { - *rv = op_map[op]; - } - }); - + .set_body([](TVMArgs args, TVMRetValue* rv) { + Op op = args[0]; + std::string attr_name = args[1]; + auto op_map = Op::GetAttr(attr_name); + if (op_map.count(op)) { + *rv = op_map[op]; + } + }); TVM_REGISTER_API("relay.op._Register") -.set_body([](TVMArgs args, TVMRetValue* rv) { - std::string op_name = args[0]; - std::string attr_key = args[1]; - runtime::TVMArgValue value = args[2]; - int plevel = args[3]; - auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name(); - // enable resgiteration and override of certain properties - if (attr_key == "num_inputs" && plevel > 128) { - reg.set_num_inputs(value); - } else if (attr_key == "attrs_type_key" && plevel > 128) { - reg.set_attrs_type_key(value); - } else { - // normal attr table override. - if (args[2].type_code() == kFuncHandle) { - // do an eager copy of the PackedFunc - PackedFunc f = args[2]; - // If we get a function from frontend, avoid deleting it. - OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f)); - reg.set_attr(attr_key, f, plevel); + .set_body([](TVMArgs args, TVMRetValue* rv) { + std::string op_name = args[0]; + std::string attr_key = args[1]; + runtime::TVMArgValue value = args[2]; + int plevel = args[3]; + auto& reg = + OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name(); + // enable resgiteration and override of certain properties + if (attr_key == "num_inputs" && plevel > 128) { + reg.set_num_inputs(value); + } else if (attr_key == "attrs_type_key" && plevel > 128) { + reg.set_attrs_type_key(value); } else { - reg.set_attr(attr_key, args[2], plevel); + // normal attr table override. + if (args[2].type_code() == kFuncHandle) { + // do an eager copy of the PackedFunc + PackedFunc f = args[2]; + // If we get a function from frontend, avoid deleting it. + OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f)); + reg.set_attr(attr_key, f, plevel); + } else { + reg.set_attr(attr_key, args[2], plevel); + } } - } - }); + }); bool IsGeneric(const Op& op) { if (auto ty_func = op.as()) { - return ty_func->type_params.size() == 0; + return ty_func->type_params.size() != 0; } else { return false; } @@ -146,8 +142,8 @@ bool IsGeneric(const Op& op) { using namespace runtime; -Module CompileOpsToModule(const std::vector & op_names) { - PackedFunc compile_ops = GetPackedFunc("relay.op.compile_ops"); +Module CompileOpsToModule(const std::vector& op_names) { + PackedFunc compile_ops = GetPackedFunc("relay.op._compile_ops"); tvm::Array> args; auto compiler_map = Op::GetAttr("FRelayOpCompiler"); @@ -155,12 +151,15 @@ Module CompileOpsToModule(const std::vector & op_names) { for (auto op_name : op_names) { Op op = Op::Get(op_name); - if (IsGeneric(op)) { + if (!IsGeneric(op)) { auto compiler = compiler_map[op]; - tvm::Array pair = - compiler(op->name, op->op_type); - //TODO(@jroesch): I can't pass strings across what should be the interface here. - tvm::Array triple = {LocalVarNode::make(op->name), pair[0], pair[1]}; + std::cout << "ABOVE CALL" << std::endl; + tvm::Array pair = compiler(op->name, op->op_type); + std::cout << "BELOW CALL" << std::endl; + // TODO(@jroesch): I can't pass strings across what should be the + // interface here. + tvm::Array triple = {LocalVarNode::make(op->name), pair[0], + pair[1]}; args.push_back(triple); } else { throw dmlc::Error("it is impossible to compile generic operators."); @@ -177,14 +176,49 @@ Module CompileOpsToModule(const std::vector & op_names) { } TVM_REGISTER_API("relay.op._CompileOpsToModule") -.set_body([](TVMArgs args, TVMRetValue* ret) { - tvm::Map map = args[0]; - std::vector names; - for (auto pair : map) { - names.push_back(pair.first); + .set_body([](TVMArgs args, TVMRetValue* ret) { + std::vector names; + for (auto i = 0; i < args.num_args; i++) { + names.push_back(args[i]); + } + std::cout << "Right here" << std::endl; + *ret = CompileOpsToModule(names); + }); + +Op SpecializeOp(const std::string& op_name, + const std::string& new_op_name, Array type_args) { + auto registry = ::tvm::relay::OpRegistry::Registry(); + auto op_reg = registry->__REGISTER_OR_GET__(op_name); + auto new_op_reg = registry->__REGISTER__(new_op_name).set_name(); + + auto fn_ty = op_reg.op()->op_type; + + tvm::Map subst_map; + + CHECK(fn_ty->type_params.size() == type_args.size()); + + // Build a subsitituion map up from the function type and type arguments. + // Eventually allow the type vars to be passed in. + for (auto i = 0; i < type_args.size(); i++) { + subst_map.Set(fn_ty->type_params[i], type_args[i]); } - *ret = CompileOpsToModule(names); -}); + + Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, {}); + inst_ty = TypeSubst(fn_ty, subst_map); + FuncType new_op_ty = GetRef(inst_ty.as()); + new_op_reg.op()->op_type = new_op_ty; + + // Now we want to copy over some attributes. + PackedFunc compiler = Op::GetAttr("FRelayOpCompiler")[op_reg.op()]; + new_op_reg.set_attr("FRelayOpCompiler", compiler); + + return new_op_reg.op(); +} + +TVM_REGISTER_API("relay.op._SpecializeOp") + .set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = SpecializeOp(args[0], args[1], args[2]); + }); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 2ea205b511b1..b624a5709ddd 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -210,6 +210,7 @@ CheckedExpr TypeInferencer::VisitExpr_(const TupleNode *op) { } CheckedExpr TypeInferencer::VisitExpr_(const ParamNode *param) { + // We should trigger error here and move param code direclty into function checking. auto rtype = resolve(param->type); // This is a special case ... not sure if there is a better way // to handle this. diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index 51833e13e475..9e89e8813e08 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -1,6 +1,9 @@ """Test that type checker correcly computes types for expressions. """ +import tvm +import numpy as np +from nnvm import graph from tvm.relay.ir_pass import check_expr from tvm.relay.ir_builder import IRBuilder, float_type, int_type from tvm.relay.ir_builder import func_type, tensor_type, into_ast @@ -9,6 +12,7 @@ from tvm.relay.op import log, add, equal, subtract from tvm.relay.expr import Function from tvm.relay import to_tvm +from tvm.contrib import graph_runtime def has_type(expr, typ, env=Environment({})): checked_expr = check_expr(env, expr) @@ -19,14 +23,18 @@ def decl_has_type(env, name, typ): return func.checked_type() == typ -def run(env, expr): +def run(env, expr, inputs, shape): if not isinstance(expr, Function): expr = Function([], None, expr, []) env.add("main", expr) env.transform(Monomorphize.to_pass()) main = env.lookup("main") - graph_json, mod, _ = to_tvm.compile(main) + graph, lib, _ = to_tvm.compile(main) + module = graph_runtime.create(graph, lib, tvm.cpu(0)) + module.set_input(None, None, **inputs) + module.run() + out = module.get_output(0, out=tvm.nd.array(shape)) import pdb; pdb.set_trace() def test_monomorphic_let(): @@ -59,13 +67,15 @@ def test_binary_op(): x = b.param('x', tensor_type(5, 5, 5)) y = b.param('y', tensor_type(5, 5, 5)) with b.function(x, y) as func: - b.ret(add(x, y)) + b.ret(add(x.var, y.var)) b.ret(func) prog, env = b.get() ttype = tensor_type(5, 5, 5) expected_ty = func_type([ttype, ttype], ttype) assert has_type(func.to_func(), expected_ty) - run(env, prog) + x_data = np.random.rand(5, 5, 5) + y_data = np.random.rand(5, 5, 5) + run(env, prog, { 'x': x_data, 'y': y_data }, (5, 5, 5)) def test_dual_op(): """Program: From 9fcd45d1128b2e8d26c826959eddf0fe3ea550b3 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 6 Sep 2018 13:07:51 -0700 Subject: [PATCH 072/136] Add another test case and do a little clean up --- python/tvm/relay/to_tvm.py | 16 +- src/relay/pass/type_infer.cc | 142 ------------------ .../relay/test_tyck_eval_integration.py | 64 ++++++-- 3 files changed, 56 insertions(+), 166 deletions(-) diff --git a/python/tvm/relay/to_tvm.py b/python/tvm/relay/to_tvm.py index f2c2a9ba5463..181251844a6d 100644 --- a/python/tvm/relay/to_tvm.py +++ b/python/tvm/relay/to_tvm.py @@ -70,7 +70,8 @@ def to_json(self) -> Any: def shape_to_json(shape): - return [str(sh.value) for sh in shape] + return [sh.value for sh in shape] + def from_tensor(typ: TensorType) -> Tuple[str, List[int]]: return (typ.dtype, shape_to_json(typ.shape)) @@ -162,8 +163,8 @@ def visit_call(self, call: Call) -> NodeRef: self.all_ops.add(call.op.name) op_name = call.op.name - attrs = { 'shape': shape_to_json(call.checked_type().shape), - 'dtype': call.checked_type().dtype } + attrs = {'shape': shape_to_json(call.checked_type().shape), + 'dtype': call.checked_type().dtype} op_node = OpNode("call_name", attrs, op_name, inputs, {}) return self.add_node(op_node) @@ -230,11 +231,4 @@ def compile(func): op_names = list(comp.all_ops) mod = compile_ops(op_names) graph_json = comp.to_json() - try: - import nnvm - graph = nnvm.graph.load_json(graph_json) - except Exception as e: - import traceback - traceback.print_tb(e.__traceback__) - import pdb; pdb.set_trace() - return graph, mod, None # params currently isn't supported by API + return graph_json, mod, None # params currently isn't supported by API diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index b624a5709ddd..6cc73d1b8fbe 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -138,8 +138,6 @@ class TypeInferencer : private ExprFunctor { Type resolve(const Type &t); Expr resolve(const Expr &e); CheckedExpr VisitFunction(const Function &f, bool generalize); - void CheckOp(Op op); - // Defn CheckDefn(Defn def); private: CheckedExpr VisitExpr_(const LocalVarNode *op) override; CheckedExpr VisitExpr_(const GlobalVarNode *op) override; @@ -218,43 +216,6 @@ CheckedExpr TypeInferencer::VisitExpr_(const ParamNode *param) { return {ParamNode::make(param->var, rtype), rtype}; } -// // We should probably generalize the subst code. -// struct GeneralizeTypeType : TypeFVisitor { -// Map vars_to_id; -// const TypeUnifier &unifier; - -// GeneralizeTypeType(Map vars_to_id, -// const TypeUnifier &unifier) -// : vars_to_id(vars_to_id), unifier(unifier) {} - -// Type VisitType_(const TypeVarNode *op) override { -// auto repr = unifier->subst(GetRef(op)); -// if (auto tvn = repr.as()) { -// auto ty_var = GetRef(tvn); -// if (vars_to_id.find(ty_var) != vars_to_id.end()) { -// return vars_to_id[ty_var]; -// } else { -// return ty_var; -// } -// } else { -// return this->VisitType(repr); -// } -// } -// }; - -// struct GeneralizeTypeExpr : ExprFVisitor<> { -// Map vars_to_id; -// const TypeUnifier &unifier; - -// GeneralizeTypeExpr(const TypeUnifier &unifier, -// Map vars_to_id) -// : vars_to_id(vars_to_id), unifier(unifier) {} - -// Type VisitType(const Type &t) { -// return GeneralizeTypeType(vars_to_id, unifier).VisitType(t); -// } -// }; - CheckedExpr TypeInferencer::VisitFunction(const Function &f, bool generalize) { // First we add the parameters to the context allowing us to check their // types. @@ -282,83 +243,6 @@ CheckedExpr TypeInferencer::VisitFunction(const Function &f, bool generalize) { return {FunctionNode::make(params, unified_rtype, checked_body.expr, {}), FuncTypeNode::make(param_types, unified_rtype, {}, {})}; }); - - // // typecheck body and ensure that it matches stated return type - // // TODO(sslyu): should the unified return type override the annotated - // one? Type checked_return = this->Check(f->body); Type ret_type = - // resolve(f->ret_type); Type unified = - // this->unify(simple_eval_shape(ret_type), - // simple_eval_shape(checked_return), f->span); - // return TypeArrowNode::make(arg_types, unified); - // }); - // if (generalize) { - // auto free_vars = free_type_vars(resolve(fn_type)); - // std::set dedup_free_vars; - - // for (auto free_var : free_vars) { - // auto repr = this->unifier->subst(free_var); - // if (auto new_free_var_node = repr.as()) { - // dedup_free_vars.insert(GetRef(new_free_var_node)); - // } else { - // // debug(repr); - // throw dmlc::Error( - // "internal error: this list should only contain type var - // nodes"); - // } - // } - - // Map vars_to_id; - - // GenFresh gf; - // for (auto free_var : dedup_free_vars) { - // vars_to_id.Set(free_var, gf.freshTV(free_var->kind)); - // } - - // fn_type = GeneralizeTypeType(vars_to_id, unifier).VisitType(fn_type); - // for (std::pair pair : vars_to_id) { - // // NB: In generalization we want to find type variables with - // // *no constraints* on them, and convert them to universally - // quantified - // // variables. - // // - // // i.e the program can be abstracted over the details of *that* type. - - // // For example a program that works irrespective of shape or - // datatype. - - // // In order to do this we find the set of free type variables in the - // // term, and then unify them with the fresh type ids we generate. - // // - // // Remember importantly these type variables still may appear in many - // // places in the program including both types and expressions. - - // // Our method for resolving these is to unify them with the variables - // // as we build the new quanitifer, changing from a program with - // "holes" - // // to one that is properly abstracted over. - - // // Finally later on we can iterate over the whole term and change - // from - // // type variables to these type ids. - // this->unify(pair.first, pair.second, pair.second->span); - // fn_type = TypeQuantifierNode::make(pair.second, fn_type); - // } - // } else { - // for (auto i = f->ty_params.size(); i > 0; i--) { - // auto ty_param = f->ty_params[i - 1]; - // auto ty_param_node = ty_param.as(); - // if (!ty_param_node) { - // throw dmlc::Error("internal error should be TypeParam"); - // } - // auto fresh_tid = - // TypeParamNode::make(ty_param_node->name, ty_param_node->kind); - // fn_type = - // TypeSubst(fn_type, GetRef(ty_param_node), fresh_tid); - // fn_type = TypeQuantifierNode::make(fresh_tid, fn_type); - // } - // } - - // return fn_type; } CheckedExpr TypeInferencer::VisitExpr_(const FunctionNode *op) { @@ -520,32 +404,6 @@ Expr TypeInferencer::resolve(const Expr& e) { return ::tvm::relay::Resolve(this->unifier, e); } -// Defn TypeInferencer::CheckDefn(Defn defn) { -// // This is to handle recursion, but we need to speculatively -// // put it in env, then remove it. -// env->items.insert({defn->id, defn}); - -// Type expected_ty = this->resolve(defn->type); - -// Expr body = defn->body; - -// auto checked_ty = Check(body); - -// try { -// Type uret_type = unify(expected_ty, checked_ty, defn->body->span); -// CHECK(is_fully_resolved(uret_type)); -// // Now let's clean up our work from earlier. -// env->items.erase(defn->id); -// return DefnNode::make(defn->id, uret_type, this->resolve(defn->body)); -// } catch (const UnificationError& err) { -// std::string msg = std::string("mismatch between `") + -// PrintType(env, expected_ty, WrapWidth(40)) + "` and -// `" + PrintType(env, checked_ty, WrapWidth(40)) + -// "`"; -// fatal_error(msg, defn->span); -// } -// } - Expr InferType(const Environment &env, const Expr &e) { TypeInferencer ti(env); auto checked_expr = ti.Infer(e); diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index 9e89e8813e08..cd87fb83ec52 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -2,7 +2,7 @@ for expressions. """ import tvm -import numpy as np +import numpy as np from nnvm import graph from tvm.relay.ir_pass import check_expr from tvm.relay.ir_builder import IRBuilder, float_type, int_type @@ -13,15 +13,18 @@ from tvm.relay.expr import Function from tvm.relay import to_tvm from tvm.contrib import graph_runtime +import nnvm + def has_type(expr, typ, env=Environment({})): checked_expr = check_expr(env, expr) return checked_expr.checked_type() == typ + def decl_has_type(env, name, typ): func = env.lookup(name) return func.checked_type() == typ - + def run(env, expr, inputs, shape): if not isinstance(expr, Function): @@ -31,11 +34,14 @@ def run(env, expr, inputs, shape): env.transform(Monomorphize.to_pass()) main = env.lookup("main") graph, lib, _ = to_tvm.compile(main) - module = graph_runtime.create(graph, lib, tvm.cpu(0)) + # We use NNVM to load the graph right now because it populates node_row_ptr field. + nnvm_graph = nnvm.graph.load_json(graph) + module = graph_runtime.create(nnvm_graph, lib, tvm.cpu(0)) module.set_input(None, None, **inputs) module.run() - out = module.get_output(0, out=tvm.nd.array(shape)) - import pdb; pdb.set_trace() + out_nd_array = tvm.nd.array(np.empty(shape, dtype='float32')) + return module.get_output(0, out=out_nd_array) + def test_monomorphic_let(): "Program: let x = 1; return x" @@ -45,7 +51,8 @@ def test_monomorphic_let(): prog, env = b.get() assert has_type(prog, float_type(64)) - run(env, prog) + run(env, prog, [], float_type(64)) + def test_single_op(): "Program: fn (x : float32) { let t1 = f(x); t1 }" @@ -56,7 +63,8 @@ def test_single_op(): b.ret(t1) assert has_type(func.to_func(), func_type([float_type()], float_type())) -def test_binary_op(): + +def test_add_op(): """ Program: fn (x, y) { @@ -73,9 +81,34 @@ def test_binary_op(): ttype = tensor_type(5, 5, 5) expected_ty = func_type([ttype, ttype], ttype) assert has_type(func.to_func(), expected_ty) - x_data = np.random.rand(5, 5, 5) - y_data = np.random.rand(5, 5, 5) - run(env, prog, { 'x': x_data, 'y': y_data }, (5, 5, 5)) + x_data = tvm.nd.array(np.random.rand(5, 5, 5).astype('float32')) + y_data = tvm.nd.array(np.random.rand(5, 5, 5).astype('float32')) + result = run(env, prog, {'x': x_data, 'y': y_data}, (5, 5, 5)) + np.testing.assert_allclose( + x_data.asnumpy() + y_data.asnumpy(), result.asnumpy()) + +def test_add_broadcast_op(): + """ + Program: + fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] { + return x + y; + } + """ + b = IRBuilder() + x = b.param('x', tensor_type(10, 4)) + y = b.param('y', tensor_type(5, 10, 1)) + with b.function(x, y) as func: + b.ret(add(x.var, y.var)) + b.ret(func) + prog, env = b.get() + ttype = tensor_type(5, 5, 5) + expected_ty = func_type([ttype, ttype], ttype) + assert has_type(func.to_func(), expected_ty) + x_data = tvm.nd.array(np.random.rand(5, 5, 5).astype('float32')) + y_data = tvm.nd.array(np.random.rand(5, 5, 5).astype('float32')) + result = run(env, prog, {'x': x_data, 'y': y_data}, (5, 10, 4)) + np.testing.assert_allclose( + x_data.asnumpy() + y_data.asnumpy(), result.asnumpy()) def test_dual_op(): """Program: @@ -84,7 +117,7 @@ def test_dual_op(): let t2 = add(t1, x); return t1; } - """ + """ b = IRBuilder() with b.function(('x', tensor_type(10, 10))) as func: x, = func.param_ids() @@ -109,6 +142,7 @@ def f(x : Tensor[f32, (10, 10)]) { _, env = b.get() assert decl_has_type(env, 'f', func_type([float_type()], float_type())) + def test_recursion(): """ Program: @@ -131,12 +165,16 @@ def f(n: i32, data: f32) -> f32 { with b.else_scope(): b.ret(data) b.ret(f(into_ast(2.0), into_ast(10000.0))) - assert decl_has_type(b.env, 'f', func_type([int_type(), float_type()], float_type())) + assert decl_has_type(b.env, 'f', func_type( + [int_type(), float_type()], float_type())) + # TODO(@jroesch): need evaluator or new runtime + # to execute this. if __name__ == "__main__": # test_monomorphic_let() # test_single_op() - test_binary_op() + test_add_op() + test_add_broadcast_op() # test_dual_op() # test_decl() # test_recursion() From 6323e39cc9aa7af23861dfb9ae651eb0a2ae5c13 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 6 Sep 2018 13:08:39 -0700 Subject: [PATCH 073/136] Port docs from previous Relay version --- docs/api/python/relay/index.rst | 17 ++ docs/langref/relay/expressions.rst | 178 +++++++++++++++++++++ docs/langref/relay/index.rst | 17 ++ docs/langref/relay/intro.rst | 17 ++ docs/langref/relay/type_system.rst | 137 ++++++++++++++++ tutorials/relay/implement_fma_transform.py | 141 ++++++++++++++++ 6 files changed, 507 insertions(+) create mode 100644 docs/api/python/relay/index.rst create mode 100644 docs/langref/relay/expressions.rst create mode 100644 docs/langref/relay/index.rst create mode 100644 docs/langref/relay/intro.rst create mode 100644 docs/langref/relay/type_system.rst create mode 100644 tutorials/relay/implement_fma_transform.py diff --git a/docs/api/python/relay/index.rst b/docs/api/python/relay/index.rst new file mode 100644 index 000000000000..32db5daded2b --- /dev/null +++ b/docs/api/python/relay/index.rst @@ -0,0 +1,17 @@ +Relay API +========= + +This document contains the Python API to the Relay frontend, optimizer, and +compiler toolchain. + +Relay is a new high level intermediate representation for the TVM compiler +stack. Our goal is to generalize computation graphs provided by previous +languages to full differentiable programs. + +.. toctree:: + :maxdepth: 2 + + env + ir + make + unifier diff --git a/docs/langref/relay/expressions.rst b/docs/langref/relay/expressions.rst new file mode 100644 index 000000000000..37dc62c6bc24 --- /dev/null +++ b/docs/langref/relay/expressions.rst @@ -0,0 +1,178 @@ +================== +Expressions +================== + +Relay's IR is a pure expression oriented language, that has a +dataflow fragment and structured control flow. Although Relay's +representation is a tree, it is possible to view the dataflow +fragments as graph for purposes of writing and expressing +transformations. + +The below sections make an attempt to clearly split the dataflow +fragment from the control fragment. + +================== +Dataflow Expressions +================== + +First we will cover the set of nodes which do not involve control flow, +this fragment of the language is semantically equivalent to pure +computation graphs without control flow. + +Constants +~~~~~~~~~ +Relay programs can contain constant Tensor values, since in Relay all +values are either Tensors, Products, or Closures. We will discuss the +later two later, but we represent Tensor constants as `tvm.NDArray`, +allowing us to utilize normal operators for constant evaluation. + + +Constructors +~~~~~~~~ + +Relay supports a handful of constructors which we will cover below. A +constructor enables programs to build new values from arbitrary Relay +expressions. + + +We support four types of literals, literals are type polymorphic and can +assigned any base type. If we can not solve for a concrete type we apply +a defaulting rule. + +We support signed and unsigned integers, floating point numbers, booleans, +and tensor literals. + +The base type literals are designed to closely model literals in TVM's +expressions langauge. + +### Boolean Literals +TODO: don't have these in any form right now + +### Integer Literals +TODO: don't have these in any form right now + +Tensor Constructor +~~~~~~~~~~~~~~~ + +A tensor literal allows us to build a Tensor from other expressions. + +TODO: Example here + + +Tuple Constructor +~~~~~~~~~~~~~~~ + +We support tuple constructors which allows us to build a fixed-k sized +sequence of heterogenous data. These tuples match closely to Python's +and enable efficient projection of their members due to their fixed length. + + (a, b, c) : Tuple + + (a + b + c, d) : Tuple, Tensor> + +Function +~~~~~~~~ + +A function node represents a function, it contains a seqeuence of +parameters, a return type, and a body. + + fun (x : Float, y: Float) -> Float { x + y } + +Functions are first class in Relay, and can be used in any expression +position. Functions are the same as global functions, but do not have +an explicit name. You can use a function in conjunction with a let +binding to define locally recursive functions. + + let fact = fun (x : Float) -> Float { + if (x == 0) { + 0 + } else { + x * fact(x - 1) + }; + fact(10) + +Identifiers +~~~~~~~~~~~ + +All of the identifiers are valid expressions, you can use a local identifier, +global identifier, or intrinsic identifier anywhere an expression may appear. + +For example the below fragment of code is a valid expression. + + %ret = @global(intrinsic, %local) + +Let Binding +~~~~~~~~~~~ + +An immutable variable binding, allows the user to bind an +expression to a name. A let binding contains a local identifier, +an optional type, a value, and body expression which may +reference the bound identifier. + +We will first introduce a single binding with no type +anntoations:: + let %x = %a + %b; + x + +The value of a let binding is the value of the final expression +after evaluating the bindings it depends on. + +A user can write a sequence of let bindings, we can view +these blocks and pure dataflow +single binding. These blocks are pure dataflow, and can +be evaluated in any order, reordered up to dataflow. + +We support a sequence of bindings followed by a body which +is the continutation after executing the sequence of bindings. + +I believe this representation will be easier to manipulate then +the mixed dataflow/control flow comptuation graphs. +Data flow and control flow are strictly seperated in this representation +and we can easily syntactically discriminate. When in ANF there should only be +general control flow between `Assignment` nodes and not within the values bound +in bindings. + +This representation also makes it easy to apply reverse more since +sequences of assignments where the only control flow is call instructions +are treated by the algorithm uniformly, and each control flow construct +must be handled individualy. + +TODO Add Ref, ReadRef, WriteRef, Projection, + +Gradient +~~~~~~~~ + +The `Reverse` acts as a marker node, when the compiler encounters it +we will apply the reverse mode transformation to the enclosed function. + +We will employ static analysis and constant evaluation in order to +simplify the node's argument to a known function call target. + + +You can compute the reverse node of a function node like so: + +Cast +~~~~~ + +Cast the type of the `node` to `ty`. + +======================= +Control Flow Expression +======================= +Control flow expressions change network topology based on values +computed by previous expressions. + +Call +~~~~ + +Terms with function types in Relay are "callable", that can be invoked like +a function in a typical programming language by supplying a set of arguments. + +Instrinsics with functions types, definitions, and functions are all callable. + +If-Then-Else +~~~~~~~~~~~~ + +Relay has a simple if/then/else expression which allows programs to branch +on a single control value which must be of type `Bool`, i.e a zero-rank +tensor of booleans. diff --git a/docs/langref/relay/index.rst b/docs/langref/relay/index.rst new file mode 100644 index 000000000000..617e745acdfc --- /dev/null +++ b/docs/langref/relay/index.rst @@ -0,0 +1,17 @@ +Relay Language Reference +======================== + +This document is a work in progress language reference describing +Relay, TVM's high level intermediate representation. The name is an +allusion to interneurons which are often referred to as intermediate, +or relay neurons. + +We will continually iterate on this document as we evolve the new IR +and update accordingly. + +.. toctree:: + :maxdepth: 2 + + intro + expressions + type_system diff --git a/docs/langref/relay/intro.rst b/docs/langref/relay/intro.rst new file mode 100644 index 000000000000..617e745acdfc --- /dev/null +++ b/docs/langref/relay/intro.rst @@ -0,0 +1,17 @@ +Relay Language Reference +======================== + +This document is a work in progress language reference describing +Relay, TVM's high level intermediate representation. The name is an +allusion to interneurons which are often referred to as intermediate, +or relay neurons. + +We will continually iterate on this document as we evolve the new IR +and update accordingly. + +.. toctree:: + :maxdepth: 2 + + intro + expressions + type_system diff --git a/docs/langref/relay/type_system.rst b/docs/langref/relay/type_system.rst new file mode 100644 index 000000000000..91a634431d7c --- /dev/null +++ b/docs/langref/relay/type_system.rst @@ -0,0 +1,137 @@ +================== +Type System +================== + +We have briefly introduced types while detailing the the expression language +of Relay, but have fully laid out the type system. + +Although the majority of Relay programs require no type annotations, Relay +is statically typed. Each expression in Relay has a precisely known type. + +You might ask why we want a statically typed IR, there are multiple advantages. +- efficient layout and code generation for tensors +- TODO +- debugging transformations (most program transformations should be type perserving) + +We are able to omit these type annotations by a process known as type inference. +Type inference is a technique that has its roots in the programming language +community, and can be viewed as a method for generalizing shape inference to +run over arbitrary user programs. + +Static typing means we know before executing the program properties about +the values it manipulates. Static types are useful for compiler optimization +because they communicate properties about the data we manipulate, such as +runtime shape, data layout, storage. + +Most current IRs use "shape inference" to recover Tensor dimensions from the user +provided program. Machine learning users have enjoyed shape inference for +tensors because it allows them to generate performant code without giving up +on the expressivity of the input language. + +Because Relay is intended as an IR we require *some* type information to provide +full inference. We don't believe this to be an issue as many of the IR builder +inferfaces require some type information, or can generate IR based on their own +higher level inferences. + +We view this limited shape inference as a simpler form of type +inference. Instead of relying on an ad-hoc procedure for recovering type +information from a potentially dynamic program, we apply ideas from compiler and IR design. + +Below we briefly dicsuss the different kinds of types in Relay. + +===== +Types +===== + +BaseType +~~~~~~~~~~ +Relay has a notion of a BaseType, which captures the set of types +that can be stored in a Tensor. Relay's base types map to the set +of types supported by TVM. + +Each of the base types can be parametrized by number of bits, and +lanes for vectorization purposes. We support four base types any:`Bool`, +any:`Int` + +Type Variables +~~~~~~~~~~~~~~ + +Type Parameters +~~~~~~ +TODO: type parameter + +Kind +~~~~ + +Function Types +~~~~~~~~~~ +TODO: rename function type? + +TypeQuantifier +~~~~~~~~~~~~~~ +TODO + +Placeholders +~~~~~~~~~~~~ + +TODO + +Tuple Types +~~~~~~~~~~~~~ + +Reference Types +~~~~~~~~~~~~~~~ + +A reference type is simply a mutable memory location, since Relay is a pure +language by default we need a way to introduce limited mutability. In this +case mutable data is clearly marked in the type system as a reference type. + + Ref + +Tensor Type +~~~~~~~~~~~ + +Tensor values in Relay are typed with tensor types. A tensor type is +parametrized by a data type, and shape. The data type must be a base +type as enforced by the kind checking rules described in TODO. + +This restriction importantly means + +The shape may be any valid Relay shape as described in the below +section on shapes. + + +====== +Shapes +====== + +Shape Singleton +~~~~~~~~~~~~~~~ +I don't like this name + +ShapeAttr +~~~~~~~~~ +TODO + +ShapeProjection +~~~~~~~~~~~~~~~ +TODO + +ShapeBinaryOp +~~~~~~~~~~~~~ + +enum ShapeOp : int { + SHPLUS = 0, + SHSUB = 1, + SHMUL = 2, + SHDIV = 3 +}; + + +Shape Sequence +~~~~~~~~ +A sequence of shapes ... + + +ShapeBroadcast +~~~~~~~~~~~~~~ diff --git a/tutorials/relay/implement_fma_transform.py b/tutorials/relay/implement_fma_transform.py new file mode 100644 index 000000000000..8410dd6c1152 --- /dev/null +++ b/tutorials/relay/implement_fma_transform.py @@ -0,0 +1,141 @@ +"""How to use Relay to implement a simple two-operator fusion pass. +================================== +**Author**: `Jared Roesch `_ + +In this tutorial, we will demonstrate how to write a fusion pass for +the Relay IR. We demonstrate many Relay features including defining a +new operator, a program transform, the NNVM compatibility layer, +and executing the original and transformed programs on the Relay +evaluator and TVM runtime system. +""" + +################################################################ +# Introduction +# ------------------------- +# +# We use the fixed size for input tensors with 256 channels and 14 x 14 +# dimensions. The batch size is 256. Convolution filters contain 512 filters +# of size 3 x 3. We use stride size 1 and padding size 1 for the +# convolution. The following code defines the convolution algorithm in TVM. +# + +from typing import Any, Dict + +import numpy as np +import tvm +import topi + +from relay import ir, make as mk +from relay.ir import OperatorId +from relay.opt import ItemVisitor, ExprVisitor +from relay.frontend.nnvm import Variable, symbol +from relay.frontend.nnvm import compiler +from relay.frontend.global_env import get_env +from relay.operators.register import func_ty_to_placeholders, register_op +from relay.eval import defn_to_pyfunc +from relay.tyck import check_expr + +class ExprAtVisitor(ExprVisitor): + """A demo visitor which adds a new traversal strategy.""" + expr_map: Dict[ir.LocalId, ir.Expr] + + def __init__(self): + self.expr_map = {} + + def expr_at(self,id: ir.LocalId) -> ir.Expr: + try: + return self.expr_map[id] + except KeyError: + return id + + def visit_let(self, let: ir.Let) -> ir.Expr: + self.expr_map[let.id] = let.value + return super().visit_let(let) + +# let x = 1 + 1; +# ... x will map to 1 + 1 + +class FuseTwo(ExprAtVisitor): + """Rewrite b(a(x, y), z) into ab(x, y, z). """ + def __init__(self, a: OperatorId, b: OperatorId, a_and_b: OperatorId) -> None: + self.a = a + self.b = b + self.a_and_b = a_and_b + super().__init__() + + def visit_call(self, call: ir.Call) -> ir.Expr: + func = call.fn + if func == self.b: + assert len(call.args) == 2 # An assumption of this fusion code. + arg0 = self.expr_at(call.args[0]) + arg1 = self.expr_at(call.args[1]) + if isinstance(arg0, ir.Call) and arg0.fn == self.a: + new_call = mk.Call(self.a_and_b, arg0.args[:] + [arg1]) + elif isinstance(arg1, ir.Call) and arg1.fn == self.a: + new_call = mk.Call(self.a_and_b, arg1.args[:] + [arg0]) + else: + new_call = super().visit_call(call) + + return new_call + else: + return super().visit_call(call) + +def fma_compile(op_name: str, func_ty: ir.Type, attrs: ir.Attributes=None) -> Any: + Inputs, ret_ty = func_ty_to_placeholders(func_ty) + x, y, z = Inputs + Output = topi.multiply(topi.add(x, y), z) + # this is not a python function call, but builds an AST + schedule = tvm.create_schedule(Output.op) + return [schedule, Inputs + [Output]] + + +def register_fma(env: Any) -> None: + """Register TOPI's elementwise broadcast addition for the `+` operator.""" + shape = mk.TypeParam("s", ir.Kind.Shape) + bt = mk.TypeParam("bt", ir.Kind.BaseType) + in_out_type = mk.TensorType(bt, shape) + fma_type = mk.TypeQuantifier(bt, mk.TypeQuantifier(shape, mk.TypeArrow([in_out_type, in_out_type, in_out_type], in_out_type))) + # forall (bt: BaseTYpe) (s : Shape), Tensor[bt, s] -> Tensor[bt, s] -> Tensor[bt, s] + # TODO: no reverse mode + register_op(env, 'fma', fma_type, compiler=fma_compile) + +# Get the global environment for demo purposes. +env = get_env() + +register_fma(env) + +# A small helper which applies just our transform to the Relay expression. +def transform(e): + fuse = FuseTwo(env.add_id(), env.mul_id(), env.operator_id('fma')) + e = fuse.visit(e) + # Now let's use the type checker to make sure we didn't make a mistake. + check_expr(env, e) + return e + +# We will use NNVM frontend. +x = Variable('x') +y = Variable('y') +z = x * (x + y) + +relay_func = compiler.to_relay(z) + +print(f"Relay Function:\n{compiler.pp(relay_func)}") + +xform_func = transform(relay_func) + +print(f"Transformed Function:\n{compiler.pp(xform_func)}") + +# Use the evaluator. +norm = defn_to_pyfunc(env, relay_func) +xform = defn_to_pyfunc(env, xform_func) + +x = np.random.uniform(size=(10, 5, 10)).astype('float32') +y = np.random.uniform(size=(10, 5, 10)).astype('float32') + +norm_out = norm(x, y).asnumpy() +xform_out = xform(x, y).asnumpy() + +np.testing.assert_allclose(norm_out, xform_out) + +# Use the TVM runtime. + From 9fac901714c38df98a23f12009223f6be2a3e0b4 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 6 Sep 2018 17:10:28 -0700 Subject: [PATCH 074/136] Update docs --- docs/api/python/index.rst | 1 + docs/api/python/relay/base.rst | 9 +++ docs/api/python/relay/env.rst | 6 ++ docs/api/python/relay/expr.rst | 36 +++++++++++ docs/api/python/relay/index.rst | 15 +++-- docs/api/python/relay/ir_builder.rst | 6 ++ docs/api/python/relay/ir_pass.rst | 3 + docs/api/python/relay/op.rst | 3 + docs/api/python/relay/to_tvm.rst | 3 + docs/api/python/relay/type.rst | 27 ++++++++ python/tvm/relay/type.py | 72 +++++++++++++++++++++- tutorials/relay/implement_fma_transform.py | 10 +-- 12 files changed, 178 insertions(+), 13 deletions(-) create mode 100644 docs/api/python/relay/base.rst create mode 100644 docs/api/python/relay/env.rst create mode 100644 docs/api/python/relay/expr.rst create mode 100644 docs/api/python/relay/ir_builder.rst create mode 100644 docs/api/python/relay/ir_pass.rst create mode 100644 docs/api/python/relay/op.rst create mode 100644 docs/api/python/relay/to_tvm.rst create mode 100644 docs/api/python/relay/type.rst diff --git a/docs/api/python/index.rst b/docs/api/python/index.rst index 59bd1795b7ec..ab411d77f4f4 100644 --- a/docs/api/python/index.rst +++ b/docs/api/python/index.rst @@ -23,4 +23,5 @@ Python API topi vta/index nnvm/index + relay/index hybrid diff --git a/docs/api/python/relay/base.rst b/docs/api/python/relay/base.rst new file mode 100644 index 000000000000..f0cec295ee6b --- /dev/null +++ b/docs/api/python/relay/base.rst @@ -0,0 +1,9 @@ +tvm.relay.base +----------- +.. automodule:: tvm.relay.base + +.. autoclass:: tvm.relay.base.NodeBase + :members: + +.. autoclass:: tvm.relay.base.Span + :members: \ No newline at end of file diff --git a/docs/api/python/relay/env.rst b/docs/api/python/relay/env.rst new file mode 100644 index 000000000000..eca7312d5bbb --- /dev/null +++ b/docs/api/python/relay/env.rst @@ -0,0 +1,6 @@ +tvm.relay.env +----------- +.. automodule:: tvm.relay.env + +.. autoclass:: tvm.relay.env.Environment + :members: \ No newline at end of file diff --git a/docs/api/python/relay/expr.rst b/docs/api/python/relay/expr.rst new file mode 100644 index 000000000000..cd0cb5c308c4 --- /dev/null +++ b/docs/api/python/relay/expr.rst @@ -0,0 +1,36 @@ +tvm.relay.expr +----------- +.. automodule:: tvm.relay.expr + +.. autoclass:: tvm.relay.expr.ExprBuilder + :members: + +.. autoclass:: tvm.relay.expr.Expr + :members: + +.. autoclass:: tvm.relay.expr.Constant + :members: + +.. autoclass:: tvm.relay.expr.Tuple + :members: + +.. autoclass:: tvm.relay.expr.LocalVar + :members: + +.. autoclass:: tvm.relay.expr.GlobalVar + :members: + +.. autoclass:: tvm.relay.expr.Param + :members: + +.. autoclass:: tvm.relay.expr.Function + :members: + +.. autoclass:: tvm.relay.expr.Call + :members: + +.. autoclass:: tvm.relay.expr.Let + :members: + +.. autoclass:: tvm.relay.expr.If + :members: \ No newline at end of file diff --git a/docs/api/python/relay/index.rst b/docs/api/python/relay/index.rst index 32db5daded2b..231d49df0e6d 100644 --- a/docs/api/python/relay/index.rst +++ b/docs/api/python/relay/index.rst @@ -4,14 +4,17 @@ Relay API This document contains the Python API to the Relay frontend, optimizer, and compiler toolchain. -Relay is a new high level intermediate representation for the TVM compiler -stack. Our goal is to generalize computation graphs provided by previous -languages to full differentiable programs. +Relay is the second generation high level intermediate representation for the TVM +compiler stack. .. toctree:: :maxdepth: 2 + base env - ir - make - unifier + expr + ir_builder + ir_pass + op + to_tvm + type diff --git a/docs/api/python/relay/ir_builder.rst b/docs/api/python/relay/ir_builder.rst new file mode 100644 index 000000000000..b12e3cc6cdd1 --- /dev/null +++ b/docs/api/python/relay/ir_builder.rst @@ -0,0 +1,6 @@ +tvm.relay.ir_builder +----------- +.. automodule:: tvm.relay.ir_builder + +.. autoclass:: tvm.relay.ir_builder.IRBuilder + :members: \ No newline at end of file diff --git a/docs/api/python/relay/ir_pass.rst b/docs/api/python/relay/ir_pass.rst new file mode 100644 index 000000000000..e2e3b432e5bd --- /dev/null +++ b/docs/api/python/relay/ir_pass.rst @@ -0,0 +1,3 @@ +tvm.relay.ir_pass +----------- +.. automodule:: tvm.relay.ir_pass \ No newline at end of file diff --git a/docs/api/python/relay/op.rst b/docs/api/python/relay/op.rst new file mode 100644 index 000000000000..fb8e9ce774c2 --- /dev/null +++ b/docs/api/python/relay/op.rst @@ -0,0 +1,3 @@ +tvm.relay.op +----------- +.. automodule:: tvm.relay.op \ No newline at end of file diff --git a/docs/api/python/relay/to_tvm.rst b/docs/api/python/relay/to_tvm.rst new file mode 100644 index 000000000000..72d01d123e0f --- /dev/null +++ b/docs/api/python/relay/to_tvm.rst @@ -0,0 +1,3 @@ +tvm.relay.to_tvm +----------- +.. automodule:: tvm.relay.to_tvm diff --git a/docs/api/python/relay/type.rst b/docs/api/python/relay/type.rst new file mode 100644 index 000000000000..d357df8f08ac --- /dev/null +++ b/docs/api/python/relay/type.rst @@ -0,0 +1,27 @@ +tvm.relay.type +----------- +.. automodule:: tvm.relay.type + +.. autoclass:: tvm.relay.type.Type + :members: + +.. autoclass:: tvm.relay.type.TensorType + :members: + +.. autoclass:: tvm.relay.type.Kind + :members: + +.. autoclass:: tvm.relay.type.TypeParam + :members: + +.. autoclass:: tvm.relay.type.TypeConstraint + :members: + +.. autoclass:: tvm.relay.type.FuncType + :members: + +.. autoclass:: tvm.relay.type.TypeCall + :members: + +.. autoclass:: tvm.relay.type.IncompleteType + :members: \ No newline at end of file diff --git a/python/tvm/relay/type.py b/python/tvm/relay/type.py index 70e4666e96f9..cde989603929 100644 --- a/python/tvm/relay/type.py +++ b/python/tvm/relay/type.py @@ -6,11 +6,12 @@ from tvm import expr from . import _make + class Type(NodeBase): """The base type for all Relay types.""" def __eq__(self, other) -> bool: - """Compares two Relay types for structural equivalence using + """Compare two Relay types for structural equivalence using alpha equivalence. """ return bool(_make._type_alpha_eq(self, other)) @@ -22,46 +23,97 @@ def same_as(self, other) -> bool: """Compares two Relay types by referential equality.""" return super().__eq__(other) + @register_relay_node class TensorType(Type): """A concrete TensorType in Relay, see tvm/relay/type.h for more details. + + This is the type assigned to tensor's with a known dype and shape. For + example a tensor of `float32` and `(5, 5)`. """ shape: List[expr.Expr] dtype: str span: Span def __init__(self, shape: List[expr.Expr], dtype: str) -> None: + """Construct a tensor type. + + Parameters + ---------- + shape: list of tvm.Expr + dtype: str + + Returns + ------- + tensor_type: The TensorType + """ self.__init_handle_by_constructor__(_make.TensorType, shape, dtype) + class Kind(IntEnum): """The kind of a type parameter, represents a variable shape, base type, type, or dimension. + + This controls what a type parameter is allowed to be instantiated + with. For example one's of kind BaseType can only be `float32`, `int32`, + and so on. """ ShapeVar = 0 Shape = 1 BaseType = 1 Type = 2 + @register_relay_node class TypeParam(Type): """A type parameter used for generic types in Relay, see tvm/relay/type.h for more details. + + A type parameter represents a type placeholder which will + be filled in later on. This allows the user to write + functions which are generic over types. """ var: expr.Var kind: Kind span: Span def __init__(self, var: expr.Var, kind: Kind) -> None: + """Construct a TypeParam. + + Parameters + ---------- + var: tvm.expr.Var + The tvm.Var which backs the type parameter. + + kind: Kind + The kind of the type parameter. + + Returns + ------- + type_param: TypeParam + The type parameter. + """ self.__init_handle_by_constructor__(_make.TypeParam, var, kind) + @register_relay_node class TypeConstraint(Type): """Abstract class representing a type constraint.""" pass + @register_relay_node class FuncType(Type): """A function type in Relay, see tvm/relay/type.h for more details. + + This is the type assigned to functions in Relay. They consist of + a list of type parameters which enable the definition of generic + fucntions, a set of type constraints which we omit for the time + being, a sequence of argument types, and a return type. + + We informally write them as: + `forall (type_params), (arg_types) -> ret_type + where type_constraints` """ type_params: List[TypeParam] type_constraints: List[TypeConstraint] @@ -70,7 +122,23 @@ class FuncType(Type): span: Span def __init__(self, arg_types: List[Type], ret_type: Type, type_params: List[TypeParam], type_constraints: List[TypeConstraint]) -> None: - self.__init_handle_by_constructor__(_make.FuncType, arg_types, ret_type, type_params, type_constraints) + """Construct a function type. + + Parameters + ---------- + arg_types: list of Type + ret_type: Type + type_params: list of TypeParam + type_constraints: list of TypeConstraint + + Returns + ------- + func_type: FuncType + The function type. + """ + self.__init_handle_by_constructor__( + _make.FuncType, arg_types, ret_type, type_params, type_constraints) + @register_relay_node class TypeCall(Type): diff --git a/tutorials/relay/implement_fma_transform.py b/tutorials/relay/implement_fma_transform.py index 8410dd6c1152..8c04e70aa846 100644 --- a/tutorials/relay/implement_fma_transform.py +++ b/tutorials/relay/implement_fma_transform.py @@ -13,11 +13,11 @@ # Introduction # ------------------------- # -# We use the fixed size for input tensors with 256 channels and 14 x 14 -# dimensions. The batch size is 256. Convolution filters contain 512 filters -# of size 3 x 3. We use stride size 1 and padding size 1 for the -# convolution. The following code defines the convolution algorithm in TVM. -# +# In this tutorial, we will demonstrate how to write a fusion pass for +# the Relay IR. We demonstrate many Relay features including defining a +# new operator, a program transform, the NNVM compatibility layer, +# and executing the original and transformed programs on the Relay +# evaluator and TVM runtime system. from typing import Any, Dict From d6ed95709eb354257d4db919fa7657e108c3b621 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 6 Sep 2018 17:11:20 -0700 Subject: [PATCH 075/136] Add skeleton for converting from NNVM models --- python/tvm/relay/from_nnvm.py | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 python/tvm/relay/from_nnvm.py diff --git a/python/tvm/relay/from_nnvm.py b/python/tvm/relay/from_nnvm.py new file mode 100644 index 000000000000..18a1112c2629 --- /dev/null +++ b/python/tvm/relay/from_nnvm.py @@ -0,0 +1,4 @@ +import nnvm + +def from_nnvm(graph): + import pdb; pdb.set_trace() From f4a3358ab97524e1515994d80e299f7f2ba39499 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 6 Sep 2018 17:27:08 -0700 Subject: [PATCH 076/136] Address more code review feedback --- include/tvm/relay/expr_functor.h | 2 +- include/tvm/relay/type.h | 5 ++--- python/tvm/relay/type.py | 4 ++-- src/relay/ir/op.cc | 10 +++------- src/relay/pass/unifier.cc | 2 +- 5 files changed, 9 insertions(+), 14 deletions(-) diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 4632733cbcfc..0d736212c9eb 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -23,7 +23,7 @@ namespace relay { * \sa tvm/ir_functor.h * * \tparam FType function signiture - * This type if only defined for FType with function signiture R(const Expr&, + * This type is only defined for FType with function signature R(const Expr&, * Args...) */ template diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 5d579b661280..54cf91cee4ec 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -131,9 +131,8 @@ class TypeParamNode : public TypeNode { kType = 3, }; /*! - * \brief The variable - * The variable itself is only meaningful when - * kind is ShapeVar, otherwise, we can only use the name. + * \brief The variable itself is only meaningful when + * kind is ShapeVar, otherwise, we only use the name. */ tvm::Var var; /*! \brief The kind of type parameter */ diff --git a/python/tvm/relay/type.py b/python/tvm/relay/type.py index cde989603929..d9fc1eff1fd0 100644 --- a/python/tvm/relay/type.py +++ b/python/tvm/relay/type.py @@ -60,8 +60,8 @@ class Kind(IntEnum): """ ShapeVar = 0 Shape = 1 - BaseType = 1 - Type = 2 + BaseType = 2 + Type = 3 @register_relay_node diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 64467004a973..064551efe9d6 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -132,12 +132,8 @@ TVM_REGISTER_API("relay.op._Register") } }); -bool IsGeneric(const Op& op) { - if (auto ty_func = op.as()) { - return ty_func->type_params.size() != 0; - } else { - return false; - } +bool IsGeneric(const FuncType & func_ty) { + return func_ty->type_params.size() != 0; } using namespace runtime; @@ -151,7 +147,7 @@ Module CompileOpsToModule(const std::vector& op_names) { for (auto op_name : op_names) { Op op = Op::Get(op_name); - if (!IsGeneric(op)) { + if (!IsGeneric(op->op_type)) { auto compiler = compiler_map[op]; std::cout << "ABOVE CALL" << std::endl; tvm::Array pair = compiler(op->name, op->op_type); diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index 4558f6a24919..2c809a574cc6 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -29,7 +29,7 @@ void UnionFindNode::insert(const IncompleteType &v) { this->uf_map.Set(v, v); } void UnionFindNode::debug() { for (auto entry : this->uf_map) { - std::cout << entry.first << " = " << entry.second << std::endl; + RELAY_LOG(INFO) << entry.first << " = " << entry.second << std::endl; } } From 142d4e3e3758f266e0459feb703ffb9838f29802 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 6 Sep 2018 19:02:34 -0700 Subject: [PATCH 077/136] Fix cpplint --- include/tvm/relay/error.h | 12 ++-- include/tvm/relay/op.h | 111 +++++++++++++++--------------- include/tvm/relay/pass.h | 2 +- include/tvm/relay/pass/alpha_eq.h | 7 +- include/tvm/relay/source_map.h | 19 +++-- include/tvm/relay/type.h | 20 +++--- src/relay/ir/environment.cc | 3 +- src/relay/ir/op.cc | 19 +++-- src/relay/op/tensor/elemwise.cc | 2 +- src/relay/op/type_relations.cc | 10 +-- src/relay/op/type_relations.h | 8 +-- src/relay/pass/incomplete_type.h | 14 ++-- src/relay/pass/kind_check.cc | 2 +- src/relay/pass/resolve.h | 8 +-- src/relay/pass/type_infer.cc | 2 +- src/relay/pass/type_subst.h | 8 +-- src/relay/pass/type_visitor.h | 10 +-- src/relay/pass/unifier.cc | 71 ++++++++++--------- src/relay/pass/unifier.h | 11 +-- src/relay/source_map.cc | 21 +++--- 20 files changed, 193 insertions(+), 167 deletions(-) diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index 433c08abfd58..055cc42936df 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -13,11 +13,11 @@ namespace tvm { namespace relay { struct Error : dmlc::Error { - Error(std::string msg) : dmlc::Error(msg) {} + explicit Error(const std::string &msg) : dmlc::Error(msg) {} }; struct InternalError : Error { - InternalError(std::string msg) : Error(msg) {} + explicit InternalError(const std::string &msg) : Error(msg) {} }; struct SpannedError { @@ -26,14 +26,14 @@ struct SpannedError { SpannedError(std::string msg, Span sp) : msg(msg), sp(sp) {} }; -// FIX, we should change spanned errors to have a method which allow them to report on the Environment, -// inverting control to error definition. +// FIX, we should change spanned errors to have a method which allow them to +// report on the Environment, inverting control to error definition. struct FatalTypeError : dmlc::Error { - explicit FatalTypeError(const std::string & s) : dmlc::Error(s) {} + explicit FatalTypeError(const std::string &s) : dmlc::Error(s) {} }; struct TypecheckerError : public dmlc::Error { - explicit TypecheckerError(const std::string &msg) : Error(msg) {} + explicit TypecheckerError(const std::string &msg) : Error(msg) {} }; } // namespace relay diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 756451e66768..2d5627f2c844 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -6,23 +6,23 @@ #ifndef TVM_RELAY_OP_H_ #define TVM_RELAY_OP_H_ +#include +#include #include -#include -#include #include -#include -#include +#include +#include +#include "../attrs.h" #include "./base.h" -#include "./type.h" #include "./expr.h" -#include "../attrs.h" +#include "./type.h" namespace tvm { namespace relay { // forward declare name. -template +template class OpMap; class GenericOpMap; class OpRegistry; @@ -103,7 +103,7 @@ class Op : public relay::Expr { * \return An OpMap of specified attr_name. * \tparam ValueType The type of the attribute. */ - template + template inline static OpMap GetAttr(const std::string& attr_name); /*! * \brief Get an Op for a given operator name. @@ -129,9 +129,7 @@ class Op : public relay::Expr { class OpRegistry { public: /*! \return the operator */ - const Op& op() const { - return op_; - } + const Op& op() const { return op_; } /*! * \brief setter function during registration * Set the description of operator @@ -146,24 +144,25 @@ class OpRegistry { * \param description Description of the argument. * \return reference to self. */ - inline OpRegistry& add_argument(const std::string &name, - const std::string &type, - const std::string &description); - /*! + inline OpRegistry& add_argument(const std::string& name, + const std::string& type, + const std::string& description); + /*! * \brief Attach the type function corresponding to the return type. * \param ty_func The type function to register for the return type. * \return reference to self. */ - inline OpRegistry& add_type_func(const std::string & type_func_name, TypeRelationFn type_fn); + inline OpRegistry& add_type_func(const std::string& type_func_name, + TypeRelationFn type_fn); - /*! + /*! * \brief Attach the type function corresponding to the return type. * \param ty_func The type function to register for the return type. * \return reference to self. */ inline OpRegistry& add_type_func( - const std::string & type_func_name, - std::function(const Array &, int)> type_fn); + const std::string& type_func_name, + std::function(const Array&, int)> type_fn); /*! * \brief Set the type key of attributes. @@ -196,10 +195,9 @@ class OpRegistry { * * \tparam ValueType The type of the value to be set. */ - template + template inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*) - const ValueType& value, - int plevel = 10); + const ValueType& value, int plevel = 10); // set the name of the op to be the same as registry inline OpRegistry& set_name() { // NOLINT(*) @@ -222,8 +220,7 @@ class OpRegistry { // return internal pointer to op. inline OpNode* get(); // update the attribute OpMap - TVM_DLL void UpdateAttr(const std::string& key, - TVMRetValue value, + TVM_DLL void UpdateAttr(const std::string& key, TVMRetValue value, int plevel); }; @@ -251,7 +248,7 @@ class GenericOpMap { * \return the const reference to the content value. * \tparam ValueType The content value type. */ - template + template inline ValueType get(const Op& op, ValueType def_value) const; private: @@ -268,7 +265,7 @@ class GenericOpMap { * \brief Map used to store meta-information about Op. * \tparam ValueType The type of the value stored in map. */ -template +template class OpMap { public: /*! @@ -294,15 +291,14 @@ class OpMap { private: friend class Op; // constructor - explicit OpMap(const GenericOpMap& map) - : map_(map) {} + explicit OpMap(const GenericOpMap& map) : map_(map) {} /*! \brief The internal map field */ const GenericOpMap& map_; }; // internal macros to make -#define RELAY_REGISTER_VAR_DEF \ - static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::OpRegistry & __make_ ## RelayOp +#define RELAY_REGISTER_VAR_DEF \ + static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::OpRegistry& __make_##RelayOp /*! * \def RELAY_REGISTER_OP @@ -319,16 +315,18 @@ class OpMap { * * \endcode */ -#define RELAY_REGISTER_OP(OpName) \ - DMLC_STR_CONCAT(RELAY_REGISTER_VAR_DEF, __COUNTER__) = \ - ::tvm::relay::OpRegistry::Registry()->__REGISTER_OR_GET__(OpName).set_name() +#define RELAY_REGISTER_OP(OpName) \ + DMLC_STR_CONCAT(RELAY_REGISTER_VAR_DEF, __COUNTER__) = \ + ::tvm::relay::OpRegistry::Registry() \ + ->__REGISTER_OR_GET__(OpName) \ + .set_name() // implementations inline const OpNode* Op::operator->() const { return static_cast(node_.get()); } -template +template inline OpMap Op::GetAttr(const std::string& key) { return OpMap(Op::GetGenericAttr(key)); } @@ -337,14 +335,15 @@ inline OpNode* OpRegistry::get() { return const_cast(op_.operator->()); } -inline OpRegistry& OpRegistry::describe(const std::string& descr) { // NOLINT(*) +inline OpRegistry& OpRegistry::describe( + const std::string& descr) { // NOLINT(*) get()->description = descr; return *this; } -inline OpRegistry& OpRegistry::add_argument(const std::string &name, - const std::string &type, - const std::string &description) { +inline OpRegistry& OpRegistry::add_argument(const std::string& name, + const std::string& type, + const std::string& description) { std::shared_ptr n = std::make_shared(); n->name = name; n->type_info = type; @@ -354,13 +353,15 @@ inline OpRegistry& OpRegistry::add_argument(const std::string &name, } inline OpRegistry& OpRegistry::add_type_func( - const std::string & type_func_name, - std::function(const Array &, int)> type_fn) { - auto pfunc = runtime::TypedPackedFunc(const Array &, int)>(type_fn); + const std::string& type_func_name, + std::function(const Array&, int)> type_fn) { + auto pfunc = + runtime::TypedPackedFunc(const Array&, int)>(type_fn); return add_type_func(type_func_name, pfunc); } -inline OpRegistry& OpRegistry::add_type_func(const std::string & type_func_name, TypeRelationFn type_fn) { +inline OpRegistry& OpRegistry::add_type_func(const std::string& type_func_name, + TypeRelationFn type_fn) { auto type_func = TypeRelationNode::make(type_func_name, 0, type_fn); std::vector type_params; @@ -397,7 +398,7 @@ inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) return *this; } -inline OpRegistry& OpRegistry::set_attrs_type_key( // NOLINT(*) +inline OpRegistry& OpRegistry::set_attrs_type_key( // NOLINT(*) const std::string& type_key) { get()->attrs_type_key = type_key; return *this; @@ -408,13 +409,10 @@ inline OpRegistry& OpRegistry::set_support_level(int32_t n) { // NOLINT(*) return *this; } -template +template inline OpRegistry& OpRegistry::set_attr( // NOLINT(*) - const std::string& attr_name, - const ValueType& value, - int plevel) { - CHECK_GT(plevel, 0) - << "plevel in set_attr must be greater than 0"; + const std::string& attr_name, const ValueType& value, int plevel) { + CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; TVMRetValue rv; rv = value; UpdateAttr(attr_name, rv, plevel); @@ -435,12 +433,12 @@ inline const TVMRetValue& GenericOpMap::operator[](const Op& op) const { CHECK(op.defined()); const uint32_t idx = op->index_; CHECK(idx < data_.size() && data_[idx].second != 0) - << "Attribute " << attr_name_ - << " has not been registered for Operator " << op->name; + << "Attribute " << attr_name_ << " has not been registered for Operator " + << op->name; return data_[idx].first; } -template +template inline ValueType GenericOpMap::get(const Op& op, ValueType value) const { CHECK(op.defined()); const uint32_t idx = op->index_; @@ -451,17 +449,18 @@ inline ValueType GenericOpMap::get(const Op& op, ValueType value) const { } } -template +template inline int OpMap::count(const Op& op) const { return map_.count(op); } -template +template inline ValueType OpMap::operator[](const Op& op) const { return map_[op]; } -template -inline ValueType OpMap::get(const Op& op, ValueType def_value) const { +template +inline ValueType OpMap::get(const Op& op, + ValueType def_value) const { return map_.get(op, def_value); } diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index f92596c41179..46419bde3f97 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -47,4 +47,4 @@ bool KindCheck(const Environment& env, const Type& t); } // namespace relay } // namespace tvm -#endif // TVM_RELAY_PASS_TYPECHECKER_H_ \ No newline at end of file +#endif // TVM_RELAY_PASS_H_ diff --git a/include/tvm/relay/pass/alpha_eq.h b/include/tvm/relay/pass/alpha_eq.h index 51b5b4dd8b70..87b5164462d7 100644 --- a/include/tvm/relay/pass/alpha_eq.h +++ b/include/tvm/relay/pass/alpha_eq.h @@ -3,8 +3,8 @@ * \file tvm/relay/alpha_eq.h * \brief Check expressions and types for structural equivalence. */ -#ifndef TVM_RELAY_ALPHA_EQ_H_ -#define TVM_RELAY_ALPHA_EQ_H_ +#ifndef TVM_RELAY_PASS_ALPHA_EQ_H_ +#define TVM_RELAY_PASS_ALPHA_EQ_H_ #include #include @@ -51,4 +51,5 @@ bool AlphaEqual(const Type& t1, const Type& t2); } // namespace relay } // namespace tvm -#endif // TVM_RELAY_ALPHA_EQ_H_ +#endif // TVM_RELAY_PASS_ALPHA_EQ_H_ + diff --git a/include/tvm/relay/source_map.h b/include/tvm/relay/source_map.h index a4dbc20b30ff..277c3875a17f 100644 --- a/include/tvm/relay/source_map.h +++ b/include/tvm/relay/source_map.h @@ -1,7 +1,7 @@ /*! * Copyright (c) 2018 by Contributors * \file source_map.h - * \brief A representation of source files and a data structure for + * \brief A representation of source files and a data structure for * storing them. */ #ifndef TVM_RELAY_SOURCE_MAP_H_ @@ -14,8 +14,15 @@ namespace tvm { namespace relay { +/*! \brief A fragment of a source file used for error reporting. + * + * These can be registered by the frontends and are used for + * displaying errors. + */ struct SourceFragment { + /*! \brief The file name which the source fragment originates from. */ std::string file_name; + /*! \brief The lines of source corresponding to the fragment. */ std::vector source_lines; SourceFragment(const std::string& file_name, const std::string& source); @@ -25,6 +32,7 @@ struct SourceFragment { this->source_lines = sf.source_lines; } + /*! \brief The lines of source code originate at lines. */ std::string SourceAt(Span sp, int lines); }; @@ -33,12 +41,15 @@ struct SourceFragment { class SourceMap { /*! \brief Map from unique token to a fragment of a source file. */ std::unordered_map map_; + public: SourceMap() : map_() {} - SourceName AddSource(std::string file_name, std::string source); - const SourceFragment & GetSource(SourceName id) const; + /*! \brief Add a source fragment with the file name and source. */ + SourceName AddSource(const std::string& file_name, const std::string& source); + /*! \brief Retrieve a source fragment by source name. */ + const SourceFragment& GetSource(SourceName id) const; }; } // namespace relay } // namespace tvm -#endif // TVM_RELAY_SOURCE_MAP_H_ \ No newline at end of file +#endif // TVM_RELAY_SOURCE_MAP_H_ diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 54cf91cee4ec..f485e0d8d62f 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -126,9 +126,9 @@ class TypeParamNode : public TypeNode { enum Kind : int { /*! \brief template variable in shape expression */ kShapeVar = 0, - kShape = 1, + kShape = 1, kBaseType = 2, - kType = 3, + kType = 3, }; /*! * \brief The variable itself is only meaningful when @@ -200,8 +200,8 @@ class FuncTypeNode : public TypeNode { } TVM_DLL static FuncType make(tvm::Array arg_types, Type ret_type, - tvm::Array type_params, - tvm::Array type_constraints); + tvm::Array type_params, + tvm::Array type_constraints); static constexpr const char* _type_key = "relay.FuncType"; TVM_DECLARE_NODE_TYPE_INFO(FuncTypeNode, TypeNode); @@ -209,7 +209,8 @@ class FuncTypeNode : public TypeNode { RELAY_DEFINE_NODE_REF(FuncType, FuncTypeNode, Type); -using TypeRelationFn = runtime::TypedPackedFunc(const Array&, int)>; +using TypeRelationFn = + runtime::TypedPackedFunc(const Array&, int)>; /*! * \brief Opaque type relation, is an input-output relation on types. @@ -238,7 +239,8 @@ class TypeRelationNode : public RelayNode { v->Visit("num_args", &num_args); } - TVM_DLL static TypeRelation make(std::string name, int num_args, TypeRelationFn func_); + TVM_DLL static TypeRelation make(std::string name, int num_args, + TypeRelationFn func_); static constexpr const char* _type_key = "relay.TypeRelation"; TVM_DECLARE_NODE_TYPE_INFO(TypeRelationNode, RelayNode); @@ -257,7 +259,7 @@ class TypeCallNode : public TypeNode { public: /*! \brief The type function to be called. */ Type func; - + /*! \brief The type arguments to the type function. */ tvm::Array args; @@ -290,9 +292,7 @@ class TupleTypeNode : public TypeNode { TupleTypeNode() {} - void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("fields", &fields); - } + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); } TVM_DLL static TupleType make(tvm::Array fields); diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc index db7f11fb9e2b..b5f0d663d26a 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/environment.cc @@ -149,7 +149,7 @@ void EnvironmentNode::Transform(EnvironmentNode::Transformer transformer) { to_process.push_back(var_and_func.first); } - auto for_each = transformer(GetRef(this)); + auto for_each = transformer(GetRef(this)); for (auto var : to_process) { auto func = this->functions[var]; auto transformed = for_each(var, func); @@ -157,7 +157,6 @@ void EnvironmentNode::Transform(EnvironmentNode::Transformer transformer) { } } - TVM_REGISTER_API("relay._make.Environment") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = EnvironmentNode::make(args[0]); diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 064551efe9d6..18a647798c9e 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -1,12 +1,18 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/relay/op.cc + * \brief Resolve incomplete types to complete types. + */ #include #include #include #include -#include "./../pass/type_subst.h" #include #include +#include "./../pass/type_subst.h" + namespace dmlc { // enable registry DMLC_REGISTRY_ENABLE(::tvm::relay::OpRegistry); @@ -132,7 +138,7 @@ TVM_REGISTER_API("relay.op._Register") } }); -bool IsGeneric(const FuncType & func_ty) { +bool IsGeneric(const FuncType& func_ty) { return func_ty->type_params.size() != 0; } @@ -181,12 +187,12 @@ TVM_REGISTER_API("relay.op._CompileOpsToModule") *ret = CompileOpsToModule(names); }); -Op SpecializeOp(const std::string& op_name, - const std::string& new_op_name, Array type_args) { +Op SpecializeOp(const std::string& op_name, const std::string& new_op_name, + Array type_args) { auto registry = ::tvm::relay::OpRegistry::Registry(); auto op_reg = registry->__REGISTER_OR_GET__(op_name); auto new_op_reg = registry->__REGISTER__(new_op_name).set_name(); - + auto fn_ty = op_reg.op()->op_type; tvm::Map subst_map; @@ -205,7 +211,8 @@ Op SpecializeOp(const std::string& op_name, new_op_reg.op()->op_type = new_op_ty; // Now we want to copy over some attributes. - PackedFunc compiler = Op::GetAttr("FRelayOpCompiler")[op_reg.op()]; + PackedFunc compiler = + Op::GetAttr("FRelayOpCompiler")[op_reg.op()]; new_op_reg.set_attr("FRelayOpCompiler", compiler); return new_op_reg.op(); diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc index 76adfbbfb968..d6a04773b7fa 100644 --- a/src/relay/op/tensor/elemwise.cc +++ b/src/relay/op/tensor/elemwise.cc @@ -120,5 +120,5 @@ RELAY_REGISTER_OP("equal") .set_support_level(1) .add_type_func("BroadcastComp", BroadcastCompRel); -} // namespace relayv +} // namespace relay } // namespace tvm diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 32d81a1d445e..2a6efbcf71e4 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -29,7 +29,7 @@ int to_int(const tvm::Expr& e) { } Array IdentityRel(const Array& types, int num_args) { - CHECK(types.size() == 2); + CHECK_EQ(types.size(), 2); auto t1 = as_ttype(types[0]); if (t1 && types[1].as()) { return {t1, t1}; @@ -88,7 +88,7 @@ static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, smaller = sh2; } - CHECK(larger.size() == smaller.size()); + CHECK_EQ(larger.size(), smaller.size()); Array out_shape; for (int i = 0; i < smaller.size(); i++) { @@ -105,11 +105,11 @@ static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, } Array BroadcastRel(const Array& types, int num_args) { - CHECK(types.size() == 3); + CHECK_EQ(types.size(), 3); if (auto t1 = as_ttype(types[0])) { if (auto t2 = as_ttype(types[1])) { std::cout << t1->dtype << t2->dtype << std::endl; - CHECK(t1->dtype == t2->dtype); + CHECK_EQ(t1->dtype, t2->dtype); return {t1, t2, ConcreteBroadcast(t1, t2, t1->dtype)}; } } @@ -121,7 +121,7 @@ Array BroadcastRel(const Array& types, int num_args) { compute boolean results. */ Array BroadcastCompRel(const Array& types, int num_args) { - CHECK(types.size() == 3); + CHECK_EQ(types.size(), 3); if (auto t1 = as_ttype(types[0])) { if (auto t2 = as_ttype(types[1])) { return {t1, t2, ConcreteBroadcast(t1, t2, HalideIR::Bool())}; diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h index 71c98fef7da1..3597246b5a4a 100644 --- a/src/relay/op/type_relations.h +++ b/src/relay/op/type_relations.h @@ -4,11 +4,11 @@ * \brief A set of utilities and common functionality * for type relations. */ -#ifndef TVM_RELAY_TYPECK_RESOLVE_H_ -#define TVM_RELAY_TYPECK_RESOLVE_H_ +#ifndef TVM_RELAY_OP_TYPE_RELATIONS_H_ +#define TVM_RELAY_OP_TYPE_RELATIONS_H_ -#include #include +#include namespace tvm { namespace relay { @@ -20,4 +20,4 @@ Array BroadcastCompRel(const Array & types, int num_args); } // namespace relay } // namespace tvm -#endif // TVM_RELAY_TYPECK_RESOLVE_H_ +#endif // TVM_RELAY_OP_TYPE_RELATIONS_H_ diff --git a/src/relay/pass/incomplete_type.h b/src/relay/pass/incomplete_type.h index 3967b4e58657..78771dc6e9b7 100644 --- a/src/relay/pass/incomplete_type.h +++ b/src/relay/pass/incomplete_type.h @@ -4,8 +4,8 @@ * \brief A way to defined arbitrary function signature with dispatch on types. */ -#ifndef TVM_RELAY_PASS_INCOMPLETE_TYPE_H -#define TVM_RELAY_PASS_INCOMPLETE_TYPE_H +#ifndef TVM_RELAY_PASS_INCOMPLETE_TYPE_H_ +#define TVM_RELAY_PASS_INCOMPLETE_TYPE_H_ #include @@ -22,9 +22,7 @@ class IncompleteTypeNode : public TypeNode { public: TypeParamNode::Kind kind; - void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("kind", &kind); - } + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("kind", &kind); } TVM_DLL static IncompleteType make(TypeParamNode::Kind kind); @@ -34,7 +32,7 @@ class IncompleteTypeNode : public TypeNode { RELAY_DEFINE_NODE_REF(IncompleteType, IncompleteTypeNode, Type); -} // namespace relay -} // namespace tvm +} // namespace relay +} // namespace tvm -#endif // TVM_RELAY_PASS_INCOMPLETE_TYPE_H +#endif // TVM_RELAY_PASS_INCOMPLETE_TYPE_H_ diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index c3823c8c3a35..522eb93483fb 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -39,4 +39,4 @@ bool KindCheck(const Environment& env, const Type &t) { } } // namespace relay -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/src/relay/pass/resolve.h b/src/relay/pass/resolve.h index 495c9658238a..deb6558322b8 100644 --- a/src/relay/pass/resolve.h +++ b/src/relay/pass/resolve.h @@ -3,11 +3,11 @@ * \file tvm/relay/resolve.h * \brief Resolve incomplete types to complete types. */ -#ifndef TVM_RELAY_TYPECK_RESOLVE_H_ -#define TVM_RELAY_TYPECK_RESOLVE_H_ +#ifndef TVM_RELAY_PASS_RESOLVE_H_ +#define TVM_RELAY_PASS_RESOLVE_H_ -#include #include +#include #include "./unifier.h" namespace tvm { @@ -20,4 +20,4 @@ bool IsFullyResolved(const Type & t); } // namespace relay } // namespace tvm -#endif // TVM_RELAY_TYPECK_RESOLVE_H_ +#endif // TVM_RELAY_PASS_RESOLVE_H_ diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 6cc73d1b8fbe..df896fa3936a 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -60,7 +60,7 @@ struct TypeContext { struct TypeNormalizer : TypeFVisitor { TypeUnifier unifier; - TypeNormalizer(const TypeUnifier &unifier) : unifier(unifier) {} + explicit TypeNormalizer(const TypeUnifier &unifier) : unifier(unifier) {} Type VisitType_(const TypeCallNode *ty_call_node) { auto ty_call = GetRef(ty_call_node); diff --git a/src/relay/pass/type_subst.h b/src/relay/pass/type_subst.h index 3c248fdce3b7..5b6956f8e451 100644 --- a/src/relay/pass/type_subst.h +++ b/src/relay/pass/type_subst.h @@ -1,10 +1,10 @@ /*! * Copyright (c) 2018 by Contributors - * \file typeck/type_subst.h + * \file src/tvm/relay/pass/type_subst.h * \brief Utility functions for substituting types. */ -#ifndef TVM_RELAY_TYPECK_TYPE_SUBST_H_ -#define TVM_RELAY_TYPECK_TYPE_SUBST_H_ +#ifndef TVM_RELAY_PASS_TYPE_SUBST_H_ +#define TVM_RELAY_PASS_TYPE_SUBST_H_ #include @@ -16,4 +16,4 @@ Type TypeSubst(const Type &type, tvm::Map subst_map); } // namespace relay } // namespace tvm -#endif // TVM_RELAY_TYPECK_TYPE_SUBST_H_ +#endif // TVM_RELAY_PASS_TYPE_SUBST_H_ diff --git a/src/relay/pass/type_visitor.h b/src/relay/pass/type_visitor.h index f3c0f9a74fb7..d65d6c567b23 100644 --- a/src/relay/pass/type_visitor.h +++ b/src/relay/pass/type_visitor.h @@ -3,8 +3,8 @@ * \file type_visitor.h * \brief A wrapper around TypeFunctor for common use cases. */ -#ifndef TVM_RELAY_TYPE_VISITOR_H_ -#define TVM_RELAY_TYPE_VISITOR_H_ +#ifndef TVM_RELAY_PASS_TYPE_VISITOR_H_ +#define TVM_RELAY_PASS_TYPE_VISITOR_H_ #include #include "./type_functor.h" @@ -54,7 +54,7 @@ struct TypeVisitor : ::tvm::relay::TypeFunctor { // A functional visitor for rebuilding an AST in place. struct TypeFVisitor : TypeFunctor { Type VisitType_(const TensorTypeNode* op) override { - // TODO (@jroesch): maybe we should recursively visit + // TODO(@jroesch): maybe we should recursively visit return TensorTypeNode::make(op->shape, op->dtype); } @@ -63,7 +63,7 @@ struct TypeFVisitor : TypeFunctor { } Type VisitType_(const FuncTypeNode* op) override { - // TODO (@jroesch): handle poly + // TODO(@jroesch): handle poly // auto new_id = this->VisitType(op->var); // if (const TypeParamNode* tin = new_id.as()) { @@ -107,4 +107,4 @@ struct TypeFVisitor : TypeFunctor { } // namespace relay } // namespace tvm -#endif // TVM_RELAY_TYPE_VISITOR_H_ +#endif // TVM_RELAY_PASS_TYPE_VISITOR_H_ diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index 2c809a574cc6..7735ca8b0482 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -5,11 +5,11 @@ * incomplete types. */ -#include +#include "./unifier.h" #include #include #include -#include "./unifier.h" +#include #include "./type_visitor.h" // #include "tvm/relay/typeck/kindchecker.h" @@ -33,7 +33,7 @@ void UnionFindNode::debug() { } } -void UnionFindNode::AssertAlphaEqual(const Type & l, const Type & r) { +void UnionFindNode::AssertAlphaEqual(const Type &l, const Type &r) { if (!AlphaEqual(l, r)) { std::stringstream ss; ss << "Incompatible parent types in UF:" << l << " and " << r; @@ -141,7 +141,7 @@ Type TypeUnifierNode::unify(const Type &t1, const Type &t2) { Type unified = this->VisitType(t1, t2); // if (!check_kind(unified)) { - // throw UnificationError("Invalid kinds in unified type"); + // throw UnificationError("Invalid kinds in unified type"); // } return unified; } @@ -167,32 +167,34 @@ Type TypeUnifierNode::subst(const Type &t) { // normalize first so substitutions in quantifiers will be correct Type ret = tvsubst.VisitType(t); // if (!check_kind(ret)) { - // std::stringstream ss; - // ss << "Invalid Kinds in substituted type!"; - // ss << t << std::endl; - // ss << ret << std::endl; - // throw SubstitutionError(ss.str()); + // std::stringstream ss; + // ss << "Invalid Kinds in substituted type!"; + // ss << t << std::endl; + // ss << ret << std::endl; + // throw SubstitutionError(ss.str()); // } return ret; } -Type TypeUnifierNode::VisitType(const Type & t1, const Type t2) { +Type TypeUnifierNode::VisitType(const Type &t1, const Type t2) { // When the right hand size is a type variable immediately unify. if (const IncompleteTypeNode *tvn2 = t2.as()) { return this->unifyWithIncompleteType(t1, GetRef(tvn2)); - // The TypeCallNode case is special and not symmetric. - // - // We flip the arguments so we hit the TypeCall and other case in there is - // ever a type call. + // The TypeCallNode case is special and not symmetric. + // + // We flip the arguments so we hit the TypeCall and other case in there is + // ever a type call. } else if (const TypeCallNode *tvn2 = t2.as()) { - return TypeFunctor::VisitType(t2, t1); + return TypeFunctor::VisitType(t2, t1); } else { - return TypeFunctor::VisitType(t1, t2); + return TypeFunctor::VisitType(t1, t2); } } -Type TypeUnifierNode::unifyWithIncompleteType(const Type &t1, const IncompleteType tv2) { - RELAY_LOG(INFO) << "unifyWithIncompleteType: t1=" << t1 << " t2=" << tv2 << std::endl; +Type TypeUnifierNode::unifyWithIncompleteType(const Type &t1, + const IncompleteType tv2) { + RELAY_LOG(INFO) << "unifyWithIncompleteType: t1=" << t1 << " t2=" << tv2 + << std::endl; // Fix unify to return new representative this->uf->unify(tv2, t1); auto rep = this->uf->find(tv2); @@ -235,7 +237,8 @@ Type TypeUnifierNode::VisitType_(const FuncTypeNode *t1, const Type rt2) { FuncType ft2 = GetRef(tan2); if (ft1->type_params.size() != ft2->type_params.size()) { - throw UnificationError("unable to unify functions with differing number of type parameters"); + throw UnificationError( + "unable to unify functions with differing number of type parameters"); } if (ft1->type_params.size() != 0) { @@ -282,7 +285,7 @@ Type TypeUnifierNode::VisitType_(const TensorTypeNode *t1, const Type rt2) { TensorType tt2 = GetRef(ttn2); if (!AlphaEqual(tt1, tt2)) { - throw UnificationError("dtypes do not match"); + throw UnificationError("dtypes do not match"); } RELAY_LOG(INFO) << "Unify Tensor Shape s1=" << tt1->shape @@ -290,8 +293,9 @@ Type TypeUnifierNode::VisitType_(const TensorTypeNode *t1, const Type rt2) { try { // Type unified_shape = this->VisitType(tt1->shape, tt2->shape); return rt2; - } catch (const UnificationError & err) { - std::cout << "Need to check constraint " << tt1->shape << " = " << tt2->shape << std::endl; + } catch (const UnificationError &err) { + std::cout << "Need to check constraint " << tt1->shape << " = " + << tt2->shape << std::endl; } // fix me @@ -328,15 +332,16 @@ Type TypeUnifierNode::VisitType_(const TupleTypeNode *t1, const Type rt2) { } Type TypeUnifierNode::VisitType_(const TypeRelationNode *tr1, const Type t2) { - if (const TypeRelationNode *tr2 = t2.as()) { - if (tr1 == tr2) { - return GetRef(tr1); - } else { - throw UnificationError("Cannot unify different type relations"); - } - } else { - throw UnificationError("Cannot unify type relation with another type of type"); - } + if (const TypeRelationNode *tr2 = t2.as()) { + if (tr1 == tr2) { + return GetRef(tr1); + } else { + throw UnificationError("Cannot unify different type relations"); + } + } else { + throw UnificationError( + "Cannot unify type relation with another type of type"); + } } Type TypeUnifierNode::VisitType_(const TypeCallNode *tcn1, const Type t2) { @@ -347,7 +352,8 @@ Type TypeUnifierNode::VisitType_(const TypeCallNode *tcn1, const Type t2) { // For now, we will only unify if they are equal. if (ty_call1->args.size() != tcn2->args.size()) { - throw UnificationError("Cannot unify calls of different number of arguments"); + throw UnificationError( + "Cannot unify calls of different number of arguments"); } // Unify members, if possible @@ -364,6 +370,5 @@ Type TypeUnifierNode::VisitType_(const TypeCallNode *tcn1, const Type t2) { } } - } // namespace relay } // namespace tvm diff --git a/src/relay/pass/unifier.h b/src/relay/pass/unifier.h index 64485768c2f0..0671a40c0d74 100644 --- a/src/relay/pass/unifier.h +++ b/src/relay/pass/unifier.h @@ -7,8 +7,8 @@ #ifndef TVM_RELAY_PASS_UNIFIER_H_ #define TVM_RELAY_PASS_UNIFIER_H_ -#include #include +#include #include "./type_functor.h" namespace tvm { @@ -62,7 +62,7 @@ class UnionFind : public NodeRef { explicit UnionFind(std::shared_ptr p) : NodeRef(p) {} // The union find structure is mutable so we do not use the standard macros - // and expose the pointer via `->`. + // and expose the pointer via `->`. UnionFindNode* operator->() const { return static_cast(node_.get()); } @@ -102,8 +102,9 @@ class TypeUnifierNode : public Node, private: /*! \brief Unify incomplete type with another type. */ Type unifyWithIncompleteType(const Type& t1, const IncompleteType tvn2); - /*! \brief Implements unification between two types with incomplete portions. */ - Type VisitType(const Type & t1, const Type t2) override; + /*! \brief Implements unification between two types with incomplete portions. + */ + Type VisitType(const Type& t1, const Type t2) override; // Visitor Cases Type VisitType_(const IncompleteTypeNode* t1, const Type t2) override; @@ -130,4 +131,4 @@ class TypeUnifier : public NodeRef { } // namespace relay } // namespace tvm -#endif // TVM_RELAY_TYPECK_UNIFIER_H_ +#endif // TVM_RELAY_PASS_UNIFIER_H_ diff --git a/src/relay/source_map.cc b/src/relay/source_map.cc index d784c7946954..9d3316cf38cf 100644 --- a/src/relay/source_map.cc +++ b/src/relay/source_map.cc @@ -14,15 +14,18 @@ namespace relay { using tvm::IRPrinter; using namespace tvm::runtime; -SourceFragment::SourceFragment(const std::string& file_name, const std::string& source) +SourceFragment::SourceFragment(const std::string& file_name, + const std::string& source) : file_name(file_name), source_lines({}) { - RELAY_LOG(INFO)<< "SourceFragment::SourceFragment source=" << source << std::endl; + RELAY_LOG(INFO) << "SourceFragment::SourceFragment source=" << source + << std::endl; std::stringstream source_stream; source_stream.str(source.c_str()); std::string line; while (std::getline(source_stream, line)) { - RELAY_LOG(INFO) << "SourceFragment::SourceFragment: line=" << line << std::endl; + RELAY_LOG(INFO) << "SourceFragment::SourceFragment: line=" << line + << std::endl; std::string copy(line); source_lines.push_back(copy); } @@ -38,7 +41,8 @@ std::string SourceFragment::SourceAt(Span sp, int max_lines = 1) { throw dmlc::Error("SourceFragment: index out of bounds"); } - auto lines = std::max(static_cast(max_lines), source_lines.size() - starting_line); + auto lines = std::max(static_cast(max_lines), + source_lines.size() - starting_line); for (size_t i = 0; i < lines; i++) { out << std::endl << this->source_lines.at(starting_line + i); @@ -46,11 +50,12 @@ std::string SourceFragment::SourceAt(Span sp, int max_lines = 1) { auto source_slice = out.str(); - RELAY_LOG(INFO) << "SourceFragment::SourceAt: source_slice=" << source_slice << std::endl; + RELAY_LOG(INFO) << "SourceFragment::SourceAt: source_slice=" << source_slice + << std::endl; return source_slice; } -SourceName SourceMap::AddSource(std::string file_name, std::string source) { +SourceName SourceMap::AddSource(const std::string & file_name, const std::string & source) { auto new_id = SourceNameNode::make(file_name); SourceFragment sfile(file_name, source); this->map_.insert({new_id, sfile}); @@ -62,9 +67,9 @@ const SourceFragment& SourceMap::GetSource(SourceName id) const { if (item != map_.end()) { return (*item).second; } else { - throw dmlc::Error("could not find requested source fragment"); + throw dmlc::Error("could not find requested source fragment"); } } } // namespace relay -} // namespace tvm \ No newline at end of file +} // namespace tvm From e0c61437d747822c64fe75ba6553a9992f9bf2a8 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 6 Sep 2018 19:11:31 -0700 Subject: [PATCH 078/136] Fix pylint --- nnvm/python/nnvm/_base.py | 7 +- python/tvm/relay/__init__.py | 10 +-- python/tvm/relay/base.py | 1 - python/tvm/relay/env.py | 14 ++- python/tvm/relay/expr.py | 65 ++++++++++---- python/tvm/relay/from_nnvm.py | 5 +- python/tvm/relay/ir_builder.py | 87 ++++++++++++------- python/tvm/relay/ir_pass.py | 23 +++-- python/tvm/relay/op/__init__.py | 1 + python/tvm/relay/op/_tensor.py | 20 +++-- python/tvm/relay/op/op.py | 19 ++-- python/tvm/relay/op/tensor.py | 5 +- python/tvm/relay/to_tvm.py | 3 +- python/tvm/relay/type.py | 34 ++++---- python/tvm/tensor.py | 5 ++ .../relay/test_tyck_eval_integration.py | 2 +- 16 files changed, 191 insertions(+), 110 deletions(-) diff --git a/nnvm/python/nnvm/_base.py b/nnvm/python/nnvm/_base.py index 63b2f815ad9b..29390a2201bf 100644 --- a/nnvm/python/nnvm/_base.py +++ b/nnvm/python/nnvm/_base.py @@ -22,12 +22,7 @@ numeric_types = (float, int, np.float32, np.int32) # this function is needed for python3 # to convert ctypes.char_p .value back to python str - def py_str(x): - try: - return x.decode('utf-8') - except: - print(x) - # py_str = lambda x: x.decode('utf-8') + py_str = lambda x: x.decode('utf-8') else: string_types = basestring numeric_types = (float, int, long, np.float32, np.int32) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index c36b9bcf8357..aae019c8d9c1 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -1,8 +1,12 @@ +# pylint: disable=wildcard-import """The Relay IR namespace containing the IR definition and compiler.""" from . import base from . import type as tpe from . import expr -from . import op + +# Operators +from .op import Op +from .op.tensor import * # Span Span = base.Span @@ -26,7 +30,3 @@ Let = expr.Let If = expr.If Var = LocalVar - -# Operators -from .op import Op -from .op.tensor import * diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index ee818617f629..0f3d2bc58d71 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -1,7 +1,6 @@ # pylint: disable=no-else-return, unidiomatic-typecheck """The base node types for the Relay language.""" from __future__ import absolute_import as _abs -from typing import Union from .._ffi.node import NodeBase, register_node as _register_tvm_node from . import _make diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py index 86c9ac794b4e..beef6fd1a62c 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/env.py @@ -1,28 +1,26 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import """A global environment storing everything needed to interpret or compile a Realy program.""" -from typing import Union, List from .base import register_relay_node, NodeBase from . import _make from . import _env -import tvm @register_relay_node class Environment(NodeBase): - """The global Relay environment containing definitions, - primitives, options, and more. + """The global Relay environment containing functions, + options and more. """ def __init__(self, funcs) -> None: self.__init_handle_by_constructor__(_make.Environment, funcs) - + def add(self, var, func) -> None: if isinstance(var, str): var = _env.Environment_GetGlobalVar(self, var) _env.Environment_Add(self, var, func) - + def merge(self, other): return _env.Environment_Merge(self, other) - + def global_var(self, var): return _env.Environment_GetGlobalVar(self, var) @@ -31,6 +29,6 @@ def lookup(self, var): return _env.Environment_Lookup_str(self, var) else: return _env.Environment_Lookup(self, var) - + def transform(self, transformer): _env.Environment_Transform(self, transformer) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 1558853c2820..748b2aa1e282 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -1,20 +1,22 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """The expression nodes of Relay.""" +from typing import List import tvm -from typing import Tuple as PyTuple, List -from enum import IntEnum from .base import Span, NodeBase, register_relay_node from .type import Type, TypeParam -from tvm import expr from ._ir_pass import _get_checked_type from . import _make + class ExprBuilder(): - # def convert_args(self, + """A set of methods useful for building expressions + from other expressions. + """ def __call__(self, *args, **kwargs): converted_args = [] for arg in args: - import pdb; pdb.set_trace() + import pdb + pdb.set_trace() if isinstance(arg, Param): converted_args.append(arg.var) else: @@ -22,11 +24,14 @@ def __call__(self, *args, **kwargs): return Call(self, args, None, None) + class Expr(NodeBase, ExprBuilder): """The base type for all Relay exprressions.""" + def checked_type(self): return _get_checked_type(self) + @register_relay_node class Constant(Expr): """A constant tensor in Relay, see tvm/relay/type.h for more details. @@ -36,6 +41,7 @@ class Constant(Expr): def __init__(self, data: tvm.nd.NDArray) -> None: self.__init_handle_by_constructor__(_make.Constant, data) + @register_relay_node class Tuple(Expr): """A hetereogenous sequence of values. @@ -55,6 +61,7 @@ class LocalVar(Expr): def __init__(self, name_hint: str) -> None: self.__init_handle_by_constructor__(_make.LocalVar, name_hint) + @register_relay_node class GlobalVar(Expr): """A global variable in Relay.""" @@ -63,6 +70,7 @@ class GlobalVar(Expr): def __init__(self, name_hint: str) -> None: self.__init_handle_by_constructor__(_make.GlobalVar, name_hint) + @register_relay_node class Param(Expr): """A function type in Relay, see tvm/relay/type.h for more details. @@ -70,47 +78,66 @@ class Param(Expr): var: LocalVar type: Type - def __init__(self, var: LocalVar, type: Type) -> None: - self.__init_handle_by_constructor__(_make.Param, var, type) + def __init__(self, var: LocalVar, ty: Type) -> None: + self.__init_handle_by_constructor__(_make.Param, var, ty) @register_relay_node class Function(Expr): + """A function in Relay, see tvm/relay/expr.h for more details.""" type_params: List[TypeParam] params: List[Param] ret_type: Type body: Expr - def __init__(self, params: List[Param], ret_type: Type, body: Expr, type_params: List[TypeParam]=[]) -> None: - self.__init_handle_by_constructor__(_make.Function, params, ret_type, body, type_params) + def __init__(self, + params: List[Param], + ret_type: Type, + body: Expr, + type_params: List[TypeParam] = None) -> None: + if not type_params: + type_params = [] + self.__init_handle_by_constructor__( + _make.Function, params, ret_type, body, type_params) + @register_relay_node class Call(Expr): - op: Expr - args: List[Expr] - # todo(@jroesch): add attrs + """A function call in Relay, see tvm/relay/expr.h for more details.""" + op: Expr + args: List[Expr] + # todo(@jroesch): add attrs + + def __init__(self, op: Expr, args: List[Expr], attrs, ty_args=None) -> None: + if not ty_args: + ty_args = [] + + self.__init_handle_by_constructor__( + _make.Call, op, args, attrs, ty_args) - def __init__(self, op: Expr, args: List[Expr], attrs, ty_args) -> None: - self.__init_handle_by_constructor__(_make.Call, op, args, attrs, ty_args) @register_relay_node class Let(Expr): + """A variable bindings in Relay, see tvm/relay/expr.h for more details.""" var: LocalVar value: Expr body: Expr - value_type: Type # should be type nanotation + # should be type annotation + value_type: Type def __init__(self, var: LocalVar, value: Expr, body: Expr, value_type: Type) -> None: - self.__init_handle_by_constructor__(_make.Let, var, value, body, value_type) + self.__init_handle_by_constructor__( + _make.Let, var, value, body, value_type) + @register_relay_node class If(Expr): + """A conditional expression in Relay, see tvm/relay/expr.h for more details.""" cond: Expr true_value: Expr false_value: Expr span: Span def __init__(self, cond: Expr, true_value: Expr, false_value: Expr) -> None: - self.__init_handle_by_constructor__(_make.If, cond, true_value, false_value) - - + self.__init_handle_by_constructor__( + _make.If, cond, true_value, false_value) diff --git a/python/tvm/relay/from_nnvm.py b/python/tvm/relay/from_nnvm.py index 18a1112c2629..9700ea955f59 100644 --- a/python/tvm/relay/from_nnvm.py +++ b/python/tvm/relay/from_nnvm.py @@ -1,4 +1,7 @@ +#pylint: disable-all +"""Convert an nnvm.graph.Graph into a tvm.relay.Expr""" import nnvm def from_nnvm(graph): - import pdb; pdb.set_trace() + """Convert an nnvm.graph.Graph into a tvm.relay.Expr""" + raise Exception("NYI") diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 098eb474c6ee..a271a537b290 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -1,10 +1,14 @@ +"""IR builder for the Relay IR. + +Enables users to construct Relay programs with a Python API. +""" from typing import Any import numpy as np import tvm from .type import FuncType, TensorType -from .expr import Expr, Call, Constant, Let, LocalVar, Param, Function, If +from .expr import Expr, Constant, Let, LocalVar, Param, Function, If from .env import Environment -from . import op as _op + def convert(arg: Any, ctxt=tvm.cpu(0)) -> tvm.nd.NDArray: """Convert Python values into the appropriate types @@ -24,6 +28,7 @@ def convert(arg: Any, ctxt=tvm.cpu(0)) -> tvm.nd.NDArray: # raise Exception(f"can't convert {type(arg)} to a Relay AST") raise Exception(f"unsupported argument type {type(arg)}") + def into_ast(arg: Any, ctxt=tvm.cpu(0)) -> Expr: if isinstance(arg, Expr): return arg @@ -35,6 +40,7 @@ def into_ast(arg: Any, ctxt=tvm.cpu(0)) -> Expr: value = convert(arg, ctxt) return Constant(value) + class WithScope(object): """A wrapper for builder methods which introduce scoping.""" @@ -53,6 +59,7 @@ def __exit__(self, ptype, value, trace): class PartialFunc(): + """A wrapper around functions while they are being built.""" def __init__(self, params, ret_type, body, type_params): self.params = params self.ret_type = ret_type @@ -69,15 +76,20 @@ def to_func(self): self.body, self.type_params) - +#pylint: disable=invalid-name def _mk_let(bindings, ret_value): let_expr = ret_value - for var, value in reversed(list(bindings.items())): - let_expr = Let(var, value, let_expr, None) + for var, value, ty in reversed(list(bindings.items())): + let_expr = Let(var, value, let_expr, ty) return let_expr + class IRBuilder(): + """The IRBuilder class. + + Enables users to build up a Relay environment and program. + """ def __init__(self): self.bindings = [{}] self.scopes = [{}] @@ -85,13 +97,15 @@ def __init__(self): self.ret_values = [None] self.env = Environment({}) - def enter_scope(self, params=[]): + def enter_scope(self, params=None): + if not params: + params = [] + self.bindings.append({}) self.scopes.append({}) self.params.append(params) self.ret_values.append(None) - def exit_scope(self): bindings = self.bindings.pop() scopes = self.scopes.pop() @@ -99,14 +113,13 @@ def exit_scope(self): ret_value = self.ret_values.pop() return bindings, scopes, params, ret_value - - def bind(self, name, type, value): + #pylint: disable=invalid-name + def bind(self, name, ty, value): lv = LocalVar(name) self.scopes[-1][name] = lv - self.bindings[-1][lv] = value + self.bindings[-1][lv] = (value, ty) return lv - def let(self, name, value, value_type=None): if isinstance(value, Param): value = value.var @@ -117,6 +130,7 @@ def let(self, name, value, value_type=None): return self.bind(name, value_type, value) def function(self, *params): + """Construct a Relay function.""" relay_params = [] for param in params: name = param.var @@ -131,13 +145,12 @@ def function(self, *params): pfunc = PartialFunc(relay_params, None, None, []) def _on_exit(): - bindings, scope, params, ret_value = self.exit_scope() + bindings, _, _, ret_value = self.exit_scope() body = _mk_let(bindings, ret_value) pfunc.body = body return WithScope(pfunc, _on_exit) - def ret(self, x): if not self.ret_values[-1]: self.ret_values[-1] = into_ast(x) @@ -146,6 +159,7 @@ def ret(self, x): "return value already set, a function can only have one return value") def if_scope(self, cond): + """Construct the if branch an if expression with scoping.""" self.enter_scope() def _on_exit(): @@ -153,29 +167,30 @@ def _on_exit(): assert self.ret_values[-1] is None true_branch = _mk_let(bindings, ret_value) self.ret_values[-1] = If(cond, true_branch, None) - + return WithScope(10, _on_exit) - def else_scope(self): + """Construct the else branch of an if expression with scoping.""" self.enter_scope() def _on_exit(): bindings, _, _, ret_value = self.exit_scope() partial_if = self.ret_values[-1] - assert isinstance(partial_if, If) and partial_if.false_value is None + assert isinstance( + partial_if, If) and partial_if.false_value is None false_branch = _mk_let(bindings, ret_value) self.ret_values[-1] = If( - partial_if.cond, - partial_if.true_value, + partial_if.cond, + partial_if.true_value, false_branch) - + return WithScope(10, _on_exit) def param(self, name, ty=None): if not ty: ty = float_type() - + return Param(LocalVar(name), ty) # def params(*args): @@ -183,7 +198,7 @@ def param(self, name, ty=None): # while i < args.size(): # arg = args[i] # if isinstance(arg, str): - + def global_var(self, name: str): return self.env.global_var(name) @@ -197,8 +212,8 @@ def _on_exit(): return WithScope(10, _on_exit) - # def while_loop(cond) + def get(self): """Get the full program""" bindings = self.bindings.pop() @@ -215,33 +230,47 @@ def get(self): return _mk_let(bindings, self.ret_values[-1]), self.env + def bool_dtype(): return 'uint1' + def int_dtype(bits=32): return f'int{bits}' + def float_dtype(bits=32): return f'float{bits}' + def uint_dtype(bits=32): return f'uint{bits}' - -def int_type(bits=32, lanes=1): + + +def int_type(bits=32, _lanes=1): # TODO(@jroesch, @tqchen) How do we set lanes? return TensorType(tvm.convert([]), int_dtype(bits)) -def uint_type(bits=32, lanes=1): + +def uint_type(bits=32, _lanes=1): return TensorType(tvm.convert([]), uint_dtype(bits)) -def float_type(bits=32, lanes=1): + +def float_type(bits=32, _lanes=1): return TensorType(tvm.convert([]), float_dtype(bits)) -def bool_type(lanes=1): - return TensorType(tvm.convert([]), bool_dtype(bits)) + +def bool_type(_lanes=1): + return TensorType(tvm.convert([]), bool_dtype()) + def tensor_type(*shape, dtype='float32'): return TensorType(tvm.convert(shape), dtype) -def func_type(args, ret_type, type_params=[], type_constraints=[]): + +def func_type(args, ret_type, type_params=None, type_constraints=None): + if not type_params: + type_params = [] + if not type_constraints: + type_constraints = [] return FuncType(args, ret_type, type_params, type_constraints) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 8b49710f70ec..b075704c212a 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -1,7 +1,6 @@ -# pylint: disable=no-else-return +# pylint: disable=no-else-return, # pylint: disable=unidiomatic-typecheck -""" -The optimizer for Relay. +"""The optimizer for Relay. Exposes an interface for configuring the optimizer and scripting it directly in Python. @@ -26,6 +25,7 @@ from . import _ir_pass # Expose checking expression, should rename to infer_type. +# pylint: disable=invalid-name check_expr = _ir_pass.check_expr # # pylint: disable=invalid-name @@ -47,7 +47,10 @@ def mangle(name: str, types: List[Type]) -> str: name += str(typ) + "_" return name + T = TypeVar('T') + + class AbstractExprVisitor(Generic[T]): """A functional visitor over Expr in Python.""" @@ -104,11 +107,13 @@ def visit_global_var(self, _: GlobalVar) -> T: def to_pass(cls) -> Callable[[Environment], Callable[[GlobalVar, Function], Function]]: def _outer_wrapper(env): visitor = cls(env) - def _inner_wrapper(var, func): + + def _inner_wrapper(_, func): return visitor.visit(func) return _inner_wrapper return _outer_wrapper + class ExprVisitor(AbstractExprVisitor[Expr]): """A functional visitor over Expr in Python.""" @@ -149,8 +154,10 @@ def visit_tuple(self, tup: Tuple) -> Expr: def visit_constant(self, const: Constant) -> Expr: return const + MMCacheKey = Tuple[Union[GlobalVar, str], List[Type]] + class Monomorphize(ExprVisitor): """A monomorphization pass. @@ -182,11 +189,12 @@ def visit_call(self, call: Call) -> Expr: mono_name = mangle(poly_name, call.type_args) for arg in call.type_args: if isinstance(arg, TypeParam): - return call # raise Exception("...") # Fix me in the morning!!! + # raise Exception("...") # Fix me in the morning!!! + return call mono_op = specialize_op(poly_name, mono_name, call.type_args) self.monomorph_map[cache_key] = mono_op - return Call(mono_op, new_args,call.attrs, []) + return Call(mono_op, new_args, call.attrs, []) elif isinstance(call.op, GlobalVar): return call # defn = self.env.lookup(call.op) @@ -203,7 +211,7 @@ def visit_call(self, call: Call) -> Expr: # self.env.add(defn) # self.visit_item(defn) # return Call(new_id, call.args, call.attrs) - + elif isinstance(call.op, Function): return call # new_func = type_specialize(call.type_args, call.op) @@ -222,4 +230,3 @@ def visit_call(self, call: Call) -> Expr: # TODO(@jroesch): Fix up my type __tgt_host__ = __tgt__ = "llvm" __relay_tvm_context__ = tvm.cpu() - diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 47ebc5501cab..5c3a8ac249a6 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -1,3 +1,4 @@ +#pylint: disable=wildcard-import """Relay core operators.""" # operator defs from .op import get, register, Op, compile_ops diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index da94ec89b380..4427faa6a3a6 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -1,17 +1,20 @@ +#pylint: disable=invalid-name """Backend compiler related feature regsitration""" +from topi import add from .op import register from ..type import FuncType, TensorType from ...schedule import create_schedule from ...api import placeholder -from topi import add def type_to_placeholder(name, ty): + """Convert a single type into the correct placeholder.""" if isinstance(ty, TensorType): return placeholder(ty.shape, name=name, dtype=ty.dtype) else: raise Exception("can only pass Tensor values to TVM operators") def func_ty_to_placeholders(func_ty): + """Build input placeholders based on a function type.""" if isinstance(func_ty, FuncType): arg_types = func_ty.arg_types ret_type = func_ty.ret_type @@ -45,12 +48,13 @@ def func_ty_to_placeholders(func_ty): # schedule = tvm.create_schedule(Output.op) # return [schedule, Inputs + [Output]] - -def add_compiler(op_name, func_type, *args): - Inputs, ret_ty = func_ty_to_placeholders(func_type) +#pylint: disable=duplicate-argument-name +def add_compiler(_, func_type, *_): + """The compilation code for the TVM compiler.""" + inputs, _ = func_ty_to_placeholders(func_type) # op = lookup_in_topi(op_name) - Output = add(*Inputs) - schedule = create_schedule(Output.op) - return [schedule, Inputs + [Output]] + output = add(*inputs) + schedule = create_schedule(output.op) + return [schedule, inputs + [output]] -register("add", "FRelayOpCompiler", add_compiler) \ No newline at end of file +register("add", "FRelayOpCompiler", add_compiler) diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index bb589f40f138..14570b62269b 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -3,13 +3,13 @@ from ..base import register_relay_node from ..expr import Expr -from ..._ffi.function import Function, register_func -from ...api import convert -from ...container import Map -from ... import lower, build, cpu +from ..._ffi.function import register_func +from ... import lower, build + @register_relay_node class Op(Expr): + """A Relay operator definition.""" def __init__(self): raise RuntimeError("Cannot create op, use get instead") @@ -74,6 +74,7 @@ def _register(v): return v return _register(value) if value else _register + def compile_ops(op_names): """Register an operator property of an operator. @@ -90,6 +91,8 @@ def compile_ops(op_names): return _CompileOpsToModule(*op_names) # TODO(@jroesch): We should port to C++, just need to figure out how to write this code. + + @register_func("relay.op._compile_ops") def _compile_ops(op_impls): lowered = [] @@ -100,8 +103,10 @@ def _compile_ops(op_impls): # TOOD(@jroesch): Where should we read these settings from return build(lowered, target='llvm', target_host='llvm') + _init_api("relay.op", __name__) + def specialize_op(op_name, new_op_name, type_args): """Specializes an operator to a set of types and assigns it new_op_name. @@ -110,7 +115,7 @@ def specialize_op(op_name, new_op_name, type_args): add : forall (T : Type) (U : Type), (U, T) -> Broadcast(U, T) - This is a function which is polymorphic over two types `T` and `U` and + This is a function which is polymorphic over two types `T` and `U` and takes a value of type `T` and one of `U` and returns `Broadcast` of U and T. @@ -135,9 +140,9 @@ def specialize_op(op_name, new_op_name, type_args): ---------- op_name : str The operator to be specialized. - + Returns ------- The specialized operator. """ - return _SpecializeOp(op_name, new_op_name, type_args) \ No newline at end of file + return _SpecializeOp(op_name, new_op_name, type_args) diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index d0c1b88eb240..57fbccf488dc 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -59,6 +59,7 @@ def sqrt(data): """ return _make.sqrt(data) + def add(lhs, rhs): """Take sqrt of data. @@ -76,6 +77,7 @@ def add(lhs, rhs): """ return _make.add(lhs, rhs) + def subtract(lhs, rhs): """Take sqrt of data. @@ -93,5 +95,6 @@ def subtract(lhs, rhs): """ return _make.add(lhs, rhs) + def equal(lhs, rhs): - return _make.equal(lhs, rhs) \ No newline at end of file + return _make.equal(lhs, rhs) diff --git a/python/tvm/relay/to_tvm.py b/python/tvm/relay/to_tvm.py index 181251844a6d..615a39301142 100644 --- a/python/tvm/relay/to_tvm.py +++ b/python/tvm/relay/to_tvm.py @@ -155,6 +155,7 @@ def visit_local_var(self, ident: LocalVar) -> NodeRef: return self.lookup(ident) def visit_call(self, call: Call) -> NodeRef: + """Transform a ::tvm.relay.Call into an operator in the TVM graph.""" inputs = [] for arg in call.args: inputs.append(self.visit(arg).to_json()) @@ -222,7 +223,7 @@ def to_json(self) -> str: return json.dumps(json_dict) -def compile(func): +def compile_to_tvm(func): """Compile a single function to the components needed by the TVM RTS. """ diff --git a/python/tvm/relay/type.py b/python/tvm/relay/type.py index d9fc1eff1fd0..22c853ef512f 100644 --- a/python/tvm/relay/type.py +++ b/python/tvm/relay/type.py @@ -1,9 +1,9 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """The type nodes of the Relay language.""" -from typing import Tuple, List +from typing import List from enum import IntEnum -from .base import Span, NodeBase, register_relay_node from tvm import expr +from .base import Span, NodeBase, register_relay_node from . import _make @@ -67,18 +67,18 @@ class Kind(IntEnum): @register_relay_node class TypeParam(Type): """A type parameter used for generic types in Relay, - see tvm/relay/type.h for more details. + see tvm/relay/type.h for more details. - A type parameter represents a type placeholder which will - be filled in later on. This allows the user to write - functions which are generic over types. + A type parameter represents a type placeholder which will + be filled in later on. This allows the user to write + functions which are generic over types. """ var: expr.Var kind: Kind span: Span def __init__(self, var: expr.Var, kind: Kind) -> None: - """Construct a TypeParam. + """Construct a TypeParam. Parameters ---------- @@ -87,7 +87,7 @@ def __init__(self, var: expr.Var, kind: Kind) -> None: kind: Kind The kind of the type parameter. - + Returns ------- type_param: TypeParam @@ -112,8 +112,7 @@ class FuncType(Type): being, a sequence of argument types, and a return type. We informally write them as: - `forall (type_params), (arg_types) -> ret_type - where type_constraints` + `forall (type_params), (arg_types) -> ret_type where type_constraints` """ type_params: List[TypeParam] type_constraints: List[TypeConstraint] @@ -121,8 +120,12 @@ class FuncType(Type): ret_type: Type span: Span - def __init__(self, arg_types: List[Type], ret_type: Type, type_params: List[TypeParam], type_constraints: List[TypeConstraint]) -> None: - """Construct a function type. + def __init__(self, + arg_types: List[Type], + ret_type: Type, + type_params: List[TypeParam], + type_constraints: List[TypeConstraint]) -> None: + """Construct a function type. Parameters ---------- @@ -130,7 +133,7 @@ def __init__(self, arg_types: List[Type], ret_type: Type, type_params: List[Type ret_type: Type type_params: list of TypeParam type_constraints: list of TypeConstraint - + Returns ------- func_type: FuncType @@ -142,8 +145,9 @@ def __init__(self, arg_types: List[Type], ret_type: Type, type_params: List[Type @register_relay_node class TypeCall(Type): - def __init__() -> None: - pass + def __init__(self, type_rel, args) -> None: + self.__init_handle_by_constructor__( + _make.TypeCall, type_rel, args) @register_relay_node diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index f169ff1b64ac..f0d60f514a37 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -6,8 +6,10 @@ from . import make as _make from . import expr as _expr + class TensorSlice(NodeGeneric, _expr.ExprOp): """Auxiliary data structure for enable slicing syntax from tensor.""" + def __init__(self, tensor, indices): if not isinstance(indices, tuple): indices = (indices,) @@ -31,9 +33,11 @@ def dtype(self): itervar_cls = None + @register_node class Tensor(NodeBase, _expr.ExprOp): """Tensor object, to construct, see function.Tensor""" + def __call__(self, *indices): ndim = self.ndim if len(indices) != ndim: @@ -104,6 +108,7 @@ def name(self): class Operation(NodeBase): """Represent an operation that generate a tensor""" + def output(self, index): """Get the index-th output of the operation diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index cd87fb83ec52..e225b0c5579a 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -33,7 +33,7 @@ def run(env, expr, inputs, shape): env.add("main", expr) env.transform(Monomorphize.to_pass()) main = env.lookup("main") - graph, lib, _ = to_tvm.compile(main) + graph, lib, _ = to_tvm.compile_to_tvm(main) # We use NNVM to load the graph right now because it populates node_row_ptr field. nnvm_graph = nnvm.graph.load_json(graph) module = graph_runtime.create(nnvm_graph, lib, tvm.cpu(0)) From 356a810bbc961716f67113bc8ed52b1318cd4cfb Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 7 Sep 2018 15:18:51 -0700 Subject: [PATCH 079/136] Fix doc error --- include/tvm/relay/op.h | 8 ++++++-- include/tvm/relay/pass/alpha_eq.h | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 2d5627f2c844..f79728918086 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -149,7 +149,9 @@ class OpRegistry { const std::string& description); /*! * \brief Attach the type function corresponding to the return type. - * \param ty_func The type function to register for the return type. + * \param ty_func_name The type function name to register for the return type. + * \param type_fn The backing relation which can solve an arbitrary relation + * on variables. * \return reference to self. */ inline OpRegistry& add_type_func(const std::string& type_func_name, @@ -157,7 +159,9 @@ class OpRegistry { /*! * \brief Attach the type function corresponding to the return type. - * \param ty_func The type function to register for the return type. + * \param ty_func_name The type function name to register for the return type. + * \param type_fn The backing relation which can solve an arbitrary relation + * on variables. * \return reference to self. */ inline OpRegistry& add_type_func( diff --git a/include/tvm/relay/pass/alpha_eq.h b/include/tvm/relay/pass/alpha_eq.h index 87b5164462d7..b6d98bd68940 100644 --- a/include/tvm/relay/pass/alpha_eq.h +++ b/include/tvm/relay/pass/alpha_eq.h @@ -1,6 +1,6 @@ /*! * Copyright (c) 2018 by Contributors - * \file tvm/relay/alpha_eq.h + * \file tvm/relay/pass/alpha_eq.h * \brief Check expressions and types for structural equivalence. */ #ifndef TVM_RELAY_PASS_ALPHA_EQ_H_ From 12dcc2ce883fec4bec3dafa1d8e69c9f2da0e905 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 7 Sep 2018 15:22:48 -0700 Subject: [PATCH 080/136] Fix doc error again --- include/tvm/relay/op.h | 24 ++++++++++++------------ src/relay/op/tensor/elemwise.cc | 12 ++++++------ 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index f79728918086..7d0a58265565 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -149,24 +149,24 @@ class OpRegistry { const std::string& description); /*! * \brief Attach the type function corresponding to the return type. - * \param ty_func_name The type function name to register for the return type. - * \param type_fn The backing relation which can solve an arbitrary relation + * \param type_rel_name The type function name to register for the return type. + * \param type_rel The backing relation which can solve an arbitrary relation * on variables. * \return reference to self. */ - inline OpRegistry& add_type_func(const std::string& type_func_name, - TypeRelationFn type_fn); + inline OpRegistry& add_type_rel(const std::string& type_rel_name, + TypeRelationFn type_rel); /*! * \brief Attach the type function corresponding to the return type. - * \param ty_func_name The type function name to register for the return type. - * \param type_fn The backing relation which can solve an arbitrary relation + * \param type_rel_name The type function name to register for the return type. + * \param type_rel The backing relation which can solve an arbitrary relation * on variables. * \return reference to self. */ - inline OpRegistry& add_type_func( - const std::string& type_func_name, - std::function(const Array&, int)> type_fn); + inline OpRegistry& add_type_rel( + const std::string& type_rel_name, + std::function(const Array&, int)> type_rel); /*! * \brief Set the type key of attributes. @@ -356,15 +356,15 @@ inline OpRegistry& OpRegistry::add_argument(const std::string& name, return *this; } -inline OpRegistry& OpRegistry::add_type_func( +inline OpRegistry& OpRegistry::add_type_rel( const std::string& type_func_name, std::function(const Array&, int)> type_fn) { auto pfunc = runtime::TypedPackedFunc(const Array&, int)>(type_fn); - return add_type_func(type_func_name, pfunc); + return add_type_rel(type_func_name, pfunc); } -inline OpRegistry& OpRegistry::add_type_func(const std::string& type_func_name, +inline OpRegistry& OpRegistry::add_type_rel(const std::string& type_func_name, TypeRelationFn type_fn) { auto type_func = TypeRelationNode::make(type_func_name, 0, type_fn); diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc index d6a04773b7fa..a18259c72117 100644 --- a/src/relay/op/tensor/elemwise.cc +++ b/src/relay/op/tensor/elemwise.cc @@ -37,7 +37,7 @@ RELAY_REGISTER_UNARY_OP("log") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_func("Log", IdentityRel); +.add_type_rel("Log", IdentityRel); // data : Tensor[shape, dtype] // result: Tensor[shape, dtype] @@ -51,7 +51,7 @@ RELAY_REGISTER_UNARY_OP("exp") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_func("Exp", IdentityRel); +.add_type_rel("Exp", IdentityRel); RELAY_REGISTER_UNARY_OP("sqrt") @@ -62,7 +62,7 @@ RELAY_REGISTER_UNARY_OP("sqrt") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_func("Sqrt", IdentityRel); +.add_type_rel("Sqrt", IdentityRel); // Addition TVM_REGISTER_API("relay.op._make.add") @@ -76,7 +76,7 @@ RELAY_REGISTER_OP("add") .add_argument("lhs", "Tensor", "The left hand side tensor.") .add_argument("rhs", "Tensor", "The right hand side tensor.") .set_support_level(1) - .add_type_func("Broadcast", BroadcastRel); + .add_type_rel("Broadcast", BroadcastRel); // def broadcast(s1, s2): // ... @@ -97,7 +97,7 @@ RELAY_REGISTER_OP("subtract") .add_argument("lhs", "Tensor", "The left hand side tensor.") .add_argument("rhs", "Tensor", "The right hand side tensor.") .set_support_level(1) - .add_type_func("BroadcastComp", BroadcastCompRel); + .add_type_rel("BroadcastComp", BroadcastCompRel); // def broadcast(s1, s2): // ... @@ -118,7 +118,7 @@ RELAY_REGISTER_OP("equal") .add_argument("lhs", "Tensor", "The left hand side tensor.") .add_argument("rhs", "Tensor", "The right hand side tensor.") .set_support_level(1) - .add_type_func("BroadcastComp", BroadcastCompRel); + .add_type_rel("BroadcastComp", BroadcastCompRel); } // namespace relay } // namespace tvm From 2c4f54cabbaeecbcc29b90b33e6f67cdad8c6319 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 7 Sep 2018 15:29:09 -0700 Subject: [PATCH 081/136] Fix signed/unsigned compare --- src/relay/op/type_relations.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 2a6efbcf71e4..fb9008b3e8f2 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -91,7 +91,7 @@ static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, CHECK_EQ(larger.size(), smaller.size()); Array out_shape; - for (int i = 0; i < smaller.size(); i++) { + for (size_t i = 0; i < smaller.size(); i++) { auto left = smaller[i].as(); auto right = larger[i].as(); CHECK(left); From 91757a76ce0378b50c4c1248f2182882a194a970 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 7 Sep 2018 15:33:07 -0700 Subject: [PATCH 082/136] Kill a few more warnings --- src/relay/pass/type_infer.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index df896fa3936a..4873b0a55580 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -86,13 +86,14 @@ struct TypeNormalizer : TypeFVisitor { CHECK(new_args.size() == normalized_args.size()); tvm::Array final_args; - for (int i = 0; i < new_args.size(); i++) { + for (size_t i = 0; i < new_args.size(); i++) { final_args.push_back(unifier->unify(normalized_args[i], new_args[i])); } return TypeCallNode::make(ty_call->func, final_args); } else { - CHECK(false); + throw InternalError("found non type relation in the "\ + "type call function position"); } } } From 2d72cb5f0f8a803cadccf197435ca63f0ca383f3 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 7 Sep 2018 15:48:44 -0700 Subject: [PATCH 083/136] Remove another size_t --- src/relay/ir/op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 18a647798c9e..769f26a42101 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -201,7 +201,7 @@ Op SpecializeOp(const std::string& op_name, const std::string& new_op_name, // Build a subsitituion map up from the function type and type arguments. // Eventually allow the type vars to be passed in. - for (auto i = 0; i < type_args.size(); i++) { + for (size_t i = 0; i < type_args.size(); i++) { subst_map.Set(fn_ty->type_params[i], type_args[i]); } From 0f48f4901392de7dae7f7ef617c9d3af01f58e27 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 7 Sep 2018 15:52:14 -0700 Subject: [PATCH 084/136] Fix warning --- src/relay/pass/unifier.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index 7735ca8b0482..f1411bf9476c 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -184,7 +184,7 @@ Type TypeUnifierNode::VisitType(const Type &t1, const Type t2) { // // We flip the arguments so we hit the TypeCall and other case in there is // ever a type call. - } else if (const TypeCallNode *tvn2 = t2.as()) { + } else if (t2.as()) { return TypeFunctor::VisitType(t2, t1); } else { return TypeFunctor::VisitType(t1, t2); From 98efaf9b2b42d4e85c94d9ab6e6c627d20047615 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sat, 8 Sep 2018 18:25:18 -0700 Subject: [PATCH 085/136] Rewriting language reference to newest version of Relay. --- docs/api/python/index.rst | 1 - docs/api/python/relay/base.rst | 9 - docs/api/python/relay/env.rst | 6 - docs/api/python/relay/expr.rst | 36 ---- docs/api/python/relay/index.rst | 20 -- docs/api/python/relay/ir_builder.rst | 6 - docs/api/python/relay/ir_pass.rst | 3 - docs/api/python/relay/op.rst | 3 - docs/api/python/relay/to_tvm.rst | 3 - docs/api/python/relay/type.rst | 27 --- docs/conf.py | 2 +- docs/langref/relay/expressions.rst | 178 ------------------ docs/langref/relay/index.rst | 17 -- docs/langref/relay/intro.rst | 17 -- docs/langref/relay/type_system.rst | 137 -------------- include/tvm/relay/expr_visitor.h | 4 + python/tvm/relay/__init__.py | 5 +- python/tvm/relay/env.py | 32 +++- python/tvm/relay/expr.py | 2 - python/tvm/relay/ir_builder.py | 37 +++- python/tvm/relay/op/_tensor.py | 2 +- src/relay/ir/op.cc | 3 - src/relay/op/type_relations.cc | 24 +-- src/relay/pass/resolve.cc | 1 - src/relay/pass/unifier.cc | 2 +- .../relay/test_tyck_eval_integration.py | 8 +- 26 files changed, 86 insertions(+), 499 deletions(-) delete mode 100644 docs/api/python/relay/base.rst delete mode 100644 docs/api/python/relay/env.rst delete mode 100644 docs/api/python/relay/expr.rst delete mode 100644 docs/api/python/relay/index.rst delete mode 100644 docs/api/python/relay/ir_builder.rst delete mode 100644 docs/api/python/relay/ir_pass.rst delete mode 100644 docs/api/python/relay/op.rst delete mode 100644 docs/api/python/relay/to_tvm.rst delete mode 100644 docs/api/python/relay/type.rst delete mode 100644 docs/langref/relay/expressions.rst delete mode 100644 docs/langref/relay/index.rst delete mode 100644 docs/langref/relay/intro.rst delete mode 100644 docs/langref/relay/type_system.rst diff --git a/docs/api/python/index.rst b/docs/api/python/index.rst index ab411d77f4f4..59bd1795b7ec 100644 --- a/docs/api/python/index.rst +++ b/docs/api/python/index.rst @@ -23,5 +23,4 @@ Python API topi vta/index nnvm/index - relay/index hybrid diff --git a/docs/api/python/relay/base.rst b/docs/api/python/relay/base.rst deleted file mode 100644 index f0cec295ee6b..000000000000 --- a/docs/api/python/relay/base.rst +++ /dev/null @@ -1,9 +0,0 @@ -tvm.relay.base ------------ -.. automodule:: tvm.relay.base - -.. autoclass:: tvm.relay.base.NodeBase - :members: - -.. autoclass:: tvm.relay.base.Span - :members: \ No newline at end of file diff --git a/docs/api/python/relay/env.rst b/docs/api/python/relay/env.rst deleted file mode 100644 index eca7312d5bbb..000000000000 --- a/docs/api/python/relay/env.rst +++ /dev/null @@ -1,6 +0,0 @@ -tvm.relay.env ------------ -.. automodule:: tvm.relay.env - -.. autoclass:: tvm.relay.env.Environment - :members: \ No newline at end of file diff --git a/docs/api/python/relay/expr.rst b/docs/api/python/relay/expr.rst deleted file mode 100644 index cd0cb5c308c4..000000000000 --- a/docs/api/python/relay/expr.rst +++ /dev/null @@ -1,36 +0,0 @@ -tvm.relay.expr ------------ -.. automodule:: tvm.relay.expr - -.. autoclass:: tvm.relay.expr.ExprBuilder - :members: - -.. autoclass:: tvm.relay.expr.Expr - :members: - -.. autoclass:: tvm.relay.expr.Constant - :members: - -.. autoclass:: tvm.relay.expr.Tuple - :members: - -.. autoclass:: tvm.relay.expr.LocalVar - :members: - -.. autoclass:: tvm.relay.expr.GlobalVar - :members: - -.. autoclass:: tvm.relay.expr.Param - :members: - -.. autoclass:: tvm.relay.expr.Function - :members: - -.. autoclass:: tvm.relay.expr.Call - :members: - -.. autoclass:: tvm.relay.expr.Let - :members: - -.. autoclass:: tvm.relay.expr.If - :members: \ No newline at end of file diff --git a/docs/api/python/relay/index.rst b/docs/api/python/relay/index.rst deleted file mode 100644 index 231d49df0e6d..000000000000 --- a/docs/api/python/relay/index.rst +++ /dev/null @@ -1,20 +0,0 @@ -Relay API -========= - -This document contains the Python API to the Relay frontend, optimizer, and -compiler toolchain. - -Relay is the second generation high level intermediate representation for the TVM -compiler stack. - -.. toctree:: - :maxdepth: 2 - - base - env - expr - ir_builder - ir_pass - op - to_tvm - type diff --git a/docs/api/python/relay/ir_builder.rst b/docs/api/python/relay/ir_builder.rst deleted file mode 100644 index b12e3cc6cdd1..000000000000 --- a/docs/api/python/relay/ir_builder.rst +++ /dev/null @@ -1,6 +0,0 @@ -tvm.relay.ir_builder ------------ -.. automodule:: tvm.relay.ir_builder - -.. autoclass:: tvm.relay.ir_builder.IRBuilder - :members: \ No newline at end of file diff --git a/docs/api/python/relay/ir_pass.rst b/docs/api/python/relay/ir_pass.rst deleted file mode 100644 index e2e3b432e5bd..000000000000 --- a/docs/api/python/relay/ir_pass.rst +++ /dev/null @@ -1,3 +0,0 @@ -tvm.relay.ir_pass ------------ -.. automodule:: tvm.relay.ir_pass \ No newline at end of file diff --git a/docs/api/python/relay/op.rst b/docs/api/python/relay/op.rst deleted file mode 100644 index fb8e9ce774c2..000000000000 --- a/docs/api/python/relay/op.rst +++ /dev/null @@ -1,3 +0,0 @@ -tvm.relay.op ------------ -.. automodule:: tvm.relay.op \ No newline at end of file diff --git a/docs/api/python/relay/to_tvm.rst b/docs/api/python/relay/to_tvm.rst deleted file mode 100644 index 72d01d123e0f..000000000000 --- a/docs/api/python/relay/to_tvm.rst +++ /dev/null @@ -1,3 +0,0 @@ -tvm.relay.to_tvm ------------ -.. automodule:: tvm.relay.to_tvm diff --git a/docs/api/python/relay/type.rst b/docs/api/python/relay/type.rst deleted file mode 100644 index d357df8f08ac..000000000000 --- a/docs/api/python/relay/type.rst +++ /dev/null @@ -1,27 +0,0 @@ -tvm.relay.type ------------ -.. automodule:: tvm.relay.type - -.. autoclass:: tvm.relay.type.Type - :members: - -.. autoclass:: tvm.relay.type.TensorType - :members: - -.. autoclass:: tvm.relay.type.Kind - :members: - -.. autoclass:: tvm.relay.type.TypeParam - :members: - -.. autoclass:: tvm.relay.type.TypeConstraint - :members: - -.. autoclass:: tvm.relay.type.FuncType - :members: - -.. autoclass:: tvm.relay.type.TypeCall - :members: - -.. autoclass:: tvm.relay.type.IncompleteType - :members: \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index e3f7f6a82c24..717003824703 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -33,7 +33,7 @@ # General information about the project. project = u'tvm' author = u'%s developers' % project -copyright = u'2017, %s' % author +copyright = u'2018, %s' % author github_doc_root = 'https://github.com/tqchen/tvm/tree/master/docs/' # add markdown parser diff --git a/docs/langref/relay/expressions.rst b/docs/langref/relay/expressions.rst deleted file mode 100644 index 37dc62c6bc24..000000000000 --- a/docs/langref/relay/expressions.rst +++ /dev/null @@ -1,178 +0,0 @@ -================== -Expressions -================== - -Relay's IR is a pure expression oriented language, that has a -dataflow fragment and structured control flow. Although Relay's -representation is a tree, it is possible to view the dataflow -fragments as graph for purposes of writing and expressing -transformations. - -The below sections make an attempt to clearly split the dataflow -fragment from the control fragment. - -================== -Dataflow Expressions -================== - -First we will cover the set of nodes which do not involve control flow, -this fragment of the language is semantically equivalent to pure -computation graphs without control flow. - -Constants -~~~~~~~~~ -Relay programs can contain constant Tensor values, since in Relay all -values are either Tensors, Products, or Closures. We will discuss the -later two later, but we represent Tensor constants as `tvm.NDArray`, -allowing us to utilize normal operators for constant evaluation. - - -Constructors -~~~~~~~~ - -Relay supports a handful of constructors which we will cover below. A -constructor enables programs to build new values from arbitrary Relay -expressions. - - -We support four types of literals, literals are type polymorphic and can -assigned any base type. If we can not solve for a concrete type we apply -a defaulting rule. - -We support signed and unsigned integers, floating point numbers, booleans, -and tensor literals. - -The base type literals are designed to closely model literals in TVM's -expressions langauge. - -### Boolean Literals -TODO: don't have these in any form right now - -### Integer Literals -TODO: don't have these in any form right now - -Tensor Constructor -~~~~~~~~~~~~~~~ - -A tensor literal allows us to build a Tensor from other expressions. - -TODO: Example here - - -Tuple Constructor -~~~~~~~~~~~~~~~ - -We support tuple constructors which allows us to build a fixed-k sized -sequence of heterogenous data. These tuples match closely to Python's -and enable efficient projection of their members due to their fixed length. - - (a, b, c) : Tuple - - (a + b + c, d) : Tuple, Tensor> - -Function -~~~~~~~~ - -A function node represents a function, it contains a seqeuence of -parameters, a return type, and a body. - - fun (x : Float, y: Float) -> Float { x + y } - -Functions are first class in Relay, and can be used in any expression -position. Functions are the same as global functions, but do not have -an explicit name. You can use a function in conjunction with a let -binding to define locally recursive functions. - - let fact = fun (x : Float) -> Float { - if (x == 0) { - 0 - } else { - x * fact(x - 1) - }; - fact(10) - -Identifiers -~~~~~~~~~~~ - -All of the identifiers are valid expressions, you can use a local identifier, -global identifier, or intrinsic identifier anywhere an expression may appear. - -For example the below fragment of code is a valid expression. - - %ret = @global(intrinsic, %local) - -Let Binding -~~~~~~~~~~~ - -An immutable variable binding, allows the user to bind an -expression to a name. A let binding contains a local identifier, -an optional type, a value, and body expression which may -reference the bound identifier. - -We will first introduce a single binding with no type -anntoations:: - let %x = %a + %b; - x - -The value of a let binding is the value of the final expression -after evaluating the bindings it depends on. - -A user can write a sequence of let bindings, we can view -these blocks and pure dataflow -single binding. These blocks are pure dataflow, and can -be evaluated in any order, reordered up to dataflow. - -We support a sequence of bindings followed by a body which -is the continutation after executing the sequence of bindings. - -I believe this representation will be easier to manipulate then -the mixed dataflow/control flow comptuation graphs. -Data flow and control flow are strictly seperated in this representation -and we can easily syntactically discriminate. When in ANF there should only be -general control flow between `Assignment` nodes and not within the values bound -in bindings. - -This representation also makes it easy to apply reverse more since -sequences of assignments where the only control flow is call instructions -are treated by the algorithm uniformly, and each control flow construct -must be handled individualy. - -TODO Add Ref, ReadRef, WriteRef, Projection, - -Gradient -~~~~~~~~ - -The `Reverse` acts as a marker node, when the compiler encounters it -we will apply the reverse mode transformation to the enclosed function. - -We will employ static analysis and constant evaluation in order to -simplify the node's argument to a known function call target. - - -You can compute the reverse node of a function node like so: - -Cast -~~~~~ - -Cast the type of the `node` to `ty`. - -======================= -Control Flow Expression -======================= -Control flow expressions change network topology based on values -computed by previous expressions. - -Call -~~~~ - -Terms with function types in Relay are "callable", that can be invoked like -a function in a typical programming language by supplying a set of arguments. - -Instrinsics with functions types, definitions, and functions are all callable. - -If-Then-Else -~~~~~~~~~~~~ - -Relay has a simple if/then/else expression which allows programs to branch -on a single control value which must be of type `Bool`, i.e a zero-rank -tensor of booleans. diff --git a/docs/langref/relay/index.rst b/docs/langref/relay/index.rst deleted file mode 100644 index 617e745acdfc..000000000000 --- a/docs/langref/relay/index.rst +++ /dev/null @@ -1,17 +0,0 @@ -Relay Language Reference -======================== - -This document is a work in progress language reference describing -Relay, TVM's high level intermediate representation. The name is an -allusion to interneurons which are often referred to as intermediate, -or relay neurons. - -We will continually iterate on this document as we evolve the new IR -and update accordingly. - -.. toctree:: - :maxdepth: 2 - - intro - expressions - type_system diff --git a/docs/langref/relay/intro.rst b/docs/langref/relay/intro.rst deleted file mode 100644 index 617e745acdfc..000000000000 --- a/docs/langref/relay/intro.rst +++ /dev/null @@ -1,17 +0,0 @@ -Relay Language Reference -======================== - -This document is a work in progress language reference describing -Relay, TVM's high level intermediate representation. The name is an -allusion to interneurons which are often referred to as intermediate, -or relay neurons. - -We will continually iterate on this document as we evolve the new IR -and update accordingly. - -.. toctree:: - :maxdepth: 2 - - intro - expressions - type_system diff --git a/docs/langref/relay/type_system.rst b/docs/langref/relay/type_system.rst deleted file mode 100644 index 91a634431d7c..000000000000 --- a/docs/langref/relay/type_system.rst +++ /dev/null @@ -1,137 +0,0 @@ -================== -Type System -================== - -We have briefly introduced types while detailing the the expression language -of Relay, but have fully laid out the type system. - -Although the majority of Relay programs require no type annotations, Relay -is statically typed. Each expression in Relay has a precisely known type. - -You might ask why we want a statically typed IR, there are multiple advantages. -- efficient layout and code generation for tensors -- TODO -- debugging transformations (most program transformations should be type perserving) - -We are able to omit these type annotations by a process known as type inference. -Type inference is a technique that has its roots in the programming language -community, and can be viewed as a method for generalizing shape inference to -run over arbitrary user programs. - -Static typing means we know before executing the program properties about -the values it manipulates. Static types are useful for compiler optimization -because they communicate properties about the data we manipulate, such as -runtime shape, data layout, storage. - -Most current IRs use "shape inference" to recover Tensor dimensions from the user -provided program. Machine learning users have enjoyed shape inference for -tensors because it allows them to generate performant code without giving up -on the expressivity of the input language. - -Because Relay is intended as an IR we require *some* type information to provide -full inference. We don't believe this to be an issue as many of the IR builder -inferfaces require some type information, or can generate IR based on their own -higher level inferences. - -We view this limited shape inference as a simpler form of type -inference. Instead of relying on an ad-hoc procedure for recovering type -information from a potentially dynamic program, we apply ideas from compiler and IR design. - -Below we briefly dicsuss the different kinds of types in Relay. - -===== -Types -===== - -BaseType -~~~~~~~~~~ -Relay has a notion of a BaseType, which captures the set of types -that can be stored in a Tensor. Relay's base types map to the set -of types supported by TVM. - -Each of the base types can be parametrized by number of bits, and -lanes for vectorization purposes. We support four base types any:`Bool`, -any:`Int` - -Type Variables -~~~~~~~~~~~~~~ - -Type Parameters -~~~~~~ -TODO: type parameter - -Kind -~~~~ - -Function Types -~~~~~~~~~~ -TODO: rename function type? - -TypeQuantifier -~~~~~~~~~~~~~~ -TODO - -Placeholders -~~~~~~~~~~~~ - -TODO - -Tuple Types -~~~~~~~~~~~~~ - -Reference Types -~~~~~~~~~~~~~~~ - -A reference type is simply a mutable memory location, since Relay is a pure -language by default we need a way to introduce limited mutability. In this -case mutable data is clearly marked in the type system as a reference type. - - Ref - -Tensor Type -~~~~~~~~~~~ - -Tensor values in Relay are typed with tensor types. A tensor type is -parametrized by a data type, and shape. The data type must be a base -type as enforced by the kind checking rules described in TODO. - -This restriction importantly means - -The shape may be any valid Relay shape as described in the below -section on shapes. - - -====== -Shapes -====== - -Shape Singleton -~~~~~~~~~~~~~~~ -I don't like this name - -ShapeAttr -~~~~~~~~~ -TODO - -ShapeProjection -~~~~~~~~~~~~~~~ -TODO - -ShapeBinaryOp -~~~~~~~~~~~~~ - -enum ShapeOp : int { - SHPLUS = 0, - SHSUB = 1, - SHMUL = 2, - SHDIV = 3 -}; - - -Shape Sequence -~~~~~~~~ -A sequence of shapes ... - - -ShapeBroadcast -~~~~~~~~~~~~~~ diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h index 6f2a7f98542a..0febad503b12 100644 --- a/include/tvm/relay/expr_visitor.h +++ b/include/tvm/relay/expr_visitor.h @@ -75,6 +75,10 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { return GetRef(op); } + Expr VisitExpr_(const ConstantNode* op) override { + return GetRef(op); + } + Expr VisitExpr_(const GlobalVarNode* op) override { return GetRef(op); } diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index aae019c8d9c1..c254c7e9ce7a 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -3,7 +3,10 @@ from . import base from . import type as tpe from . import expr - +from . import to_tvm +from . import env +from . import ir_pass +from . import ir_builder # Operators from .op import Op from .op.tensor import * diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py index beef6fd1a62c..93cbe1bca284 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/env.py @@ -1,5 +1,5 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import -"""A global environment storing everything needed to interpret or compile a Realy program.""" +"""A global environment storing everything needed to interpret or compile a Relay program.""" from .base import register_relay_node, NodeBase from . import _make from . import _env @@ -10,25 +10,55 @@ class Environment(NodeBase): options and more. """ def __init__(self, funcs) -> None: + """Construct an environment. + + Parameters + ------ + funcs: list of relay.Function + + Returns + ------ + env: A new environment containing :py:class:`~relay.env.Environment`. + """ self.__init_handle_by_constructor__(_make.Environment, funcs) def add(self, var, func) -> None: + """Add a function to the environment. + + Parameters + --------- + var: GlobalVar + The global variable which names the function. + + func: Function + The function. + """ if isinstance(var, str): var = _env.Environment_GetGlobalVar(self, var) _env.Environment_Add(self, var, func) def merge(self, other): + """Merge two environments. + + Parameters + ---------- + other: Environment + The environment to merge into the current Environment. + """ return _env.Environment_Merge(self, other) def global_var(self, var): + """Get a global variable by name.""" return _env.Environment_GetGlobalVar(self, var) def lookup(self, var): + """Lookup a global function by name or by variable.""" if isinstance(var, str): return _env.Environment_Lookup_str(self, var) else: return _env.Environment_Lookup(self, var) def transform(self, transformer): + """Apply a transformer function to the environment.""" _env.Environment_Transform(self, transformer) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 748b2aa1e282..3cdaed89d2fb 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -15,8 +15,6 @@ class ExprBuilder(): def __call__(self, *args, **kwargs): converted_args = [] for arg in args: - import pdb - pdb.set_trace() if isinstance(arg, Param): converted_args.append(arg.var) else: diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index a271a537b290..c0c2e76c1157 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -79,7 +79,7 @@ def to_func(self): #pylint: disable=invalid-name def _mk_let(bindings, ret_value): let_expr = ret_value - for var, value, ty in reversed(list(bindings.items())): + for var, (value, ty) in reversed(list(bindings.items())): let_expr = Let(var, value, let_expr, ty) return let_expr @@ -114,7 +114,7 @@ def exit_scope(self): return bindings, scopes, params, ret_value #pylint: disable=invalid-name - def bind(self, name, ty, value): + def bind(self, name, value, ty): lv = LocalVar(name) self.scopes[-1][name] = lv self.bindings[-1][lv] = (value, ty) @@ -127,16 +127,35 @@ def let(self, name, value, value_type=None): if not isinstance(value, Expr): value = into_ast(value) - return self.bind(name, value_type, value) + return self.bind(name, value, value_type) + + def _convert_params(self, raw_params): + relay_params = [] + for raw_param in raw_params: + if isinstance(raw_param, Param): + var = raw_param.var + param = raw_param + elif isinstance(raw_param, tuple): + var, ty = raw_param + if isinstance(var, str): + var = LocalVar(var) + param = Param(var, ty) + elif isinstance(param, str): + var = LocalVar(raw_param) + ty = None + param = Param(var, ty) + else: + raise Exception("unknown parameter type") + + self.scopes[-1][var.name_hint] = var + relay_params.append(param) + + return relay_params def function(self, *params): """Construct a Relay function.""" - relay_params = [] - for param in params: - name = param.var - ty = param.type - self.scopes[-1][name.name_hint] = name - relay_params.append(Param(name, ty)) + + relay_params = self._convert_params(params) # self.params.append(relay_params) diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 4427faa6a3a6..2a0ecc6c8550 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -49,7 +49,7 @@ def func_ty_to_placeholders(func_ty): # return [schedule, Inputs + [Output]] #pylint: disable=duplicate-argument-name -def add_compiler(_, func_type, *_): +def add_compiler(_, func_type, *__): """The compilation code for the TVM compiler.""" inputs, _ = func_ty_to_placeholders(func_type) # op = lookup_in_topi(op_name) diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 769f26a42101..7c005acb8648 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -155,9 +155,7 @@ Module CompileOpsToModule(const std::vector& op_names) { if (!IsGeneric(op->op_type)) { auto compiler = compiler_map[op]; - std::cout << "ABOVE CALL" << std::endl; tvm::Array pair = compiler(op->name, op->op_type); - std::cout << "BELOW CALL" << std::endl; // TODO(@jroesch): I can't pass strings across what should be the // interface here. tvm::Array triple = {LocalVarNode::make(op->name), pair[0], @@ -183,7 +181,6 @@ TVM_REGISTER_API("relay.op._CompileOpsToModule") for (auto i = 0; i < args.num_args; i++) { names.push_back(args[i]); } - std::cout << "Right here" << std::endl; *ret = CompileOpsToModule(names); }); diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index fb9008b3e8f2..e2b2cba1e0ef 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -22,9 +22,9 @@ TensorType as_ttype(const Type& t) { // TODO(@jroesch) what size value do we extract? int to_int(const tvm::Expr& e) { + CHECK(e.defined()); auto imm = e.as(); - CHECK(imm); - std::cout << "TYPE: " << imm << imm->type << std::endl; + CHECK(imm) << "TYPE: " << imm << imm->type << std::endl; return imm->value; } @@ -53,17 +53,17 @@ static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, auto suffix_len = static_cast(std::min(sh1.size(), sh2.size())); auto full_len = static_cast(std::max(sh1.size(), sh2.size())); - std::cout << "Length" << suffix_len << full_len - << (full_len - suffix_len - 1) << std::endl; - auto lower_bound = full_len - suffix_len - 1; + auto rev_sh1 = sh1.rbegin(); + auto rev_sh2 = sh2.rbegin(); - for (int64_t i = full_len - 1; i > lower_bound; i--) { - std::cout << "Index i=" << i << std::endl; - auto dim1 = to_int(sh1[i]); - auto dim2 = to_int(sh2[i]); - if (dim1 != dim2) { - CHECK(false); + while (rev_sh1 != sh1.rend() && rev_sh2 != sh2.rend()) { + auto dim1 = to_int(*rev_sh1); + auto dim2 = to_int(*rev_sh2); + if ((dim1 != dim2) && ((dim1 != 1) && (dim2 != 1))) { + CHECK(false) << "Dimension mistmatch " << "dim1: " << dim1 << " dim2: " << dim2 << std::endl; } + rev_sh1++; + rev_sh2++; } Array larger; @@ -106,9 +106,9 @@ static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, Array BroadcastRel(const Array& types, int num_args) { CHECK_EQ(types.size(), 3); + RELAY_LOG(INFO) << "In1: " << types[0] << "In2: " << types[1] << "Out: " << types[2] << std::endl; if (auto t1 = as_ttype(types[0])) { if (auto t2 = as_ttype(types[1])) { - std::cout << t1->dtype << t2->dtype << std::endl; CHECK_EQ(t1->dtype, t2->dtype); return {t1, t2, ConcreteBroadcast(t1, t2, t1->dtype)}; } diff --git a/src/relay/pass/resolve.cc b/src/relay/pass/resolve.cc index f513e36c9a30..bc63d939959e 100644 --- a/src/relay/pass/resolve.cc +++ b/src/relay/pass/resolve.cc @@ -53,7 +53,6 @@ struct ResolveTypeExpr : ExprFVisitor { // term, then resolve e's old type and write // it back into the new node. auto new_e = ExprFVisitor::VisitExpr(e); - std::cout << e << std::endl; CHECK(e->checked_type_.defined()); auto resolved_cty = VisitType(e->checked_type_); new_e->checked_type_ = resolved_cty; diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index f1411bf9476c..f5e337eb17f7 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -294,7 +294,7 @@ Type TypeUnifierNode::VisitType_(const TensorTypeNode *t1, const Type rt2) { // Type unified_shape = this->VisitType(tt1->shape, tt2->shape); return rt2; } catch (const UnificationError &err) { - std::cout << "Need to check constraint " << tt1->shape << " = " + CHECK(false) << "Need to check constraint " << tt1->shape << " = " << tt2->shape << std::endl; } diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index e225b0c5579a..f9a3d098a3e2 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -104,8 +104,8 @@ def test_add_broadcast_op(): ttype = tensor_type(5, 5, 5) expected_ty = func_type([ttype, ttype], ttype) assert has_type(func.to_func(), expected_ty) - x_data = tvm.nd.array(np.random.rand(5, 5, 5).astype('float32')) - y_data = tvm.nd.array(np.random.rand(5, 5, 5).astype('float32')) + x_data = tvm.nd.array(np.random.rand(10, 4).astype('float32')) + y_data = tvm.nd.array(np.random.rand(5, 10, 1).astype('float32')) result = run(env, prog, {'x': x_data, 'y': y_data}, (5, 10, 4)) np.testing.assert_allclose( x_data.asnumpy() + y_data.asnumpy(), result.asnumpy()) @@ -171,8 +171,8 @@ def f(n: i32, data: f32) -> f32 { # to execute this. if __name__ == "__main__": - # test_monomorphic_let() - # test_single_op() + test_monomorphic_let() + test_single_op() test_add_op() test_add_broadcast_op() # test_dual_op() From ef2a2298b80fa46bc1393e5a02160a32545c44ca Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 13 Sep 2018 16:16:37 -0700 Subject: [PATCH 086/136] Revert "Port docs from previous Relay version" This reverts commit d47f637a14b4cbfcbd356cb55df78df3b5093b88. --- tutorials/relay/implement_fma_transform.py | 141 --------------------- 1 file changed, 141 deletions(-) delete mode 100644 tutorials/relay/implement_fma_transform.py diff --git a/tutorials/relay/implement_fma_transform.py b/tutorials/relay/implement_fma_transform.py deleted file mode 100644 index 8c04e70aa846..000000000000 --- a/tutorials/relay/implement_fma_transform.py +++ /dev/null @@ -1,141 +0,0 @@ -"""How to use Relay to implement a simple two-operator fusion pass. -================================== -**Author**: `Jared Roesch `_ - -In this tutorial, we will demonstrate how to write a fusion pass for -the Relay IR. We demonstrate many Relay features including defining a -new operator, a program transform, the NNVM compatibility layer, -and executing the original and transformed programs on the Relay -evaluator and TVM runtime system. -""" - -################################################################ -# Introduction -# ------------------------- -# -# In this tutorial, we will demonstrate how to write a fusion pass for -# the Relay IR. We demonstrate many Relay features including defining a -# new operator, a program transform, the NNVM compatibility layer, -# and executing the original and transformed programs on the Relay -# evaluator and TVM runtime system. - -from typing import Any, Dict - -import numpy as np -import tvm -import topi - -from relay import ir, make as mk -from relay.ir import OperatorId -from relay.opt import ItemVisitor, ExprVisitor -from relay.frontend.nnvm import Variable, symbol -from relay.frontend.nnvm import compiler -from relay.frontend.global_env import get_env -from relay.operators.register import func_ty_to_placeholders, register_op -from relay.eval import defn_to_pyfunc -from relay.tyck import check_expr - -class ExprAtVisitor(ExprVisitor): - """A demo visitor which adds a new traversal strategy.""" - expr_map: Dict[ir.LocalId, ir.Expr] - - def __init__(self): - self.expr_map = {} - - def expr_at(self,id: ir.LocalId) -> ir.Expr: - try: - return self.expr_map[id] - except KeyError: - return id - - def visit_let(self, let: ir.Let) -> ir.Expr: - self.expr_map[let.id] = let.value - return super().visit_let(let) - -# let x = 1 + 1; -# ... x will map to 1 + 1 - -class FuseTwo(ExprAtVisitor): - """Rewrite b(a(x, y), z) into ab(x, y, z). """ - def __init__(self, a: OperatorId, b: OperatorId, a_and_b: OperatorId) -> None: - self.a = a - self.b = b - self.a_and_b = a_and_b - super().__init__() - - def visit_call(self, call: ir.Call) -> ir.Expr: - func = call.fn - if func == self.b: - assert len(call.args) == 2 # An assumption of this fusion code. - arg0 = self.expr_at(call.args[0]) - arg1 = self.expr_at(call.args[1]) - if isinstance(arg0, ir.Call) and arg0.fn == self.a: - new_call = mk.Call(self.a_and_b, arg0.args[:] + [arg1]) - elif isinstance(arg1, ir.Call) and arg1.fn == self.a: - new_call = mk.Call(self.a_and_b, arg1.args[:] + [arg0]) - else: - new_call = super().visit_call(call) - - return new_call - else: - return super().visit_call(call) - -def fma_compile(op_name: str, func_ty: ir.Type, attrs: ir.Attributes=None) -> Any: - Inputs, ret_ty = func_ty_to_placeholders(func_ty) - x, y, z = Inputs - Output = topi.multiply(topi.add(x, y), z) - # this is not a python function call, but builds an AST - schedule = tvm.create_schedule(Output.op) - return [schedule, Inputs + [Output]] - - -def register_fma(env: Any) -> None: - """Register TOPI's elementwise broadcast addition for the `+` operator.""" - shape = mk.TypeParam("s", ir.Kind.Shape) - bt = mk.TypeParam("bt", ir.Kind.BaseType) - in_out_type = mk.TensorType(bt, shape) - fma_type = mk.TypeQuantifier(bt, mk.TypeQuantifier(shape, mk.TypeArrow([in_out_type, in_out_type, in_out_type], in_out_type))) - # forall (bt: BaseTYpe) (s : Shape), Tensor[bt, s] -> Tensor[bt, s] -> Tensor[bt, s] - # TODO: no reverse mode - register_op(env, 'fma', fma_type, compiler=fma_compile) - -# Get the global environment for demo purposes. -env = get_env() - -register_fma(env) - -# A small helper which applies just our transform to the Relay expression. -def transform(e): - fuse = FuseTwo(env.add_id(), env.mul_id(), env.operator_id('fma')) - e = fuse.visit(e) - # Now let's use the type checker to make sure we didn't make a mistake. - check_expr(env, e) - return e - -# We will use NNVM frontend. -x = Variable('x') -y = Variable('y') -z = x * (x + y) - -relay_func = compiler.to_relay(z) - -print(f"Relay Function:\n{compiler.pp(relay_func)}") - -xform_func = transform(relay_func) - -print(f"Transformed Function:\n{compiler.pp(xform_func)}") - -# Use the evaluator. -norm = defn_to_pyfunc(env, relay_func) -xform = defn_to_pyfunc(env, xform_func) - -x = np.random.uniform(size=(10, 5, 10)).astype('float32') -y = np.random.uniform(size=(10, 5, 10)).astype('float32') - -norm_out = norm(x, y).asnumpy() -xform_out = xform(x, y).asnumpy() - -np.testing.assert_allclose(norm_out, xform_out) - -# Use the TVM runtime. - From 059727a52934330311ad836e43ea3203e836c233 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 14 Sep 2018 14:13:00 -0700 Subject: [PATCH 087/136] Rewrite type inferencer to use new relations --- include/tvm/relay/base.h | 43 +++- include/tvm/relay/expr_visitor.h | 14 +- include/tvm/relay/op.h | 5 +- include/tvm/relay/type.h | 51 +--- python/tvm/relay/type.py | 7 - src/relay/ir/op.cc | 4 +- src/relay/ir/type.cc | 25 +- src/relay/op/type_relations.cc | 6 +- src/relay/pass/alpha_eq.cc | 40 +-- src/relay/pass/type_functor.h | 2 - src/relay/pass/type_infer.cc | 242 ++++++++++++++---- src/relay/pass/type_visitor.h | 52 ++-- src/relay/pass/unifier.cc | 44 +--- src/relay/pass/unifier.h | 1 - .../relay/test_tyck_eval_integration.py | 31 +-- 15 files changed, 330 insertions(+), 237 deletions(-) diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 09f3a94e1edb..3178d073c778 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -14,10 +14,10 @@ namespace tvm { /*! * \brief Relay: a high level functional IR for TVM. - * + * * This namespace contains the abstract syntax tree, and other * essential data structures for the Relay IR. - * + * * You can find more about Relay by reading the language reference. */ namespace relay { @@ -66,7 +66,6 @@ using NodeEqual = ::tvm::NodeEqual; using ContainerType = NodeName; \ }; - /*! * \brief The source name in the Span * \sa SourceNameNode, Span @@ -80,9 +79,7 @@ class SourceNameNode : public Node { /*! \brief The source name */ std::string name; // override attr visitor - void VisitAttrs(AttrVisitor* v) final { - v->Visit("name", &name); - } + void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); } TVM_DLL static SourceName make(std::string name); @@ -154,6 +151,40 @@ RefType GetRef(const NodeType* ptr) { return RefType(const_cast(ptr)->shared_from_this()); } +// TODO(@tqchen, @jroesch): can we move these semantics to HalideIR +template +inline const T* As(const NodeRef& node) { + const Node* ptr = static_cast(node.get()); + if (ptr && (ptr->is_type() || ptr->derived_from())) { + return static_cast(ptr); + } + return nullptr; +} + +template +std::vector Downcast(std::vector array) { + std::vector out; + for (const U& elem : array) { + const typename T::ContainerType* node = + elem.template as(); + CHECK(node) << "Downcast failed" << std::endl; + out.push_back(GetRef(node)); + } + return out; +} + +template +Array Downcast(Array array) { + Array out; + for (const U& elem : array) { + const typename T::ContainerType* node = + elem.template as(); + CHECK(node) << "Downcast failed" << std::endl; + out.push_back(GetRef(node)); + } + return out; +} + /*! * \brief Get PackedFunction from global registry and * report error if it does not exist diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h index 0febad503b12..349ce3f9543b 100644 --- a/include/tvm/relay/expr_visitor.h +++ b/include/tvm/relay/expr_visitor.h @@ -1,8 +1,8 @@ /*! * Copyright (c) 2018 by Contributors * \file tvm/relay/expr_visitor.h - * \brief A simple visitor wrapper around ExprFunctor. - * + * \brief A simple visitor wrapper around ExprFunctor. + * * Exposes two visitors with default traversal strategies, one * which doesn't compute a result but can mutate internal state, * and another which functionally builds a new Expr. @@ -29,9 +29,7 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor { } } - void VisitExpr_(const ParamNode* op) override { - this->VisitExpr(op->var); - } + void VisitExpr_(const ParamNode* op) override { this->VisitExpr(op->var); } void VisitExpr_(const FunctionNode* op) override { for (auto param : op->params) { @@ -75,7 +73,7 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { return GetRef(op); } - Expr VisitExpr_(const ConstantNode* op) override { + Expr VisitExpr_(const ConstantNode* op) override { return GetRef(op); } @@ -83,9 +81,7 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { return GetRef(op); } - Expr VisitExpr_(const OpNode* op) override { - return GetRef(op); - } + Expr VisitExpr_(const OpNode* op) override { return GetRef(op); } Expr VisitExpr_(const TupleNode* op) override { tvm::Array fields; diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 7d0a58265565..9727e80fa561 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -366,7 +366,6 @@ inline OpRegistry& OpRegistry::add_type_rel( inline OpRegistry& OpRegistry::add_type_rel(const std::string& type_func_name, TypeRelationFn type_fn) { - auto type_func = TypeRelationNode::make(type_func_name, 0, type_fn); std::vector type_params; std::vector arg_types; @@ -388,9 +387,9 @@ inline OpRegistry& OpRegistry::add_type_rel(const std::string& type_func_name, type_params.push_back(out_param); ty_call_args.push_back(out_param); - auto type_result = TypeCallNode::make(type_func, ty_call_args); + TypeConstraint type_rel = TypeRelationNode::make(type_func_name, type_fn, ty_call_args); - auto func_type = FuncTypeNode::make(arg_types, type_result, type_params, {}); + auto func_type = FuncTypeNode::make(arg_types, out_param, type_params, { type_rel }); get()->op_type = func_type; diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index f485e0d8d62f..5e3665dfbd1d 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -129,6 +129,7 @@ class TypeParamNode : public TypeNode { kShape = 1, kBaseType = 2, kType = 3, + kTypeList = 4, }; /*! * \brief The variable itself is only meaningful when @@ -158,13 +159,13 @@ RELAY_DEFINE_NODE_REF(TypeParam, TypeParamNode, Type); */ class TypeConstraint; /*! \brief TypeConstraint container node. */ -class TypeConstraintNode : public Node { +class TypeConstraintNode : public TypeNode { public: static constexpr const char* _type_key = "relay.TypeConstraint"; - TVM_DECLARE_BASE_NODE_INFO(TypeConstraintNode, Node); + TVM_DECLARE_BASE_NODE_INFO(TypeConstraintNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(TypeConstraint, TypeConstraintNode, NodeRef); +RELAY_DEFINE_NODE_REF(TypeConstraint, TypeConstraintNode, Type); class FuncType; /*! @@ -221,12 +222,11 @@ class TypeRelation; * \note This node is not directly serializable. * The type function need to be lookedup in the environment. */ -class TypeRelationNode : public RelayNode { +class TypeRelationNode : public TypeConstraintNode { public: /*! \brief The name of the function */ std::string name; - /*! \brief Number of input type arguments, can be -1, which means VarArgs */ - int num_args; + /*! * \brief The function on input and output variables which * this is not directly serializable, @@ -234,49 +234,20 @@ class TypeRelationNode : public RelayNode { */ TypeRelationFn func_; - void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("name", &name); - v->Visit("num_args", &num_args); - } - - TVM_DLL static TypeRelation make(std::string name, int num_args, - TypeRelationFn func_); - - static constexpr const char* _type_key = "relay.TypeRelation"; - TVM_DECLARE_NODE_TYPE_INFO(TypeRelationNode, RelayNode); -}; - -RELAY_DEFINE_NODE_REF(TypeRelation, TypeRelationNode, Type); - -/*! - * \brief Call a type function with some number of arguments. - */ -class TypeCall; -/*! - * \brief TypeCall container. - */ -class TypeCallNode : public TypeNode { - public: - /*! \brief The type function to be called. */ - Type func; - /*! \brief The type arguments to the type function. */ tvm::Array args; - TypeCallNode() {} - void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("func", &func); - v->Visit("args", &args); + v->Visit("name", &name); } - TVM_DLL static TypeCall make(Type func, tvm::Array args); + TVM_DLL static TypeRelation make(std::string name, TypeRelationFn func_, Array args); - static constexpr const char* _type_key = "relay.TypeCall"; - TVM_DECLARE_NODE_TYPE_INFO(TypeCallNode, TypeNode); + static constexpr const char* _type_key = "relay.TypeRelation"; + TVM_DECLARE_NODE_TYPE_INFO(TypeRelationNode, TypeConstraintNode); }; -RELAY_DEFINE_NODE_REF(TypeCall, TypeCallNode, Type); +RELAY_DEFINE_NODE_REF(TypeRelation, TypeRelationNode, TypeConstraint); /*! * \brief The type of tuple values. diff --git a/python/tvm/relay/type.py b/python/tvm/relay/type.py index 22c853ef512f..ca74f2c5deb3 100644 --- a/python/tvm/relay/type.py +++ b/python/tvm/relay/type.py @@ -143,13 +143,6 @@ def __init__(self, _make.FuncType, arg_types, ret_type, type_params, type_constraints) -@register_relay_node -class TypeCall(Type): - def __init__(self, type_rel, args) -> None: - self.__init_handle_by_constructor__( - _make.TypeCall, type_rel, args) - - @register_relay_node class IncompleteType(Type): """An incomplete type.""" diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 7c005acb8648..a6dbd769f75f 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -202,8 +202,8 @@ Op SpecializeOp(const std::string& op_name, const std::string& new_op_name, subst_map.Set(fn_ty->type_params[i], type_args[i]); } - Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, {}); - inst_ty = TypeSubst(fn_ty, subst_map); + Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, fn_ty->type_constraints); + inst_ty = TypeSubst(inst_ty, subst_map); FuncType new_op_ty = GetRef(inst_ty.as()); new_op_reg.op()->op_type = new_op_ty; diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index 2975c60cc0c1..73be2400ba2e 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -93,12 +93,11 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << node->type_constraints << ")"; }); -TypeRelation TypeRelationNode::make(std::string name, int num_args, - TypeRelationFn func) { +TypeRelation TypeRelationNode::make(std::string name, TypeRelationFn func, Array args) { std::shared_ptr n = std::make_shared(); n->name = std::move(name); - n->num_args = std::move(num_args); n->func_ = std::move(func); + n->args = std::move(args); return TypeRelation(n); } @@ -110,28 +109,10 @@ TVM_REGISTER_API("relay._make.TypeRelation") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const TypeRelationNode *node, tvm::IRPrinter *p) { - p->stream << "TypeRelationNode(" << node->name << ", " << node->num_args + p->stream << "TypeRelationNode(" << node->name << ", " << node->args << ")"; }); -TypeCall TypeCallNode::make(Type func, Array args) { - std::shared_ptr n = std::make_shared(); - n->func = std::move(func); - n->args = std::move(args); - return TypeCall(n); -} - -TVM_REGISTER_API("relay._make.TypeCall") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = TypeCallNode::make(args[0], args[1]); - }); - -TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const TypeCallNode *node, - tvm::IRPrinter *p) { - p->stream << "TypeCallNode(" << node->func << ", " << node->args << ")"; - }); - TupleType TupleTypeNode::make(Array fields) { std::shared_ptr n = std::make_shared(); n->fields = std::move(fields); diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index e2b2cba1e0ef..751583e738d4 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -60,7 +60,8 @@ static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, auto dim1 = to_int(*rev_sh1); auto dim2 = to_int(*rev_sh2); if ((dim1 != dim2) && ((dim1 != 1) && (dim2 != 1))) { - CHECK(false) << "Dimension mistmatch " << "dim1: " << dim1 << " dim2: " << dim2 << std::endl; + CHECK(false) << "Dimension mistmatch " + << "dim1: " << dim1 << " dim2: " << dim2 << std::endl; } rev_sh1++; rev_sh2++; @@ -106,7 +107,8 @@ static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, Array BroadcastRel(const Array& types, int num_args) { CHECK_EQ(types.size(), 3); - RELAY_LOG(INFO) << "In1: " << types[0] << "In2: " << types[1] << "Out: " << types[2] << std::endl; + RELAY_LOG(INFO) << "In1: " << types[0] << "In2: " << types[1] + << "Out: " << types[2] << std::endl; if (auto t1 = as_ttype(types[0])) { if (auto t2 = as_ttype(types[1])) { CHECK_EQ(t1->dtype, t2->dtype); diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc index 555d4f2db99d..764a9139c5f6 100644 --- a/src/relay/pass/alpha_eq.cc +++ b/src/relay/pass/alpha_eq.cc @@ -118,26 +118,26 @@ struct TypeAlphaEq : TypeVisitor { // } // } - void VisitType_(const TypeCallNode *tyn1, const Type &t2) override { - TypeCall tycall = GetRef(tyn1); - if (const TypeCallNode *tyn2 = t2.as()) { - if (tycall->func != tyn2->func) { - equal = false; - return; - } - - if (tycall->args.size() != tyn2->args.size()) { - equal = false; - return; - } - - for (size_t i = 0U; i < tycall->args.size(); i++) { - this->VisitType(tycall->args[i], tyn2->args[i]); - } - } else { - equal = false; - } - } + // void VisitType_(const TypeCallNode *tyn1, const Type &t2) override { + // TypeCall tycall = GetRef(tyn1); + // if (const TypeCallNode *tyn2 = t2.as()) { + // if (tycall->func != tyn2->func) { + // equal = false; + // return; + // } + + // if (tycall->args.size() != tyn2->args.size()) { + // equal = false; + // return; + // } + + // for (size_t i = 0U; i < tycall->args.size(); i++) { + // this->VisitType(tycall->args[i], tyn2->args[i]); + // } + // } else { + // equal = false; + // } + // } }; bool AlphaEqual(const Type &t1, const Type &t2) { diff --git a/src/relay/pass/type_functor.h b/src/relay/pass/type_functor.h index 9180703b49e8..339552108af4 100644 --- a/src/relay/pass/type_functor.h +++ b/src/relay/pass/type_functor.h @@ -64,7 +64,6 @@ class TypeFunctor { virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; @@ -83,7 +82,6 @@ class TypeFunctor { RELAY_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode); RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode); RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode); - RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode); RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode); RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode); return vtable; diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 4873b0a55580..d530fe19782b 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -35,13 +35,31 @@ namespace relay { using namespace tvm::runtime; +struct TypeConstraintSet { + std::vector ty_rels; + TypeConstraintSet() : ty_rels() {} + TypeConstraintSet(const std::vector& cs) : ty_rels(cs) {} + void Add(const TypeConstraint & ty_rel) { + ty_rels.push_back(ty_rel); + } +}; + struct TypeContext { std::vector> stack; + std::vector constraints; TypeContext() { stack.push_back({}); } void insert(const LocalVar &id, const Type &t) { stack.back()[id] = t; } + void AddConstraint(const TypeConstraint & ty_rel) { + constraints.back().Add(ty_rel); + } + + // TypeConstraint & Constraints() { + // return + // } + Type lookup(const LocalVar &id) { for (auto frame = stack.rbegin(); frame != stack.rend(); ++frame) { if (frame->find(id) != frame->end()) { @@ -53,7 +71,10 @@ struct TypeContext { struct LocalFrame { TypeContext &tc; - explicit LocalFrame(TypeContext &tc) : tc(tc) { tc.stack.push_back({}); } + explicit LocalFrame(TypeContext &tc) : tc(tc) { + tc.stack.push_back({}); + tc.constraints.push_back({}); + } ~LocalFrame() { tc.stack.pop_back(); } }; }; @@ -62,41 +83,41 @@ struct TypeNormalizer : TypeFVisitor { TypeUnifier unifier; explicit TypeNormalizer(const TypeUnifier &unifier) : unifier(unifier) {} - Type VisitType_(const TypeCallNode *ty_call_node) { - auto ty_call = GetRef(ty_call_node); - - Array normalized_args; - - for (auto arg : ty_call->args) { - normalized_args.push_back(VisitType(arg)); - } - - auto all_concrete = true; - for (auto arg : normalized_args) { - all_concrete = all_concrete && !arg.as(); - } - - if (all_concrete) { - return normalized_args[normalized_args.size() - 1]; - } else { - if (auto ty_rel_node = ty_call->func.as()) { - // NB: we substract 1 for the output argument. - auto new_args = - ty_rel_node->func_(ty_call->args, ty_call->args.size() - 1); - CHECK(new_args.size() == normalized_args.size()); - tvm::Array final_args; - - for (size_t i = 0; i < new_args.size(); i++) { - final_args.push_back(unifier->unify(normalized_args[i], new_args[i])); - } - - return TypeCallNode::make(ty_call->func, final_args); - } else { - throw InternalError("found non type relation in the "\ - "type call function position"); - } - } - } + // Type VisitType_(const TypeRelationNode *ty_call_node) { + // auto ty_call = GetRef(ty_call_node); + + // Array normalized_args; + + // for (auto arg : ty_call->args) { + // normalized_args.push_back(VisitType(arg)); + // } + + // auto all_concrete = true; + // for (auto arg : normalized_args) { + // all_concrete = all_concrete && !arg.as(); + // } + + // if (all_concrete) { + // return normalized_args[normalized_args.size() - 1]; + // } else { + // if (auto ty_rel_node = ty_call->func.as()) { + // // NB: we substract 1 for the output argument. + // auto new_args = + // ty_rel_node->func_(ty_call->args, ty_call->args.size() - 1); + // CHECK(new_args.size() == normalized_args.size()); + // tvm::Array final_args; + + // for (size_t i = 0; i < new_args.size(); i++) { + // final_args.push_back(unifier->unify(normalized_args[i], new_args[i])); + // } + + // return TypeCallNode::make(ty_call->func, final_args); + // } else { + // throw InternalError("found non type relation in the "\ + // "type call function position"); + // } + // } + // } }; struct CheckedExpr { @@ -106,9 +127,11 @@ struct CheckedExpr { CheckedExpr() {} }; +enum SolverResult : int; + class TypeInferencer : private ExprFunctor { private: - TypeContext local_stack; + TypeContext context; public: Environment env; @@ -117,7 +140,7 @@ class TypeInferencer : private ExprFunctor { // Should be in header? template T with_frame(const std::function &f) { - TypeContext::LocalFrame fr(local_stack); + TypeContext::LocalFrame fr(context); return f(); } @@ -138,6 +161,11 @@ class TypeInferencer : private ExprFunctor { Type unify(const Type &t1, const Type &t2, Span sp); Type resolve(const Type &t); Expr resolve(const Expr &e); + TypeRelation Solve(const TypeRelation & ty_rel); + SolverResult Solve(std::vector & rels); + + /*! \brief Check that all relations hold. */ + bool RelationsHold(bool scope_only = false); CheckedExpr VisitFunction(const Function &f, bool generalize); private: CheckedExpr VisitExpr_(const LocalVarNode *op) override; @@ -181,7 +209,7 @@ CheckedExpr TypeInferencer::Infer(const Expr &expr) { CheckedExpr TypeInferencer::VisitExpr_(const LocalVarNode *op) { auto var = GetRef(op); - return {var, this->local_stack.lookup(var)}; + return {var, this->context.lookup(var)}; } CheckedExpr TypeInferencer::VisitExpr_(const GlobalVarNode *op) { @@ -232,7 +260,7 @@ CheckedExpr TypeInferencer::VisitFunction(const Function &f, bool generalize) { Type arg_type; param_types.push_back(checked_param.type); params.push_back(GetRef(checked_param.expr.as())); - this->local_stack.insert(param->var, checked_param.type); + this->context.insert(param->var, checked_param.type); } auto checked_body = this->Infer(f->body); @@ -263,8 +291,8 @@ FuncType TypeInferencer::instantiate(FuncType fn_ty, subst_map.Set(ty_param, fresh); } - Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, {}); - inst_ty = TypeSubst(fn_ty, subst_map); + Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, fn_ty->type_constraints); + inst_ty = TypeSubst(inst_ty, subst_map); CHECK(KindCheck(this->env, inst_ty)); @@ -296,6 +324,7 @@ CheckedExpr TypeInferencer::VisitExpr_(const CallNode *op) { fn_ty = instantiate(fn_ty, ty_args); + std::vector arg_types; std::vector checked_args; @@ -328,6 +357,11 @@ CheckedExpr TypeInferencer::VisitExpr_(const CallNode *op) { ty_args.Set(i, this->unifier->subst(ty_args[i])); } + // Add type constraints from the function types. + for (auto cs : fn_ty->type_constraints) { + context.AddConstraint(cs); + } + auto new_call = CallNode::make(checked_op.expr, checked_args, c->attrs, ty_args); @@ -345,7 +379,7 @@ CheckedExpr TypeInferencer::VisitExpr_(const LetNode *op) { // local definitions. if (let->value.as()) { with_frame([&]() { - local_stack.insert(let->var, annotated_ty); + context.insert(let->var, annotated_ty); checked_value = Infer(let->value); }); } else { @@ -356,10 +390,10 @@ CheckedExpr TypeInferencer::VisitExpr_(const LetNode *op) { // Update type context with unified type now that we have // solved this equation. - local_stack.insert(let->var, unified_ty); + context.insert(let->var, unified_ty); auto checked_body = with_frame([&]() { - local_stack.insert(let->var, unified_ty); + context.insert(let->var, unified_ty); return Infer(let->body); }); @@ -405,9 +439,124 @@ Expr TypeInferencer::resolve(const Expr& e) { return ::tvm::relay::Resolve(this->unifier, e); } +TypeRelation TypeInferencer::Solve(const TypeRelation & ty_rel) { + Array normalized_args; + + for (auto arg : ty_rel->args) { + normalized_args.push_back(resolve(arg)); + } + + auto new_args = ty_rel->func_(normalized_args, ty_rel->args.size() - 1); + + CHECK(new_args.size() == normalized_args.size()); + tvm::Array final_args; + + for (size_t i = 0; i < new_args.size(); i++) { + final_args.push_back(unifier->unify(normalized_args[i], new_args[i])); + } + + return TypeRelationNode::make(ty_rel->name, ty_rel->func_, final_args); +} + +int NumSolvedVars(const TypeRelation & ty_rel) { + int num = 0; + for (auto arg : ty_rel->args) { + if (!arg.as()) { + num += 1; + } else { + std::cout << "arg :" << arg << std::endl; + } + } + return num; +} + +enum SolverResult : int { + Failed = -1, + Progress = 0, + Done = 1, +}; + + +SolverResult TypeInferencer::Solve(std::vector & rels) { + // We start in the done state with zero progress. + SolverResult status = SolverResult::Done; + int progress = 0; + + do { + // Upon rentering the loop we reset the state. + status = SolverResult::Done; + progress = 0; + + // We will now process each relation in order. + for (TypeRelation & ty_rel : rels) { + std::cout << "TypeRelation: " << ty_rel << std::endl; + int arity = ty_rel->args.size(); + int pre_solved = NumSolvedVars(ty_rel); + std::cout << "Arity: " << arity << " " << "Solved: " << pre_solved << std::endl; + // If the relation is already solved then we will make no progress but try to + // set the status to done. + if (pre_solved == arity) { + status = static_cast((status && SolverResult::Done)); + // If there are unsolved variables we will try to solve some. + } else if (pre_solved < arity) { + auto solved = Solve(ty_rel); + int post_solved = NumSolvedVars(solved); + + // If we solved any variables we will try to downgrade status to progress + // update the type relation, and then bump the progress counter by one. + if (post_solved > pre_solved) { + status = static_cast((status && SolverResult::Progress)); + ty_rel = solved; + progress += 1; + } + } + } + + // If we made no progress and we aren't finished, then the state should be + // downgraded to fail, then we should exit the loop. + if (progress == 0 && status != SolverResult::Done) { + status = SolverResult::Failed; + break; + } + } while (status == SolverResult::Progress); + return status; +} + + +bool TypeInferencer::RelationsHold(bool scope_only) { + // If we are only checking the top scope, + // slice out the constraints. + // + // Otherwise we use all of them. + std::vector constraints; + if (scope_only) { + constraints = { context.constraints[0] }; + } else { + constraints = context.constraints; + } + + std::cout << "Constraints hold " << std::endl; + bool all_hold = true; + for (auto cs_set : context.constraints) { + auto ty_rels = Downcast(cs_set.ty_rels); + auto status = Solve(ty_rels); + std::cout << "Status: " << status << std::endl; + if (status == SolverResult::Failed || status == SolverResult::Progress) { + all_hold = false; + } else if (status == SolverResult::Done) { + continue; + } else { + throw InternalError("found invalid value for SolverResult"); + } + } + + return all_hold; +} + Expr InferType(const Environment &env, const Expr &e) { TypeInferencer ti(env); auto checked_expr = ti.Infer(e); + CHECK(ti.RelationsHold()); return ti.resolve(checked_expr.expr); } @@ -417,6 +566,7 @@ Expr InferType(const Environment &env, const GlobalVar & var, const Function & f func_copy->checked_type_ = ti.resolve(func_copy->fn_type()); env->functions.Set(var, func_copy); auto checked_expr = ti.Infer(func); + CHECK(ti.RelationsHold()); auto map_node = env->functions.CopyOnWrite(); map_node->data.erase(var.node_); return ti.resolve(checked_expr.expr); @@ -465,6 +615,8 @@ TVM_REGISTER_API("relay._ir_pass._get_checked_type") *ret = e->checked_type(); }); +/* Incomplete Type */ + IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) { std::shared_ptr n = std::make_shared(); diff --git a/src/relay/pass/type_visitor.h b/src/relay/pass/type_visitor.h index d65d6c567b23..ece9f27613bf 100644 --- a/src/relay/pass/type_visitor.h +++ b/src/relay/pass/type_visitor.h @@ -22,9 +22,14 @@ struct TypeVisitor : ::tvm::relay::TypeFunctor { void VisitType_(const TypeParamNode* op, Args... args) override {} void VisitType_(const FuncTypeNode* op, Args... args) override { - // TODO(@jroesch): handle poly - // this->VisitType(op->var, args...); - // this->VisitType(op->boundType, args...); + for (auto type_param : op->type_params) { + this->VisitType(type_param, args...); + } + + for (auto type_cs : op->type_constraints) { + this->VisitType(type_cs, args...); + } + for (auto arg_type : op->arg_types) { this->VisitType(arg_type, args...); } @@ -39,15 +44,12 @@ struct TypeVisitor : ::tvm::relay::TypeFunctor { } } - void VisitType_(const TypeCallNode* op, Args... args) override { - this->VisitType(op->func, args...); - + void VisitType_(const TypeRelationNode* op, Args... args) override { for (const Type& t : op->args) { this->VisitType(t, args...); } } - void VisitType_(const TypeRelationNode* op, Args... args) override {} void VisitType_(const IncompleteTypeNode* op, Args... args) override {} }; @@ -63,12 +65,25 @@ struct TypeFVisitor : TypeFunctor { } Type VisitType_(const FuncTypeNode* op) override { - // TODO(@jroesch): handle poly + Array type_params; + for (auto type_param : op->type_params) { + auto new_type_param = VisitType(type_param); + if (const TypeParamNode* tin = new_type_param.as()) { + type_params.push_back(GetRef(tin)); + } else { + CHECK(false) << new_type_param << std::endl; + } + } - // auto new_id = this->VisitType(op->var); - // if (const TypeParamNode* tin = new_id.as()) { - // return TypeQuantifierNode::make(GetRef(tin), - // this->VisitType(op->boundType)); + Array type_constraints; + for (auto type_cs : op->type_constraints) { + auto new_type_cs = VisitType(type_cs); + if (const TypeConstraintNode* tin = As(new_type_cs)) { + type_constraints.push_back(GetRef(tin)); + } else { + CHECK(false) << new_type_cs << std::endl; + } + } std::vector args; for (auto arg_type : op->arg_types) { @@ -76,7 +91,7 @@ struct TypeFVisitor : TypeFunctor { } return FuncTypeNode::make(tvm::Array(args), VisitType(op->ret_type), - {}, {}); // fix me + type_params, type_constraints); } Type VisitType_(const TupleTypeNode* op) override { @@ -87,17 +102,12 @@ struct TypeFVisitor : TypeFunctor { return TupleTypeNode::make(new_fields); } - Type VisitType_(const TypeRelationNode* op) override { - return GetRef(op); - } - - Type VisitType_(const TypeCallNode* op) override { - auto func = this->VisitType(op->func); + Type VisitType_(const TypeRelationNode* type_rel) override { std::vector new_args; - for (const Type& t : op->args) { + for (const Type& t : type_rel->args) { new_args.push_back(this->VisitType(t)); } - return TypeCallNode::make(func, new_args); + return TypeRelationNode::make(type_rel->name, type_rel->func_, new_args); } Type VisitType_(const IncompleteTypeNode* op) override { diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index f5e337eb17f7..f80daa8d3bd0 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -6,6 +6,7 @@ */ #include "./unifier.h" +#include #include #include #include @@ -180,12 +181,6 @@ Type TypeUnifierNode::VisitType(const Type &t1, const Type t2) { // When the right hand size is a type variable immediately unify. if (const IncompleteTypeNode *tvn2 = t2.as()) { return this->unifyWithIncompleteType(t1, GetRef(tvn2)); - // The TypeCallNode case is special and not symmetric. - // - // We flip the arguments so we hit the TypeCall and other case in there is - // ever a type call. - } else if (t2.as()) { - return TypeFunctor::VisitType(t2, t1); } else { return TypeFunctor::VisitType(t1, t2); } @@ -332,42 +327,7 @@ Type TypeUnifierNode::VisitType_(const TupleTypeNode *t1, const Type rt2) { } Type TypeUnifierNode::VisitType_(const TypeRelationNode *tr1, const Type t2) { - if (const TypeRelationNode *tr2 = t2.as()) { - if (tr1 == tr2) { - return GetRef(tr1); - } else { - throw UnificationError("Cannot unify different type relations"); - } - } else { - throw UnificationError( - "Cannot unify type relation with another type of type"); - } -} - -Type TypeUnifierNode::VisitType_(const TypeCallNode *tcn1, const Type t2) { - TypeCall ty_call1 = GetRef(tcn1); - - if (const TypeCallNode *tcn2 = t2.as()) { - Type unified_func = this->VisitType(ty_call1->func, tcn2->func); - - // For now, we will only unify if they are equal. - if (ty_call1->args.size() != tcn2->args.size()) { - throw UnificationError( - "Cannot unify calls of different number of arguments"); - } - - // Unify members, if possible - tvm::Array new_args; - for (size_t i = 0U; i < ty_call1->args.size(); i++) { - Type unified_member = this->VisitType(ty_call1->args[i], tcn2->args[i]); - new_args.push_back(unified_member); - } - - return TypeCallNode::make(unified_func, new_args); - } else { - auto args = ty_call1->args; - return this->VisitType(args[args.size() - 1], t2); - } + throw InternalError("Cannot unify different type relations"); } } // namespace relay diff --git a/src/relay/pass/unifier.h b/src/relay/pass/unifier.h index 0671a40c0d74..a5f3c60a85df 100644 --- a/src/relay/pass/unifier.h +++ b/src/relay/pass/unifier.h @@ -113,7 +113,6 @@ class TypeUnifierNode : public Node, Type VisitType_(const FuncTypeNode* t1, const Type t2) override; Type VisitType_(const TupleTypeNode* t1, const Type t2) override; Type VisitType_(const TypeRelationNode* s1, const Type t2) override; - Type VisitType_(const TypeCallNode* s1, const Type t2) override; }; class TypeUnifier : public NodeRef { diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index f9a3d098a3e2..5338fad9ad8c 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -16,14 +16,14 @@ import nnvm -def has_type(expr, typ, env=Environment({})): +def assert_has_type(expr, typ, env=Environment({})): checked_expr = check_expr(env, expr) - return checked_expr.checked_type() == typ + assert checked_expr.checked_type() == typ -def decl_has_type(env, name, typ): +def assert_decl_has_type(env, name, typ): func = env.lookup(name) - return func.checked_type() == typ + assert func.checked_type() == typ def run(env, expr, inputs, shape): @@ -50,8 +50,9 @@ def test_monomorphic_let(): b.ret(x) prog, env = b.get() - assert has_type(prog, float_type(64)) - run(env, prog, [], float_type(64)) + assert_has_type(prog, float_type(64)) + # Need to handle constants + # run(env, prog, [], float_type(64)) def test_single_op(): @@ -61,7 +62,7 @@ def test_single_op(): x, = func.param_ids() t1 = b.let('t1', log(x)) b.ret(t1) - assert has_type(func.to_func(), func_type([float_type()], float_type())) + assert_has_type(func.to_func(), func_type([float_type()], float_type())) def test_add_op(): @@ -80,7 +81,7 @@ def test_add_op(): prog, env = b.get() ttype = tensor_type(5, 5, 5) expected_ty = func_type([ttype, ttype], ttype) - assert has_type(func.to_func(), expected_ty) + assert_has_type(func.to_func(), expected_ty) x_data = tvm.nd.array(np.random.rand(5, 5, 5).astype('float32')) y_data = tvm.nd.array(np.random.rand(5, 5, 5).astype('float32')) result = run(env, prog, {'x': x_data, 'y': y_data}, (5, 5, 5)) @@ -103,7 +104,7 @@ def test_add_broadcast_op(): prog, env = b.get() ttype = tensor_type(5, 5, 5) expected_ty = func_type([ttype, ttype], ttype) - assert has_type(func.to_func(), expected_ty) + assert_has_type(func.to_func(), expected_ty) x_data = tvm.nd.array(np.random.rand(10, 4).astype('float32')) y_data = tvm.nd.array(np.random.rand(5, 10, 1).astype('float32')) result = run(env, prog, {'x': x_data, 'y': y_data}, (5, 10, 4)) @@ -124,7 +125,7 @@ def test_dual_op(): t1 = b.let('t1', log(x)) t2 = b.let('t2', add(t1, x)) b.ret(t2) - assert has_type(func.to_func(), func_type([float_type()], float_type())) + assert_has_type(func.to_func(), func_type([float_type()], float_type())) def test_decl(): @@ -140,7 +141,7 @@ def f(x : Tensor[f32, (10, 10)]) { lx = b.let('lx', log(x)) b.ret(lx) _, env = b.get() - assert decl_has_type(env, 'f', func_type([float_type()], float_type())) + assert_decl_has_type(env, 'f', func_type([float_type()], float_type())) def test_recursion(): @@ -165,7 +166,7 @@ def f(n: i32, data: f32) -> f32 { with b.else_scope(): b.ret(data) b.ret(f(into_ast(2.0), into_ast(10000.0))) - assert decl_has_type(b.env, 'f', func_type( + assert_decl_has_type(b.env, 'f', func_type( [int_type(), float_type()], float_type())) # TODO(@jroesch): need evaluator or new runtime # to execute this. @@ -175,6 +176,6 @@ def f(n: i32, data: f32) -> f32 { test_single_op() test_add_op() test_add_broadcast_op() - # test_dual_op() - # test_decl() - # test_recursion() + test_dual_op() + test_decl() + test_recursion() From cd02ec790c1d6c41ed8d9d4aad9cd7ede31d1d53 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 14 Sep 2018 14:22:16 -0700 Subject: [PATCH 088/136] Convert printing to logging --- src/relay/pass/type_infer.cc | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index d530fe19782b..385a9c5a4a1e 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -463,8 +463,6 @@ int NumSolvedVars(const TypeRelation & ty_rel) { for (auto arg : ty_rel->args) { if (!arg.as()) { num += 1; - } else { - std::cout << "arg :" << arg << std::endl; } } return num; @@ -489,10 +487,9 @@ SolverResult TypeInferencer::Solve(std::vector & rels) { // We will now process each relation in order. for (TypeRelation & ty_rel : rels) { - std::cout << "TypeRelation: " << ty_rel << std::endl; int arity = ty_rel->args.size(); int pre_solved = NumSolvedVars(ty_rel); - std::cout << "Arity: " << arity << " " << "Solved: " << pre_solved << std::endl; + RELAY_LOG(INFO) << "TypeInferencer::Solve: " << "TypeRelation= " << ", Arity=" << arity << ", Solved=" << pre_solved << std::endl; // If the relation is already solved then we will make no progress but try to // set the status to done. if (pre_solved == arity) { @@ -535,12 +532,12 @@ bool TypeInferencer::RelationsHold(bool scope_only) { constraints = context.constraints; } - std::cout << "Constraints hold " << std::endl; + RELAY_LOG(INFO) << "TypeInferencer::RelationsHold: scope_only= " << scope_only << std::endl; bool all_hold = true; for (auto cs_set : context.constraints) { auto ty_rels = Downcast(cs_set.ty_rels); auto status = Solve(ty_rels); - std::cout << "Status: " << status << std::endl; + RELAY_LOG(INFO) << "status= " << status << std::endl; if (status == SolverResult::Failed || status == SolverResult::Progress) { all_hold = false; } else if (status == SolverResult::Done) { From 561e710d278fec76ca358fed7e1e483aba093977 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 14 Sep 2018 14:24:35 -0700 Subject: [PATCH 089/136] Remove Normalizer --- src/relay/pass/type_infer.cc | 144 +++++++++++++---------------------- 1 file changed, 52 insertions(+), 92 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 385a9c5a4a1e..e29a22234a8a 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -38,10 +38,8 @@ using namespace tvm::runtime; struct TypeConstraintSet { std::vector ty_rels; TypeConstraintSet() : ty_rels() {} - TypeConstraintSet(const std::vector& cs) : ty_rels(cs) {} - void Add(const TypeConstraint & ty_rel) { - ty_rels.push_back(ty_rel); - } + TypeConstraintSet(const std::vector &cs) : ty_rels(cs) {} + void Add(const TypeConstraint &ty_rel) { ty_rels.push_back(ty_rel); } }; struct TypeContext { @@ -52,12 +50,12 @@ struct TypeContext { void insert(const LocalVar &id, const Type &t) { stack.back()[id] = t; } - void AddConstraint(const TypeConstraint & ty_rel) { + void AddConstraint(const TypeConstraint &ty_rel) { constraints.back().Add(ty_rel); } // TypeConstraint & Constraints() { - // return + // return // } Type lookup(const LocalVar &id) { @@ -69,57 +67,16 @@ struct TypeContext { throw FatalTypeError("Could not resolve local id"); } - struct LocalFrame { + struct Frame { TypeContext &tc; - explicit LocalFrame(TypeContext &tc) : tc(tc) { - tc.stack.push_back({}); + explicit Frame(TypeContext &tc) : tc(tc) { + tc.stack.push_back({}); tc.constraints.push_back({}); } - ~LocalFrame() { tc.stack.pop_back(); } + ~Frame() { tc.stack.pop_back(); } }; }; -struct TypeNormalizer : TypeFVisitor { - TypeUnifier unifier; - explicit TypeNormalizer(const TypeUnifier &unifier) : unifier(unifier) {} - - // Type VisitType_(const TypeRelationNode *ty_call_node) { - // auto ty_call = GetRef(ty_call_node); - - // Array normalized_args; - - // for (auto arg : ty_call->args) { - // normalized_args.push_back(VisitType(arg)); - // } - - // auto all_concrete = true; - // for (auto arg : normalized_args) { - // all_concrete = all_concrete && !arg.as(); - // } - - // if (all_concrete) { - // return normalized_args[normalized_args.size() - 1]; - // } else { - // if (auto ty_rel_node = ty_call->func.as()) { - // // NB: we substract 1 for the output argument. - // auto new_args = - // ty_rel_node->func_(ty_call->args, ty_call->args.size() - 1); - // CHECK(new_args.size() == normalized_args.size()); - // tvm::Array final_args; - - // for (size_t i = 0; i < new_args.size(); i++) { - // final_args.push_back(unifier->unify(normalized_args[i], new_args[i])); - // } - - // return TypeCallNode::make(ty_call->func, final_args); - // } else { - // throw InternalError("found non type relation in the "\ - // "type call function position"); - // } - // } - // } -}; - struct CheckedExpr { Expr expr; Type type; @@ -140,7 +97,7 @@ class TypeInferencer : private ExprFunctor { // Should be in header? template T with_frame(const std::function &f) { - TypeContext::LocalFrame fr(context); + TypeContext::Frame fr(context); return f(); } @@ -161,12 +118,13 @@ class TypeInferencer : private ExprFunctor { Type unify(const Type &t1, const Type &t2, Span sp); Type resolve(const Type &t); Expr resolve(const Expr &e); - TypeRelation Solve(const TypeRelation & ty_rel); - SolverResult Solve(std::vector & rels); - + TypeRelation Solve(const TypeRelation &ty_rel); + SolverResult Solve(std::vector &rels); + /*! \brief Check that all relations hold. */ bool RelationsHold(bool scope_only = false); CheckedExpr VisitFunction(const Function &f, bool generalize); + private: CheckedExpr VisitExpr_(const LocalVarNode *op) override; CheckedExpr VisitExpr_(const GlobalVarNode *op) override; @@ -189,18 +147,13 @@ TypeInferencer::TypeInferencer(Environment env) : env(env) { this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); } -Type TypeInferencer::Normalize(const Type &t) { - auto nt = this->resolve(t); - auto normalizer = TypeNormalizer(this->unifier); - return normalizer.VisitType(nt); -} CheckedExpr TypeInferencer::Infer(const Expr &expr) { RELAY_LOG(INFO) << "TypeInferencer::Check expr=" << expr << std::endl; CheckedExpr checked_expr = this->VisitExpr(expr); RELAY_LOG(INFO) << "TypeInferencer::Check type=" << checked_expr.type << std::endl; - Type final_type = Normalize(checked_expr.type); + Type final_type = checked_expr.type; RELAY_LOG(INFO) << "TypeInferencer::Check type_after_subst=" << final_type << std::endl; checked_expr.expr->checked_type_ = final_type; @@ -237,7 +190,8 @@ CheckedExpr TypeInferencer::VisitExpr_(const TupleNode *op) { } CheckedExpr TypeInferencer::VisitExpr_(const ParamNode *param) { - // We should trigger error here and move param code direclty into function checking. + // We should trigger error here and move param code direclty into function + // checking. auto rtype = resolve(param->type); // This is a special case ... not sure if there is a better way // to handle this. @@ -291,7 +245,8 @@ FuncType TypeInferencer::instantiate(FuncType fn_ty, subst_map.Set(ty_param, fresh); } - Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, fn_ty->type_constraints); + Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, + fn_ty->type_constraints); inst_ty = TypeSubst(inst_ty, subst_map); CHECK(KindCheck(this->env, inst_ty)); @@ -324,7 +279,6 @@ CheckedExpr TypeInferencer::VisitExpr_(const CallNode *op) { fn_ty = instantiate(fn_ty, ty_args); - std::vector arg_types; std::vector checked_args; @@ -411,13 +365,14 @@ CheckedExpr TypeInferencer::VisitExpr_(const IfNode *op) { auto checked_cond = this->Infer(ifn->cond); auto cond_type = checked_cond.type; - this->unify(cond_type, TensorTypeNode::make({}, HalideIR::Bool()), ifn->cond->span); + this->unify(cond_type, TensorTypeNode::make({}, HalideIR::Bool()), + ifn->cond->span); auto checked_true = this->Infer(ifn->true_value); auto checked_false = this->Infer(ifn->false_value); auto unified_type = - this->unify(checked_true.type, checked_false.type, ifn->span); - auto checked_if = IfNode::make(checked_cond.expr, checked_true.expr, - checked_false.expr); + this->unify(checked_true.type, checked_false.type, ifn->span); + auto checked_if = + IfNode::make(checked_cond.expr, checked_true.expr, checked_false.expr); return {checked_if, unified_type}; } @@ -426,7 +381,7 @@ CheckedExpr TypeInferencer::VisitExpr_(const OpNode *op_node) { return {op, op->op_type}; } -Type TypeInferencer::resolve(const Type& t) { +Type TypeInferencer::resolve(const Type &t) { if (t.defined()) { return ::tvm::relay::Resolve(this->unifier, t); } else { @@ -434,12 +389,12 @@ Type TypeInferencer::resolve(const Type& t) { } } -Expr TypeInferencer::resolve(const Expr& e) { +Expr TypeInferencer::resolve(const Expr &e) { CHECK(e.defined()); return ::tvm::relay::Resolve(this->unifier, e); } -TypeRelation TypeInferencer::Solve(const TypeRelation & ty_rel) { +TypeRelation TypeInferencer::Solve(const TypeRelation &ty_rel) { Array normalized_args; for (auto arg : ty_rel->args) { @@ -458,7 +413,7 @@ TypeRelation TypeInferencer::Solve(const TypeRelation & ty_rel) { return TypeRelationNode::make(ty_rel->name, ty_rel->func_, final_args); } -int NumSolvedVars(const TypeRelation & ty_rel) { +int NumSolvedVars(const TypeRelation &ty_rel) { int num = 0; for (auto arg : ty_rel->args) { if (!arg.as()) { @@ -469,13 +424,12 @@ int NumSolvedVars(const TypeRelation & ty_rel) { } enum SolverResult : int { - Failed = -1, + Failed = -1, Progress = 0, Done = 1, }; - -SolverResult TypeInferencer::Solve(std::vector & rels) { +SolverResult TypeInferencer::Solve(std::vector &rels) { // We start in the done state with zero progress. SolverResult status = SolverResult::Done; int progress = 0; @@ -486,30 +440,35 @@ SolverResult TypeInferencer::Solve(std::vector & rels) { progress = 0; // We will now process each relation in order. - for (TypeRelation & ty_rel : rels) { + for (TypeRelation &ty_rel : rels) { int arity = ty_rel->args.size(); int pre_solved = NumSolvedVars(ty_rel); - RELAY_LOG(INFO) << "TypeInferencer::Solve: " << "TypeRelation= " << ", Arity=" << arity << ", Solved=" << pre_solved << std::endl; - // If the relation is already solved then we will make no progress but try to - // set the status to done. + RELAY_LOG(INFO) << "TypeInferencer::Solve: " + << "TypeRelation= " + << ", Arity=" << arity << ", Solved=" << pre_solved + << std::endl; + // If the relation is already solved then we will make no progress but try + // to set the status to done. if (pre_solved == arity) { status = static_cast((status && SolverResult::Done)); - // If there are unsolved variables we will try to solve some. + // If there are unsolved variables we will try to solve some. } else if (pre_solved < arity) { auto solved = Solve(ty_rel); int post_solved = NumSolvedVars(solved); - // If we solved any variables we will try to downgrade status to progress - // update the type relation, and then bump the progress counter by one. + // If we solved any variables we will try to downgrade status to + // progress update the type relation, and then bump the progress counter + // by one. if (post_solved > pre_solved) { - status = static_cast((status && SolverResult::Progress)); + status = + static_cast((status && SolverResult::Progress)); ty_rel = solved; progress += 1; - } + } } } - // If we made no progress and we aren't finished, then the state should be + // If we made no progress and we aren't finished, then the state should be // downgraded to fail, then we should exit the loop. if (progress == 0 && status != SolverResult::Done) { status = SolverResult::Failed; @@ -519,7 +478,6 @@ SolverResult TypeInferencer::Solve(std::vector & rels) { return status; } - bool TypeInferencer::RelationsHold(bool scope_only) { // If we are only checking the top scope, // slice out the constraints. @@ -527,12 +485,13 @@ bool TypeInferencer::RelationsHold(bool scope_only) { // Otherwise we use all of them. std::vector constraints; if (scope_only) { - constraints = { context.constraints[0] }; + constraints = {context.constraints[0]}; } else { constraints = context.constraints; } - RELAY_LOG(INFO) << "TypeInferencer::RelationsHold: scope_only= " << scope_only << std::endl; + RELAY_LOG(INFO) << "TypeInferencer::RelationsHold: scope_only= " << scope_only + << std::endl; bool all_hold = true; for (auto cs_set : context.constraints) { auto ty_rels = Downcast(cs_set.ty_rels); @@ -557,9 +516,11 @@ Expr InferType(const Environment &env, const Expr &e) { return ti.resolve(checked_expr.expr); } -Expr InferType(const Environment &env, const GlobalVar & var, const Function & func) { +Expr InferType(const Environment &env, const GlobalVar &var, + const Function &func) { TypeInferencer ti(env); - auto func_copy = FunctionNode::make(func->params, func->ret_type, func->body, func->type_params); + auto func_copy = FunctionNode::make(func->params, func->ret_type, func->body, + func->type_params); func_copy->checked_type_ = ti.resolve(func_copy->fn_type()); env->functions.Set(var, func_copy); auto checked_expr = ti.Infer(func); @@ -569,7 +530,6 @@ Expr InferType(const Environment &env, const GlobalVar & var, const Function & f return ti.resolve(checked_expr.expr); } - inline void TypeInferencer::report_error(const std::string &msg, Span sp) { this->env->AddDiagnostic({msg, sp}); } @@ -584,7 +544,7 @@ void TypeInferencer::fatal_error(const std::string &msg, Span sp) { Type TypeInferencer::unify(const Type &t1, const Type &t2, Span sp) { try { - return Normalize(this->unifier->unify(t1, t2)); + return this->unifier->unify(t1, t2); } catch (const dmlc::Error &e) { std::stringstream ss; ss << "Error unifying `"; From b327d051503ac87f888517f0db57d56daa8ddf54 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 14 Sep 2018 15:02:52 -0700 Subject: [PATCH 090/136] Address a bunch of CR feedback --- include/tvm/base.h | 1 + include/tvm/relay/base.h | 10 +++--- include/tvm/relay/environment.h | 2 +- include/tvm/relay/expr.h | 31 ++++++++-------- include/tvm/relay/expr_visitor.h | 20 ++++++----- include/tvm/relay/pass/alpha_eq.h | 60 +++++++++++++++---------------- include/tvm/relay/type.h | 4 +-- python/tvm/relay/ir_builder.py | 4 +-- python/tvm/relay/op/_tensor.py | 2 +- python/tvm/relay/op/op.py | 4 +-- src/relay/ir/expr.cc | 10 +++--- src/relay/pass/type_infer.cc | 4 +-- src/relay/pass/unifier.cc | 2 +- src/relay/source_map.cc | 1 - 14 files changed, 77 insertions(+), 78 deletions(-) diff --git a/include/tvm/base.h b/include/tvm/base.h index 464259bc0527..c2d796b6002c 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -134,5 +134,6 @@ struct NodeFactoryReg { */ #define TVM_ADD_FILELINE "\n\nDefined in " __FILE__ ":L" TVM_STRINGIZE(__LINE__) + } // namespace tvm #endif // TVM_BASE_H_ diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 3178d073c778..098c47ff10a2 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -51,7 +51,7 @@ using NodeEqual = ::tvm::NodeEqual; /*! * \brief Macro to make it easy to define node ref type given node * \param TypeName The name of the reference type. - * \param NodeName The internal contrainer name. + * \param NodeName The internal container name. * \param NodeRefBase The base type. */ #define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \ @@ -72,11 +72,11 @@ using NodeEqual = ::tvm::NodeEqual; */ class SourceName; /*! - * \brief The source name in the Span + * \brief The name of a source fragment. */ class SourceNameNode : public Node { public: - /*! \brief The source name */ + /*! \brief The source name. */ std::string name; // override attr visitor void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); } @@ -95,7 +95,6 @@ RELAY_DEFINE_NODE_REF(SourceName, SourceNameNode, NodeRef); class Span; /*! * \brief Stores locations in frontend source that generated a node. - * */ class SpanNode : public Node { public: @@ -125,7 +124,8 @@ RELAY_DEFINE_NODE_REF(Span, SpanNode, NodeRef); */ class RelayNode : public Node { public: - /*! \brief The debug information, can be null, check with span.defined() */ + /*! \brief The location of the program in a SourceFragment can be null, + * check with span.defined() */ mutable Span span; static constexpr const char* _type_key = "relay.Node"; diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h index 29cde295398d..2df21b0bb2ce 100644 --- a/include/tvm/relay/environment.h +++ b/include/tvm/relay/environment.h @@ -31,7 +31,7 @@ struct Environment; * Many operations require access to the global * Environment. We pass the Environment by value * in a functional style as an explicit argument, - * but we will mutate the Environment while optimizing + * but we mutate the Environment while optimizing * Relay programs. * * The functional style allows users to construct custom diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index a882b7cc1ea7..8f597373960e 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -22,7 +22,7 @@ namespace relay { */ class Expr; /*! - * \brief Base type of the Relay type hiearchy. + * \brief Base type of the Relay expression hiearchy. */ class ExprNode : public RelayNode { public: @@ -30,7 +30,7 @@ class ExprNode : public RelayNode { * \brief Stores the result of type inference(type checking). * * \note This can be undefined before type inference. - * this value is discarded during serialization. + * This value is discarded during serialization. */ mutable Type checked_type_ = Type(nullptr); /*! @@ -39,7 +39,7 @@ class ExprNode : public RelayNode { const Type& checked_type() const { CHECK(checked_type_.defined()) << "internal error: the type checker has " "not populated the checked_type " - << "field for this node"; + "field for this node"; return this->checked_type_; } @@ -50,7 +50,7 @@ class ExprNode : public RelayNode { RELAY_DEFINE_NODE_REF(Expr, ExprNode, NodeRef); /*! - * \brief Constant tensor, backed by an NDArray on cpu(0). + * \brief Constant tensor, backed by an NDArray on the cpu(0) device. * * \note scalar constants are represented by rank-0 const tensor. * Constant folding are handled uniformly via Tensor types. @@ -67,7 +67,7 @@ class ConstantNode : public ExprNode { /*! \return The corresponding tensor type of the data */ TensorType tensor_type() const; - /*! \return whether it is scalar(rank-0 tensor) */ + /*! \return Whether it is scalar(rank-0 tensor) */ bool is_scalar() const { return data->ndim == 0; } void VisitAttrs(tvm::AttrVisitor* v) final { @@ -114,7 +114,9 @@ class LocalVar; /*! \brief Container for LocalVar */ class LocalVarNode : public ExprNode { public: - /*! \brief The name of the variable, this only acts as a hint. */ + /*! \brief The name of the variable, this only acts as a hint to the user, + * and is not used for equality. + */ std::string name_hint; void VisitAttrs(tvm::AttrVisitor* v) final { @@ -133,7 +135,7 @@ RELAY_DEFINE_NODE_REF(LocalVar, LocalVarNode, Expr); * \brief Global variable that leaves in the top-level environment. * This is used to enable recursive calls between function. * - * \note GlobalVar can only corresponds to functions. + * \note A GlobalVar may only point to functions. */ class GlobalVar; /*! \brief A GlobalId from the node's current type to target type. */ @@ -343,20 +345,20 @@ class IfNode : public ExprNode { /*! \brief The condition */ Expr cond; /*! \brief The expression evaluated when condition is true. */ - Expr true_value; + Expr true_branch; /*! \brief The expression evaluated when condition is false */ - Expr false_value; + Expr false_branch; IfNode() {} void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("cond", &cond); - v->Visit("true_value", &true_value); - v->Visit("false_value", &false_value); + v->Visit("true_branch", &true_branch); + v->Visit("false_branch", &false_branch); v->Visit("span", &span); } - TVM_DLL static If make(Expr cond, Expr true_value, Expr false_value); + TVM_DLL static If make(Expr cond, Expr true_branch, Expr false_branch); static constexpr const char* _type_key = "relay.If"; TVM_DECLARE_NODE_TYPE_INFO(IfNode, ExprNode); @@ -364,11 +366,6 @@ class IfNode : public ExprNode { RELAY_DEFINE_NODE_REF(If, IfNode, Expr); -// template -// T Downcast(U u) { - -// } - } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_H_ diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h index 349ce3f9543b..1221a15c0da1 100644 --- a/include/tvm/relay/expr_visitor.h +++ b/include/tvm/relay/expr_visitor.h @@ -58,8 +58,8 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor { void VisitExpr_(const IfNode* op) override { this->VisitExpr(op->cond); - this->VisitExpr(op->true_value); - this->VisitExpr(op->false_value); + this->VisitExpr(op->true_branch); + this->VisitExpr(op->false_branch); } void VisitExpr_(const OpNode* op) override { return; } @@ -99,7 +99,8 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { auto type = this->VisitType(op->type); return ParamNode::make(var, type); } else { - throw dmlc::Error("the default param visitor has bug"); + CHECK(false) << "the default param visitor expected a Var found: " << var_expr << std::endl; + __builtin_unreachable(); } } @@ -112,7 +113,8 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { auto ty_param_ref = GetRef(ty_param); ty_params.push_back(ty_param_ref); } else { - throw dmlc::Error("the default func visitor has bug"); + CHECK(false) << "the default function visitor expected a TypeParam found: " << ty_param_type << std::endl; + __builtin_unreachable(); } } @@ -123,7 +125,8 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { auto param = GetRef(param_node); params.push_back(param); } else { - throw dmlc::Error("the default func visitor has bug"); + CHECK(false) << "the default function visitor expected a Param found: " << param_expr << std::endl; + __builtin_unreachable(); } } @@ -160,14 +163,15 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { auto body = this->VisitExpr(op->body); return LetNode::make(var, value, body, type); } else { - throw dmlc::Error("the default let visitor has error"); + CHECK(false) << "the default let visitor expected a Var found: " << var_expr << std::endl; + __builtin_unreachable(); } } Expr VisitExpr_(const IfNode* op) override { auto guard = this->VisitExpr(op->cond); - auto true_b = this->VisitExpr(op->true_value); - auto false_b = this->VisitExpr(op->false_value); + auto true_b = this->VisitExpr(op->true_branch); + auto false_b = this->VisitExpr(op->false_branch); return IfNode::make(guard, true_b, false_b); } diff --git a/include/tvm/relay/pass/alpha_eq.h b/include/tvm/relay/pass/alpha_eq.h index b6d98bd68940..cfc6f5aa1ae7 100644 --- a/include/tvm/relay/pass/alpha_eq.h +++ b/include/tvm/relay/pass/alpha_eq.h @@ -13,40 +13,38 @@ namespace tvm { namespace relay { /*! \brief Compare two expressions for structural equivalence. - - This comparsion operator respects scoping and compares - expressions without regard to variable choice. - - For example: `let x = 1 in x` is equal to `let y = 1 in y`. - - See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence - for more details. - - \param e1 The left hand expression. - \param e2 The right hand expression. - - \return true if equal, otherwise false - -*/ + * + * This comparison operator respects scoping and compares + * expressions without regard to variable choice. + * + * For example: `let x = 1 in x` is equal to `let y = 1 in y`. + * + * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence + * for more details. + * + * \param e1 The left hand expression. + * \param e2 The right hand expression. + * + * \return true if equal, otherwise false + */ bool AlphaEqual(const Expr& e1, const Expr& e2); /*! \brief Compare two types for structural equivalence. - - This comparsion operator respects scoping and compares - expressions without regard to variable choice. - - For example: `forall s, Tensor[f32, s]` is equal to - `forall w, Tensor[f32, w]`. - - See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence - for more details. - - \param t1 The left hand type. - \param t2 The right hand type. - - \return true if equal, otherwise false - -*/ + * + * This comparison operator respects scoping and compares + * expressions without regard to variable choice. + * + * For example: `forall s, Tensor[f32, s]` is equal to + * `forall w, Tensor[f32, w]`. + * + * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence + * for more details. + * + * \param t1 The left hand type. + * \param t2 The right hand type. + * + * \return true if equal, otherwise false + */ bool AlphaEqual(const Type& t1, const Type& t2); } // namespace relay diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 5e3665dfbd1d..d4b16043dbc0 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -83,10 +83,10 @@ class TensorTypeNode : public BaseTensorTypeNode { TVM_DLL static TensorType make(Array shape, DataType dtype); - /*! \brief Constructing an unsigned integer type */ + /*! \brief Construct an unsigned integer type */ TVM_DLL static TensorType Int(int bits, int lanes = 1); - /*! \brief Constructing an unsigned integer type */ + /*! \brief Construct an unsigned integer type */ TVM_DLL static TensorType UInt(int bits, int lanes = 1); /*! \brief Construct a floating-point type */ diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index c0c2e76c1157..69ec97a8fd6a 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -197,11 +197,11 @@ def _on_exit(): bindings, _, _, ret_value = self.exit_scope() partial_if = self.ret_values[-1] assert isinstance( - partial_if, If) and partial_if.false_value is None + partial_if, If) and partial_if.false_branch is None false_branch = _mk_let(bindings, ret_value) self.ret_values[-1] = If( partial_if.cond, - partial_if.true_value, + partial_if.true_branch, false_branch) return WithScope(10, _on_exit) diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 2a0ecc6c8550..875df0e52561 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -1,5 +1,5 @@ #pylint: disable=invalid-name -"""Backend compiler related feature regsitration""" +"""Backend compiler related feature registration""" from topi import add from .op import register from ..type import FuncType, TensorType diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 14570b62269b..adf963403bd3 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -10,6 +10,7 @@ @register_relay_node class Op(Expr): """A Relay operator definition.""" + def __init__(self): raise RuntimeError("Cannot create op, use get instead") @@ -90,9 +91,8 @@ def compile_ops(op_names): """ return _CompileOpsToModule(*op_names) -# TODO(@jroesch): We should port to C++, just need to figure out how to write this code. - +# TODO(@jroesch): We should port to C++, just need to figure out how to write this code. @register_func("relay.op._compile_ops") def _compile_ops(op_impls): lowered = [] diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 8dce7c054c8e..7f9ef28cbe10 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -181,11 +181,11 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << ", " << node->body << ", " << node->value_type << ")"; }); -If IfNode::make(Expr cond, Expr true_value, Expr false_value) { +If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) { std::shared_ptr n = std::make_shared(); n->cond = std::move(cond); - n->true_value = std::move(true_value); - n->false_value = std::move(false_value); + n->true_branch = std::move(true_branch); + n->false_branch = std::move(false_branch); return If(n); } @@ -195,8 +195,8 @@ TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) { TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const IfNode *node, tvm::IRPrinter *p) { - p->stream << "IfNode(" << node->cond << ", " << node->true_value - << node->false_value << ")"; + p->stream << "IfNode(" << node->cond << ", " << node->true_branch + << node->false_branch << ")"; }); } // namespace relay diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index e29a22234a8a..78d973c0c261 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -367,8 +367,8 @@ CheckedExpr TypeInferencer::VisitExpr_(const IfNode *op) { this->unify(cond_type, TensorTypeNode::make({}, HalideIR::Bool()), ifn->cond->span); - auto checked_true = this->Infer(ifn->true_value); - auto checked_false = this->Infer(ifn->false_value); + auto checked_true = this->Infer(ifn->true_branch); + auto checked_false = this->Infer(ifn->false_branch); auto unified_type = this->unify(checked_true.type, checked_false.type, ifn->span); auto checked_if = diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index f80daa8d3bd0..752c4c3f1116 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -43,7 +43,7 @@ void UnionFindNode::AssertAlphaEqual(const Type &l, const Type &r) { } void UnionFindNode::unify(const IncompleteType &v1, const Type &t) { - RELAY_LOG(INFO) << "UnionFindNode::Unify v1=" << v1 << "t=" << t << std::endl; + RELAY_LOG(INFO) << "UnionFindNode::Unify v1=" << v1 << ", t=" << t << std::endl; auto parent1 = this->find(v1); // if t is a type var, then unify parents diff --git a/src/relay/source_map.cc b/src/relay/source_map.cc index 9d3316cf38cf..7630135a1e5e 100644 --- a/src/relay/source_map.cc +++ b/src/relay/source_map.cc @@ -34,7 +34,6 @@ SourceFragment::SourceFragment(const std::string& file_name, std::string SourceFragment::SourceAt(Span sp, int max_lines = 1) { std::stringstream out; - // We need to move from 1 based indexing to zero based indexing. int starting_line = sp->lineno; if (starting_line >= static_cast(this->source_lines.size())) { From 032c0407c0e766db3a2fa26a3162e4a90273899b Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 14 Sep 2018 15:06:10 -0700 Subject: [PATCH 091/136] Rename LocalVar to Var --- include/tvm/relay/expr.h | 26 +++++++++++++------------- include/tvm/relay/expr_functor.h | 4 ++-- include/tvm/relay/expr_visitor.h | 14 +++++++------- python/tvm/relay/__init__.py | 4 ++-- python/tvm/relay/expr.py | 12 ++++++------ python/tvm/relay/ir_builder.py | 10 +++++----- python/tvm/relay/ir_pass.py | 8 ++++---- python/tvm/relay/to_tvm.py | 12 ++++++------ src/relay/ir/expr.cc | 18 +++++++++--------- src/relay/ir/op.cc | 2 +- src/relay/pass/type_infer.cc | 12 ++++++------ tests/python/relay/test_ir_nodes.py | 18 +++++++++--------- 12 files changed, 70 insertions(+), 70 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 8f597373960e..beb4770c30eb 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -108,11 +108,11 @@ RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr); * \brief Local variables used in the let expression. * This is similar to Var that is being used in the low level tensor expression. * - * \note Each LocalVar is bind only once and is immutable/ + * \note Each Var is bind only once and is immutable/ */ -class LocalVar; -/*! \brief Container for LocalVar */ -class LocalVarNode : public ExprNode { +class Var; +/*! \brief Container for Var */ +class VarNode : public ExprNode { public: /*! \brief The name of the variable, this only acts as a hint to the user, * and is not used for equality. @@ -123,13 +123,13 @@ class LocalVarNode : public ExprNode { v->Visit("name_hint", &name_hint); } - TVM_DLL static LocalVar make(std::string name_hint); + TVM_DLL static Var make(std::string name_hint); - static constexpr const char* _type_key = "relay.LocalVar"; - TVM_DECLARE_NODE_TYPE_INFO(LocalVarNode, ExprNode); + static constexpr const char* _type_key = "relay.Var"; + TVM_DECLARE_NODE_TYPE_INFO(VarNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(LocalVar, LocalVarNode, Expr); +RELAY_DEFINE_NODE_REF(Var, VarNode, Expr); /*! * \brief Global variable that leaves in the top-level environment. @@ -164,7 +164,7 @@ class Param; class ParamNode : public ExprNode { public: /*! \brief The variable */ - LocalVar var; + Var var; /*! \brief The type of the parameter */ Type type; @@ -174,7 +174,7 @@ class ParamNode : public ExprNode { v->Visit("span", &span); } - TVM_DLL static Param make(LocalVar var, Type type); + TVM_DLL static Param make(Var var, Type type); static constexpr const char* _type_key = "relay.Param"; TVM_DECLARE_NODE_TYPE_INFO(ParamNode, ExprNode); @@ -240,7 +240,7 @@ class CallNode : public ExprNode { * \brief The operator(function) being invoked * * - It can be relay::Op which corresponds to the primitive operators. - * - It can also be user defined functions (Function, GlobalVar, LocalVar). + * - It can also be user defined functions (Function, GlobalVar, Var). */ Expr op; @@ -305,7 +305,7 @@ class Let; class LetNode : public ExprNode { public: /*! \brief The variable we bind to */ - LocalVar var; + Var var; /*! \brief The value we bind var to */ Expr value; /*! \brief The body of the let binding */ @@ -321,7 +321,7 @@ class LetNode : public ExprNode { v->Visit("span", &span); } - TVM_DLL static Let make(LocalVar var, Expr value, Expr body, Type value_type); + TVM_DLL static Let make(Var var, Expr value, Expr body, Type value_type); static constexpr const char* _type_key = "relay.Let"; TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode); diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 0d736212c9eb..8e2f24837473 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -75,7 +75,7 @@ class ExprFunctor { Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const TupleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const LocalVarNode* op, + virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const GlobalVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; @@ -99,7 +99,7 @@ class ExprFunctor { // Set dispatch RELAY_EXPR_FUNCTOR_DISPATCH(ConstantNode); RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode); - RELAY_EXPR_FUNCTOR_DISPATCH(LocalVarNode); + RELAY_EXPR_FUNCTOR_DISPATCH(VarNode); RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode); RELAY_EXPR_FUNCTOR_DISPATCH(ParamNode); RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode); diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h index 1221a15c0da1..9f548223845b 100644 --- a/include/tvm/relay/expr_visitor.h +++ b/include/tvm/relay/expr_visitor.h @@ -17,7 +17,7 @@ namespace relay { class ExprVisitor : public ::tvm::relay::ExprFunctor { public: - void VisitExpr_(const LocalVarNode* op) override { return; } + void VisitExpr_(const VarNode* op) override { return; } void VisitExpr_(const GlobalVarNode* op) override { return; } @@ -69,8 +69,8 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor { class ExprFVisitor : public ::tvm::relay::ExprFunctor { public: - Expr VisitExpr_(const LocalVarNode* op) override { - return GetRef(op); + Expr VisitExpr_(const VarNode* op) override { + return GetRef(op); } Expr VisitExpr_(const ConstantNode* op) override { @@ -94,8 +94,8 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { Expr VisitExpr_(const ParamNode* op) override { Expr var_expr = this->VisitExpr(op->var); - if (const LocalVarNode* var_node = var_expr.as()) { - auto var = GetRef(var_node); + if (const VarNode* var_node = var_expr.as()) { + auto var = GetRef(var_node); auto type = this->VisitType(op->type); return ParamNode::make(var, type); } else { @@ -156,8 +156,8 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { Expr VisitExpr_(const LetNode* op) override { Expr var_expr = this->VisitExpr(op->var); - if (const LocalVarNode* var_node = var_expr.as()) { - auto var = GetRef(var_node); + if (const VarNode* var_node = var_expr.as()) { + auto var = GetRef(var_node); auto type = this->VisitType(op->value_type); auto value = this->VisitExpr(op->value); auto body = this->VisitExpr(op->body); diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index c254c7e9ce7a..4f58958feaf9 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -25,11 +25,11 @@ # Expr Constant = expr.Constant Tuple = expr.Tuple -LocalVar = expr.LocalVar +Var = expr.Var GlobalVar = expr.GlobalVar Param = expr.Param Function = expr.Function Call = expr.Call Let = expr.Let If = expr.If -Var = LocalVar +Var = Var diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 3cdaed89d2fb..ebf8483ccc04 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -52,12 +52,12 @@ def __init__(self, fields: List[Expr]) -> None: @register_relay_node -class LocalVar(Expr): +class Var(Expr): """A local variable in Relay.""" name_hint: str def __init__(self, name_hint: str) -> None: - self.__init_handle_by_constructor__(_make.LocalVar, name_hint) + self.__init_handle_by_constructor__(_make.Var, name_hint) @register_relay_node @@ -73,10 +73,10 @@ def __init__(self, name_hint: str) -> None: class Param(Expr): """A function type in Relay, see tvm/relay/type.h for more details. """ - var: LocalVar + var: Var type: Type - def __init__(self, var: LocalVar, ty: Type) -> None: + def __init__(self, var: Var, ty: Type) -> None: self.__init_handle_by_constructor__(_make.Param, var, ty) @@ -117,13 +117,13 @@ def __init__(self, op: Expr, args: List[Expr], attrs, ty_args=None) -> None: @register_relay_node class Let(Expr): """A variable bindings in Relay, see tvm/relay/expr.h for more details.""" - var: LocalVar + var: Var value: Expr body: Expr # should be type annotation value_type: Type - def __init__(self, var: LocalVar, value: Expr, body: Expr, value_type: Type) -> None: + def __init__(self, var: Var, value: Expr, body: Expr, value_type: Type) -> None: self.__init_handle_by_constructor__( _make.Let, var, value, body, value_type) diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 69ec97a8fd6a..9dc4796802c1 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -6,7 +6,7 @@ import numpy as np import tvm from .type import FuncType, TensorType -from .expr import Expr, Constant, Let, LocalVar, Param, Function, If +from .expr import Expr, Constant, Let, Var, Param, Function, If from .env import Environment @@ -115,7 +115,7 @@ def exit_scope(self): #pylint: disable=invalid-name def bind(self, name, value, ty): - lv = LocalVar(name) + lv = Var(name) self.scopes[-1][name] = lv self.bindings[-1][lv] = (value, ty) return lv @@ -138,10 +138,10 @@ def _convert_params(self, raw_params): elif isinstance(raw_param, tuple): var, ty = raw_param if isinstance(var, str): - var = LocalVar(var) + var = Var(var) param = Param(var, ty) elif isinstance(param, str): - var = LocalVar(raw_param) + var = Var(raw_param) ty = None param = Param(var, ty) else: @@ -210,7 +210,7 @@ def param(self, name, ty=None): if not ty: ty = float_type() - return Param(LocalVar(name), ty) + return Param(Var(name), ty) # def params(*args): # i = 0 diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index b075704c212a..ca396404610a 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -10,7 +10,7 @@ import tvm from .expr import Expr -from .expr import Function, Let, Call, LocalVar +from .expr import Function, Let, Call, Var from .expr import GlobalVar, If, Constant from .type import Type, TypeParam from .env import Environment @@ -63,7 +63,7 @@ def visit(self, expr: Expr) -> T: return self.visit_call(expr) elif isinstance(expr, Let): return self.visit_let(expr) - elif isinstance(expr, LocalVar): + elif isinstance(expr, Var): return self.visit_local_var(expr) elif isinstance(expr, GlobalVar): return self.visit_global_var(expr) @@ -85,7 +85,7 @@ def visit_let(self, _: Let) -> T: def visit_call(self, _: Call) -> T: raise Exception("Abstract method please implement me.") - def visit_local_id(self, _: LocalVar) -> T: + def visit_local_id(self, _: Var) -> T: raise Exception("Abstract method please implement me.") def visit_type(self, typ: Type) -> Type: @@ -136,7 +136,7 @@ def visit_call(self, call: Call) -> Expr: new_args = [self.visit(arg) for arg in call.args] return Call(new_fn, new_args, call.attrs) - def visit_local_var(self, local_var: LocalVar) -> Expr: + def visit_local_var(self, local_var: Var) -> Expr: return local_var def visit_global_id(self, global_var: GlobalVar) -> Expr: diff --git a/python/tvm/relay/to_tvm.py b/python/tvm/relay/to_tvm.py index 615a39301142..99e6b6d41674 100644 --- a/python/tvm/relay/to_tvm.py +++ b/python/tvm/relay/to_tvm.py @@ -7,7 +7,7 @@ from .ir_pass import AbstractExprVisitor from .op import compile_ops, Op from .type import TensorType -from .expr import LocalVar, Function, Let, Call +from .expr import Var, Function, Let, Call @attr.s(auto_attribs=True) @@ -80,7 +80,7 @@ def from_tensor(typ: TensorType) -> Tuple[str, List[int]]: class TVMRTSCompiler(AbstractExprVisitor[NodeRef]): """The compiler from Relay to the TVM runtime system.""" nodes: List[Node] - id_map: Dict[LocalVar, NodeRef] + id_map: Dict[Var, NodeRef] all_ops: Set[Op] def __init__(self) -> None: @@ -93,10 +93,10 @@ def add_node(self, node: Node) -> NodeRef: ident = len(self.nodes) - 1 return NodeRef(ident) - def add_binding(self, ident: LocalVar, ref: NodeRef) -> None: + def add_binding(self, ident: Var, ref: NodeRef) -> None: self.id_map[ident] = ref - def let_bind(self, ident: LocalVar, node: Node) -> NodeRef: + def let_bind(self, ident: Var, node: Node) -> NodeRef: ref = self.add_node(node) self.add_binding(ident, ref) return ref @@ -104,7 +104,7 @@ def let_bind(self, ident: LocalVar, node: Node) -> NodeRef: def get_node(self, ref: NodeRef) -> Node: return self.nodes[ref.ident] - def lookup(self, ident: LocalVar) -> NodeRef: + def lookup(self, ident: Var) -> NodeRef: return self.id_map[ident] def compile(self, func: Function) -> None: @@ -151,7 +151,7 @@ def visit_let(self, let: Let) -> NodeRef: self.add_binding(ident, val_ref) return self.visit(body) - def visit_local_var(self, ident: LocalVar) -> NodeRef: + def visit_local_var(self, ident: Var) -> NodeRef: return self.lookup(ident) def visit_call(self, call: Call) -> NodeRef: diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 7f9ef28cbe10..925630bc8399 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -58,21 +58,21 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "TupleNode(" << node->fields << ")"; }); -LocalVar LocalVarNode::make(std::string name_hint) { - std::shared_ptr n = std::make_shared(); +Var VarNode::make(std::string name_hint) { + std::shared_ptr n = std::make_shared(); n->name_hint = std::move(name_hint); - return LocalVar(n); + return Var(n); } -TVM_REGISTER_API("relay._make.LocalVar") +TVM_REGISTER_API("relay._make.Var") .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = LocalVarNode::make(args[0]); + *ret = VarNode::make(args[0]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const LocalVarNode *node, + .set_dispatch([](const VarNode *node, tvm::IRPrinter *p) { - p->stream << "LocalVarNode(" << node->name_hint << ")"; + p->stream << "VarNode(" << node->name_hint << ")"; }); GlobalVar GlobalVarNode::make(std::string name_hint) { @@ -92,7 +92,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "GlobalVarNode(" << node->name_hint << ")"; }); -Param ParamNode::make(LocalVar var, Type type) { +Param ParamNode::make(Var var, Type type) { std::shared_ptr n = std::make_shared(); n->var = std::move(var); n->type = std::move(type); @@ -161,7 +161,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << node->attrs << ", " << node->type_args << ")"; }); -Let LetNode::make(LocalVar var, Expr value, Expr body, Type value_type) { +Let LetNode::make(Var var, Expr value, Expr body, Type value_type) { std::shared_ptr n = std::make_shared(); n->var = std::move(var); n->value = std::move(value); diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index a6dbd769f75f..8b6748dfe482 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -158,7 +158,7 @@ Module CompileOpsToModule(const std::vector& op_names) { tvm::Array pair = compiler(op->name, op->op_type); // TODO(@jroesch): I can't pass strings across what should be the // interface here. - tvm::Array triple = {LocalVarNode::make(op->name), pair[0], + tvm::Array triple = {VarNode::make(op->name), pair[0], pair[1]}; args.push_back(triple); } else { diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 78d973c0c261..7501aff67682 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -43,12 +43,12 @@ struct TypeConstraintSet { }; struct TypeContext { - std::vector> stack; + std::vector> stack; std::vector constraints; TypeContext() { stack.push_back({}); } - void insert(const LocalVar &id, const Type &t) { stack.back()[id] = t; } + void insert(const Var &id, const Type &t) { stack.back()[id] = t; } void AddConstraint(const TypeConstraint &ty_rel) { constraints.back().Add(ty_rel); @@ -58,7 +58,7 @@ struct TypeContext { // return // } - Type lookup(const LocalVar &id) { + Type lookup(const Var &id) { for (auto frame = stack.rbegin(); frame != stack.rend(); ++frame) { if (frame->find(id) != frame->end()) { return frame->at(id); @@ -126,7 +126,7 @@ class TypeInferencer : private ExprFunctor { CheckedExpr VisitFunction(const Function &f, bool generalize); private: - CheckedExpr VisitExpr_(const LocalVarNode *op) override; + CheckedExpr VisitExpr_(const VarNode *op) override; CheckedExpr VisitExpr_(const GlobalVarNode *op) override; CheckedExpr VisitExpr_(const ConstantNode *op) override; CheckedExpr VisitExpr_(const TupleNode *op) override; @@ -160,8 +160,8 @@ CheckedExpr TypeInferencer::Infer(const Expr &expr) { return checked_expr; } -CheckedExpr TypeInferencer::VisitExpr_(const LocalVarNode *op) { - auto var = GetRef(op); +CheckedExpr TypeInferencer::VisitExpr_(const VarNode *op) { + auto var = GetRef(op); return {var, this->context.lookup(var)}; } diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index 676aa347950b..cf035f8a2b19 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -66,7 +66,7 @@ def test_tuple() -> None: def test_local_var() -> None: name_hint = 's' - lv = relay.LocalVar(name_hint) + lv = relay.Var(name_hint) lv.name_hint == name_hint # assert lv.span == None todo(@jroesch): what do we do about spans str(lv) @@ -81,7 +81,7 @@ def test_global_var() -> None: def test_param() -> None: - lv = relay.LocalVar('x') + lv = relay.Var('x') ty = None param = relay.Param(lv, ty) assert param.var == lv @@ -92,7 +92,7 @@ def test_param() -> None: def test_function() -> None: param_names = ['a', 'b', 'c', 'd'] - params = tvm.convert([relay.Param(relay.LocalVar(n), None) for n in param_names]) + params = tvm.convert([relay.Param(relay.Var(n), None) for n in param_names]) ret_type = None body = None type_params = tvm.convert([]) @@ -105,9 +105,9 @@ def test_function() -> None: def test_call() -> None: - op = relay.LocalVar('f') + op = relay.Var('f') arg_names = ['a', 'b', 'c', 'd'] - args = tvm.convert([relay.LocalVar(n) for n in arg_names]) + args = tvm.convert([relay.Var(n) for n in arg_names]) call = relay.Call(op, args, None, None) assert call.op == op assert call.args == args @@ -116,7 +116,7 @@ def test_call() -> None: def test_let() -> None: - lv = relay.LocalVar('x') + lv = relay.Var('x') ty = None arr = tvm.nd.array(10) value = relay.Constant(arr) @@ -132,9 +132,9 @@ def test_let() -> None: def test_if() -> None: - cond = relay.LocalVar('cond') - left = relay.LocalVar('left') - right = relay.LocalVar('right') + cond = relay.Var('cond') + left = relay.Var('left') + right = relay.Var('right') ife = relay.If(cond, left, right) assert ife.cond == cond assert ife.true_value == left From 1c0007b6207df52d80f43655d588d88b7a9d6667 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 14 Sep 2018 15:14:18 -0700 Subject: [PATCH 092/136] Fix cpplint --- include/tvm/relay/base.h | 3 ++- include/tvm/relay/expr_visitor.h | 17 ++++++++++------- include/tvm/relay/op.h | 25 ++++++++++++------------- src/relay/pass/type_infer.cc | 2 +- src/relay/pass/type_visitor.h | 2 +- 5 files changed, 26 insertions(+), 23 deletions(-) diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 098c47ff10a2..96508de66e51 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -10,6 +10,7 @@ #include #include #include +#include namespace tvm { /*! @@ -124,7 +125,7 @@ RELAY_DEFINE_NODE_REF(Span, SpanNode, NodeRef); */ class RelayNode : public Node { public: - /*! \brief The location of the program in a SourceFragment can be null, + /*! \brief The location of the program in a SourceFragment can be null, * check with span.defined() */ mutable Span span; diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h index 9f548223845b..748b8ac02f97 100644 --- a/include/tvm/relay/expr_visitor.h +++ b/include/tvm/relay/expr_visitor.h @@ -69,9 +69,7 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor { class ExprFVisitor : public ::tvm::relay::ExprFunctor { public: - Expr VisitExpr_(const VarNode* op) override { - return GetRef(op); - } + Expr VisitExpr_(const VarNode* op) override { return GetRef(op); } Expr VisitExpr_(const ConstantNode* op) override { return GetRef(op); @@ -99,7 +97,8 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { auto type = this->VisitType(op->type); return ParamNode::make(var, type); } else { - CHECK(false) << "the default param visitor expected a Var found: " << var_expr << std::endl; + CHECK(false) << "the default param visitor expected a Var found: " + << var_expr << std::endl; __builtin_unreachable(); } } @@ -113,7 +112,9 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { auto ty_param_ref = GetRef(ty_param); ty_params.push_back(ty_param_ref); } else { - CHECK(false) << "the default function visitor expected a TypeParam found: " << ty_param_type << std::endl; + CHECK(false) + << "the default function visitor expected a TypeParam found: " + << ty_param_type << std::endl; __builtin_unreachable(); } } @@ -125,7 +126,8 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { auto param = GetRef(param_node); params.push_back(param); } else { - CHECK(false) << "the default function visitor expected a Param found: " << param_expr << std::endl; + CHECK(false) << "the default function visitor expected a Param found: " + << param_expr << std::endl; __builtin_unreachable(); } } @@ -163,7 +165,8 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { auto body = this->VisitExpr(op->body); return LetNode::make(var, value, body, type); } else { - CHECK(false) << "the default let visitor expected a Var found: " << var_expr << std::endl; + CHECK(false) << "the default let visitor expected a Var found: " + << var_expr << std::endl; __builtin_unreachable(); } } diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 9727e80fa561..a3037d3bebf4 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -149,20 +149,18 @@ class OpRegistry { const std::string& description); /*! * \brief Attach the type function corresponding to the return type. - * \param type_rel_name The type function name to register for the return type. - * \param type_rel The backing relation which can solve an arbitrary relation - * on variables. - * \return reference to self. + * \param type_rel_name The type function name to register for the return + * type. \param type_rel The backing relation which can solve an arbitrary + * relation on variables. \return reference to self. */ inline OpRegistry& add_type_rel(const std::string& type_rel_name, - TypeRelationFn type_rel); + TypeRelationFn type_rel); /*! * \brief Attach the type function corresponding to the return type. - * \param type_rel_name The type function name to register for the return type. - * \param type_rel The backing relation which can solve an arbitrary relation - * on variables. - * \return reference to self. + * \param type_rel_name The type function name to register for the return + * type. \param type_rel The backing relation which can solve an arbitrary + * relation on variables. \return reference to self. */ inline OpRegistry& add_type_rel( const std::string& type_rel_name, @@ -365,8 +363,7 @@ inline OpRegistry& OpRegistry::add_type_rel( } inline OpRegistry& OpRegistry::add_type_rel(const std::string& type_func_name, - TypeRelationFn type_fn) { - + TypeRelationFn type_fn) { std::vector type_params; std::vector arg_types; @@ -387,9 +384,11 @@ inline OpRegistry& OpRegistry::add_type_rel(const std::string& type_func_name, type_params.push_back(out_param); ty_call_args.push_back(out_param); - TypeConstraint type_rel = TypeRelationNode::make(type_func_name, type_fn, ty_call_args); + TypeConstraint type_rel = + TypeRelationNode::make(type_func_name, type_fn, ty_call_args); - auto func_type = FuncTypeNode::make(arg_types, out_param, type_params, { type_rel }); + auto func_type = + FuncTypeNode::make(arg_types, out_param, type_params, {type_rel}); get()->op_type = func_type; diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 7501aff67682..89b53bbb3c66 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -38,7 +38,7 @@ using namespace tvm::runtime; struct TypeConstraintSet { std::vector ty_rels; TypeConstraintSet() : ty_rels() {} - TypeConstraintSet(const std::vector &cs) : ty_rels(cs) {} + explicit TypeConstraintSet(const std::vector &cs) : ty_rels(cs) {} void Add(const TypeConstraint &ty_rel) { ty_rels.push_back(ty_rel); } }; diff --git a/src/relay/pass/type_visitor.h b/src/relay/pass/type_visitor.h index ece9f27613bf..252642b1a492 100644 --- a/src/relay/pass/type_visitor.h +++ b/src/relay/pass/type_visitor.h @@ -91,7 +91,7 @@ struct TypeFVisitor : TypeFunctor { } return FuncTypeNode::make(tvm::Array(args), VisitType(op->ret_type), - type_params, type_constraints); + type_params, type_constraints); } Type VisitType_(const TupleTypeNode* op) override { From c72d3a6521706c224beb4d41a44ded57857b7d48 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 14 Sep 2018 16:29:37 -0700 Subject: [PATCH 093/136] More clean up to type inference --- src/relay/pass/type_infer.cc | 125 +++++++++++++++++------------------ 1 file changed, 59 insertions(+), 66 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 89b53bbb3c66..76321a0fd4c5 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -35,45 +35,35 @@ namespace relay { using namespace tvm::runtime; -struct TypeConstraintSet { - std::vector ty_rels; - TypeConstraintSet() : ty_rels() {} - explicit TypeConstraintSet(const std::vector &cs) : ty_rels(cs) {} - void Add(const TypeConstraint &ty_rel) { ty_rels.push_back(ty_rel); } -}; +using TypeConstraintSet = std::vector; struct TypeContext { - std::vector> stack; + std::unordered_map var_map; std::vector constraints; - TypeContext() { stack.push_back({}); } + TypeContext() { constraints.push_back({}); } - void insert(const Var &id, const Type &t) { stack.back()[id] = t; } + void Insert(const Var &id, const Type &t) { var_map[id] = t; } void AddConstraint(const TypeConstraint &ty_rel) { - constraints.back().Add(ty_rel); + constraints.back().push_back(ty_rel); } - // TypeConstraint & Constraints() { - // return - // } - - Type lookup(const Var &id) { - for (auto frame = stack.rbegin(); frame != stack.rend(); ++frame) { - if (frame->find(id) != frame->end()) { - return frame->at(id); - } + Type Lookup(const Var &id) { + auto type = var_map.find(id); + if (type != var_map.end()) { + return (*type).second; + } else { + throw FatalTypeError("Could not resolve local id"); } - throw FatalTypeError("Could not resolve local id"); } struct Frame { TypeContext &tc; explicit Frame(TypeContext &tc) : tc(tc) { - tc.stack.push_back({}); tc.constraints.push_back({}); } - ~Frame() { tc.stack.pop_back(); } + ~Frame() { tc.constraints.pop_back(); } }; }; @@ -96,7 +86,7 @@ class TypeInferencer : private ExprFunctor { // Should be in header? template - T with_frame(const std::function &f) { + T WithFrame(const std::function &f) { TypeContext::Frame fr(context); return f(); } @@ -108,16 +98,16 @@ class TypeInferencer : private ExprFunctor { CheckedExpr Infer(const Expr &expr); - FuncType instantiate(FuncType fn_ty, tvm::Array &ty_args); + FuncType Instantiate(FuncType fn_ty, tvm::Array &ty_args); Type Normalize(const Type &t); - void report_error(const std::string &msg, Span sp); - [[noreturn]] void fatal_error(const std::string &msg, Span sp); + void ReportError(const std::string &msg, Span sp); + [[noreturn]] void FatalError(const std::string &msg, Span sp); - Type unify(const Type &t1, const Type &t2, Span sp); - Type resolve(const Type &t); - Expr resolve(const Expr &e); + Type Unify(const Type &t1, const Type &t2, Span sp); + Type Resolve(const Type &t); + Expr Resolve(const Expr &e); TypeRelation Solve(const TypeRelation &ty_rel); SolverResult Solve(std::vector &rels); @@ -162,7 +152,7 @@ CheckedExpr TypeInferencer::Infer(const Expr &expr) { CheckedExpr TypeInferencer::VisitExpr_(const VarNode *op) { auto var = GetRef(op); - return {var, this->context.lookup(var)}; + return {var, this->context.Lookup(var)}; } CheckedExpr TypeInferencer::VisitExpr_(const GlobalVarNode *op) { @@ -192,7 +182,7 @@ CheckedExpr TypeInferencer::VisitExpr_(const TupleNode *op) { CheckedExpr TypeInferencer::VisitExpr_(const ParamNode *param) { // We should trigger error here and move param code direclty into function // checking. - auto rtype = resolve(param->type); + auto rtype = this->Resolve(param->type); // This is a special case ... not sure if there is a better way // to handle this. param->var->checked_type_ = rtype; @@ -208,23 +198,31 @@ CheckedExpr TypeInferencer::VisitFunction(const Function &f, bool generalize) { std::vector param_types; std::vector params; - return this->with_frame([&]() -> CheckedExpr { + return this->WithFrame([&]() -> CheckedExpr { for (auto param : f->params) { CheckedExpr checked_param = this->Infer(param); Type arg_type; param_types.push_back(checked_param.type); params.push_back(GetRef(checked_param.expr.as())); - this->context.insert(param->var, checked_param.type); + this->context.Insert(param->var, checked_param.type); } auto checked_body = this->Infer(f->body); auto inferred_rtype = checked_body.type; - auto annotated_rtype = resolve(f->ret_type); + auto annotated_rtype = Resolve(f->ret_type); + + auto unified_rtype = this->Unify(inferred_rtype, annotated_rtype, f->span); + + CHECK(RelationsHold(true)); - auto unified_rtype = this->unify(inferred_rtype, annotated_rtype, f->span); + Array cs; + + for (auto cons : this->context.constraints.back()) { + cs.push_back(cons); + } return {FunctionNode::make(params, unified_rtype, checked_body.expr, {}), - FuncTypeNode::make(param_types, unified_rtype, {}, {})}; + FuncTypeNode::make(param_types, unified_rtype, {}, cs)}; }); } @@ -232,7 +230,7 @@ CheckedExpr TypeInferencer::VisitExpr_(const FunctionNode *op) { return this->VisitFunction(GetRef(op), false); } -FuncType TypeInferencer::instantiate(FuncType fn_ty, +FuncType TypeInferencer::Instantiate(FuncType fn_ty, tvm::Array &ty_args) { tvm::Map subst_map; @@ -265,7 +263,7 @@ CheckedExpr TypeInferencer::VisitExpr_(const CallNode *op) { auto fn_ty_node = checked_op.type.as(); if (!fn_ty_node) { - this->fatal_error("only expressions with function types can be called", + this->FatalError("only expressions with function types can be called", c->op->span); } @@ -277,7 +275,7 @@ CheckedExpr TypeInferencer::VisitExpr_(const CallNode *op) { throw Error("found manually suplied type args, not supported"); } - fn_ty = instantiate(fn_ty, ty_args); + fn_ty = Instantiate(fn_ty, ty_args); std::vector arg_types; std::vector checked_args; @@ -293,14 +291,14 @@ CheckedExpr TypeInferencer::VisitExpr_(const CallNode *op) { if (type_arity != number_of_args) { if (type_arity < number_of_args) { - this->fatal_error("the function is provided too many arguments", c->span); + this->FatalError("the function is provided too many arguments", c->span); } else { - this->fatal_error("the function is provided too few arguments", c->span); + this->FatalError("the function is provided too few arguments", c->span); } } for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { - this->unify(fn_ty->arg_types[i], arg_types[i], c->args[i]->span); + this->Unify(fn_ty->arg_types[i], arg_types[i], c->args[i]->span); } // After we unify the arguments we should know more about the type @@ -326,30 +324,25 @@ CheckedExpr TypeInferencer::VisitExpr_(const LetNode *op) { Let let = GetRef(op); CheckedExpr checked_value; - Type annotated_ty = resolve(let->value_type); + Type annotated_ty = Resolve(let->value_type); // If we are let-defining a function, we want to be able to // recursively name the function in order to support recursive // local definitions. if (let->value.as()) { - with_frame([&]() { - context.insert(let->var, annotated_ty); - checked_value = Infer(let->value); - }); + context.Insert(let->var, annotated_ty); + checked_value = Infer(let->value); } else { checked_value = Infer(let->value); } - Type unified_ty = this->unify(checked_value.type, annotated_ty, let->span); + Type unified_ty = this->Unify(checked_value.type, annotated_ty, let->span); // Update type context with unified type now that we have // solved this equation. - context.insert(let->var, unified_ty); + context.Insert(let->var, unified_ty); - auto checked_body = with_frame([&]() { - context.insert(let->var, unified_ty); - return Infer(let->body); - }); + auto checked_body = Infer(let->body); auto checked_let = LetNode::make(let->var, checked_value.expr, checked_body.expr, let->value_type); @@ -365,12 +358,12 @@ CheckedExpr TypeInferencer::VisitExpr_(const IfNode *op) { auto checked_cond = this->Infer(ifn->cond); auto cond_type = checked_cond.type; - this->unify(cond_type, TensorTypeNode::make({}, HalideIR::Bool()), + this->Unify(cond_type, TensorTypeNode::make({}, HalideIR::Bool()), ifn->cond->span); auto checked_true = this->Infer(ifn->true_branch); auto checked_false = this->Infer(ifn->false_branch); auto unified_type = - this->unify(checked_true.type, checked_false.type, ifn->span); + this->Unify(checked_true.type, checked_false.type, ifn->span); auto checked_if = IfNode::make(checked_cond.expr, checked_true.expr, checked_false.expr); return {checked_if, unified_type}; @@ -381,7 +374,7 @@ CheckedExpr TypeInferencer::VisitExpr_(const OpNode *op_node) { return {op, op->op_type}; } -Type TypeInferencer::resolve(const Type &t) { +Type TypeInferencer::Resolve(const Type &t) { if (t.defined()) { return ::tvm::relay::Resolve(this->unifier, t); } else { @@ -389,7 +382,7 @@ Type TypeInferencer::resolve(const Type &t) { } } -Expr TypeInferencer::resolve(const Expr &e) { +Expr TypeInferencer::Resolve(const Expr &e) { CHECK(e.defined()); return ::tvm::relay::Resolve(this->unifier, e); } @@ -398,7 +391,7 @@ TypeRelation TypeInferencer::Solve(const TypeRelation &ty_rel) { Array normalized_args; for (auto arg : ty_rel->args) { - normalized_args.push_back(resolve(arg)); + normalized_args.push_back(Resolve(arg)); } auto new_args = ty_rel->func_(normalized_args, ty_rel->args.size() - 1); @@ -494,7 +487,7 @@ bool TypeInferencer::RelationsHold(bool scope_only) { << std::endl; bool all_hold = true; for (auto cs_set : context.constraints) { - auto ty_rels = Downcast(cs_set.ty_rels); + auto ty_rels = Downcast(cs_set); auto status = Solve(ty_rels); RELAY_LOG(INFO) << "status= " << status << std::endl; if (status == SolverResult::Failed || status == SolverResult::Progress) { @@ -513,7 +506,7 @@ Expr InferType(const Environment &env, const Expr &e) { TypeInferencer ti(env); auto checked_expr = ti.Infer(e); CHECK(ti.RelationsHold()); - return ti.resolve(checked_expr.expr); + return ti.Resolve(checked_expr.expr); } Expr InferType(const Environment &env, const GlobalVar &var, @@ -521,20 +514,20 @@ Expr InferType(const Environment &env, const GlobalVar &var, TypeInferencer ti(env); auto func_copy = FunctionNode::make(func->params, func->ret_type, func->body, func->type_params); - func_copy->checked_type_ = ti.resolve(func_copy->fn_type()); + func_copy->checked_type_ = ti.Resolve(func_copy->fn_type()); env->functions.Set(var, func_copy); auto checked_expr = ti.Infer(func); CHECK(ti.RelationsHold()); auto map_node = env->functions.CopyOnWrite(); map_node->data.erase(var.node_); - return ti.resolve(checked_expr.expr); + return ti.Resolve(checked_expr.expr); } -inline void TypeInferencer::report_error(const std::string &msg, Span sp) { +inline void TypeInferencer::ReportError(const std::string &msg, Span sp) { this->env->AddDiagnostic({msg, sp}); } -void TypeInferencer::fatal_error(const std::string &msg, Span sp) { +void TypeInferencer::FatalError(const std::string &msg, Span sp) { this->env->AddDiagnostic({msg, sp}); throw FatalTypeError( "internal error: this exception should" @@ -542,7 +535,7 @@ void TypeInferencer::fatal_error(const std::string &msg, Span sp) { msg); } -Type TypeInferencer::unify(const Type &t1, const Type &t2, Span sp) { +Type TypeInferencer::Unify(const Type &t1, const Type &t2, Span sp) { try { return this->unifier->unify(t1, t2); } catch (const dmlc::Error &e) { @@ -554,7 +547,7 @@ Type TypeInferencer::unify(const Type &t1, const Type &t2, Span sp) { ss << t2; // ss << PrintType(env, t2, WrapWidth(40)); ss << "`: " << e.what(); - this->fatal_error(ss.str(), sp); + this->FatalError(ss.str(), sp); } } From 4495bda85f77bfd7e66e10d1830dea4bfde2eb11 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 14 Sep 2018 17:57:30 -0700 Subject: [PATCH 094/136] Remove error reporting and RTS code. --- include/tvm/relay/environment.h | 30 ++-- include/tvm/relay/error.h | 6 - include/tvm/relay/source_map.h | 55 -------- python/tvm/relay/env.py | 2 +- python/tvm/relay/expr.py | 19 +-- python/tvm/relay/op/op.py | 69 ---------- python/tvm/relay/to_tvm.py | 235 -------------------------------- src/relay/ir/environment.cc | 41 ------ src/relay/ir/op.cc | 2 +- src/relay/pass/type_infer.cc | 18 +-- src/relay/source_map.cc | 74 ---------- 11 files changed, 25 insertions(+), 526 deletions(-) delete mode 100644 include/tvm/relay/source_map.h delete mode 100644 python/tvm/relay/to_tvm.py delete mode 100644 src/relay/source_map.cc diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h index 2df21b0bb2ce..fa805e2944d0 100644 --- a/include/tvm/relay/environment.h +++ b/include/tvm/relay/environment.h @@ -10,7 +10,6 @@ #include #include #include -#include #include #include #include @@ -40,22 +39,15 @@ struct Environment; * */ class EnvironmentNode : public RelayNode { - private: - /*! \brief A map from string names to global variables ensures global - * uniqueness. */ - tvm::Map global_map_; - /*! \brief A map from file names to source fragments. */ - SourceMap source_map_; - /*! \brief A list of the errors reported during the current run. */ - std::vector errors_; - public: /*! \brief A map from ids to all global functions. */ tvm::Map functions; EnvironmentNode() {} - void VisitAttrs(tvm::AttrVisitor* v) final {} + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("global_map_", &global_map_); + } TVM_DLL static Environment make(tvm::Map global_funcs); @@ -75,20 +67,20 @@ class EnvironmentNode : public RelayNode { // TODO(@jroesch, @tqchen): what are the semantics here void Merge(const Environment& env); - /*! \brief Add a source fragment to the environment. */ - SourceName AddSource(std::string file_name, std::string source); - - using Transformer = runtime::TypedPackedFunc< - runtime::TypedPackedFunc(const Environment&)>; + using Transformer = + runtime::TypedPackedFunc(const Environment&)>; /*! \brief Apply a function over every function in the global environment. */ void Transform(Transformer tranformer); - void AddDiagnostic(SpannedError); - void DisplayErrors(); - static constexpr const char* _type_key = "relay.Environment"; TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node); + + private: + /*! \brief A map from string names to global variables ensures global + * uniqueness. */ + tvm::Map global_map_; }; struct Environment : public NodeRef { diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index 055cc42936df..696e5a05487d 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -20,12 +20,6 @@ struct InternalError : Error { explicit InternalError(const std::string &msg) : Error(msg) {} }; -struct SpannedError { - std::string msg; - Span sp; - SpannedError(std::string msg, Span sp) : msg(msg), sp(sp) {} -}; - // FIX, we should change spanned errors to have a method which allow them to // report on the Environment, inverting control to error definition. struct FatalTypeError : dmlc::Error { diff --git a/include/tvm/relay/source_map.h b/include/tvm/relay/source_map.h deleted file mode 100644 index 277c3875a17f..000000000000 --- a/include/tvm/relay/source_map.h +++ /dev/null @@ -1,55 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file source_map.h - * \brief A representation of source files and a data structure for - * storing them. - */ -#ifndef TVM_RELAY_SOURCE_MAP_H_ -#define TVM_RELAY_SOURCE_MAP_H_ - -#include -#include -#include - -namespace tvm { -namespace relay { - -/*! \brief A fragment of a source file used for error reporting. - * - * These can be registered by the frontends and are used for - * displaying errors. - */ -struct SourceFragment { - /*! \brief The file name which the source fragment originates from. */ - std::string file_name; - /*! \brief The lines of source corresponding to the fragment. */ - std::vector source_lines; - - SourceFragment(const std::string& file_name, const std::string& source); - - SourceFragment(const SourceFragment& sf) { - this->file_name = sf.file_name; - this->source_lines = sf.source_lines; - } - - /*! \brief The lines of source code originate at lines. */ - std::string SourceAt(Span sp, int lines); -}; - -/*! \brief Maps from FileId's to a SourceFragment. - */ -class SourceMap { - /*! \brief Map from unique token to a fragment of a source file. */ - std::unordered_map map_; - - public: - SourceMap() : map_() {} - /*! \brief Add a source fragment with the file name and source. */ - SourceName AddSource(const std::string& file_name, const std::string& source); - /*! \brief Retrieve a source fragment by source name. */ - const SourceFragment& GetSource(SourceName id) const; -}; - -} // namespace relay -} // namespace tvm -#endif // TVM_RELAY_SOURCE_MAP_H_ diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py index 93cbe1bca284..cd3aba69813c 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/env.py @@ -52,7 +52,7 @@ def global_var(self, var): """Get a global variable by name.""" return _env.Environment_GetGlobalVar(self, var) - def lookup(self, var): + def __get_item__(self, var): """Lookup a global function by name or by variable.""" if isinstance(var, str): return _env.Environment_Lookup_str(self, var) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index ebf8483ccc04..63f525a316b2 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -8,11 +8,13 @@ from . import _make -class ExprBuilder(): - """A set of methods useful for building expressions - from other expressions. - """ - def __call__(self, *args, **kwargs): +class Expr(NodeBase, ExprBuilder): + """The base type for all Relay exprressions.""" + + def checked_type(self): + return _get_checked_type(self) + + def __call__(self, *args): converted_args = [] for arg in args: if isinstance(arg, Param): @@ -23,13 +25,6 @@ def __call__(self, *args, **kwargs): return Call(self, args, None, None) -class Expr(NodeBase, ExprBuilder): - """The base type for all Relay exprressions.""" - - def checked_type(self): - return _get_checked_type(self) - - @register_relay_node class Constant(Expr): """A constant tensor in Relay, see tvm/relay/type.h for more details. diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index adf963403bd3..59616502d2d7 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -76,73 +76,4 @@ def _register(v): return _register(value) if value else _register -def compile_ops(op_names): - """Register an operator property of an operator. - - - Parameters - ---------- - op_names : List[str] - A list of operator names to compile to machine code. - - Returns - ------- - A module containing the compiled TVM operators. - """ - return _CompileOpsToModule(*op_names) - - -# TODO(@jroesch): We should port to C++, just need to figure out how to write this code. -@register_func("relay.op._compile_ops") -def _compile_ops(op_impls): - lowered = [] - for local, sch, inputs in op_impls: - lfn = lower(sch, inputs, name=local.name_hint) - lowered.append(lfn) - - # TOOD(@jroesch): Where should we read these settings from - return build(lowered, target='llvm', target_host='llvm') - - _init_api("relay.op", __name__) - - -def specialize_op(op_name, new_op_name, type_args): - """Specializes an operator to a set of types and assigns it new_op_name. - - The idea is to take operators with generic types such as broadcasting - addition: - - add : forall (T : Type) (U : Type), (U, T) -> Broadcast(U, T) - - This is a function which is polymorphic over two types `T` and `U` and - takes a value of type `T` and one of `U` and returns `Broadcast` of U - and T. - - Broadcast is a type relation which relates U and T to an output type. - - The idea is that the above type is shorthand for: - - add : forall (T : Type) (U : Type) (O : Type), Broadcast(U, T, O) => (U, T) -> O - - That is a function from U and T to O where the typing relation between the values - is specified by Broadcast. - - We implement a basic Broadcasting rule in `type_relations.h` but users can specify - their own. - - If we know T=Tensor[(10, 10), dtype], U=Tensor[(10, 10), dtype] then the result - should be Tensor[(10, 10), dtype]. - - We can use SpecializeOp to implement this change of operator. - - Parameters - ---------- - op_name : str - The operator to be specialized. - - Returns - ------- - The specialized operator. - """ - return _SpecializeOp(op_name, new_op_name, type_args) diff --git a/python/tvm/relay/to_tvm.py b/python/tvm/relay/to_tvm.py deleted file mode 100644 index 99e6b6d41674..000000000000 --- a/python/tvm/relay/to_tvm.py +++ /dev/null @@ -1,235 +0,0 @@ -"""A compiler from Relay programs to TVM's graph runtime. -""" -import json -from typing import Dict, Any, List, Tuple, Set - -import attr -from .ir_pass import AbstractExprVisitor -from .op import compile_ops, Op -from .type import TensorType -from .expr import Var, Function, Let, Call - - -@attr.s(auto_attribs=True) -class NodeRef: - ident: int - index: int = 0 - version: int = 0 - - def to_json(self) -> Any: - return [self.ident, self.index, self.version] - - -@attr.s(auto_attribs=True) -class Node(): - name: str - attrs: Dict[str, Any] - is_output: bool - - def to_json(self) -> Any: - raise Exception("Abstract method, please implement me.") - - -@attr.s(auto_attribs=True) -class InputNode(Node): - """An input node in the graph representation we lower to before NNVM's graph.""" - is_output: bool = False - - def to_json(self): - return { - "op": "null", - "name": self.name, - "inputs": [] - } - - -@attr.s(auto_attribs=True) -class OpNode(Node): - """An operator node in the graph representation we lower to before NNVM's graph.""" - op_name: str - inputs: List[NodeRef] - op_attrs: Dict[str, Any] - is_output: bool = False - - def to_json(self) -> Any: - attrs = dict.copy(self.op_attrs) - # Extend ops with extra info. - attrs['func_name'] = self.op_name - # When do we flatten? - attrs['flatten_data'] = "0" - # Fix me! - attrs['num_inputs'] = str(len(self.inputs)) - attrs['num_outputs'] = "1" - - return { - "op": "tvm_op", - "name": self.name, - "attrs": attrs, - "inputs": self.inputs - } - - -def shape_to_json(shape): - return [sh.value for sh in shape] - - -def from_tensor(typ: TensorType) -> Tuple[str, List[int]]: - return (typ.dtype, shape_to_json(typ.shape)) - - -class TVMRTSCompiler(AbstractExprVisitor[NodeRef]): - """The compiler from Relay to the TVM runtime system.""" - nodes: List[Node] - id_map: Dict[Var, NodeRef] - all_ops: Set[Op] - - def __init__(self) -> None: - self.nodes = [] - self.id_map = {} - self.all_ops = set() - - def add_node(self, node: Node) -> NodeRef: - self.nodes.append(node) - ident = len(self.nodes) - 1 - return NodeRef(ident) - - def add_binding(self, ident: Var, ref: NodeRef) -> None: - self.id_map[ident] = ref - - def let_bind(self, ident: Var, node: Node) -> NodeRef: - ref = self.add_node(node) - self.add_binding(ident, ref) - return ref - - def get_node(self, ref: NodeRef) -> Node: - return self.nodes[ref.ident] - - def lookup(self, ident: Var) -> NodeRef: - return self.id_map[ident] - - def compile(self, func: Function) -> None: - """Compile a single function into a graph.""" - # TODO: (@jroesch) Restore me - # assert len(fn.ty_params) == 0 - - # First we convert all the parameters into input nodes. - params = func.params - - for param in params: - dtype, shape = from_tensor(param.type) - node = InputNode(f"{param.var.name_hint}", { - "shape": shape, - "dtype": dtype, - }) - self.let_bind(param.var, node) - - # Then we compile the body into a graph which can depend - # on input variables. - output_ref = self.visit(func.body) - - # Finally we retreive return value of program, which will - # become our output node. - self.get_node(output_ref).is_output = True - - def visit_let(self, let: Let) -> NodeRef: - """Visit the Let binding, by first traversing its value, - then setting the metadata on the returned NodeRef. - - Finally visit the body, and return the NodeRef corresponding - to it. - """ - ident = let.var - val = let.value - body = let.body - - # Need to add type info? - val_ref = self.visit(val) - dtype, shape = from_tensor(val.checked_type()) - val_node = self.get_node(val_ref) - val_node.attrs["dtype"] = dtype - val_node.attrs["shape"] = shape - self.add_binding(ident, val_ref) - return self.visit(body) - - def visit_local_var(self, ident: Var) -> NodeRef: - return self.lookup(ident) - - def visit_call(self, call: Call) -> NodeRef: - """Transform a ::tvm.relay.Call into an operator in the TVM graph.""" - inputs = [] - for arg in call.args: - inputs.append(self.visit(arg).to_json()) - - assert isinstance(call.op, Op) - self.all_ops.add(call.op.name) - - op_name = call.op.name - attrs = {'shape': shape_to_json(call.checked_type().shape), - 'dtype': call.checked_type().dtype} - op_node = OpNode("call_name", attrs, op_name, inputs, {}) - return self.add_node(op_node) - - def to_json(self) -> str: - """Convert the sequence of nodes stored by the compiler into the - JSON format defined in: https://docs.tvm.ai/dev/nnvm_json_spec.html. - """ - nodes = [] - # First we compute "nodes" field. - for node in self.nodes: - nodes.append(node.to_json()) - - arg_nodes = [] - heads = [] - # Compute "arg_nodes" and "heads" fields. - for i, node in enumerate(self.nodes): - if isinstance(node, InputNode): - arg_nodes.append(i) - - if node.is_output: - # Need to fix this. - heads.append(NodeRef(i).to_json()) - - # Compute "node_row_ptr". - # TODO - - # Compute "attrs" field. - attrs = {} - - # A - shapes = [] - storage_ids = [] - dtype = [] - dltype = [] - - for i, node in enumerate(self.nodes): - storage_ids.append(i) - shapes.append(node.attrs['shape']) - if node.attrs['dtype'] == 'float32': - dtype.append(0) - dltype.append('float32') - - attrs["shape"] = ["list_shape", shapes] - attrs["storage_id"] = ["list_int", storage_ids] - attrs["dtype"] = ["list_int", dtype] - attrs["dltype"] = ["list_str", dltype] - - json_dict = { - "nodes": nodes, - "arg_nodes": arg_nodes, - "heads": heads, - "attrs": attrs - } - - return json.dumps(json_dict) - - -def compile_to_tvm(func): - """Compile a single function to the components needed by the - TVM RTS. - """ - comp = TVMRTSCompiler() - comp.compile(func) - op_names = list(comp.all_ops) - mod = compile_ops(op_names) - graph_json = comp.to_json() - return graph_json, mod, None # params currently isn't supported by API diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc index b5f0d663d26a..c8d099b6f269 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/environment.cc @@ -102,47 +102,6 @@ void EnvironmentNode::Merge(const Environment &env) { } } -inline SourceName EnvironmentNode::AddSource(std::string file_name, - std::string source) { - return this->source_map_.AddSource(file_name, source); -} - -void EnvironmentNode::AddDiagnostic(SpannedError error) { - this->errors_.push_back(error); -} - -void EnvironmentNode::DisplayErrors() { - throw Error("need to restore error printing"); - // for (auto err : this->errors_) { - // auto sp = err.sp; - // auto source_file = this->source_map_.GetSource(err.sp->file_id); - // auto file_name = source_file.file_name; - // auto source_at_span = source_file.SourceAt(err.sp, 1); - // std::string error_marker = "error:"; - // auto line_info = - // std::to_string(sp->lineno) + ":" + std::to_string(sp->col_offset); - - // std::cout << rang::style::bold << rang::fg::red << error_marker - // << rang::fg::reset << file_name << ":" << line_info - // << rang::style::reset << " " << source_at_span << std::endl; - - // // Build the cursor. - - // // Fix this code, hardwired to compute alignment of pointer. - // size_t spaces = error_marker.size() + line_info.size() + file_name.size() - // + - // sp->col_offset - 3; - - // std::string cursor = "~~~~^~~~~"; - // for (size_t i = 0; i < spaces; i++) { - // std::cout << " "; - // } - // std::cout << rang::fg::red << cursor << " " << err.msg << - // rang::style::reset - // << std::endl; - // } -} - void EnvironmentNode::Transform(EnvironmentNode::Transformer transformer) { Array to_process; for (auto var_and_func : this->functions) { diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 8b6748dfe482..61e4b34a0fc7 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -35,7 +35,7 @@ struct OpManager { std::unordered_map> attr; // frontend functions std::vector frontend_funcs; - // get singleton of the + // get singleton of the op manager static OpManager* Global() { static OpManager inst; return &inst; diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 76321a0fd4c5..fa26601b608a 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -50,9 +50,9 @@ struct TypeContext { } Type Lookup(const Var &id) { - auto type = var_map.find(id); - if (type != var_map.end()) { - return (*type).second; + auto type = var_map.find(id); + if (type != var_map.end()) { + return (*type).second; } else { throw FatalTypeError("Could not resolve local id"); } @@ -60,9 +60,7 @@ struct TypeContext { struct Frame { TypeContext &tc; - explicit Frame(TypeContext &tc) : tc(tc) { - tc.constraints.push_back({}); - } + explicit Frame(TypeContext &tc) : tc(tc) { tc.constraints.push_back({}); } ~Frame() { tc.constraints.pop_back(); } }; }; @@ -137,7 +135,6 @@ TypeInferencer::TypeInferencer(Environment env) : env(env) { this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); } - CheckedExpr TypeInferencer::Infer(const Expr &expr) { RELAY_LOG(INFO) << "TypeInferencer::Check expr=" << expr << std::endl; CheckedExpr checked_expr = this->VisitExpr(expr); @@ -264,7 +261,7 @@ CheckedExpr TypeInferencer::VisitExpr_(const CallNode *op) { if (!fn_ty_node) { this->FatalError("only expressions with function types can be called", - c->op->span); + c->op->span); } // We now have a function type. @@ -523,12 +520,7 @@ Expr InferType(const Environment &env, const GlobalVar &var, return ti.Resolve(checked_expr.expr); } -inline void TypeInferencer::ReportError(const std::string &msg, Span sp) { - this->env->AddDiagnostic({msg, sp}); -} - void TypeInferencer::FatalError(const std::string &msg, Span sp) { - this->env->AddDiagnostic({msg, sp}); throw FatalTypeError( "internal error: this exception should" "be handled and errors reported with Environment::display_errors\n" + diff --git a/src/relay/source_map.cc b/src/relay/source_map.cc deleted file mode 100644 index 7630135a1e5e..000000000000 --- a/src/relay/source_map.cc +++ /dev/null @@ -1,74 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file source_map.cc - * \brief Source maps for Relay. - */ - -#include -#include -#include - -namespace tvm { -namespace relay { - -using tvm::IRPrinter; -using namespace tvm::runtime; - -SourceFragment::SourceFragment(const std::string& file_name, - const std::string& source) - : file_name(file_name), source_lines({}) { - RELAY_LOG(INFO) << "SourceFragment::SourceFragment source=" << source - << std::endl; - std::stringstream source_stream; - source_stream.str(source.c_str()); - std::string line; - - while (std::getline(source_stream, line)) { - RELAY_LOG(INFO) << "SourceFragment::SourceFragment: line=" << line - << std::endl; - std::string copy(line); - source_lines.push_back(copy); - } -} - -std::string SourceFragment::SourceAt(Span sp, int max_lines = 1) { - std::stringstream out; - - int starting_line = sp->lineno; - - if (starting_line >= static_cast(this->source_lines.size())) { - throw dmlc::Error("SourceFragment: index out of bounds"); - } - - auto lines = std::max(static_cast(max_lines), - source_lines.size() - starting_line); - - for (size_t i = 0; i < lines; i++) { - out << std::endl << this->source_lines.at(starting_line + i); - } - - auto source_slice = out.str(); - - RELAY_LOG(INFO) << "SourceFragment::SourceAt: source_slice=" << source_slice - << std::endl; - return source_slice; -} - -SourceName SourceMap::AddSource(const std::string & file_name, const std::string & source) { - auto new_id = SourceNameNode::make(file_name); - SourceFragment sfile(file_name, source); - this->map_.insert({new_id, sfile}); - return new_id; -} - -const SourceFragment& SourceMap::GetSource(SourceName id) const { - auto item = map_.find(id); - if (item != map_.end()) { - return (*item).second; - } else { - throw dmlc::Error("could not find requested source fragment"); - } -} - -} // namespace relay -} // namespace tvm From 3f5e2c88fecd52f40e4b95eb20aad48ed363e41d Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 16 Sep 2018 13:18:27 -0700 Subject: [PATCH 095/136] Add documentation for env.py --- python/tvm/relay/env.py | 38 +++++++++++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py index cd3aba69813c..4c73db0b524b 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/env.py @@ -48,17 +48,45 @@ def merge(self, other): """ return _env.Environment_Merge(self, other) - def global_var(self, var): - """Get a global variable by name.""" - return _env.Environment_GetGlobalVar(self, var) + def global_var(self, name): + """Get a global variable by name. + + Parameters + ---------- + name: str + The name of the global variable. + + Returns + ------- + global_var: GlobalVar + The global variable mapped to :code:`name`. + """ + return _env.Environment_GetGlobalVar(self, name) def __get_item__(self, var): - """Lookup a global function by name or by variable.""" + """Lookup a global function by name or by variable. + + Parameters + ---------- + var: str or GlobalVar + The name or global variable. + + Returns + ------- + func: Function + The function referenced by :code:`var`. + """ if isinstance(var, str): return _env.Environment_Lookup_str(self, var) else: return _env.Environment_Lookup(self, var) def transform(self, transformer): - """Apply a transformer function to the environment.""" + """Apply a transformer function to the environment. + + Parameters + ---------- + transformer: function + The environment transformer function. + """ _env.Environment_Transform(self, transformer) From c908b7a466fba62348374e8a314214f18d431216 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 16 Sep 2018 13:58:00 -0700 Subject: [PATCH 096/136] Clean up and docs --- include/tvm/relay/environment.h | 46 ++++- include/tvm/relay/expr_visitor.h | 14 +- include/tvm/relay/op.h | 2 +- python/tvm/relay/env.py | 2 + python/tvm/relay/ir_pass.py | 227 +---------------------- tests/scripts/task_python_integration.sh | 2 + 6 files changed, 52 insertions(+), 241 deletions(-) diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h index fa805e2944d0..75ddc88674e6 100644 --- a/include/tvm/relay/environment.h +++ b/include/tvm/relay/environment.h @@ -51,28 +51,56 @@ class EnvironmentNode : public RelayNode { TVM_DLL static Environment make(tvm::Map global_funcs); + /*! \brief Add a function to the global environment. + * \param var The name of the global function. + * \param func The function. + * \param update Controls whether you can replace a definition in the + * environment. + */ void Add(const GlobalVar& var, const Function& func, bool update = false); + + /*! \brief Update a function in the global environment. + * \param var The name of the global function to update. + * \param func The new function. + */ void Update(const GlobalVar& var, const Function& func); + + /*! \brief Remove a function from the global environment. + * \param var The name of the global function to update. + */ void Remove(const GlobalVar& var); - /*! \brief Lookup a global function by its variable. */ + /*! \brief Lookup a global function by its variable. + * \param str The unique string specifying the global variable. + * \returns The global variable. + */ GlobalVar GetGlobalVar(const std::string& str); - /*! \brief Lookup a global function by its variable. */ - Function Lookup(const GlobalVar& id); + /*! \brief Lookup a global function by its variable. + * \param var The global var to lookup. + * \returns The function named by the variable argument. + */ + Function Lookup(const GlobalVar& var); - /*! \brief Lookup a global function by its string name */ - Function Lookup(const std::string& s); + /*! \brief Lookup a global function by its string name + * \param name The name of the function. + * \returns The function named by the argument. + */ + Function Lookup(const std::string& name); - // TODO(@jroesch, @tqchen): what are the semantics here - void Merge(const Environment& env); + /*! \brief Combine with another Environment. + * \param other The other environment. + */ + void Merge(const Environment& other); using Transformer = runtime::TypedPackedFunc(const Environment&)>; - /*! \brief Apply a function over every function in the global environment. */ - void Transform(Transformer tranformer); + /*! \brief Apply a function over every function in the global environment. + * \param transformer The transformation function. + */ + void Transform(Transformer transformer); static constexpr const char* _type_key = "relay.Environment"; TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node); diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h index 748b8ac02f97..4a26dcbd32e7 100644 --- a/include/tvm/relay/expr_visitor.h +++ b/include/tvm/relay/expr_visitor.h @@ -97,9 +97,9 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { auto type = this->VisitType(op->type); return ParamNode::make(var, type); } else { - CHECK(false) << "the default param visitor expected a Var found: " + LOG(FATAL) << "the default param visitor expected a Var found: " << var_expr << std::endl; - __builtin_unreachable(); + return Expr(); } } @@ -112,10 +112,10 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { auto ty_param_ref = GetRef(ty_param); ty_params.push_back(ty_param_ref); } else { - CHECK(false) + LOG(FATAL) << "the default function visitor expected a TypeParam found: " << ty_param_type << std::endl; - __builtin_unreachable(); + return Expr(); } } @@ -128,7 +128,7 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { } else { CHECK(false) << "the default function visitor expected a Param found: " << param_expr << std::endl; - __builtin_unreachable(); + return Expr(); } } @@ -165,9 +165,9 @@ class ExprFVisitor : public ::tvm::relay::ExprFunctor { auto body = this->VisitExpr(op->body); return LetNode::make(var, value, body, type); } else { - CHECK(false) << "the default let visitor expected a Var found: " + LOG(FATAL) << "the default let visitor expected a Var found: " << var_expr << std::endl; - __builtin_unreachable(); + return Expr(); } } diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index a3037d3bebf4..3ab8c778c76d 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -208,7 +208,7 @@ class OpRegistry { } return *this; } - /*! \return The global single retistry */ + /*! \return The global single registry */ TVM_DLL static ::dmlc::Registry* Registry(); private: diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py index 4c73db0b524b..6f4362a77c2d 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/env.py @@ -4,11 +4,13 @@ from . import _make from . import _env + @register_relay_node class Environment(NodeBase): """The global Relay environment containing functions, options and more. """ + def __init__(self, funcs) -> None: """Construct an environment. diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index ca396404610a..37f7001c460b 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -1,232 +1,11 @@ # pylint: disable=no-else-return, # pylint: disable=unidiomatic-typecheck -"""The optimizer for Relay. +"""The set of passes for Relay. -Exposes an interface for configuring the optimizer and scripting -it directly in Python. +Exposes an interface for configuring the passes and scripting +them in Python. """ -from typing import TypeVar, Generic, Union -from typing import Dict, Tuple, List, Callable -import tvm - -from .expr import Expr -from .expr import Function, Let, Call, Var -from .expr import GlobalVar, If, Constant -from .type import Type, TypeParam -from .env import Environment -from .op import Op -from .op.op import specialize_op -# import relay.make as relay_mk -# from relay import ir -# from relay.env import Environment -# from relay.tyck import check_expr -# from relay.first_order_reverse_ad import fo_with_gradient -# from relay.anf import to_anf from . import _ir_pass # Expose checking expression, should rename to infer_type. -# pylint: disable=invalid-name check_expr = _ir_pass.check_expr - -# # pylint: disable=invalid-name -# concretize = _opt.concretize - -# # pylint: disable=invalid-name -# optimize = _opt.optimize - -# # pylint: disable=invalid-name -# type_specialize = _opt.type_specialize - -# # pylint: disable=invalid-name -# compile_ops_to_module = _opt.compile_ops_to_module - - -@tvm.register_func("relay.mangle") -def mangle(name: str, types: List[Type]) -> str: - for typ in types: - name += str(typ) + "_" - return name - - -T = TypeVar('T') - - -class AbstractExprVisitor(Generic[T]): - """A functional visitor over Expr in Python.""" - - # pylint: disable=no-else-return - def visit(self, expr: Expr) -> T: - """Apply the visitor to an expression.""" - if isinstance(expr, Function): - return self.visit_function(expr) - elif isinstance(expr, Call): - return self.visit_call(expr) - elif isinstance(expr, Let): - return self.visit_let(expr) - elif isinstance(expr, Var): - return self.visit_local_var(expr) - elif isinstance(expr, GlobalVar): - return self.visit_global_var(expr) - elif isinstance(expr, If): - return self.visit_if(expr) - elif isinstance(expr, Tuple): - return self.visit_tuple(expr) - elif isinstance(expr, Constant): - return self.visit_constant(expr) - else: - raise Exception(f"warning unhandled case: {type(expr)}") - - def visit_function(self, _: Function) -> T: - raise Exception("Abstract method please implement me.") - - def visit_let(self, _: Let) -> T: - raise Exception("Abstract method please implement me.") - - def visit_call(self, _: Call) -> T: - raise Exception("Abstract method please implement me.") - - def visit_local_id(self, _: Var) -> T: - raise Exception("Abstract method please implement me.") - - def visit_type(self, typ: Type) -> Type: - return typ - - def visit_if(self, _: If) -> T: - raise Exception("Abstract method please implement me.") - - def visit_tuple(self, _: Tuple) -> T: - raise Exception("Abstract method please implement me.") - - def visit_constant(self, _: Constant) -> T: - raise Exception("Abstract method please implement me.") - - def visit_global_var(self, _: GlobalVar) -> T: - raise Exception("Abstract method please implement me.") - - @classmethod - def to_pass(cls) -> Callable[[Environment], Callable[[GlobalVar, Function], Function]]: - def _outer_wrapper(env): - visitor = cls(env) - - def _inner_wrapper(_, func): - return visitor.visit(func) - return _inner_wrapper - return _outer_wrapper - - -class ExprVisitor(AbstractExprVisitor[Expr]): - """A functional visitor over Expr in Python.""" - - def visit_function(self, fn: Function) -> Expr: - new_body = self.visit(fn.body) - return Function( - list(fn.params), - fn.ret_type, new_body, - fn.type_params) - - def visit_let(self, let: Let) -> Expr: - new_var = self.visit(let.var) - new_value_type = self.visit_type(let.value_type) - new_val = self.visit(let.value) - new_body = self.visit(let.body) - return Let(new_var, new_val, new_body, new_value_type) - - def visit_call(self, call: Call) -> Expr: - new_fn = self.visit(call.op) - new_args = [self.visit(arg) for arg in call.args] - return Call(new_fn, new_args, call.attrs) - - def visit_local_var(self, local_var: Var) -> Expr: - return local_var - - def visit_global_id(self, global_var: GlobalVar) -> Expr: - return global_var - - def visit_if(self, ite: If) -> Expr: - return If( - self.visit(ite.guard), - self.visit(ite.true_b), - self.visit(ite.false_b)) - - def visit_tuple(self, tup: Tuple) -> Expr: - return Tuple([self.visit(field) for field in tup.fields]) - - def visit_constant(self, const: Constant) -> Expr: - return const - - -MMCacheKey = Tuple[Union[GlobalVar, str], List[Type]] - - -class Monomorphize(ExprVisitor): - """A monomorphization pass. - - Implements what is known as "monomorphization" in - classic compiler literature. This pass removes - polymorphism replacing calls to functions and - operators with type specialized versions. - """ - monomorph_map: Dict[MMCacheKey, Union[Op, Function]] - - # pylint: disable=super-init-not-called - def __init__(self, env: Environment) -> None: - self.env = env - # Stores (GlobalVar, Type), should eventually store attributes. - self.monomorph_map = {} - - # pylint: disable=no-else-return - def visit_call(self, call: Call) -> Expr: - cache_key = (call.op, call.type_args) - new_args = [self.visit(arg) for arg in call.args] - - if cache_key in self.monomorph_map: - op = self.monomorph_map[cache_key] - new_args = [self.visit(arg) for arg in call.args] - return Call(op, new_args, call.attrs) - else: - if isinstance(call.op, Op): - poly_name = call.op.name - mono_name = mangle(poly_name, call.type_args) - for arg in call.type_args: - if isinstance(arg, TypeParam): - # raise Exception("...") # Fix me in the morning!!! - return call - - mono_op = specialize_op(poly_name, mono_name, call.type_args) - self.monomorph_map[cache_key] = mono_op - return Call(mono_op, new_args, call.attrs, []) - elif isinstance(call.op, GlobalVar): - return call - # defn = self.env.lookup(call.op) - # new_id = self.env.global_id(defn.id.name + str(1)) - # cache_key = (call.op, call.type_args) - # self.monomorph_map[cache_key] = new_id - # new_body = self.visit(type_specialize(call.type_args, defn.body)) - # new_body = Function( - # [], new_body.params, new_body.ret_type, new_body.body) - # new_ty = check_expr(self.env, new_body) - # # TODO(@jroesch): move into C++ - # # TODO(@joresch): implement and call name mangler - # defn = Defn(new_id, new_ty, new_body) - # self.env.add(defn) - # self.visit_item(defn) - # return Call(new_id, call.args, call.attrs) - - elif isinstance(call.op, Function): - return call - # new_func = type_specialize(call.type_args, call.op) - # new_func = self.visit(new_func) - # new_func = Function([], - # new_func.params, - # new_func.ret_type, - # new_func.body) - # check_expr(self.env, new_func) - # return Call(new_func, call.args, call.attrs) - else: - new_fn = self.visit(call.op) - return Call(new_fn, new_args, call.attrs) - - -# TODO(@jroesch): Fix up my type -__tgt_host__ = __tgt__ = "llvm" -__relay_tvm_context__ = tvm.cpu() diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index 8104bf079502..7dcd5c921905 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -18,6 +18,8 @@ TVM_FFI=cython python -m nose -v tests/python/integration || exit -1 TVM_FFI=ctypes python3 -m nose -v tests/python/integration || exit -1 TVM_FFI=cython python -m nose -v tests/python/contrib || exit -1 TVM_FFI=ctypes python3 -m nose -v tests/python/contrib || exit -1 +TVM_FFI=cython python -m nose -v tests/python/relay || exit -1 +TVM_FFI=ctypes python3 -m nose -v tests/python/relay || exit -1 # Do not enabke OpenGL # TVM_FFI=cython python -m nose -v tests/webgl || exit -1 From c03be26d2549800939f64531cca28838f77ae656 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 16 Sep 2018 14:07:36 -0700 Subject: [PATCH 097/136] Add docs to unifier.h --- src/relay/pass/unifier.cc | 24 ++++++++++++------------ src/relay/pass/unifier.h | 28 ++++++++++++++++++---------- 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index 752c4c3f1116..3a2a29dd0dc7 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -26,7 +26,7 @@ UnionFind UnionFindNode::make(tvm::Map uf_map) { return UnionFind(n); } -void UnionFindNode::insert(const IncompleteType &v) { this->uf_map.Set(v, v); } +void UnionFindNode::Insert(const IncompleteType &v) { this->uf_map.Set(v, v); } void UnionFindNode::debug() { for (auto entry : this->uf_map) { @@ -42,15 +42,15 @@ void UnionFindNode::AssertAlphaEqual(const Type &l, const Type &r) { } } -void UnionFindNode::unify(const IncompleteType &v1, const Type &t) { +void UnionFindNode::Unify(const IncompleteType &v1, const Type &t) { RELAY_LOG(INFO) << "UnionFindNode::Unify v1=" << v1 << ", t=" << t << std::endl; - auto parent1 = this->find(v1); + auto parent1 = this->Find(v1); // if t is a type var, then unify parents const IncompleteTypeNode *tvn2 = t.as(); if (tvn2) { auto v2 = GetRef(tvn2); - auto parent2 = this->find(v2); + auto parent2 = this->Find(v2); // if parents are exactly equal, then we're done if (parent1 == parent2) { @@ -88,7 +88,7 @@ void UnionFindNode::unify(const IncompleteType &v1, const Type &t) { AssertAlphaEqual(parent1, t); } -Type UnionFindNode::find(const IncompleteType &v) { +Type UnionFindNode::Find(const IncompleteType &v) { // The node has no mapping, so its representative is just itself. if (this->uf_map.find(v) == this->uf_map.end()) { return v; @@ -108,7 +108,7 @@ Type UnionFindNode::find(const IncompleteType &v) { // otherwise, recurse and perform path compression IncompleteType pv = GetRef(rep); - Type higher_up = this->find(pv); + Type higher_up = this->Find(pv); this->uf_map.Set(v, higher_up); return higher_up; } @@ -134,7 +134,7 @@ TypeUnifier TypeUnifierNode::make(UnionFind uf) { return TypeUnifier(n); } -void TypeUnifierNode::insert(const IncompleteType &v) { this->uf->insert(v); } +void TypeUnifierNode::insert(const IncompleteType &v) { this->uf->Insert(v); } Type TypeUnifierNode::unify(const Type &t1, const Type &t2) { RELAY_LOG(INFO) << "TypeUnifierNode::unify: t1=" << t1 << " t2=" << t2 @@ -155,7 +155,7 @@ struct IncompleteTypeSubst : TypeFVisitor { // type var: look it up in the type map and recurse Type VisitType_(const IncompleteTypeNode *op) override { auto tv = GetRef(op); - auto parent = unifier->uf->find(tv); + auto parent = unifier->uf->Find(tv); if (parent == tv) { return tv; } @@ -191,8 +191,8 @@ Type TypeUnifierNode::unifyWithIncompleteType(const Type &t1, RELAY_LOG(INFO) << "unifyWithIncompleteType: t1=" << t1 << " t2=" << tv2 << std::endl; // Fix unify to return new representative - this->uf->unify(tv2, t1); - auto rep = this->uf->find(tv2); + this->uf->Unify(tv2, t1); + auto rep = this->uf->Find(tv2); RELAY_LOG(INFO) << "unifyWithIncompleteType: rep =" << rep << std::endl; return rep; } @@ -201,8 +201,8 @@ Type TypeUnifierNode::VisitType_(const IncompleteTypeNode *t1, const Type rt2) { IncompleteType tv1 = GetRef(t1); RELAY_LOG(INFO) << "VisitType_: IncompleteTypeNode t1=" << t1 << " = " << rt2 << std::endl; - this->uf->unify(tv1, rt2); - auto rep = this->uf->find(tv1); + this->uf->Unify(tv1, rt2); + auto rep = this->uf->Find(tv1); RELAY_LOG(INFO) << "VisitType_: IncompleteTypeNode rep=" << rep << std::endl; return rep; } diff --git a/src/relay/pass/unifier.h b/src/relay/pass/unifier.h index a5f3c60a85df..06cba1fc4461 100644 --- a/src/relay/pass/unifier.h +++ b/src/relay/pass/unifier.h @@ -26,11 +26,12 @@ struct SubstitutionError : dmlc::Error { explicit SubstitutionError(const std::string& msg) : Error(msg) {} }; -/*! \brief a union-find data structure for the type-checker */ -class UnionFind; // forward declaration +/*! \brief A union-find data structure for the type-checker */ +class UnionFind; class UnionFindNode : public Node { public: + /*! \brief The inernal map from incomplete types to their representatives. */ tvm::Map uf_map; UnionFindNode() {} @@ -39,14 +40,21 @@ class UnionFindNode : public Node { TVM_DLL static UnionFind make(tvm::Map uf_map); - // insert v into UF - void insert(const IncompleteType& v); - - // infers that v1 and v2 must be of the smae type - void unify(const IncompleteType& v1, const Type& v2); - - // returns representative of v's UF-group - Type find(const IncompleteType& v); + /*! \brief Insert it into the union find. + * \param it The type to add to the union find. + */ + void Insert(const IncompleteType& it); + + /*! \brief Union operation, combine two equivalence classes. + * \param it The incomplete type to unify. + * \param ty The other type. + */ + void Unify(const IncompleteType& it, const Type& t); + + /*! \brief Find operation, returns the representative of the argument. + * \param it The element to lookup. + */ + Type Find(const IncompleteType& it); void debug(); From bc94272e8a1f06fa056a50629e2f75eae798d891 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 16 Sep 2018 14:16:51 -0700 Subject: [PATCH 098/136] resolve.h docs and clean up whitespace --- python/tvm/relay/ir_builder.py | 6 +++++- python/tvm/relay/op/op.py | 2 -- src/relay/pass/resolve.cc | 1 - src/relay/pass/resolve.h | 25 +++++++++++++++++++++++++ 4 files changed, 30 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 9dc4796802c1..298172054bbe 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -60,6 +60,7 @@ def __exit__(self, ptype, value, trace): class PartialFunc(): """A wrapper around functions while they are being built.""" + def __init__(self, params, ret_type, body, type_params): self.params = params self.ret_type = ret_type @@ -77,6 +78,8 @@ def to_func(self): self.type_params) #pylint: disable=invalid-name + + def _mk_let(bindings, ret_value): let_expr = ret_value for var, (value, ty) in reversed(list(bindings.items())): @@ -90,6 +93,7 @@ class IRBuilder(): Enables users to build up a Relay environment and program. """ + def __init__(self): self.bindings = [{}] self.scopes = [{}] @@ -149,7 +153,7 @@ def _convert_params(self, raw_params): self.scopes[-1][var.name_hint] = var relay_params.append(param) - + return relay_params def function(self, *params): diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 59616502d2d7..f1130b52e7ce 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -3,8 +3,6 @@ from ..base import register_relay_node from ..expr import Expr -from ..._ffi.function import register_func -from ... import lower, build @register_relay_node diff --git a/src/relay/pass/resolve.cc b/src/relay/pass/resolve.cc index bc63d939959e..cf5fb2910245 100644 --- a/src/relay/pass/resolve.cc +++ b/src/relay/pass/resolve.cc @@ -12,7 +12,6 @@ namespace tvm { namespace relay { -// TODO(@jroesch): We should probably generalize the subst code. struct ResolveTypeType : TypeFVisitor { const TypeUnifier &unifier; diff --git a/src/relay/pass/resolve.h b/src/relay/pass/resolve.h index deb6558322b8..d213ddb2b1ff 100644 --- a/src/relay/pass/resolve.h +++ b/src/relay/pass/resolve.h @@ -13,8 +13,33 @@ namespace tvm { namespace relay { + +/*! \brief Resolve a type containing incomplete types. +* +* This pass replaces incomplete types with their representative, and +* converts types which are not defined into fresh variables. +* +* \param unifier The unifier containing the unification data. +* \param ty The type to resolve. +* \returns The resolved type. +*/ Type Resolve(const TypeUnifier & unifier, const Type & ty); + +/*! \brief Resolve an expression containing incomplete types. +* +* This pass replaces incomplete types with their representative, and +* converts types which are not defined into fresh variables. +* +* \param unifier The unifier containing the unification data. +* \param ty The expression to resolve. +* \returns The resolved expression. +*/ Expr Resolve(const TypeUnifier & unifier, const Expr & expr); + +/*! \brief Check if all types have been filled in. +* \param t The type. +* \returns True if the type is resolved, false otherwise. +*/ bool IsFullyResolved(const Type & t); } // namespace relay From 255cf151dfa3ce8cdb3cb9beed92844eb5562d3e Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 16 Sep 2018 14:24:57 -0700 Subject: [PATCH 099/136] Clean up interface a little more --- src/relay/pass/resolve.cc | 4 ++-- src/relay/pass/type_infer.cc | 8 ++++---- src/relay/pass/unifier.cc | 24 ++++++++++++------------ src/relay/pass/unifier.h | 12 ++++++------ 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/relay/pass/resolve.cc b/src/relay/pass/resolve.cc index cf5fb2910245..470374da5a0e 100644 --- a/src/relay/pass/resolve.cc +++ b/src/relay/pass/resolve.cc @@ -20,7 +20,7 @@ struct ResolveTypeType : TypeFVisitor { Type VisitType(const Type &t) override { if (!t.defined()) { auto inc_ty = IncompleteTypeNode::make(TypeParamNode::Kind::kType); - unifier->insert(inc_ty); + unifier->Insert(inc_ty); return inc_ty; } else { return TypeFVisitor::VisitType(t); @@ -28,7 +28,7 @@ struct ResolveTypeType : TypeFVisitor { } Type VisitType_(const IncompleteTypeNode *op) override { - return unifier->subst(GetRef(op)); + return unifier->Subst(GetRef(op)); } }; diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index fa26601b608a..1978d69ad9ce 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -235,7 +235,7 @@ FuncType TypeInferencer::Instantiate(FuncType fn_ty, // Eventually allow the type vars to be passed in. for (auto ty_param : fn_ty->type_params) { IncompleteType fresh = IncompleteTypeNode::make(ty_param->kind); - this->unifier->insert(fresh); + this->unifier->Insert(fresh); ty_args.push_back(fresh); subst_map.Set(ty_param, fresh); } @@ -303,7 +303,7 @@ CheckedExpr TypeInferencer::VisitExpr_(const CallNode *op) { // representatives. for (size_t i = 0; i < ty_args.size(); i++) { - ty_args.Set(i, this->unifier->subst(ty_args[i])); + ty_args.Set(i, this->unifier->Subst(ty_args[i])); } // Add type constraints from the function types. @@ -397,7 +397,7 @@ TypeRelation TypeInferencer::Solve(const TypeRelation &ty_rel) { tvm::Array final_args; for (size_t i = 0; i < new_args.size(); i++) { - final_args.push_back(unifier->unify(normalized_args[i], new_args[i])); + final_args.push_back(Unify(normalized_args[i], new_args[i], ty_rel->span)); } return TypeRelationNode::make(ty_rel->name, ty_rel->func_, final_args); @@ -529,7 +529,7 @@ void TypeInferencer::FatalError(const std::string &msg, Span sp) { Type TypeInferencer::Unify(const Type &t1, const Type &t2, Span sp) { try { - return this->unifier->unify(t1, t2); + return this->unifier->Unify(t1, t2); } catch (const dmlc::Error &e) { std::stringstream ss; ss << "Error unifying `"; diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index 3a2a29dd0dc7..fa65d1ef18aa 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -128,15 +128,15 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "UnionFindNode(" << node->uf_map << ")"; }); -TypeUnifier TypeUnifierNode::make(UnionFind uf) { +TypeUnifier TypeUnifierNode::make(UnionFind union_find) { std::shared_ptr n = std::make_shared(); - n->uf = uf; + n->union_find = union_find; return TypeUnifier(n); } -void TypeUnifierNode::insert(const IncompleteType &v) { this->uf->Insert(v); } +void TypeUnifierNode::Insert(const IncompleteType &v) { this->union_find->Insert(v); } -Type TypeUnifierNode::unify(const Type &t1, const Type &t2) { +Type TypeUnifierNode::Unify(const Type &t1, const Type &t2) { RELAY_LOG(INFO) << "TypeUnifierNode::unify: t1=" << t1 << " t2=" << t2 << std::endl; @@ -155,7 +155,7 @@ struct IncompleteTypeSubst : TypeFVisitor { // type var: look it up in the type map and recurse Type VisitType_(const IncompleteTypeNode *op) override { auto tv = GetRef(op); - auto parent = unifier->uf->Find(tv); + auto parent = unifier->union_find->Find(tv); if (parent == tv) { return tv; } @@ -163,7 +163,7 @@ struct IncompleteTypeSubst : TypeFVisitor { } }; -Type TypeUnifierNode::subst(const Type &t) { +Type TypeUnifierNode::Subst(const Type &t) { IncompleteTypeSubst tvsubst(this); // normalize first so substitutions in quantifiers will be correct Type ret = tvsubst.VisitType(t); @@ -180,19 +180,19 @@ Type TypeUnifierNode::subst(const Type &t) { Type TypeUnifierNode::VisitType(const Type &t1, const Type t2) { // When the right hand size is a type variable immediately unify. if (const IncompleteTypeNode *tvn2 = t2.as()) { - return this->unifyWithIncompleteType(t1, GetRef(tvn2)); + return this->UnifyWithIncompleteType(t1, GetRef(tvn2)); } else { return TypeFunctor::VisitType(t1, t2); } } -Type TypeUnifierNode::unifyWithIncompleteType(const Type &t1, +Type TypeUnifierNode::UnifyWithIncompleteType(const Type &t1, const IncompleteType tv2) { RELAY_LOG(INFO) << "unifyWithIncompleteType: t1=" << t1 << " t2=" << tv2 << std::endl; // Fix unify to return new representative - this->uf->Unify(tv2, t1); - auto rep = this->uf->Find(tv2); + this->union_find->Unify(tv2, t1); + auto rep = this->union_find->Find(tv2); RELAY_LOG(INFO) << "unifyWithIncompleteType: rep =" << rep << std::endl; return rep; } @@ -201,8 +201,8 @@ Type TypeUnifierNode::VisitType_(const IncompleteTypeNode *t1, const Type rt2) { IncompleteType tv1 = GetRef(t1); RELAY_LOG(INFO) << "VisitType_: IncompleteTypeNode t1=" << t1 << " = " << rt2 << std::endl; - this->uf->Unify(tv1, rt2); - auto rep = this->uf->Find(tv1); + this->union_find->Unify(tv1, rt2); + auto rep = this->union_find->Find(tv1); RELAY_LOG(INFO) << "VisitType_: IncompleteTypeNode rep=" << rep << std::endl; return rep; } diff --git a/src/relay/pass/unifier.h b/src/relay/pass/unifier.h index 06cba1fc4461..4e939cc26bca 100644 --- a/src/relay/pass/unifier.h +++ b/src/relay/pass/unifier.h @@ -82,24 +82,24 @@ class TypeUnifier; class TypeUnifierNode : public Node, private TypeFunctor { public: - UnionFind uf; + UnionFind union_find; TypeUnifierNode() {} - void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("uf", &uf); } + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("union_find", &union_find); } TVM_DLL static TypeUnifier make(UnionFind uf); /*! \brief Introduces a new type var into the unifier */ - void insert(const IncompleteType& v); + void Insert(const IncompleteType& v); /*! \brief Unifies two types if possible, throws a unification error if it * cannot */ - Type unify(const Type& t1, const Type& t2); + Type Unify(const Type& t1, const Type& t2); /*! \brief Attempts to substitute all type vars in t with concrete types, * throws substitution error if it cannot concretize*/ - Type subst(const Type& t); + Type Subst(const Type& t); // /*! \brief Checks the kinds in the given type */ // Type CheckKinds(const Type& t); @@ -109,7 +109,7 @@ class TypeUnifierNode : public Node, private: /*! \brief Unify incomplete type with another type. */ - Type unifyWithIncompleteType(const Type& t1, const IncompleteType tvn2); + Type UnifyWithIncompleteType(const Type& t1, const IncompleteType tvn2); /*! \brief Implements unification between two types with incomplete portions. */ Type VisitType(const Type& t1, const Type t2) override; From 1e1924892b7d6d11d3f572559ea6a6b45ec3d220 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 16 Sep 2018 14:28:21 -0700 Subject: [PATCH 100/136] Fix a few issues from clean up --- python/tvm/relay/__init__.py | 1 - python/tvm/relay/env.py | 2 +- python/tvm/relay/expr.py | 2 +- python/tvm/relay/op/__init__.py | 2 +- src/relay/pass/type_infer.cc | 2 ++ .../relay/test_tyck_eval_integration.py | 34 +------------------ 6 files changed, 6 insertions(+), 37 deletions(-) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 4f58958feaf9..493036857b29 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -3,7 +3,6 @@ from . import base from . import type as tpe from . import expr -from . import to_tvm from . import env from . import ir_pass from . import ir_builder diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py index 6f4362a77c2d..e36c27d1a632 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/env.py @@ -65,7 +65,7 @@ def global_var(self, name): """ return _env.Environment_GetGlobalVar(self, name) - def __get_item__(self, var): + def __getitem__(self, var): """Lookup a global function by name or by variable. Parameters diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 63f525a316b2..85e69349321d 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -8,7 +8,7 @@ from . import _make -class Expr(NodeBase, ExprBuilder): +class Expr(NodeBase): """The base type for all Relay exprressions.""" def checked_type(self): diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 5c3a8ac249a6..0646a8326db6 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -1,7 +1,7 @@ #pylint: disable=wildcard-import """Relay core operators.""" # operator defs -from .op import get, register, Op, compile_ops +from .op import get, register, Op # Operators from .tensor import * diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 1978d69ad9ce..1f14c7879b45 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -464,6 +464,8 @@ SolverResult TypeInferencer::Solve(std::vector &rels) { status = SolverResult::Failed; break; } + + std::reverse(rels.begin(), rels.end()); } while (status == SolverResult::Progress); return status; } diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index 5338fad9ad8c..216d370fac7b 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -8,13 +8,8 @@ from tvm.relay.ir_builder import IRBuilder, float_type, int_type from tvm.relay.ir_builder import func_type, tensor_type, into_ast from tvm.relay.env import Environment -from tvm.relay.ir_pass import Monomorphize from tvm.relay.op import log, add, equal, subtract from tvm.relay.expr import Function -from tvm.relay import to_tvm -from tvm.contrib import graph_runtime -import nnvm - def assert_has_type(expr, typ, env=Environment({})): checked_expr = check_expr(env, expr) @@ -22,27 +17,10 @@ def assert_has_type(expr, typ, env=Environment({})): def assert_decl_has_type(env, name, typ): - func = env.lookup(name) + func = env[name] assert func.checked_type() == typ -def run(env, expr, inputs, shape): - if not isinstance(expr, Function): - expr = Function([], None, expr, []) - - env.add("main", expr) - env.transform(Monomorphize.to_pass()) - main = env.lookup("main") - graph, lib, _ = to_tvm.compile_to_tvm(main) - # We use NNVM to load the graph right now because it populates node_row_ptr field. - nnvm_graph = nnvm.graph.load_json(graph) - module = graph_runtime.create(nnvm_graph, lib, tvm.cpu(0)) - module.set_input(None, None, **inputs) - module.run() - out_nd_array = tvm.nd.array(np.empty(shape, dtype='float32')) - return module.get_output(0, out=out_nd_array) - - def test_monomorphic_let(): "Program: let x = 1; return x" b = IRBuilder() @@ -82,11 +60,6 @@ def test_add_op(): ttype = tensor_type(5, 5, 5) expected_ty = func_type([ttype, ttype], ttype) assert_has_type(func.to_func(), expected_ty) - x_data = tvm.nd.array(np.random.rand(5, 5, 5).astype('float32')) - y_data = tvm.nd.array(np.random.rand(5, 5, 5).astype('float32')) - result = run(env, prog, {'x': x_data, 'y': y_data}, (5, 5, 5)) - np.testing.assert_allclose( - x_data.asnumpy() + y_data.asnumpy(), result.asnumpy()) def test_add_broadcast_op(): """ @@ -105,11 +78,6 @@ def test_add_broadcast_op(): ttype = tensor_type(5, 5, 5) expected_ty = func_type([ttype, ttype], ttype) assert_has_type(func.to_func(), expected_ty) - x_data = tvm.nd.array(np.random.rand(10, 4).astype('float32')) - y_data = tvm.nd.array(np.random.rand(5, 10, 1).astype('float32')) - result = run(env, prog, {'x': x_data, 'y': y_data}, (5, 10, 4)) - np.testing.assert_allclose( - x_data.asnumpy() + y_data.asnumpy(), result.asnumpy()) def test_dual_op(): """Program: From 1b625a60e843789bd8c49170534505f0b42d6155 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 16 Sep 2018 14:44:10 -0700 Subject: [PATCH 101/136] Another few tweaks --- include/tvm/relay/expr.h | 4 +++- src/relay/pass/type_infer.cc | 35 ++++++++++++++++++++++++++++++----- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index beb4770c30eb..f61ee3503bdf 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -106,7 +106,9 @@ RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr); /*! * \brief Local variables used in the let expression. - * This is similar to Var that is being used in the low level tensor expression. + * + * Its semantics are similar to tvm.Var node used in TVM's low level + * tensor expression language. * * \note Each Var is bind only once and is immutable/ */ diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 1f14c7879b45..a787e39b1a3e 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -35,18 +35,42 @@ namespace relay { using namespace tvm::runtime; -using TypeConstraintSet = std::vector; +// @tqchen +// I wanted to use this data structure but then the algorithm gets more complex +// because we need to convert them back to the same representation as before +// when we check a single function scope. See line 240. +// +// I can see building an auxillary data structure at solve time but it seems +// like a lot of complexity for an unquantified speed gain, which we may or may +// not need. +// +// Thoughts? +// +// // We declare this for forward compatibility. +// struct ConstraintData {}; + +// struct TyRelData : ConstraintData { +// std::vector args; +// TypeRelationFn func; +// bool complete; +// TyRelData(Array args, TypeRelationFn func) : complete(false), +// func(func) { +// for (auto arg : args) { +// this->args.push_back(arg); +// } +// } +// }; struct TypeContext { std::unordered_map var_map; - std::vector constraints; + std::vector> constraints; TypeContext() { constraints.push_back({}); } void Insert(const Var &id, const Type &t) { var_map[id] = t; } - void AddConstraint(const TypeConstraint &ty_rel) { - constraints.back().push_back(ty_rel); + void AddConstraint(const TypeConstraint &constraint) { + constraints.back().push_back(constraint); } Type Lookup(const Var &id) { @@ -475,7 +499,8 @@ bool TypeInferencer::RelationsHold(bool scope_only) { // slice out the constraints. // // Otherwise we use all of them. - std::vector constraints; + std::vector> constraints; + if (scope_only) { constraints = {context.constraints[0]}; } else { From 6d386c8db8fe25a455027da5a6bdcca62a1f3edd Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 16 Sep 2018 15:17:27 -0700 Subject: [PATCH 102/136] Add SourceName::Get and serialization --- include/tvm/relay/base.h | 30 ++++++++++++++++++++++- src/relay/ir/base.cc | 51 +++++++++++++++++++++++++++++----------- 2 files changed, 66 insertions(+), 15 deletions(-) diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 96508de66e51..de5c93e9c94e 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -88,7 +88,35 @@ class SourceNameNode : public Node { TVM_DECLARE_NODE_TYPE_INFO(SourceNameNode, Node); }; -RELAY_DEFINE_NODE_REF(SourceName, SourceNameNode, NodeRef); +/*! + * \brief The source name of a file span. + * \sa SourceNameNode, Span + */ +class SourceName : public NodeRef { + public: + /*! \brief default constructor */ + SourceName() {} + + /*! \brief constructor from node pointer */ + explicit SourceName(std::shared_ptr n) : NodeRef(n) {} + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const SourceNameNode* operator->() const; + + /*! + * \brief Get an SourceName for a given operator name. + * Will raise an error if the source name has not been registered. + * \param op_name Name of the operator. + * \return Pointer to a Op, valid throughout program lifetime. + */ + TVM_DLL static const SourceName& Get(const std::string& name); + + /*! \brief specify container node */ + using ContainerType = SourceNameNode; +}; + /*! * \brief Span information for debugging purposes diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index d48b9a4c3e0c..e4ad08d893a3 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -18,17 +18,45 @@ SourceName SourceNameNode::make(std::string name) { return SourceName(n); } -// TVM_REGISTER_API("relay._make.SourceName") -// .set_body([](tvm::TVMArgs args, tvm::TVMRetValue *ret) { -// *ret = SourceNameNode::make(args[0]); -// }); +std::shared_ptr CreateSourceName(const std::string& name) { + SourceName sn = SourceName::Get(name); + CHECK(!sn.defined()) << "Cannot find source name \'" << name << '\''; + std::shared_ptr node = sn.node_; + return std::dynamic_pointer_cast(node); +} + +const SourceName& SourceName::Get(const std::string& name) { + static std::unordered_map *source_map; + + if (source_map == nullptr) { + source_map = new std::unordered_map(); + } -// This causes a crash? + auto sn = source_map->find(name); + if (sn == source_map->end()) { + auto source_name = SourceNameNode::make(name); + source_map->insert({name, source_name}); + return source_map->at(name); + } else { + return sn->second; + } +} + +TVM_REGISTER_API("relay._make.SourceName") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue *ret) { + *ret = SourceNameNode::make(args[0]); + }); -// TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -// .set_dispatch([](const SourceNameNode *node, tvm::IRPrinter *p) { -// p->stream << "SourceNameNode(" << node->name << ", " << node << ")"; -// }); +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const SourceNameNode *node, tvm::IRPrinter *p) { + p->stream << "SourceNameNode(" << node->name << ", " << node << ")"; + }); + +TVM_REGISTER_NODE_TYPE(SourceNameNode) +.set_creator(CreateSourceName) +.set_global_key([](const Node* n) { + return static_cast(n)->name; + }); Span SpanNode::make(SourceName source, int lineno, int col_offset) { std::shared_ptr n = std::make_shared(); @@ -43,11 +71,6 @@ TVM_REGISTER_API("relay._make.Span") *ret = SpanNode::make(args[0], args[1], args[2]); }); -TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const SourceNameNode *node, tvm::IRPrinter *p) { - p->stream << node->name; - }); - TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const SpanNode *node, tvm::IRPrinter *p) { p->stream << "SpanNode(" << node->source << ", " << node->lineno << ", " From 303e512396f49bf23e97463e04ff31840ccc3877 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 16 Sep 2018 15:20:16 -0700 Subject: [PATCH 103/136] Add serialization for op.cc --- src/relay/ir/op.cc | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 61e4b34a0fc7..bd06d6f5200b 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -138,6 +138,19 @@ TVM_REGISTER_API("relay.op._Register") } }); +std::shared_ptr CreateOp(const std::string& name) { + auto op = Op::Get(name); + CHECK(!op.defined()) << "Cannot find op \'" << name << '\''; + std::shared_ptr node = op.node_; + return std::dynamic_pointer_cast(node); +} + +TVM_REGISTER_NODE_TYPE(OpNode) +.set_creator(CreateOp) +.set_global_key([](const Node* n) { + return static_cast(n)->name; + }); + bool IsGeneric(const FuncType& func_ty) { return func_ty->type_params.size() != 0; } From ac0bb19651c6f48b26cd1d674c112a0d2215a0e7 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 16 Sep 2018 15:24:35 -0700 Subject: [PATCH 104/136] Fix lint --- python/tvm/relay/ir_pass.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 37f7001c460b..bcd0556082dc 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -8,4 +8,5 @@ from . import _ir_pass # Expose checking expression, should rename to infer_type. +# pylint: disable: invalid-name check_expr = _ir_pass.check_expr From 0fae0dec974db9fd8913f797db07828562ea328f Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 16 Sep 2018 15:26:44 -0700 Subject: [PATCH 105/136] mend --- python/tvm/relay/ir_pass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index bcd0556082dc..bbc294b59f5b 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -8,5 +8,5 @@ from . import _ir_pass # Expose checking expression, should rename to infer_type. -# pylint: disable: invalid-name +# pylint: disable=invalid-name check_expr = _ir_pass.check_expr From 2c81356fbd2e2bc79959a343001ec23d94c52cbf Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 17 Sep 2018 13:52:57 -0700 Subject: [PATCH 106/136] Restore alpha_eq.cc and refactor type_relations to use EnvFunc. --- include/tvm/relay/op.h | 28 +++++------------ include/tvm/relay/type.h | 2 +- src/relay/op/tensor/elemwise.cc | 12 ++++---- src/relay/op/type_relations.cc | 17 +++++++++++ src/relay/pass/alpha_eq.cc | 53 ++++++++++----------------------- 5 files changed, 47 insertions(+), 65 deletions(-) diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 3ab8c778c76d..4f4b0ef87aaa 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -153,18 +153,8 @@ class OpRegistry { * type. \param type_rel The backing relation which can solve an arbitrary * relation on variables. \return reference to self. */ - inline OpRegistry& add_type_rel(const std::string& type_rel_name, - TypeRelationFn type_rel); - - /*! - * \brief Attach the type function corresponding to the return type. - * \param type_rel_name The type function name to register for the return - * type. \param type_rel The backing relation which can solve an arbitrary - * relation on variables. \return reference to self. - */ - inline OpRegistry& add_type_rel( - const std::string& type_rel_name, - std::function(const Array&, int)> type_rel); + inline OpRegistry& add_type_rel(const std::string& rel_name, + const std::string& type_rel_func_name); /*! * \brief Set the type key of attributes. @@ -355,15 +345,11 @@ inline OpRegistry& OpRegistry::add_argument(const std::string& name, } inline OpRegistry& OpRegistry::add_type_rel( - const std::string& type_func_name, - std::function(const Array&, int)> type_fn) { - auto pfunc = - runtime::TypedPackedFunc(const Array&, int)>(type_fn); - return add_type_rel(type_func_name, pfunc); -} + const std::string& rel_name, const std::string& type_rel_func_name) { + auto env_func = EnvFunc::Get(type_rel_func_name); + TypedEnvFunc(const Array&, int)> type_rel_func; + type_rel_func = env_func; -inline OpRegistry& OpRegistry::add_type_rel(const std::string& type_func_name, - TypeRelationFn type_fn) { std::vector type_params; std::vector arg_types; @@ -385,7 +371,7 @@ inline OpRegistry& OpRegistry::add_type_rel(const std::string& type_func_name, ty_call_args.push_back(out_param); TypeConstraint type_rel = - TypeRelationNode::make(type_func_name, type_fn, ty_call_args); + TypeRelationNode::make(rel_name, type_rel_func, ty_call_args); auto func_type = FuncTypeNode::make(arg_types, out_param, type_params, {type_rel}); diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index d4b16043dbc0..71a52fdeb88d 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -211,7 +211,7 @@ class FuncTypeNode : public TypeNode { RELAY_DEFINE_NODE_REF(FuncType, FuncTypeNode, Type); using TypeRelationFn = - runtime::TypedPackedFunc(const Array&, int)>; + TypedEnvFunc(const Array&, int)>; /*! * \brief Opaque type relation, is an input-output relation on types. diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc index a18259c72117..683b661baa80 100644 --- a/src/relay/op/tensor/elemwise.cc +++ b/src/relay/op/tensor/elemwise.cc @@ -37,7 +37,7 @@ RELAY_REGISTER_UNARY_OP("log") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_rel("Log", IdentityRel); +.add_type_rel("Log", "tvm.relay.type_relations.IdentityRel"); // data : Tensor[shape, dtype] // result: Tensor[shape, dtype] @@ -51,7 +51,7 @@ RELAY_REGISTER_UNARY_OP("exp") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_rel("Exp", IdentityRel); +.add_type_rel("Exp", "tvm.relay.type_relations.IdentityRel"); RELAY_REGISTER_UNARY_OP("sqrt") @@ -62,7 +62,7 @@ RELAY_REGISTER_UNARY_OP("sqrt") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_rel("Sqrt", IdentityRel); +.add_type_rel("Sqrt", "tvm.relay.type_relations.IdentityRel"); // Addition TVM_REGISTER_API("relay.op._make.add") @@ -76,7 +76,7 @@ RELAY_REGISTER_OP("add") .add_argument("lhs", "Tensor", "The left hand side tensor.") .add_argument("rhs", "Tensor", "The right hand side tensor.") .set_support_level(1) - .add_type_rel("Broadcast", BroadcastRel); + .add_type_rel("Broadcast", "tvm.relay.type_relations.BroadcastRel"); // def broadcast(s1, s2): // ... @@ -97,7 +97,7 @@ RELAY_REGISTER_OP("subtract") .add_argument("lhs", "Tensor", "The left hand side tensor.") .add_argument("rhs", "Tensor", "The right hand side tensor.") .set_support_level(1) - .add_type_rel("BroadcastComp", BroadcastCompRel); + .add_type_rel("BroadcastComp", "tvm.relay.type_relations.BroadcastCompRel"); // def broadcast(s1, s2): // ... @@ -118,7 +118,7 @@ RELAY_REGISTER_OP("equal") .add_argument("lhs", "Tensor", "The left hand side tensor.") .add_argument("rhs", "Tensor", "The right hand side tensor.") .set_support_level(1) - .add_type_rel("BroadcastComp", BroadcastCompRel); + .add_type_rel("BroadcastComp", "tvm.relay.type_relations.BroadcastCompRel"); } // namespace relay } // namespace tvm diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 751583e738d4..b61b2c1de554 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -38,6 +38,11 @@ Array IdentityRel(const Array& types, int num_args) { } } +TVM_REGISTER_API("tvm.relay.type_relations.IdentityRel") +.set_body_typed(const Array&, int)>([](const Array& types, int num_args) { + return IdentityRel(types, num_args); +}); + static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) { RELAY_LOG(INFO) << "ConcreteBroadcast: t1=" << t1 << " t2=" << t2 @@ -119,6 +124,11 @@ Array BroadcastRel(const Array& types, int num_args) { return types; } +TVM_REGISTER_API("tvm.relay.type_relations.BroadcastRel") +.set_body_typed(const Array&, int)>([](const Array& types, int num_args) { + return BroadcastRel(types, num_args); +}); + /* A relation which specifies broadcasting rules for operations which compute boolean results. */ @@ -133,5 +143,12 @@ Array BroadcastCompRel(const Array& types, int num_args) { return types; } +TVM_REGISTER_API("tvm.relay.type_relations.BroadcastCompRel") +.set_body_typed(const Array&, int)>([](const Array& types, int num_args) { + return BroadcastCompRel(types, num_args); +}); + + + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc index 764a9139c5f6..2fad828c149a 100644 --- a/src/relay/pass/alpha_eq.cc +++ b/src/relay/pass/alpha_eq.cc @@ -100,44 +100,23 @@ struct TypeAlphaEq : TypeVisitor { } } -// void VisitType_(const TupleTypeNode *op, const Type &t2) override { -// if (const TupleTypeNode *pt = t2.as()) { -// if (op->fields.size() != pt->fields.size()) { -// equal = false; -// return; -// } - -// for (size_t i = 0U; i < op->fields.size(); i++) { -// if (!equal) { -// return; -// } -// this->VisitType(op->fields[i], pt->fields[i]); -// } -// } else { -// equal = false; -// } -// } + void VisitType_(const TupleTypeNode *op, const Type &t2) override { + if (const TupleTypeNode *pt = t2.as()) { + if (op->fields.size() != pt->fields.size()) { + equal = false; + return; + } - // void VisitType_(const TypeCallNode *tyn1, const Type &t2) override { - // TypeCall tycall = GetRef(tyn1); - // if (const TypeCallNode *tyn2 = t2.as()) { - // if (tycall->func != tyn2->func) { - // equal = false; - // return; - // } - - // if (tycall->args.size() != tyn2->args.size()) { - // equal = false; - // return; - // } - - // for (size_t i = 0U; i < tycall->args.size(); i++) { - // this->VisitType(tycall->args[i], tyn2->args[i]); - // } - // } else { - // equal = false; - // } - // } + for (size_t i = 0U; i < op->fields.size(); i++) { + if (!equal) { + return; + } + this->VisitType(op->fields[i], pt->fields[i]); + } + } else { + equal = false; + } + } }; bool AlphaEqual(const Type &t1, const Type &t2) { From 5c0a84c756c68bd4d51d5763d79b6fa7211fca92 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 17 Sep 2018 13:56:16 -0700 Subject: [PATCH 107/136] Rename ExprFVisitor to ExprMutator --- include/tvm/relay/expr_visitor.h | 2 +- src/relay/pass/resolve.cc | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h index 4a26dcbd32e7..3595a516b0be 100644 --- a/include/tvm/relay/expr_visitor.h +++ b/include/tvm/relay/expr_visitor.h @@ -67,7 +67,7 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor { virtual void VisitType(const Type& t) {} }; -class ExprFVisitor : public ::tvm::relay::ExprFunctor { +class ExprMutator : public ::tvm::relay::ExprFunctor { public: Expr VisitExpr_(const VarNode* op) override { return GetRef(op); } diff --git a/src/relay/pass/resolve.cc b/src/relay/pass/resolve.cc index 470374da5a0e..affe870484a3 100644 --- a/src/relay/pass/resolve.cc +++ b/src/relay/pass/resolve.cc @@ -32,7 +32,7 @@ struct ResolveTypeType : TypeFVisitor { } }; -struct ResolveTypeExpr : ExprFVisitor { +struct ResolveTypeExpr : ExprMutator { const TypeUnifier &unifier; explicit ResolveTypeExpr(const TypeUnifier &unifier) : unifier(unifier) {} @@ -51,7 +51,7 @@ struct ResolveTypeExpr : ExprFVisitor { // We will visit e like normal building a new // term, then resolve e's old type and write // it back into the new node. - auto new_e = ExprFVisitor::VisitExpr(e); + auto new_e = ExprMutator::VisitExpr(e); CHECK(e->checked_type_.defined()); auto resolved_cty = VisitType(e->checked_type_); new_e->checked_type_ = resolved_cty; From bee332540d7ffb53f4355c00d4b6e4a47811bccd Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 17 Sep 2018 13:56:36 -0700 Subject: [PATCH 108/136] Remove compiler bits --- src/relay/ir/op.cc | 82 ---------------------------------------------- 1 file changed, 82 deletions(-) diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index bd06d6f5200b..d1a9dd072d31 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -151,87 +151,5 @@ TVM_REGISTER_NODE_TYPE(OpNode) return static_cast(n)->name; }); -bool IsGeneric(const FuncType& func_ty) { - return func_ty->type_params.size() != 0; -} - -using namespace runtime; - -Module CompileOpsToModule(const std::vector& op_names) { - PackedFunc compile_ops = GetPackedFunc("relay.op._compile_ops"); - tvm::Array> args; - - auto compiler_map = Op::GetAttr("FRelayOpCompiler"); - - for (auto op_name : op_names) { - Op op = Op::Get(op_name); - - if (!IsGeneric(op->op_type)) { - auto compiler = compiler_map[op]; - tvm::Array pair = compiler(op->name, op->op_type); - // TODO(@jroesch): I can't pass strings across what should be the - // interface here. - tvm::Array triple = {VarNode::make(op->name), pair[0], - pair[1]}; - args.push_back(triple); - } else { - throw dmlc::Error("it is impossible to compile generic operators."); - } - } - - // Nothing to do, bail out earlier. - // TVM will complain if we try to generate a module of size 0. - if (args.size() == 0) { - return Module(nullptr); - } - - return compile_ops(args); -} - -TVM_REGISTER_API("relay.op._CompileOpsToModule") - .set_body([](TVMArgs args, TVMRetValue* ret) { - std::vector names; - for (auto i = 0; i < args.num_args; i++) { - names.push_back(args[i]); - } - *ret = CompileOpsToModule(names); - }); - -Op SpecializeOp(const std::string& op_name, const std::string& new_op_name, - Array type_args) { - auto registry = ::tvm::relay::OpRegistry::Registry(); - auto op_reg = registry->__REGISTER_OR_GET__(op_name); - auto new_op_reg = registry->__REGISTER__(new_op_name).set_name(); - - auto fn_ty = op_reg.op()->op_type; - - tvm::Map subst_map; - - CHECK(fn_ty->type_params.size() == type_args.size()); - - // Build a subsitituion map up from the function type and type arguments. - // Eventually allow the type vars to be passed in. - for (size_t i = 0; i < type_args.size(); i++) { - subst_map.Set(fn_ty->type_params[i], type_args[i]); - } - - Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, fn_ty->ret_type, {}, fn_ty->type_constraints); - inst_ty = TypeSubst(inst_ty, subst_map); - FuncType new_op_ty = GetRef(inst_ty.as()); - new_op_reg.op()->op_type = new_op_ty; - - // Now we want to copy over some attributes. - PackedFunc compiler = - Op::GetAttr("FRelayOpCompiler")[op_reg.op()]; - new_op_reg.set_attr("FRelayOpCompiler", compiler); - - return new_op_reg.op(); -} - -TVM_REGISTER_API("relay.op._SpecializeOp") - .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = SpecializeOp(args[0], args[1], args[2]); - }); - } // namespace relay } // namespace tvm From 0352410375493f059b79442e89785ec64cb2c611 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 17 Sep 2018 14:12:58 -0700 Subject: [PATCH 109/136] Address CR feedback --- include/tvm/relay/type.h | 13 ++----------- src/relay/ir/type.cc | 16 ++-------------- src/relay/pass/resolve.cc | 4 ++-- src/relay/pass/type_subst.cc | 2 +- src/relay/pass/type_visitor.h | 2 +- src/relay/pass/unifier.cc | 2 +- 6 files changed, 9 insertions(+), 30 deletions(-) diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 71a52fdeb88d..a17450d74ead 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -83,17 +83,8 @@ class TensorTypeNode : public BaseTensorTypeNode { TVM_DLL static TensorType make(Array shape, DataType dtype); - /*! \brief Construct an unsigned integer type */ - TVM_DLL static TensorType Int(int bits, int lanes = 1); - - /*! \brief Construct an unsigned integer type */ - TVM_DLL static TensorType UInt(int bits, int lanes = 1); - - /*! \brief Construct a floating-point type */ - TVM_DLL static TensorType Float(int bits, int lanes = 1); - - /*! \brief Construct a boolean type */ - TVM_DLL static TensorType Bool(int lanes = 1); + /*! \brief Construct an scalar containing elements of dtype. */ + TVM_DLL static TensorType Scalar(DataType dtype); static constexpr const char* _type_key = "relay.TensorType"; TVM_DECLARE_NODE_TYPE_INFO(TensorTypeNode, BaseTensorTypeNode); diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index 73be2400ba2e..1282acafcb92 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -19,20 +19,8 @@ TensorType TensorTypeNode::make(Array shape, DataType dtype) { return TensorType(n); } -TensorType TensorTypeNode::Int(int bits, int lanes) { - return TensorTypeNode::make({}, HalideIR::Int(bits, lanes)); -} - -TensorType TensorTypeNode::UInt(int bits, int lanes) { - return TensorTypeNode::make({}, HalideIR::UInt(bits, lanes)); -} - -TensorType TensorTypeNode::Float(int bits, int lanes) { - return TensorTypeNode::make({}, HalideIR::Float(bits, lanes)); -} - -TensorType TensorTypeNode::Bool(int lanes) { - return TensorTypeNode::make({}, HalideIR::Bool(lanes)); +TensorType TensorTypeNode::Scalar(DataType dtype) { + return TensorTypeNode::make({}, dtype); } TVM_REGISTER_API("relay._make.TensorType") diff --git a/src/relay/pass/resolve.cc b/src/relay/pass/resolve.cc index affe870484a3..d6fdda4ae432 100644 --- a/src/relay/pass/resolve.cc +++ b/src/relay/pass/resolve.cc @@ -12,7 +12,7 @@ namespace tvm { namespace relay { -struct ResolveTypeType : TypeFVisitor { +struct ResolveTypeType : TypeMutator { const TypeUnifier &unifier; explicit ResolveTypeType(const TypeUnifier &unifier) : unifier(unifier) {} @@ -23,7 +23,7 @@ struct ResolveTypeType : TypeFVisitor { unifier->Insert(inc_ty); return inc_ty; } else { - return TypeFVisitor::VisitType(t); + return TypeMutator::VisitType(t); } } diff --git a/src/relay/pass/type_subst.cc b/src/relay/pass/type_subst.cc index 91713976bcaa..5fce9c2ca73b 100644 --- a/src/relay/pass/type_subst.cc +++ b/src/relay/pass/type_subst.cc @@ -9,7 +9,7 @@ namespace tvm { namespace relay { -struct TypeSubstV : TypeFVisitor { +struct TypeSubstV : TypeMutator { tvm::Map subst_map; explicit TypeSubstV(tvm::Map subst_map) diff --git a/src/relay/pass/type_visitor.h b/src/relay/pass/type_visitor.h index 252642b1a492..357a36ffa41f 100644 --- a/src/relay/pass/type_visitor.h +++ b/src/relay/pass/type_visitor.h @@ -54,7 +54,7 @@ struct TypeVisitor : ::tvm::relay::TypeFunctor { }; // A functional visitor for rebuilding an AST in place. -struct TypeFVisitor : TypeFunctor { +struct TypeMutator : TypeFunctor { Type VisitType_(const TensorTypeNode* op) override { // TODO(@jroesch): maybe we should recursively visit return TensorTypeNode::make(op->shape, op->dtype); diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index fa65d1ef18aa..fc856d44fe23 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -147,7 +147,7 @@ Type TypeUnifierNode::Unify(const Type &t1, const Type &t2) { return unified; } -struct IncompleteTypeSubst : TypeFVisitor { +struct IncompleteTypeSubst : TypeMutator { const TypeUnifierNode *unifier; IncompleteTypeSubst(const TypeUnifierNode *unifier) : unifier(unifier) {} From 19e441920a81c169e83cb308f9a85de1c30e93d8 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 17 Sep 2018 14:55:26 -0700 Subject: [PATCH 110/136] Refactor expr_visitor.h --- include/tvm/relay/base.h | 8 ++ include/tvm/relay/expr_mutator.h | 37 +++++++ include/tvm/relay/expr_visitor.h | 173 ++----------------------------- src/relay/ir/expr_mutator.cc | 150 +++++++++++++++++++++++++++ src/relay/ir/expr_visitor.cc | 67 ++++++++++++ src/relay/pass/resolve.cc | 8 +- 6 files changed, 277 insertions(+), 166 deletions(-) create mode 100644 include/tvm/relay/expr_mutator.h create mode 100644 src/relay/ir/expr_mutator.cc create mode 100644 src/relay/ir/expr_visitor.cc diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index de5c93e9c94e..421f7263f811 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -214,6 +214,14 @@ Array Downcast(Array array) { return out; } +template +SubRef Downcast(BaseRef ref) { + const typename SubRef::ContainerType* node = + ref.template as(); + CHECK(node) << "Downcast failed" << std::endl; + return GetRef(node); +} + /*! * \brief Get PackedFunction from global registry and * report error if it does not exist diff --git a/include/tvm/relay/expr_mutator.h b/include/tvm/relay/expr_mutator.h new file mode 100644 index 000000000000..4a33c00ced38 --- /dev/null +++ b/include/tvm/relay/expr_mutator.h @@ -0,0 +1,37 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/expr_mutator.h + * \brief A wrapper around ExprFunctor which functionally updates the AST. + * + * ExprMutator uses memoization and self return in order to amortize + * the cost of using functional updates. + */ +#ifndef TVM_RELAY_EXPR_VISITOR_H_ +#define TVM_RELAY_EXPR_VISITOR_H_ + +#include + +namespace tvm { +namespace relay { + +class ExprMutator : public ::tvm::relay::ExprFunctor { + public: + Expr Mutate(const Expr& expr); + Expr VisitExpr_(const VarNode* op, const Expr & e) override; + Expr VisitExpr_(const ConstantNode* op, const Expr & e) override; + Expr VisitExpr_(const GlobalVarNode* op, const Expr& e) override; + Expr VisitExpr_(const OpNode* op, const Expr& expr) override; + Expr VisitExpr_(const TupleNode* op, const Expr& e) override; + Expr VisitExpr_(const ParamNode* op, const Expr& e) override; + Expr VisitExpr_(const FunctionNode* op, const Expr & e) override; + Expr VisitExpr_(const CallNode* call_node, const Expr & e) override; + Expr VisitExpr_(const LetNode* op, const Expr & e) override; + Expr VisitExpr_(const IfNode* op, const Expr & e) override; + virtual Type VisitType(const Type& t); +private: + tvm::Map memo_; +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_EXPR_VISITOR_H_ diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h index 3595a516b0be..55da8e3d6e9c 100644 --- a/include/tvm/relay/expr_visitor.h +++ b/include/tvm/relay/expr_visitor.h @@ -17,168 +17,17 @@ namespace relay { class ExprVisitor : public ::tvm::relay::ExprFunctor { public: - void VisitExpr_(const VarNode* op) override { return; } - - void VisitExpr_(const GlobalVarNode* op) override { return; } - - void VisitExpr_(const ConstantNode* op) override { return; } - - void VisitExpr_(const TupleNode* op) override { - for (auto field : op->fields) { - this->VisitExpr(field); - } - } - - void VisitExpr_(const ParamNode* op) override { this->VisitExpr(op->var); } - - void VisitExpr_(const FunctionNode* op) override { - for (auto param : op->params) { - this->VisitExpr(param); - } - - this->VisitExpr(op->body); - } - - void VisitExpr_(const CallNode* op) override { - this->VisitExpr(op->op); - for (auto ty_arg : op->type_args) { - this->VisitType(ty_arg); - } - - for (auto arg : op->args) { - this->VisitExpr(arg); - } - } - - void VisitExpr_(const LetNode* op) override { - this->VisitExpr(op->var); - this->VisitExpr(op->value); - this->VisitExpr(op->body); - } - - void VisitExpr_(const IfNode* op) override { - this->VisitExpr(op->cond); - this->VisitExpr(op->true_branch); - this->VisitExpr(op->false_branch); - } - - void VisitExpr_(const OpNode* op) override { return; } - - virtual void VisitType(const Type& t) {} -}; - -class ExprMutator : public ::tvm::relay::ExprFunctor { - public: - Expr VisitExpr_(const VarNode* op) override { return GetRef(op); } - - Expr VisitExpr_(const ConstantNode* op) override { - return GetRef(op); - } - - Expr VisitExpr_(const GlobalVarNode* op) override { - return GetRef(op); - } - - Expr VisitExpr_(const OpNode* op) override { return GetRef(op); } - - Expr VisitExpr_(const TupleNode* op) override { - tvm::Array fields; - for (auto field : op->fields) { - fields.push_back(this->VisitExpr(field)); - } - - return TupleNode::make(fields); - } - - Expr VisitExpr_(const ParamNode* op) override { - Expr var_expr = this->VisitExpr(op->var); - if (const VarNode* var_node = var_expr.as()) { - auto var = GetRef(var_node); - auto type = this->VisitType(op->type); - return ParamNode::make(var, type); - } else { - LOG(FATAL) << "the default param visitor expected a Var found: " - << var_expr << std::endl; - return Expr(); - } - } - - Expr VisitExpr_(const FunctionNode* op) override { - tvm::Array ty_params; - - for (auto ty : op->type_params) { - Type ty_param_type = VisitType(ty); - if (auto ty_param = ty_param_type.as()) { - auto ty_param_ref = GetRef(ty_param); - ty_params.push_back(ty_param_ref); - } else { - LOG(FATAL) - << "the default function visitor expected a TypeParam found: " - << ty_param_type << std::endl; - return Expr(); - } - } - - tvm::Array params; - for (auto param : op->params) { - Expr param_expr = this->VisitExpr(param); - if (const ParamNode* param_node = param_expr.as()) { - auto param = GetRef(param_node); - params.push_back(param); - } else { - CHECK(false) << "the default function visitor expected a Param found: " - << param_expr << std::endl; - return Expr(); - } - } - - auto ret_type = this->VisitType(op->ret_type); - auto body = this->VisitExpr(op->body); - return FunctionNode::make(params, ret_type, body, ty_params); - } - - Expr VisitExpr_(const CallNode* call_node) override { - auto fn = this->VisitExpr(call_node->op); - - tvm::Array ty_args; - for (auto ty_arg : call_node->type_args) { - auto new_ty_arg = this->VisitType(ty_arg); - ty_args.push_back(new_ty_arg); - } - - tvm::Array call_args; - for (auto arg : call_node->args) { - call_args.push_back(this->VisitExpr(arg)); - } - - auto call = CallNode::make(fn, call_args, call_node->attrs, ty_args); - - return call; - } - - Expr VisitExpr_(const LetNode* op) override { - Expr var_expr = this->VisitExpr(op->var); - if (const VarNode* var_node = var_expr.as()) { - auto var = GetRef(var_node); - auto type = this->VisitType(op->value_type); - auto value = this->VisitExpr(op->value); - auto body = this->VisitExpr(op->body); - return LetNode::make(var, value, body, type); - } else { - LOG(FATAL) << "the default let visitor expected a Var found: " - << var_expr << std::endl; - return Expr(); - } - } - - Expr VisitExpr_(const IfNode* op) override { - auto guard = this->VisitExpr(op->cond); - auto true_b = this->VisitExpr(op->true_branch); - auto false_b = this->VisitExpr(op->false_branch); - return IfNode::make(guard, true_b, false_b); - } - - virtual Type VisitType(const Type& t) { return t; } + void VisitExpr_(const VarNode* op) override; + void VisitExpr_(const GlobalVarNode* op) override; + void VisitExpr_(const ConstantNode* op) override; + void VisitExpr_(const TupleNode* op) override; + void VisitExpr_(const ParamNode* op) override; + void VisitExpr_(const FunctionNode* op) override; + void VisitExpr_(const CallNode* op) override; + void VisitExpr_(const LetNode* op) override; + void VisitExpr_(const IfNode* op) override; + void VisitExpr_(const OpNode* op) override; + virtual void VisitType(const Type& t); }; } // namespace relay diff --git a/src/relay/ir/expr_mutator.cc b/src/relay/ir/expr_mutator.cc new file mode 100644 index 000000000000..a3be288e2392 --- /dev/null +++ b/src/relay/ir/expr_mutator.cc @@ -0,0 +1,150 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/relay/expr_mutator.cc + * \brief A wrapper around ExprFunctor which functionally updates the AST. + * + * ExprMutator uses memoization and self return in order to amortize + * the cost of using functional updates. + */ + +#include + +namespace tvm { +namespace relay { + +Expr ExprMutator::Mutate(const Expr& expr) { + auto cached_expr = this->memo_.find(expr); + if (cached_expr != this->memo_.end()) { + return (*cached_expr).second; + } else { + auto new_expr = this->ExprMutator::VisitExpr(expr, expr); + this->memo_.Set(expr, new_expr); + return new_expr; + } +} + +Expr ExprMutator::VisitExpr_(const VarNode* op, const Expr& expr) { + return expr; +} + +Expr ExprMutator::VisitExpr_(const ConstantNode* op, const Expr& expr) { + return expr; +} + +Expr ExprMutator::VisitExpr_(const GlobalVarNode* op, const Expr& expr) { + return expr; +} + +Expr ExprMutator::VisitExpr_(const OpNode* op, const Expr& expr) { + return expr; +} + +Expr ExprMutator::VisitExpr_(const TupleNode* op, const Expr& e) { + tvm::Array fields; + bool all_fields_unchanged = true; + for (auto field : op->fields) { + auto new_field = this->Mutate(field); + fields.push_back(new_field); + all_fields_unchanged &= new_field.same_as(field); + } + + if (all_fields_unchanged) { + return e; + } else { + return TupleNode::make(fields); + } +} + +Expr ExprMutator::VisitExpr_(const ParamNode* op, const Expr& e) { + Var var = Downcast(this->Mutate(op->var)); + auto type = this->VisitType(op->type); + if (var == op->var && type == op->type) { + return e; + } else { + return ParamNode::make(var, type); + } +} + +Expr ExprMutator::VisitExpr_(const FunctionNode* op, const Expr& e) { + tvm::Array ty_params; + + for (auto ty : op->type_params) { + Type ty_param_type = VisitType(ty); + if (auto ty_param = ty_param_type.as()) { + auto ty_param_ref = GetRef(ty_param); + ty_params.push_back(ty_param_ref); + } else { + LOG(FATAL) << "the default function visitor expected a TypeParam found: " + << ty_param_type << std::endl; + return Expr(); + } + } + + tvm::Array params; + for (auto param : op->params) { + Expr param_expr = this->Mutate(param); + if (const ParamNode* param_node = param_expr.as()) { + auto param = GetRef(param_node); + params.push_back(param); + } else { + CHECK(false) << "the default function visitor expected a Param found: " + << param_expr << std::endl; + return Expr(); + } + } + + auto ret_type = this->VisitType(op->ret_type); + auto body = this->Mutate(op->body); + return FunctionNode::make(params, ret_type, body, ty_params); +} + +Expr ExprMutator::VisitExpr_(const CallNode* call_node, const Expr& e) { + auto fn = this->Mutate(call_node->op); + + tvm::Array ty_args; + for (auto ty_arg : call_node->type_args) { + auto new_ty_arg = this->VisitType(ty_arg); + ty_args.push_back(new_ty_arg); + } + + tvm::Array call_args; + for (auto arg : call_node->args) { + call_args.push_back(this->Mutate(arg)); + } + + auto call = CallNode::make(fn, call_args, call_node->attrs, ty_args); + + return call; +} + +Expr ExprMutator::VisitExpr_(const LetNode* op, const Expr& e) { + Expr var_expr = this->Mutate(op->var); + if (const VarNode* var_node = var_expr.as()) { + auto var = GetRef(var_node); + auto type = this->VisitType(op->value_type); + auto value = this->Mutate(op->value); + auto body = this->Mutate(op->body); + return LetNode::make(var, value, body, type); + } else { + LOG(FATAL) << "the default let visitor expected a Var found: " << var_expr + << std::endl; + return Expr(); + } +} + +Expr ExprMutator::VisitExpr_(const IfNode* op, const Expr& e) { + auto guard = this->Mutate(op->cond); + auto true_b = this->Mutate(op->true_branch); + auto false_b = this->Mutate(op->false_branch); + if (op->cond == guard && true_b == op->true_branch && + false_b == op->false_branch) { + return e; + } else { + return IfNode::make(guard, true_b, false_b); + } +} + +Type ExprMutator::VisitType(const Type& t) { return t; } + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/expr_visitor.cc b/src/relay/ir/expr_visitor.cc new file mode 100644 index 000000000000..acb9074347fe --- /dev/null +++ b/src/relay/ir/expr_visitor.cc @@ -0,0 +1,67 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/expr_visitor.h + * \brief A simple visitor wrapper around ExprFunctor. + * + * Exposes two visitors with default traversal strategies, one + * which doesn't compute a result but can mutate internal state, + * and another which functionally builds a new Expr. + */ +#include + +namespace tvm { +namespace relay { + +void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) { return; } + +void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) { return; } + +void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) { return; } + +void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) { + for (auto field : op->fields) { + this->VisitExpr(field); + } +} + +void ExprVisitor::ExprVisitor::VisitExpr_(const ParamNode* op) { + this->VisitExpr(op->var); +} + +void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) { + for (auto param : op->params) { + this->VisitExpr(param); + } + + this->VisitExpr(op->body); +} + +void ExprVisitor::VisitExpr_(const CallNode* op) { + this->VisitExpr(op->op); + for (auto ty_arg : op->type_args) { + this->VisitType(ty_arg); + } + + for (auto arg : op->args) { + this->VisitExpr(arg); + } +} + +void ExprVisitor::VisitExpr_(const LetNode* op) { + this->VisitExpr(op->var); + this->VisitExpr(op->value); + this->VisitExpr(op->body); +} + +void ExprVisitor::VisitExpr_(const IfNode* op) { + this->VisitExpr(op->cond); + this->VisitExpr(op->true_branch); + this->VisitExpr(op->false_branch); +} + +void ExprVisitor::VisitExpr_(const OpNode* op) { return; } + +void ExprVisitor::VisitType(const Type& t) { return; } + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/resolve.cc b/src/relay/pass/resolve.cc index d6fdda4ae432..fe27adfb9a32 100644 --- a/src/relay/pass/resolve.cc +++ b/src/relay/pass/resolve.cc @@ -5,7 +5,7 @@ */ #include -#include +#include #include "./resolve.h" #include "./type_visitor.h" @@ -37,7 +37,7 @@ struct ResolveTypeExpr : ExprMutator { explicit ResolveTypeExpr(const TypeUnifier &unifier) : unifier(unifier) {} - Expr VisitExpr(const Expr &e) { + Expr Mutate(const Expr &e) { // NB: a bit tricky here. // // We want to store resolved type without having @@ -51,7 +51,7 @@ struct ResolveTypeExpr : ExprMutator { // We will visit e like normal building a new // term, then resolve e's old type and write // it back into the new node. - auto new_e = ExprMutator::VisitExpr(e); + auto new_e = ExprMutator::Mutate(e); CHECK(e->checked_type_.defined()); auto resolved_cty = VisitType(e->checked_type_); new_e->checked_type_ = resolved_cty; @@ -69,7 +69,7 @@ Type Resolve(const TypeUnifier &unifier, const Type &ty) { } Expr Resolve(const TypeUnifier &unifier, const Expr &expr) { - return ResolveTypeExpr(unifier).VisitExpr(expr); + return ResolveTypeExpr(unifier).Mutate(expr); } struct FullyResolved : TypeVisitor<> { From dc7330569f089f633f60e3da23d07fcd0008195c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 17 Sep 2018 16:25:50 -0700 Subject: [PATCH 111/136] WIP trying to rearrange type_relation --- include/tvm/relay/op.h | 26 +++++++++--- src/relay/ir/expr_mutator.cc | 75 ++++++++++++++++++--------------- src/relay/op/tensor/elemwise.cc | 12 +++--- src/relay/pass/alpha_eq.cc | 30 +++++++------ 4 files changed, 82 insertions(+), 61 deletions(-) diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 4f4b0ef87aaa..7222504a0a3a 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -154,7 +154,7 @@ class OpRegistry { * relation on variables. \return reference to self. */ inline OpRegistry& add_type_rel(const std::string& rel_name, - const std::string& type_rel_func_name); + std::function(const Array &, int)> type_rel_func); /*! * \brief Set the type key of attributes. @@ -345,10 +345,24 @@ inline OpRegistry& OpRegistry::add_argument(const std::string& name, } inline OpRegistry& OpRegistry::add_type_rel( - const std::string& rel_name, const std::string& type_rel_func_name) { - auto env_func = EnvFunc::Get(type_rel_func_name); - TypedEnvFunc(const Array&, int)> type_rel_func; - type_rel_func = env_func; + const std::string& rel_name, std::function(const Array &, int)> type_rel_func) { + + TypedEnvFunc(const Array&, int)> env_type_rel_func; + + std::cout << "BeforeHere" << std::endl; + + try { + auto env_func = EnvFunc::Get(rel_name); + env_type_rel_func = env_func; + } catch (const dmlc::Error& err) { + std::cout << "In Catch...." << rel_name << std::endl; + TVM_REGISTER_API(rel_name).set_body_typed(const Array &, int)>(type_rel_func); + std::cout << "After Reg...." << std::endl; + auto env_func = EnvFunc::Get(rel_name); + env_type_rel_func = env_func; + } + + std::cout << "AfterHere" << std::endl; std::vector type_params; std::vector arg_types; @@ -371,7 +385,7 @@ inline OpRegistry& OpRegistry::add_type_rel( ty_call_args.push_back(out_param); TypeConstraint type_rel = - TypeRelationNode::make(rel_name, type_rel_func, ty_call_args); + TypeRelationNode::make(rel_name, env_type_rel_func, ty_call_args); auto func_type = FuncTypeNode::make(arg_types, out_param, type_params, {type_rel}); diff --git a/src/relay/ir/expr_mutator.cc b/src/relay/ir/expr_mutator.cc index a3be288e2392..83017508d1cc 100644 --- a/src/relay/ir/expr_mutator.cc +++ b/src/relay/ir/expr_mutator.cc @@ -58,6 +58,7 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op, const Expr& e) { Expr ExprMutator::VisitExpr_(const ParamNode* op, const Expr& e) { Var var = Downcast(this->Mutate(op->var)); auto type = this->VisitType(op->type); + if (var == op->var && type == op->type) { return e; } else { @@ -67,68 +68,71 @@ Expr ExprMutator::VisitExpr_(const ParamNode* op, const Expr& e) { Expr ExprMutator::VisitExpr_(const FunctionNode* op, const Expr& e) { tvm::Array ty_params; + bool all_ty_params_changed = true; - for (auto ty : op->type_params) { - Type ty_param_type = VisitType(ty); - if (auto ty_param = ty_param_type.as()) { - auto ty_param_ref = GetRef(ty_param); - ty_params.push_back(ty_param_ref); - } else { - LOG(FATAL) << "the default function visitor expected a TypeParam found: " - << ty_param_type << std::endl; - return Expr(); - } + for (auto ty_param : op->type_params) { + TypeParam new_ty_param = Downcast(VisitType(ty_param)); + ty_params.push_back(new_ty_param); + all_ty_params_changed &= new_ty_param.same_as(ty_param); } tvm::Array params; + bool all_params_changed = true; for (auto param : op->params) { - Expr param_expr = this->Mutate(param); - if (const ParamNode* param_node = param_expr.as()) { - auto param = GetRef(param_node); - params.push_back(param); - } else { - CHECK(false) << "the default function visitor expected a Param found: " - << param_expr << std::endl; - return Expr(); - } + Param new_param = Downcast(this->Mutate(param)); + params.push_back(new_param); + all_params_changed &= param.same_as(new_param); } auto ret_type = this->VisitType(op->ret_type); auto body = this->Mutate(op->body); - return FunctionNode::make(params, ret_type, body, ty_params); + + if (ty_params.same_as(op->type_params) && params.same_as(op->params) && + ret_type.same_as(op->ret_type) && body.same_as(op->body)) { + return e; + } else { + return FunctionNode::make(params, ret_type, body, ty_params); + } } Expr ExprMutator::VisitExpr_(const CallNode* call_node, const Expr& e) { - auto fn = this->Mutate(call_node->op); + auto op = this->Mutate(call_node->op); tvm::Array ty_args; + bool all_ty_args_unchanged = true; for (auto ty_arg : call_node->type_args) { auto new_ty_arg = this->VisitType(ty_arg); ty_args.push_back(new_ty_arg); + all_ty_args_unchanged &= new_ty_arg.same_as(ty_arg); } tvm::Array call_args; + bool all_args_unchanged = true; for (auto arg : call_node->args) { - call_args.push_back(this->Mutate(arg)); + auto new_arg = this->Mutate(arg); + call_args.push_back(new_arg); + all_args_unchanged &= new_arg.same_as(arg); } - auto call = CallNode::make(fn, call_args, call_node->attrs, ty_args); - - return call; + if (all_ty_args_unchanged && all_args_unchanged && + call_node->op.same_as(op)) { + return e; + } else { + return CallNode::make(op, call_args, call_node->attrs, ty_args); + } } Expr ExprMutator::VisitExpr_(const LetNode* op, const Expr& e) { - Expr var_expr = this->Mutate(op->var); - if (const VarNode* var_node = var_expr.as()) { - auto var = GetRef(var_node); - auto type = this->VisitType(op->value_type); - auto value = this->Mutate(op->value); - auto body = this->Mutate(op->body); - return LetNode::make(var, value, body, type); + Var var = Downcast(this->Mutate(op->var)); + auto type = this->VisitType(op->value_type); + auto value = this->Mutate(op->value); + auto body = this->Mutate(op->body); + + if (var.same_as(op->var) && type.same_as(op->value_type) && + value.same_as(op->value) && body.same_as(op->body)) { + return e; } else { - LOG(FATAL) << "the default let visitor expected a Var found: " << var_expr - << std::endl; - return Expr(); + return LetNode::make(var, value, body, type); } } @@ -136,6 +140,7 @@ Expr ExprMutator::VisitExpr_(const IfNode* op, const Expr& e) { auto guard = this->Mutate(op->cond); auto true_b = this->Mutate(op->true_branch); auto false_b = this->Mutate(op->false_branch); + if (op->cond == guard && true_b == op->true_branch && false_b == op->false_branch) { return e; diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc index 683b661baa80..df806899a7a0 100644 --- a/src/relay/op/tensor/elemwise.cc +++ b/src/relay/op/tensor/elemwise.cc @@ -37,7 +37,7 @@ RELAY_REGISTER_UNARY_OP("log") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_rel("Log", "tvm.relay.type_relations.IdentityRel"); +.add_type_rel("Identity", IdentityRel); // data : Tensor[shape, dtype] // result: Tensor[shape, dtype] @@ -51,7 +51,7 @@ RELAY_REGISTER_UNARY_OP("exp") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_rel("Exp", "tvm.relay.type_relations.IdentityRel"); +.add_type_rel("Identity", IdentityRel); RELAY_REGISTER_UNARY_OP("sqrt") @@ -62,7 +62,7 @@ RELAY_REGISTER_UNARY_OP("sqrt") )code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_rel("Sqrt", "tvm.relay.type_relations.IdentityRel"); +.add_type_rel("Identity", IdentityRel); // Addition TVM_REGISTER_API("relay.op._make.add") @@ -76,7 +76,7 @@ RELAY_REGISTER_OP("add") .add_argument("lhs", "Tensor", "The left hand side tensor.") .add_argument("rhs", "Tensor", "The right hand side tensor.") .set_support_level(1) - .add_type_rel("Broadcast", "tvm.relay.type_relations.BroadcastRel"); + .add_type_rel("Broadcast", BroadcastRel); // def broadcast(s1, s2): // ... @@ -97,7 +97,7 @@ RELAY_REGISTER_OP("subtract") .add_argument("lhs", "Tensor", "The left hand side tensor.") .add_argument("rhs", "Tensor", "The right hand side tensor.") .set_support_level(1) - .add_type_rel("BroadcastComp", "tvm.relay.type_relations.BroadcastCompRel"); + .add_type_rel("Broadcast", BroadcastRel); // def broadcast(s1, s2): // ... @@ -118,7 +118,7 @@ RELAY_REGISTER_OP("equal") .add_argument("lhs", "Tensor", "The left hand side tensor.") .add_argument("rhs", "Tensor", "The right hand side tensor.") .set_support_level(1) - .add_type_rel("BroadcastComp", "tvm.relay.type_relations.BroadcastCompRel"); + .add_type_rel("BroadcastComp", BroadcastCompRel); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc index 2fad828c149a..9b9b7a59dd28 100644 --- a/src/relay/pass/alpha_eq.cc +++ b/src/relay/pass/alpha_eq.cc @@ -3,8 +3,8 @@ * \file src/tvm/relay/pass/alpha_eq.cc * \brief Compute the set of variables not bound in the expression. */ +#include #include "tvm/relay/pass/alpha_eq.h" -#include "tvm/relay/expr_visitor.h" #include "./type_visitor.h" namespace tvm { @@ -244,19 +244,21 @@ bool AlphaEqual(const Type &t1, const Type &t2) { // } // }; -// bool alpha_eq(const Expr &e1, const Expr &e2) { -// AlphaEq eq; -// eq.VisitExpr(e1, e2); -// return eq.equal; -// } - -// // TODO(@jroesch): move to correct namespace? -// TVM_REGISTER_API("relay._make._alpha_eq") -// .set_body([](TVMArgs args, TVMRetValue *ret) { -// Expr e1 = args[0]; -// Expr e2 = args[1]; -// *ret = alpha_eq(e1, e2); -// }); +bool AlphaEqual(const Expr &e1, const Expr &e2) { + // AlphaEq eq; + // eq.VisitExpr(e1, e2); + // return eq.equal; + LOG(FATAL) << "NYI"; + return false; +} + +// TODO(@jroesch): move to correct namespace? +TVM_REGISTER_API("relay._make._alpha_eq") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Expr e1 = args[0]; + Expr e2 = args[1]; + *ret = AlphaEqual(e1, e2); + }); TVM_REGISTER_API("relay._make._type_alpha_eq") .set_body([](TVMArgs args, TVMRetValue *ret) { From 63cdbd2681aeff681593f40488d97025d3469d16 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 18 Sep 2018 00:30:58 -0700 Subject: [PATCH 112/136] Fix crash --- include/tvm/relay/op.h | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 7222504a0a3a..472d109d8170 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -153,8 +153,9 @@ class OpRegistry { * type. \param type_rel The backing relation which can solve an arbitrary * relation on variables. \return reference to self. */ - inline OpRegistry& add_type_rel(const std::string& rel_name, - std::function(const Array &, int)> type_rel_func); + inline OpRegistry& add_type_rel( + const std::string& rel_name, + std::function(const Array&, int)> type_rel_func); /*! * \brief Set the type key of attributes. @@ -345,25 +346,20 @@ inline OpRegistry& OpRegistry::add_argument(const std::string& name, } inline OpRegistry& OpRegistry::add_type_rel( - const std::string& rel_name, std::function(const Array &, int)> type_rel_func) { - + const std::string& rel_name, + std::function(const Array&, int)> type_rel_func) { TypedEnvFunc(const Array&, int)> env_type_rel_func; - std::cout << "BeforeHere" << std::endl; - - try { + if (runtime::Registry::Get(rel_name)) { auto env_func = EnvFunc::Get(rel_name); env_type_rel_func = env_func; - } catch (const dmlc::Error& err) { - std::cout << "In Catch...." << rel_name << std::endl; - TVM_REGISTER_API(rel_name).set_body_typed(const Array &, int)>(type_rel_func); - std::cout << "After Reg...." << std::endl; + } else { + runtime::Registry::Register(rel_name) + .set_body_typed(const Array&, int)>(type_rel_func); auto env_func = EnvFunc::Get(rel_name); env_type_rel_func = env_func; } - std::cout << "AfterHere" << std::endl; - std::vector type_params; std::vector arg_types; From 4e4ee0066a821193a6d1370fba4e23e5240008a0 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 18 Sep 2018 01:29:22 -0700 Subject: [PATCH 113/136] Address more CR feedback --- include/tvm/relay/op.h | 10 +++--- python/tvm/relay/op/_tensor.py | 58 ---------------------------------- src/relay/ir/environment.cc | 7 ++-- src/relay/pass/type_infer.cc | 2 -- src/relay/pass/unifier.cc | 54 +++++++++++++------------------ 5 files changed, 31 insertions(+), 100 deletions(-) diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 472d109d8170..0d1c2605b45b 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -348,15 +348,17 @@ inline OpRegistry& OpRegistry::add_argument(const std::string& name, inline OpRegistry& OpRegistry::add_type_rel( const std::string& rel_name, std::function(const Array&, int)> type_rel_func) { + auto func_name = std::string("tvm.relay.type_relation.") + rel_name; + TypedEnvFunc(const Array&, int)> env_type_rel_func; - if (runtime::Registry::Get(rel_name)) { - auto env_func = EnvFunc::Get(rel_name); + if (runtime::Registry::Get(func_name)) { + auto env_func = EnvFunc::Get(func_name); env_type_rel_func = env_func; } else { - runtime::Registry::Register(rel_name) + runtime::Registry::Register(func_name) .set_body_typed(const Array&, int)>(type_rel_func); - auto env_func = EnvFunc::Get(rel_name); + auto env_func = EnvFunc::Get(func_name); env_type_rel_func = env_func; } diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 875df0e52561..0bc2054cebdf 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -1,60 +1,2 @@ #pylint: disable=invalid-name """Backend compiler related feature registration""" -from topi import add -from .op import register -from ..type import FuncType, TensorType -from ...schedule import create_schedule -from ...api import placeholder - -def type_to_placeholder(name, ty): - """Convert a single type into the correct placeholder.""" - if isinstance(ty, TensorType): - return placeholder(ty.shape, name=name, dtype=ty.dtype) - else: - raise Exception("can only pass Tensor values to TVM operators") - -def func_ty_to_placeholders(func_ty): - """Build input placeholders based on a function type.""" - if isinstance(func_ty, FuncType): - arg_types = func_ty.arg_types - ret_type = func_ty.ret_type - args = [] - var = 0 - for arg in arg_types: - var += 1 - args.append(type_to_placeholder(f"Input{var}", arg)) - return args, ret_type - else: - raise Exception("error") - -# def lookup_in_topi(name): -# try: -# f = eval(f"topi.{name}") -# except: -# f = eval(f"topi.nn.{name}") - -# return f - -# @tvm.register_func("nnvm.relay._default_op_compiler") -# def _default_op_compile(op_name: str, func_ty: ir.Type, attrs: ir.Attributes=None) -> Any: -# Inputs, ret_ty = func_ty_to_placeholders(func_ty) -# op = lookup_in_topi(op_name) -# Output = op(*Inputs) - -# if Output.dtype == 'uint1': -# import pdb; pdb.set_trace() -# Output = Output.astype('uint8') - -# schedule = tvm.create_schedule(Output.op) -# return [schedule, Inputs + [Output]] - -#pylint: disable=duplicate-argument-name -def add_compiler(_, func_type, *__): - """The compilation code for the TVM compiler.""" - inputs, _ = func_ty_to_placeholders(func_type) - # op = lookup_in_topi(op_name) - output = add(*inputs) - schedule = create_schedule(output.op) - return [schedule, inputs + [output]] - -register("add", "FRelayOpCompiler", add_compiler) diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc index c8d099b6f269..d41fcf23c2c7 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/environment.cc @@ -76,10 +76,9 @@ void EnvironmentNode::Update(const GlobalVar &var, const Function &func) { this->Add(var, func, true); } -void EnvironmentNode::Remove(const GlobalVar &) { - // Clarify with @tqchen about how to use COW to do this. - throw Error("NYI"); - // this->items.erase(id); +void EnvironmentNode::Remove(const GlobalVar & var) { + auto functions_node = this->functions.CopyOnWrite(); + functions_node->data.erase(var.node_); } Function EnvironmentNode::Lookup(const GlobalVar &var) { diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index a787e39b1a3e..6e5b091c72a0 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -561,10 +561,8 @@ Type TypeInferencer::Unify(const Type &t1, const Type &t2, Span sp) { std::stringstream ss; ss << "Error unifying `"; ss << t1; - // ss << PrintType(env, t1, WrapWidth(40)); ss << "` and `"; ss << t2; - // ss << PrintType(env, t2, WrapWidth(40)); ss << "`: " << e.what(); this->FatalError(ss.str(), sp); } diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index fc856d44fe23..548c5f610ba7 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -11,8 +11,8 @@ #include #include #include +#include "./type_subst.h" #include "./type_visitor.h" -// #include "tvm/relay/typeck/kindchecker.h" namespace tvm { namespace relay { @@ -43,7 +43,8 @@ void UnionFindNode::AssertAlphaEqual(const Type &l, const Type &r) { } void UnionFindNode::Unify(const IncompleteType &v1, const Type &t) { - RELAY_LOG(INFO) << "UnionFindNode::Unify v1=" << v1 << ", t=" << t << std::endl; + RELAY_LOG(INFO) << "UnionFindNode::Unify v1=" << v1 << ", t=" << t + << std::endl; auto parent1 = this->Find(v1); // if t is a type var, then unify parents @@ -134,13 +135,16 @@ TypeUnifier TypeUnifierNode::make(UnionFind union_find) { return TypeUnifier(n); } -void TypeUnifierNode::Insert(const IncompleteType &v) { this->union_find->Insert(v); } +void TypeUnifierNode::Insert(const IncompleteType &v) { + this->union_find->Insert(v); +} Type TypeUnifierNode::Unify(const Type &t1, const Type &t2) { RELAY_LOG(INFO) << "TypeUnifierNode::unify: t1=" << t1 << " t2=" << t2 << std::endl; Type unified = this->VisitType(t1, t2); + // TODO (@jroesch): Restore this code when we finish kind checker. // if (!check_kind(unified)) { // throw UnificationError("Invalid kinds in unified type"); // } @@ -167,6 +171,7 @@ Type TypeUnifierNode::Subst(const Type &t) { IncompleteTypeSubst tvsubst(this); // normalize first so substitutions in quantifiers will be correct Type ret = tvsubst.VisitType(t); + // TODO (@jroesch): Restore this code when we finish kind checker. // if (!check_kind(ret)) { // std::stringstream ss; // ss << "Invalid Kinds in substituted type!"; @@ -210,7 +215,6 @@ Type TypeUnifierNode::VisitType_(const IncompleteTypeNode *t1, const Type rt2) { Type TypeUnifierNode::VisitType_(const TypeParamNode *t1, const Type rt2) { TypeParam ti1 = GetRef(t1); - // for other type ids, only check equality if (const TypeParamNode *tin2 = rt2.as()) { TypeParam ti2 = GetRef(tin2); @@ -221,7 +225,6 @@ Type TypeUnifierNode::VisitType_(const TypeParamNode *t1, const Type rt2) { return ti1; } - // cannot unify TypeParam with non-TypeParam throw UnificationError("Unable to unify TypeParamNode"); } @@ -236,24 +239,13 @@ Type TypeUnifierNode::VisitType_(const FuncTypeNode *t1, const Type rt2) { "unable to unify functions with differing number of type parameters"); } - if (ft1->type_params.size() != 0) { - throw dmlc::Error("NYI"); - } - - // TypeParam id1 = tq1->id; - // TypeParam id2 = tq2->id; - - // if (id1->kind != id2->kind) { - // throw UnificationError( - // "Cannot unify quantifiers over ids of different kinds"); - // } - - // TypeParam fresh = TypeParamNode::make(id1->name, id1->kind); + tvm::Map subst_map; - // auto bt1 = type_subst(tq1->boundType, id1, fresh); - // auto bt2 = type_subst(tq2->boundType, id2, fresh); + for (size_t i = 0; i < ft1->arg_types.size(); i++) { + subst_map.Set(ft1->type_params[i], ft2->type_params[i]); + } - // Type unified_bound_type = this->VisitType(bt1, bt2); + ft1 = Downcast(TypeSubst(ft1, subst_map)); if (ft1->arg_types.size() != ft2->arg_types.size()) { throw UnificationError("unable to unify functions of different arities"); @@ -285,27 +277,26 @@ Type TypeUnifierNode::VisitType_(const TensorTypeNode *t1, const Type rt2) { RELAY_LOG(INFO) << "Unify Tensor Shape s1=" << tt1->shape << " s2= " << tt2->shape << std::endl; - try { - // Type unified_shape = this->VisitType(tt1->shape, tt2->shape); - return rt2; - } catch (const UnificationError &err) { - CHECK(false) << "Need to check constraint " << tt1->shape << " = " - << tt2->shape << std::endl; + + if (tt1->shape.size() != tt2->shape.size()) { + throw UnificationError("shapes are not of the same length"); + } + + for (size_t i = 0U; i < tt1->shape.size(); i++) { + if (!tt1->shape[i].same_as(tt2->shape[i])) { + throw UnificationError("shapes do not match at index"); + } } - // fix me return rt2; - // return TensorTypeNode::make(unified_bt, tt2->shape); } - // nothing else can unify throw UnificationError("Cannot unify TensorTypeNode"); } Type TypeUnifierNode::VisitType_(const TupleTypeNode *t1, const Type rt2) { TupleType pt1 = GetRef(t1); - // When unifying tuple types we just solve each field in order. if (const TupleTypeNode *ptn2 = rt2.as()) { TupleType pt2 = GetRef(ptn2); @@ -322,7 +313,6 @@ Type TypeUnifierNode::VisitType_(const TupleTypeNode *t1, const Type rt2) { return TupleTypeNode::make(unified_fields); } - // otherwise cannot unify throw UnificationError("Cannot unify TupleTypeNode"); } From d100018f32167118814f5b8a8fb5506bc00d6282 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 18 Sep 2018 01:35:18 -0700 Subject: [PATCH 114/136] Restore AlphaEqual for Expr --- src/relay/pass/alpha_eq.cc | 252 ++++++++++++++++++------------------- 1 file changed, 120 insertions(+), 132 deletions(-) diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc index 9b9b7a59dd28..6bffc411c9a1 100644 --- a/src/relay/pass/alpha_eq.cc +++ b/src/relay/pass/alpha_eq.cc @@ -1,6 +1,6 @@ /*! * Copyright (c) 2018 by Contributors - * \file src/tvm/relay/pass/alpha_eq.cc + * \file src/tvm/relay/pass/AlphaEqual.cc * \brief Compute the set of variables not bound in the expression. */ #include @@ -24,7 +24,7 @@ struct TypeAlphaEq : TypeVisitor { void ShapeEqual(Array s1, Array s2) { } - void VisitType_(const TensorTypeNode *tt1, const Type &t2) override { + void VisitType_(const TensorTypeNode *tt1, const Type &t2) final { if (const TensorTypeNode *tt2 = t2.as()) { DataTypeEqual(tt1->dtype, tt2->dtype); ShapeEqual(tt1->shape, tt2->shape); @@ -33,7 +33,7 @@ struct TypeAlphaEq : TypeVisitor { } } - void VisitType_(const IncompleteTypeNode *bt1, const Type &t2) override { + void VisitType_(const IncompleteTypeNode *bt1, const Type &t2) final { if (const IncompleteTypeNode *bt2 = t2.as()) { equal = equal && bt1 == bt2; return; @@ -42,7 +42,7 @@ struct TypeAlphaEq : TypeVisitor { } } - void VisitType_(const TypeParamNode *ti1, const Type &t2) override { + void VisitType_(const TypeParamNode *ti1, const Type &t2) final { if (const TypeParamNode *ti2 = t2.as()) { auto tid1 = GetRef(ti1); auto tid2 = GetRef(ti2); @@ -72,7 +72,7 @@ struct TypeAlphaEq : TypeVisitor { } } - void VisitType_(const FuncTypeNode *op, const Type &t2) override { + void VisitType_(const FuncTypeNode *op, const Type &t2) final { if (const FuncTypeNode *ta2 = t2.as()) { if (op->arg_types.size() != ta2->arg_types.size()) { equal = false; @@ -92,7 +92,7 @@ struct TypeAlphaEq : TypeVisitor { } } - void VisitType_(const TypeRelationNode *tr1, const Type &t2) override { + void VisitType_(const TypeRelationNode *tr1, const Type &t2) final { if (const TypeRelationNode *tr2 = t2.as()) { equal = tr1 == tr2; } else { @@ -100,7 +100,7 @@ struct TypeAlphaEq : TypeVisitor { } } - void VisitType_(const TupleTypeNode *op, const Type &t2) override { + void VisitType_(const TupleTypeNode *op, const Type &t2) final { if (const TupleTypeNode *pt = t2.as()) { if (op->fields.size() != pt->fields.size()) { equal = false; @@ -125,142 +125,130 @@ bool AlphaEqual(const Type &t1, const Type &t2) { return aeq.equal; } -// struct AlphaEq : ExprVisitor { -// public: -// tvm::Map eq_map; -// bool equal; -// AlphaEq() : eq_map(), equal(true) {} - -// void VisitExpr_(const LocalIdNode *e1, const Expr &e2) override { -// if (const LocalIdNode *id2 = e2.as()) { -// auto local1 = GetRef(e1); -// auto local2 = GetRef(id2); -// // -// // We handle open terms with this rule assuming variables are identical. -// // -// // Not sure if we should do this. -// if (local1 == local2) { -// equal = true; -// return; -// } - -// // Next we see if there is mapping for local1 into the rhs term. -// // If there is we check to see if those are equal. -// if (eq_map.find(local1) != eq_map.end()) { -// equal = equal && eq_map[local1] == local2; -// } else { -// equal = false; -// } -// } else { -// equal = false; -// } -// } - -// void VisitExpr_(const GlobalIdNode *g1, const Expr &e2) override { -// if (const GlobalIdNode *g2 = e2.as()) { -// equal = equal && g1 == g2; -// } else { -// equal = false; -// } -// } - -// void VisitExpr_(const OperatorIdNode *i1, const Expr &e2) override { -// if (const OperatorIdNode *i2 = e2.as()) { -// equal = equal && i1 == i2; -// } else { -// equal = false; -// } -// } - -// void VisitExpr_(const TupleNode *pl1, const Expr &e2) override { -// Tuple prod1 = GetRef(pl1); -// if (const TupleNode *pl2 = e2.as()) { -// Tuple prod2 = GetRef(pl2); -// if (prod1->fields.size() != prod2->fields.size()) { -// equal = false; -// return; -// } - -// for (size_t i = 0U; i < prod1->fields.size(); i++) { -// this->VisitExpr(prod1->fields[i], prod2->fields[i]); -// } -// } else { -// equal = false; -// } -// } - -// void VisitExpr_(const ParamNode *p1, const Expr &e2) override { -// if (const ParamNode *p2 = e2.as()) { -// eq_map.Set(p1->id, p2->id); -// equal = equal && alpha_eq(p1->type, p2->type); -// } else { -// equal = false; -// } -// } - -// void VisitExpr_(const FunctionNode *func1, const Expr &e2) override { -// if (const FunctionNode *func2 = e2.as()) { -// if (func1->params.size() != func2->params.size()) { -// equal = false; -// return; -// } - -// for (size_t i = 0U; i < func1->params.size(); i++) { -// this->VisitExpr(func1->params[i], func2->params[i]); -// } - -// this->VisitExpr(func1->body, func2->body); -// } else { -// equal = false; -// } -// } - -// void VisitExpr_(const CallNode *op, const Expr &e2) override { -// if (const CallNode *call = e2.as()) { -// this->VisitExpr(op->fn, call->fn); - -// if (op->args.size() != call->args.size()) { -// equal = false; -// return; -// } - -// for (size_t i = 0U; i < op->args.size(); i++) { -// this->VisitExpr(op->args[i], call->args[i]); -// } - -// } else { -// equal = false; -// } -// } - -// void VisitExpr_(const LetNode *op, const Expr &e2) override { -// if (const LetNode *let = e2.as()) { -// eq_map.Set(op->id, let->id); -// this->VisitExpr(op->value, let->value); -// this->VisitExpr(op->body, let->body); -// } else { -// equal = false; -// } -// } -// }; +struct AlphaEq : ExprFunctor { + public: + tvm::Map eq_map; + + bool equal; + AlphaEq() : eq_map(), equal(true) {} + + void VisitExpr_(const VarNode *e1, const Expr &e2) final { + if (const VarNode *id2 = e2.as()) { + auto local1 = GetRef(e1); + auto local2 = GetRef(id2); + // We handle open terms with this rule assuming variables are identical. + if (local1 == local2) { + equal = true; + return; + } + + // Next we see if there is mapping for local1 into the rhs term. + // If there is we check to see if those are equal. + if (eq_map.find(local1) != eq_map.end()) { + equal = equal && eq_map[local1] == local2; + } else { + equal = false; + } + } else { + equal = false; + } + } + + void VisitExpr_(const GlobalVarNode *g1, const Expr &e2) final { + if (const GlobalVarNode *g2 = e2.as()) { + equal = equal && g1 == g2; + } else { + equal = false; + } + } + + void VisitExpr_(const TupleNode *pl1, const Expr &e2) final { + Tuple prod1 = GetRef(pl1); + if (const TupleNode *pl2 = e2.as()) { + Tuple prod2 = GetRef(pl2); + if (prod1->fields.size() != prod2->fields.size()) { + equal = false; + return; + } + + for (size_t i = 0U; i < prod1->fields.size(); i++) { + this->VisitExpr(prod1->fields[i], prod2->fields[i]); + } + } else { + equal = false; + } + } + + void VisitExpr_(const ParamNode *p1, const Expr &e2) final { + if (const ParamNode *p2 = e2.as()) { + eq_map.Set(p1->var, p2->var); + equal = equal && AlphaEqual(p1->type, p2->type); + } else { + equal = false; + } + } + + void VisitExpr_(const FunctionNode *func1, const Expr &e2) final { + if (const FunctionNode *func2 = e2.as()) { + if (func1->params.size() != func2->params.size()) { + equal = false; + return; + } + + for (size_t i = 0U; i < func1->params.size(); i++) { + this->VisitExpr(func1->params[i], func2->params[i]); + } + + this->VisitExpr(func1->body, func2->body); + } else { + equal = false; + } + } + + void VisitExpr_(const CallNode *op, const Expr &e2) final { + if (const CallNode *call = e2.as()) { + this->VisitExpr(op->op, call->op); + + if (op->args.size() != call->args.size()) { + equal = false; + return; + } + + for (size_t i = 0U; i < op->args.size(); i++) { + this->VisitExpr(op->args[i], call->args[i]); + } + + } else { + equal = false; + } + } + + void VisitExpr_(const LetNode *op, const Expr &e2) final { + if (const LetNode *let = e2.as()) { + eq_map.Set(op->var, let->var); + this->VisitExpr(op->value, let->value); + this->VisitExpr(op->body, let->body); + } else { + equal = false; + } + } +}; bool AlphaEqual(const Expr &e1, const Expr &e2) { - // AlphaEq eq; - // eq.VisitExpr(e1, e2); - // return eq.equal; - LOG(FATAL) << "NYI"; - return false; + AlphaEq eq; + eq.VisitExpr(e1, e2); + return eq.equal; } // TODO(@jroesch): move to correct namespace? -TVM_REGISTER_API("relay._make._alpha_eq") +TVM_REGISTER_API("relay._make._AlphaEqual") .set_body([](TVMArgs args, TVMRetValue *ret) { Expr e1 = args[0]; Expr e2 = args[1]; *ret = AlphaEqual(e1, e2); }); -TVM_REGISTER_API("relay._make._type_alpha_eq") +TVM_REGISTER_API("relay._make._type_AlphaEqual") .set_body([](TVMArgs args, TVMRetValue *ret) { Type t1 = args[0]; Type t2 = args[1]; From bd45407266fe25ee5f6cad379456e6ddfdaead7a Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 18 Sep 2018 01:38:41 -0700 Subject: [PATCH 115/136] Fix linting --- include/tvm/relay/expr_mutator.h | 26 ++++++++++++++------------ src/relay/ir/expr_mutator.cc | 2 -- src/relay/pass/unifier.cc | 4 ++-- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/include/tvm/relay/expr_mutator.h b/include/tvm/relay/expr_mutator.h index 4a33c00ced38..a53dbe845825 100644 --- a/include/tvm/relay/expr_mutator.h +++ b/include/tvm/relay/expr_mutator.h @@ -6,32 +6,34 @@ * ExprMutator uses memoization and self return in order to amortize * the cost of using functional updates. */ -#ifndef TVM_RELAY_EXPR_VISITOR_H_ -#define TVM_RELAY_EXPR_VISITOR_H_ +#ifndef TVM_RELAY_EXPR_MUTATOR_H_ +#define TVM_RELAY_EXPR_MUTATOR_H_ #include namespace tvm { namespace relay { -class ExprMutator : public ::tvm::relay::ExprFunctor { +class ExprMutator + : public ::tvm::relay::ExprFunctor { public: Expr Mutate(const Expr& expr); - Expr VisitExpr_(const VarNode* op, const Expr & e) override; - Expr VisitExpr_(const ConstantNode* op, const Expr & e) override; + Expr VisitExpr_(const VarNode* op, const Expr& e) override; + Expr VisitExpr_(const ConstantNode* op, const Expr& e) override; Expr VisitExpr_(const GlobalVarNode* op, const Expr& e) override; Expr VisitExpr_(const OpNode* op, const Expr& expr) override; Expr VisitExpr_(const TupleNode* op, const Expr& e) override; Expr VisitExpr_(const ParamNode* op, const Expr& e) override; - Expr VisitExpr_(const FunctionNode* op, const Expr & e) override; - Expr VisitExpr_(const CallNode* call_node, const Expr & e) override; - Expr VisitExpr_(const LetNode* op, const Expr & e) override; - Expr VisitExpr_(const IfNode* op, const Expr & e) override; + Expr VisitExpr_(const FunctionNode* op, const Expr& e) override; + Expr VisitExpr_(const CallNode* call_node, const Expr& e) override; + Expr VisitExpr_(const LetNode* op, const Expr& e) override; + Expr VisitExpr_(const IfNode* op, const Expr& e) override; virtual Type VisitType(const Type& t); -private: - tvm::Map memo_; + + private: + tvm::Map memo_; }; } // namespace relay } // namespace tvm -#endif // TVM_RELAY_EXPR_VISITOR_H_ +#endif // TVM_RELAY_EXPR_MUTATOR_H_ diff --git a/src/relay/ir/expr_mutator.cc b/src/relay/ir/expr_mutator.cc index 83017508d1cc..6ad3279a7890 100644 --- a/src/relay/ir/expr_mutator.cc +++ b/src/relay/ir/expr_mutator.cc @@ -58,7 +58,6 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op, const Expr& e) { Expr ExprMutator::VisitExpr_(const ParamNode* op, const Expr& e) { Var var = Downcast(this->Mutate(op->var)); auto type = this->VisitType(op->type); - if (var == op->var && type == op->type) { return e; } else { @@ -140,7 +139,6 @@ Expr ExprMutator::VisitExpr_(const IfNode* op, const Expr& e) { auto guard = this->Mutate(op->cond); auto true_b = this->Mutate(op->true_branch); auto false_b = this->Mutate(op->false_branch); - if (op->cond == guard && true_b == op->true_branch && false_b == op->false_branch) { return e; diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index 548c5f610ba7..3b3c1f883232 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -144,7 +144,7 @@ Type TypeUnifierNode::Unify(const Type &t1, const Type &t2) { << std::endl; Type unified = this->VisitType(t1, t2); - // TODO (@jroesch): Restore this code when we finish kind checker. + // TODO(@jroesch): Restore this code when we finish kind checker. // if (!check_kind(unified)) { // throw UnificationError("Invalid kinds in unified type"); // } @@ -171,7 +171,7 @@ Type TypeUnifierNode::Subst(const Type &t) { IncompleteTypeSubst tvsubst(this); // normalize first so substitutions in quantifiers will be correct Type ret = tvsubst.VisitType(t); - // TODO (@jroesch): Restore this code when we finish kind checker. + // TODO(@jroesch): Restore this code when we finish kind checker. // if (!check_kind(ret)) { // std::stringstream ss; // ss << "Invalid Kinds in substituted type!"; From 91c7b45253ad5e5dc59b975d92fb37060faa31b7 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 18 Sep 2018 01:42:55 -0700 Subject: [PATCH 116/136] Fix doc issue --- include/tvm/relay/base.h | 4 ++-- include/tvm/relay/op.h | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 421f7263f811..ca4c1335eb5b 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -108,8 +108,8 @@ class SourceName : public NodeRef { /*! * \brief Get an SourceName for a given operator name. * Will raise an error if the source name has not been registered. - * \param op_name Name of the operator. - * \return Pointer to a Op, valid throughout program lifetime. + * \param name Name of the operator. + * \return Reference to a SourceName valid throughout program lifetime. */ TVM_DLL static const SourceName& Get(const std::string& name); diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 0d1c2605b45b..45fbec715859 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -149,9 +149,10 @@ class OpRegistry { const std::string& description); /*! * \brief Attach the type function corresponding to the return type. - * \param type_rel_name The type function name to register for the return - * type. \param type_rel The backing relation which can solve an arbitrary - * relation on variables. \return reference to self. + * \param rel_name The type relation name to register. + * \param type_rel_func The backing relation function which can solve an arbitrary + * relation on variables. + * \return reference to self. */ inline OpRegistry& add_type_rel( const std::string& rel_name, From fb74e2c72e1458b736ebd96ab02cdffa810109c9 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 18 Sep 2018 16:28:56 -0700 Subject: [PATCH 117/136] Address majority of in-person feedback. --- include/tvm/relay/base.h | 45 +----- include/tvm/relay/environment.h | 10 +- include/tvm/relay/expr.h | 24 +-- include/tvm/relay/expr_functor.h | 54 +++++++ include/tvm/relay/expr_mutator.h | 39 ----- include/tvm/relay/expr_visitor.h | 35 ----- include/tvm/relay/pass.h | 35 +++++ include/tvm/relay/pass/alpha_eq.h | 53 ------- python/tvm/relay/__init__.py | 14 +- python/tvm/relay/env.py | 10 -- python/tvm/relay/expr.py | 2 +- python/tvm/relay/from_nnvm.py | 7 - python/tvm/relay/ir_builder.py | 148 +++++++++++------- python/tvm/relay/{type.py => ty.py} | 0 src/relay/ir/environment.cc | 22 --- .../ir/{expr_mutator.cc => expr_functor.cc} | 54 ++++++- src/relay/ir/expr_visitor.cc | 67 -------- src/relay/op/type_relations.cc | 35 ++--- src/relay/pass/alpha_eq.cc | 10 +- src/relay/pass/resolve.cc | 2 +- src/relay/pass/type_infer.cc | 5 +- src/relay/pass/unifier.cc | 2 +- .../relay/test_tyck_eval_integration.py | 29 ++-- 23 files changed, 299 insertions(+), 403 deletions(-) delete mode 100644 include/tvm/relay/expr_mutator.h delete mode 100644 include/tvm/relay/expr_visitor.h delete mode 100644 include/tvm/relay/pass/alpha_eq.h delete mode 100644 python/tvm/relay/from_nnvm.py rename python/tvm/relay/{type.py => ty.py} (100%) rename src/relay/ir/{expr_mutator.cc => expr_functor.cc} (76%) delete mode 100644 src/relay/ir/expr_visitor.cc diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index ca4c1335eb5b..7c66d2c2de43 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -117,7 +117,6 @@ class SourceName : public NodeRef { using ContainerType = SourceNameNode; }; - /*! * \brief Span information for debugging purposes */ @@ -190,48 +189,12 @@ inline const T* As(const NodeRef& node) { return nullptr; } -template -std::vector Downcast(std::vector array) { - std::vector out; - for (const U& elem : array) { - const typename T::ContainerType* node = - elem.template as(); - CHECK(node) << "Downcast failed" << std::endl; - out.push_back(GetRef(node)); - } - return out; -} - -template -Array Downcast(Array array) { - Array out; - for (const U& elem : array) { - const typename T::ContainerType* node = - elem.template as(); - CHECK(node) << "Downcast failed" << std::endl; - out.push_back(GetRef(node)); - } - return out; -} - template SubRef Downcast(BaseRef ref) { - const typename SubRef::ContainerType* node = - ref.template as(); - CHECK(node) << "Downcast failed" << std::endl; - return GetRef(node); -} - -/*! - * \brief Get PackedFunction from global registry and - * report error if it does not exist - * \param name The name of the function. - * \return The created PackedFunc. - */ -inline const PackedFunc& GetPackedFunc(const std::string& name) { - const PackedFunc* pf = tvm::runtime::Registry::Get(name); - CHECK(pf != nullptr) << "Cannot find function " << name << " in registry"; - return *pf; + CHECK(ref->template is_type()) + << "Downcast from " << ref->type_key() << " to " + << SubRef::ContainerType::_type_key << " failed."; + return SubRef(ref.node_); } } // namespace relay diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h index 75ddc88674e6..949339fc23b3 100644 --- a/include/tvm/relay/environment.h +++ b/include/tvm/relay/environment.h @@ -46,6 +46,7 @@ class EnvironmentNode : public RelayNode { EnvironmentNode() {} void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("functions", &functions); v->Visit("global_map_", &global_map_); } @@ -93,15 +94,6 @@ class EnvironmentNode : public RelayNode { */ void Merge(const Environment& other); - using Transformer = - runtime::TypedPackedFunc(const Environment&)>; - - /*! \brief Apply a function over every function in the global environment. - * \param transformer The transformation function. - */ - void Transform(Transformer transformer); - static constexpr const char* _type_key = "relay.Environment"; TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node); diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index f61ee3503bdf..21ba9659ca9a 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -7,9 +7,9 @@ #define TVM_RELAY_EXPR_H_ #include +#include #include #include -#include #include #include "./base.h" #include "./type.h" @@ -73,6 +73,7 @@ class ConstantNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); } TVM_DLL static Constant make(runtime::NDArray data); @@ -94,6 +95,7 @@ class TupleNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); } TVM_DLL static Tuple make(tvm::Array fields); @@ -106,8 +108,8 @@ RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr); /*! * \brief Local variables used in the let expression. - * - * Its semantics are similar to tvm.Var node used in TVM's low level + * + * Its semantics are similar to tvm.Var node used in TVM's low level * tensor expression language. * * \note Each Var is bind only once and is immutable/ @@ -116,13 +118,14 @@ class Var; /*! \brief Container for Var */ class VarNode : public ExprNode { public: - /*! \brief The name of the variable, this only acts as a hint to the user, + /*! \brief The name of the variable, this only acts as a hint to the user, * and is not used for equality. */ std::string name_hint; void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("name_hint", &name_hint); + v->Visit("_checked_type_", &checked_type_); } TVM_DLL static Var make(std::string name_hint); @@ -148,6 +151,7 @@ class GlobalVarNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("name_hint", &name_hint); + v->Visit("_checked_type_", &checked_type_); } TVM_DLL static GlobalVar make(std::string name_hint); @@ -217,6 +221,7 @@ class FunctionNode : public ExprNode { v->Visit("body", &body); v->Visit("type_params", &type_params); v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); } Type fn_type() const; @@ -278,11 +283,10 @@ class CallNode : public ExprNode { v->Visit("attrs", &attrs); v->Visit("type_args", &type_args); v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); } - TVM_DLL static Call make(Expr op, - Array args, - Attrs attrs = Attrs(), + TVM_DLL static Call make(Expr op, Array args, Attrs attrs = Attrs(), Array ty_args = Array()); static constexpr const char* _type_key = "relay.Call"; @@ -321,6 +325,7 @@ class LetNode : public ExprNode { v->Visit("body", &body); v->Visit("value_type", &value_type); v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); } TVM_DLL static Let make(Var var, Expr value, Expr body, Type value_type); @@ -333,10 +338,10 @@ RELAY_DEFINE_NODE_REF(Let, LetNode, Expr); /*! * \brief Condition expression - * + * * Unlike traditional statement `if`s, the if evalutes * to the result of the branch taken. - * + * * let x = if (true) { 1 } else { 0 }; // x is 1 * let y = if (false) { 1 } else { 0 }; // y is 0 */ @@ -358,6 +363,7 @@ class IfNode : public ExprNode { v->Visit("true_branch", &true_branch); v->Visit("false_branch", &false_branch); v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); } TVM_DLL static If make(Expr cond, Expr true_branch, Expr false_branch); diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 8e2f24837473..8ad0537ad68b 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -111,6 +111,60 @@ class ExprFunctor { } }; +/*! \brief A simple visitor wrapper around ExprFunctor. + * + * Exposes two visitors with default traversal strategies, one + * which doesn't compute a result but can mutate internal state, + * and another which functionally builds a new Expr. + */ + +class ExprVisitor : public ::tvm::relay::ExprFunctor { + public: + void VisitExpr_(const VarNode* op) override; + void VisitExpr_(const GlobalVarNode* op) override; + void VisitExpr_(const ConstantNode* op) override; + void VisitExpr_(const TupleNode* op) override; + void VisitExpr_(const ParamNode* op) override; + void VisitExpr_(const FunctionNode* op) override; + void VisitExpr_(const CallNode* op) override; + void VisitExpr_(const LetNode* op) override; + void VisitExpr_(const IfNode* op) override; + void VisitExpr_(const OpNode* op) override; + virtual void VisitType(const Type& t); +}; + +/*! \brief A wrapper around ExprFunctor which functionally updates the AST. +* +* ExprMutator uses memoization and self return in order to amortize +* the cost of using functional updates. +*/ +class ExprMutator + : public ::tvm::relay::ExprFunctor { + public: + Expr Mutate(const Expr& expr); + Expr VisitExpr_(const VarNode* op, const Expr& e) override; + Expr VisitExpr_(const ConstantNode* op, const Expr& e) override; + Expr VisitExpr_(const GlobalVarNode* op, const Expr& e) override; + Expr VisitExpr_(const OpNode* op, const Expr& expr) override; + Expr VisitExpr_(const TupleNode* op, const Expr& e) override; + Expr VisitExpr_(const ParamNode* op, const Expr& e) override; + Expr VisitExpr_(const FunctionNode* op, const Expr& e) override; + Expr VisitExpr_(const CallNode* call_node, const Expr& e) override; + Expr VisitExpr_(const LetNode* op, const Expr& e) override; + Expr VisitExpr_(const IfNode* op, const Expr& e) override; + /*! \brief Used to visit the types inside of expressions. + * + * Can be overloaded to transform the types in arbitrary + * ways, one way would be to define a sub-class of type + * visitor for types which transform them appropriately. + */ + virtual Type VisitType(const Type& t); + + private: + /*! \brief Internal map used for memoization. */ + tvm::Map memo_; +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_FUNCTOR_H_ diff --git a/include/tvm/relay/expr_mutator.h b/include/tvm/relay/expr_mutator.h deleted file mode 100644 index a53dbe845825..000000000000 --- a/include/tvm/relay/expr_mutator.h +++ /dev/null @@ -1,39 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file tvm/relay/expr_mutator.h - * \brief A wrapper around ExprFunctor which functionally updates the AST. - * - * ExprMutator uses memoization and self return in order to amortize - * the cost of using functional updates. - */ -#ifndef TVM_RELAY_EXPR_MUTATOR_H_ -#define TVM_RELAY_EXPR_MUTATOR_H_ - -#include - -namespace tvm { -namespace relay { - -class ExprMutator - : public ::tvm::relay::ExprFunctor { - public: - Expr Mutate(const Expr& expr); - Expr VisitExpr_(const VarNode* op, const Expr& e) override; - Expr VisitExpr_(const ConstantNode* op, const Expr& e) override; - Expr VisitExpr_(const GlobalVarNode* op, const Expr& e) override; - Expr VisitExpr_(const OpNode* op, const Expr& expr) override; - Expr VisitExpr_(const TupleNode* op, const Expr& e) override; - Expr VisitExpr_(const ParamNode* op, const Expr& e) override; - Expr VisitExpr_(const FunctionNode* op, const Expr& e) override; - Expr VisitExpr_(const CallNode* call_node, const Expr& e) override; - Expr VisitExpr_(const LetNode* op, const Expr& e) override; - Expr VisitExpr_(const IfNode* op, const Expr& e) override; - virtual Type VisitType(const Type& t); - - private: - tvm::Map memo_; -}; - -} // namespace relay -} // namespace tvm -#endif // TVM_RELAY_EXPR_MUTATOR_H_ diff --git a/include/tvm/relay/expr_visitor.h b/include/tvm/relay/expr_visitor.h deleted file mode 100644 index 55da8e3d6e9c..000000000000 --- a/include/tvm/relay/expr_visitor.h +++ /dev/null @@ -1,35 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file tvm/relay/expr_visitor.h - * \brief A simple visitor wrapper around ExprFunctor. - * - * Exposes two visitors with default traversal strategies, one - * which doesn't compute a result but can mutate internal state, - * and another which functionally builds a new Expr. - */ -#ifndef TVM_RELAY_EXPR_VISITOR_H_ -#define TVM_RELAY_EXPR_VISITOR_H_ - -#include - -namespace tvm { -namespace relay { - -class ExprVisitor : public ::tvm::relay::ExprFunctor { - public: - void VisitExpr_(const VarNode* op) override; - void VisitExpr_(const GlobalVarNode* op) override; - void VisitExpr_(const ConstantNode* op) override; - void VisitExpr_(const TupleNode* op) override; - void VisitExpr_(const ParamNode* op) override; - void VisitExpr_(const FunctionNode* op) override; - void VisitExpr_(const CallNode* op) override; - void VisitExpr_(const LetNode* op) override; - void VisitExpr_(const IfNode* op) override; - void VisitExpr_(const OpNode* op) override; - virtual void VisitType(const Type& t); -}; - -} // namespace relay -} // namespace tvm -#endif // TVM_RELAY_EXPR_VISITOR_H_ diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 46419bde3f97..db29ad418d24 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -45,6 +45,41 @@ Expr InferType(const Environment& env, const GlobalVar & v, const Function & e); */ bool KindCheck(const Environment& env, const Type& t); +/*! \brief Compare two expressions for structural equivalence. + * + * This comparison operator respects scoping and compares + * expressions without regard to variable choice. + * + * For example: `let x = 1 in x` is equal to `let y = 1 in y`. + * + * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence + * for more details. + * + * \param e1 The left hand expression. + * \param e2 The right hand expression. + * + * \return true if equal, otherwise false + */ +bool AlphaEqual(const Expr& e1, const Expr& e2); + +/*! \brief Compare two types for structural equivalence. + * + * This comparison operator respects scoping and compares + * expressions without regard to variable choice. + * + * For example: `forall s, Tensor[f32, s]` is equal to + * `forall w, Tensor[f32, w]`. + * + * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence + * for more details. + * + * \param t1 The left hand type. + * \param t2 The right hand type. + * + * \return true if equal, otherwise false + */ +bool AlphaEqual(const Type& t1, const Type& t2); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_PASS_H_ diff --git a/include/tvm/relay/pass/alpha_eq.h b/include/tvm/relay/pass/alpha_eq.h deleted file mode 100644 index cfc6f5aa1ae7..000000000000 --- a/include/tvm/relay/pass/alpha_eq.h +++ /dev/null @@ -1,53 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file tvm/relay/pass/alpha_eq.h - * \brief Check expressions and types for structural equivalence. - */ -#ifndef TVM_RELAY_PASS_ALPHA_EQ_H_ -#define TVM_RELAY_PASS_ALPHA_EQ_H_ - -#include -#include - -namespace tvm { -namespace relay { - -/*! \brief Compare two expressions for structural equivalence. - * - * This comparison operator respects scoping and compares - * expressions without regard to variable choice. - * - * For example: `let x = 1 in x` is equal to `let y = 1 in y`. - * - * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence - * for more details. - * - * \param e1 The left hand expression. - * \param e2 The right hand expression. - * - * \return true if equal, otherwise false - */ -bool AlphaEqual(const Expr& e1, const Expr& e2); - -/*! \brief Compare two types for structural equivalence. - * - * This comparison operator respects scoping and compares - * expressions without regard to variable choice. - * - * For example: `forall s, Tensor[f32, s]` is equal to - * `forall w, Tensor[f32, w]`. - * - * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence - * for more details. - * - * \param t1 The left hand type. - * \param t2 The right hand type. - * - * \return true if equal, otherwise false - */ -bool AlphaEqual(const Type& t1, const Type& t2); - -} // namespace relay -} // namespace tvm -#endif // TVM_RELAY_PASS_ALPHA_EQ_H_ - diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 493036857b29..18a53be92815 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -1,7 +1,7 @@ # pylint: disable=wildcard-import """The Relay IR namespace containing the IR definition and compiler.""" from . import base -from . import type as tpe +from . import ty from . import expr from . import env from . import ir_pass @@ -14,12 +14,12 @@ Span = base.Span # Type -Type = tpe.Type -TensorType = tpe.TensorType -Kind = tpe.Kind -TypeParam = tpe.TypeParam -TypeConstraint = tpe.TypeConstraint -FuncType = tpe.FuncType +Type = ty.Type +TensorType = ty.TensorType +Kind = ty.Kind +TypeParam = ty.TypeParam +TypeConstraint = ty.TypeConstraint +FuncType = ty.FuncType # Expr Constant = expr.Constant diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py index e36c27d1a632..8c9150d18835 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/env.py @@ -82,13 +82,3 @@ def __getitem__(self, var): return _env.Environment_Lookup_str(self, var) else: return _env.Environment_Lookup(self, var) - - def transform(self, transformer): - """Apply a transformer function to the environment. - - Parameters - ---------- - transformer: function - The environment transformer function. - """ - _env.Environment_Transform(self, transformer) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 85e69349321d..52a3aca7590f 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -3,7 +3,7 @@ from typing import List import tvm from .base import Span, NodeBase, register_relay_node -from .type import Type, TypeParam +from .ty import Type, TypeParam from ._ir_pass import _get_checked_type from . import _make diff --git a/python/tvm/relay/from_nnvm.py b/python/tvm/relay/from_nnvm.py deleted file mode 100644 index 9700ea955f59..000000000000 --- a/python/tvm/relay/from_nnvm.py +++ /dev/null @@ -1,7 +0,0 @@ -#pylint: disable-all -"""Convert an nnvm.graph.Graph into a tvm.relay.Expr""" -import nnvm - -def from_nnvm(graph): - """Convert an nnvm.graph.Graph into a tvm.relay.Expr""" - raise Exception("NYI") diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 298172054bbe..3f1bdf288609 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -5,12 +5,12 @@ from typing import Any import numpy as np import tvm -from .type import FuncType, TensorType +from .ty import Type, FuncType, TensorType from .expr import Expr, Constant, Let, Var, Param, Function, If from .env import Environment -def convert(arg: Any, ctxt=tvm.cpu(0)) -> tvm.nd.NDArray: +def _convert_to_value(arg: Any, ctxt=tvm.cpu(0)) -> tvm.nd.NDArray: """Convert Python values into the appropriate types for the Relay evaluator. """ @@ -28,8 +28,15 @@ def convert(arg: Any, ctxt=tvm.cpu(0)) -> tvm.nd.NDArray: # raise Exception(f"can't convert {type(arg)} to a Relay AST") raise Exception(f"unsupported argument type {type(arg)}") +def _convert_type(rtype): + if isinstance(rtype, str): + return scalar_type(rtype) + elif isinstance(rtype, Type): + return rtype + else: + raise Exception(f"unsupported conversion to Relay type {type(rtype)}") -def into_ast(arg: Any, ctxt=tvm.cpu(0)) -> Expr: +def convert(arg: Any, ctxt=tvm.cpu(0)) -> Expr: if isinstance(arg, Expr): return arg elif isinstance(arg, tuple): @@ -37,7 +44,7 @@ def into_ast(arg: Any, ctxt=tvm.cpu(0)) -> Expr: elif isinstance(arg, PartialFunc): return arg.to_func() else: - value = convert(arg, ctxt) + value = _convert_to_value(arg, ctxt) return Constant(value) @@ -58,8 +65,13 @@ def __exit__(self, ptype, value, trace): self._exit_cb() -class PartialFunc(): - """A wrapper around functions while they are being built.""" +class PartialFunc(object): + """A wrapper around functions while they are being built. + + Used by the builder as a user is building up a function, + allows Function nodes which contain partially initialized + state. + """ def __init__(self, params, ret_type, body, type_params): self.params = params @@ -71,6 +83,7 @@ def param_ids(self): return [p.var for p in self.params] def to_func(self): + """Converts a PartialFunc into a :py:class:`~relay.Function`.""" return Function( self.params, self.ret_type, @@ -78,8 +91,6 @@ def to_func(self): self.type_params) #pylint: disable=invalid-name - - def _mk_let(bindings, ret_value): let_expr = ret_value for var, (value, ty) in reversed(list(bindings.items())): @@ -88,10 +99,28 @@ def _mk_let(bindings, ret_value): return let_expr -class IRBuilder(): +class IRBuilder(object): """The IRBuilder class. Enables users to build up a Relay environment and program. + + Examples + -------- + + Program: + fn (x : Tensor[f32, (10, 10)]) { + let t1 = log(x); + let t2 = add(t1, x); + return t1; + } + + ..code-block: python + b = IRBuilder() + with b.function(('x', tensor_type(10, 10))) as func: + x, = func.param_ids() + t1 = b.let('t1', log(x)) + t2 = b.let('t2', add(t1, x)) + b.ret(t2) """ def __init__(self): @@ -129,7 +158,7 @@ def let(self, name, value, value_type=None): value = value.var if not isinstance(value, Expr): - value = into_ast(value) + value = convert(value) return self.bind(name, value, value_type) @@ -143,6 +172,7 @@ def _convert_params(self, raw_params): var, ty = raw_param if isinstance(var, str): var = Var(var) + ty = _convert_type(ty) param = Param(var, ty) elif isinstance(param, str): var = Var(raw_param) @@ -175,8 +205,9 @@ def _on_exit(): return WithScope(pfunc, _on_exit) def ret(self, x): + """Set `x` to be the return value of the current function.""" if not self.ret_values[-1]: - self.ret_values[-1] = into_ast(x) + self.ret_values[-1] = convert(x) else: raise Exception( "return value already set, a function can only have one return value") @@ -212,17 +243,26 @@ def _on_exit(): def param(self, name, ty=None): if not ty: - ty = float_type() + ty = scalar_type('float32') + else: + ty = _convert_type(ty) return Param(Var(name), ty) - # def params(*args): - # i = 0 - # while i < args.size(): - # arg = args[i] - # if isinstance(arg, str): - def global_var(self, name: str): + """Construct a global var with `name` as its name hint. + + Parameters + ---------- + name: str + The name of the global variable. + + Returns + ------- + global_var: relay.GlobalVar + The global variable with `name`. + + """ return self.env.global_var(name) def decl(self, name: str, *params, ret_type=None): @@ -235,10 +275,14 @@ def _on_exit(): return WithScope(10, _on_exit) - # def while_loop(cond) - def get(self): - """Get the full program""" + """Get the full program. + + Returns + ---------- + (prog, env) : (relay.Expr, relay.Environment) + A pair of the partial program, and the modified environment. + """ bindings = self.bindings.pop() scope = self.scopes.pop() @@ -254,46 +298,44 @@ def get(self): return _mk_let(bindings, self.ret_values[-1]), self.env -def bool_dtype(): - return 'uint1' - - -def int_dtype(bits=32): - return f'int{bits}' - - -def float_dtype(bits=32): - return f'float{bits}' - - -def uint_dtype(bits=32): - return f'uint{bits}' - - -def int_type(bits=32, _lanes=1): - # TODO(@jroesch, @tqchen) How do we set lanes? - return TensorType(tvm.convert([]), int_dtype(bits)) - - -def uint_type(bits=32, _lanes=1): - return TensorType(tvm.convert([]), uint_dtype(bits)) - - -def float_type(bits=32, _lanes=1): - return TensorType(tvm.convert([]), float_dtype(bits)) - - -def bool_type(_lanes=1): - return TensorType(tvm.convert([]), bool_dtype()) +def scalar_type(dtype): + """Construct a Relay scalar type. + + Parameters + ---------- + dtype: dtype + The dtype of the scalar type. + + Returns: + scalar_type: relay.Type + The scalar type. + """ + return TensorType(tvm.convert([]), dtype) def tensor_type(*shape, dtype='float32'): + """Construct a Relay Tensor type. + + Parameters + ---------- + shape: list of tvm.Expr + The shape of the Tensor type. + dtype: dtype + The dtype of the Tensor type. + + Returns + ------- + tensor_type: relay.Type + The resulting tensor types. + """ return TensorType(tvm.convert(shape), dtype) - def func_type(args, ret_type, type_params=None, type_constraints=None): + """document""" if not type_params: type_params = [] if not type_constraints: type_constraints = [] + args = [_convert_type(arg) for arg in args] + ret_type = _convert_type(ret_type) return FuncType(args, ret_type, type_params, type_constraints) diff --git a/python/tvm/relay/type.py b/python/tvm/relay/ty.py similarity index 100% rename from python/tvm/relay/type.py rename to python/tvm/relay/ty.py diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc index d41fcf23c2c7..47c9789ab5ae 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/environment.cc @@ -5,10 +5,8 @@ */ #include #include -#include #include #include "./../pass/resolve.h" -// #include "tvm/relay/util/rang.h" namespace tvm { namespace relay { @@ -101,20 +99,6 @@ void EnvironmentNode::Merge(const Environment &env) { } } -void EnvironmentNode::Transform(EnvironmentNode::Transformer transformer) { - Array to_process; - for (auto var_and_func : this->functions) { - to_process.push_back(var_and_func.first); - } - - auto for_each = transformer(GetRef(this)); - for (auto var : to_process) { - auto func = this->functions[var]; - auto transformed = for_each(var, func); - this->Add(var, transformed, true); - } -} - TVM_REGISTER_API("relay._make.Environment") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = EnvironmentNode::make(args[0]); @@ -153,12 +137,6 @@ TVM_REGISTER_API("relay._env.Environment_Merge") env->Merge(args[1]); }); -TVM_REGISTER_API("relay._env.Environment_Transform") - .set_body([](TVMArgs args, TVMRetValue *ret) { - Environment env = args[0]; - env->Transform(args[1]); - }); - TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const EnvironmentNode *node, tvm::IRPrinter *p) { diff --git a/src/relay/ir/expr_mutator.cc b/src/relay/ir/expr_functor.cc similarity index 76% rename from src/relay/ir/expr_mutator.cc rename to src/relay/ir/expr_functor.cc index 6ad3279a7890..85ae5ffa694e 100644 --- a/src/relay/ir/expr_mutator.cc +++ b/src/relay/ir/expr_functor.cc @@ -7,7 +7,7 @@ * the cost of using functional updates. */ -#include +#include namespace tvm { namespace relay { @@ -149,5 +149,57 @@ Expr ExprMutator::VisitExpr_(const IfNode* op, const Expr& e) { Type ExprMutator::VisitType(const Type& t) { return t; } +void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) { return; } + +void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) { return; } + +void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) { return; } + +void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) { + for (auto field : op->fields) { + this->VisitExpr(field); + } +} + +void ExprVisitor::ExprVisitor::VisitExpr_(const ParamNode* op) { + this->VisitExpr(op->var); +} + +void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) { + for (auto param : op->params) { + this->VisitExpr(param); + } + + this->VisitExpr(op->body); +} + +void ExprVisitor::VisitExpr_(const CallNode* op) { + this->VisitExpr(op->op); + for (auto ty_arg : op->type_args) { + this->VisitType(ty_arg); + } + + for (auto arg : op->args) { + this->VisitExpr(arg); + } +} + +void ExprVisitor::VisitExpr_(const LetNode* op) { + this->VisitExpr(op->var); + this->VisitExpr(op->value); + this->VisitExpr(op->body); +} + +void ExprVisitor::VisitExpr_(const IfNode* op) { + this->VisitExpr(op->cond); + this->VisitExpr(op->true_branch); + this->VisitExpr(op->false_branch); +} + +void ExprVisitor::VisitExpr_(const OpNode* op) { return; } + +void ExprVisitor::VisitType(const Type& t) { return; } + } // namespace relay } // namespace tvm + diff --git a/src/relay/ir/expr_visitor.cc b/src/relay/ir/expr_visitor.cc deleted file mode 100644 index acb9074347fe..000000000000 --- a/src/relay/ir/expr_visitor.cc +++ /dev/null @@ -1,67 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file tvm/relay/expr_visitor.h - * \brief A simple visitor wrapper around ExprFunctor. - * - * Exposes two visitors with default traversal strategies, one - * which doesn't compute a result but can mutate internal state, - * and another which functionally builds a new Expr. - */ -#include - -namespace tvm { -namespace relay { - -void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) { return; } - -void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) { return; } - -void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) { return; } - -void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) { - for (auto field : op->fields) { - this->VisitExpr(field); - } -} - -void ExprVisitor::ExprVisitor::VisitExpr_(const ParamNode* op) { - this->VisitExpr(op->var); -} - -void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) { - for (auto param : op->params) { - this->VisitExpr(param); - } - - this->VisitExpr(op->body); -} - -void ExprVisitor::VisitExpr_(const CallNode* op) { - this->VisitExpr(op->op); - for (auto ty_arg : op->type_args) { - this->VisitType(ty_arg); - } - - for (auto arg : op->args) { - this->VisitExpr(arg); - } -} - -void ExprVisitor::VisitExpr_(const LetNode* op) { - this->VisitExpr(op->var); - this->VisitExpr(op->value); - this->VisitExpr(op->body); -} - -void ExprVisitor::VisitExpr_(const IfNode* op) { - this->VisitExpr(op->cond); - this->VisitExpr(op->true_branch); - this->VisitExpr(op->false_branch); -} - -void ExprVisitor::VisitExpr_(const OpNode* op) { return; } - -void ExprVisitor::VisitType(const Type& t) { return; } - -} // namespace relay -} // namespace tvm diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index b61b2c1de554..4a2464044a1f 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -12,7 +12,7 @@ namespace tvm { namespace relay { -TensorType as_ttype(const Type& t) { +TensorType ToTensorType(const Type& t) { if (auto tt_node = t.as()) { return GetRef(tt_node); } else { @@ -21,7 +21,7 @@ TensorType as_ttype(const Type& t) { } // TODO(@jroesch) what size value do we extract? -int to_int(const tvm::Expr& e) { +int ToInt(const tvm::Expr& e) { CHECK(e.defined()); auto imm = e.as(); CHECK(imm) << "TYPE: " << imm << imm->type << std::endl; @@ -30,7 +30,7 @@ int to_int(const tvm::Expr& e) { Array IdentityRel(const Array& types, int num_args) { CHECK_EQ(types.size(), 2); - auto t1 = as_ttype(types[0]); + auto t1 = ToTensorType(types[0]); if (t1 && types[1].as()) { return {t1, t1}; } else { @@ -38,11 +38,6 @@ Array IdentityRel(const Array& types, int num_args) { } } -TVM_REGISTER_API("tvm.relay.type_relations.IdentityRel") -.set_body_typed(const Array&, int)>([](const Array& types, int num_args) { - return IdentityRel(types, num_args); -}); - static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) { RELAY_LOG(INFO) << "ConcreteBroadcast: t1=" << t1 << " t2=" << t2 @@ -62,8 +57,8 @@ static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, auto rev_sh2 = sh2.rbegin(); while (rev_sh1 != sh1.rend() && rev_sh2 != sh2.rend()) { - auto dim1 = to_int(*rev_sh1); - auto dim2 = to_int(*rev_sh2); + auto dim1 = ToInt(*rev_sh1); + auto dim2 = ToInt(*rev_sh2); if ((dim1 != dim2) && ((dim1 != 1) && (dim2 != 1))) { CHECK(false) << "Dimension mistmatch " << "dim1: " << dim1 << " dim2: " << dim2 << std::endl; @@ -114,8 +109,8 @@ Array BroadcastRel(const Array& types, int num_args) { CHECK_EQ(types.size(), 3); RELAY_LOG(INFO) << "In1: " << types[0] << "In2: " << types[1] << "Out: " << types[2] << std::endl; - if (auto t1 = as_ttype(types[0])) { - if (auto t2 = as_ttype(types[1])) { + if (auto t1 = ToTensorType(types[0])) { + if (auto t2 = ToTensorType(types[1])) { CHECK_EQ(t1->dtype, t2->dtype); return {t1, t2, ConcreteBroadcast(t1, t2, t1->dtype)}; } @@ -124,18 +119,13 @@ Array BroadcastRel(const Array& types, int num_args) { return types; } -TVM_REGISTER_API("tvm.relay.type_relations.BroadcastRel") -.set_body_typed(const Array&, int)>([](const Array& types, int num_args) { - return BroadcastRel(types, num_args); -}); - /* A relation which specifies broadcasting rules for operations which compute boolean results. */ Array BroadcastCompRel(const Array& types, int num_args) { CHECK_EQ(types.size(), 3); - if (auto t1 = as_ttype(types[0])) { - if (auto t2 = as_ttype(types[1])) { + if (auto t1 = ToTensorType(types[0])) { + if (auto t2 = ToTensorType(types[1])) { return {t1, t2, ConcreteBroadcast(t1, t2, HalideIR::Bool())}; } } @@ -143,12 +133,5 @@ Array BroadcastCompRel(const Array& types, int num_args) { return types; } -TVM_REGISTER_API("tvm.relay.type_relations.BroadcastCompRel") -.set_body_typed(const Array&, int)>([](const Array& types, int num_args) { - return BroadcastCompRel(types, num_args); -}); - - - } // namespace relay } // namespace tvm diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc index 6bffc411c9a1..a3fab22570ec 100644 --- a/src/relay/pass/alpha_eq.cc +++ b/src/relay/pass/alpha_eq.cc @@ -1,10 +1,10 @@ /*! * Copyright (c) 2018 by Contributors - * \file src/tvm/relay/pass/AlphaEqual.cc + * \file src/tvm/relay/pass/alpha_eq.cc * \brief Compute the set of variables not bound in the expression. */ -#include -#include "tvm/relay/pass/alpha_eq.h" +#include +#include "tvm/relay/pass.h" #include "./type_visitor.h" namespace tvm { @@ -241,14 +241,14 @@ bool AlphaEqual(const Expr &e1, const Expr &e2) { } // TODO(@jroesch): move to correct namespace? -TVM_REGISTER_API("relay._make._AlphaEqual") +TVM_REGISTER_API("relay._make._alpha_eq") .set_body([](TVMArgs args, TVMRetValue *ret) { Expr e1 = args[0]; Expr e2 = args[1]; *ret = AlphaEqual(e1, e2); }); -TVM_REGISTER_API("relay._make._type_AlphaEqual") +TVM_REGISTER_API("relay._make._type_alpha_eq") .set_body([](TVMArgs args, TVMRetValue *ret) { Type t1 = args[0]; Type t2 = args[1]; diff --git a/src/relay/pass/resolve.cc b/src/relay/pass/resolve.cc index fe27adfb9a32..b073613bafc2 100644 --- a/src/relay/pass/resolve.cc +++ b/src/relay/pass/resolve.cc @@ -5,7 +5,7 @@ */ #include -#include +#include #include "./resolve.h" #include "./type_visitor.h" diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 6e5b091c72a0..b3f7f34597d9 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -511,7 +511,10 @@ bool TypeInferencer::RelationsHold(bool scope_only) { << std::endl; bool all_hold = true; for (auto cs_set : context.constraints) { - auto ty_rels = Downcast(cs_set); + std::vector ty_rels; + for (auto cs : cs_set) { + ty_rels.push_back(Downcast(cs)); + } auto status = Solve(ty_rels); RELAY_LOG(INFO) << "status= " << status << std::endl; if (status == SolverResult::Failed || status == SolverResult::Progress) { diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index 3b3c1f883232..73e837c52fa7 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -9,7 +9,7 @@ #include #include #include -#include +#include #include #include "./type_subst.h" #include "./type_visitor.h" diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index 216d370fac7b..95f296657380 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -5,8 +5,8 @@ import numpy as np from nnvm import graph from tvm.relay.ir_pass import check_expr -from tvm.relay.ir_builder import IRBuilder, float_type, int_type -from tvm.relay.ir_builder import func_type, tensor_type, into_ast +from tvm.relay.ir_builder import IRBuilder, func_type +from tvm.relay.ir_builder import scalar_type, convert, tensor_type from tvm.relay.env import Environment from tvm.relay.op import log, add, equal, subtract from tvm.relay.expr import Function @@ -24,11 +24,11 @@ def assert_decl_has_type(env, name, typ): def test_monomorphic_let(): "Program: let x = 1; return x" b = IRBuilder() - x = b.let('x', 1.0, value_type=float_type(64)) + x = b.let('x', 1.0, value_type=scalar_type('float64')) b.ret(x) prog, env = b.get() - assert_has_type(prog, float_type(64)) + assert_has_type(prog, scalar_type('float64')) # Need to handle constants # run(env, prog, [], float_type(64)) @@ -36,12 +36,11 @@ def test_monomorphic_let(): def test_single_op(): "Program: fn (x : float32) { let t1 = f(x); t1 }" b = IRBuilder() - with b.function(('x', float_type())) as func: + with b.function(('x', 'float32')) as func: x, = func.param_ids() t1 = b.let('t1', log(x)) b.ret(t1) - assert_has_type(func.to_func(), func_type([float_type()], float_type())) - + assert_has_type(func.to_func(), func_type(['float32'], 'float32')) def test_add_op(): """ @@ -93,7 +92,7 @@ def test_dual_op(): t1 = b.let('t1', log(x)) t2 = b.let('t2', add(t1, x)) b.ret(t2) - assert_has_type(func.to_func(), func_type([float_type()], float_type())) + assert_has_type(func.to_func(), func_type(['float32'], 'float32')) def test_decl(): @@ -109,7 +108,7 @@ def f(x : Tensor[f32, (10, 10)]) { lx = b.let('lx', log(x)) b.ret(lx) _, env = b.get() - assert_decl_has_type(env, 'f', func_type([float_type()], float_type())) + assert_decl_has_type(env, 'f', func_type(['float32'], 'float32')) def test_recursion(): @@ -126,16 +125,16 @@ def f(n: i32, data: f32) -> f32 { """ b = IRBuilder() f = b.global_var('f') - n = b.param('n', ty=int_type()) - data = b.param('data', ty=float_type()) + n = b.param('n', ty='int32') + data = b.param('data', ty='float32') with b.decl(f, n, data): - with b.if_scope(equal(n, into_ast(0.0))): - b.ret(f(subtract(n, into_ast(1)), log(data))) + with b.if_scope(equal(n, convert(0.0))): + b.ret(f(subtract(n, convert(1)), log(data))) with b.else_scope(): b.ret(data) - b.ret(f(into_ast(2.0), into_ast(10000.0))) + b.ret(f(convert(2.0), convert(10000.0))) assert_decl_has_type(b.env, 'f', func_type( - [int_type(), float_type()], float_type())) + ['int32', 'float32'], 'float32')) # TODO(@jroesch): need evaluator or new runtime # to execute this. From defa8e4532edcf0eb5fd53c365f19f102097b538 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 18 Sep 2018 16:33:01 -0700 Subject: [PATCH 118/136] Remove alpha_eq tests --- include/tvm/relay/environment.h | 5 +- python/tvm/relay/ir_builder.py | 4 +- tests/python/relay/test_alpha_eq.py | 573 ---------------------------- 3 files changed, 4 insertions(+), 578 deletions(-) delete mode 100644 tests/python/relay/test_alpha_eq.py diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h index 949339fc23b3..7e07dc01eab4 100644 --- a/include/tvm/relay/environment.h +++ b/include/tvm/relay/environment.h @@ -98,8 +98,9 @@ class EnvironmentNode : public RelayNode { TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node); private: - /*! \brief A map from string names to global variables ensures global - * uniqueness. */ + /*! \brief A map from string names to global variables that + * ensures global uniqueness. + */ tvm::Map global_map_; }; diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 3f1bdf288609..fb5e9bb71956 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -40,7 +40,7 @@ def convert(arg: Any, ctxt=tvm.cpu(0)) -> Expr: if isinstance(arg, Expr): return arg elif isinstance(arg, tuple): - raise Exception("..") + return relay.Tuple([convert(el) for el in arg]) elif isinstance(arg, PartialFunc): return arg.to_func() else: @@ -191,8 +191,6 @@ def function(self, *params): relay_params = self._convert_params(params) - # self.params.append(relay_params) - self.enter_scope() pfunc = PartialFunc(relay_params, None, None, []) diff --git a/tests/python/relay/test_alpha_eq.py b/tests/python/relay/test_alpha_eq.py deleted file mode 100644 index 6c0e7779eae6..000000000000 --- a/tests/python/relay/test_alpha_eq.py +++ /dev/null @@ -1,573 +0,0 @@ -"""Test alpha-equivalence of expressions and types.""" -# from relay.ir import alpha_eq, ShapeOp, Kind -# from relay.typing import TYPE_DEFAULTS -# from relay import ir - -# INT_TYPE_WIDTH = TYPE_DEFAULTS["INT_WIDTH"] -# INT_TYPE_LANES = TYPE_DEFAULTS["INT_LANES"] - -# def int_type(width=32) -> ir.Type: -# return TensorType(IntType(width), ShapeSeq([])) - -# def float_type(width=32) -> ir.Type: -# return TensorType(FloatType(width), ShapeSeq([])) - -# def bool_type() -> ir.Type: -# return TensorType(BoolType(), ShapeSeq([])) - -# def nest_quantifiers(ids, body) -> ir.Type: -# ret = body -# for tid in reversed(ids): -# ret = TypeQuantifier(tid, ret) -# return ret - -# def test_local_id_not_eq() -> None: -# assert not alpha_eq(LocalId("x"), LocalId("y")) - -# def test_local_id_eq() -> None: -# x = LocalId("x") -# assert alpha_eq(x, x) - -# def test_global_id_not_eq() -> None: -# left = GlobalId("xyz") -# right = GlobalId("xyz") -# assert not alpha_eq(left, right) - -# def test_global_id_eq() -> None: -# ident = GlobalId("xyz") -# assert alpha_eq(ident, ident) - -# def test_operator_id_not_eq() -> None: -# left = OperatorId("xyz") -# right = OperatorId("xyz") -# # equality on operator id is pointer equality -# assert not alpha_eq(left, right) - -# def test_operator_id_eq() -> None: -# x = OperatorId("xyz") -# assert alpha_eq(x, x) - -# def test_float_literal_eq() -> None: -# x = FloatLit(1.0) -# y = FloatLit(1.0) -# assert alpha_eq(x, y) - -# def test_float_literal_not_eq() -> None: -# x = FloatLit(1.0) -# y = FloatLit(2.0) -# assert not alpha_eq(x, y) - -# def test_int_literal_eq() -> None: -# x = IntLit(1) -# y = IntLit(1) -# assert alpha_eq(x, y) - -# def test_int_literal_not_eq() -> None: -# x = IntLit(1) -# y = IntLit(2) -# assert not alpha_eq(x, y) - -# def test_bool_literal_eq() -> None: -# x = BoolLit(True) -# y = BoolLit(True) -# assert alpha_eq(x, y) - -# def test_bool_literal_not_eq() -> None: -# x = BoolLit(True) -# y = BoolLit(False) -# assert not alpha_eq(x, y) - -# def test_tensor_literal_eq() -> None: -# x = TensorLit([IntLit(1), IntLit(2)]) -# y = TensorLit([IntLit(1), IntLit(2)]) -# assert alpha_eq(x, y) - -# def test_tensor_literal_not_eq() -> None: -# x = TensorLit([IntLit(1), IntLit(2)]) -# y = TensorLit([IntLit(1), IntLit(3)]) -# z = TensorLit([IntLit(1)]) -# assert not alpha_eq(x, y) -# assert not alpha_eq(x, z) - -# def test_product_literal_eq() -> None: -# x = Tuple([IntLit(1), IntLit(2)]) -# y = Tuple([IntLit(1), IntLit(2)]) -# assert alpha_eq(x, y) - -# def test_product_literal_not_eq() -> None: -# x = Tuple([IntLit(1), IntLit(2)]) -# y = Tuple([IntLit(2), IntLit(2)]) -# z = Tuple([IntLit(1), IntLit(2), IntLit(3)]) -# assert not alpha_eq(x, y) -# assert not alpha_eq(x, z) - -# def test_projection_eq() -> None: -# prod = Tuple([IntLit(3), FloatLit(3.5)]) - -# assert alpha_eq(Projection(prod, 0), Projection(prod, 0)) -# assert alpha_eq(Projection(prod, 1), Projection(prod, 1)) - -# def test_projection_not_eq() -> None: -# prod1 = Tuple([IntLit(3), IntLit(4)]) -# prod2 = Tuple([IntLit(3)]) -# prod3 = Tuple([IntLit(3), IntLit(4), FloatLit(3.5)]) - -# assert not alpha_eq(Projection(prod1, 0), Projection(prod1, 1)) -# assert not alpha_eq(Projection(prod1, 0), Projection(prod2, 0)) -# assert not alpha_eq(Projection(prod1, 0), Projection(prod3, 0)) -# assert not alpha_eq(Projection(prod1, 1), Projection(prod3, 1)) - -# def test_cast_not_eq() -> None: -# left = Cast(IntType(1), IntLit(2)) -# right = Cast(IntType(1), IntLit(1)) -# assert not alpha_eq(left, right) - -# # same literal, different type -# left = Cast(IntType(1), IntLit(2)) -# right = Cast(IntType(2), IntLit(2)) -# assert not alpha_eq(left, right) - -# def test_cast_eq() -> None: -# left = Cast(IntType(1), IntLit(2)) -# right = Cast(IntType(1), IntLit(2)) -# assert alpha_eq(left, right) - -# def test_param_not_eq() -> None: -# left = Param(LocalId("foo"), int_type()) -# right = Param(LocalId("foo"), bool_type()) -# assert not alpha_eq(left, right) - -# def test_param_eq() -> None: -# left = Param(LocalId("foo"), int_type()) -# right = Param(LocalId("bar"), int_type()) -# assert alpha_eq(left, right) - -# def test_function_not_eq() -> None: -# params1 = [Param(LocalId("x"), int_type())] -# fn1 = Function([], params1, int_type(), LocalId("x")) -# params2 = [Param(LocalId("y"), bool_type())] -# fn2 = Function([], params2, int_type(), LocalId("y")) -# assert not alpha_eq(fn1, fn2) - -# params3 = [Param(LocalId("x"), int_type()), Param(LocalId("y"), int_type())] -# fn3 = Function([], params3, int_type(), LocalId("z")) -# assert not alpha_eq(fn1, fn3) - -# def test_function_eq() -> None: -# x = LocalId("x") -# y = LocalId("y") -# params1 = [Param(x, int_type())] -# fn1 = Function([], params1, int_type(), x) -# params2 = [Param(y, int_type())] -# fn2 = Function([], params2, int_type(), y) -# assert alpha_eq(fn1, fn2) - -# def test_call_not_eq() -> None: -# x = LocalId("x") -# y = LocalId("y") -# params1 = [Param(x, int_type())] -# fn1 = Function([], params1, int_type(), x) -# args1 = [IntLit(1)] -# call1 = Call(fn1, args1) - -# args2 = [IntLit(2)] -# call2 = Call(fn1, args2) -# assert not alpha_eq(call1, call2) - -# params2 = [Param(y, int_type())] -# fn2 = Function([], params2, float_type(), FloatLit(0.0)) -# call3 = Call(fn2, args1) -# assert not alpha_eq(call1, call3) -# assert not alpha_eq(call2, call3) - -# def test_call_eq() -> None: -# x = LocalId("x") -# y = LocalId("y") -# params1 = [Param(x, int_type())] -# fn1 = Function([], params1, int_type(), x) -# args = [IntLit(1)] -# call1 = Call(fn1, args) - -# params2 = [Param(y, int_type())] -# fn2 = Function([], params2, int_type(), y) -# call2 = Call(fn2, args) -# assert alpha_eq(call1, call2) - -# def test_debug_not_eq() -> None: -# left = Debug(IntLit(1)) -# right = Debug(IntLit(2)) -# assert not alpha_eq(left, right) - -# def test_debug_eq() -> None: -# left = Debug(IntLit(1)) -# right = Debug(IntLit(1)) -# assert alpha_eq(left, right) - -# def test_let_not_eq() -> None: -# x = LocalId("x") -# y = LocalId("y") -# let1 = Let(x, int_type(), IntLit(10), IntLit(11)) -# let2 = Let(y, int_type(), IntLit(10), IntLit(12)) -# assert not alpha_eq(let1, let2) - -# let3 = Let(x, int_type(), IntLit(10), x) -# let4 = Let(y, int_type(), IntLit(12), y) -# assert not alpha_eq(let3, let4) - -# def test_let_eq() -> None: -# x = LocalId("x") -# y = LocalId("y") -# let1 = Let(x, int_type(), IntLit(10), x) -# let2 = Let(y, int_type(), IntLit(10), y) -# assert alpha_eq(let1, let2) - -# def test_ref_eq() -> None: -# r1 = Ref(IntLit(5)) -# r2 = Ref(IntLit(5)) -# assert alpha_eq(r1, r2) - -# def test_ref_not_eq() -> None: -# r1 = Ref(IntLit(5)) -# r2 = Ref(FloatLit(3.5)) -# r3 = Ref(r1) -# assert not alpha_eq(r1, r2) -# assert not alpha_eq(r1, r3) -# assert not alpha_eq(r2, r3) - -# def test_val_ref_eq() -> None: -# vr1 = ReadRef(Ref(IntLit(35))) -# vr2 = ReadRef(Ref(Tuple([IntLit(12), FloatLit(2.5)]))) -# assert alpha_eq(vr1, vr1) -# assert alpha_eq(vr2, vr2) - -# def test_val_ref_not_eq() -> None: -# vr1 = ReadRef(Ref(IntLit(5))) -# vr2 = ReadRef(Ref(vr1)) -# vr3 = ReadRef(Ref(FloatLit(5.0))) -# assert not alpha_eq(vr1, vr2) -# assert not alpha_eq(vr1, vr3) -# assert not alpha_eq(vr2, vr3) - -# def test_set_ref_eq() -> None: -# sr1 = WriteRef(Ref(FloatLit(5.0)), FloatLit(6.0)) -# sr2 = WriteRef(Ref(Tuple([IntLit(3), BoolLit(False)])), -# Tuple([IntLit(5), BoolLit(True)])) -# assert alpha_eq(sr1, sr1) -# assert alpha_eq(sr2, sr2) - -# def test_set_ref_not_eq() -> None: -# r1 = Ref(FloatLit(5.0)) -# r2 = Ref(IntLit(5)) -# r3 = Ref(IntLit(6)) - -# assert not alpha_eq(WriteRef(r1, FloatLit(6.0)), -# WriteRef(r2, IntLit(6))) -# assert not alpha_eq(WriteRef(r2, IntLit(6)), WriteRef(r2, IntLit(7))) -# assert not alpha_eq(WriteRef(r2, IntLit(7)), WriteRef(r3, IntLit(7))) - -# # Type alpha-equality tests - -# def test_base_type_eq() -> None: -# assert alpha_eq(IntType(32), IntType(32)) -# assert alpha_eq(BoolType(), BoolType()) -# assert alpha_eq(FloatType(32), FloatType(32)) - -# def test_tensor_type_eq() -> None: -# tt1 = TensorType( -# IntType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) -# tt2 = TensorType( -# FloatType(32), ShapeSeq([ShapeSingleton(3), ShapeSingleton(3)])) -# assert alpha_eq(tt1, tt1) -# assert alpha_eq(tt2, tt2) - -# def test_tensor_type_not_eq() -> None: -# tt1 = TensorType( -# IntType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) -# tt2 = TensorType( -# FloatType(32), ShapeSeq([ShapeSingleton(1), ShapeSingleton(2), ShapeSingleton(3)])) -# tt3 = TensorType( -# IntType(32), ShapeSeq([ShapeSingleton(3), ShapeSingleton(3)])) -# assert not alpha_eq(tt1, tt2) -# assert not alpha_eq(tt1, tt3) - -# def test_ref_type_eq() -> None: -# rt1 = RefType(int_type()) -# rt2 = RefType(float_type()) -# assert alpha_eq(rt1, rt1) -# assert alpha_eq(rt2, rt2) - -# def test_ref_type_not_eq() -> None: -# rt1 = RefType(int_type()) -# rt2 = RefType(float_type()) -# assert not alpha_eq(rt1, rt2) - -# def test_product_type_eq() -> None: -# pt1 = TupleType([int_type(), RefType(float_type())]) -# pt2 = TupleType([float_type(), float_type(), int_type()]) -# assert alpha_eq(pt1, pt1) -# assert alpha_eq(pt2, pt2) - -# def test_product_type_not_eq() -> None: -# pt1 = TupleType([int_type(), int_type()]) -# pt2 = TupleType([int_type(), int_type(), float_type()]) -# pt3 = TupleType([bool_type(), float_type()]) -# assert not alpha_eq(pt1, pt2) -# assert not alpha_eq(pt1, pt3) - -# def test_type_id_eq() -> None: -# id1 = TypeParam("id1", Kind.Shape) -# id2 = TypeParam("id2", Kind.BaseType) -# id3 = TypeParam("id2", Kind.Type) - -# assert alpha_eq(id1, id1) -# assert alpha_eq(id2, id2) -# assert alpha_eq(id3, id3) - -# def test_type_id_not_eq() -> None: -# # name is just a hint, we use pointer equality as the rule -# # (unless there is a quantifier to give context) -# id1 = TypeParam("id1", Kind.Shape) -# id2 = TypeParam("id1", Kind.Shape) -# id3 = TypeParam("id3", Kind.BaseType) - -# assert not alpha_eq(id1, id2) -# assert not alpha_eq(id1, id3) - -# def test_arrow_type_eq() -> None: -# ar1 = TypeArrow([int_type()], bool_type()) -# ar2 = TypeArrow([int_type(), int_type()], TupleType([])) -# assert alpha_eq(ar1, ar1) -# assert alpha_eq(ar2, ar2) - -# def test_arrow_type_not_eq() -> None: -# t1 = int_type() -# t2 = bool_type() -# t3 = [int_type(), bool_type()] - -# assert not alpha_eq(TypeArrow([t1], t2), TypeArrow([t1], t1)) -# assert not alpha_eq(TypeArrow(t3, t1), TypeArrow([t2], t1)) -# assert not alpha_eq(TypeArrow([t1], TypeArrow([t1], t1)), -# TypeArrow([t1], t1)) - -# def test_type_quantifier_eq() -> None: -# id1 = TypeParam("id1", Kind.Shape) -# id2 = TypeParam("id2", Kind.Shape) -# tq1 = TypeQuantifier(id1, TensorType(IntType(32), id1)) -# tq2 = TypeQuantifier(id2, TensorType(IntType(32), id2)) - -# assert alpha_eq(tq1, tq1) -# assert alpha_eq(tq1, tq2) - -# def test_nested_type_quantifier_eq() -> None: -# id1 = TypeParam("id1", Kind.BaseType) -# id2 = TypeParam("id2", Kind.Shape) -# id3 = TypeParam("id3", Kind.BaseType) -# id4 = TypeParam("id4", Kind.Shape) -# tq1 = TypeQuantifier(id1, TypeQuantifier(id2, TensorType(id1, id2))) -# tq2 = TypeQuantifier(id3, TypeQuantifier(id4, TensorType(id3, id4))) - -# assert alpha_eq(tq1, tq1) -# assert alpha_eq(tq1, tq2) - -# def test_type_quantifier_not_eq() -> None: -# id1 = TypeParam("id1", Kind.Shape) -# id2 = TypeParam("id2", Kind.BaseType) -# id3 = TypeParam("id3", Kind.Shape) - -# tq1 = TypeQuantifier(id1, TensorType(IntType(32), id1)) -# tq2 = TypeQuantifier(id2, TensorType(id2, ShapeSeq([ShapeSingleton(3)]))) -# tq3 = TypeQuantifier(id1, TensorType(IntType(32), id3)) -# tq4 = TypeQuantifier(id1, TensorType(FloatType(32), id1)) - -# assert not alpha_eq(tq1, tq2) -# assert not alpha_eq(tq1, tq3) -# assert not alpha_eq(tq1, tq4) -# assert not alpha_eq(tq2, tq3) -# assert not alpha_eq(tq2, tq4) - -# def test_shape_singleton_eq() -> None: -# single1 = ShapeSingleton(10) -# single2 = ShapeSingleton(10) - -# assert alpha_eq(single1, single1) -# assert alpha_eq(single1, single2) - -# def test_shape_singelton_not_eq() -> None: -# single1 = ShapeSingleton(10) -# single2 = ShapeSingleton(11) - -# assert not alpha_eq(single1, single2) - -# def test_shape_attr_eq() -> None: -# attr1 = ShapeAttr("x") -# attr2 = ShapeAttr("x") - -# assert alpha_eq(attr1, attr1) -# assert alpha_eq(attr1, attr2) - -# def test_shape_attr_not_eq() -> None: -# id1 = "x" -# id2 = "y" -# attr1 = ShapeAttr(id1) -# attr2 = ShapeAttr(id2) - -# assert not alpha_eq(attr1, attr2) - -# def test_shape_seq_eq() -> None: -# empty = ShapeSeq([]) -# seq1 = ShapeSeq([ShapeSingleton(5)]) -# seq2 = ShapeSeq([ShapeSingleton(5)]) - -# assert alpha_eq(empty, empty) -# assert alpha_eq(seq1, seq2) - -# def test_shape_seq_not_eq() -> None: -# empty = ShapeSeq([]) -# seq = ShapeSeq([ShapeSingleton(5)]) -# single = ShapeSingleton(5) - -# assert not alpha_eq(empty, seq) -# assert not alpha_eq(seq, single) - -# def test_shape_projection_eq() -> None: -# proj1 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) -# proj2 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) - -# assert alpha_eq(proj1, proj2) - -# def test_shape_projection_not_eq() -> None: -# proj1 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 0) -# proj2 = ShapeProjection(ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]), 1) -# proj3 = ShapeProjection(ShapeSeq([ShapeSingleton(2), ShapeSingleton(1)]), 0) -# proj4 = ShapeProjection(ShapeSeq([ShapeSingleton(2), ShapeSingleton(1)]), 1) - -# assert not alpha_eq(proj1, proj2) -# assert not alpha_eq(proj1, proj3) -# assert not alpha_eq(proj1, proj4) -# assert not alpha_eq(proj2, proj3) -# assert not alpha_eq(proj2, proj4) -# assert not alpha_eq(proj3, proj4) - -# def test_shape_binary_op_eq() -> None: -# empty = ShapeSeq([]) -# single = ShapeSingleton(5) -# seq = ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]) - -# op1 = ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty) -# op2 = ShapeBinaryOp(ShapeOp.SHSUB, single, single) -# op3 = ShapeBinaryOp(ShapeOp.SHMUL, seq, seq) -# op4 = ShapeBinaryOp(ShapeOp.SHDIV, seq, seq) - -# assert alpha_eq(op1, op1) -# assert alpha_eq(op2, op2) -# assert alpha_eq(op3, op3) -# assert alpha_eq(op4, op4) - -# def test_shape_binary_op_not_eq() -> None: -# empty = ShapeSeq([]) -# single = ShapeSingleton(5) -# seq = ShapeSeq([ShapeSingleton(1), ShapeSingleton(2)]) - -# assert not alpha_eq(ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty), empty) -# assert not alpha_eq(ShapeBinaryOp(ShapeOp.SHMUL, seq, ShapeSingleton(1)), seq) -# assert not alpha_eq( -# ShapeBinaryOp(ShapeOp.SHPLUS, single, single), -# ShapeBinaryOp(ShapeOp.SHPLUS, -# ShapeSeq([single]), -# ShapeSeq([single]))) -# assert not alpha_eq( -# ShapeBinaryOp(ShapeOp.SHPLUS, empty, empty), -# ShapeBinaryOp(ShapeOp.SHSUB, empty, empty)) -# assert not alpha_eq( -# ShapeBinaryOp(ShapeOp.SHMUL, empty, empty), -# ShapeBinaryOp(ShapeOp.SHDIV, empty, empty)) - -# def test_shape_nested_in_quantifier() -> None: -# b1 = TypeParam("b", Kind.BaseType) -# x1 = TypeParam("x", Kind.Shape) -# y1 = TypeParam("y", Kind.Shape) - -# b2 = TypeParam("b", Kind.BaseType) -# x2 = TypeParam("x", Kind.Shape) -# y2 = TypeParam("y", Kind.Shape) - -# b3 = TypeParam("b", Kind.BaseType) -# x3 = TypeParam("x", Kind.Shape) -# y3 = TypeParam("y", Kind.Shape) - -# tq1 = nest_quantifiers( -# [b1, x1, y1], -# TypeArrow( -# [TensorType(b1, x1), TensorType(b1, y2)], -# TensorType( -# b1, -# ShapeBinaryOp(ShapeOp.SHPLUS, -# ShapeSeq([x1, ShapeProjection(y1, 1), -# ShapeSingleton(5), ShapeAttr("att")]), -# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) - -# tq2 = nest_quantifiers( -# [b2, x2, y2], -# TypeArrow( -# [TensorType(b2, x2), TensorType(b2, y2)], -# TensorType( -# b2, -# ShapeBinaryOp(ShapeOp.SHPLUS, -# ShapeSeq([x2, ShapeProjection(y2, 1), -# ShapeSingleton(5), ShapeAttr("att")]), -# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) - -# # different attr, var order, position, and constant -# tq3 = nest_quantifiers( -# [b3, x3, y3], -# TypeArrow( -# [TensorType(b3, x3), TensorType(b3, y3)], -# TensorType( -# b3, -# ShapeBinaryOp(ShapeOp.SHPLUS, -# ShapeSeq([x3, ShapeProjection(y3, 1), -# ShapeSingleton(4), ShapeAttr("att")]), -# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) - -# tq4 = nest_quantifiers( -# [b3, x3, y3], -# TypeArrow( -# [TensorType(b3, x3), TensorType(b3, y3)], -# TensorType( -# b3, -# ShapeBinaryOp(ShapeOp.SHPLUS, -# ShapeSeq([x3, ShapeProjection(y3, 2), -# ShapeSingleton(5), ShapeAttr("att2")]), -# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) - -# tq5 = nest_quantifiers( -# [b3, x3, y3], -# TypeArrow( -# [TensorType(b3, x3), TensorType(b3, y3)], -# TensorType( -# b3, -# ShapeBinaryOp(ShapeOp.SHMUL, -# ShapeSeq([x3, ShapeProjection(y3, 1), -# ShapeSingleton(5), ShapeAttr("att")]), -# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) - -# tq6 = nest_quantifiers( -# [b3, y3, x3], -# TypeArrow( -# [TensorType(b3, x3), TensorType(b3, y3)], -# TensorType( -# b3, -# ShapeBinaryOp(ShapeOp.SHPLUS, -# ShapeSeq([x3, ShapeProjection(y3, 1), -# ShapeSingleton(5), ShapeAttr("att")]), -# ShapeSeq([ShapeSingleton(1) for i in range(6)]))))) - -# assert alpha_eq(tq1, tq2) -# assert not alpha_eq(tq1, tq3) -# assert not alpha_eq(tq2, tq3) -# assert not alpha_eq(tq1, tq4) -# assert not alpha_eq(tq2, tq4) -# assert not alpha_eq(tq1, tq5) -# assert not alpha_eq(tq2, tq5) -# assert not alpha_eq(tq1, tq6) -# assert not alpha_eq(tq2, tq6) From 167d2352e595e6867adc3b880d9a5e2b97ff2057 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 18 Sep 2018 16:41:30 -0700 Subject: [PATCH 119/136] Style fix in expr.cc --- src/relay/ir/expr.cc | 62 +++++++++++++++++++++----------------------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 925630bc8399..f4363f5312c4 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -30,9 +30,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) }); TensorType ConstantNode::tensor_type() const { - auto dl_dtype = data->dtype; - auto dtype = HalideIR::Type(static_cast(dl_dtype.code), - dl_dtype.bits, dl_dtype.lanes); + auto dtype = TVMType2Type(data->dtype); Array shape; for (int i = 0; i < data->ndim; i++) { @@ -100,14 +98,14 @@ Param ParamNode::make(Var var, Type type) { } TVM_REGISTER_API("relay._make.Param") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = ParamNode::make(args[0], args[1]); - }); +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = ParamNode::make(args[0], args[1]); +}); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const ParamNode *node, tvm::IRPrinter *p) { - p->stream << "ParamNode(" << node->var << ", " << node->type << ")"; - }); +.set_dispatch([](const ParamNode *node, tvm::IRPrinter *p) { + p->stream << "ParamNode(" << node->var << ", " << node->type << ")"; +}); Function FunctionNode::make(tvm::Array params, Type ret_type, Expr body, tvm::Array type_params) { @@ -129,16 +127,16 @@ Type FunctionNode::fn_type() const { } TVM_REGISTER_API("relay._make.Function") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = FunctionNode::make(args[0], args[1], args[2], args[3]); - }); +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = FunctionNode::make(args[0], args[1], args[2], args[3]); +}); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const FunctionNode *node, +.set_dispatch([](const FunctionNode *node, tvm::IRPrinter *p) { p->stream << "FunctionNode(" << node->params << ", " << node->ret_type << ", " << node->body << ", " << node->type_params << ")"; - }); +}); Call CallNode::make(Expr op, Array args, Attrs attrs, Array type_args) { @@ -151,15 +149,15 @@ Call CallNode::make(Expr op, Array args, Attrs attrs, } TVM_REGISTER_API("relay._make.Call") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = CallNode::make(args[0], args[1], args[2], args[3]); - }); +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = CallNode::make(args[0], args[1], args[2], args[3]); +}); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const CallNode *node, tvm::IRPrinter *p) { - p->stream << "CallNode(" << node->op << ", " << node->args << ", " - << node->attrs << ", " << node->type_args << ")"; - }); +.set_dispatch([](const CallNode *node, tvm::IRPrinter *p) { + p->stream << "CallNode(" << node->op << ", " << node->args << ", " + << node->attrs << ", " << node->type_args << ")"; +}); Let LetNode::make(Var var, Expr value, Expr body, Type value_type) { std::shared_ptr n = std::make_shared(); @@ -171,15 +169,15 @@ Let LetNode::make(Var var, Expr value, Expr body, Type value_type) { } TVM_REGISTER_API("relay._make.Let") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = LetNode::make(args[0], args[1], args[2], args[3]); - }); +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = LetNode::make(args[0], args[1], args[2], args[3]); +}); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const LetNode *node, tvm::IRPrinter *p) { - p->stream << "LetNode(" << node->var << ", " << node->value - << ", " << node->body << ", " << node->value_type << ")"; - }); +.set_dispatch([](const LetNode *node, tvm::IRPrinter *p) { + p->stream << "LetNode(" << node->var << ", " << node->value + << ", " << node->body << ", " << node->value_type << ")"; +}); If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) { std::shared_ptr n = std::make_shared(); @@ -194,10 +192,10 @@ TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) { }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const IfNode *node, tvm::IRPrinter *p) { - p->stream << "IfNode(" << node->cond << ", " << node->true_branch - << node->false_branch << ")"; - }); +.set_dispatch([](const IfNode *node, tvm::IRPrinter *p) { + p->stream << "IfNode(" << node->cond << ", " << node->true_branch + << node->false_branch << ")"; +}); } // namespace relay } // namespace tvm From a6b60380b894d9e431f306dd070f1a226ddd001a Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 18 Sep 2018 17:03:16 -0700 Subject: [PATCH 120/136] Fix & style --- include/tvm/relay/base.h | 1 + include/tvm/relay/error.h | 4 ++-- include/tvm/relay/expr.h | 11 ++++----- include/tvm/relay/op.h | 7 +++--- include/tvm/relay/pass.h | 2 +- src/relay/pass/alpha_eq.cc | 43 +++++++++++++++++----------------- src/relay/pass/resolve.h | 47 +++++++++++++++++++------------------- src/relay/pass/unifier.cc | 32 +++++++++++++------------- 8 files changed, 72 insertions(+), 75 deletions(-) diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 7c66d2c2de43..2b5667dc4dc9 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index 696e5a05487d..8ce73a027ca0 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -20,8 +20,8 @@ struct InternalError : Error { explicit InternalError(const std::string &msg) : Error(msg) {} }; -// FIX, we should change spanned errors to have a method which allow them to -// report on the Environment, inverting control to error definition. +// TODO(@jroesch): we should change spanned errors to report +// errors against the Environment, inverting control to error definition. struct FatalTypeError : dmlc::Error { explicit FatalTypeError(const std::string &s) : dmlc::Error(s) {} }; diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 21ba9659ca9a..886d16a9400a 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -6,11 +6,8 @@ #ifndef TVM_RELAY_EXPR_H_ #define TVM_RELAY_EXPR_H_ -#include -#include -#include -#include #include +#include #include "./base.h" #include "./type.h" @@ -52,7 +49,7 @@ RELAY_DEFINE_NODE_REF(Expr, ExprNode, NodeRef); /*! * \brief Constant tensor, backed by an NDArray on the cpu(0) device. * - * \note scalar constants are represented by rank-0 const tensor. + * \note Scalar constants are represented by rank-0 const tensor. * Constant folding are handled uniformly via Tensor types. */ class Constant; @@ -316,7 +313,7 @@ class LetNode : public ExprNode { Expr value; /*! \brief The body of the let binding */ Expr body; - /*! \brief type annotation of value, this can be null */ + /*! \brief Type annotation of value, this can be null */ Type value_type; void VisitAttrs(tvm::AttrVisitor* v) final { @@ -344,6 +341,8 @@ RELAY_DEFINE_NODE_REF(Let, LetNode, Expr); * * let x = if (true) { 1 } else { 0 }; // x is 1 * let y = if (false) { 1 } else { 0 }; // y is 0 + * + * \note This is similar to C's ternary operator. */ class If; /*! \brief container of If */ diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 45fbec715859..67a2ff3381bb 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -367,10 +367,9 @@ inline OpRegistry& OpRegistry::add_type_rel( std::vector arg_types; // Add inputs. - int i = 0; - for (auto arg : get()->arguments) { - std::string name = "in"; - name += std::to_string(i++); + std::string input_name_prefix = "in"; + for (int i = 0; i < get()->arguments.size(); i++) { + auto name = input_name_prefix + std::to_string(i++); auto param = TypeParamNode::make(name, TypeParamNode::Kind::kType); type_params.push_back(param); arg_types.push_back(param); diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index db29ad418d24..e956097780bb 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -26,7 +26,7 @@ namespace relay { * \return A type checked expression with its checked_type field populated. */ Expr InferType(const Environment& env, const Expr& e); -Expr InferType(const Environment& env, const GlobalVar & v, const Function & e); +Expr InferType(const Environment& env, const GlobalVar& v, const Function& e); /*! * \brief Check that types are well formed by applying "kinding rules". diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc index a3fab22570ec..f76da793c503 100644 --- a/src/relay/pass/alpha_eq.cc +++ b/src/relay/pass/alpha_eq.cc @@ -4,27 +4,26 @@ * \brief Compute the set of variables not bound in the expression. */ #include -#include "tvm/relay/pass.h" #include "./type_visitor.h" +#include "tvm/relay/pass.h" namespace tvm { namespace relay { using namespace tvm::runtime; -struct TypeAlphaEq : TypeVisitor { +struct TypeAlphaEq : TypeVisitor { tvm::Map eq_map; bool equal; TypeAlphaEq() : eq_map(), equal(true) {} - void DataTypeEqual(const DataType & dt1, const DataType & dt2) { - equal = equal && dt1 == dt2; - } - void ShapeEqual(Array s1, Array s2) { + void DataTypeEqual(const DataType& dt1, const DataType& dt2) { + equal = equal && dt1 == dt2; } + void ShapeEqual(Array s1, Array s2) {} - void VisitType_(const TensorTypeNode *tt1, const Type &t2) final { + void VisitType_(const TensorTypeNode *tt1, const Type& t2) final { if (const TensorTypeNode *tt2 = t2.as()) { DataTypeEqual(tt1->dtype, tt2->dtype); ShapeEqual(tt1->shape, tt2->shape); @@ -33,7 +32,7 @@ struct TypeAlphaEq : TypeVisitor { } } - void VisitType_(const IncompleteTypeNode *bt1, const Type &t2) final { + void VisitType_(const IncompleteTypeNode *bt1, const Type& t2) final { if (const IncompleteTypeNode *bt2 = t2.as()) { equal = equal && bt1 == bt2; return; @@ -42,7 +41,7 @@ struct TypeAlphaEq : TypeVisitor { } } - void VisitType_(const TypeParamNode *ti1, const Type &t2) final { + void VisitType_(const TypeParamNode *ti1, const Type& t2) final { if (const TypeParamNode *ti2 = t2.as()) { auto tid1 = GetRef(ti1); auto tid2 = GetRef(ti2); @@ -72,7 +71,7 @@ struct TypeAlphaEq : TypeVisitor { } } - void VisitType_(const FuncTypeNode *op, const Type &t2) final { + void VisitType_(const FuncTypeNode *op, const Type& t2) final { if (const FuncTypeNode *ta2 = t2.as()) { if (op->arg_types.size() != ta2->arg_types.size()) { equal = false; @@ -92,7 +91,7 @@ struct TypeAlphaEq : TypeVisitor { } } - void VisitType_(const TypeRelationNode *tr1, const Type &t2) final { + void VisitType_(const TypeRelationNode *tr1, const Type& t2) final { if (const TypeRelationNode *tr2 = t2.as()) { equal = tr1 == tr2; } else { @@ -100,7 +99,7 @@ struct TypeAlphaEq : TypeVisitor { } } - void VisitType_(const TupleTypeNode *op, const Type &t2) final { + void VisitType_(const TupleTypeNode *op, const Type& t2) final { if (const TupleTypeNode *pt = t2.as()) { if (op->fields.size() != pt->fields.size()) { equal = false; @@ -119,20 +118,20 @@ struct TypeAlphaEq : TypeVisitor { } }; -bool AlphaEqual(const Type &t1, const Type &t2) { +bool AlphaEqual(const Type& t1, const Type& t2) { TypeAlphaEq aeq; aeq.VisitType(t1, t2); return aeq.equal; } -struct AlphaEq : ExprFunctor { +struct AlphaEq : ExprFunctor { public: tvm::Map eq_map; bool equal; AlphaEq() : eq_map(), equal(true) {} - void VisitExpr_(const VarNode *e1, const Expr &e2) final { + void VisitExpr_(const VarNode *e1, const Expr& e2) final { if (const VarNode *id2 = e2.as()) { auto local1 = GetRef(e1); auto local2 = GetRef(id2); @@ -154,7 +153,7 @@ struct AlphaEq : ExprFunctor { } } - void VisitExpr_(const GlobalVarNode *g1, const Expr &e2) final { + void VisitExpr_(const GlobalVarNode *g1, const Expr& e2) final { if (const GlobalVarNode *g2 = e2.as()) { equal = equal && g1 == g2; } else { @@ -162,7 +161,7 @@ struct AlphaEq : ExprFunctor { } } - void VisitExpr_(const TupleNode *pl1, const Expr &e2) final { + void VisitExpr_(const TupleNode *pl1, const Expr& e2) final { Tuple prod1 = GetRef(pl1); if (const TupleNode *pl2 = e2.as()) { Tuple prod2 = GetRef(pl2); @@ -179,7 +178,7 @@ struct AlphaEq : ExprFunctor { } } - void VisitExpr_(const ParamNode *p1, const Expr &e2) final { + void VisitExpr_(const ParamNode *p1, const Expr& e2) final { if (const ParamNode *p2 = e2.as()) { eq_map.Set(p1->var, p2->var); equal = equal && AlphaEqual(p1->type, p2->type); @@ -188,7 +187,7 @@ struct AlphaEq : ExprFunctor { } } - void VisitExpr_(const FunctionNode *func1, const Expr &e2) final { + void VisitExpr_(const FunctionNode *func1, const Expr& e2) final { if (const FunctionNode *func2 = e2.as()) { if (func1->params.size() != func2->params.size()) { equal = false; @@ -205,7 +204,7 @@ struct AlphaEq : ExprFunctor { } } - void VisitExpr_(const CallNode *op, const Expr &e2) final { + void VisitExpr_(const CallNode *op, const Expr& e2) final { if (const CallNode *call = e2.as()) { this->VisitExpr(op->op, call->op); @@ -223,7 +222,7 @@ struct AlphaEq : ExprFunctor { } } - void VisitExpr_(const LetNode *op, const Expr &e2) final { + void VisitExpr_(const LetNode *op, const Expr& e2) final { if (const LetNode *let = e2.as()) { eq_map.Set(op->var, let->var); this->VisitExpr(op->value, let->value); @@ -234,7 +233,7 @@ struct AlphaEq : ExprFunctor { } }; -bool AlphaEqual(const Expr &e1, const Expr &e2) { +bool AlphaEqual(const Expr& e1, const Expr& e2) { AlphaEq eq; eq.VisitExpr(e1, e2); return eq.equal; diff --git a/src/relay/pass/resolve.h b/src/relay/pass/resolve.h index d213ddb2b1ff..0cd7dce2d88d 100644 --- a/src/relay/pass/resolve.h +++ b/src/relay/pass/resolve.h @@ -13,34 +13,33 @@ namespace tvm { namespace relay { - /*! \brief Resolve a type containing incomplete types. -* -* This pass replaces incomplete types with their representative, and -* converts types which are not defined into fresh variables. -* -* \param unifier The unifier containing the unification data. -* \param ty The type to resolve. -* \returns The resolved type. -*/ -Type Resolve(const TypeUnifier & unifier, const Type & ty); + * + * This pass replaces incomplete types with their representative, and + * converts types which are not defined into fresh variables. + * + * \param unifier The unifier containing the unification data. + * \param ty The type to resolve. + * \returns The resolved type. + */ +Type Resolve(const TypeUnifier& unifier, const Type& ty); /*! \brief Resolve an expression containing incomplete types. -* -* This pass replaces incomplete types with their representative, and -* converts types which are not defined into fresh variables. -* -* \param unifier The unifier containing the unification data. -* \param ty The expression to resolve. -* \returns The resolved expression. -*/ -Expr Resolve(const TypeUnifier & unifier, const Expr & expr); + * + * This pass replaces incomplete types with their representative, and + * converts types which are not defined into fresh variables. + * + * \param unifier The unifier containing the unification data. + * \param ty The expression to resolve. + * \returns The resolved expression. + */ +Expr Resolve(const TypeUnifier& unifier, const Expr& expr); -/*! \brief Check if all types have been filled in. -* \param t The type. -* \returns True if the type is resolved, false otherwise. -*/ -bool IsFullyResolved(const Type & t); +/*! \brief Check if all types have been filled in. + * \param t The type. + * \returns True if the type is resolved, false otherwise. + */ +bool IsFullyResolved(const Type& t); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index 73e837c52fa7..164404b19248 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -26,15 +26,15 @@ UnionFind UnionFindNode::make(tvm::Map uf_map) { return UnionFind(n); } -void UnionFindNode::Insert(const IncompleteType &v) { this->uf_map.Set(v, v); } +void UnionFindNode::Insert(const IncompleteType& v) { this->uf_map.Set(v, v); } void UnionFindNode::debug() { - for (auto entry : this->uf_map) { + for (const auto& entry : this->uf_map) { RELAY_LOG(INFO) << entry.first << " = " << entry.second << std::endl; } } -void UnionFindNode::AssertAlphaEqual(const Type &l, const Type &r) { +void UnionFindNode::AssertAlphaEqual(const Type& l, const Type& r) { if (!AlphaEqual(l, r)) { std::stringstream ss; ss << "Incompatible parent types in UF:" << l << " and " << r; @@ -42,7 +42,7 @@ void UnionFindNode::AssertAlphaEqual(const Type &l, const Type &r) { } } -void UnionFindNode::Unify(const IncompleteType &v1, const Type &t) { +void UnionFindNode::Unify(const IncompleteType& v1, const Type& t) { RELAY_LOG(INFO) << "UnionFindNode::Unify v1=" << v1 << ", t=" << t << std::endl; auto parent1 = this->Find(v1); @@ -89,7 +89,7 @@ void UnionFindNode::Unify(const IncompleteType &v1, const Type &t) { AssertAlphaEqual(parent1, t); } -Type UnionFindNode::Find(const IncompleteType &v) { +Type UnionFindNode::Find(const IncompleteType& v) { // The node has no mapping, so its representative is just itself. if (this->uf_map.find(v) == this->uf_map.end()) { return v; @@ -135,11 +135,11 @@ TypeUnifier TypeUnifierNode::make(UnionFind union_find) { return TypeUnifier(n); } -void TypeUnifierNode::Insert(const IncompleteType &v) { +void TypeUnifierNode::Insert(const IncompleteType& v) { this->union_find->Insert(v); } -Type TypeUnifierNode::Unify(const Type &t1, const Type &t2) { +Type TypeUnifierNode::Unify(const Type& t1, const Type& t2) { RELAY_LOG(INFO) << "TypeUnifierNode::unify: t1=" << t1 << " t2=" << t2 << std::endl; @@ -167,7 +167,7 @@ struct IncompleteTypeSubst : TypeMutator { } }; -Type TypeUnifierNode::Subst(const Type &t) { +Type TypeUnifierNode::Subst(const Type& t) { IncompleteTypeSubst tvsubst(this); // normalize first so substitutions in quantifiers will be correct Type ret = tvsubst.VisitType(t); @@ -182,7 +182,7 @@ Type TypeUnifierNode::Subst(const Type &t) { return ret; } -Type TypeUnifierNode::VisitType(const Type &t1, const Type t2) { +Type TypeUnifierNode::VisitType(const Type& t1, const Type t2) { // When the right hand size is a type variable immediately unify. if (const IncompleteTypeNode *tvn2 = t2.as()) { return this->UnifyWithIncompleteType(t1, GetRef(tvn2)); @@ -191,7 +191,7 @@ Type TypeUnifierNode::VisitType(const Type &t1, const Type t2) { } } -Type TypeUnifierNode::UnifyWithIncompleteType(const Type &t1, +Type TypeUnifierNode::UnifyWithIncompleteType(const Type& t1, const IncompleteType tv2) { RELAY_LOG(INFO) << "unifyWithIncompleteType: t1=" << t1 << " t2=" << tv2 << std::endl; @@ -202,7 +202,7 @@ Type TypeUnifierNode::UnifyWithIncompleteType(const Type &t1, return rep; } -Type TypeUnifierNode::VisitType_(const IncompleteTypeNode *t1, const Type rt2) { +Type TypeUnifierNode::VisitType_(const IncompleteTypeNode* t1, const Type rt2) { IncompleteType tv1 = GetRef(t1); RELAY_LOG(INFO) << "VisitType_: IncompleteTypeNode t1=" << t1 << " = " << rt2 << std::endl; @@ -212,7 +212,7 @@ Type TypeUnifierNode::VisitType_(const IncompleteTypeNode *t1, const Type rt2) { return rep; } -Type TypeUnifierNode::VisitType_(const TypeParamNode *t1, const Type rt2) { +Type TypeUnifierNode::VisitType_(const TypeParamNode* t1, const Type rt2) { TypeParam ti1 = GetRef(t1); if (const TypeParamNode *tin2 = rt2.as()) { @@ -228,7 +228,7 @@ Type TypeUnifierNode::VisitType_(const TypeParamNode *t1, const Type rt2) { throw UnificationError("Unable to unify TypeParamNode"); } -Type TypeUnifierNode::VisitType_(const FuncTypeNode *t1, const Type rt2) { +Type TypeUnifierNode::VisitType_(const FuncTypeNode* t1, const Type rt2) { FuncType ft1 = GetRef(t1); if (const FuncTypeNode *tan2 = rt2.as()) { @@ -265,7 +265,7 @@ Type TypeUnifierNode::VisitType_(const FuncTypeNode *t1, const Type rt2) { throw UnificationError("unable to unify function types"); } -Type TypeUnifierNode::VisitType_(const TensorTypeNode *t1, const Type rt2) { +Type TypeUnifierNode::VisitType_(const TensorTypeNode* t1, const Type rt2) { TensorType tt1 = GetRef(t1); if (const TensorTypeNode *ttn2 = rt2.as()) { @@ -294,7 +294,7 @@ Type TypeUnifierNode::VisitType_(const TensorTypeNode *t1, const Type rt2) { throw UnificationError("Cannot unify TensorTypeNode"); } -Type TypeUnifierNode::VisitType_(const TupleTypeNode *t1, const Type rt2) { +Type TypeUnifierNode::VisitType_(const TupleTypeNode* t1, const Type rt2) { TupleType pt1 = GetRef(t1); if (const TupleTypeNode *ptn2 = rt2.as()) { @@ -316,7 +316,7 @@ Type TypeUnifierNode::VisitType_(const TupleTypeNode *t1, const Type rt2) { throw UnificationError("Cannot unify TupleTypeNode"); } -Type TypeUnifierNode::VisitType_(const TypeRelationNode *tr1, const Type t2) { +Type TypeUnifierNode::VisitType_(const TypeRelationNode* tr1, const Type t2) { throw InternalError("Cannot unify different type relations"); } From 41098c421487ded4c53660a7fb3e43ec0f2c8770 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 18 Sep 2018 17:20:09 -0700 Subject: [PATCH 121/136] Fix more style issues --- include/tvm/relay/expr.h | 2 +- include/tvm/relay/type.h | 3 +-- src/relay/pass/type_subst.cc | 4 ++-- src/relay/pass/type_subst.h | 4 ++-- src/relay/pass/unifier.cc | 2 +- 5 files changed, 7 insertions(+), 8 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 886d16a9400a..6388e8367bf6 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -6,8 +6,8 @@ #ifndef TVM_RELAY_EXPR_H_ #define TVM_RELAY_EXPR_H_ -#include #include +#include #include "./base.h" #include "./type.h" diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index a17450d74ead..44030ad8d97f 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -119,8 +119,7 @@ class TypeParamNode : public TypeNode { kShapeVar = 0, kShape = 1, kBaseType = 2, - kType = 3, - kTypeList = 4, + kType = 3 }; /*! * \brief The variable itself is only meaningful when diff --git a/src/relay/pass/type_subst.cc b/src/relay/pass/type_subst.cc index 5fce9c2ca73b..1f816dc1659f 100644 --- a/src/relay/pass/type_subst.cc +++ b/src/relay/pass/type_subst.cc @@ -25,12 +25,12 @@ struct TypeSubstV : TypeMutator { } }; -Type TypeSubst(const Type &type, const TypeParam &target, const Type &subst) { +Type TypeSubst(const Type& type, const TypeParam& target, const Type& subst) { TypeSubstV ty_sub({ {target, subst} }); return ty_sub.VisitType(type); } -Type TypeSubst(const Type &type, tvm::Map subst_map) { +Type TypeSubst(const Type& type, tvm::Map subst_map) { TypeSubstV ty_sub(subst_map); return ty_sub.VisitType(type); } diff --git a/src/relay/pass/type_subst.h b/src/relay/pass/type_subst.h index 5b6956f8e451..aee3209afb7a 100644 --- a/src/relay/pass/type_subst.h +++ b/src/relay/pass/type_subst.h @@ -11,8 +11,8 @@ namespace tvm { namespace relay { -Type TypeSubst(const Type & type, const TypeParam & target, const Type & subst); -Type TypeSubst(const Type &type, tvm::Map subst_map); +Type TypeSubst(const Type& type, const TypeParam& target, const Type& subst); +Type TypeSubst(const Type& type, tvm::Map subst_map); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/unifier.cc b/src/relay/pass/unifier.cc index 164404b19248..b0ed71d17911 100644 --- a/src/relay/pass/unifier.cc +++ b/src/relay/pass/unifier.cc @@ -157,7 +157,7 @@ struct IncompleteTypeSubst : TypeMutator { IncompleteTypeSubst(const TypeUnifierNode *unifier) : unifier(unifier) {} // type var: look it up in the type map and recurse - Type VisitType_(const IncompleteTypeNode *op) override { + Type VisitType_(const IncompleteTypeNode* op) override { auto tv = GetRef(op); auto parent = unifier->union_find->Find(tv); if (parent == tv) { From bee42344e4528890db6d723245b80c4623c3d0f4 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 18 Sep 2018 17:43:21 -0700 Subject: [PATCH 122/136] Remove whitespace --- python/tvm/relay/ir_builder.py | 22 ++++++++++++++-------- src/relay/ir/base.cc | 14 +++++--------- src/relay/pass/type_subst.cc | 2 +- 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index fb5e9bb71956..99da2cf46420 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -1,3 +1,4 @@ +#pylint: disable=no-else-return """IR builder for the Relay IR. Enables users to construct Relay programs with a Python API. @@ -28,6 +29,7 @@ def _convert_to_value(arg: Any, ctxt=tvm.cpu(0)) -> tvm.nd.NDArray: # raise Exception(f"can't convert {type(arg)} to a Relay AST") raise Exception(f"unsupported argument type {type(arg)}") + def _convert_type(rtype): if isinstance(rtype, str): return scalar_type(rtype) @@ -36,6 +38,7 @@ def _convert_type(rtype): else: raise Exception(f"unsupported conversion to Relay type {type(rtype)}") + def convert(arg: Any, ctxt=tvm.cpu(0)) -> Expr: if isinstance(arg, Expr): return arg @@ -91,6 +94,8 @@ def to_func(self): self.type_params) #pylint: disable=invalid-name + + def _mk_let(bindings, ret_value): let_expr = ret_value for var, (value, ty) in reversed(list(bindings.items())): @@ -106,14 +111,14 @@ class IRBuilder(object): Examples -------- - + Program: fn (x : Tensor[f32, (10, 10)]) { let t1 = log(x); let t2 = add(t1, x); return t1; } - + ..code-block: python b = IRBuilder() with b.function(('x', tensor_type(10, 10))) as func: @@ -254,12 +259,12 @@ def global_var(self, name: str): ---------- name: str The name of the global variable. - + Returns ------- global_var: relay.GlobalVar The global variable with `name`. - + """ return self.env.global_var(name) @@ -298,12 +303,12 @@ def get(self): def scalar_type(dtype): """Construct a Relay scalar type. - + Parameters ---------- dtype: dtype The dtype of the scalar type. - + Returns: scalar_type: relay.Type The scalar type. @@ -313,14 +318,14 @@ def scalar_type(dtype): def tensor_type(*shape, dtype='float32'): """Construct a Relay Tensor type. - + Parameters ---------- shape: list of tvm.Expr The shape of the Tensor type. dtype: dtype The dtype of the Tensor type. - + Returns ------- tensor_type: relay.Type @@ -328,6 +333,7 @@ def tensor_type(*shape, dtype='float32'): """ return TensorType(tvm.convert(shape), dtype) + def func_type(args, ret_type, type_params=None, type_constraints=None): """document""" if not type_params: diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index e4ad08d893a3..7e7fb71f6d6c 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -26,17 +26,13 @@ std::shared_ptr CreateSourceName(const std::string& name) { } const SourceName& SourceName::Get(const std::string& name) { - static std::unordered_map *source_map; + static std::unordered_map source_map; - if (source_map == nullptr) { - source_map = new std::unordered_map(); - } - - auto sn = source_map->find(name); - if (sn == source_map->end()) { + auto sn = source_map.find(name); + if (sn == source_map.end()) { auto source_name = SourceNameNode::make(name); - source_map->insert({name, source_name}); - return source_map->at(name); + source_map.insert({name, source_name}); + return source_map.at(name); } else { return sn->second; } diff --git a/src/relay/pass/type_subst.cc b/src/relay/pass/type_subst.cc index 1f816dc1659f..0b17fa0bc4f8 100644 --- a/src/relay/pass/type_subst.cc +++ b/src/relay/pass/type_subst.cc @@ -15,7 +15,7 @@ struct TypeSubstV : TypeMutator { explicit TypeSubstV(tvm::Map subst_map) : subst_map(subst_map) {} - Type VisitType_(const TypeParamNode *op) override { + Type VisitType_(const TypeParamNode* op) override { auto id = GetRef(op); if (subst_map.find(id) != subst_map.end()) { return this->subst_map[id]; From 7e86f56564be5d013f02c463165296353f00e814 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 18 Sep 2018 22:39:29 -0700 Subject: [PATCH 123/136] Address final comments modulo type_infer.cc --- python/tvm/relay/ir_builder.py | 8 ++++---- src/relay/op/type_relations.h | 25 +++++++++++++++++++++++++ src/relay/pass/type_visitor.h | 12 ++++++------ 3 files changed, 35 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 99da2cf46420..94b927e95e04 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -112,10 +112,10 @@ class IRBuilder(object): Examples -------- - Program: - fn (x : Tensor[f32, (10, 10)]) { - let t1 = log(x); - let t2 = add(t1, x); + Program: + fn (x : Tensor[f32, (10, 10)]) { + let t1 = log(x); + let t2 = add(t1, x); return t1; } diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h index 3597246b5a4a..521c6f8e1681 100644 --- a/src/relay/op/type_relations.h +++ b/src/relay/op/type_relations.h @@ -13,8 +13,33 @@ namespace tvm { namespace relay { +/*! \brief The identity type relation maps a single input variable + * to the output variable. + * + * \param types The input and output types to the relation. + * \param num_args The number of input arguments. + * \return The (potentially partial) solution to the relation. + */ Array IdentityRel(const Array & types, int num_args); +/*! \brief The broadcast type relation, implements the broadcasting + * rule over the two input types producing the broadcasted type. + * + * \param types The input and output types to the relation. + * \param num_args The number of input arguments. + * \return The (potentially partial) solution to the relation. + */ Array BroadcastRel(const Array & types, int num_args); +/*! \brief The broadcast type relation, implements the broadcasting + * rule over the two input types producing the broadcasted type. + * + * This differs from BroadcastRel in the return dtype, + * it instead returns bool, for use in comparsion operators + * such as equal, not_equal, lt, and so on. + * + * \param types The input and output types to the relation. + * \param num_args The number of input arguments. + * \return The (potentially partial) solution to the relation. + */ Array BroadcastCompRel(const Array & types, int num_args); } // namespace relay diff --git a/src/relay/pass/type_visitor.h b/src/relay/pass/type_visitor.h index 357a36ffa41f..725e3d9b3846 100644 --- a/src/relay/pass/type_visitor.h +++ b/src/relay/pass/type_visitor.h @@ -23,30 +23,30 @@ struct TypeVisitor : ::tvm::relay::TypeFunctor { void VisitType_(const FuncTypeNode* op, Args... args) override { for (auto type_param : op->type_params) { - this->VisitType(type_param, args...); + this->VisitType(type_param, std::forward(args)...); } for (auto type_cs : op->type_constraints) { - this->VisitType(type_cs, args...); + this->VisitType(type_cs, std::forward(args)...); } for (auto arg_type : op->arg_types) { - this->VisitType(arg_type, args...); + this->VisitType(arg_type, std::forward(args)...); } - this->VisitType(op->ret_type, args...); + this->VisitType(op->ret_type, std::forward(args)...); } void VisitType_(const TensorTypeNode* op, Args... args) override {} void VisitType_(const TupleTypeNode* op, Args... args) override { for (const Type& t : op->fields) { - this->VisitType(t, args...); + this->VisitType(t, std::forward(args)...); } } void VisitType_(const TypeRelationNode* op, Args... args) override { for (const Type& t : op->args) { - this->VisitType(t, args...); + this->VisitType(t, std::forward(args)...); } } From 8339eaf18dc7651f1e4b797c71260109b6031793 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 18 Sep 2018 23:25:28 -0700 Subject: [PATCH 124/136] Address CR and work on 2.7 support --- include/tvm/relay/base.h | 1 - python/tvm/relay/base.py | 8 +- python/tvm/relay/env.py | 4 +- python/tvm/relay/expr.py | 52 +++++++------ python/tvm/relay/ir_builder.py | 29 ++++++-- python/tvm/relay/ty.py | 47 ++++++------ src/relay/ir/type.cc | 73 +++++++++---------- .../relay/test_tyck_eval_integration.py | 1 + 8 files changed, 120 insertions(+), 95 deletions(-) diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 2b5667dc4dc9..7c66d2c2de43 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -9,7 +9,6 @@ #include #include #include -#include #include #include diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index 0f3d2bc58d71..b2c65231cca1 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -22,9 +22,9 @@ def register_relay_node(type_key=None): @register_relay_node class Span(NodeBase): - source: "FileSource" - lineno: int - col_offset: int + source = None # type: FileSource + lineno = None # type: int + col_offset = None # type: int - def __init__(self, source, lineno, col_offset): + def __init__(self, source, lineno, col_offset): # type: (FileSource, int, int) -> None self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset) diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py index 8c9150d18835..62afef76425a 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/env.py @@ -11,7 +11,7 @@ class Environment(NodeBase): options and more. """ - def __init__(self, funcs) -> None: + def __init__(self, funcs): """Construct an environment. Parameters @@ -24,7 +24,7 @@ def __init__(self, funcs) -> None: """ self.__init_handle_by_constructor__(_make.Environment, funcs) - def add(self, var, func) -> None: + def add(self, var, func): """Add a function to the environment. Parameters diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 52a3aca7590f..da0fc0f5bc7e 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -29,9 +29,10 @@ def __call__(self, *args): class Constant(Expr): """A constant tensor in Relay, see tvm/relay/type.h for more details. """ - data: tvm.nd.NDArray + data = None # type: tvm.nd.NDArray - def __init__(self, data: tvm.nd.NDArray) -> None: + def __init__(self, data): + # type: (tvm.nd.NDArray) -> None self.__init_handle_by_constructor__(_make.Constant, data) @@ -40,27 +41,30 @@ class Tuple(Expr): """A hetereogenous sequence of values. see tvm/relay/type.h for more details. """ - fields: List[Expr] + fields = None # type: List[Expr] - def __init__(self, fields: List[Expr]) -> None: + def __init__(self, fields): + # type: (List[Expr]) -> None self.__init_handle_by_constructor__(_make.Tuple, fields) @register_relay_node class Var(Expr): """A local variable in Relay.""" - name_hint: str + name_hint = None # type: str - def __init__(self, name_hint: str) -> None: + def __init__(self, name_hint): + # type: (str) -> None self.__init_handle_by_constructor__(_make.Var, name_hint) @register_relay_node class GlobalVar(Expr): """A global variable in Relay.""" - name_hint: str + name_hint = None # type: str - def __init__(self, name_hint: str) -> None: + def __init__(self, name_hint): + # type: (str) -> None self.__init_handle_by_constructor__(_make.GlobalVar, name_hint) @@ -68,26 +72,29 @@ def __init__(self, name_hint: str) -> None: class Param(Expr): """A function type in Relay, see tvm/relay/type.h for more details. """ - var: Var - type: Type + var = None # type: Var + type = None # type: Type - def __init__(self, var: Var, ty: Type) -> None: + def __init__(self, var, ty): + # type: (Var, Type) -> None self.__init_handle_by_constructor__(_make.Param, var, ty) @register_relay_node class Function(Expr): """A function in Relay, see tvm/relay/expr.h for more details.""" - type_params: List[TypeParam] - params: List[Param] - ret_type: Type - body: Expr + type_params = None # type: List[TypeParam] + params = None # type: List[Param] + ret_type = None # type: Type + body = None # type: Expr def __init__(self, - params: List[Param], - ret_type: Type, - body: Expr, - type_params: List[TypeParam] = None) -> None: + params, # type: List[Param], + ret_type, # type: Type, + body, # type: Expr, + type_params = None, # type: List[TypeParam] + ): + # type: (...) -> None if not type_params: type_params = [] self.__init_handle_by_constructor__( @@ -97,8 +104,8 @@ def __init__(self, @register_relay_node class Call(Expr): """A function call in Relay, see tvm/relay/expr.h for more details.""" - op: Expr - args: List[Expr] + op = None # type: Expr + args = None # type: List[Expr] # todo(@jroesch): add attrs def __init__(self, op: Expr, args: List[Expr], attrs, ty_args=None) -> None: @@ -118,7 +125,8 @@ class Let(Expr): # should be type annotation value_type: Type - def __init__(self, var: Var, value: Expr, body: Expr, value_type: Type) -> None: + def __init__(self, var: Var, value: Expr, body, value_type) -> None: + # type: (Var, Expr, Expr, Type) -> None self.__init_handle_by_constructor__( _make.Let, var, value, body, value_type) diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 94b927e95e04..e3342ac4a6c8 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -11,7 +11,8 @@ from .env import Environment -def _convert_to_value(arg: Any, ctxt=tvm.cpu(0)) -> tvm.nd.NDArray: +def _convert_to_value(arg, ctxt=tvm.cpu(0)): + # type: (Any, tvm.Context) -> tvm.nd.NDArray """Convert Python values into the appropriate types for the Relay evaluator. """ @@ -27,7 +28,7 @@ def _convert_to_value(arg: Any, ctxt=tvm.cpu(0)) -> tvm.nd.NDArray: return arg else: # raise Exception(f"can't convert {type(arg)} to a Relay AST") - raise Exception(f"unsupported argument type {type(arg)}") + raise Exception("unsupported argument type {0}".format(type(arg))) def _convert_type(rtype): @@ -36,10 +37,11 @@ def _convert_type(rtype): elif isinstance(rtype, Type): return rtype else: - raise Exception(f"unsupported conversion to Relay type {type(rtype)}") + raise Exception("unsupported conversion to Relay type {0}".format(type(rtype))) -def convert(arg: Any, ctxt=tvm.cpu(0)) -> Expr: +def convert(arg): + # type: (Any) -> Expr if isinstance(arg, Expr): return arg elif isinstance(arg, tuple): @@ -47,7 +49,7 @@ def convert(arg: Any, ctxt=tvm.cpu(0)) -> Expr: elif isinstance(arg, PartialFunc): return arg.to_func() else: - value = _convert_to_value(arg, ctxt) + value = _convert_to_value(arg) return Constant(value) @@ -252,7 +254,8 @@ def param(self, name, ty=None): return Param(Var(name), ty) - def global_var(self, name: str): + def global_var(self, name): + # type: (str) -> GlobalVar """Construct a global var with `name` as its name hint. Parameters @@ -268,7 +271,12 @@ def global_var(self, name: str): """ return self.env.global_var(name) - def decl(self, name: str, *params, ret_type=None): + def decl(self, name, *params, **kwargs): + if 'ret_type' in kwargs: + ret_type = kwargs['ret_type'] + else: + ret_type = None + self.enter_scope() def _on_exit(): @@ -316,7 +324,7 @@ def scalar_type(dtype): return TensorType(tvm.convert([]), dtype) -def tensor_type(*shape, dtype='float32'): +def tensor_type(*shape, **kwargs): """Construct a Relay Tensor type. Parameters @@ -331,6 +339,11 @@ def tensor_type(*shape, dtype='float32'): tensor_type: relay.Type The resulting tensor types. """ + if 'dtype' in kwargs: + dtype = kwargs['dtype'] + else: + dtype = 'float32' + return TensorType(tvm.convert(shape), dtype) diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index ca74f2c5deb3..66bd50871358 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -10,16 +10,16 @@ class Type(NodeBase): """The base type for all Relay types.""" - def __eq__(self, other) -> bool: + def __eq__(self, other): # type: (Type) -> bool """Compare two Relay types for structural equivalence using alpha equivalence. """ return bool(_make._type_alpha_eq(self, other)) - def __ne__(self, other) -> bool: + def __ne__(self, other): # (Type) -> bool return not self.__eq__(other) - def same_as(self, other) -> bool: + def same_as(self, other): # (Type) -> bool """Compares two Relay types by referential equality.""" return super().__eq__(other) @@ -31,11 +31,12 @@ class TensorType(Type): This is the type assigned to tensor's with a known dype and shape. For example a tensor of `float32` and `(5, 5)`. """ - shape: List[expr.Expr] - dtype: str - span: Span + shape = None # type: List[expr.Expr] + dtype = None # type: str + span = None # type: Span - def __init__(self, shape: List[expr.Expr], dtype: str) -> None: + def __init__(self, shape, dtype): + # type: (List[expr.Expr], str) -> None """Construct a tensor type. Parameters @@ -73,11 +74,12 @@ class TypeParam(Type): be filled in later on. This allows the user to write functions which are generic over types. """ - var: expr.Var - kind: Kind - span: Span + var = None # type: expr.Var + kind = None # type: Kind + span = None # type: Span - def __init__(self, var: expr.Var, kind: Kind) -> None: + def __init__(self, var, kind): + # type: (expr.Var, Kind) -> None """Construct a TypeParam. Parameters @@ -114,17 +116,19 @@ class FuncType(Type): We informally write them as: `forall (type_params), (arg_types) -> ret_type where type_constraints` """ - type_params: List[TypeParam] - type_constraints: List[TypeConstraint] - arg_types: List[Type] - ret_type: Type - span: Span + type_params = None # type: List[TypeParam] + type_constraints = None # type: List[TypeConstraint] + arg_types = None # type: List[Type] + ret_type = None # type: Type + span = None # type: Span def __init__(self, - arg_types: List[Type], - ret_type: Type, - type_params: List[TypeParam], - type_constraints: List[TypeConstraint]) -> None: + arg_types, # type: List[Type], + ret_type, # type: Type, + type_params, # type: List[TypeParam], + type_constraints, # type: List[TypeConstraint] + ): + # type: (...) -> None """Construct a function type. Parameters @@ -147,5 +151,6 @@ def __init__(self, class IncompleteType(Type): """An incomplete type.""" - def __init__(self, kind: Kind) -> None: + def __init__(self, kind): + # type: (Kind) -> None self.__init_handle_by_constructor__(_make.IncompleteType, kind) diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index 1282acafcb92..c13fea26dacd 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -24,17 +24,16 @@ TensorType TensorTypeNode::Scalar(DataType dtype) { } TVM_REGISTER_API("relay._make.TensorType") - .set_body([](TVMArgs args, TVMRetValue *ret) { - Array shape = args[0]; - *ret = TensorTypeNode::make(shape, args[1]); - }); +.set_body([](TVMArgs args, TVMRetValue *ret) { + Array shape = args[0]; + *ret = TensorTypeNode::make(shape, args[1]); +}); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const TensorTypeNode *node, +.set_dispatch([](const TensorTypeNode *node, tvm::IRPrinter *p) { - p->stream << "TensorTypeNode(" << node->dtype << ", " << node->shape - << ")"; - }); + p->stream << "TensorTypeNode(" << node->dtype << ", " << node->shape << ")"; +}); TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) { std::shared_ptr n = std::make_shared(); @@ -44,18 +43,18 @@ TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) { } TVM_REGISTER_API("relay._make.TypeParam") - .set_body([](TVMArgs args, TVMRetValue *ret) { - int kind = args[1]; - *ret = - TypeParamNode::make(args[0], static_cast(kind)); +.set_body([](TVMArgs args, TVMRetValue *ret) { + int kind = args[1]; + *ret = + TypeParamNode::make(args[0], static_cast(kind)); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const TypeParamNode *node, +.set_dispatch([](const TypeParamNode *node, tvm::IRPrinter *p) { - p->stream << "TypeParamNode(" << node->var->name_hint << ", " - << node->kind << ")"; - }); + p->stream << "TypeParamNode(" << node->var->name_hint << ", " + << node->kind << ")"; +}); FuncType FuncTypeNode::make(tvm::Array arg_types, Type ret_type, tvm::Array type_params, @@ -69,17 +68,17 @@ FuncType FuncTypeNode::make(tvm::Array arg_types, Type ret_type, } TVM_REGISTER_API("relay._make.FuncType") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]); - }); +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]); +}); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const FuncTypeNode *node, +.set_dispatch([](const FuncTypeNode *node, tvm::IRPrinter *p) { - p->stream << "FuncTypeNode(" << node->type_params << ", " - << node->arg_types << ", " << node->ret_type << ", " - << node->type_constraints << ")"; - }); + p->stream << "FuncTypeNode(" << node->type_params << ", " + << node->arg_types << ", " << node->ret_type << ", " + << node->type_constraints << ")"; +}); TypeRelation TypeRelationNode::make(std::string name, TypeRelationFn func, Array args) { std::shared_ptr n = std::make_shared(); @@ -90,16 +89,16 @@ TypeRelation TypeRelationNode::make(std::string name, TypeRelationFn func, Array } TVM_REGISTER_API("relay._make.TypeRelation") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = TypeRelationNode::make(args[0], args[1], args[2]); - }); +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TypeRelationNode::make(args[0], args[1], args[2]); +}); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const TypeRelationNode *node, +.set_dispatch([](const TypeRelationNode *node, tvm::IRPrinter *p) { - p->stream << "TypeRelationNode(" << node->name << ", " << node->args - << ")"; - }); + p->stream << "TypeRelationNode(" << node->name << ", " << node->args + << ")"; +}); TupleType TupleTypeNode::make(Array fields) { std::shared_ptr n = std::make_shared(); @@ -108,15 +107,15 @@ TupleType TupleTypeNode::make(Array fields) { } TVM_REGISTER_API("relay._make.TupleType") - .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = TupleTypeNode::make(args[0]); - }); +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = TupleTypeNode::make(args[0]); +}); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const TupleTypeNode *node, +.set_dispatch([](const TupleTypeNode *node, tvm::IRPrinter *p) { - p->stream << "TupleTypeNode(" << node->fields << ")"; - }); + p->stream << "TupleTypeNode(" << node->fields << ")"; +}); } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index 95f296657380..69c5f534ee94 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -12,6 +12,7 @@ from tvm.relay.expr import Function def assert_has_type(expr, typ, env=Environment({})): + import pdb; pdb.set_trace() checked_expr = check_expr(env, expr) assert checked_expr.checked_type() == typ From a7b67ad27bd1dd9ae3b0a347a1fe9487304a32a4 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 18 Sep 2018 23:31:54 -0700 Subject: [PATCH 125/136] Move annotations to expr.pyi --- python/tvm/relay/expr.py | 39 ++----------- python/tvm/relay/expr.pyi | 114 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 34 deletions(-) create mode 100644 python/tvm/relay/expr.pyi diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index da0fc0f5bc7e..f5303fdfe80a 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -9,7 +9,7 @@ class Expr(NodeBase): - """The base type for all Relay exprressions.""" + """The base type for all Relay expressions.""" def checked_type(self): return _get_checked_type(self) @@ -29,10 +29,8 @@ def __call__(self, *args): class Constant(Expr): """A constant tensor in Relay, see tvm/relay/type.h for more details. """ - data = None # type: tvm.nd.NDArray def __init__(self, data): - # type: (tvm.nd.NDArray) -> None self.__init_handle_by_constructor__(_make.Constant, data) @@ -41,30 +39,24 @@ class Tuple(Expr): """A hetereogenous sequence of values. see tvm/relay/type.h for more details. """ - fields = None # type: List[Expr] def __init__(self, fields): - # type: (List[Expr]) -> None self.__init_handle_by_constructor__(_make.Tuple, fields) @register_relay_node class Var(Expr): """A local variable in Relay.""" - name_hint = None # type: str def __init__(self, name_hint): - # type: (str) -> None self.__init_handle_by_constructor__(_make.Var, name_hint) @register_relay_node class GlobalVar(Expr): """A global variable in Relay.""" - name_hint = None # type: str def __init__(self, name_hint): - # type: (str) -> None self.__init_handle_by_constructor__(_make.GlobalVar, name_hint) @@ -72,29 +64,20 @@ def __init__(self, name_hint): class Param(Expr): """A function type in Relay, see tvm/relay/type.h for more details. """ - var = None # type: Var - type = None # type: Type def __init__(self, var, ty): - # type: (Var, Type) -> None self.__init_handle_by_constructor__(_make.Param, var, ty) @register_relay_node class Function(Expr): """A function in Relay, see tvm/relay/expr.h for more details.""" - type_params = None # type: List[TypeParam] - params = None # type: List[Param] - ret_type = None # type: Type - body = None # type: Expr - def __init__(self, - params, # type: List[Param], - ret_type, # type: Type, - body, # type: Expr, - type_params = None, # type: List[TypeParam] + params, + ret_type, + body, + type_params = None, ): - # type: (...) -> None if not type_params: type_params = [] self.__init_handle_by_constructor__( @@ -104,9 +87,6 @@ def __init__(self, @register_relay_node class Call(Expr): """A function call in Relay, see tvm/relay/expr.h for more details.""" - op = None # type: Expr - args = None # type: List[Expr] - # todo(@jroesch): add attrs def __init__(self, op: Expr, args: List[Expr], attrs, ty_args=None) -> None: if not ty_args: @@ -119,14 +99,9 @@ def __init__(self, op: Expr, args: List[Expr], attrs, ty_args=None) -> None: @register_relay_node class Let(Expr): """A variable bindings in Relay, see tvm/relay/expr.h for more details.""" - var: Var - value: Expr - body: Expr # should be type annotation - value_type: Type def __init__(self, var: Var, value: Expr, body, value_type) -> None: - # type: (Var, Expr, Expr, Type) -> None self.__init_handle_by_constructor__( _make.Let, var, value, body, value_type) @@ -134,10 +109,6 @@ def __init__(self, var: Var, value: Expr, body, value_type) -> None: @register_relay_node class If(Expr): """A conditional expression in Relay, see tvm/relay/expr.h for more details.""" - cond: Expr - true_value: Expr - false_value: Expr - span: Span def __init__(self, cond: Expr, true_value: Expr, false_value: Expr) -> None: self.__init_handle_by_constructor__( diff --git a/python/tvm/relay/expr.pyi b/python/tvm/relay/expr.pyi new file mode 100644 index 000000000000..fd30e3ed25cf --- /dev/null +++ b/python/tvm/relay/expr.pyi @@ -0,0 +1,114 @@ +from typing import List +import tvm +from .base import Span, NodeBase +from .ty import Type, TypeParam +from ._ir_pass import _get_checked_type + + +class Expr(NodeBase): + def checked_type(self): + ... + + def __call__(self, *args): + ... + + +class Constant(Expr): + data = ... # type: tvm.nd.NDArray + + def __init__(self, data): + # type: (tvm.nd.NDArray) -> None + ... + + +class Tuple(Expr): + fields = .. # type: List[Expr] + + def __init__(self, fields): + # type: (List[Expr]) -> None + ... + + +class Var(Expr): + """A local variable in Relay.""" + name_hint = ... # type: str + + def __init__(self, name_hint): + # type: (str) -> None + ... + + +class GlobalVar(Expr): + name_hint = ... # type: str + + def __init__(self, name_hint): + # type: (str) -> None + ... + + +class Param(Expr): + var = ... # type: Var + type = ... # type: Type + + def __init__(self, var, ty): + # type: (Var, Type) -> None + ... + + +class Function(Expr): + """A function in Relay, see tvm/relay/expr.h for more details.""" + type_params = ... # type: List[TypeParam] + params = ... # type: List[Param] + ret_type = ... # type: Type + body = ... # type: Expr + + def __init__(self, + params, # type: List[Param], + ret_type, # type: Type, + body, # type: Expr, + type_params=None, # type: List[TypeParam] + ): + # type: (...) -> None + ... + + +@register_relay_node +class Call(Expr): + """A function call in Relay, see tvm/relay/expr.h for more details.""" + op = ... # type: Expr + args = ... # type: List[Expr] + # todo(@jroesch): add attrs + + def __init__(self, op, args, attrs, ty_args=None): + # type: (Expr, List[Expr], Optional[List[Type]]) -> None + if not ty_args: + ty_args = [] + + self.__init_handle_by_constructor__( + _make.Call, op, args, attrs, ty_args) + + +@register_relay_node +class Let(Expr): + """A variable bindings in Relay, see tvm/relay/expr.h for more details.""" + var = ... # type: Var + value = ... # type: Expr + body = ... # type: Expr + value_type = ... # type: Type + + def __init__(self, var, value, body, value_type): + # type: (Var, Expr, Expr, Type) -> None + ... + + +@register_relay_node +class If(Expr): + """A conditional expression in Relay, see tvm/relay/expr.h for more details.""" + cond = ... # type: Expr + true_value = ... # type: Expr + false_value = ... # type: Expr + span = ... # type: Span + + def __init__(self, cond, true_value, false_value): + # type: (Expr, Expr, Expr) -> None + ... From a1734d70f11e53a63bf3c6b14f6a81f93971850f Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 18 Sep 2018 23:48:19 -0700 Subject: [PATCH 126/136] Fix bug introduced by refactor --- include/tvm/relay/op.h | 4 ++-- python/tvm/relay/expr.py | 7 +++---- src/relay/pass/type_infer.cc | 2 ++ tests/python/relay/test_tyck_eval_integration.py | 14 +++++++------- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 67a2ff3381bb..49661fec5731 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -368,8 +368,8 @@ inline OpRegistry& OpRegistry::add_type_rel( // Add inputs. std::string input_name_prefix = "in"; - for (int i = 0; i < get()->arguments.size(); i++) { - auto name = input_name_prefix + std::to_string(i++); + for (int i = 0; i < get()->num_inputs; i++) { + auto name = input_name_prefix + std::to_string(i); auto param = TypeParamNode::make(name, TypeParamNode::Kind::kType); type_params.push_back(param); arg_types.push_back(param); diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index f5303fdfe80a..b872f48fd8b2 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -88,7 +88,7 @@ def __init__(self, class Call(Expr): """A function call in Relay, see tvm/relay/expr.h for more details.""" - def __init__(self, op: Expr, args: List[Expr], attrs, ty_args=None) -> None: + def __init__(self, op, args, attrs, ty_args=None): if not ty_args: ty_args = [] @@ -99,9 +99,8 @@ def __init__(self, op: Expr, args: List[Expr], attrs, ty_args=None) -> None: @register_relay_node class Let(Expr): """A variable bindings in Relay, see tvm/relay/expr.h for more details.""" - # should be type annotation - def __init__(self, var: Var, value: Expr, body, value_type) -> None: + def __init__(self, var, value, body, value_type): self.__init_handle_by_constructor__( _make.Let, var, value, body, value_type) @@ -110,6 +109,6 @@ def __init__(self, var: Var, value: Expr, body, value_type) -> None: class If(Expr): """A conditional expression in Relay, see tvm/relay/expr.h for more details.""" - def __init__(self, cond: Expr, true_value: Expr, false_value: Expr) -> None: + def __init__(self, cond, true_value, false_value): self.__init_handle_by_constructor__( _make.If, cond, true_value, false_value) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index b3f7f34597d9..6eb095338732 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -291,6 +291,8 @@ CheckedExpr TypeInferencer::VisitExpr_(const CallNode *op) { // We now have a function type. FuncType fn_ty = GetRef(fn_ty_node); + std::cout << fn_ty << std::endl; + tvm::Array ty_args; if (ty_args.size() != 0) { throw Error("found manually suplied type args, not supported"); diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index 69c5f534ee94..50a34c89ec80 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -12,7 +12,6 @@ from tvm.relay.expr import Function def assert_has_type(expr, typ, env=Environment({})): - import pdb; pdb.set_trace() checked_expr = check_expr(env, expr) assert checked_expr.checked_type() == typ @@ -59,6 +58,7 @@ def test_add_op(): prog, env = b.get() ttype = tensor_type(5, 5, 5) expected_ty = func_type([ttype, ttype], ttype) + import pdb; pdb.set_trace() assert_has_type(func.to_func(), expected_ty) def test_add_broadcast_op(): @@ -140,10 +140,10 @@ def f(n: i32, data: f32) -> f32 { # to execute this. if __name__ == "__main__": - test_monomorphic_let() - test_single_op() + # test_monomorphic_let() + # test_single_op() test_add_op() - test_add_broadcast_op() - test_dual_op() - test_decl() - test_recursion() + # test_add_broadcast_op() + # test_dual_op() + # test_decl() + # test_recursion() From 49768089d06bfca244a1af11f6d30870cede4329 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 18 Sep 2018 23:50:43 -0700 Subject: [PATCH 127/136] Python2 and Python3 tests pass locally --- src/relay/pass/type_infer.cc | 2 -- tests/python/relay/test_tyck_eval_integration.py | 15 ++++++--------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 6eb095338732..b3f7f34597d9 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -291,8 +291,6 @@ CheckedExpr TypeInferencer::VisitExpr_(const CallNode *op) { // We now have a function type. FuncType fn_ty = GetRef(fn_ty_node); - std::cout << fn_ty << std::endl; - tvm::Array ty_args; if (ty_args.size() != 0) { throw Error("found manually suplied type args, not supported"); diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index 50a34c89ec80..b8f3dfc0fb34 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -29,8 +29,6 @@ def test_monomorphic_let(): prog, env = b.get() assert_has_type(prog, scalar_type('float64')) - # Need to handle constants - # run(env, prog, [], float_type(64)) def test_single_op(): @@ -58,7 +56,6 @@ def test_add_op(): prog, env = b.get() ttype = tensor_type(5, 5, 5) expected_ty = func_type([ttype, ttype], ttype) - import pdb; pdb.set_trace() assert_has_type(func.to_func(), expected_ty) def test_add_broadcast_op(): @@ -140,10 +137,10 @@ def f(n: i32, data: f32) -> f32 { # to execute this. if __name__ == "__main__": - # test_monomorphic_let() - # test_single_op() + test_monomorphic_let() + test_single_op() test_add_op() - # test_add_broadcast_op() - # test_dual_op() - # test_decl() - # test_recursion() + test_add_broadcast_op() + test_dual_op() + test_decl() + test_recursion() From f074edd01d7525c7617b7ef358d744f24c30a2bc Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 19 Sep 2018 00:21:35 -0700 Subject: [PATCH 128/136] Refactor type relation solver. --- src/relay/pass/type_infer.cc | 137 ++++++++++++++++++++--------------- 1 file changed, 78 insertions(+), 59 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index b3f7f34597d9..bdda3940fc6e 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -35,42 +35,45 @@ namespace relay { using namespace tvm::runtime; -// @tqchen -// I wanted to use this data structure but then the algorithm gets more complex -// because we need to convert them back to the same representation as before -// when we check a single function scope. See line 240. -// -// I can see building an auxillary data structure at solve time but it seems -// like a lot of complexity for an unquantified speed gain, which we may or may -// not need. -// -// Thoughts? -// // // We declare this for forward compatibility. -// struct ConstraintData {}; - -// struct TyRelData : ConstraintData { -// std::vector args; -// TypeRelationFn func; -// bool complete; -// TyRelData(Array args, TypeRelationFn func) : complete(false), -// func(func) { -// for (auto arg : args) { -// this->args.push_back(arg); -// } -// } -// }; +struct ConstraintData {}; + +/*! \brief A more efficient representation of the type relation + * data needed for type checking. + */ +struct TypeRelationData : ConstraintData { + std::string name; + std::vector args; + TypeRelationFn func; + Span span; + + explicit TypeRelationData(const TypeRelation& ty_rel) + : TypeRelationData(ty_rel->args, ty_rel->func_, ty_rel->span) {} + + TypeRelationData(const Array& args, const TypeRelationFn& func, const Span& sp) + : func(func), span(sp) { + for (auto arg : args) { + this->args.push_back(arg); + } + } + + TypeRelation ToTypeRel() const { + Array args = Array(this->args.begin(), this->args.end()); + return TypeRelationNode::make( + this->name, this->func, args); + } +}; struct TypeContext { std::unordered_map var_map; - std::vector> constraints; + std::vector > constraints; TypeContext() { constraints.push_back({}); } void Insert(const Var &id, const Type &t) { var_map[id] = t; } void AddConstraint(const TypeConstraint &constraint) { - constraints.back().push_back(constraint); + constraints.back().push_back(TypeRelationData(Downcast(constraint))); } Type Lookup(const Var &id) { @@ -82,10 +85,10 @@ struct TypeContext { } } - struct Frame { + struct Scope { TypeContext &tc; - explicit Frame(TypeContext &tc) : tc(tc) { tc.constraints.push_back({}); } - ~Frame() { tc.constraints.pop_back(); } + explicit Scope(TypeContext &tc) : tc(tc) { tc.constraints.push_back({}); } + ~Scope() { tc.constraints.pop_back(); } }; }; @@ -106,10 +109,9 @@ class TypeInferencer : private ExprFunctor { Environment env; TypeUnifier unifier; - // Should be in header? template - T WithFrame(const std::function &f) { - TypeContext::Frame fr(context); + T WithScope(const std::function &f) { + TypeContext::Scope fr(context); return f(); } @@ -130,11 +132,20 @@ class TypeInferencer : private ExprFunctor { Type Unify(const Type &t1, const Type &t2, Span sp); Type Resolve(const Type &t); Expr Resolve(const Expr &e); - TypeRelation Solve(const TypeRelation &ty_rel); - SolverResult Solve(std::vector &rels); + + /*! \brief Attempt to solve a single relation. */ + void Solve(TypeRelationData & ty_rel); + + /*! \brief Attempt to solve all pending relations. + * + * If the solver + */ + SolverResult Solve(std::vector &rels); /*! \brief Check that all relations hold. */ bool RelationsHold(bool scope_only = false); + + /*! \brief Visit a function node, extra flag controls behavior. */ CheckedExpr VisitFunction(const Function &f, bool generalize); private: @@ -219,7 +230,7 @@ CheckedExpr TypeInferencer::VisitFunction(const Function &f, bool generalize) { std::vector param_types; std::vector params; - return this->WithFrame([&]() -> CheckedExpr { + return this->WithScope([&]() -> CheckedExpr { for (auto param : f->params) { CheckedExpr checked_param = this->Infer(param); Type arg_type; @@ -239,7 +250,7 @@ CheckedExpr TypeInferencer::VisitFunction(const Function &f, bool generalize) { Array cs; for (auto cons : this->context.constraints.back()) { - cs.push_back(cons); + cs.push_back(cons.ToTypeRel()); } return {FunctionNode::make(params, unified_rtype, checked_body.expr, {}), @@ -408,29 +419,27 @@ Expr TypeInferencer::Resolve(const Expr &e) { return ::tvm::relay::Resolve(this->unifier, e); } -TypeRelation TypeInferencer::Solve(const TypeRelation &ty_rel) { +void TypeInferencer::Solve(TypeRelationData & ty_rel) { Array normalized_args; - for (auto arg : ty_rel->args) { + for (auto arg : ty_rel.args) { normalized_args.push_back(Resolve(arg)); } - auto new_args = ty_rel->func_(normalized_args, ty_rel->args.size() - 1); + auto new_args = ty_rel.func(normalized_args, ty_rel.args.size()); CHECK(new_args.size() == normalized_args.size()); tvm::Array final_args; for (size_t i = 0; i < new_args.size(); i++) { - final_args.push_back(Unify(normalized_args[i], new_args[i], ty_rel->span)); + ty_rel.args[i] = Unify(normalized_args[i], new_args[i], ty_rel.span); } - - return TypeRelationNode::make(ty_rel->name, ty_rel->func_, final_args); } -int NumSolvedVars(const TypeRelation &ty_rel) { +int NumSolvedVars(const Array & vars) { int num = 0; - for (auto arg : ty_rel->args) { - if (!arg.as()) { + for (auto var : vars) { + if (!var.as()) { num += 1; } } @@ -443,7 +452,7 @@ enum SolverResult : int { Done = 1, }; -SolverResult TypeInferencer::Solve(std::vector &rels) { +SolverResult TypeInferencer::Solve(std::vector &rels) { // We start in the done state with zero progress. SolverResult status = SolverResult::Done; int progress = 0; @@ -453,10 +462,13 @@ SolverResult TypeInferencer::Solve(std::vector &rels) { status = SolverResult::Done; progress = 0; + std::vector complete; + + int i = 0; // We will now process each relation in order. - for (TypeRelation &ty_rel : rels) { - int arity = ty_rel->args.size(); - int pre_solved = NumSolvedVars(ty_rel); + for (TypeRelationData &ty_rel : rels) { + int arity = ty_rel.args.size(); + int pre_solved = NumSolvedVars(ty_rel.args); RELAY_LOG(INFO) << "TypeInferencer::Solve: " << "TypeRelation= " << ", Arity=" << arity << ", Solved=" << pre_solved @@ -465,10 +477,11 @@ SolverResult TypeInferencer::Solve(std::vector &rels) { // to set the status to done. if (pre_solved == arity) { status = static_cast((status && SolverResult::Done)); - // If there are unsolved variables we will try to solve some. + complete.push_back(i); + // If there are unsolved variables we will try to solve some. } else if (pre_solved < arity) { - auto solved = Solve(ty_rel); - int post_solved = NumSolvedVars(solved); + Solve(ty_rel); + int post_solved = NumSolvedVars(ty_rel.args); // If we solved any variables we will try to downgrade status to // progress update the type relation, and then bump the progress counter @@ -476,10 +489,10 @@ SolverResult TypeInferencer::Solve(std::vector &rels) { if (post_solved > pre_solved) { status = static_cast((status && SolverResult::Progress)); - ty_rel = solved; progress += 1; } } + i++; } // If we made no progress and we aren't finished, then the state should be @@ -489,6 +502,16 @@ SolverResult TypeInferencer::Solve(std::vector &rels) { break; } + // Remove the satisfied relations. + for (auto i : complete) { + if (rels.size() > 1) { + rels[i] = rels.back(); + rels.pop_back(); + } else { + rels.pop_back(); + } + } + std::reverse(rels.begin(), rels.end()); } while (status == SolverResult::Progress); return status; @@ -499,7 +522,7 @@ bool TypeInferencer::RelationsHold(bool scope_only) { // slice out the constraints. // // Otherwise we use all of them. - std::vector> constraints; + std::vector > constraints; if (scope_only) { constraints = {context.constraints[0]}; @@ -510,11 +533,7 @@ bool TypeInferencer::RelationsHold(bool scope_only) { RELAY_LOG(INFO) << "TypeInferencer::RelationsHold: scope_only= " << scope_only << std::endl; bool all_hold = true; - for (auto cs_set : context.constraints) { - std::vector ty_rels; - for (auto cs : cs_set) { - ty_rels.push_back(Downcast(cs)); - } + for (auto ty_rels : context.constraints) { auto status = Solve(ty_rels); RELAY_LOG(INFO) << "status= " << status << std::endl; if (status == SolverResult::Failed || status == SolverResult::Progress) { From d7835fa91205fa302747740ba9477f85ed745991 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 19 Sep 2018 00:32:17 -0700 Subject: [PATCH 129/136] Fix PyLint --- python/tvm/relay/base.py | 2 +- python/tvm/relay/expr.py | 9 ++++--- python/tvm/relay/ir_builder.py | 44 +++++++++++++++++++++++++++------- 3 files changed, 41 insertions(+), 14 deletions(-) diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index b2c65231cca1..60015551bd83 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -23,7 +23,7 @@ def register_relay_node(type_key=None): @register_relay_node class Span(NodeBase): source = None # type: FileSource - lineno = None # type: int + lineno = None # type: int col_offset = None # type: int def __init__(self, source, lineno, col_offset): # type: (FileSource, int, int) -> None diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index b872f48fd8b2..f76d1552f34c 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -1,8 +1,6 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """The expression nodes of Relay.""" -from typing import List -import tvm -from .base import Span, NodeBase, register_relay_node +from .base import NodeBase, register_relay_node from .ty import Type, TypeParam from ._ir_pass import _get_checked_type from . import _make @@ -72,11 +70,12 @@ def __init__(self, var, ty): @register_relay_node class Function(Expr): """A function in Relay, see tvm/relay/expr.h for more details.""" + def __init__(self, params, ret_type, - body, - type_params = None, + body, + type_params=None, ): if not type_params: type_params = [] diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index e3342ac4a6c8..42347b0eeb75 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -1,4 +1,4 @@ -#pylint: disable=no-else-return +# pylint: disable=no-else-return """IR builder for the Relay IR. Enables users to construct Relay programs with a Python API. @@ -37,11 +37,24 @@ def _convert_type(rtype): elif isinstance(rtype, Type): return rtype else: - raise Exception("unsupported conversion to Relay type {0}".format(type(rtype))) + raise Exception( + "unsupported conversion to Relay type {0}".format(type(rtype))) def convert(arg): # type: (Any) -> Expr + """Convert some Python objects into a Relay AST fragment. + + Parameters + ---------- + arg: Any + The Python object + + Returns + ------- + expr: relay.Expr + The converted expression. + """ if isinstance(arg, Expr): return arg elif isinstance(arg, tuple): @@ -343,16 +356,31 @@ def tensor_type(*shape, **kwargs): dtype = kwargs['dtype'] else: dtype = 'float32' - + return TensorType(tvm.convert(shape), dtype) -def func_type(args, ret_type, type_params=None, type_constraints=None): - """document""" +def func_type(args, ret_type, type_params=None): + """Construct a Relay function type. + + Parameters + ---------- + args: list of relay.Type + The argument types. + + ret_type: relay.Type + The return type. + + type_params: list of relay.TypeParam + The type parameters. + + Returns + ------- + func_type: The function type. + """ if not type_params: type_params = [] - if not type_constraints: - type_constraints = [] + args = [_convert_type(arg) for arg in args] ret_type = _convert_type(ret_type) - return FuncType(args, ret_type, type_params, type_constraints) + return FuncType(args, ret_type, type_params, []) From 92810775bee06577b97d2ad848e71bc8c5ab6ae0 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 19 Sep 2018 01:45:06 -0700 Subject: [PATCH 130/136] Fix PyLint --- python/tvm/relay/expr.py | 5 +- python/tvm/relay/ir_builder.py | 25 +++--- python/tvm/relay/ty.py | 8 +- python/tvm/relay/ty.pyi | 139 +++++++++++++++++++++++++++++++++ 4 files changed, 160 insertions(+), 17 deletions(-) create mode 100644 python/tvm/relay/ty.pyi diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index f76d1552f34c..1f34294d6c15 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -1,7 +1,6 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """The expression nodes of Relay.""" from .base import NodeBase, register_relay_node -from .ty import Type, TypeParam from ._ir_pass import _get_checked_type from . import _make @@ -75,8 +74,8 @@ def __init__(self, params, ret_type, body, - type_params=None, - ): + type_params=None + ): if not type_params: type_params = [] self.__init_handle_by_constructor__( diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 42347b0eeb75..1ef09f1874a5 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -3,7 +3,6 @@ Enables users to construct Relay programs with a Python API. """ -from typing import Any import numpy as np import tvm from .ty import Type, FuncType, TensorType @@ -285,10 +284,21 @@ def global_var(self, name): return self.env.global_var(name) def decl(self, name, *params, **kwargs): - if 'ret_type' in kwargs: - ret_type = kwargs['ret_type'] - else: - ret_type = None + """Create a global function. + + Parameters + ---------- + name: str or GlobalVar + The name of the function. + params: params + The parameters of the function. + + Returns + ------- + with_scope: Scope for the function. + """ + + ret_type = kwargs.get('ret_type', None) self.enter_scope() @@ -352,10 +362,7 @@ def tensor_type(*shape, **kwargs): tensor_type: relay.Type The resulting tensor types. """ - if 'dtype' in kwargs: - dtype = kwargs['dtype'] - else: - dtype = 'float32' + dtype = kwargs.get('dtype', 'float32') return TensorType(tvm.convert(shape), dtype) diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index 66bd50871358..768b01609ca8 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -1,9 +1,7 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """The type nodes of the Relay language.""" -from typing import List from enum import IntEnum -from tvm import expr -from .base import Span, NodeBase, register_relay_node +from .base import NodeBase, register_relay_node from . import _make @@ -126,8 +124,8 @@ def __init__(self, arg_types, # type: List[Type], ret_type, # type: Type, type_params, # type: List[TypeParam], - type_constraints, # type: List[TypeConstraint] - ): + type_constraints # type: List[TypeConstraint] + ): # type: (...) -> None """Construct a function type. diff --git a/python/tvm/relay/ty.pyi b/python/tvm/relay/ty.pyi new file mode 100644 index 000000000000..0581847598d4 --- /dev/null +++ b/python/tvm/relay/ty.pyi @@ -0,0 +1,139 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +"""The type nodes of the Relay language.""" +from enum import IntEnum +from .base import NodeBase, register_relay_node +from . import _make + + +class Type(NodeBase): + """The base type for all Relay types.""" + + def __eq__(self, other): + """Compare two Relay types for structural equivalence using + alpha equivalence. + """ + return bool(_make._type_alpha_eq(self, other)) + + def __ne__(self, other): + return not self.__eq__(other) + + def same_as(self, other): + """Compares two Relay types by referential equality.""" + return super().__eq__(other) + + +@register_relay_node +class TensorType(Type): + """A concrete TensorType in Relay, see tvm/relay/type.h for more details. + + This is the type assigned to tensor's with a known dype and shape. For + example a tensor of `float32` and `(5, 5)`. + """ + + def __init__(self, shape, dtype): + """Construct a tensor type. + + Parameters + ---------- + shape: list of tvm.Expr + dtype: str + + Returns + ------- + tensor_type: The TensorType + """ + self.__init_handle_by_constructor__(_make.TensorType, shape, dtype) + + +class Kind(IntEnum): + """The kind of a type parameter, represents a variable shape, + base type, type, or dimension. + + This controls what a type parameter is allowed to be instantiated + with. For example one's of kind BaseType can only be `float32`, `int32`, + and so on. + """ + ShapeVar = 0 + Shape = 1 + BaseType = 2 + Type = 3 + + +@register_relay_node +class TypeParam(Type): + """A type parameter used for generic types in Relay, + see tvm/relay/type.h for more details. + + A type parameter represents a type placeholder which will + be filled in later on. This allows the user to write + functions which are generic over types. + """ + + def __init__(self, var, kind): + """Construct a TypeParam. + + Parameters + ---------- + var: tvm.expr.Var + The tvm.Var which backs the type parameter. + + kind: Kind + The kind of the type parameter. + + Returns + ------- + type_param: TypeParam + The type parameter. + """ + self.__init_handle_by_constructor__(_make.TypeParam, var, kind) + + +@register_relay_node +class TypeConstraint(Type): + """Abstract class representing a type constraint.""" + pass + + +@register_relay_node +class FuncType(Type): + """A function type in Relay, see tvm/relay/type.h for more details. + + This is the type assigned to functions in Relay. They consist of + a list of type parameters which enable the definition of generic + fucntions, a set of type constraints which we omit for the time + being, a sequence of argument types, and a return type. + + We informally write them as: + `forall (type_params), (arg_types) -> ret_type where type_constraints` + """ + + def __init__(self, + arg_types, + ret_type, + type_params, + type_constraints, + ): + """Construct a function type. + + Parameters + ---------- + arg_types: list of Type + ret_type: Type + type_params: list of TypeParam + type_constraints: list of TypeConstraint + + Returns + ------- + func_type: FuncType + The function type. + """ + self.__init_handle_by_constructor__( + _make.FuncType, arg_types, ret_type, type_params, type_constraints) + + +@register_relay_node +class IncompleteType(Type): + """An incomplete type.""" + + def __init__(self, kind): + self.__init_handle_by_constructor__(_make.IncompleteType, kind) From b1c8ecd49b253544fb976172ddde344a51732605 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 19 Sep 2018 02:33:29 -0700 Subject: [PATCH 131/136] Repair tests --- python/tvm/relay/base.py | 6 +--- python/tvm/relay/expr.py | 7 ++-- python/tvm/relay/ty.py | 30 ++++------------- tests/python/relay/test_ir_builder.py | 5 +-- tests/python/relay/test_ir_nodes.py | 32 +++++++++---------- .../relay/test_tyck_eval_integration.py | 1 - 6 files changed, 30 insertions(+), 51 deletions(-) diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index 60015551bd83..d683c96739cd 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -22,9 +22,5 @@ def register_relay_node(type_key=None): @register_relay_node class Span(NodeBase): - source = None # type: FileSource - lineno = None # type: int - col_offset = None # type: int - - def __init__(self, source, lineno, col_offset): # type: (FileSource, int, int) -> None + def __init__(self, source, lineno, col_offset): self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 1f34294d6c15..3bddbc89b56e 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -1,8 +1,10 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """The expression nodes of Relay.""" +from __future__ import absolute_import from .base import NodeBase, register_relay_node from ._ir_pass import _get_checked_type from . import _make +from .. import convert class Expr(NodeBase): @@ -76,8 +78,9 @@ def __init__(self, body, type_params=None ): - if not type_params: - type_params = [] + if type_params is None: + type_params = convert([]) + self.__init_handle_by_constructor__( _make.Function, params, ret_type, body, type_params) diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index 768b01609ca8..10e267a53977 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -8,16 +8,16 @@ class Type(NodeBase): """The base type for all Relay types.""" - def __eq__(self, other): # type: (Type) -> bool + def __eq__(self, other): """Compare two Relay types for structural equivalence using alpha equivalence. """ return bool(_make._type_alpha_eq(self, other)) - def __ne__(self, other): # (Type) -> bool + def __ne__(self, other): return not self.__eq__(other) - def same_as(self, other): # (Type) -> bool + def same_as(self, other): """Compares two Relay types by referential equality.""" return super().__eq__(other) @@ -29,12 +29,8 @@ class TensorType(Type): This is the type assigned to tensor's with a known dype and shape. For example a tensor of `float32` and `(5, 5)`. """ - shape = None # type: List[expr.Expr] - dtype = None # type: str - span = None # type: Span def __init__(self, shape, dtype): - # type: (List[expr.Expr], str) -> None """Construct a tensor type. Parameters @@ -72,12 +68,8 @@ class TypeParam(Type): be filled in later on. This allows the user to write functions which are generic over types. """ - var = None # type: expr.Var - kind = None # type: Kind - span = None # type: Span def __init__(self, var, kind): - # type: (expr.Var, Kind) -> None """Construct a TypeParam. Parameters @@ -114,19 +106,12 @@ class FuncType(Type): We informally write them as: `forall (type_params), (arg_types) -> ret_type where type_constraints` """ - type_params = None # type: List[TypeParam] - type_constraints = None # type: List[TypeConstraint] - arg_types = None # type: List[Type] - ret_type = None # type: Type - span = None # type: Span - def __init__(self, - arg_types, # type: List[Type], - ret_type, # type: Type, - type_params, # type: List[TypeParam], - type_constraints # type: List[TypeConstraint] + arg_types, + ret_type, + type_params, + type_constraints ): - # type: (...) -> None """Construct a function type. Parameters @@ -150,5 +135,4 @@ class IncompleteType(Type): """An incomplete type.""" def __init__(self, kind): - # type: (Kind) -> None self.__init_handle_by_constructor__(_make.IncompleteType, kind) diff --git a/tests/python/relay/test_ir_builder.py b/tests/python/relay/test_ir_builder.py index 666d7ff25659..c98f920ca491 100644 --- a/tests/python/relay/test_ir_builder.py +++ b/tests/python/relay/test_ir_builder.py @@ -6,7 +6,7 @@ def test_let(): b = IRBuilder() x = b.let('x', 1) b.ret(x) - prog = b.get() + prog, _ = b.get() assert isinstance(prog, Let) var = prog.var value = prog.value @@ -16,8 +16,5 @@ def test_let(): assert value.data.asnumpy() == np.array(1) assert prog.value_type == None -# def test_function(): -# b = IRBuilder() - if __name__ == "__main__": test_let() diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index cf035f8a2b19..803b3d0faa0c 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -4,7 +4,7 @@ from tvm.expr import * # Span -def test_span() -> None: +def test_span(): span = relay.Span(None, 1, 1) assert span.source == None assert span.lineno == 1 @@ -16,7 +16,7 @@ def test_span() -> None: # Types -def test_tensor_type() -> None: +def test_tensor_type(): shape = tvm.convert([1, 2, 3]) dtype = 'float32' tt = relay.TensorType(shape, dtype) @@ -26,14 +26,14 @@ def test_tensor_type() -> None: str(tt) -def test_type_param() -> None: +def test_type_param(): tp = relay.TypeParam('name', relay.Kind.Shape) tp.kind == relay.Kind.Shape tp.span # TODO allow us to set span str(tp) -def test_func_type() -> None: +def test_func_type(): type_params = tvm.convert([]) type_constraints = tvm.convert([]) # TODO: fill me in arg_types = tvm.convert([]) @@ -48,7 +48,7 @@ def test_func_type() -> None: str(tf) -def test_constant() -> None: +def test_constant(): arr = tvm.nd.array(10) const = relay.Constant(arr) assert const.data == arr @@ -56,7 +56,7 @@ def test_constant() -> None: str(const) -def test_tuple() -> None: +def test_tuple(): fields = tvm.convert([]) tup = relay.Tuple(fields) assert tup.fields == fields @@ -64,7 +64,7 @@ def test_tuple() -> None: str(tup) -def test_local_var() -> None: +def test_local_var(): name_hint = 's' lv = relay.Var(name_hint) lv.name_hint == name_hint @@ -72,7 +72,7 @@ def test_local_var() -> None: str(lv) -def test_global_var() -> None: +def test_global_var(): name_hint = 'g' gv = relay.GlobalVar(name_hint) gv.name_hint == name_hint @@ -80,7 +80,7 @@ def test_global_var() -> None: str(gv) -def test_param() -> None: +def test_param(): lv = relay.Var('x') ty = None param = relay.Param(lv, ty) @@ -90,7 +90,7 @@ def test_param() -> None: str(param) -def test_function() -> None: +def test_function(): param_names = ['a', 'b', 'c', 'd'] params = tvm.convert([relay.Param(relay.Var(n), None) for n in param_names]) ret_type = None @@ -104,7 +104,7 @@ def test_function() -> None: str(fn) -def test_call() -> None: +def test_call(): op = relay.Var('f') arg_names = ['a', 'b', 'c', 'd'] args = tvm.convert([relay.Var(n) for n in arg_names]) @@ -115,13 +115,13 @@ def test_call() -> None: str(call) -def test_let() -> None: +def test_let(): lv = relay.Var('x') ty = None arr = tvm.nd.array(10) value = relay.Constant(arr) # I would prefer that the order of arguments - # matches syntax let x : t = v in b + # matches syntax let x: t = v in b let = relay.Let(lv, value, lv, ty) assert let.var == lv assert let.value == value @@ -131,14 +131,14 @@ def test_let() -> None: str(let) -def test_if() -> None: +def test_if(): cond = relay.Var('cond') left = relay.Var('left') right = relay.Var('right') ife = relay.If(cond, left, right) assert ife.cond == cond - assert ife.true_value == left - assert ife.false_value == right + assert ife.true_branch == left + assert ife.false_branch == right assert ife.span == None str(ife) diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index b8f3dfc0fb34..d8190f0da2e8 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -3,7 +3,6 @@ """ import tvm import numpy as np -from nnvm import graph from tvm.relay.ir_pass import check_expr from tvm.relay.ir_builder import IRBuilder, func_type from tvm.relay.ir_builder import scalar_type, convert, tensor_type From b5c562693a76677456221584eb962104ee8be05c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 19 Sep 2018 03:49:57 -0700 Subject: [PATCH 132/136] Add initial version of concat --- python/tvm/relay/op/tensor.py | 25 ++++++- src/relay/op/tensor/elemwise.cc | 15 +++- src/relay/op/type_relations.cc | 71 ++++++++++++++++++- src/relay/op/type_relations.h | 35 ++++++--- .../relay/test_tyck_eval_integration.py | 33 ++++++--- 5 files changed, 160 insertions(+), 19 deletions(-) diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index 57fbccf488dc..1b84e92233c7 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -1,6 +1,7 @@ """Basic tensor operations.""" from __future__ import absolute_import as _abs from . import _make +from ..expr import Tuple # We create a wrapper function for each operator in the # python side to call into the positional _make.OpName function. @@ -61,7 +62,7 @@ def sqrt(data): def add(lhs, rhs): - """Take sqrt of data. + """Elementwise addition. Parameters ---------- @@ -78,6 +79,23 @@ def add(lhs, rhs): return _make.add(lhs, rhs) +def subtract(lhs, rhs): + """Elementwise subtraction. + + Parameters + ---------- + lhs : relay.Expr + The left hand side input data + rhs : relay.Expr + The right hand side input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.add(lhs, rhs) + def subtract(lhs, rhs): """Take sqrt of data. @@ -98,3 +116,8 @@ def subtract(lhs, rhs): def equal(lhs, rhs): return _make.equal(lhs, rhs) + +def concat(*args): + tup = Tuple(list(args)) + return _make.concat(tup) + diff --git a/src/relay/op/tensor/elemwise.cc b/src/relay/op/tensor/elemwise.cc index df806899a7a0..8c1823114f44 100644 --- a/src/relay/op/tensor/elemwise.cc +++ b/src/relay/op/tensor/elemwise.cc @@ -106,7 +106,7 @@ RELAY_REGISTER_OP("subtract") // input2: Tensor[dtype, s2] // output: Tensor[dtype, broadcast(s1, s2)] -// Addition +// Equality Comparison TVM_REGISTER_API("relay.op._make.equal") .set_body_typed([](Expr lhs, Expr rhs) { static const Op& op = Op::Get("equal"); @@ -120,5 +120,18 @@ RELAY_REGISTER_OP("equal") .set_support_level(1) .add_type_rel("BroadcastComp", BroadcastCompRel); +// Concat +TVM_REGISTER_API("relay.op._make.concat") + .set_body_typed([](Expr tuple) { + static const Op& op = Op::Get("concat"); + return CallNode::make(op, { tuple }, Attrs(), {}); + }); + +RELAY_REGISTER_OP("concat") + .set_num_inputs(1) + .add_argument("tuple", "Tuple", "The tupled tensor arguments.") + .set_support_level(1) + .add_type_rel("Concat", ConcatRel); + } // namespace relay } // namespace tvm diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 4a2464044a1f..168307d288eb 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -7,7 +7,9 @@ #include #include #include +#include #include "../pass/incomplete_type.h" +#include "./type_relations.h" namespace tvm { namespace relay { @@ -20,7 +22,7 @@ TensorType ToTensorType(const Type& t) { } } -// TODO(@jroesch) what size value do we extract? +// TODO(@jroesch) what size value do we extract, 64bit or 32bit? int ToInt(const tvm::Expr& e) { CHECK(e.defined()); auto imm = e.as(); @@ -133,5 +135,72 @@ Array BroadcastCompRel(const Array& types, int num_args) { return types; } +/*! \brief Handle concrete concat case from known input to output. */ +inline Type ConcreteConcatRel(const Type& input_type) { + if (auto tuple_node = input_type.as()) { + // NB: For now the axis argument is hardwired to be 0. + std::vector dims; + DataType dtype; + + CHECK_LT(1, tuple_node->fields.size()); + bool skip_first = true; + + // Collect the suffix dimensions since axis is zero. + // TODO(@jroesch): This is a demonstration of how + // to do varargs. It requires a little more work to + // fully type the behavior of concat. + + auto first = Downcast(tuple_node->fields[0]); + dtype = first->dtype; + + for (auto dim_expr : first->shape) { + if (!skip_first) { + dims.push_back(ToInt(dim_expr)); + } else { + skip_first = false; + } + } + + std::vector axis_dims; + for (auto field_ty : tuple_node->fields) { + auto ttype = Downcast(field_ty); + for (size_t i = 0; i < ttype->shape.size(); i++) { + if (i != 0) { + CHECK_EQ(ToInt(dims[i - 1]), ToInt(ttype->shape[i])); + } else { + axis_dims.push_back(ToInt(ttype->shape[i])); + } + } + } + + auto out_axis_dim = std::accumulate(axis_dims.begin(), axis_dims.end(), 0); + + Array out_shape = { tvm::ir::IntImm::make(HalideIR::Int(64), out_axis_dim) }; + + for (auto dim : dims) { + out_shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), dim)); + } + + return TensorTypeNode::make(out_shape, dtype); + + } else { + throw TypeRelationError("concat can only be used with a tuple as its argument"); + } +} + +Array ConcatRel(const Array& types, int num_args) { + CHECK_EQ(types.size(), 2); + + if (types[0].as() && types[1].as()) { + return types; + } else if (types[1].as()) { + return { types[0], ConcreteConcatRel(types[0]) }; + } else { + throw TypeRelationError( + "can not deduce relationship between the " \ + "type of concat's input and output"); + } +} + } // namespace relay } // namespace tvm diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h index 521c6f8e1681..9dfc29022ee3 100644 --- a/src/relay/op/type_relations.h +++ b/src/relay/op/type_relations.h @@ -7,40 +7,59 @@ #ifndef TVM_RELAY_OP_TYPE_RELATIONS_H_ #define TVM_RELAY_OP_TYPE_RELATIONS_H_ +#include #include #include namespace tvm { namespace relay { +/*! \brief The error raised by a type relation. + * + * This error is how a type relation signals that it has failed. + * + */ +struct TypeRelationError : Error { + explicit TypeRelationError(const std::string& msg) + : Error(msg) {} +}; + /*! \brief The identity type relation maps a single input variable * to the output variable. - * + * * \param types The input and output types to the relation. * \param num_args The number of input arguments. * \return The (potentially partial) solution to the relation. */ -Array IdentityRel(const Array & types, int num_args); +Array IdentityRel(const Array& types, int num_args); /*! \brief The broadcast type relation, implements the broadcasting * rule over the two input types producing the broadcasted type. - * + * * \param types The input and output types to the relation. * \param num_args The number of input arguments. * \return The (potentially partial) solution to the relation. */ -Array BroadcastRel(const Array & types, int num_args); +Array BroadcastRel(const Array& types, int num_args); /*! \brief The broadcast type relation, implements the broadcasting - * rule over the two input types producing the broadcasted type. - * + * rule over the two input types producing the broadcasted type. + * * This differs from BroadcastRel in the return dtype, * it instead returns bool, for use in comparsion operators * such as equal, not_equal, lt, and so on. - * + * * \param types The input and output types to the relation. * \param num_args The number of input arguments. * \return The (potentially partial) solution to the relation. */ -Array BroadcastCompRel(const Array & types, int num_args); +Array BroadcastCompRel(const Array& types, int num_args); + +/*! \brief The concat relation. + * + * This relation takes a single input which must be a single tensor + * or an arbitrary sized tuple. It combines these input dimensions + * together to produce the output example. + */ +Array ConcatRel(const Array& types, int num_args); } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_tyck_eval_integration.py b/tests/python/relay/test_tyck_eval_integration.py index d8190f0da2e8..d95cda0ba819 100644 --- a/tests/python/relay/test_tyck_eval_integration.py +++ b/tests/python/relay/test_tyck_eval_integration.py @@ -7,7 +7,7 @@ from tvm.relay.ir_builder import IRBuilder, func_type from tvm.relay.ir_builder import scalar_type, convert, tensor_type from tvm.relay.env import Environment -from tvm.relay.op import log, add, equal, subtract +from tvm.relay.op import log, add, equal, subtract, concat from tvm.relay.expr import Function def assert_has_type(expr, typ, env=Environment({})): @@ -135,11 +135,28 @@ def f(n: i32, data: f32) -> f32 { # TODO(@jroesch): need evaluator or new runtime # to execute this. +def test_concat(): + """ + Program: + def try_concat2(x: Float(3, 2), y: Float(2, 2)) -> Float(5, 2) { + return concat(x, y); + } + """ + ib = IRBuilder() + try_concat2 = ib.global_var('try_concat2') + x = ib.param('x', ty=tensor_type(3, 2)) + y = ib.param('y', ty=tensor_type(2, 2)) + with ib.decl(try_concat2, x, y): + ib.ret(concat(x, y)) + fn_ty = func_type([tensor_type(3, 2), tensor_type(2, 2)], tensor_type(5, 2)) + assert_decl_has_type(ib.env, try_concat2, fn_ty) + if __name__ == "__main__": - test_monomorphic_let() - test_single_op() - test_add_op() - test_add_broadcast_op() - test_dual_op() - test_decl() - test_recursion() + # test_monomorphic_let() + # test_single_op() + # test_add_op() + # test_add_broadcast_op() + # test_dual_op() + # test_decl() + # test_recursion() + test_concat() From 388b27575fa6b210ee67312dc78051e0aa78a0e0 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 19 Sep 2018 13:28:13 -0700 Subject: [PATCH 133/136] Fix issue in type_relations.cc --- src/relay/op/type_relations.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 168307d288eb..94550dbd5075 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -69,8 +69,8 @@ static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, rev_sh2++; } - Array larger; - Array smaller; + Array larger; + Array smaller; for (int i = 0; i < (full_len - suffix_len); i++) { smaller.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), 1)); From 81c9d6d03679b47b8f0819b7f2086a5b360e3226 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 19 Sep 2018 13:29:09 -0700 Subject: [PATCH 134/136] Fix PyLint ... again --- python/tvm/relay/op/tensor.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index 1b84e92233c7..5864881be4dc 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -96,28 +96,9 @@ def subtract(lhs, rhs): """ return _make.add(lhs, rhs) -def subtract(lhs, rhs): - """Take sqrt of data. - - Parameters - ---------- - lhs : relay.Expr - The left hand side input data - rhs : relay.Expr - The right hand side input data - - Returns - ------- - result : relay.Expr - The computed result. - """ - return _make.add(lhs, rhs) - - def equal(lhs, rhs): return _make.equal(lhs, rhs) def concat(*args): tup = Tuple(list(args)) return _make.concat(tup) - From 4e0c7b8b558ce00b8e1ed10b541798b21fb95d2c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 19 Sep 2018 14:11:50 -0700 Subject: [PATCH 135/136] Fix issue with test case --- python/tvm/relay/ir_builder.py | 9 +++++---- python/tvm/relay/op/tensor.py | 10 ++++++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 1ef09f1874a5..6e52f209d0c6 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -3,6 +3,7 @@ Enables users to construct Relay programs with a Python API. """ +from collections import OrderedDict import numpy as np import tvm from .ty import Type, FuncType, TensorType @@ -143,8 +144,8 @@ class IRBuilder(object): """ def __init__(self): - self.bindings = [{}] - self.scopes = [{}] + self.bindings = [OrderedDict({})] + self.scopes = [OrderedDict({})] self.params = [] self.ret_values = [None] self.env = Environment({}) @@ -153,8 +154,8 @@ def enter_scope(self, params=None): if not params: params = [] - self.bindings.append({}) - self.scopes.append({}) + self.bindings.append(OrderedDict({})) + self.scopes.append(OrderedDict({})) self.params.append(params) self.ret_values.append(None) diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index 5864881be4dc..fa54d8b53dd8 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -100,5 +100,15 @@ def equal(lhs, rhs): return _make.equal(lhs, rhs) def concat(*args): + """Concatenate the input tensors along the zero axis. + + Parameters + ---------- + args: list of Tensor + + Returns + ------- + tensor: The concatenated tensor. + """ tup = Tuple(list(args)) return _make.concat(tup) From 052d61913e67d4f31eef59007dae699db512e9b5 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 19 Sep 2018 14:58:49 -0700 Subject: [PATCH 136/136] Fix & style --- src/relay/pass/type_infer.cc | 104 +++++++++++++++++------------------ 1 file changed, 52 insertions(+), 52 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index bdda3940fc6e..f4f6d82eb5e1 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -70,24 +70,24 @@ struct TypeContext { TypeContext() { constraints.push_back({}); } - void Insert(const Var &id, const Type &t) { var_map[id] = t; } + void Insert(const Var& id, const Type& t) { var_map[id] = t; } - void AddConstraint(const TypeConstraint &constraint) { + void AddConstraint(const TypeConstraint& constraint) { constraints.back().push_back(TypeRelationData(Downcast(constraint))); } - Type Lookup(const Var &id) { - auto type = var_map.find(id); + Type Lookup(const Var& var) { + auto type = var_map.find(var); if (type != var_map.end()) { return (*type).second; } else { - throw FatalTypeError("Could not resolve local id"); + throw FatalTypeError(std::string("undeclared local variable: ") + var->name_hint); } } struct Scope { - TypeContext &tc; - explicit Scope(TypeContext &tc) : tc(tc) { tc.constraints.push_back({}); } + TypeContext& tc; + explicit Scope(TypeContext& tc) : tc(tc) { tc.constraints.push_back({}); } ~Scope() { tc.constraints.pop_back(); } }; }; @@ -101,7 +101,7 @@ struct CheckedExpr { enum SolverResult : int; -class TypeInferencer : private ExprFunctor { +class TypeInferencer : private ExprFunctor { private: TypeContext context; @@ -110,7 +110,7 @@ class TypeInferencer : private ExprFunctor { TypeUnifier unifier; template - T WithScope(const std::function &f) { + T WithScope(const std::function& f) { TypeContext::Scope fr(context); return f(); } @@ -124,41 +124,41 @@ class TypeInferencer : private ExprFunctor { FuncType Instantiate(FuncType fn_ty, tvm::Array &ty_args); - Type Normalize(const Type &t); + Type Normalize(const Type& t); - void ReportError(const std::string &msg, Span sp); - [[noreturn]] void FatalError(const std::string &msg, Span sp); + void ReportError(const std::string& msg, Span sp); + [[noreturn]] void FatalError(const std::string& msg, Span sp); - Type Unify(const Type &t1, const Type &t2, Span sp); + Type Unify(const Type &t1, const Type& t2, Span sp); Type Resolve(const Type &t); Expr Resolve(const Expr &e); /*! \brief Attempt to solve a single relation. */ - void Solve(TypeRelationData & ty_rel); + void Solve(TypeRelationData& ty_rel); /*! \brief Attempt to solve all pending relations. * * If the solver */ - SolverResult Solve(std::vector &rels); + SolverResult Solve(std::vector& rels); /*! \brief Check that all relations hold. */ bool RelationsHold(bool scope_only = false); /*! \brief Visit a function node, extra flag controls behavior. */ - CheckedExpr VisitFunction(const Function &f, bool generalize); + CheckedExpr VisitFunction(const Function& f, bool generalize); private: - CheckedExpr VisitExpr_(const VarNode *op) override; - CheckedExpr VisitExpr_(const GlobalVarNode *op) override; - CheckedExpr VisitExpr_(const ConstantNode *op) override; - CheckedExpr VisitExpr_(const TupleNode *op) override; - CheckedExpr VisitExpr_(const ParamNode *op) override; - CheckedExpr VisitExpr_(const FunctionNode *op) override; - CheckedExpr VisitExpr_(const CallNode *op) override; - CheckedExpr VisitExpr_(const LetNode *op) override; - CheckedExpr VisitExpr_(const IfNode *op) override; - CheckedExpr VisitExpr_(const OpNode *op) override; + CheckedExpr VisitExpr_(const VarNode* op) override; + CheckedExpr VisitExpr_(const GlobalVarNode* op) override; + CheckedExpr VisitExpr_(const ConstantNode* op) override; + CheckedExpr VisitExpr_(const TupleNode* op) override; + CheckedExpr VisitExpr_(const ParamNode* op) override; + CheckedExpr VisitExpr_(const FunctionNode* op) override; + CheckedExpr VisitExpr_(const CallNode* op) override; + CheckedExpr VisitExpr_(const LetNode* op) override; + CheckedExpr VisitExpr_(const IfNode* op) override; + CheckedExpr VisitExpr_(const OpNode* op) override; }; TypeInferencer::TypeInferencer() { @@ -170,7 +170,7 @@ TypeInferencer::TypeInferencer(Environment env) : env(env) { this->unifier = TypeUnifierNode::make(UnionFindNode::make({})); } -CheckedExpr TypeInferencer::Infer(const Expr &expr) { +CheckedExpr TypeInferencer::Infer(const Expr& expr) { RELAY_LOG(INFO) << "TypeInferencer::Check expr=" << expr << std::endl; CheckedExpr checked_expr = this->VisitExpr(expr); RELAY_LOG(INFO) << "TypeInferencer::Check type=" << checked_expr.type @@ -182,22 +182,22 @@ CheckedExpr TypeInferencer::Infer(const Expr &expr) { return checked_expr; } -CheckedExpr TypeInferencer::VisitExpr_(const VarNode *op) { +CheckedExpr TypeInferencer::VisitExpr_(const VarNode* op) { auto var = GetRef(op); return {var, this->context.Lookup(var)}; } -CheckedExpr TypeInferencer::VisitExpr_(const GlobalVarNode *op) { +CheckedExpr TypeInferencer::VisitExpr_(const GlobalVarNode* op) { GlobalVar var = GetRef(op); Expr e = this->env->Lookup(var); return {var, e->checked_type()}; } -CheckedExpr TypeInferencer::VisitExpr_(const ConstantNode *const_node) { +CheckedExpr TypeInferencer::VisitExpr_(const ConstantNode* const_node) { return {GetRef(const_node), const_node->tensor_type()}; } -CheckedExpr TypeInferencer::VisitExpr_(const TupleNode *op) { +CheckedExpr TypeInferencer::VisitExpr_(const TupleNode* op) { Tuple pl = GetRef(op); std::vector field_exprs; @@ -211,7 +211,7 @@ CheckedExpr TypeInferencer::VisitExpr_(const TupleNode *op) { return {TupleNode::make(field_exprs), TupleTypeNode::make(field_types)}; } -CheckedExpr TypeInferencer::VisitExpr_(const ParamNode *param) { +CheckedExpr TypeInferencer::VisitExpr_(const ParamNode* param) { // We should trigger error here and move param code direclty into function // checking. auto rtype = this->Resolve(param->type); @@ -221,7 +221,7 @@ CheckedExpr TypeInferencer::VisitExpr_(const ParamNode *param) { return {ParamNode::make(param->var, rtype), rtype}; } -CheckedExpr TypeInferencer::VisitFunction(const Function &f, bool generalize) { +CheckedExpr TypeInferencer::VisitFunction(const Function& f, bool generalize) { // First we add the parameters to the context allowing us to check their // types. @@ -258,12 +258,12 @@ CheckedExpr TypeInferencer::VisitFunction(const Function &f, bool generalize) { }); } -CheckedExpr TypeInferencer::VisitExpr_(const FunctionNode *op) { +CheckedExpr TypeInferencer::VisitExpr_(const FunctionNode* op) { return this->VisitFunction(GetRef(op), false); } FuncType TypeInferencer::Instantiate(FuncType fn_ty, - tvm::Array &ty_args) { + tvm::Array& ty_args) { tvm::Map subst_map; // Build a subsitituion map up from the function type and type arguments. @@ -284,7 +284,7 @@ FuncType TypeInferencer::Instantiate(FuncType fn_ty, return GetRef(inst_ty.as()); } -CheckedExpr TypeInferencer::VisitExpr_(const CallNode *op) { +CheckedExpr TypeInferencer::VisitExpr_(const CallNode* op) { Call c = GetRef(op); auto checked_op = this->Infer(c->op); @@ -352,7 +352,7 @@ CheckedExpr TypeInferencer::VisitExpr_(const CallNode *op) { return {new_call, fn_ty->ret_type}; } -CheckedExpr TypeInferencer::VisitExpr_(const LetNode *op) { +CheckedExpr TypeInferencer::VisitExpr_(const LetNode* op) { Let let = GetRef(op); CheckedExpr checked_value; @@ -382,7 +382,7 @@ CheckedExpr TypeInferencer::VisitExpr_(const LetNode *op) { return {checked_let, checked_body.type}; } -CheckedExpr TypeInferencer::VisitExpr_(const IfNode *op) { +CheckedExpr TypeInferencer::VisitExpr_(const IfNode* op) { If ifn = GetRef(op); // Ensure the type of the guard is of Tensor[Bool, ()], @@ -401,7 +401,7 @@ CheckedExpr TypeInferencer::VisitExpr_(const IfNode *op) { return {checked_if, unified_type}; } -CheckedExpr TypeInferencer::VisitExpr_(const OpNode *op_node) { +CheckedExpr TypeInferencer::VisitExpr_(const OpNode* op_node) { auto op = GetRef(op_node); return {op, op->op_type}; } @@ -436,7 +436,7 @@ void TypeInferencer::Solve(TypeRelationData & ty_rel) { } } -int NumSolvedVars(const Array & vars) { +int NumSolvedVars(const Array& vars) { int num = 0; for (auto var : vars) { if (!var.as()) { @@ -452,7 +452,7 @@ enum SolverResult : int { Done = 1, }; -SolverResult TypeInferencer::Solve(std::vector &rels) { +SolverResult TypeInferencer::Solve(std::vector& rels) { // We start in the done state with zero progress. SolverResult status = SolverResult::Done; int progress = 0; @@ -466,7 +466,7 @@ SolverResult TypeInferencer::Solve(std::vector &rels) { int i = 0; // We will now process each relation in order. - for (TypeRelationData &ty_rel : rels) { + for (TypeRelationData& ty_rel : rels) { int arity = ty_rel.args.size(); int pre_solved = NumSolvedVars(ty_rel.args); RELAY_LOG(INFO) << "TypeInferencer::Solve: " @@ -548,15 +548,15 @@ bool TypeInferencer::RelationsHold(bool scope_only) { return all_hold; } -Expr InferType(const Environment &env, const Expr &e) { +Expr InferType(const Environment& env, const Expr& e) { TypeInferencer ti(env); auto checked_expr = ti.Infer(e); CHECK(ti.RelationsHold()); return ti.Resolve(checked_expr.expr); } -Expr InferType(const Environment &env, const GlobalVar &var, - const Function &func) { +Expr InferType(const Environment& env, const GlobalVar& var, + const Function& func) { TypeInferencer ti(env); auto func_copy = FunctionNode::make(func->params, func->ret_type, func->body, func->type_params); @@ -569,14 +569,14 @@ Expr InferType(const Environment &env, const GlobalVar &var, return ti.Resolve(checked_expr.expr); } -void TypeInferencer::FatalError(const std::string &msg, Span sp) { +void TypeInferencer::FatalError(const std::string& msg, Span sp) { throw FatalTypeError( "internal error: this exception should" "be handled and errors reported with Environment::display_errors\n" + msg); } -Type TypeInferencer::Unify(const Type &t1, const Type &t2, Span sp) { +Type TypeInferencer::Unify(const Type& t1, const Type& t2, Span sp) { try { return this->unifier->Unify(t1, t2); } catch (const dmlc::Error &e) { @@ -591,7 +591,7 @@ Type TypeInferencer::Unify(const Type &t1, const Type &t2, Span sp) { } TVM_REGISTER_API("relay._ir_pass.check_expr") - .set_body([](TVMArgs args, TVMRetValue *ret) { + .set_body([](TVMArgs args, TVMRetValue* ret) { Environment env = args[0]; Expr e = args[1]; *ret = InferType(env, e); @@ -599,7 +599,7 @@ TVM_REGISTER_API("relay._ir_pass.check_expr") // TODO(@jroesch): put in a better namespace. TVM_REGISTER_API("relay._ir_pass._get_checked_type") - .set_body([](TVMArgs args, TVMRetValue *ret) { + .set_body([](TVMArgs args, TVMRetValue* ret) { Expr e = args[0]; *ret = e->checked_type(); }); @@ -614,14 +614,14 @@ IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) { } TVM_REGISTER_API("relay._make.IncompleteType") - .set_body([](TVMArgs args, TVMRetValue *ret) { + .set_body([](TVMArgs args, TVMRetValue* ret) { int kind = args[0]; *ret = IncompleteTypeNode::make(static_cast(kind)); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const IncompleteTypeNode *node, - tvm::IRPrinter *p) { + .set_dispatch([](const IncompleteTypeNode* node, + tvm::IRPrinter* p) { p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; });