diff --git a/CMakeLists.txt b/CMakeLists.txt index e55b7174dc8e..88bf6472ce9b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -289,6 +289,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/driver/*.cc src/support/*.cc src/script/*.cc + src/relax/ir/*.cc src/relax/backend/vm/*.cc ) diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h new file mode 100644 index 000000000000..d21c8db86b3f --- /dev/null +++ b/include/tvm/relax/struct_info.h @@ -0,0 +1,430 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RELAX_STRUCT_INFO_H_ +#define TVM_RELAX_STRUCT_INFO_H_ + +#include +#include +#include +// #include +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief Opaque object. + */ +class ObjectStructInfoNode : public StructInfoNode { + public: + void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); } + + bool SEqualReduce(const ObjectStructInfoNode* other, SEqualReducer equal) const { return true; } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); } + + static constexpr const char* _type_key = "relax.ObjectStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(ObjectStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to ObjectStructInfoNode. + * \sa ObjectStructInfoNode + */ +class ObjectStructInfo : public StructInfo { + public: + TVM_DLL ObjectStructInfo(Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectStructInfo, StructInfo, ObjectStructInfoNode); +}; + +/*! + * \brief Primitive value. + */ +class PrimStructInfoNode : public StructInfoNode { + public: + /*! \brief Underlying data type of the primitive value */ + DataType dtype; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("dtype", &dtype); + v->Visit("span", &span); + } + + bool SEqualReduce(const PrimStructInfoNode* other, SEqualReducer equal) const { + return equal(dtype, other->dtype); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); } + + static constexpr const char* _type_key = "relax.PrimStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(PrimStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to PrimStructInfoNode. + * \sa PrimStructInfoNode + */ +class PrimStructInfo : public StructInfo { + public: + TVM_DLL PrimStructInfo(DataType dtype, Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PrimStructInfo, StructInfo, PrimStructInfoNode); +}; + +/*! + * \brief StructInfo of shape value. + */ +class ShapeStructInfoNode : public StructInfoNode { + public: + /*! \brief optionally stores the symbolic value patterns of the shape */ + Optional> values; + /*! + * \brief The number of dimension of the shape, can be unknown. + * \sa kUnknownNDim + */ + int ndim; + + /*! \return Whether the struct info contains unknown ndim. */ + bool IsUnknownNdim() const { return ndim == kUnknownNDim; } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("values", &values); + v->Visit("ndim", &ndim); + v->Visit("span", &span); + } + + bool SEqualReduce(const ShapeStructInfoNode* other, SEqualReducer equal) const { + return equal(values, other->values) && equal(ndim, other->ndim); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(values); + hash_reduce(ndim); + } + + static constexpr const char* _type_key = "relax.ShapeStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(ShapeStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to ShapeStructInfoNode. + * \sa ShapeStructInfoNode + */ +class ShapeStructInfo : public StructInfo { + public: + /*! + * \brief Construction with known symbolic shape patterns + * \param values The symbolic shape values + * \param span The span of the AST. + */ + TVM_DLL ShapeStructInfo(Array values, Span span = Span()); + /*! + * \brief Construction with known unknown symbolic shape patterns. + * \param ndim Number of dimensions -- can be kUnknownNDim + * \param span The span of the AST. + */ + TVM_DLL ShapeStructInfo(int ndim, Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeStructInfo, StructInfo, ShapeStructInfoNode); +}; + +/*! + * \brief StructInfo of Tensor. + */ +class TensorStructInfoNode : public StructInfoNode { + public: + /*! + * \brief optionally store the shape expression of the tensor. + * \note shape must be normalized: it can only be NullOpt or ShapeExpr or Var. + */ + Optional shape; + /*! \brief The content data type, use void to denote the dtype is unknown. */ + DataType dtype; + /*! + * \brief The number of dimension of the tensor, can be unknown. + * \sa kUnknownNDim + */ + int ndim; + + /*! \return Whether the struct info contains unknown ndim. */ + bool IsUnknownNdim() const { return ndim == kUnknownNDim; } + + /*! \return Whether the struct info contains unknown dtype. */ + bool IsUnknownDtype() const { return dtype.is_void(); } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("shape", &shape); + v->Visit("dtype", &dtype); + v->Visit("ndim", &ndim); + v->Visit("span", &span); + } + + bool SEqualReduce(const TensorStructInfoNode* other, SEqualReducer equal) const { + return equal(shape, other->shape) && equal(ndim, other->ndim) && equal(dtype, other->dtype); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(shape); + hash_reduce(dtype); + hash_reduce(ndim); + } + + static constexpr const char* _type_key = "relax.TensorStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to TensorStructInfoNode. + * \sa TensorStructInfoNode + */ +class TensorStructInfo : public StructInfo { + public: + /*! + * \brief Construction with a known shape expression. + * \param shape The shape of the tensor. + * \param dtype The data type of tensor's elements. + * \param span The span of the AST. + * + * \note shape must already be normalized. + */ + TVM_DLL TensorStructInfo(Expr shape, DataType dtype, Span span = Span()); + + /*! + * \brief Construction with an unknown shape expression. + * \param dtype The data type of tensor's elements. + * \param ndim The number of dimensions + * \param span The span of the AST. + */ + TVM_DLL TensorStructInfo(DataType dtype, int ndim, Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorStructInfo, StructInfo, TensorStructInfoNode); +}; + +/*! + * \brief StructInfo of Tuple. + */ +class TupleStructInfoNode : public StructInfoNode { + public: + /*! \brief The struct info of tuple fields. */ + Array fields; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("fields", &fields); + v->Visit("span", &span); + } + + bool SEqualReduce(const TupleStructInfoNode* other, SEqualReducer equal) const { + return equal(fields, other->fields); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); } + + static constexpr const char* _type_key = "relax.TupleStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to TupleStructInfoNode. + * \sa TupleStructInfoNode + */ +class TupleStructInfo : public StructInfo { + public: + /*! + * \brief Constructor + * \param fields Struct info of tuple fields. + * \param span The span of the AST. + */ + TVM_DLL TupleStructInfo(Array fields, Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TupleStructInfo, StructInfo, TupleStructInfoNode); +}; + +class BlockBuilder; + +/*! + * \brief custom-defined StructInfo derivation function. + * \param call The call expression to be derived. + * \param ctx The builder context. + * \return The derived struct info of the call. + */ +using StructInfoDeriveFunc = TypedEnvFunc; + +/*! + * \brief Structure information about function. + * + * This data structure contains enough information for us to + * do best-effort structure information deduction. + */ +class FuncStructInfoNode : public StructInfoNode { + public: + /*! + * \brief The parameter struct info of the function. + * \note When params is NullOpt means the function can take arbitrary number of arguments. + * We define such functions as Opaque function. + */ + Optional> params; + /*! + * \brief The struct info of the function's return value. + */ + StructInfo ret; + /*! + * \brief Derivation function of opaque functions that may take any number of parameters. + * \note When derive_func is not empty, then params should be NullOpt, + * ret should be ObjectStructInfo() + */ + Optional derive_func; + + /*! + * \return Whether the func struct info is opaque. + * \note We define a function as opaque we have no constraints on params. + */ + bool IsOpaque() const { return !params.defined(); } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("params", ¶ms); + v->Visit("ret", &ret); + v->Visit("derive_func", &derive_func); + v->Visit("span", &span); + } + + bool SEqualReduce(const FuncStructInfoNode* other, SEqualReducer equal) const { + return equal.DefEqual(params, other->params) && equal(ret, other->ret) && + equal(derive_func, other->derive_func); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce.DefHash(params); + hash_reduce(ret); + hash_reduce(derive_func); + } + + static constexpr const char* _type_key = "relax.FuncStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(FuncStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to FuncStructInfoNode. + * \sa FuncStructInfoNode + */ +class FuncStructInfo : public StructInfo { + public: + /*! + * \brief Constructor from parameter struct info and return value struct info. + * \param params The struct info of function parameters. + * \param ret The return value struct info. + * \param span The span of the AST. + * + * \note If the ret contains variables(tir::Var and relax::Var), they must be deducible from + * params. If you are unsure, you can always erase ret to static. + */ + TVM_DLL FuncStructInfo(Array params, StructInfo ret, Span span = Span()); + + /*! + * \brief Constructing an opaque function struct info using derive_func. + * + * \param derive_func Derivation function. + * \param span The span of the AST. + * + * \return The FuncStructInfo for opaque packedfunc. + * \note Defaults to an derive func that always return ObjectStructInfo if not specified. + */ + TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, Span span = Span()); + + /*! + * \brief Construct an opaque function using from return struct info. + * + * \param ret The struct info of the return value. + * \param span The span of the AST. + * + * \return The FuncStructInfo for opaque packedfunc. + * \note Defaults to an derive func that always return ObjectStructInfo if not specified. + */ + TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret = ObjectStructInfo(), Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FuncStructInfo, StructInfo, FuncStructInfoNode); +}; + +/*! + * \brief Match and check if expr have StructInfo T and return it. + * + * \param expr The input expression. + * \return The result of match. + * \tparam T the underlying structure info type + */ +template +inline Optional MatchStructInfo(const Expr& expr) { + using TNode = typename T::ContainerType; + if (const TNode* ptr = expr->struct_info_.as()) { + return GetRef(ptr); + } else { + return NullOpt; + } +} + +/*! + * \brief Get the structure info of a given expr and try to cast it as const T*. + * + * \param expr The input expression. + * \return The pointer. Returns nullptr if the type does not match + * \tparam T the underlying structure info type + */ +template +inline const T* GetStructInfoAs(const Expr& expr) { + ICHECK(expr->struct_info_.defined()) + << "The struct_info is not populated, check if you have normalized the expr"; + return expr->struct_info_.as(); +} + +/*! + * \brief Get the underlying structure info of expr. + * + * \param expr The input expression. + * \return underlying struct info. + */ +inline StructInfo GetStructInfo(const Expr& expr) { + auto* ptr = expr->struct_info_.as(); + ICHECK(ptr) << "The struct_info is not populated, check if you have normalized the expr"; + return GetRef(ptr); +} + +/*! + * \brief Whether the expr has void struct info. + * + * \param expr The input expression. + * \return Whether the expr has void struct info. + */ +inline bool HasVoidStructInfo(const Expr& expr) { + auto* ptr = expr->struct_info_.as(); + return ptr != nullptr && ptr->fields.size() == 0; +} + +/*! + * \brief Update the struct info of an Expr. + * \param expr The Expr whose struct info to be updated. + * \param struct_info The struct_info assigned. + * \note We ensure idempotence, that is we can only update the struct_info of an Expr only + * if the original one is nullptr. + */ +TVM_DLL void UpdateStructInfo(Expr expr, StructInfo struct_info); + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_STRUCT_INFO_H_ diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index abc25e89c48c..08ce082c4586 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -40,6 +40,7 @@ pub mod attrs; pub struct ExprNode { pub base: BaseExprNode, pub checked_type: Type, + pub struct_info: ObjectRef, pub virtual_device: ObjectRef, } @@ -48,6 +49,7 @@ impl ExprNode { ExprNode { base: BaseExprNode::base::(span.clone()), checked_type: Type::null(), + struct_info: ObjectRef::null(), virtual_device: ObjectRef::null(), } } diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc new file mode 100644 index 000000000000..88046ed81f10 --- /dev/null +++ b/src/relax/ir/struct_info.cc @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/struct_info.cc + * \brief Relax struct info. + */ +#include +#include + +namespace tvm { +namespace relax { + +ObjectStructInfo::ObjectStructInfo(Span span) { + ObjectPtr n = make_object(); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ObjectStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.ObjectStructInfo").set_body_typed([](Span span) { + return ObjectStructInfo(span); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + p->stream << "ObjectStructInfo()"; + }); + +// Prim +PrimStructInfo::PrimStructInfo(DataType dtype, Span span) { + ObjectPtr n = make_object(); + n->dtype = dtype; + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(PrimStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.PrimStructInfo").set_body_typed([](DataType dtype, Span span) { + return PrimStructInfo(dtype, span); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + const auto* node = static_cast(ref.get()); + p->stream << "PrimStructInfo(" << node->dtype << ")"; + }); + +// Shape +ShapeStructInfo::ShapeStructInfo(Array values, Span span) { + ObjectPtr n = make_object(); + n->ndim = static_cast(values.size()); + n->values = values.Map([](PrimExpr value) { + if (value->IsInstance()) { + return tvm::cast(DataType::Int(64), value); + } + ICHECK(value.dtype() == DataType::Int(64)) + << "the value in ShapeStructInfo can only have dtype of int64"; + return value; + }); + n->span = span; + data_ = std::move(n); +} + +ShapeStructInfo::ShapeStructInfo(int ndim, Span span) { + ObjectPtr n = make_object(); + CHECK_GE(ndim, -1) << "ndim of ShapeStructInfo must be >= -1, but got " << ndim; + n->ndim = ndim; + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ShapeStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.ShapeStructInfo") + .set_body_typed([](Optional> values, int ndim, Span span) { + if (values.defined()) { + CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify values and ndim"; + return ShapeStructInfo(values.value(), span); + } else { + return ShapeStructInfo(ndim, span); + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + const auto* node = static_cast(ref.get()); + if (node->values.defined()) { + p->stream << "ShapeStructInfo(" << node->values.value() << ")"; + } else { + p->stream << "ShapeStructInfo(ndim=" << node->ndim << ")"; + } + }); + +// Tensor +TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, Span span) { + ObjectPtr n = make_object(); + // assign ndim before move + Optional sinfo = MatchStructInfo(shape); + ICHECK(sinfo) << "We expect shape to contain pre-set shape struct info"; + ICHECK(shape.defined()) << "Must provide a shape in this constructor"; + ICHECK(shape->IsInstance() || shape->IsInstance()) + << "We require shape to be normalized when constructing TensorStructInfo"; + n->ndim = sinfo.get()->ndim; + // assign rest of the fields. + n->shape = std::move(shape); + n->dtype = dtype; + n->span = span; + data_ = std::move(n); +} + +TensorStructInfo::TensorStructInfo(DataType dtype, int ndim, Span span) { + ObjectPtr n = make_object(); + CHECK_GE(ndim, -1) << "ndim of TensorStructInfo must be >= -1, but got " << ndim; + n->ndim = ndim; + n->dtype = dtype; + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TensorStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.TensorStructInfo") + .set_body_typed([](Optional shape, DataType dtype, int ndim, Span span) { + if (shape.defined()) { + CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify shape and ndim"; + return TensorStructInfo(shape.value(), dtype, span); + } else { + return TensorStructInfo(dtype, ndim, span); + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + const auto* node = static_cast(ref.get()); + if (node->shape.defined()) { + p->stream << "TensorStructInfo(" << node->shape.value() << ", " << node->dtype << ")"; + } else { + p->stream << "TensorStructInfo(" << node->dtype << ", ndim=" << node->ndim << ")"; + } + }); + +// Tuple +TupleStructInfo::TupleStructInfo(Array fields, Span span) { + ObjectPtr n = make_object(); + n->fields = std::move(fields); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TupleStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.TupleStructInfo") + .set_body_typed([](Array fields, Span span) { + return TupleStructInfo(fields, span); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + const auto* node = static_cast(ref.get()); + p->stream << "TupleStructInfo(" << node->fields << ")"; + }); + +// Func +FuncStructInfo::FuncStructInfo(Array params, StructInfo ret, Span span) { + ObjectPtr n = make_object(); + n->params = std::move(params); + n->ret = std::move(ret); + n->span = span; + data_ = std::move(n); +} + +FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfoDeriveFunc derive_func, Span span) { + ObjectPtr n = make_object(); + n->derive_func = std::move(derive_func); + n->ret = ObjectStructInfo(); + n->span = span; + return FuncStructInfo(n); +} + +FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfo ret, Span span) { + ObjectPtr n = make_object(); + n->ret = std::move(ret); + n->span = span; + return FuncStructInfo(n); +} + +TVM_REGISTER_NODE_TYPE(FuncStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.FuncStructInfo") + .set_body_typed([](Array params, StructInfo ret, Span span) { + return FuncStructInfo(params, ret, span); + }); + +TVM_REGISTER_GLOBAL("relax.FuncStructInfoOpaqueFunc") + .set_body_typed([](Optional ret, Optional derive_func, + Span span) { + if (derive_func.defined()) { + ICHECK(!ret.defined()) << "ValueError: Cannot specify both ret and derive_func"; + return FuncStructInfo::OpaqueFunc(derive_func.value(), span); + } else { + return FuncStructInfo::OpaqueFunc(ret.value_or(ObjectStructInfo()), span); + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + const auto* node = static_cast(ref.get()); + p->stream << "FuncStructInfo(" << node->params << ", " << node->ret << ")"; + }); + +// Helper functions +// TODO(unity-team): add UpdateStructInfo once analysis.cc is upstreamed + +TVM_REGISTER_GLOBAL("ir.ExprStructInfo").set_body_typed([](Expr expr) { + return GetStructInfo(expr); +}); + +} // namespace relax +} // namespace tvm