diff --git a/CMakeLists.txt b/CMakeLists.txt index c23d403bcb6a..534a9f80b1ac 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -76,7 +76,6 @@ if(MSVC) add_definitions(-D_CRT_SECURE_NO_WARNINGS) add_definitions(-D_SCL_SECURE_NO_WARNINGS) add_definitions(-D_ENABLE_EXTENDED_ALIGNED_STORAGE) - add_definitions(-DHalide_SHARED) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /bigobj") @@ -112,8 +111,8 @@ else(MSVC) endif(MSVC) # add source group -FILE(GLOB_RECURSE GROUP_SOURCE "src/*.cc" "3rdparty/HalideIR/src/*.cpp" "nnvm/src/*.cc") -FILE(GLOB_RECURSE GROUP_INCLUDE "src/*.h" "include/*.h" "3rdparty/HalideIR/src/*.h" +FILE(GLOB_RECURSE GROUP_SOURCE "src/*.cc" "nnvm/src/*.cc") +FILE(GLOB_RECURSE GROUP_INCLUDE "src/*.h" "include/*.h" "nnvm/src/*.h" "nnvm/include/*.h") assign_source_group("Source" ${GROUP_SOURCE}) assign_source_group("Include" ${GROUP_INCLUDE}) @@ -127,6 +126,7 @@ file(GLOB COMPILER_SRCS src/lang/*.cc src/pass/*.cc src/op/*.cc + src/node/*.cc src/schedule/*.cc ) @@ -154,12 +154,7 @@ file(GLOB_RECURSE NNVM_COMPILER_SRCS file(GLOB TOPI_SRCS topi/src/*.cc ) -file(GLOB_RECURSE HALIDEIR_SRCS - 3rdparty/HalideIR/src/base/*.cpp - 3rdparty/HalideIR/src/ir/*.cpp - 3rdparty/HalideIR/src/tvm/*.cpp -) -list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS}) + file(GLOB RUNTIME_SRCS src/runtime/*.cc src/runtime/vm/*.cc @@ -245,7 +240,6 @@ target_link_libraries(nnvm_compiler tvm) # Related headers target_include_directories( tvm - PUBLIC "3rdparty/HalideIR/src" PUBLIC "topi/include") target_include_directories( tvm_topi @@ -294,11 +288,6 @@ if (INSTALL_DEV) FILES_MATCHING PATTERN "*.h" ) - install( - DIRECTORY "3rdparty/HalideIR/src/." DESTINATION "include/HalideIR" - FILES_MATCHING - PATTERN "*.h" - ) install( DIRECTORY "3rdparty/dlpack/include/." DESTINATION "include" FILES_MATCHING @@ -319,8 +308,6 @@ endif(INSTALL_DEV) # More target definitions if(MSVC) - target_compile_definitions(tvm PRIVATE -DHalide_EXPORTS) - target_compile_definitions(tvm_runtime PRIVATE -DHalide_EXPORTS) target_compile_definitions(tvm PRIVATE -DTVM_EXPORTS) target_compile_definitions(tvm_runtime PRIVATE -DTVM_EXPORTS) target_compile_definitions(nnvm_compiler PRIVATE -DNNVM_EXPORTS) diff --git a/apps/howto_deploy/tvm_runtime_pack.cc b/apps/howto_deploy/tvm_runtime_pack.cc index 736add6e9fa3..6ebad8177cd5 100644 --- a/apps/howto_deploy/tvm_runtime_pack.cc +++ b/apps/howto_deploy/tvm_runtime_pack.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index 446c4c0c19a9..105cbf7af8e9 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -591,7 +591,7 @@ IntSet EvalSet(Range r, const std::unordered_map& dom_map); /*! \brief Map from Expr to IntSet */ -using ExprIntSetMap = std::unordered_map; +using ExprIntSetMap = std::unordered_map; /*! * \brief Find the integer set of every sub-expression, given the * domain of each iteration variables. diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h index ed021beb5d35..10fbe9f2ce4d 100644 --- a/include/tvm/attrs.h +++ b/include/tvm/attrs.h @@ -89,8 +89,8 @@ inline TNodeRef NullValue() { } template<> -inline Type NullValue() { - return Type(Type::Handle, 0, 0); +inline DataType NullValue() { + return DataType(kHandle, 0, 0); } /*! \brief Error thrown during attribute checking. */ diff --git a/include/tvm/data_layout.h b/include/tvm/data_layout.h index ff5f8e37dbb6..a703d928ba5f 100644 --- a/include/tvm/data_layout.h +++ b/include/tvm/data_layout.h @@ -221,7 +221,7 @@ class Layout : public NodeRef { if (!this->defined()) return -1; const auto axes = operator->()->axes; for (size_t i = 0; i < axes.size(); ++i) { - if (axes[i]->var.get()->name_hint == axis.name()) return static_cast(i); + if (axes[i]->var->name_hint == axis.name()) return static_cast(i); } return -1; } @@ -243,7 +243,7 @@ class Layout : public NodeRef { bool Contains(const LayoutAxis& axis) const { if (!defined()) return false; for (const IterVar var : operator->()->axes) { - if (var->var.get()->name_hint == axis.name()) { + if (var->var->name_hint == axis.name()) { return true; } } diff --git a/include/tvm/dtype.h b/include/tvm/dtype.h new file mode 100644 index 000000000000..60a96a3cc320 --- /dev/null +++ b/include/tvm/dtype.h @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file tvm/dtype.h + * \brief Data type used in IR. + */ +#ifndef TVM_DTYPE_H_ +#define TVM_DTYPE_H_ + +#include "runtime/packed_func.h" + +namespace tvm { +class Expr; + +/*! + * \brief Primitive data types in tvm. + */ +class DataType { + public: + /*! \brief default constructor */ + DataType() {} + /*! + * \brief Constructor + * \param dtype The DLDataType + */ + explicit DataType(DLDataType dtype) + : data_(dtype) {} + /*! + * \brief Constructor + * \param code The type code. + * \param bits The number of bits in the type. + * \param lanes The number of lanes. + */ + DataType(int code, int bits, int lanes) { + data_.code = static_cast(code); + data_.bits = static_cast(bits); + data_.lanes = static_cast(lanes); + } + /*! \return The type code. */ + int code() const { + return static_cast(data_.code); + } + /*! \return number of bits in the data. */ + int bits() const { + return static_cast(data_.bits); + } + /*! \return number of bytes to store each scalar. */ + int bytes() const { + return (bits() + 7) / 8; + } + /*! \return number of lanes in the data. */ + int lanes() const { + return static_cast(data_.lanes); + } + /*! \return whether type is a scalar type. */ + bool is_scalar() const { + return lanes() == 1; + } + /*! \return whether type is a scalar type. */ + bool is_bool() const { + return code() == kDLUInt && bits() == 1; + } + /*! \return whether type is a float type. */ + bool is_float() const { + return code() == kDLFloat; + } + /*! \return whether type is an int type. */ + bool is_int() const { + return code() == kDLInt; + } + /*! \return whether type is an uint type. */ + bool is_uint() const { + return code() == kDLUInt; + } + /*! \return whether type is a handle type. */ + bool is_handle() const { + return code() == kHandle; + } + /*! \return whether type is a vector type. */ + bool is_vector() const { + return lanes() > 1; + } + /*! + * \brief Create a new data type by change lanes to a specified value. + * \param lanes The target number of lanes. + * \return the result type. + */ + DataType with_lanes(int lanes) const { + return DataType(data_.code, data_.bits, lanes); + } + /*! + * \brief Create a new data type by change bits to a specified value. + * \param bits The target number of bits. + * \return the result type. + */ + DataType with_bits(int bits) const { + return DataType(data_.code, bits, data_.lanes); + } + /*! + * \brief Get the scalar version of the type. + * \return the result type. + */ + DataType element_of() const { + return with_lanes(1); + } + // operator overloadings + bool operator==(const DataType& other) const { + return + data_.code == other.data_.code && + data_.bits == other.data_.bits && + data_.lanes == other.data_.lanes; + } + bool operator!=(const DataType& other) const { + return !operator==(other); + } + operator DLDataType () const { + return data_; + } + /*! \return the maximum possible value in this format. */ + TVM_DLL Expr max() const; + /*! \return the minimum possible value in this format. */ + TVM_DLL Expr min() const; + + private: + DLDataType data_; +}; + +/*! + * \brief Construct an int type. + * \param bits The number of bits in the type. + * \param lanes The number of lanes. + * \return The constructed data type. + */ +inline DataType Int(int bits, int lanes = 1) { + return DataType(kDLInt, bits, lanes); +} + +/*! + * \brief Construct an uint type. + * \param bits The number of bits in the type. + * \param lanes The number of lanes + * \return The constructed data type. + */ +inline DataType UInt(int bits, int lanes = 1) { + return DataType(kDLUInt, bits, lanes); +} + +/*! + * \brief Construct a bool type. + * \param lanes The number of lanes + * \return The constructed data type. + */ +inline DataType Bool(int lanes = 1) { + return UInt(1, lanes); +} + +/*! + * \brief Construct an uint type. + * \param bits The number of bits in the type. + * \param lanes The number of lanes + * \return The constructed data type. + */ +inline DataType Float(int bits, int lanes = 1) { + return DataType(kDLFloat, bits, lanes); +} + +/*! + * \brief Construct a handle type. + * \param bits The number of bits in the type. + * \param lanes The number of lanes + * \return The constructed data type. + */ +inline DataType Handle(int bits = 64, int lanes = 1) { + return DataType(kHandle, bits, lanes); +} + +/*! + * \brief Get the corresponding type of TVMShapeIndex. + * \return The type of TVM shape index. + */ +inline DataType TVMShapeIndexType() { + if (std::is_signed::value) { + return Int(sizeof(tvm_index_t) * 8); + } else { + return UInt(sizeof(tvm_index_t) * 8); + } +} + +/*! + * \brief Convert DLDataType to DataType. + * \param t The original type. + * \return The conversion result. + */ +inline DataType TVMType2Type(DLDataType t) { + return DataType(t.code, t.bits, t.lanes); +} + +/*! + * \brief Convert DataType to DataType. + * \param t The original type. + * \return The conversion result. + */ +inline DLDataType Type2TVMType(DataType t) { + return t.operator DLDataType(); +} + +/*! + * \brief Get the number of bytes needed in a vector. + * \param dtype The data type. + * \return Number of bytes needed. + */ +inline int GetVectorBytes(DataType dtype) { + int data_bits = dtype.bits() * dtype.lanes(); + // allow bool to exist + if (dtype == Bool()) return 1; + CHECK_EQ(data_bits % 8, 0U) + << "Need to load/store by multiple of bytes"; + return data_bits / 8; +} + +// Overload print function. +inline std::ostream& operator<<(std::ostream& os, DataType dtype) { // NOLINT(*) + using namespace tvm::runtime; + return os << dtype.operator DLDataType(); +} + +// Backward compatibility +using Type = DataType; +} // namespace tvm +#endif // TVM_DTYPE_H_ diff --git a/include/tvm/expr.h b/include/tvm/expr.h index c7e69d59d636..07cfbc7791da 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -1,4 +1,3 @@ - /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -25,72 +24,107 @@ #ifndef TVM_EXPR_H_ #define TVM_EXPR_H_ -#include -#include #include #include #include #include "base.h" +#include "dtype.h" +#include "node/container.h" +#include "node/ir_functor.h" #include "runtime/c_runtime_api.h" namespace tvm { -using HalideIR::Type; -using HalideIR::Float; -using HalideIR::Bool; -using HalideIR::Int; -using HalideIR::UInt; -using HalideIR::Handle; -using HalideIR::ExprHash; -using HalideIR::ExprEqual; - -using HalideIR::Expr; -using HalideIR::VarExpr; -using HalideIR::IR::RangeNode; -using HalideIR::IR::FunctionRef; -using HalideIR::IR::FunctionBaseNode; -using HalideIR::Internal::IntImm; -using HalideIR::Internal::Stmt; -using HalideIR::Internal::IRPrinter; -using HalideIR::Internal::Variable; - -inline Type TVMShapeIndexType() { - if (std::is_signed::value) { - return Int(sizeof(tvm_index_t) * 8); - } else { - return UInt(sizeof(tvm_index_t) * 8); +/*! \brief Base node of all expressions. */ +class ExprNode : public Node { + public: + /*! \brief The data type of the expression. */ + DataType type; + + static constexpr const char* _type_key = "Expr"; + TVM_DECLARE_BASE_NODE_INFO(ExprNode, Node); +}; + +/*! \brief Container of all expressions. */ +class Expr : public NodeRef { + public: + Expr() {} + explicit Expr(NodePtr ptr) : NodeRef(ptr) {} + /*! + * \brief construct from integer. + * \param value The value to be constructed. + */ + TVM_DLL Expr(int32_t value); // NOLINT(*) + /*! + * \brief construct from float. + * \param value The value to be constructed. + */ + TVM_DLL Expr(float value); // NOLINT(*) + /*! + * \brief construct from string. + * \param str The value to be constructed. + */ + TVM_DLL Expr(std::string str); // NOLINT(*) + + /*! \return the data type of this expression. */ + DataType type() const { + return static_cast(get())->type; } -} -inline Type TVMType2Type(TVMType t) { - return Type(static_cast(t.code), t.bits, t.lanes); -} + /*! \brief type indicate the container type */ + using ContainerType = ExprNode; +}; -inline TVMType Type2TVMType(Type t) { - TVMType ret; - ret.code = static_cast(t.code()); - ret.bits = static_cast(t.bits()); - ret.lanes = static_cast(t.lanes()); - return ret; -} +/*! \brief Base node of all statements. */ +class StmtNode : public Node { + public: + static constexpr const char* _type_key = "Stmt"; + TVM_DECLARE_BASE_NODE_INFO(StmtNode, Node); +}; -// Get number of bytes considering vector type. -inline int GetVectorBytes(Type dtype) { - int data_bits = dtype.bits() * dtype.lanes(); - // allow bool to exist - if (dtype == Bool()) return 1; - CHECK_EQ(data_bits % 8, 0U) - << "Need to load/store by multiple of bytes"; - return data_bits / 8; -} +/*! \brief Container of all statements */ +class Stmt : public NodeRef { + public: + TVM_DEFINE_NODE_REF_METHODS(Stmt, NodeRef, StmtNode); +}; + +class Var; +/*! + * \brief A variable node in the IR. + * + * A vraible is uniquely identified by its address. + * + * Each variable is only binded once in the following nodes: + * - Allocate + * - For + * - Let + * - LetStmt + */ +class Variable : public ExprNode { + public: + /*! + * \brief The hint to the variable name. + * \note Each variable is uniquely identified by its address. + */ + std::string name_hint; + + static Var make(DataType dtype, std::string name_hint); + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("dtype", &type); + v->Visit("name", &name_hint); + } + + static constexpr const char* _type_key = "Variable"; + TVM_DECLARE_NODE_TYPE_INFO(Variable, ExprNode); +}; /*! \brief a named variable in TVM */ -class Var : public HalideIR::VarExpr { +class Var : public Expr { public: - EXPORT explicit Var(const std::string& name_hint = "v", - Type t = Int(32)) : VarExpr(name_hint, t) {} - explicit Var(NodePtr n) : VarExpr(n) {} - explicit Var(VarExpr v) : VarExpr(v) {} + explicit Var(NodePtr n) : Expr(n) {} + TVM_DLL explicit Var(std::string name_hint = "v", + Type t = Int(32)); /*! * \brief Make a new copy of var with same type, append suffix * \param suffix The suffix to be appended. @@ -99,10 +133,47 @@ class Var : public HalideIR::VarExpr { Var copy_with_suffix(const std::string& suffix) const { return Var((*this)->name_hint + suffix, (*this)->type); } + /*! + * \brief Get pointer to the internal value. + * \return the corresponding Variable. + */ + const Variable* operator->() const { + return get(); + } + /*! + * \brief Get pointer to the internal value. + * \return the corresponding Variable. + */ + const Variable* get() const { + return static_cast(node_.get()); + } /*! \brief type indicate the container type */ using ContainerType = Variable; }; +// Backward compatibility, will be removed later. +using VarExpr = Var; +using BaseExprNode = ExprNode; +using ExprHash = NodeHash; +using ExprEqual = NodeEqual; + +class Integer; +/*! \brief ExprNode: constant integer. */ +class IntImm : public ExprNode { + public: + /*! \brief the Internal value. */ + int64_t value; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("dtype", &type); + v->Visit("value", &value); + } + + TVM_DLL static Integer make(DataType t, int64_t value); + + static constexpr const char* _type_key = "IntImm"; + TVM_DECLARE_NODE_TYPE_INFO(IntImm, ExprNode); +}; /*! * \brief Container of constant integer (IntImm). @@ -148,34 +219,52 @@ class Integer : public Expr { using ContainerType = IntImm; }; +/*! \brief range over one dimension */ +class RangeNode : public Node { + public: + /*! \brief beginning of the node */ + Expr min; + /*! \brief the extend of range */ + Expr extent; + /*! \brief constructor */ + RangeNode() {} + RangeNode(Expr min, Expr extent) : min(min), extent(extent) {} -/*! \brief container class of iteration variable. */ -class IterVarNode; + void VisitAttrs(AttrVisitor* v) final { + v->Visit("min", &min); + v->Visit("extent", &extent); + } -/*! - * \brief same as HalideIR::IR::Range - * except it provide an constructor with (begin, end) - * - * \note Traditional Halide's Range have a constructor with - * (begin, extent), which does not match the convention in e.g. python. - * We decided to correct it by removing the constructor in HalideIR, - * and add it back in TVM's range. - */ -class Range : public HalideIR::IR::Range { + static constexpr const char* _type_key = "Range"; + TVM_DECLARE_NODE_TYPE_INFO(RangeNode, Node); +}; + +/*! \brief Range constainer */ +class Range : public NodeRef { public: - /*! \brief constructor */ - Range() {} - explicit Range(NodePtr n) : HalideIR::IR::Range(n) {} /*! * \brief constructor by begin and end * \param begin The begin of the range. * \param end The end of the range. */ TVM_DLL Range(Expr begin, Expr end); - - TVM_DLL static Range make_by_min_extent(Expr min, Expr extent); + /*! + * \brief construct a new range with min and extent + * The corresponding constructor is removed, + * because that is counter convention of tradition meaning + * of range(begin, end) + * + * \param min The minimum range. + * \param extent The extent of the range. + */ + static Range make_by_min_extent(Expr min, Expr extent); + // declare range. + TVM_DEFINE_NODE_REF_METHODS(Range, NodeRef, RangeNode); }; +/*! \brief container class of iteration variable. */ +class IterVarNode; + using Region = Array; /*! @@ -289,9 +378,6 @@ TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv"); using Domain = Array; -// print functions for expr -TVM_DLL std::ostream& operator<<(std::ostream& os, const NodeRef& n); // NOLINT(*) - /*! * \brief Dump the node to stderr, used for debug purposes. * \param node The input node @@ -364,7 +450,7 @@ inline const char* IterVarType2String(IterVarType t) { * \param name_hint The name hint for the expression * \param t The type of the expression */ -TVM_DLL Var var(const std::string& name_hint, Type t = Int(32)); +TVM_DLL Var var(std::string name_hint, Type t = Int(32)); /* * \brief Template function to convert Map to unordered_map @@ -382,6 +468,32 @@ inline std::unordered_map as_unordered_map(const Map& dmap) { } return ret; } + +// Printer infra. +/*! \brief A Pretty printer class to print the IR. */ +class IRPrinter { + public: + /*! \brief The output stream */ + std::ostream& stream; + /*! \brief The indentation level. */ + int indent{0}; + explicit IRPrinter(std::ostream& stream) // NOLINT(*) + : stream(stream) {} + + /*! \brief The node to be printed. */ + TVM_DLL void Print(const NodeRef& node); + /*! \brief Print indent to the stream */ + TVM_DLL void PrintIndent(); + // Allow registration to be printer. + using FType = IRFunctor; + TVM_DLL static FType& vtable(); +}; + +// default print function for all nodes +inline std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*) + IRPrinter(os).Print(n); + return os; +} } // namespace tvm namespace std { diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 7524109ec48b..547386154f76 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -16,18 +16,19 @@ * specific language governing permissions and limitations * under the License. */ - /*! * \file tvm/ir.h * \brief Additional high level nodes in the IR */ +// Acknowledgement: Most low-level IR nodes originate from Halide. + #ifndef TVM_IR_H_ #define TVM_IR_H_ -#include -#include #include #include +#include +#include #include "base.h" #include "expr.h" #include "runtime/util.h" @@ -35,17 +36,561 @@ namespace tvm { namespace ir { -using HalideIR::Internal::BaseExprNode; -using HalideIR::Internal::ExprNode; -using HalideIR::Internal::StmtNode; -using HalideIR::Internal::IRNodeType; -using HalideIR::Internal::ForType; -using HalideIR::DeviceAPI; +using IntImm = tvm::IntImm; +using Variable = tvm::Variable; + +/*! \brief constant unsigned integer. */ +class UIntImm : public ExprNode { + public: + /*! \brief The constant value content. */ + uint64_t value; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("dtype", &type); + v->Visit("value", &value); + } + + TVM_DLL static Expr make(Type t, uint64_t value); + + static constexpr const char* _type_key = "UIntImm"; + TVM_DECLARE_NODE_TYPE_INFO(UIntImm, ExprNode); +}; + +/*! \brief Floating point constants. */ +class FloatImm : public ExprNode { + public: + /*! \brief The constant value content. */ + double value; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("dtype", &type); + v->Visit("value", &value); + } + + TVM_DLL static Expr make(Type t, double value); + + static constexpr const char* _type_key = "FloatImm"; + TVM_DECLARE_NODE_TYPE_INFO(FloatImm, ExprNode); +}; + +/*! \brief String constants, only used in asserts. */ +class StringImm : public ExprNode { + public: + /*! \brief The constant value content. */ + std::string value; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("dtype", &type); + v->Visit("value", &value); + } + + TVM_DLL Expr static make(std::string value); + + static constexpr const char* _type_key = "StringImm"; + TVM_DECLARE_NODE_TYPE_INFO(StringImm, ExprNode); +}; + +/*! + * \brief Cast value from one data type to another. + * \note The lanes of value should keep fixed. + */ +class Cast : public ExprNode { + public: + /*! \brief Original data type. */ + Expr value; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("dtype", &type); + v->Visit("value", &value); + } + + TVM_DLL static Expr make(Type t, Expr v); + + static constexpr const char* _type_key = "Cast"; + TVM_DECLARE_NODE_TYPE_INFO(Cast, ExprNode); +}; + +/*! + * \brief Base template to implement binary ops. + * \tparam T The type of the child class. + */ +template +class BinaryOpNode : public ExprNode { + public: + /*! \brief The left operand. */ + Expr a; + /*! \brief The right operand. */ + Expr b; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("dtype", &(this->type)); + v->Visit("a", &a); + v->Visit("b", &b); + } + + static Expr make(Expr a, Expr b) { + CHECK(a.defined()) << "ValueError: a is undefined\n"; + CHECK(b.defined()) << "ValueError: b is undefined\n"; + CHECK(a.type() == b.type()) << "TypeError: mismatched types\n"; + NodePtr node = make_node(); + node->type = a.type(); + node->a = std::move(a); + node->b = std::move(b); + return Expr(node); + } + + TVM_DECLARE_NODE_TYPE_INFO(T, ExprNode); +}; + +/*! \brief a + b */ +class Add : public BinaryOpNode { + public: + static constexpr const char* _type_key = "Add"; +}; + +/*! \brief a - b */ +class Sub : public BinaryOpNode { + public: + static constexpr const char* _type_key = "Sub"; +}; + +/*! \brief a * b */ +class Mul : public BinaryOpNode { + public: + static constexpr const char* _type_key = "Mul"; +}; + +/*! + * \brief a / b in the C semnatics. + * \note For integer division, C standard uses trunc div. + */ +class Div : public BinaryOpNode
{ + public: + static constexpr const char* _type_key = "Div"; +}; + +/*! + * \brief a % b in the C semnatics. + * \note For integer division, C standard uses trunc div. + */ +class Mod : public BinaryOpNode { + public: + static constexpr const char* _type_key = "Mod"; +}; + +/*! \brief min(a, b) */ +class Min : public BinaryOpNode { + public: + static constexpr const char* _type_key = "Min"; +}; + +/*! \brief max(a, b) */ +class Max : public BinaryOpNode { + public: + static constexpr const char* _type_key = "Max"; +}; + +/*! + * \brief Base template to implement comparison ops. + * \tparam T The type of the child class. + */ +template +class CmpOpNode : public ExprNode { + public: + /*! \brief The left operand. */ + Expr a; + /*! \brief The right operand. */ + Expr b; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("dtype", &(this->type)); + v->Visit("a", &a); + v->Visit("b", &b); + } + + static Expr make(Expr a, Expr b) { + CHECK(a.defined()) << "ValueError: a is undefined\n"; + CHECK(b.defined()) << "ValueError: b is undefined\n"; + CHECK(a.type() == b.type()) << "TypeError: mismatched types\n"; + NodePtr node = make_node(); + node->type = Bool(a.type().lanes()); + node->a = std::move(a); + node->b = std::move(b); + return Expr(node); + } + + TVM_DECLARE_NODE_TYPE_INFO(T, ExprNode); +}; + +/*! \brief a == b */ +class EQ : public CmpOpNode { + public: + static constexpr const char* _type_key = "EQ"; +}; + +/*! \brief a != b */ +class NE : public CmpOpNode { + public: + static constexpr const char* _type_key = "NE"; +}; + +/*! \brief a < b */ +class LT : public CmpOpNode { + public: + static constexpr const char* _type_key = "LT"; +}; + +/*! \brief a <= b */ +struct LE : public CmpOpNode { + public: + static constexpr const char* _type_key = "LE"; +}; + +/*! \brief a > b */ +class GT : public CmpOpNode { + public: + static constexpr const char* _type_key = "GT"; +}; + +/*! \brief a >= b */ +class GE : public CmpOpNode { + public: + static constexpr const char* _type_key = "GE"; +}; + +/*! \brief a && b */ +class And : public ExprNode { + public: + /*! \brief The left operand. */ + Expr a; + /*! \brief The right operand. */ + Expr b; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("dtype", &(this->type)); + v->Visit("a", &a); + v->Visit("b", &b); + } + + TVM_DLL static Expr make(Expr a, Expr b); + + static constexpr const char* _type_key = "And"; + TVM_DECLARE_NODE_TYPE_INFO(And, ExprNode); +}; + +/*! \brief a || b */ +class Or : public ExprNode { + public: + /*! \brief The left operand. */ + Expr a; + /*! \brief The right operand. */ + Expr b; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("dtype", &type); + v->Visit("a", &a); + v->Visit("b", &b); + } + + TVM_DLL static Expr make(Expr a, Expr b); + + static constexpr const char* _type_key = "Or"; + TVM_DECLARE_NODE_TYPE_INFO(Or, ExprNode); +}; + +/*! \brief !a */ +class Not : public ExprNode { + public: + /*! \brief The input operand. */ + Expr a; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("dtype", &type); + v->Visit("a", &a); + } + + TVM_DLL static Expr make(Expr a); + + static constexpr const char* _type_key = "Not"; + TVM_DECLARE_NODE_TYPE_INFO(Not, ExprNode); +}; + +/*! + * \brief return true_value if condition is true, otherwise return false_value. + * \note Both true_value and false_value could be evaluated + * regardless of the condition value. + * Do not use it to guard against out of bound access, + * please use if_then_else instead. + */ +class Select : public ExprNode { + public: + /*! \brief The condition */ + Expr condition; + /*! \brief value to be returned when condition is true. */ + Expr true_value; + /*! \brief value to be returned when condition is false. */ + Expr false_value; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("dtype", &type); + v->Visit("condition", &condition); + v->Visit("true_value", &true_value); + v->Visit("false_value", &false_value); + } + + TVM_DLL static Expr make(Expr condition, Expr true_value, Expr false_value); + + static constexpr const char* _type_key = "Select"; + TVM_DECLARE_NODE_TYPE_INFO(Select, ExprNode); +}; + +/*! + * \brief Load the value from buffer_var. + * + * Equivalent to ((DType*)buffer_var)[index] + * where DType is the type specified by type().element_of(). + * + * For example, if type = float32x3, then the load will corresponds to + * + * \code + * + * auto buffer = static_cast(buffer_var); + * auto loaded_val = float32x3(buffer[index.v0], buffer[index.v1], buffer[index.v2]); + * + * \endcode + */ +class Load : public ExprNode { + public: + /*! \brief The buffer variable. */ + Var buffer_var; + /*! \brief The index locations to be loaded. */ + Expr index; + /*! \brief The predicate to mask which lanes would be loaded. */ + Expr predicate; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("dtype", &type); + v->Visit("buffer_var", &buffer_var); + v->Visit("index", &index); + v->Visit("predicate", &predicate); + } + + TVM_DLL static Expr make(Type type, Var buffer_var, Expr index, Expr predicate); + + static constexpr const char* _type_key = "Load"; + TVM_DECLARE_NODE_TYPE_INFO(Load, ExprNode); +}; + +/*! + * \brief Construct a vector with lanes elements + * where its i-th element equals base + i * stride. + * This is useful to construct a index for a continuous vector load. + * + * Examples: + * - ramp(0, 1, 3) = [0, 1, 2] + * - ramp(1, 2, 4) = [1, 3, 5, 7] + */ +class Ramp : public ExprNode { + public: + /*! \brief The base value. */ + Expr base; + /*! \brief The stride of each step. */ + Expr stride; + /*! \brief Total number of lanes. */ + int lanes; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("dtype", &type); + v->Visit("base", &base); + v->Visit("stride", &stride); + v->Visit("lanes", &lanes); + } + + TVM_DLL static Expr make(Expr base, Expr stride, int lanes); + + static constexpr const char* _type_key = "Ramp"; + TVM_DECLARE_NODE_TYPE_INFO(Ramp, ExprNode); +}; + +/*! \brief Create a vector where all the elements are value. */ +class Broadcast : public ExprNode { + public: + /*! \brief The base value. */ + Expr value; + /*! \brief The numerb of lanes. */ + int lanes; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("dtype", &type); + v->Visit("value", &value); + v->Visit("lanes", &lanes); + } + + TVM_DLL static Expr make(Expr value, int lanes); + + static constexpr const char* _type_key = "Broadcast"; + TVM_DECLARE_NODE_TYPE_INFO(Broadcast, ExprNode); +}; + +/*! + * \brief Let binding. Bind var to value then evaluate body. + */ +class Let : public ExprNode { + public: + /*! \brief The variable. */ + Var var; + /*! \brief The value to be binded. */ + Expr value; + /*! \brief The result expression. */ + Expr body; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("dtype", &type); + v->Visit("var", &var); + v->Visit("value", &value); + v->Visit("body", &body); + } + + TVM_DLL static Expr make(Var var, Expr value, Expr body); + + static constexpr const char* _type_key = "Let"; + TVM_DECLARE_NODE_TYPE_INFO(Let, ExprNode); +}; + +// Call node, represent a function call or a multi-dimensional array load. +// +// TODO(tvm-team): +// Refactor call with more explicit property registrations. +// rather than calling a string symbol. +// We should move most information into function itself and remove name. + +/*! \brief Base node of internal functions. */ +class FunctionBaseNode : public Node { + public: + /*! \return the name of the function */ + virtual const std::string& func_name() const = 0; + /*! \return the number of outputs of this function */ + virtual int num_outputs() const = 0; +}; + +/*! \brief reference to a function */ +class FunctionRef : public NodeRef { + public: + TVM_DEFINE_NODE_REF_METHODS(FunctionRef, NodeRef, FunctionBaseNode); +}; + +/*! + * \brief Call node. + */ +class Call : public ExprNode { + public: + /*! \brief Possible types of calls. */ + enum CallType : int { + /*! \brief Extern "C" function. */ + Extern = 0, + /*! \brief Extern CXX function. */ + ExternCPlusPlus = 1, + /*! \brief Extern "C" without side-effect. */ + PureExtern = 2, + /*! \brief Halide-style call, evaluates func(args). */ + Halide = 3, + /*! \brief Intrinsic functions. */ + Intrinsic = 4, + /*! \brief Intrinsic functions that are pure. */ + PureIntrinsic = 5 + }; + /*! \brief The name of the function/intrinsic. */ + std::string name; + /*! \brief The arguments. */ + Array args; + /*! \brief Type of calls. */ + CallType call_type; + /*! \brief The function to be called. */ + FunctionRef func; + /*! \brief The output value index if func's value is a tuple. */ + int value_index{0}; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("dtype", &type); + v->Visit("name", &name); + v->Visit("args", &args); + v->Visit("call_type", &call_type); + v->Visit("func", &func); + v->Visit("value_index", &value_index); + } + + TVM_DLL static Expr make(Type type, + std::string name, + Array args, + CallType call_type, + FunctionRef func = FunctionRef(), + int value_index = 0); + + /*! \return Whether call node is pure. */ + bool is_pure() const { + return (call_type == PureExtern || + call_type == PureIntrinsic || + call_type == Halide); + } + + /*! + * \return Whether call node corresponds to a defined intrinsic. + * \param intrin_name The name of the intrinsic. + */ + bool is_intrinsic(const char* intrin_name) const { + return + ((call_type == Intrinsic || + call_type == PureIntrinsic) && + name == intrin_name); + } + + static constexpr const char* _type_key = "Call"; + TVM_DECLARE_NODE_TYPE_INFO(Call, ExprNode); + + // Build-in intrinsics + static constexpr const char* reinterpret = "reinterpret"; + static constexpr const char* bitwise_and = "bitwise_and"; + static constexpr const char* bitwise_not = "bitwise_not"; + static constexpr const char* bitwise_xor = "bitwise_xor"; + static constexpr const char* bitwise_or = "bitwise_or"; + static constexpr const char* shift_left = "shift_left"; + static constexpr const char* shift_right = "shift_right"; + static constexpr const char* popcount = "popcount"; + static constexpr const char* likely = "likely"; + static constexpr const char* glsl_texture_store = "glsl_texture_store"; + static constexpr const char* prefetch = "prefetch"; +}; + +/*! + * \brief Shuffle instruction. + * vec = concat(vectors) + * result = (vec[indices[0]], vec[indices[1]] ...) + */ +class Shuffle : public ExprNode { + public: + /*! \brief the input vectors. */ + Array vectors; + /*! \brief The indices of each element. */ + Array indices; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("vectors", &vectors); + v->Visit("indices", &indices); + } + + TVM_DLL static Expr make(Array vectors, Array indices); + TVM_DLL static Expr make_concat(Array vectors); + TVM_DLL static Expr make_extract_element(Expr vector, int index); -// Node container for CommReducer -struct CommReducerNode; + static constexpr const char* _type_key = "Shuffle"; + TVM_DECLARE_NODE_TYPE_INFO(Shuffle, ExprNode); +}; -struct CommReducer : public NodeRef { +// Reduce operator +class CommReducerNode; + +class CommReducer : public NodeRef { + public: CommReducer() {} explicit CommReducer(NodePtr n) : NodeRef(n) {} /*! @@ -66,7 +611,8 @@ struct CommReducer : public NodeRef { * \brief A commutative reducer node to represent a commutative * binary operator with identity element */ -struct CommReducerNode : public Node { +class CommReducerNode : public Node { + public: /*! \brief The left argument of reducer */ Array lhs; /*! \brief The right argument of reducer */ @@ -82,8 +628,10 @@ struct CommReducerNode : public Node { /*! \brief Function call operator to combine a and b */ Array operator()(Array a, Array b) const; /*! \brief construct CommReducer from args, result and identity_element */ - TVM_DLL static CommReducer make(Array lhs, Array rhs, - Array result, Array identity_element); + TVM_DLL static CommReducer make(Array lhs, + Array rhs, + Array result, + Array identity_element); void VisitAttrs(AttrVisitor* v) final { v->Visit("lhs", &lhs); @@ -104,7 +652,8 @@ inline const CommReducerNode* CommReducer::operator->() const { } /*! \brief Reduction operator operator */ -struct Reduce : public ExprNode { +class Reduce : public ExprNode { + public: /*! \brief The commutative combiner */ CommReducer combiner; /*! \brief The source operand */ @@ -134,17 +683,483 @@ struct Reduce : public ExprNode { v->Visit("condition", &condition); v->Visit("value_index", &value_index); } - static const IRNodeType _type_info = IRNodeType::ExtensionExpr; + static constexpr const char* _type_key = "Reduce"; + TVM_DECLARE_NODE_TYPE_INFO(Reduce, ExprNode); }; /*! \brief Any shape. */ -struct Any : public ExprNode { +class Any : public ExprNode { + public: + void VisitAttrs(AttrVisitor* v) final {} + TVM_DLL static Expr make(); - void VisitAttrs(AttrVisitor* v) final {} - static const IRNodeType _type_info = IRNodeType::ExtensionExpr; static constexpr const char* _type_key = "Any"; + TVM_DECLARE_NODE_TYPE_INFO(Any, ExprNode); +}; + +// Statements +/*! + * \brief Let binding, bind var to value, then run body. + */ +class LetStmt : public StmtNode { + public: + /*! \brief The variable. */ + Var var; + /*! \brief The value to be binded. */ + Expr value; + /*! \brief The body block. */ + Stmt body; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("var", &var); + v->Visit("value", &value); + v->Visit("body", &body); + } + + TVM_DLL static Stmt make(Var var, Expr value, Stmt body); + + static constexpr const char* _type_key = "LetStmt"; + TVM_DECLARE_NODE_TYPE_INFO(LetStmt, StmtNode); +}; + +/*! + * \brief Define certain auxiliary attribute for the body to be a symbolic value. + * This provide auxiliary information for IR passes that transforms body. + * + * In terms of effect, this is equivalent to Block(Evaluate(value), body). + * + * Examples of possible usage: + * - Bound of function, variables. + * - Hint which block corresponds to a parallel region. + */ +class AttrStmt : public StmtNode { + public: + /*! \brief this is attribute about certain node */ + NodeRef node; + /*! \brief the type key of the attribute */ + std::string attr_key; + /*! \brief The attribute value, value is well defined at current scope. */ + Expr value; + /*! \brief The body statement to be executed */ + Stmt body; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("node", &node); + v->Visit("attr_key", &attr_key); + v->Visit("value", &value); + v->Visit("body", &body); + } + + TVM_DLL static Stmt make(NodeRef node, + std::string type_key, + Expr value, + Stmt body); + + static constexpr const char* _type_key = "AttrStmt"; + TVM_DECLARE_NODE_TYPE_INFO(AttrStmt, StmtNode); +}; + +/*! + * \brief Assert condition, if an error occurs, return the error message. + */ +class AssertStmt : public StmtNode { + public: + /*! \brief Condition to be checked. */ + Expr condition; + /*! \brief Error message when assertion failed. */ + Expr message; + /*! + * \brief Body which this assertion holds true. + * Will be executed after the assertion. + */ + Stmt body; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("condition", &condition); + v->Visit("message", &message); + v->Visit("body", &body); + } + + TVM_DLL static Stmt make(Expr condition, Expr message, Stmt body); + + static constexpr const char* _type_key = "AssertStmt"; + TVM_DECLARE_NODE_TYPE_INFO(AssertStmt, StmtNode); +}; + +// TODO(tvm-team): consider consolidate with AttrStmt. +/*! \brief annotation node of producer/consumer relation. */ +class ProducerConsumer : public StmtNode { + public: + /*! \brief The corresponding tensor. */ + FunctionRef func; + /*! \brief Whether the relation is producer. */ + bool is_producer; + /*! \brief Body to be executed. */ + Stmt body; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("func", &func); + v->Visit("is_producer", &is_producer); + v->Visit("body", &body); + } + + TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body); + + static constexpr const char* _type_key = "ProducerConsumer"; + TVM_DECLARE_NODE_TYPE_INFO(ProducerConsumer, StmtNode); +}; + +/*! + * \brief Store value to the buffer. + * + * Equivalent to ((DType*)buffer_var)[index] = value. + * where DType is the type specified by type().element_of(). + * + * For example, if type = float32x3, then the load will corresponds to + * + * \code + * + * auto buffer = static_cast(buffer_var); + * buffer[index.v0] = value.v0; + * buffer[index.v1] = value.v1; + * buffer[index.v2] = value.v2; + * + * \endcode + * \sa Load + */ +class Store : public StmtNode { + public: + /*! \brief The buffer variable. */ + Var buffer_var; + /*! \brief The value to be stored. */ + Expr value; + /*! \brief The index locations to be stored. */ + Expr index; + /*! \brief The predicate to mask which lanes would be stored. */ + Expr predicate; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("buffer_var", &buffer_var); + v->Visit("value", &value); + v->Visit("index", &index); + v->Visit("predicate", &predicate); + } + + TVM_DLL static Stmt make(Var buffer_var, + Expr value, + Expr index, + Expr predicate); + + static constexpr const char* _type_key = "Store"; + TVM_DECLARE_NODE_TYPE_INFO(Store, StmtNode); +}; + +/*! + * \brief Store value into mult-dimensional array defined by func. + */ +class Provide : public StmtNode { + public: + /*! \brief The function to be updated. */ + FunctionRef func; + /*! \brief The output value index if func's value is a tuple. */ + int value_index{0}; + /*! \brief The value to be stored. */ + Expr value; + /*! \brief The index arguments of the function. */ + Array args; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("func", &func); + v->Visit("value_index", &value_index); + v->Visit("value", &value); + v->Visit("args", &args); + } + + TVM_DLL static Stmt make(FunctionRef func, + int value_index, + Expr value, + Array args); + + static constexpr const char* _type_key = "Provide"; + TVM_DECLARE_NODE_TYPE_INFO(Provide, StmtNode); +}; + +/*! + * \brief Allocate a buffer that can be used in body. + */ +class Allocate : public StmtNode { + public: + /*! \brief The buffer variable. */ + Var buffer_var; + /*! \brief The type of the buffer. */ + DataType type; + /*! \brief The extents of the buffer. */ + Array extents; + /*! \brief Only allocate buffer when condition is satisfied. */ + Expr condition; + /*! \brief The body to be executed. */ + Stmt body; + // The following two fields are deprecated + // kept for backward compatibility and will be refactored later. + Expr new_expr; + std::string free_function; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("buffer_var", &buffer_var); + v->Visit("dtype", &type); + v->Visit("extents", &extents); + v->Visit("condition", &condition); + v->Visit("body", &body); + } + + TVM_DLL static Stmt make(Var buffer_var, + DataType type, + Array extents, + Expr condition, + Stmt body, + Expr new_expr = Expr(), + std::string free_function = std::string()); + + /*! + * \brief If the buffer size is constant, return the size. + * Otherwise return 0. + * \return The result. + */ + int32_t constant_allocation_size() const { + return constant_allocation_size(extents); + } + /*! + * \brief If the buffer size is constant, return the size. + * Otherwise return 0. + * \param extents The extents of the buffer. + * \return The result. + */ + TVM_DLL static int32_t constant_allocation_size( + const Array& extents); + + static constexpr const char* _type_key = "Allocate"; + TVM_DECLARE_NODE_TYPE_INFO(Allocate, StmtNode); +}; + +/*! \brief Free the resources in the buffer before the scope ends. */ +class Free : public StmtNode { + public: + /*! \brief The buffer variable. */ + Var buffer_var; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("buffer_var", &buffer_var); + } + + TVM_DLL static Stmt make(Var buffer_var); + + static constexpr const char* _type_key = "Free"; + TVM_DECLARE_NODE_TYPE_INFO(Free, StmtNode); +}; + +/*! + * \brief Annotate the bounds where func need to be written and read in body. + * We will need to allocate space for the corresponding regions. + */ +class Realize : public StmtNode { + public: + /*! \brief The function to be realized. */ + FunctionRef func; + /*! \brief The output value index if func's value is a tuple. */ + int value_index; + /*! \brief The data type of the array. */ + DataType type; + /*! \brief Bounds to be realized. */ + Region bounds; + /*! \brief Only realize if condition holds. */ + Expr condition; + /*! \brief The body of realization. */ + Stmt body; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("func", &func); + v->Visit("value_index", &value_index); + v->Visit("dtype", &type); + v->Visit("bounds", &bounds); + v->Visit("condition", &condition); + v->Visit("body", &body); + } + + TVM_DLL static Stmt make(FunctionRef func, + int value_index, + DataType type, + Region bounds, + Expr condition, + Stmt body); + + static constexpr const char* _type_key = "Realize"; + TVM_DECLARE_NODE_TYPE_INFO(Realize, StmtNode); +}; + +/*! + * \brief A sequence of statements. + */ +class Block : public StmtNode { + public: + /*! \brief The first statement. */ + Stmt first; + /*! \brief The restof statments. */ + Stmt rest; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("first", &first); + v->Visit("rest", &rest); + } + + TVM_DLL static Stmt make(Stmt first, Stmt rest); + TVM_DLL static Stmt make(const std::vector &stmts); + + static constexpr const char* _type_key = "Block"; + TVM_DECLARE_NODE_TYPE_INFO(Block, StmtNode); +}; + +/*! + * \brief IfThenElse statment. + */ +class IfThenElse : public StmtNode { + public: + /*! \brief The condition. */ + Expr condition; + /*! \brief The branch to be executed when condition is true. */ + Stmt then_case; + /*! \brief The branch to be executed when condition is false, can be null. */ + Stmt else_case; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("condition", &condition); + v->Visit("then_case", &then_case); + v->Visit("else_case", &else_case); + } + + TVM_DLL static Stmt make(Expr condition, Stmt then_case, Stmt else_case = Stmt()); + + static constexpr const char* _type_key = "IfThenElse"; + TVM_DECLARE_NODE_TYPE_INFO(IfThenElse, StmtNode); +}; + +/*! + * \brief Evaluates an expression. + * This is mostly used for putting a Call node into Stmt. + * + * If value do not have side-effect, this node can be safely removed. + */ +class Evaluate : public StmtNode { + public: + /*! \brief The expression to be evaluated. */ + Expr value; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("value", &value); + } + + TVM_DLL static Stmt make(Expr v); + + static constexpr const char* _type_key = "Evaluate"; + TVM_DECLARE_NODE_TYPE_INFO(Evaluate, StmtNode); +}; + +/*! \brief Additional annotation of for loop. */ +enum class ForType : int { + /*! \brief serial execution. */ + Serial = 0, + /*! \brief parallel execution on CPU. */ + Parallel = 1, + /*! \brief Vector SIMD loop annotaion. */ + Vectorized = 2, + /*! \brief Unroll annotation. */ + Unrolled = 3 +}; + +// Kevice api of for loop +// kept for backward compatibility +// consider refactor and remove later. +enum class DeviceAPI: int { + None = 0 +}; + +/*! + * \brief A for loop, with poissible type annotations. + * + * \code + * + * for (loop_var = min; loop_var < min + extent; ++loop_var) { + * // body + * } + * \endcode + */ +class For : public StmtNode { + public: + /*! \brief The loop variable. */ + Var loop_var; + /*! \brief The minimum value of iteration. */ + Expr min; + /*! \brief The extent of the iteration. */ + Expr extent; + /*! \brief The type of the for loop. */ + ForType for_type; + /*! + * \brief Deprecated, reserved for backward compatibility. + * Consider refactor and remove later. + */ + DeviceAPI device_api; + /*! \brief The body of the for loop. */ + Stmt body; + + TVM_DLL static Stmt make(Var loop_var, + Expr min, + Expr extent, + ForType for_type, + DeviceAPI device_api, + Stmt body); + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("loop_var", &loop_var); + v->Visit("min", &min); + v->Visit("extent", &extent); + v->Visit("for_type", &for_type); + v->Visit("device_api", &device_api); + v->Visit("body", &body); + } + + static constexpr const char* _type_key = "For"; + TVM_DECLARE_NODE_TYPE_INFO(For, StmtNode); +}; + +/*! + * \brief A prefetch hint of func. + */ +class Prefetch : public StmtNode { + public: + /*! \brief The function to be prefetched. */ + FunctionRef func; + /*! \brief The output value index if func's value is a tuple. */ + int value_index; + /*! \brief The data type of the array. */ + DataType type; + /*! \brief Bounds to be prefetched. */ + Region bounds; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("func", &func); + v->Visit("value_index", &value_index); + v->Visit("type", &type); + v->Visit("bounds", &bounds); + } + + TVM_DLL static Stmt make(FunctionRef func, + int value_index, + DataType type, + Region bounds); + + static constexpr const char* _type_key = "Prefetch"; + TVM_DECLARE_NODE_TYPE_INFO(Prefetch, StmtNode); }; /*! @@ -517,50 +1532,6 @@ constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce"; } // namespace intrinsic -// Reuse IR node defintiion from HalideIR -using HalideIR::Internal::IntImm; -using HalideIR::Internal::UIntImm; -using HalideIR::Internal::FloatImm; -using HalideIR::Internal::StringImm; -using HalideIR::Internal::Cast; -using HalideIR::Internal::Add; -using HalideIR::Internal::Sub; -using HalideIR::Internal::Mul; -using HalideIR::Internal::Div; -using HalideIR::Internal::Mod; -using HalideIR::Internal::Min; -using HalideIR::Internal::Max; -using HalideIR::Internal::EQ; -using HalideIR::Internal::NE; -using HalideIR::Internal::LT; -using HalideIR::Internal::LE; -using HalideIR::Internal::GT; -using HalideIR::Internal::GE; -using HalideIR::Internal::And; -using HalideIR::Internal::Or; -using HalideIR::Internal::Not; -using HalideIR::Internal::Select; -using HalideIR::Internal::Load; -using HalideIR::Internal::Ramp; -using HalideIR::Internal::Broadcast; -using HalideIR::Internal::Call; -using HalideIR::Internal::Let; -using HalideIR::Internal::LetStmt; -using HalideIR::Internal::AttrStmt; -using HalideIR::Internal::AssertStmt; -using HalideIR::Internal::ProducerConsumer; -using HalideIR::Internal::For; -using HalideIR::Internal::Store; -using HalideIR::Internal::Provide; -using HalideIR::Internal::Allocate; -using HalideIR::Internal::Free; -using HalideIR::Internal::Realize; -using HalideIR::Internal::Prefetch; -using HalideIR::Internal::Block; -using HalideIR::Internal::IfThenElse; -using HalideIR::Internal::Evaluate; -using HalideIR::Internal::Shuffle; - /*! * \brief Create a type annotation expression * \param dtype The data type @@ -571,6 +1542,10 @@ inline Expr TypeAnnotation(Type dtype) { "type_annotation", {}, ir::Call::PureIntrinsic); } + +// overload printing of for type. +TVM_DLL std::ostream& operator<<(std::ostream& os, ForType for_type); + } // namespace ir } // namespace tvm diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index 61080078c176..2837a6601136 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -25,6 +25,7 @@ #define TVM_IR_MUTATOR_H_ #include +#include #include "expr.h" #include "ir.h" #include "tvm/node/ir_functor.h" diff --git a/include/tvm/lowered_func.h b/include/tvm/lowered_func.h index cb03f6c9dae7..4da93b80c2ab 100644 --- a/include/tvm/lowered_func.h +++ b/include/tvm/lowered_func.h @@ -25,7 +25,6 @@ #ifndef TVM_LOWERED_FUNC_H_ #define TVM_LOWERED_FUNC_H_ -#include #include #include "base.h" @@ -42,7 +41,7 @@ class LoweredFuncNode; * \brief LoweredFunc represents function after lowering. * This is the final IR representation before codegen. */ -class LoweredFunc : public FunctionRef { +class LoweredFunc : public ir::FunctionRef { public: LoweredFunc() {} explicit LoweredFunc(NodePtr n) : FunctionRef(n) {} @@ -66,7 +65,7 @@ enum LoweredFuncType : int { }; /*! \brief Node container of LoweredFunc */ -class LoweredFuncNode : public FunctionBaseNode { +class LoweredFuncNode : public ir::FunctionBaseNode { public: /*! \brief The name of the function */ std::string name; diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h new file mode 100644 index 000000000000..71808902aeca --- /dev/null +++ b/include/tvm/node/container.h @@ -0,0 +1,612 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/node/container.h + * \brief Array/Map container in the DSL graph. + */ +#ifndef TVM_NODE_CONTAINER_H_ +#define TVM_NODE_CONTAINER_H_ + +#include +#include +#include +#include +#include +#include +#include "node.h" +#include "memory.h" + +namespace tvm { + +/*! \brief array node content in array */ +class ArrayNode : public Node { + public: + /*! \brief the data content */ + std::vector > data; + + void VisitAttrs(AttrVisitor* visitor) final { + // Visitor to array have no effect. + } + + static constexpr const char* _type_key = "Array"; + TVM_DECLARE_NODE_TYPE_INFO(ArrayNode, Node); +}; + +/*! \brief map node content */ +class MapNode : public Node { + public: + void VisitAttrs(AttrVisitor* visitor) final { + // Visitor to map have no effect. + } + // hash function + struct Hash { + size_t operator()(const NodePtr& n) const { + return std::hash()(n.get()); + } + }; + // comparator + struct Equal { + bool operator()( + const NodePtr& a, + const NodePtr& b) const { + return a.get() == b.get(); + } + }; + + /*! \brief The corresponding conatiner type */ + using ContainerType = std::unordered_map< + NodePtr, + NodePtr, + Hash, Equal>; + + /*! \brief the data content */ + ContainerType data; + + static constexpr const char* _type_key = "Map"; + TVM_DECLARE_NODE_TYPE_INFO(MapNode, Node); +}; + + +/*! \brief specialized map node with string as key */ +class StrMapNode : public Node { + public: + void VisitAttrs(AttrVisitor* visitor) final { + // Visitor to map have no effect. + } + /*! \brief The corresponding conatiner type */ + using ContainerType = std::unordered_map< + std::string, + NodePtr >; + + /*! \brief the data content */ + ContainerType data; + + static constexpr const char* _type_key = "StrMap"; + TVM_DECLARE_NODE_TYPE_INFO(StrMapNode, Node); +}; + +/*! + * \brief iterator adapter that adapts TIter to return another type. + * \tparam Converter a struct that contains converting function + * \tparam TIter the content iterator type. + */ +template +class IterAdapter { + public: + explicit IterAdapter(TIter iter) : iter_(iter) {} + inline IterAdapter& operator++() { // NOLINT(*) + ++iter_; + return *this; + } + inline IterAdapter& operator++(int) { // NOLINT(*) + ++iter_; + return *this; + } + inline IterAdapter operator+(int offset) const { // NOLINT(*) + return IterAdapter(iter_ + offset); + } + inline bool operator==(IterAdapter other) const { + return iter_ == other.iter_; + } + inline bool operator!=(IterAdapter other) const { + return !(*this == other); + } + inline const typename Converter::ResultType operator*() const { + return Converter::convert(*iter_); + } + + private: + TIter iter_; +}; + +/*! + * \brief Array container of NodeRef in DSL graph. + * Array implements copy on write semantics, which means array is mutable + * but copy will happen when array is referenced in more than two places. + * + * operator[] only provide const acces, use Set to mutate the content. + * \tparam T The content NodeRef type. + */ +template::value>::type > +class Array : public NodeRef { + public: + /*! + * \brief default constructor + */ + Array() { + node_ = make_node(); + } + /*! + * \brief move constructor + * \param other source + */ + Array(Array && other) { // NOLINT(*) + node_ = std::move(other.node_); + } + /*! + * \brief copy constructor + * \param other source + */ + Array(const Array &other) : NodeRef(other.node_) { // NOLINT(*) + } + /*! + * \brief constructor from pointer + * \param n the container pointer + */ + explicit Array(NodePtr n) : NodeRef(n) {} + /*! + * \brief constructor from iterator + * \param begin begin of iterator + * \param end end of iterator + * \tparam IterType The type of iterator + */ + template + Array(IterType begin, IterType end) { + assign(begin, end); + } + /*! + * \brief constructor from initializer list + * \param init The initalizer list + */ + Array(std::initializer_list init) { // NOLINT(*) + assign(init.begin(), init.end()); + } + /*! + * \brief constructor from vector + * \param init The vector + */ + Array(const std::vector& init) { // NOLINT(*) + assign(init.begin(), init.end()); + } + /*! + * \brief Constructs a container with n elements. Each element is a copy of val + * \param n The size of the container + * \param val The init value + */ + explicit Array(size_t n, const T& val) { + auto tmp_node = make_node(); + for (size_t i = 0; i < n; ++i) { + tmp_node->data.push_back(val.node_); + } + node_ = std::move(tmp_node); + } + /*! + * \brief move assign operator + * \param other The source of assignment + * \return reference to self. + */ + Array& operator=(Array && other) { + node_ = std::move(other.node_); + return *this; + } + /*! + * \brief copy assign operator + * \param other The source of assignment + * \return reference to self. + */ + Array& operator=(const Array & other) { + node_ = other.node_; + return *this; + } + /*! + * \brief reset the array to content from iterator. + * \param begin begin of iterator + * \param end end of iterator + * \tparam IterType The type of iterator + */ + template + void assign(IterType begin, IterType end) { + auto n = make_node(); + for (IterType it = begin; it != end; ++it) { + n->data.push_back((*it).node_); + } + node_ = std::move(n); + } + /*! + * \brief Read i-th element from array. + * \param i The index + * \return the i-th element. + */ + inline const T operator[](size_t i) const { + return T(static_cast(node_.get())->data[i]); + } + /*! \return The size of the array */ + inline size_t size() const { + if (node_.get() == nullptr) return 0; + return static_cast(node_.get())->data.size(); + } + /*! + * \brief copy on write semantics + * Do nothing if current handle is the unique copy of the array. + * Otherwise make a new copy of the array to ensure the current handle + * hold a unique copy. + * + * \return Handle to the internal node container(which ganrantees to be unique) + */ + inline ArrayNode* CopyOnWrite() { + if (node_.get() == nullptr || !node_.unique()) { + NodePtr n = make_node(); + n->data = static_cast(node_.get())->data; + NodePtr(std::move(n)).swap(node_); + } + return static_cast(node_.get()); + } + /*! + * \brief push a new item to the back of the list + * \param item The item to be pushed. + */ + inline void push_back(const T& item) { + ArrayNode* n = this->CopyOnWrite(); + n->data.push_back(item.node_); + } + /*! + * \brief set i-th element of the array. + * \param i The index + * \param value The value to be setted. + */ + inline void Set(size_t i, const T& value) { + ArrayNode* n = this->CopyOnWrite(); + n->data[i] = value.node_; + } + /*! \return whether array is empty */ + inline bool empty() const { + return size() == 0; + } + /*! \brief specify container node */ + using ContainerType = ArrayNode; + + struct Ptr2NodeRef { + using ResultType = T; + static inline T convert(const NodePtr& n) { + return T(n); + } + }; + using iterator = IterAdapter >::const_iterator>; + + using reverse_iterator = IterAdapter< + Ptr2NodeRef, + std::vector >::const_reverse_iterator>; + + /*! \return begin iterator */ + inline iterator begin() const { + return iterator(static_cast(node_.get())->data.begin()); + } + /*! \return end iterator */ + inline iterator end() const { + return iterator(static_cast(node_.get())->data.end()); + } + /*! \return rbegin iterator */ + inline reverse_iterator rbegin() const { + return reverse_iterator(static_cast(node_.get())->data.rbegin()); + } + /*! \return rend iterator */ + inline reverse_iterator rend() const { + return reverse_iterator(static_cast(node_.get())->data.rend()); + } +}; + +/*! + * \brief Map container of NodeRef->NodeRef in DSL graph. + * Map implements copy on write semantics, which means map is mutable + * but copy will happen when array is referenced in more than two places. + * + * operator[] only provide const acces, use Set to mutate the content. + * \tparam K The key NodeRef type. + * \tparam V The value NodeRef type. + */ +template::value || + std::is_base_of::value >::type, + typename = typename std::enable_if::value>::type> +class Map : public NodeRef { + public: + /*! + * \brief default constructor + */ + Map() { + node_ = make_node(); + } + /*! + * \brief move constructor + * \param other source + */ + Map(Map && other) { // NOLINT(*) + node_ = std::move(other.node_); + } + /*! + * \brief copy constructor + * \param other source + */ + Map(const Map &other) : NodeRef(other.node_) { // NOLINT(*) + } + /*! + * \brief constructor from pointer + * \param n the container pointer + */ + explicit Map(NodePtr n) : NodeRef(n) {} + /*! + * \brief constructor from iterator + * \param begin begin of iterator + * \param end end of iterator + * \tparam IterType The type of iterator + */ + template + Map(IterType begin, IterType end) { + assign(begin, end); + } + /*! + * \brief constructor from initializer list + * \param init The initalizer list + */ + Map(std::initializer_list > init) { // NOLINT(*) + assign(init.begin(), init.end()); + } + /*! + * \brief constructor from vector + * \param init The vector + */ + template + Map(const std::unordered_map& init) { // NOLINT(*) + assign(init.begin(), init.end()); + } + /*! + * \brief move assign operator + * \param other The source of assignment + * \return reference to self. + */ + Map& operator=(Map && other) { + node_ = std::move(other.node_); + return *this; + } + /*! + * \brief copy assign operator + * \param other The source of assignment + * \return reference to self. + */ + Map& operator=(const Map & other) { + node_ = other.node_; + return *this; + } + /*! + * \brief reset the array to content from iterator. + * \param begin begin of iterator + * \param end end of iterator + * \tparam IterType The type of iterator + */ + template + void assign(IterType begin, IterType end) { + NodePtr n = make_node(); + for (IterType i = begin; i != end; ++i) { + n->data.emplace(std::make_pair(i->first.node_, + i->second.node_)); + } + node_ = std::move(n); + } + /*! + * \brief Read element from map. + * \param key The key + * \return the corresonding element. + */ + inline const V operator[](const K& key) const { + return V(static_cast(node_.get())->data.at(key.node_)); + } + /*! + * \brief Read element from map. + * \param key The key + * \return the corresonding element. + */ + inline const V at(const K& key) const { + return V(static_cast(node_.get())->data.at(key.node_)); + } + /*! \return The size of the array */ + inline size_t size() const { + if (node_.get() == nullptr) return 0; + return static_cast(node_.get())->data.size(); + } + /*! \return The number of elements of the key */ + inline size_t count(const K& key) const { + if (node_.get() == nullptr) return 0; + return static_cast(node_.get())->data.count(key.node_); + } + /*! + * \brief copy on write semantics + * Do nothing if current handle is the unique copy of the array. + * Otherwise make a new copy of the array to ensure the current handle + * hold a unique copy. + * + * \return Handle to the internal node container(which ganrantees to be unique) + */ + inline MapNode* CopyOnWrite() { + if (node_.get() == nullptr || !node_.unique()) { + NodePtr n = make_node(); + n->data = static_cast(node_.get())->data; + NodePtr(std::move(n)).swap(node_); + } + return static_cast(node_.get()); + } + /*! + * \brief set the Map. + * \param key The index key. + * \param value The value to be setted. + */ + inline void Set(const K& key, const V& value) { + MapNode* n = this->CopyOnWrite(); + n->data[key.node_] = value.node_; + } + + /*! \return whether array is empty */ + inline bool empty() const { + return size() == 0; + } + /*! \brief specify container node */ + using ContainerType = MapNode; + + struct Ptr2NodeRef { + using ResultType = std::pair; + static inline ResultType convert(const std::pair< + NodePtr, + NodePtr >& n) { + return std::make_pair(K(n.first), V(n.second)); + } + }; + + using iterator = IterAdapter< + Ptr2NodeRef, MapNode::ContainerType::const_iterator>; + + /*! \return begin iterator */ + inline iterator begin() const { + return iterator(static_cast(node_.get())->data.begin()); + } + /*! \return end iterator */ + inline iterator end() const { + return iterator(static_cast(node_.get())->data.end()); + } + /*! \return begin iterator */ + inline iterator find(const K& key) const { + return iterator(static_cast(node_.get())->data.find(key.node_)); + } +}; + +// specialize of string map +template +class Map : public NodeRef { + public: + // for code reuse + Map() { + node_ = make_node(); + } + Map(Map && other) { // NOLINT(*) + node_ = std::move(other.node_); + } + Map(const Map &other) : NodeRef(other.node_) { // NOLINT(*) + } + explicit Map(NodePtr n) : NodeRef(n) {} + template + Map(IterType begin, IterType end) { + assign(begin, end); + } + Map(std::initializer_list > init) { // NOLINT(*) + assign(init.begin(), init.end()); + } + + template + Map(const std::unordered_map& init) { // NOLINT(*) + assign(init.begin(), init.end()); + } + Map& operator=(Map && other) { + node_ = std::move(other.node_); + return *this; + } + Map& operator=(const Map & other) { + node_ = other.node_; + return *this; + } + template + void assign(IterType begin, IterType end) { + auto n = make_node(); + for (IterType i = begin; i != end; ++i) { + n->data.emplace(std::make_pair(i->first, + i->second.node_)); + } + node_ = std::move(n); + } + inline const V operator[](const std::string& key) const { + return V(static_cast(node_.get())->data.at(key)); + } + inline const V at(const std::string& key) const { + return V(static_cast(node_.get())->data.at(key)); + } + inline size_t size() const { + if (node_.get() == nullptr) return 0; + return static_cast(node_.get())->data.size(); + } + inline size_t count(const std::string& key) const { + if (node_.get() == nullptr) return 0; + return static_cast(node_.get())->data.count(key); + } + inline StrMapNode* CopyOnWrite() { + if (node_.get() == nullptr || !node_.unique()) { + NodePtr n = make_node(); + n->data = static_cast(node_.get())->data; + NodePtr(std::move(n)).swap(node_); + } + return static_cast(node_.get()); + } + inline void Set(const std::string& key, const V& value) { + StrMapNode* n = this->CopyOnWrite(); + n->data[key] = value.node_; + } + inline bool empty() const { + return size() == 0; + } + using ContainerType = StrMapNode; + + struct Ptr2NodeRef { + using ResultType = std::pair; + static inline ResultType convert(const std::pair< + std::string, + NodePtr >& n) { + return std::make_pair(n.first, V(n.second)); + } + }; + + using iterator = IterAdapter< + Ptr2NodeRef, StrMapNode::ContainerType::const_iterator>; + + /*! \return begin iterator */ + inline iterator begin() const { + return iterator(static_cast(node_.get())->data.begin()); + } + /*! \return end iterator */ + inline iterator end() const { + return iterator(static_cast(node_.get())->data.end()); + } + /*! \return begin iterator */ + inline iterator find(const std::string& key) const { + return iterator(static_cast(node_.get())->data.find(key)); + } +}; + +} // namespace tvm +#endif // TVM_NODE_CONTAINER_H_ diff --git a/include/tvm/node/ir_functor.h b/include/tvm/node/ir_functor.h new file mode 100644 index 000000000000..23c5a3fafdab --- /dev/null +++ b/include/tvm/node/ir_functor.h @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/node/ir_functor.h + * \brief Defines the IRFunctor data structures. + */ +#ifndef TVM_NODE_IR_FUNCTOR_H_ +#define TVM_NODE_IR_FUNCTOR_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "node.h" + +namespace tvm { +/*! + * \brief A dynamically dispatched functor on NodeRef in the first argument. + * + * \code + * IRFunctor tostr; + * tostr.set_dispatch([](const Add* op, std::string prefix) { + * return prefix + "Add"; + * }); + * tostr.set_dispatch([](const IntImm* op) { + * return prefix + "IntImm" + * }); + * + * Expr x = make_const(1); + * Expr y = x + x; + * // dispatch to IntImm, outputs "MyIntImm" + * LOG(INFO) << tostr(x, "My"); + * // dispatch to IntImm, outputs "MyAdd" + * LOG(INFO) << tostr(y, "My"); + * \endcode + * + * \tparam FType function signiture + * This type if only defined for FType with function signature + */ +template +class IRFunctor; + +template +class IRFunctor { + private: + using Function = std::function; + using TSelf = IRFunctor; + /*! \brief internal function table */ + std::vector func_; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! + * \brief Whether the functor can dispatch the corresponding Node + * \param n The node to be dispatched + * \return Whether dispatching function is registered for n's type. + */ + inline bool can_dispatch(const NodeRef& n) const { + uint32_t type_index = n.type_index(); + return type_index < func_.size() && func_[type_index] != nullptr; + } + /*! + * \brief invoke the functor , dispatch on type of n + * \param n The Node argument + * \param args The additional arguments + * \return The result. + */ + inline R operator()(const NodeRef& n, Args... args) const { + uint32_t type_index = n.type_index(); + CHECK(type_index < func_.size() && + func_[type_index] != nullptr) + << "IRFunctor calls un-registered function on type " + << Node::TypeIndex2Key(type_index); + return func_[type_index](n, std::forward(args)...); + } + /*! + * \brief set the dispacher for type TNode + * \param f The function to be set. + * \tparam TNode the type of Node to be dispatched. + * \return reference to self. + */ + template + inline TSelf& set_dispatch(Function f) { // NOLINT(*) + uint32_t tindex = Node::TypeKey2Index(TNode::_type_key); + if (func_.size() <= tindex) { + func_.resize(tindex + 1, nullptr); + } + CHECK(func_[tindex] == nullptr) + << "Dispatch for " << Node::TypeIndex2Key(tindex) + << " is already set"; + func_[tindex] = f; + return *this; + } + /*! + * \brief set the dispacher for type TNode + * This allows f to used detailed const Node pointer to replace NodeRef + * + * \param f The function to be set. + * \tparam TNode the type of Node to be dispatched. + * \return reference to self. + */ + template + inline TSelf& set_dispatch(std::function f) { // NOLINT(*) + Function fun = [f](const NodeRef& n, Args... args) { + return f(static_cast(n.node_.get()), + std::forward(args)...); + }; + return this->set_dispatch(fun); + } + /*! + * \brief unset the dispacher for type TNode + * + * \tparam TNode the type of Node to be dispatched. + * \return reference to self. + */ + template + inline TSelf& clear_dispatch() { // NOLINT(*) + uint32_t tindex = Node::TypeKey2Index(TNode::_type_key); + CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range"; + func_[tindex] = nullptr; + return *this; + } +}; + +#if defined(__GNUC__) +#define TVM_ATTRIBUTE_UNUSED __attribute__((unused)) +#else +#define TVM_ATTRIBUTE_UNUSED +#endif + +/*! \brief helper macro to generate string concat */ +#define TVM_STR_CONCAT_(__x, __y) __x##__y +#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y) + +#define TVM_REGISTER_VAR_DEF(ClsName) \ + static TVM_ATTRIBUTE_UNUSED auto & __make_functor ## _ ## ClsName + +/*! + * \brief Useful macro to set IRFunctor dispatch in a global static field. + * + * \code + * // Use IRFunctor to implement IRPrinter similar to Visitor Pattern. + * // vtable allows easy patch in of new Node types, without changing + * // interface of IRPrinter. + * + * class IRPrinter { + * public: + * std::ostream& stream; + * // the dispatch function. + * void print(Expr e) { + * const static FType& f = *vtable(); + * f(e, this); + * } + * + * using FType = IRFunctor; + * // function to return global function table + * static FType& vtable(); + * }; + * + * // in cpp/cc file + * IRPrinter::FType& IRPrinter::vtable() { // NOLINT(*) + * static FType inst; return inst; + * } + * + * TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) + * .set_dispatch([](const Add* n, IRPrinter* p) { + * p->print(n->a); + * p->stream << '+' + * p->print(n->b); + * }); + * + * + * \endcode + * + * \param ClsName The name of the class + * \param FField The static function that returns a singleton of IRFunctor. + */ +#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \ + TVM_STR_CONCAT(TVM_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \ + ClsName::FField() + + /*! + * \brief A container for a list of callbacks. All callbacks are invoked when + * the object is destructed. + */ +class IRFunctorCleanList { + public: + ~IRFunctorCleanList() { + for (auto &f : clean_items) { + f(); + } + } + + void append(std::function func) { + clean_items.push_back(func); + } + + private: + std::vector< std::function > clean_items; +}; + +/*! +* \brief A wrapper around IRFunctor that will record calls to set_dispatch +* and make a corresponding call to clear_dispatch when the last copy of +* the IRFunctorStaticRegistry is destructed. When assigned to a static variable, +* this can be used by NNVM and other libraries to unregister callbacks when +* the library is unloaded. This prevents crashes when the underlying IRFunctor +* is destructed as it will no longer contain std::function instances allocated +* by a library that has been unloaded. +*/ +template +class IRFunctorStaticRegistry; + +template +class IRFunctorStaticRegistry { + private: + IRFunctor *irf_; + std::shared_ptr free_list; + + using TSelf = IRFunctorStaticRegistry; + + public: + IRFunctorStaticRegistry(IRFunctor *irf) { + irf_ = irf; + free_list = std::make_shared(); + } + + template + inline TSelf& set_dispatch(std::function f) { // NOLINT(*) + irf_->template set_dispatch(f); + auto irf_copy = irf_; + free_list.get()->append([irf_copy] { + irf_copy->template clear_dispatch(); + }); + return *this; + } +}; + +/*! +* \brief Helper function for constructing an IRFunctorStaticRegistry. This allows +* the compiler to deduce the template types. +*/ +template +IRFunctorStaticRegistry MakeIRFunctorStaticRegistry( + IRFunctor *irf) { + return IRFunctorStaticRegistry(irf); +} + +#define TVM_AUTO_REGISTER_VAR_DEF(ClsName) \ + static TVM_ATTRIBUTE_UNUSED auto __make_functor ## _ ## ClsName + +/*! +* \brief Macro to set IRFunctor dispatch in a global static field using an IRFunctorStaticRegistry. +* Usage is exactly the same as TVM_STATIC_IR_FUNCTOR. Libraries should use this instead of +* TVM_STATIC_IR_FUNCTOR. +*/ +#define TVM_STATIC_IR_FUNCTOR_REGISTER(ClsName, FField) \ + TVM_STR_CONCAT(TVM_AUTO_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \ + MakeIRFunctorStaticRegistry(&ClsName::FField()) + +} // namespace tvm +#endif // TVM_NODE_IR_FUNCTOR_H_ diff --git a/include/tvm/node/memory.h b/include/tvm/node/memory.h new file mode 100644 index 000000000000..1bba57144e19 --- /dev/null +++ b/include/tvm/node/memory.h @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/node/memory.h + * \brief Node memory management. + */ +#ifndef TVM_NODE_MEMORY_H_ +#define TVM_NODE_MEMORY_H_ + +#include +#include "node.h" + +namespace tvm { +/*! + * \brief Allocate a node object. + * \param args arguments to the constructor. + * \tparam T the node type. + * \return The NodePtr to the allocated object. + */ +template +inline NodePtr make_node(Args&&... args); + +// Detail implementations after this +// +// The current design allows swapping the +// allocator pattern when necessary. +// +// Possible future allocator optimizations: +// - Arena allocator that gives ownership of memory to arena (deleter_= nullptr) +// - Thread-local object pools: one pool per size and alignment requirement. +// - Can specialize by type of object to give the specific allocator to each object. +// +template +class SimpleNodeAllocator { + public: + template + static T* New(Args&&... args) { + return new T(std::forward(args)...); + } + static NodeBase::FDeleter Deleter() { + return Deleter_; + } + + private: + static void Deleter_(NodeBase* ptr) { + delete static_cast(ptr); + } +}; + +template +inline NodePtr make_node(Args&&... args) { + using Allocator = SimpleNodeAllocator; + static_assert(std::is_base_of::value, + "make_node can only be used to create NodeBase"); + T* node = Allocator::New(std::forward(args)...); + node->deleter_ = Allocator::Deleter(); + return NodePtr(node); +} + +} // namespace tvm +#endif // TVM_NODE_MEMORY_H_ diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h new file mode 100644 index 000000000000..79187b8dc4eb --- /dev/null +++ b/include/tvm/node/node.h @@ -0,0 +1,361 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/node/node.h + * \brief Node system data structure. + */ +#ifndef TVM_NODE_NODE_H_ +#define TVM_NODE_NODE_H_ + +#include +#include +#include +#include +#include +#include +#include + + +namespace tvm { +// forward declaration +class DataType; +class Node; +class NodeRef; + +namespace runtime { +// forward declaration +class NDArray; +// forward declaration +class Object; +} // namespace runtime + +/*! + * \brief Visitor class to each node content. + * The content is going to be called for each field. + */ +class TVM_DLL AttrVisitor { + public: +//! \cond Doxygen_Suppress + virtual ~AttrVisitor() = default; + virtual void Visit(const char* key, double* value) = 0; + virtual void Visit(const char* key, int64_t* value) = 0; + virtual void Visit(const char* key, uint64_t* value) = 0; + virtual void Visit(const char* key, int* value) = 0; + virtual void Visit(const char* key, bool* value) = 0; + virtual void Visit(const char* key, std::string* value) = 0; + virtual void Visit(const char* key, void** value) = 0; + virtual void Visit(const char* key, DataType* value) = 0; + virtual void Visit(const char* key, NodeRef* value) = 0; + virtual void Visit(const char* key, runtime::NDArray* value) = 0; + virtual void Visit(const char* key, runtime::Object* value) = 0; + template::value>::type> + void Visit(const char* key, ENum* ptr) { + static_assert(std::is_same::type>::value, + "declare enum to be enum int to use visitor"); + this->Visit(key, reinterpret_cast(ptr)); + } +//! \endcond +}; + +/*! + * \brief base class of node container in DSL AST. + */ +class TVM_DLL Node : public NodeBase { + public: + /*! \brief virtual destructor */ + virtual ~Node() {} + /*! \return The unique type key of the node */ + virtual const char* type_key() const = 0; + /*! + * \brief Apply visitor to each field of the Node + * Visitor could mutate the content of the node. + * override if Node contains attribute fields. + * \param visitor The visitor + */ + virtual void VisitAttrs(AttrVisitor* visitor) {} + /*! \return the type index of the node */ + virtual const uint32_t type_index() const = 0; + /*! + * \brief Whether this node derives from node with type_index=tid. + * Implemented by TVM_DECLARE_NODE_TYPE_INFO + * + * \param tid The type index. + * \return the check result. + */ + virtual const bool _DerivedFrom(uint32_t tid) const; + /*! + * \brief get a runtime unique type index given a type key + * \param type_key Type key of a type. + * \return the corresponding type index. + */ + static uint32_t TypeKey2Index(const char* type_key); + /*! + * \brief get type key from type index. + * \param index The type index + * \return the corresponding type key. + */ + static const char* TypeIndex2Key(uint32_t index); + /*! + * \return whether the type is derived from + */ + template + inline bool derived_from() const; + /*! + * \return whether the node is of type T + * \tparam The type to be checked. + */ + template + inline bool is_type() const; + /*! + * \brief Get a NodePtr that holds reference to this Node. + * \return the NodePtr + */ + inline NodePtr GetNodePtr() const; + // node ref can see this + friend class NodeRef; + static constexpr const char* _type_key = "Node"; +}; + +/*! \brief Base class of all node reference object */ +class NodeRef { + public: + /*! \brief type indicate the container type */ + using ContainerType = Node; + /*! + * \brief Comparator + * \param other Another node ref. + * \return the compare result. + */ + inline bool operator==(const NodeRef& other) const; + /*! + * \brief Comparator + * \param other Another node ref. + * \return the compare result. + */ + inline bool same_as(const NodeRef& other) const; + /*! + * \brief Comparator + * \param other Another node ref. + * \return the compare result. + */ + inline bool operator<(const NodeRef& other) const; + /*! + * \brief Comparator + * \param other Another node ref. + * \return the compare result. + */ + inline bool operator!=(const NodeRef& other) const; + /*! \return the hash function for NodeRef */ + inline size_t hash() const; + /*! \return whether the expression is null */ + inline bool defined() const; + /*! \return the internal type index of IRNode */ + inline uint32_t type_index() const; + /*! \return the internal node pointer */ + inline const Node* get() const; + /*! \return the internal node pointer */ + inline const Node* operator->() const; + /*! + * \brief Downcast this ir node to its actual type (e.g. Add, or + * Select). This returns nullptr if the node is not of the requested + * type. Example usage: + * + * if (const Add *add = node->as()) { + * // This is an add node + * } + * \tparam T the target type, must be subtype of IRNode + */ + template + inline const T *as() const; + /*! + * \brief A more powerful version of as that also works with + * intermediate base types. + * \tparam T the target type, must be subtype of IRNode + */ + template + inline const T *as_derived() const; + /*! \brief default constructor */ + NodeRef() = default; + explicit NodeRef(NodePtr node) : node_(node) {} + /*! \brief the internal node object, do not touch */ + NodePtr node_; +}; + +/*! + * \brief Get a reference type from a Node ptr type + * + * It is always important to get a reference type + * if we want to return a value as reference or keep + * the node alive beyond the scope of the function. + * + * \param ptr The node pointer + * \tparam RefType The reference type + * \tparam NodeType The node type + * \return The corresponding RefType + */ +template +inline RefType GetRef(const NodeType* ptr); + +/*! + * \brief Downcast a base reference type to a more specific type. + * + * \param ref The inptut reference + * \return The corresponding SubRef. + * \tparam SubRef The target specific reference type. + * \tparam BaseRef the current reference type. + */ +template +inline SubRef Downcast(BaseRef ref); + +/*! + * \brief helper macro to declare type information in a base node. + */ +#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \ + const bool _DerivedFrom(uint32_t tid) const override { \ + static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ + if (tidx == tid) return true; \ + return Parent::_DerivedFrom(tid); \ + } + +/*! + * \brief helper macro to declare type information in a terminal node + */ +#define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) \ + const char* type_key() const final { \ + return TypeName::_type_key; \ + } \ + const uint32_t type_index() const final { \ + static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ + return tidx; \ + } \ + const bool _DerivedFrom(uint32_t tid) const final { \ + static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ + if (tidx == tid) return true; \ + return Parent::_DerivedFrom(tid); \ + } + +// implementations of inline functions after this +template +inline bool Node::derived_from() const { + // use static field so query only happens once. + static uint32_t type_id = Node::TypeKey2Index(T::_type_key); + return this->_DerivedFrom(type_id); +} + + +template +inline bool Node::is_type() const { + // use static field so query only happens once. + static uint32_t type_id = Node::TypeKey2Index(T::_type_key); + return type_id == this->type_index(); +} + + +inline NodePtr Node::GetNodePtr() const { + return NodePtr(const_cast(this)); +} + +template +inline RefType GetRef(const NodeType* ptr) { + static_assert(std::is_base_of::value, + "Can only cast to the ref of same container type"); + return RefType(ptr->GetNodePtr()); +} + +template +inline SubRef Downcast(BaseRef ref) { + CHECK(ref->template is_type() || + ref->template derived_from()) + << "Downcast from " << ref->type_key() << " to " + << SubRef::ContainerType::_type_key << " failed."; + return SubRef(std::move(ref.node_)); +} + +inline const Node* NodeRef::get() const { + return node_.get(); +} + +inline const Node* NodeRef::operator->() const { + return node_.get(); +} + +inline bool NodeRef::defined() const { + return node_.get() != nullptr; +} + +inline bool NodeRef::operator==(const NodeRef& other) const { + return node_.get() == other.node_.get(); +} + +inline bool NodeRef::same_as(const NodeRef& other) const { + return node_.get() == other.node_.get(); +} + +inline bool NodeRef::operator<(const NodeRef& other) const { + return node_.get() < other.node_.get(); +} + +inline bool NodeRef::operator!=(const NodeRef& other) const { + return node_.get() != other.node_.get(); +} + +inline size_t NodeRef::hash() const { + return std::hash()(node_.get()); +} + +inline uint32_t NodeRef::type_index() const { + CHECK(node_.get() != nullptr) + << "null type"; + return get()->type_index(); +} + +template +inline const T* NodeRef::as() const { + const Node* ptr = static_cast(get()); + if (ptr && ptr->is_type()) { + return static_cast(ptr); + } + return nullptr; +} + +template +inline const T* NodeRef::as_derived() const { + const Node* ptr = static_cast(get()); + if (ptr && (ptr->is_type() || ptr->derived_from())) { + return static_cast(ptr); + } + return nullptr; +} + +/*! \brief The hash function for nodes */ +struct NodeHash { + size_t operator()(const NodeRef& a) const { + return a.hash(); + } +}; + +/*! \brief The equal comparator for nodes */ +struct NodeEqual { + bool operator()(const NodeRef& a, const NodeRef& b) const { + return a.get() == b.get(); + } +}; +} // namespace tvm +#endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/operation.h b/include/tvm/operation.h index 38dc39bbe7a7..2602b383aab1 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -53,7 +53,7 @@ struct TensorDom { /*! * \brief Base class of all operation nodes */ -class OperationNode : public FunctionBaseNode { +class OperationNode : public ir::FunctionBaseNode { public: /*! \brief optional name of the operation */ std::string name; @@ -463,7 +463,7 @@ class ExternOpNode : public OperationNode { v->Visit("output_placeholders", &output_placeholders); v->Visit("body", &body); } - EXPORT static Operation make(std::string name, + TVM_DLL static Operation make(std::string name, std::string tag, Map attrs, Array inputs, @@ -530,12 +530,12 @@ class HybridOpNode : public OperationNode { v->Visit("axis", &axis); v->Visit("body", &body); } - EXPORT static Operation make(std::string name, - std::string tag, - Map attrs, - Array inputs, - Array outputs, - Stmt body); + TVM_DLL static Operation make(std::string name, + std::string tag, + Map attrs, + Array inputs, + Array outputs, + Stmt body); static constexpr const char* _type_key = "HybridOp"; TVM_DECLARE_NODE_TYPE_INFO(HybridOpNode, OperationNode); diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h index 8bbde878741d..5951594b873c 100644 --- a/include/tvm/packed_func_ext.h +++ b/include/tvm/packed_func_ext.h @@ -70,7 +70,9 @@ struct NodeTypeChecker > { if (!sptr->is_type()) return false; ArrayNode* n = static_cast(sptr); for (const auto& p : n->data) { - if (!NodeTypeChecker::Check(p.get())) return false; + if (!NodeTypeChecker::Check(p.get())) { + return false; + } } return true; } @@ -144,7 +146,7 @@ inline TNodeRef TVMArgValue::AsNodeRef() const { return TNodeRef(sptr); } -inline TVMArgValue::operator HalideIR::Expr() const { +inline TVMArgValue::operator tvm::Expr() const { if (type_code_ == kNull) return Expr(); if (type_code_ == kDLInt) { CHECK_LE(value_.v_int64, std::numeric_limits::max()); @@ -240,21 +242,21 @@ inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const { / } // type related stuffs -inline TVMRetValue& TVMRetValue::operator=(const HalideIR::Type& t) { - return this->operator=(Type2TVMType(t)); +inline TVMRetValue& TVMRetValue::operator=(const DataType& t) { + return this->operator=(t.operator DLDataType()); } -inline TVMRetValue::operator HalideIR::Type() const { - return TVMType2Type(operator TVMType()); +inline TVMRetValue::operator tvm::DataType() const { + return DataType(operator DLDataType()); } -inline TVMArgValue::operator HalideIR::Type() const { - return TVMType2Type(operator TVMType()); +inline TVMArgValue::operator tvm::DataType() const { + return DataType(operator DLDataType()); } inline void TVMArgsSetter::operator()( - size_t i, const HalideIR::Type& t) const { - this->operator()(i, Type2TVMType(t)); + size_t i, const DataType& t) const { + this->operator()(i, t.operator DLDataType()); } } // namespace runtime } // namespace tvm diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 17fd626ee51d..f06b2583127a 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -42,14 +42,6 @@ #include "object.h" #include "node_base.h" -namespace HalideIR { -// Forward declare type for extensions -// The header works fine without depending on this. -struct Type; -struct Expr; -} - - // Whether use TVM runtime in header only mode. #ifndef TVM_RUNTIME_HEADER_ONLY #define TVM_RUNTIME_HEADER_ONLY 0 @@ -58,6 +50,8 @@ struct Expr; namespace tvm { // forward declarations class Integer; +class DataType; +class Expr; namespace runtime { @@ -626,8 +620,8 @@ class TVMArgValue : public TVMPODValue_ { typename = typename std::enable_if< std::is_class::value>::type> inline bool IsNodeType() const; - inline operator HalideIR::Type() const; - inline operator HalideIR::Expr() const; + inline operator tvm::DataType() const; + inline operator tvm::Expr() const; inline operator tvm::Integer() const; // get internal node ptr, if it is node inline NodePtr& node_sptr(); @@ -835,8 +829,8 @@ class TVMRetValue : public TVMPODValue_ { inline TVMRetValue& operator=(const NodeRef& other); inline TVMRetValue& operator=(const NodePtr& other); // type related - inline operator HalideIR::Type() const; - inline TVMRetValue& operator=(const HalideIR::Type& other); + inline operator tvm::DataType() const; + inline TVMRetValue& operator=(const tvm::DataType& other); private: template @@ -1184,7 +1178,7 @@ class TVMArgsSetter { inline void operator()(size_t i, const T& value) const; // NodeRef related extenstions: in tvm/packed_func_ext.h inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*) - inline void operator()(size_t i, const HalideIR::Type& t) const; + inline void operator()(size_t i, const tvm::DataType& t) const; private: /*! \brief The values fields */ diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 659b42aa1afa..ac37f017436e 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -75,24 +75,24 @@ class Stage : public NodeRef { * \brief set the memory scope of the stage * \param scope The memory scope. */ - EXPORT Stage& set_scope(std::string scope); // NOLINT(*) + TVM_DLL Stage& set_scope(std::string scope); // NOLINT(*) /*! * \brief specify the schedule to be computed at the parent schedule's scope. * \param parent The parent schedule. * \param scope The iteration point to carry the schedule. * \return reference to self. */ - EXPORT Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*) + TVM_DLL Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*) /*! * \brief Compute the function inline. * \return reference to self. */ - EXPORT Stage& compute_inline(); // NOLINT(*) + TVM_DLL Stage& compute_inline(); // NOLINT(*) /*! * \brief Compute the function at group root. * \return reference to self. */ - EXPORT Stage& compute_root(); // NOLINT(*) + TVM_DLL Stage& compute_root(); // NOLINT(*) /*! * \brief Bind the IterVar to thread index. * @@ -100,7 +100,7 @@ class Stage : public NodeRef { * \param thread_ivar The thread axis to be bound. * \return reference to self. */ - EXPORT Stage& bind(IterVar ivar, IterVar thread_ivar); + TVM_DLL Stage& bind(IterVar ivar, IterVar thread_ivar); /*! * \brief Set the predicate to determine whether a store to the array should be performed. * Use this when there are multiple threads performing the same store and we only @@ -111,7 +111,7 @@ class Stage : public NodeRef { * \param predicate The condition to be checked. * \return reference to self. */ - EXPORT Stage& set_store_predicate(Expr predicate); + TVM_DLL Stage& set_store_predicate(Expr predicate); /*! * \brief Specify environment threads that launched around the group's scope. * This can only be used in group stage. @@ -120,7 +120,7 @@ class Stage : public NodeRef { * This is a beta feature. * \return reference to self. */ - EXPORT Stage& env_threads(Array threads); + TVM_DLL Stage& env_threads(Array threads); /*! * \brief Split the parent by factor, generate * \param parent The parent iteration domain. @@ -129,7 +129,7 @@ class Stage : public NodeRef { * \param p_inner The result inner domain. * \return reference to self. */ - EXPORT Stage& split(IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner); // NOLINT(*) + TVM_DLL Stage& split(IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner); // NOLINT(*) /*! * \brief Split the iteration with given number of parts. * @@ -139,7 +139,7 @@ class Stage : public NodeRef { * \param p_inner The result inner domain. * \return reference to self. */ - EXPORT Stage& split_by_nparts(IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*) + TVM_DLL Stage& split_by_nparts(IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*) /*! * \brief Fuse the inner outer domain to the target * \param outer The outer domain to be fused. @@ -147,7 +147,7 @@ class Stage : public NodeRef { * \param p_target The result target domain. * \return reference to self. */ - EXPORT Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target); // NOLINT(*) + TVM_DLL Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target); // NOLINT(*) /*! * \brief Fuse all the axes together into a single axis. * @@ -161,13 +161,13 @@ class Stage : public NodeRef { * * \return reference to self. */ - EXPORT Stage& fuse(const Array& axes, IterVar* p_target); // NOLINT(*) + TVM_DLL Stage& fuse(const Array& axes, IterVar* p_target); // NOLINT(*) /*! * \brief Reorder the iteration * \param order The order of iteration variable. * \return reference to self. */ - EXPORT Stage& reorder(const Array& order); // NOLINT(*) + TVM_DLL Stage& reorder(const Array& order); // NOLINT(*) /*! * \brief Perform tiling on two dimensions * The final loop order from outmost to inner most are @@ -183,7 +183,7 @@ class Stage : public NodeRef { * \param p_y_inner Inner axis of y dimension * \return reference to self. */ - EXPORT Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*) + TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*) Expr x_factor, Expr y_factor, IterVar* p_x_outer, IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner); @@ -192,7 +192,7 @@ class Stage : public NodeRef { * \param var The axis to be vectorized. * \return reference to self. */ - EXPORT Stage& vectorize(IterVar var); // NOLINT(*) + TVM_DLL Stage& vectorize(IterVar var); // NOLINT(*) /*! * \brief Replace computation of the current stage by tensor intrinsic f. * \param var The axis marks beginning of tensorization. @@ -200,19 +200,19 @@ class Stage : public NodeRef { * \param f The Tensor compute intrinsics. * \return reference to self. */ - EXPORT Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*) + TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*) /*! * \brief Unroll iteration. * \param var The axis to be unrolled. * \return reference to self. */ - EXPORT Stage& unroll(IterVar var); // NOLINT(*) + TVM_DLL Stage& unroll(IterVar var); // NOLINT(*) /*! * \brief Parallelize iteration. * \param var The axis to be parallelized. * \return reference to self. */ - EXPORT Stage& parallel(IterVar var); // NOLINT(*) + TVM_DLL Stage& parallel(IterVar var); // NOLINT(*) /*! * \brief Annotate the iteration with pragma * @@ -222,7 +222,7 @@ class Stage : public NodeRef { * * \return reference to self. */ - EXPORT Stage& pragma(IterVar var, + TVM_DLL Stage& pragma(IterVar var, const std::string& pragma_type, const Expr& pragma_value = Expr()); // NOLINT(*) /*! @@ -232,7 +232,7 @@ class Stage : public NodeRef { * \param offset the number of iterations be to fetched in advance * \return reference to self */ - EXPORT Stage& prefetch(const Tensor &domain, IterVar var, Expr offset); //NOLINT(*) + TVM_DLL Stage& prefetch(const Tensor &domain, IterVar var, Expr offset); //NOLINT(*) /*! * \brief Set alignment requirement for specific dimension. * @@ -243,12 +243,12 @@ class Stage : public NodeRef { * \param offset The required offset factor. * \return reference to self */ - EXPORT Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*) + TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*) /*! * \brief Compute current stage with double buffering. * \return reference to self. */ - EXPORT Stage& double_buffer(); // NOLINT(*) + TVM_DLL Stage& double_buffer(); // NOLINT(*) /*! * \brief Schedule for OpenGL fragment shader. * \return reference to self. @@ -289,13 +289,13 @@ class Schedule : public NodeRef { * \brief Get the stage corresponds to the op * \param op The operation. */ - EXPORT Stage operator[](const Operation& op); + TVM_DLL Stage operator[](const Operation& op); /*! * \brief Short hand for getting the stage of tensor's operation. * \param tensor The tensor * \return The stage corresponding to the tensor's op */ - EXPORT Stage operator[](const Tensor& tensor) { + TVM_DLL Stage operator[](const Tensor& tensor) { return this->operator[](tensor->op); } /*! @@ -307,7 +307,7 @@ class Schedule : public NodeRef { * \param include_inputs Whether include inputs if they are reachable from outputs. * \return The new grouped stage. */ - EXPORT Stage create_group(const Array& outputs, + TVM_DLL Stage create_group(const Array& outputs, const Array& inputs, bool include_inputs = false); /*! @@ -319,7 +319,7 @@ class Schedule : public NodeRef { * \param readers The readers to redirect to the tensor. * \return The created tensor. */ - EXPORT Tensor cache_read(const Tensor& tensor, + TVM_DLL Tensor cache_read(const Tensor& tensor, const std::string& scope, const Array& readers); /*! @@ -338,7 +338,7 @@ class Schedule : public NodeRef { * \param scope The scope of the storage. * \return The created tensor. */ - EXPORT Array cache_write(const Array& tensor, const std::string& scope); + TVM_DLL Array cache_write(const Array& tensor, const std::string& scope); /*! * \brief Create a cache write tensor for producing tensor. * The the tensor will take over body of original tensor op. @@ -355,7 +355,7 @@ class Schedule : public NodeRef { * \param scope The scope of the storage. * \return The created tensor. */ - EXPORT Tensor cache_write(const Tensor& tensor, const std::string& scope); + TVM_DLL Tensor cache_write(const Tensor& tensor, const std::string& scope); /*! * \brief Factor a reduction axis in tensor's schedule to be an explicit axis. * This will create a new stage that generated the new tensor with axis @@ -369,7 +369,7 @@ class Schedule : public NodeRef { * \param factor_axis The position where the new axis is placed. * \return The created factored tensors. */ - EXPORT Array rfactor(const Tensor& tensor, + TVM_DLL Array rfactor(const Tensor& tensor, const IterVar& axis, int factor_axis = 0); /*! @@ -556,14 +556,14 @@ class ScheduleNode : public Node { * \param op The candidate Operation. * \return true if the schedule has the Operation. Otherwise, false. */ - EXPORT bool Contain(const Operation& op) const; + TVM_DLL bool Contain(const Operation& op) const; /*! * \brief Check if the schedule contains a Tensor. * \param tensor The candidate tensor. * \return true if the schedule has the tensor. Otherwise, false. */ - EXPORT bool Contain(const Tensor& tensor) const { + TVM_DLL bool Contain(const Tensor& tensor) const { return Contain(tensor->op); } @@ -572,7 +572,7 @@ class ScheduleNode : public Node { * \param ops The ops to be scheduled. * \return sch The created Schedule. */ - EXPORT static Schedule make(Array ops); + TVM_DLL static Schedule make(Array ops); static constexpr const char* _type_key = "Schedule"; TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node); diff --git a/include/tvm/schedule_pass.h b/include/tvm/schedule_pass.h index 27444ab693cd..3187e2b17727 100644 --- a/include/tvm/schedule_pass.h +++ b/include/tvm/schedule_pass.h @@ -70,7 +70,7 @@ void AutoInlineElemWise(Schedule sch); * * \param sch The schedule to be inlined. */ -EXPORT void AutoInlineInjective(Schedule sch); +TVM_DLL void AutoInlineInjective(Schedule sch); } // namespace schedule } // namespace tvm diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h index c6be52181f6c..2b33eea3c9c4 100644 --- a/include/tvm/tensor.h +++ b/include/tvm/tensor.h @@ -24,7 +24,6 @@ #ifndef TVM_TENSOR_H_ #define TVM_TENSOR_H_ -#include #include #include #include @@ -43,8 +42,6 @@ class TensorNode; // internal node container for Operation class OperationNode; -using HalideIR::IR::FunctionRef; - /*! * \brief Tensor structure representing a possible input, * or intermediate computation result. @@ -140,7 +137,7 @@ class Tensor : public NodeRef { }; /*! \brief Operation that produces tensors */ -class Operation : public FunctionRef { +class Operation : public ir::FunctionRef { public: /*! \brief default constructor */ Operation() {} diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 2525059b47ba..e8e43786d3fd 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -59,14 +59,13 @@ TVM_REGISTER_API("make._range_by_min_extent") TVM_REGISTER_API("make.For") .set_body_typed([]( VarExpr loop_var, Expr min, Expr extent, - int for_type, int device_api, Stmt body -) { + int for_type, int device_api, Stmt body) { return For::make(loop_var, - min, - extent, - static_cast(for_type), - static_cast(device_api), - body); + min, + extent, + static_cast(for_type), + static_cast(device_api), + body); }); TVM_REGISTER_API("make.Load") diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 00ac715e8c07..aa0ce47b4a37 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2016 by Contributors * Implementation of API functions related to Higher DSL build. * \file api_lang.cc */ @@ -36,10 +35,10 @@ namespace tvm { TVM_REGISTER_API("_min_value") -.set_body_method(&Type::min); +.set_body_method(&DataType::min); TVM_REGISTER_API("_max_value") -.set_body_method(&Type::max); +.set_body_method(&DataType::max); TVM_REGISTER_API("_const") .set_body([](TVMArgs args, TVMRetValue* ret) { diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index 1a93c8b3f7d8..3cc642788911 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -52,12 +52,6 @@ class CanonicalExprNode : public BaseExprNode { // overrides void VisitAttrs(tvm::AttrVisitor* v) final { } - void accept(HalideIR::Internal::IRVisitor* v, const Expr& e) const final { - LOG(FATAL) << "not supported"; - } - IRNodeType type_info() const final { - return IRNodeType::ExtensionExpr; - } static constexpr const char* _type_key = "arith.CanonicalExpr"; TVM_DECLARE_BASE_NODE_INFO(CanonicalExprNode, BaseExprNode); diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index b36d8f5625a1..84b452cd7043 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -125,7 +125,7 @@ class ConstIntBoundAnalyzer::Impl : // Override visitor behaviors Entry VisitExprDefault_(const Node* op) final { return Everything( - static_cast(op)->type); + static_cast(op)->type); } Entry VisitExpr(const Expr& expr) final { diff --git a/src/codegen/codegen_opengl.cc b/src/codegen/codegen_opengl.cc index 797a7d1c406e..5b14abbcf7dc 100644 --- a/src/codegen/codegen_opengl.cc +++ b/src/codegen/codegen_opengl.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -224,15 +224,15 @@ std::string CodeGenOpenGL::GetBufferRef( void CodeGenOpenGL::PrintType(Type t, std::ostream& os) { switch (t.code()) { - case halideir_type_int: + case kDLInt: CHECK_EQ(t.bits(), 32) << "Only support 32-bit int."; os << "int"; break; - case halideir_type_uint: + case kDLUInt: CHECK_EQ(t.bits(), 32) << "Only support 32-bit uint."; os << "uint"; break; - case halideir_type_float: + case kDLFloat: CHECK_EQ(t.bits(), 32) << "Only support 32-bit float."; os << "float"; break; diff --git a/src/lang/expr.cc b/src/lang/expr.cc index fed604922ee3..11b72c71fda7 100644 --- a/src/lang/expr.cc +++ b/src/lang/expr.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,19 +18,94 @@ */ /*! - * Copyright (c) 2016 by Contributors * \file expr.cc */ #include #include #include #include -#include #include +#include namespace tvm { -using HalideIR::IR::RangeNode; +// maximum and min values +Expr DataType::max() const { + using namespace ir; + CHECK_EQ(lanes(), 1); + if (is_int()) { + if (bits() == 64) { + return IntImm::make(*this, std::numeric_limits::max()); + } else if (bits() < 64) { + int64_t val = 1; + val = (val << (bits() - 1)) - 1; + return IntImm::make(*this, val); + } + } else if (is_uint()) { + if (bits() == 64) { + return UIntImm::make(*this, std::numeric_limits::max()); + } else if (bits() < 64) { + uint64_t val = 1; + val = (val << static_cast(bits())) - 1; + return UIntImm::make(*this, val); + } + } else if (is_float()) { + if (bits() == 64) { + return FloatImm::make(*this, std::numeric_limits::max()); + } else if (bits() == 32) { + return FloatImm::make(*this, std::numeric_limits::max()); + } else if (bits() == 16) { + return FloatImm::make(*this, 65504.0); + } + } + LOG(FATAL) << "Cannot decide max_value for type" << *this; + return Expr(); +} + +Expr DataType::min() const { + using namespace ir; + CHECK_EQ(lanes(), 1); + if (is_int()) { + if (bits() == 64) { + return IntImm::make(*this, std::numeric_limits::lowest()); + } else if (bits() < 64) { + int64_t val = 1; + val = -(val << (bits() - 1)); + return IntImm::make(*this, val); + } + } else if (is_uint()) { + return UIntImm::make(*this, 0); + } else if (is_float()) { + if (bits() == 64) { + return FloatImm::make(*this, std::numeric_limits::lowest()); + } else if (bits() == 32) { + return FloatImm::make(*this, std::numeric_limits::lowest()); + } else if (bits() == 16) { + return FloatImm::make(*this, -65504.0); + } + } + LOG(FATAL) << "Cannot decide min_value for type" << *this; + return Expr(); +} + +Expr::Expr(int32_t value) + : Expr(IntImm::make(Int(32), value)) {} + +Expr::Expr(float value) + : Expr(ir::FloatImm::make(Float(32), value)) {} + +Expr::Expr(std::string str) + : Expr(ir::StringImm::make(str)) {} + +Var::Var(std::string name_hint, DataType t) + : Var(Variable::make(t, name_hint)) {} + +Var Variable::make(DataType t, std::string name_hint) { + NodePtr node = make_node(); + node->type = t; + node->name_hint = std::move(name_hint); + return Var(node); +} Range::Range(Expr begin, Expr end) : Range(make_node( @@ -38,12 +113,23 @@ Range::Range(Expr begin, Expr end) is_zero(begin) ? end : (end - begin))) { } +Integer IntImm::make(Type t, int64_t value) { + CHECK(t.is_int() && t.is_scalar()) + << "ValueError: IntImm can only take scalar."; + NodePtr node = make_node(); + node->type = t; + node->value = value; + return Integer(node); +} + Range Range::make_by_min_extent(Expr min, Expr extent) { - return Range(make_node(min, extent)); + return Range(make_node(min, extent)); } -IterVar IterVarNode::make(Range dom, Var var, - IterVarType t, std::string thread_tag) { +IterVar IterVarNode::make(Range dom, + Var var, + IterVarType t, + std::string thread_tag) { NodePtr n = make_node(); n->dom = dom; n->var = var; @@ -62,19 +148,48 @@ IterVar reduce_axis(Range dom, std::string name) { dom, Var(name), kCommReduce); } -std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*) - IRPrinter(os).print(n); - return os; -} - void Dump(const NodeRef& n) { std::cerr << n << "\n"; } -Var var(const std::string& name_hint, Type t) { +Var var(std::string name_hint, Type t) { return Var(name_hint, t); } +void IRPrinter::Print(const NodeRef& ir) { + static const FType& f = vtable(); + if (!ir.defined()) { + stream << "(nullptr)"; + } else { + if (f.can_dispatch(ir)) { + f(ir, this); + } else { + // default value, output type key and addr. + stream << ir->type_key() << "(" << ir.get() << ")"; + } + } +} + +void IRPrinter::PrintIndent() { + for (int i = 0; i < indent; ++i) { + stream << ' '; + } +} + +IRPrinter::FType& IRPrinter::vtable() { + static FType inst; + return inst; +} + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const IntImm *op, IRPrinter *p) { + if (op->type == Int(32)) { + p->stream << op->value; + } else { + p->stream << "(" << op->type << ")" << op->value; + } + }); + TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch([](const IterVarNode *op, IRPrinter *p) { p->stream << "iter_var("; @@ -91,11 +206,10 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const HalideIR::IR::RangeNode *op, IRPrinter *p) { +.set_dispatch([](const RangeNode* op, IRPrinter* p) { p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; }); - TVM_REGISTER_NODE_TYPE(ArrayNode); TVM_REGISTER_NODE_TYPE(MapNode); TVM_REGISTER_NODE_TYPE(StrMapNode); diff --git a/src/lang/ir.cc b/src/lang/ir.cc index 4eeddd91d80c..0557e287986f 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -18,65 +18,231 @@ */ /*! - * Copyright (c) 2016 by Contributors * \file ir.cc */ #include #include #include #include -#include -#include #include #include "../pass/ir_util.h" -namespace HalideIR { -namespace Internal { +namespace tvm { +namespace ir { -using tvm::ir::CommReducerNode; -using tvm::ir::Reduce; -using tvm::ir::Any; -using tvm::ir::AttrStmt; +// constructors +Expr UIntImm::make(DataType t, uint64_t value) { + CHECK(t.is_uint() && t.lanes() == 1) + << "ValueError: UIntImm can only take scalar"; + NodePtr node = make_node(); + node->type = t; + node->value = value; + return Expr(node); +} -template<> -void ExprNode::accept(IRVisitor *v, const Expr&) const { - LOG(FATAL) << "Reduce does not work with old Visitor, use IRFunctor style visitor"; +Expr FloatImm::make(DataType t, double value) { + CHECK_EQ(t.lanes(), 1) + << "ValueError: FloatImm can only take scalar"; + NodePtr node = make_node(); + node->type = t; + node->value = value; + return Expr(node); } -template<> -void ExprNode::accept(IRVisitor *v, const Expr&) const { - LOG(FATAL) << "Any does not work with old Visitor, use IRFunctor style visitor"; +Expr StringImm::make(std::string value) { + NodePtr node = make_node(); + node->type = Handle(); + node->value = std::move(value); + return Expr(node); } -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const Any *op, IRPrinter *p) { - p->stream << "?"; -}); +Expr Cast::make(DataType t, Expr value) { + CHECK(value.defined()); + CHECK_EQ(t.lanes(), value.type().lanes()); -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const Reduce *op, IRPrinter *p) { - p->stream << "reduce(combiner=" - << op->combiner; - p->stream << ", source=" << op->source; - p->stream << ", axis=" << op->axis; - p->stream << ", where=" << op->condition; - p->stream << ", value_index=" << op->value_index; - p->stream << ")"; -}); + NodePtr node = make_node(); + node->type = t; + node->value = std::move(value); + return Expr(node); +} -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const CommReducerNode *op, IRPrinter *p) { - p->stream << "comm_reducer(result=" << op->result - << ", lhs=" << op->lhs - << ", rhs=" << op->rhs - << ", identity_element=" << op->identity_element - << ")"; -}); -} // namespace Internal -} // namespace HalideIR -namespace tvm { -namespace ir { +Expr And::make(Expr a, Expr b) { + CHECK(a.defined()) << "ValueError: a is undefined"; + CHECK(b.defined()) << "ValueError: b is undefined"; + CHECK(a.type().is_bool()); + CHECK(b.type().is_bool()); + CHECK(a.type() == b.type()) << "TypeError: mismatched types"; + + NodePtr node = make_node(); + node->type = Bool(a.type().lanes()); + node->a = std::move(a); + node->b = std::move(b); + return Expr(node); +} + +Expr Or::make(Expr a, Expr b) { + CHECK(a.defined()) << "ValueError: a is undefined"; + CHECK(b.defined()) << "ValueError: b is undefined"; + CHECK(a.type().is_bool()); + CHECK(b.type().is_bool()); + CHECK(a.type() == b.type()) << "TypeError: mismatched types"; + + NodePtr node = make_node(); + node->type = Bool(a.type().lanes()); + node->a = std::move(a); + node->b = std::move(b); + return Expr(node); +} + +Expr Not::make(Expr a) { + CHECK(a.defined()) << "ValueError: a is undefined"; + CHECK(a.type().is_bool()); + + NodePtr node = make_node(); + node->type = Bool(a.type().lanes()); + node->a = std::move(a); + return Expr(node); +} + +Expr Select::make(Expr condition, Expr true_value, Expr false_value) { + CHECK(condition.defined()) << "ValueError: condition is undefined"; + CHECK(true_value.defined()) << "ValueError: true_value is undefined"; + CHECK(false_value.defined()) << "ValueError: true_value is undefined"; + CHECK(condition.type().is_bool()); + CHECK_EQ(condition.type().lanes(), true_value.type().lanes()); + CHECK(false_value.type() == true_value.type()) << "TypeError: mismatched types"; + + NodePtr(); + node->type = true_value.type(); + node->condition = std::move(condition); + node->true_value = std::move(true_value); + node->false_value = std::move(false_value); + return Expr(node); +} + +Expr Load::make(DataType type, Var buffer_var, Expr index, Expr predicate) { + CHECK(buffer_var.defined()); + CHECK(predicate.defined()); + CHECK(index.defined()); + CHECK_EQ(type.lanes(), index.type().lanes()); + CHECK_EQ(type.lanes(), predicate.type().lanes()); + + NodePtr node = make_node(); + node->type = type; + node->buffer_var = std::move(buffer_var); + node->index = std::move(index); + node->predicate = std::move(predicate); + + return Expr(node); +} + +Expr Ramp::make(Expr base, Expr stride, int lanes) { + CHECK(base.defined()); + CHECK(stride.defined()); + CHECK(base.type().is_scalar()); + CHECK(stride.type().is_scalar()); + CHECK_GT(lanes, 1); + CHECK_EQ(stride.type(), base.type()); + + NodePtr node = make_node(); + node->type = base.type().with_lanes(lanes); + node->base = base; + node->stride = stride; + node->lanes = lanes; + return Expr(node); +} + +Expr Broadcast::make(Expr value, int lanes) { + CHECK(value.defined()); + CHECK(value.type().is_scalar()); + CHECK_GT(lanes, 1); + + NodePtr node = make_node(); + node->type = value.type().with_lanes(lanes); + node->value = std::move(value); + node->lanes = lanes; + return Expr(node); +} + +Expr Let::make(Var var, Expr value, Expr body) { + CHECK(value.defined()); + CHECK(body.defined()); + CHECK_EQ(value.type(), var.type()); + + NodePtr node = make_node(); + node->type = body.type(); + node->var = std::move(var); + node->value = std::move(value); + node->body = std::move(body); + return Expr(node); +} + +Expr Call::make(DataType type, + std::string name, + Array args, + CallType call_type, + FunctionRef func, + int value_index) { + for (size_t i = 0; i < args.size(); ++i) { + CHECK(args[i].defined()); + } + + if (call_type == Halide) { + for (size_t i = 0; i < args.size(); ++i) { + CHECK(args[i].type().is_int()); + } + } + + NodePtr node = make_node(); + node->type = type; + node->name = std::move(name); + node->args = std::move(args); + node->call_type = call_type; + node->func = std::move(func); + node->value_index = value_index; + return Expr(node); +} + +Expr Shuffle::make(Array vectors, + Array indices) { + CHECK_NE(vectors.size(), 0U); + CHECK_NE(indices.size(), 0U); + + Type base_type = vectors[0].type().element_of(); + int total_lanes = 0; + + for (Expr val : vectors) { + CHECK(val.type().element_of() == base_type); + total_lanes += val.type().lanes(); + } + CHECK_LE(indices.size(), static_cast(total_lanes)); + + NodePtr node = make_node(); + node->type = base_type.with_lanes(static_cast(indices.size())); + node->vectors = std::move(vectors); + node->indices = std::move(indices); + return Expr(node); +} + +Expr Shuffle::make_concat(Array vectors) { + CHECK_NE(vectors.size(), 0); + if (vectors.size() == 1) { + return vectors[0]; + } + Array indices; + int index = 0; + for (const Expr& e : vectors) { + for (int i = 0; i < e.type().lanes(); ++i) { + indices.push_back(IntImm::make(Int(32), index++)); + } + } + return make(vectors, indices); +} + +Expr Shuffle::make_extract_element(Expr vector, int index) { + return make({vector}, {Integer(index)}); +} CommReducer CommReducerNode::make(Array lhs, Array rhs, @@ -132,6 +298,802 @@ Expr Any::make() { return Expr(n); } +Stmt LetStmt::make(Var var, Expr value, Stmt body) { + CHECK(value.defined()); + CHECK(body.defined()); + CHECK_EQ(value.type(), var.type()); + + NodePtr node = make_node(); + node->var = std::move(var); + node->value = std::move(value); + node->body = std::move(body); + return Stmt(node); +} + +Stmt AttrStmt::make(NodeRef node, + std::string attr_key, + Expr value, + Stmt body) { + auto n = make_node(); + n->node = node; + n->attr_key = std::move(attr_key); + n->value = std::move(value); + n->body = std::move(body); + return Stmt(n); +} + +Stmt AssertStmt::make(Expr condition, Expr message, Stmt body) { + CHECK(condition.defined()); + CHECK(message.type() == Int(32) || + message.as()) + << "TypeError: AssertStmt message must be an int or string:" + << message << "\n"; + + NodePtr node = make_node(); + node->condition = std::move(condition); + node->message = std::move(message); + node->body = std::move(body); + return Stmt(node); +} + +Stmt ProducerConsumer::make(FunctionRef func, bool is_producer, Stmt body) { + CHECK(body.defined()); + + NodePtr node = make_node(); + node->func = std::move(func); + node->is_producer = is_producer; + node->body = std::move(body); + return Stmt(node); +} + +Stmt For::make(Var loop_var, + Expr min, + Expr extent, + ForType for_type, + DeviceAPI device_api, + Stmt body) { + CHECK(min.defined()); + CHECK(extent.defined()); + CHECK(min.type().is_scalar()); + CHECK(extent.type().is_scalar()); + CHECK(loop_var.type().is_scalar()); + CHECK(body.defined()); + + NodePtr node = make_node(); + node->loop_var = std::move(loop_var); + node->min = std::move(min); + node->extent = std::move(extent); + node->for_type = for_type; + node->device_api = device_api; + node->body = std::move(body); + return Stmt(node); +} + +Stmt Store::make(Var buffer_var, Expr value, Expr index, Expr predicate) { + CHECK(value.defined()); + CHECK(index.defined()); + CHECK(predicate.defined()); + CHECK_EQ(value.type().lanes(), index.type().lanes()); + CHECK_EQ(value.type().lanes(), predicate.type().lanes()); + + NodePtr node = make_node(); + node->buffer_var = std::move(buffer_var); + node->value = std::move(value); + node->index = std::move(index); + node->predicate = std::move(predicate); + return Stmt(node); +} + +Stmt Provide::make(FunctionRef func, int value_index, Expr value, Array args) { + CHECK(value_index >=0 && value_index < func->num_outputs()) + << "value index output function return value bound"; + CHECK(value.defined()) << "Provide of undefined value\n"; + + for (size_t i = 0; i < args.size(); ++i) { + CHECK(args[i].defined()) << "Provide to undefined location\n"; + } + + NodePtr node = make_node(); + node->func = std::move(func); + node->value_index = value_index; + node->value = std::move(value); + node->args = std::move(args); + return Stmt(node); +} + +Stmt Allocate::make(Var buffer_var, + DataType type, + Array extents, + Expr condition, + Stmt body, + Expr new_expr, + std::string free_function) { + for (size_t i = 0; i < extents.size(); ++i) { + CHECK(extents[i].defined()); + CHECK(extents[i].type().is_scalar()); + } + CHECK(body.defined()); + CHECK(condition.defined()); + CHECK(condition.type().is_bool()); + + NodePtr node = make_node(); + node->buffer_var = std::move(buffer_var); + node->type = type; + node->extents = std::move(extents); + node->condition = std::move(condition); + node->body = std::move(body); + node->new_expr = std::move(new_expr); + node->free_function = std::move(free_function); + return Stmt(node); +} + +int32_t Allocate::constant_allocation_size(const Array& extents) { + int64_t result = 1; + for (size_t i = 0; i < extents.size(); ++i) { + if (const IntImm *int_size = extents[i].as()) { + result *= int_size->value; + if (result > std::numeric_limits::max()) { + return 0; + } + } else { + return 0; + } + } + return static_cast(result); +} + +Stmt Free::make(Var buffer_var) { + NodePtr node = make_node(); + node->buffer_var = buffer_var; + return Stmt(node); +} + +Stmt Realize::make(FunctionRef func, + int value_index, + DataType type, + Region bounds, + Expr condition, + Stmt body) { + for (size_t i = 0; i < bounds.size(); ++i) { + CHECK(bounds[i]->min.defined()); + CHECK(bounds[i]->extent.defined()); + CHECK(bounds[i]->min.type().is_scalar()); + CHECK(bounds[i]->extent.type().is_scalar()); + } + CHECK(body.defined()); + CHECK(condition.defined()); + CHECK(condition.type().is_bool()); + + NodePtr node = make_node(); + node->func = std::move(func); + node->value_index = value_index; + node->type = type; + node->bounds = std::move(bounds); + node->condition = std::move(condition); + node->body = std::move(body); + return Stmt(node); +} + +Stmt Prefetch::make(FunctionRef func, int value_index, DataType type, Region bounds) { + for (size_t i = 0; i < bounds.size(); ++i) { + CHECK(bounds[i]->min.defined()); + CHECK(bounds[i]->extent.defined()); + CHECK(bounds[i]->min.type().is_scalar()); + CHECK(bounds[i]->extent.type().is_scalar()); + } + + NodePtr node = make_node(); + node->func = std::move(func); + node->value_index = value_index; + node->type = type; + node->bounds = std::move(bounds); + return Stmt(node); +} + +Stmt Block::make(Stmt first, Stmt rest) { + CHECK(first.defined()); + CHECK(rest.defined()); + NodePtr node = make_node(); + + // canonicalize. + if (const Block* b = first.as()) { + node->first = b->first; + node->rest = Block::make(b->rest, rest); + } else { + node->first = std::move(first); + node->rest = std::move(rest); + } + return Stmt(node); +} + +Stmt Block::make(const std::vector& stmts) { + if (stmts.empty()) { + return Stmt(); + } + Stmt result = stmts.back(); + for (size_t i = stmts.size() - 1; i != 0; --i) { + result = Block::make(stmts[i - 1], result); + } + return result; +} + +Stmt IfThenElse::make(Expr condition, Stmt then_case, Stmt else_case) { + CHECK(condition.defined()); + CHECK(then_case.defined()); + // else_case may be null. + + NodePtr node = make_node(); + node->condition = std::move(condition); + node->then_case = std::move(then_case); + node->else_case = std::move(else_case); + return Stmt(node); +} + +Stmt Evaluate::make(Expr value) { + CHECK(value.defined()); + + NodePtr node = make_node(); + node->value = std::move(value); + return Stmt(node); +} + +// Printers +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const UIntImm* op, IRPrinter* p) { + p->stream << "(" << op->type << ")" << op->value; + }); + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const FloatImm* op, IRPrinter* p) { + auto& stream = p->stream; + switch (op->type.bits()) { + case 64: + stream << op->value; + break; + case 32: + stream << op->value << 'f'; + break; + case 16: + stream << op->value << 'h'; + break; + default: + LOG(FATAL) << "Unknown float type bits=" << op->type.bits(); + } + }); + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const StringImm* op, IRPrinter* p) { + auto& stream = p->stream; + stream << '"'; + for (size_t i = 0; i < op->value.size(); ++i) { + unsigned char c = op->value[i]; + if (c >= ' ' && c <= '~' && c != '\\' && c != '"') { + stream << c; + } else { + stream << '\\'; + switch (c) { + case '"': + stream << '"'; + break; + case '\\': + stream << '\\'; + break; + case '\t': + stream << 't'; + break; + case '\r': + stream << 'r'; + break; + case '\n': + stream << 'n'; + break; + default: + const char* hex_digits = "0123456789ABCDEF"; + stream << 'x' << hex_digits[c >> 4] << hex_digits[c & 0xf]; + } + } + } + stream << '"'; + }); + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const Cast* op, IRPrinter* p) { + p->stream << op->type << '('; + p->Print(op->value); + p->stream << ')'; + }) +.set_dispatch([](const Variable* op, IRPrinter* p) { + // omit the type + // stream << op->name << "." << op->type; + p->stream << op->name_hint; + }) +.set_dispatch([](const Add* op, IRPrinter* p) { + p->stream << '('; + p->Print(op->a); + p->stream << " + "; + p->Print(op->b); + p->stream << ')'; + }) +.set_dispatch([](const Sub* op, IRPrinter* p) { + p->stream << '('; + p->Print(op->a); + p->stream << " - "; + p->Print(op->b); + p->stream << ')'; + }) +.set_dispatch([](const Mul* op, IRPrinter* p) { + p->stream << '('; + p->Print(op->a); + p->stream << "*"; + p->Print(op->b); + p->stream << ')'; + }) +.set_dispatch
([](const Div* op, IRPrinter* p) { + p->stream << '('; + p->Print(op->a); + p->stream << "/"; + p->Print(op->b); + p->stream << ')'; + }) +.set_dispatch([](const Mod* op, IRPrinter* p) { + p->stream << '('; + p->Print(op->a); + p->stream << " % "; + p->Print(op->b); + p->stream << ')'; +}) +.set_dispatch([](const Min* op, IRPrinter* p) { + p->stream << "min("; + p->Print(op->a); + p->stream << ", "; + p->Print(op->b); + p->stream << ")"; +}) +.set_dispatch([](const Max* op, IRPrinter* p) { + p->stream << "max("; + p->Print(op->a); + p->stream << ", "; + p->Print(op->b); + p->stream << ")"; +}) +.set_dispatch([](const EQ* op, IRPrinter* p) { + p->stream << '('; + p->Print(op->a); + p->stream << " == "; + p->Print(op->b); + p->stream << ')'; +}) +.set_dispatch([](const NE* op, IRPrinter* p) { + p->stream << '('; + p->Print(op->a); + p->stream << " != "; + p->Print(op->b); + p->stream << ')'; +}) +.set_dispatch([](const LT* op, IRPrinter* p) { + p->stream << '('; + p->Print(op->a); + p->stream << " < "; + p->Print(op->b); + p->stream << ')'; +}) +.set_dispatch([](const LE* op, IRPrinter* p) { + p->stream << '('; + p->Print(op->a); + p->stream << " <= "; + p->Print(op->b); + p->stream << ')'; +}) +.set_dispatch([](const GT* op, IRPrinter* p) { + p->stream << '('; + p->Print(op->a); + p->stream << " > "; + p->Print(op->b); + p->stream << ')'; +}) +.set_dispatch([](const GE* op, IRPrinter* p) { + p->stream << '('; + p->Print(op->a); + p->stream << " >= "; + p->Print(op->b); + p->stream << ')'; +}); + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const And* op, IRPrinter* p) { + p->stream << '('; + p->Print(op->a); + p->stream << " && "; + p->Print(op->b); + p->stream << ')'; +}); + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const Or* op, IRPrinter* p) { + p->stream << '('; + p->Print(op->a); + p->stream << " || "; + p->Print(op->b); + p->stream << ')'; +}); + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const Not* op, IRPrinter* p) { + p->stream << '!'; + p->Print(op->a); +}); + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch