From 156701606a2bd426c0a4f4eb8d05f1585aebac7a Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 19 Jan 2020 14:36:12 -0800 Subject: [PATCH 1/2] [REFACTOR][TYPE] Finish move all types to IR. - Move definition of Ref and TensorType to ir - Move type_functor.h to public header. - Rename RefType -> RelayRefType for clarity. --- include/tvm/ir/tensor_type.h | 117 ++++++++++++++++ include/tvm/ir/type.h | 70 ++++++++++ {src/relay => include/tvm}/ir/type_functor.h | 58 ++++---- include/tvm/relay/type.h | 137 +------------------ include/tvm/runtime/object.h | 16 +-- src/{relay/ir/type.cc => ir/tensor_type.cc} | 63 ++------- src/ir/type.cc | 41 ++++++ src/{relay => }/ir/type_functor.cc | 17 ++- src/relay/backend/compile_engine.cc | 10 +- src/relay/backend/interpreter.cc | 4 +- src/relay/ir/alpha_equal.cc | 6 +- src/relay/ir/expr.cc | 9 +- src/relay/ir/expr_functor.cc | 2 +- src/relay/ir/hash.cc | 6 +- src/relay/ir/pretty_printer.cc | 5 +- src/relay/op/algorithm/argsort.cc | 2 +- src/relay/op/algorithm/topk.cc | 4 +- src/relay/op/image/resize.cc | 4 +- src/relay/op/memory/memory.cc | 8 +- src/relay/op/nn/bitserial.cc | 6 +- src/relay/op/nn/convolution.cc | 20 +-- src/relay/op/nn/convolution.h | 12 +- src/relay/op/nn/nn.cc | 44 +++--- src/relay/op/nn/nn.h | 4 +- src/relay/op/nn/pad.cc | 4 +- src/relay/op/nn/pooling.cc | 10 +- src/relay/op/nn/sparse.cc | 10 +- src/relay/op/nn/upsampling.cc | 4 +- src/relay/op/tensor/reduce.cc | 6 +- src/relay/op/tensor/transform.cc | 62 ++++----- src/relay/op/tensor/transform.h | 2 +- src/relay/op/tensor/unary.cc | 4 +- src/relay/op/type_relations.cc | 2 +- src/relay/op/vision/multibox_op.cc | 6 +- src/relay/op/vision/nms.cc | 8 +- src/relay/op/vision/rcnn_op.cc | 6 +- src/relay/op/vision/yolo.cc | 2 +- src/relay/pass/de_duplicate.cc | 3 +- src/relay/pass/eta_expand.cc | 2 +- src/relay/pass/gradient.cc | 5 +- src/relay/pass/kind_check.cc | 6 +- src/relay/pass/partial_eval.cc | 4 +- src/relay/pass/quantize/quantize.cc | 6 +- src/relay/pass/to_cps.cc | 2 +- src/relay/pass/type_infer.cc | 37 +++-- src/relay/pass/type_solver.cc | 14 +- src/relay/pass/util.cc | 2 +- src/relay/qnn/op/dequantize.cc | 2 +- src/relay/qnn/op/quantize.cc | 2 +- src/relay/qnn/op/requantize.cc | 2 +- src/relay/qnn/util.h | 2 +- tests/cpp/relay_build_module_test.cc | 2 +- tests/cpp/relay_pass_type_infer_test.cc | 2 +- tests/cpp/relay_transform_sequential.cc | 2 +- tests/cpp/utvm_runtime_standalone_test.cc | 2 +- 55 files changed, 483 insertions(+), 405 deletions(-) create mode 100644 include/tvm/ir/tensor_type.h rename {src/relay => include/tvm}/ir/type_functor.h (77%) rename src/{relay/ir/type.cc => ir/tensor_type.cc} (50%) rename src/{relay => }/ir/type_functor.cc (94%) diff --git a/include/tvm/ir/tensor_type.h b/include/tvm/ir/tensor_type.h new file mode 100644 index 000000000000..70a2df19db6a --- /dev/null +++ b/include/tvm/ir/tensor_type.h @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/tensor_type.h + * \brief Polymorphic tensor types. + */ +#ifndef TVM_IR_TENSOR_TYPE_H_ +#define TVM_IR_TENSOR_TYPE_H_ + +#include +#include + +namespace tvm { +/*! + * \brief Base of all Tensor types + * This container can hold TensorType or GenericTensorType. + * \sa BaseTensorType, TensorTypeNode + */ +class BaseTensorTypeNode : public TypeNode { + public: + static constexpr const char* _type_key = "relay.BaseTensorType"; + TVM_DECLARE_BASE_OBJECT_INFO(BaseTensorTypeNode, TypeNode); +}; + +/*! + * \brief Managed reference to BaseTensorTypeNode. + * \sa BaseTensorTypeNode. + */ +class BaseTensorType : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(BaseTensorType, Type, BaseTensorTypeNode); +}; + +/*! + * \brief This is the most commonly used type in relay. + * TensorType have a fixed dimension, data type. + * + * The elements of shape can be either IntImm(constant integer), + * or any symbolic integer expression. + * The symbolic integer allows generic shape inference in certain cases. + * \sa TensorType + */ +class TensorTypeNode : public BaseTensorTypeNode { + public: + /*! + * \brief The shape of the tensor, + * represented by PrimExpr(tvm::Expr). + */ + Array shape; + /*! \brief The content data type */ + DataType dtype; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("shape", &shape); + v->Visit("dtype", &dtype); + v->Visit("span", &span); + } + + /*! \brief Return product of elements in the shape. + * \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero. + */ + TVM_DLL PrimExpr Size() const; + + static constexpr const char* _type_key = "relay.TensorType"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, BaseTensorTypeNode); +}; + +/*! + * \brief Managed reference to TensorTypeNode. + * \sa TensorTypeNode. + */ +class TensorType : public Type { + public: + /*! + * \brief Constructor. + * \param shape The shape of the tensor. + * \param dtype The runtime dtype of the tensor's elements. + */ + TVM_DLL TensorType(Array shape, DataType dtype); + + /*! + * \brief Construct an scalar containing elements of dtype. + * \param dtype The runtime dtype of the tensor's elements. + * \return THe constructed type. + */ + TVM_DLL static TensorType Scalar(DataType dtype); + + TVM_DEFINE_OBJECT_REF_METHODS(TensorType, Type, TensorTypeNode); +}; + +// The following fields contains advanced typing +// Only keep the class name and reserved for future usage. +class GenericTensorType; +// stores a DataType. +class GenericDataType; +// stores a DataType. +class GenericShape; + +} // namespace tvm +#endif // TVM_IR_TENSOR_TYPE_H_ diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index e143588ee4a3..56f2389ad385 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -352,5 +352,75 @@ class FuncType : public Type { TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode); }; +/*! + * \brief Intermediate values that is used to indicate incomplete type + * during type inference. + * + * If we view the type relations as "computational graph of types", + * then IncompleteType represents intermediate values of the graph, + * TypeVar represents the input to the graph. + * + * \sa IncompleteType + */ +class IncompleteTypeNode : public TypeNode { + public: + /*! \brief kind of the type. */ + TypeKind kind; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("kind", &kind); + v->Visit("span", &span); + } + + static constexpr const char* _type_key = "relay.IncompleteType"; + TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode); +}; + +/*! + * \brief Managed reference to IncompleteTypeNode. + * \sa IncompleteTypeNode + */ +class IncompleteType : public Type { + public: + /*! + * \brief Constructor. + * \param kind kind of the type. + */ + TVM_DLL explicit IncompleteType(TypeKind kind); + + TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode); +}; + + +/*! + * \brief Reference Type High-level Relay IR. + * + * \sa RelayRefType. + */ +class RelayRefTypeNode : public TypeNode { + public: + /*! \brief The type of value in the Reference. */ + Type value; + + RelayRefTypeNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("value", &value); + v->Visit("span", &span); + } + + static constexpr const char* _type_key = "relay.RefType"; + TVM_DECLARE_FINAL_OBJECT_INFO(RelayRefTypeNode, TypeNode); +}; + +/*! + * \brief Managed reference to RelayRefTypeNode. + * \sa RelayRefTypeNode. + */ +class RelayRefType : public Type { + public: + TVM_DLL explicit RelayRefType(Type value); + TVM_DEFINE_OBJECT_REF_METHODS(RelayRefType, Type, RelayRefTypeNode); +}; } // namespace tvm #endif // TVM_IR_TYPE_H_ diff --git a/src/relay/ir/type_functor.h b/include/tvm/ir/type_functor.h similarity index 77% rename from src/relay/ir/type_functor.h rename to include/tvm/ir/type_functor.h index 09049cf83f86..476538c5da36 100644 --- a/src/relay/ir/type_functor.h +++ b/include/tvm/ir/type_functor.h @@ -18,11 +18,11 @@ */ /*! - * \file type_functor.h + * \file tvm/ir/type_functor.h * \brief A way to defined arbitrary function signature with dispatch on types. */ -#ifndef TVM_RELAY_IR_TYPE_FUNCTOR_H_ -#define TVM_RELAY_IR_TYPE_FUNCTOR_H_ +#ifndef TVM_IR_TYPE_FUNCTOR_H_ +#define TVM_IR_TYPE_FUNCTOR_H_ #include #include @@ -32,17 +32,16 @@ #include namespace tvm { -namespace relay { template class TypeFunctor; // functions to be overriden. -#define TYPE_FUNCTOR_DEFAULT \ +#define TYPE_FUNCTOR_DEFAULT \ { return VisitTypeDefault_(op, std::forward(args)...); } -#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \ +#define TVM_TYPE_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ [](const ObjectRef& n, TSelf* self, Args... args) { \ return self->VisitType_(static_cast(n.get()), \ @@ -89,10 +88,11 @@ class TypeFunctor { virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitType_(const RefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const RelayRefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const GlobalTypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const PrimTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitTypeDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; // unreachable, written to stop compiler warning @@ -103,25 +103,29 @@ class TypeFunctor { static FType InitVTable() { FType vtable; // Set dispatch - RELAY_TYPE_FUNCTOR_DISPATCH(TensorTypeNode); - RELAY_TYPE_FUNCTOR_DISPATCH(TypeVarNode); - RELAY_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode); - RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode); - RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode); - RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode); - RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode); - RELAY_TYPE_FUNCTOR_DISPATCH(RefTypeNode); - RELAY_TYPE_FUNCTOR_DISPATCH(GlobalTypeVarNode); - RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode); - RELAY_TYPE_FUNCTOR_DISPATCH(TypeDataNode); + TVM_TYPE_FUNCTOR_DISPATCH(TensorTypeNode); + TVM_TYPE_FUNCTOR_DISPATCH(TypeVarNode); + TVM_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode); + TVM_TYPE_FUNCTOR_DISPATCH(FuncTypeNode); + TVM_TYPE_FUNCTOR_DISPATCH(TypeRelationNode); + TVM_TYPE_FUNCTOR_DISPATCH(TupleTypeNode); + TVM_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode); + TVM_TYPE_FUNCTOR_DISPATCH(RelayRefTypeNode); + TVM_TYPE_FUNCTOR_DISPATCH(GlobalTypeVarNode); + TVM_TYPE_FUNCTOR_DISPATCH(TypeCallNode); + TVM_TYPE_FUNCTOR_DISPATCH(TypeDataNode); + TVM_TYPE_FUNCTOR_DISPATCH(PrimTypeNode); return vtable; } }; +#undef TVM_TYPE_FUNCTOR_DISPATCH + /*! * \brief A type visitor that recursively visit types. */ -class TypeVisitor : public TypeFunctor { +class TVM_DLL TypeVisitor : + public TypeFunctor { public: void VisitType_(const TypeVarNode* op) override; void VisitType_(const IncompleteTypeNode* op) override; @@ -129,14 +133,18 @@ class TypeVisitor : public TypeFunctor { void VisitType_(const FuncTypeNode* op) override; void VisitType_(const TupleTypeNode* op) override; void VisitType_(const TypeRelationNode* op) override; - void VisitType_(const RefTypeNode* op) override; + void VisitType_(const RelayRefTypeNode* op) override; void VisitType_(const GlobalTypeVarNode* op) override; void VisitType_(const TypeCallNode* op) override; void VisitType_(const TypeDataNode* op) override; + void VisitType_(const PrimTypeNode* op) override; }; -// Mutator that transform a type to another one. -class TypeMutator : public TypeFunctor { +/*! + * \brief TypeMutator that mutates expressions. + */ +class TVM_DLL TypeMutator : + public TypeFunctor { public: Type VisitType(const Type& t) override; Type VisitType_(const TypeVarNode* op) override; @@ -145,10 +153,11 @@ class TypeMutator : public TypeFunctor { Type VisitType_(const FuncTypeNode* op) override; Type VisitType_(const TupleTypeNode* op) override; Type VisitType_(const TypeRelationNode* type_rel) override; - Type VisitType_(const RefTypeNode* op) override; + Type VisitType_(const RelayRefTypeNode* op) override; Type VisitType_(const GlobalTypeVarNode* op) override; Type VisitType_(const TypeCallNode* op) override; Type VisitType_(const TypeDataNode* op) override; + Type VisitType_(const PrimTypeNode* op) override; private: Array MutateArray(Array arr); @@ -161,6 +170,5 @@ class TypeMutator : public TypeFunctor { */ Type Bind(const Type& type, const Map& args_map); -} // namespace relay } // namespace tvm -#endif // TVM_RELAY_IR_TYPE_FUNCTOR_H_ +#endif // TVM_IR_TYPE_FUNCTOR_H_ diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index adf1380eecb9..e8f402ac961d 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -25,6 +25,7 @@ #define TVM_RELAY_TYPE_H_ #include +#include #include #include #include @@ -54,6 +55,12 @@ using TypeConstraint = tvm::TypeConstraint; using TypeConstraintNode = tvm::TypeConstraintNode; using FuncType = tvm::FuncType; using FuncTypeNode = tvm::FuncTypeNode; +using IncompleteType = tvm::IncompleteType; +using IncompleteTypeNode = tvm::IncompleteTypeNode; +using RelayRefType = tvm::RelayRefType; +using RelayRefTypeNode = tvm::RelayRefTypeNode; +using TensorType = tvm::TensorType; +using TensorTypeNode = tvm::TensorTypeNode; using TypeCall = tvm::TypeCall; using TypeCallNode = tvm::TypeCallNode; using TypeRelation = tvm::TypeRelation; @@ -62,136 +69,6 @@ using TypeRelationFn = tvm::TypeRelationFn; using TypeReporter = tvm::TypeReporter; using TypeReporterNode = tvm::TypeReporterNode; -/*! - * \brief Base of all Tensor types - * This container can hold TensorType or GenericTensorType. - */ -class BaseTensorTypeNode : public TypeNode { - public: - static constexpr const char* _type_key = "relay.BaseTensorType"; - TVM_DECLARE_BASE_OBJECT_INFO(BaseTensorTypeNode, TypeNode); -}; - -class BaseTensorType : public Type { - public: - TVM_DEFINE_OBJECT_REF_METHODS(BaseTensorType, Type, BaseTensorTypeNode); -}; - -/*! - * \brief This is the most commonly used type in relay. - * TensorType have a fixed dimension, data type. - * - * The elements of shape can be either IntImm(constant integer), - * or any symbolic integer expression. - * The symbolic integer allows generic shape inference in certain cases. - * \sa TensorTypeNode The container class of TensorType. - */ -class TensorType; -/*! \brief TensorType container node */ -class TensorTypeNode : public BaseTensorTypeNode { - public: - /*! - * \brief The shape of the tensor, - * represented by IndexExpr(tvm::Expr). - */ - Array shape; - /*! \brief The content data type */ - DataType dtype; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("shape", &shape); - v->Visit("dtype", &dtype); - v->Visit("span", &span); - } - - /*! \brief Return product of elements in the shape. - * \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero. - */ - TVM_DLL IndexExpr Size() const; - - TVM_DLL static TensorType make(Array shape, DataType dtype); - - /*! \brief Construct an scalar containing elements of dtype. */ - TVM_DLL static TensorType Scalar(DataType dtype); - - static constexpr const char* _type_key = "relay.TensorType"; - TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, BaseTensorTypeNode); -}; - -class TensorType : public Type { - public: - TVM_DEFINE_OBJECT_REF_METHODS(TensorType, Type, TensorTypeNode); -}; - -/*! - * \brief IncompleteType. - * This is intermediate values that is used during type inference. - * - * If we view the type relations as "computational graph of types", - * then IncompleteType represents intermediate values of the graph, - * TypeVar represents the input to the graph. - */ -class IncompleteType; - -/*! \brief IncompleteType container node */ -class IncompleteTypeNode : public TypeNode { - public: - Kind kind; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("kind", &kind); - v->Visit("span", &span); - } - - TVM_DLL static IncompleteType make(Kind kind); - - static constexpr const char* _type_key = "relay.IncompleteType"; - TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode); -}; - -class IncompleteType : public Type { - public: - TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode); -}; - -/*! - * \brief The type of reference values. - */ -class RefType; -/*! - * \brief Reference Type in relay. - */ -class RefTypeNode : public TypeNode { - public: - /*! \brief The type of value in the Reference. */ - Type value; - - RefTypeNode() {} - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("value", &value); - v->Visit("span", &span); - } - - TVM_DLL static RefType make(Type value); - - static constexpr const char* _type_key = "relay.RefType"; - TVM_DECLARE_FINAL_OBJECT_INFO(RefTypeNode, TypeNode); -}; - -class RefType : public Type { - public: - TVM_DEFINE_OBJECT_REF_METHODS(RefType, Type, RefTypeNode); -}; - -// The following fields contains advanced typing -// Only keep the class name and reserved for future usage. -class GenericTensorType; -// stores a DataType. -class GenericDataType; -// stores a DataType. -class GenericShape; - } // namespace relay } // namespace tvm #endif // TVM_RELAY_TYPE_H_ diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 1989afdf0787..ba84b5f10119 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -317,8 +317,8 @@ class Object { * \tparam ObjectType The object type * \return The corresponding RefType */ -template -inline RefType GetRef(const ObjectType* ptr); +template +inline RelayRefType GetRef(const ObjectType* ptr); /*! * \brief Downcast a base reference type to a more specific type. @@ -484,8 +484,8 @@ class ObjectPtr { friend class TVMArgsSetter; friend class TVMRetValue; friend class TVMArgValue; - template - friend RefType GetRef(const ObjType* ptr); + template + friend RelayRefType GetRef(const ObjType* ptr); template friend ObjectPtr GetObjectPtr(ObjType* ptr); }; @@ -848,11 +848,11 @@ inline const ObjectType* ObjectRef::as() const { } } -template -inline RefType GetRef(const ObjType* ptr) { - static_assert(std::is_base_of::value, +template +inline RelayRefType GetRef(const ObjType* ptr) { + static_assert(std::is_base_of::value, "Can only cast to the ref of same container type"); - return RefType(ObjectPtr(const_cast(static_cast(ptr)))); + return RelayRefType(ObjectPtr(const_cast(static_cast(ptr)))); } template diff --git a/src/relay/ir/type.cc b/src/ir/tensor_type.cc similarity index 50% rename from src/relay/ir/type.cc rename to src/ir/tensor_type.cc index f1e59a4b97c0..0a9ed4eed327 100644 --- a/src/relay/ir/type.cc +++ b/src/ir/tensor_type.cc @@ -18,35 +18,35 @@ */ /*! - * \file src/tvm/ir/type.cc + * \file src/tvm/ir/tensor_type.cc * \brief The type system AST nodes of Relay. */ -#include +#include +#include #include namespace tvm { -namespace relay { using tvm::NodePrinter; using namespace tvm::runtime; -TensorType TensorTypeNode::make(Array shape, DataType dtype) { +TensorType::TensorType(Array shape, DataType dtype) { ObjectPtr n = make_object(); n->shape = std::move(shape); n->dtype = std::move(dtype); - return TensorType(n); + data_ = std::move(n); } -TensorType TensorTypeNode::Scalar(DataType dtype) { - return TensorTypeNode::make({}, dtype); +TensorType TensorType::Scalar(DataType dtype) { + return TensorType({}, dtype); } -IndexExpr TensorTypeNode::Size() const { +PrimExpr TensorTypeNode::Size() const { if (shape.size() == 0) { return tir::make_const(DataType::Int(64), 1); } - IndexExpr size = shape[0]; + PrimExpr size = shape[0]; for (size_t i = 1; i < shape.size(); ++i) { size *= shape[i]; } @@ -56,7 +56,9 @@ IndexExpr TensorTypeNode::Size() const { TVM_REGISTER_NODE_TYPE(TensorTypeNode); TVM_REGISTER_GLOBAL("relay._make.TensorType") -.set_body_typed(TensorTypeNode::make); +.set_body_typed([](Array shape, DataType dtype) { + return TensorType(shape, dtype); +}); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) .set_dispatch([](const ObjectRef& ref, NodePrinter* p) { @@ -64,45 +66,4 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")"; }); -IncompleteType IncompleteTypeNode::make(Kind kind) { - auto n = make_object(); - n->kind = std::move(kind); - return IncompleteType(n); -} - -TVM_REGISTER_NODE_TYPE(IncompleteTypeNode); - -TVM_REGISTER_GLOBAL("relay._make.IncompleteType") -.set_body_typed([](int kind) { - return IncompleteTypeNode::make(static_cast(kind)); - }); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; - }); - - -RefType RefTypeNode::make(Type value) { - ObjectPtr n = make_object(); - n->value = std::move(value); - return RefType(n); -} - -TVM_REGISTER_GLOBAL("relay._make.RefType") -.set_body_typed(RefTypeNode::make); - -TVM_REGISTER_NODE_TYPE(RefTypeNode); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RefTypeNode(" << node->value << ")"; -}); - -TVM_REGISTER_GLOBAL("relay._make.Any") -.set_body_typed([]() { return Any::make(); }); - -} // namespace relay } // namespace tvm diff --git a/src/ir/type.cc b/src/ir/type.cc index 9e250db44875..233274a79e02 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -118,6 +118,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) << node->type_constraints << ")"; }); + TupleType::TupleType(Array fields) { ObjectPtr n = make_object(); n->fields = std::move(fields); @@ -141,4 +142,44 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << "TupleTypeNode(" << node->fields << ")"; }); + +IncompleteType::IncompleteType(TypeKind kind) { + auto n = make_object(); + n->kind = std::move(kind); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(IncompleteTypeNode); + +TVM_REGISTER_GLOBAL("relay._make.IncompleteType") +.set_body_typed([](int kind) { + return IncompleteType(static_cast(kind)); + }); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; + }); + + +RelayRefType::RelayRefType(Type value) { + ObjectPtr n = make_object(); + n->value = std::move(value); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relay._make.RefType") +.set_body_typed([](Type value) { + return RelayRefType(value); +}); + +TVM_REGISTER_NODE_TYPE(RelayRefTypeNode); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "RelayRefTypeNode(" << node->value << ")"; +}); + } // namespace tvm diff --git a/src/relay/ir/type_functor.cc b/src/ir/type_functor.cc similarity index 94% rename from src/relay/ir/type_functor.cc rename to src/ir/type_functor.cc index 0180a0c64ba0..cbd3538b066c 100644 --- a/src/relay/ir/type_functor.cc +++ b/src/ir/type_functor.cc @@ -21,11 +21,10 @@ * \file type_functor.cc * \brief Implementations of type functors. */ +#include #include -#include "type_functor.h" namespace tvm { -namespace relay { void TypeVisitor::VisitType_(const TypeVarNode* op) { } @@ -57,7 +56,7 @@ void TypeVisitor::VisitType_(const TupleTypeNode* op) { } } -void TypeVisitor::VisitType_(const RefTypeNode* op) { +void TypeVisitor::VisitType_(const RelayRefTypeNode* op) { this->VisitType(op->value); } @@ -91,6 +90,9 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) { } } +void TypeVisitor::VisitType_(const PrimTypeNode* op) { +} + Type TypeMutator::VisitType(const Type& t) { return t.defined() ? TypeFunctor::VisitType(t) : t; } @@ -169,8 +171,8 @@ Type TypeMutator::VisitType_(const TupleTypeNode* op) { } } -Type TypeMutator::VisitType_(const RefTypeNode* op) { - return RefTypeNode::make(this->VisitType(op->value)); +Type TypeMutator::VisitType_(const RelayRefTypeNode* op) { + return RelayRefType(this->VisitType(op->value)); } Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) { @@ -203,6 +205,10 @@ Type TypeMutator::VisitType_(const TypeDataNode* op) { return GetRef(op); } +Type TypeMutator::VisitType_(const PrimTypeNode* op) { + return GetRef(op); +} + // Implements bind. class TypeBinder : public TypeMutator { public: @@ -227,5 +233,4 @@ Type Bind(const Type& type, const tvm::Map& args_map) { return TypeBinder(args_map).VisitType(type); } -} // namespace relay } // namespace tvm diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 86d16a898c92..8ba6eb43e618 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -21,8 +21,7 @@ * \file relay/backend/compile_engine.cc * \brief Internal compialtion engine. */ -#include "compile_engine.h" - +#include #include #include #include @@ -42,7 +41,8 @@ #include #include #include -#include "../ir/type_functor.h" + +#include "compile_engine.h" namespace tvm { namespace relay { @@ -239,12 +239,12 @@ class ScheduleGetter : // TODO(@icemelon): Support recursive tuple Type call_node_type = call_node->checked_type(); if (const auto* tt = call_node->checked_type().as()) { - call_node_type = TensorTypeNode::make(GetShape(tt->shape), tt->dtype); + call_node_type = TensorType(GetShape(tt->shape), tt->dtype); } else if (const auto* tuple_t = call_node->checked_type().as()) { std::vector new_fields; for (auto field : tuple_t->fields) { if (const auto* tt = field.as()) { - new_fields.push_back(TensorTypeNode::make(GetShape(tt->shape), tt->dtype)); + new_fields.push_back(TensorType(GetShape(tt->shape), tt->dtype)); } else { new_fields.push_back(field); } diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 95d667fcf0ae..224ff778ff34 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -529,7 +529,7 @@ class Interpreter : if (is_dyn) { auto sh = out_shapes[i]; auto tt = Downcast(rtype->fields[i]); - fields.push_back(fset_output(i, TensorTypeNode::make(sh, tt->dtype))); + fields.push_back(fset_output(i, TensorType(sh, tt->dtype))); } else { fields.push_back(fset_output(i, rtype->fields[i])); } @@ -542,7 +542,7 @@ class Interpreter : CHECK_EQ(out_shapes.size(), 1); auto sh = out_shapes[0]; auto tt = Downcast(ret_type); - out_tensor = fset_output(0, TensorTypeNode::make(sh, tt->dtype)); + out_tensor = fset_output(0, TensorType(sh, tt->dtype)); } else { out_tensor = fset_output(0, ret_type); } diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index b55a4afd22e2..2d07f6131f13 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -21,6 +21,7 @@ * \file src/tvm/relay/ir/alpha_equal.cc * \brief Alpha equality check by deep comparing two nodes. */ +#include #include #include #include @@ -28,7 +29,6 @@ #include #include #include -#include "type_functor.h" #include "../../ir/attr_functor.h" namespace tvm { namespace relay { @@ -277,8 +277,8 @@ class AlphaEqualHandler: } } - bool VisitType_(const RefTypeNode* lhs, const Type& other) final { - if (const RefTypeNode* rhs = other.as()) { + bool VisitType_(const RelayRefTypeNode* lhs, const Type& other) final { + if (const RelayRefTypeNode* rhs = other.as()) { return TypeEqual(lhs->value, rhs->value); } return false; diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 7e19d5173ad4..3d8cc3a85b2b 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -59,7 +59,7 @@ TensorType ConstantNode::tensor_type() const { tvm::IntImm(DataType::Int(32), data->shape[i])); } - return TensorTypeNode::make(shape, dtype); + return TensorType(shape, dtype); } Tuple TupleNode::make(tvm::Array fields) { @@ -129,12 +129,12 @@ FuncType FunctionNode::func_type_annotation() const { Array param_types; for (auto param : this->params) { Type param_type = (param->type_annotation.defined()) ? param->type_annotation - : IncompleteTypeNode::make(Kind::kType); + : IncompleteType(Kind::kType); param_types.push_back(param_type); } Type ret_type = (this->ret_type.defined()) ? this->ret_type - : IncompleteTypeNode::make(Kind::kType); + : IncompleteType(Kind::kType); return FuncType(param_types, ret_type, this->type_params, {}); } @@ -359,5 +359,8 @@ TVM_REGISTER_GLOBAL("relay._expr.FunctionSetAttr") return FunctionSetAttr(func, name, ref); }); +TVM_REGISTER_GLOBAL("relay._make.Any") +.set_body_typed([]() { return Any::make(); }); + } // namespace relay } // namespace tvm diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 0da763ab4083..c525b9eb7324 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -24,10 +24,10 @@ * ExprMutator uses memoization and self return in order to amortize * the cost of using functional updates. */ +#include #include #include #include -#include "type_functor.h" namespace tvm { namespace relay { diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index b1906d3e0feb..9977b5ccdbea 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -21,13 +21,13 @@ * \file src/tvm/relay/ir/hash.cc * \brief Hash functions for Relay types and expressions. */ +#include #include #include #include #include #include #include -#include "type_functor.h" #include "../../ir/attr_functor.h" namespace tvm { @@ -201,8 +201,8 @@ class RelayHashHandler: return hash; } - size_t VisitType_(const RefTypeNode* rtn) final { - size_t hash = std::hash()(RefTypeNode::_type_key); + size_t VisitType_(const RelayRefTypeNode* rtn) final { + size_t hash = std::hash()(RelayRefTypeNode::_type_key); hash = Combine(hash, TypeHash(rtn->value)); return hash; } diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index ae2089d5b765..c21f565f430c 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -30,13 +30,12 @@ * - Var * - Otherwise, inline if the node is at the end of a scope and is used at most once. */ - +#include #include #include #include #include #include "doc.h" -#include "type_functor.h" #include "../pass/dependency_graph.h" #include "../../ir/attr_functor.h" @@ -779,7 +778,7 @@ class PrettyPrinter : return doc << "(" << PrintSep(arg_types) << ") -> " << Print(node->ret_type); } - Doc VisitType_(const RefTypeNode* node) final { + Doc VisitType_(const RelayRefTypeNode* node) final { Doc doc; return doc << "ref(" << Print(node->value) << ")"; } diff --git a/src/relay/op/algorithm/argsort.cc b/src/relay/op/algorithm/argsort.cc index 0d68b446f0fc..13d89a7e1af5 100644 --- a/src/relay/op/algorithm/argsort.cc +++ b/src/relay/op/algorithm/argsort.cc @@ -43,7 +43,7 @@ bool ArgsortRel(const Array& types, << types[0]; return false; } - reporter->Assign(types[1], TensorTypeNode::make(data->shape, param->dtype)); + reporter->Assign(types[1], TensorType(data->shape, param->dtype)); return true; } diff --git a/src/relay/op/algorithm/topk.cc b/src/relay/op/algorithm/topk.cc index 888d431579a8..0ff30bbd3933 100644 --- a/src/relay/op/algorithm/topk.cc +++ b/src/relay/op/algorithm/topk.cc @@ -52,8 +52,8 @@ bool TopKRel(const Array& types, out_shape.push_back(param->k); } } - auto values_ty = TensorTypeNode::make(out_shape, data->dtype); - auto indices_ty = TensorTypeNode::make(out_shape, param->dtype); + auto values_ty = TensorType(out_shape, data->dtype); + auto indices_ty = TensorType(out_shape, param->dtype); if (param->ret_type == "both") { reporter->Assign(types[1], TupleType({values_ty, indices_ty})); } else if (param->ret_type == "values") { diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index e796a044f388..4349e0944a92 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -60,7 +60,7 @@ bool ResizeRel(const Array& types, // assign output type reporter->Assign(types[1], - TensorTypeNode::make(layout_converter.BackwardShape(oshape), + TensorType(layout_converter.BackwardShape(oshape), out_dtype)); return true; } @@ -143,7 +143,7 @@ bool CropAndResizeRel(const Array& types, auto bshape = layout_converter.BackwardShape(oshape); // assign output type reporter->Assign(types[3], - TensorTypeNode::make(layout_converter.BackwardShape(oshape), + TensorType(layout_converter.BackwardShape(oshape), out_dtype)); return true; } diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index 6c4b3ea87b0e..aa0ba2daaba6 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -154,11 +154,11 @@ bool AllocTensorRel(const Array& types, int num_inputs, const Attrs& attrs for (auto i = 0u; i < dims; i++) { out_shape.push_back(tvm::Integer(sh[i])); } - alloc_type = TensorTypeNode::make(out_shape, alloc_attrs->dtype); + alloc_type = TensorType(out_shape, alloc_attrs->dtype); } else { CHECK(alloc_attrs->assert_shape.defined()) << "the assert_shape must be set when const_shape is not"; - alloc_type = TensorTypeNode::make(alloc_attrs->assert_shape, alloc_attrs->dtype); + alloc_type = TensorType(alloc_attrs->assert_shape, alloc_attrs->dtype); return true; } @@ -309,13 +309,13 @@ bool ShapeFuncRel(const Array& types, int num_inputs, const Attrs& attrs, shape_func_ins.push_back(in_type); } else { auto shape = RankShape(in_type->shape); - shape_func_ins.push_back(TensorTypeNode::make(shape, DataType::Int(64))); + shape_func_ins.push_back(TensorType(shape, DataType::Int(64))); } } for (auto out_type : out_types) { auto rank_shape = RankShape(out_type->shape); - shape_func_outs.push_back(TensorTypeNode::make(rank_shape, DataType::Int(64))); + shape_func_outs.push_back(TensorType(rank_shape, DataType::Int(64))); } auto input_type = TupleType(shape_func_ins); diff --git a/src/relay/op/nn/bitserial.cc b/src/relay/op/nn/bitserial.cc index eccffc8f3f0f..c9e05e1f5158 100644 --- a/src/relay/op/nn/bitserial.cc +++ b/src/relay/op/nn/bitserial.cc @@ -81,7 +81,7 @@ bool BitPackRel(const Array& types, int num_inputs, const Attrs& attrs, out_shape.push_back(bits); } - reporter->Assign(types[1], TensorTypeNode::make(out_shape, pack_type)); + reporter->Assign(types[1], TensorType(out_shape, pack_type)); return true; } @@ -144,7 +144,7 @@ bool BinaryConv2DRel(const Array& types, int num_inputs, const Attrs& attr DataType out_dtype = param->out_dtype; oshape = trans_in_layout.BackwardShape(oshape); // assign output type - reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); + reporter->Assign(types[2], TensorType(oshape, out_dtype)); return true; } @@ -220,7 +220,7 @@ bool BinaryDenseRel(const Array& types, int num_inputs, const Attrs& attrs } // Assign output type. - reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); + reporter->Assign(types[2], TensorType(oshape, out_dtype)); return true; } diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 82f4ba50467d..6977ac9b8575 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -271,7 +271,7 @@ bool Conv2DTransposeRel(const Array& types, channels = param->channels; // assign result to reporter - reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype)); + reporter->Assign(types[1], TensorType(wshape, data->dtype)); } else { // use weight to infer the conv shape. if (weight == nullptr) return false; @@ -310,7 +310,7 @@ bool Conv2DTransposeRel(const Array& types, out_dtype = data->dtype; } oshape = trans_out_layout.BackwardShape(oshape); - reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); + reporter->Assign(types[2], TensorType(oshape, out_dtype)); return true; } @@ -434,7 +434,7 @@ bool Conv1DTransposeRel(const Array& types, channels = param->channels; // assign result to reporter - reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype)); + reporter->Assign(types[1], TensorType(wshape, data->dtype)); } else { // use weight to infer the conv shape. if (weight == nullptr) return false; @@ -469,7 +469,7 @@ bool Conv1DTransposeRel(const Array& types, out_dtype = data->dtype; } oshape = trans_out_layout.BackwardShape(oshape); - reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); + reporter->Assign(types[2], TensorType(oshape, out_dtype)); return true; } @@ -616,7 +616,7 @@ bool Conv2DWinogradRel(const Array& types, } oshape = trans_out_layout.BackwardShape(oshape); // assign output type - reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); + reporter->Assign(types[2], TensorType(oshape, out_dtype)); return true; } @@ -702,7 +702,7 @@ bool Conv2DWinogradWeightTransformRel(const Array& types, data->shape[1], }; - reporter->Assign(types[1], TensorTypeNode::make(Array(oshape), + reporter->Assign(types[1], TensorType(Array(oshape), data->dtype)); return true; } @@ -817,7 +817,7 @@ bool Conv2DWinogradNNPACKWeightTransformRel(const Array& types, if (out_dtype.bits() == 0) { out_dtype = data->dtype; } - reporter->Assign(types[1], TensorTypeNode::make(Array(oshape), out_dtype)); + reporter->Assign(types[1], TensorType(Array(oshape), out_dtype)); return true; } @@ -1025,7 +1025,7 @@ bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; // assign result to reporter - reporter->Assign(types[2], TensorTypeNode::make(wshape, data->dtype)); + reporter->Assign(types[2], TensorType(wshape, data->dtype)); } else { // use weight to infer the conv shape. if (weight == nullptr) return false; @@ -1066,12 +1066,12 @@ bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& // infer offset shape Array offset_shape({data->shape[0], 2 * ksize_y * ksize_x * param->deformable_groups, oshape[2], oshape[3]}); - reporter->Assign(types[1], TensorTypeNode::make(offset_shape, data->dtype)); + reporter->Assign(types[1], TensorType(offset_shape, data->dtype)); if (out_dtype.bits() == 0) { out_dtype = data->dtype; } - reporter->Assign(types[3], TensorTypeNode::make(oshape, out_dtype)); + reporter->Assign(types[3], TensorType(oshape, out_dtype)); return true; } diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index f858efca62bd..40619091656f 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -81,7 +81,7 @@ bool Conv1DRel(const Array& types, int num_inputs, const Attrs& attrs, weight_dtype = weight->dtype; } // assign result to reporter - reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype)); + reporter->Assign(types[1], TensorType(wshape, weight_dtype)); } else { // use weight to infer the conv shape. if (weight == nullptr) return false; @@ -117,7 +117,7 @@ bool Conv1DRel(const Array& types, int num_inputs, const Attrs& attrs, } oshape = trans_out_layout.BackwardShape(oshape); // assign output type - reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); + reporter->Assign(types[2], TensorType(oshape, out_dtype)); return true; } @@ -179,7 +179,7 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, weight_dtype = weight->dtype; } // assign result to reporter - reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype)); + reporter->Assign(types[1], TensorType(wshape, weight_dtype)); } else { // use weight to infer the conv shape. if (weight == nullptr) return false; @@ -226,7 +226,7 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, } oshape = trans_out_layout.BackwardShape(oshape); // assign output type - reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); + reporter->Assign(types[2], TensorType(oshape, out_dtype)); return true; } @@ -290,7 +290,7 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, } // assign result to reporter - reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype)); + reporter->Assign(types[1], TensorType(wshape, weight_dtype)); } else { // use weight to infer the conv shape. if (weight == nullptr) return false; @@ -346,7 +346,7 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, } oshape = trans_out_layout.BackwardShape(oshape); // assign output type - reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); + reporter->Assign(types[2], TensorType(oshape, out_dtype)); return true; } diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 2ff439a527ba..1f6ad8f692b8 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -61,7 +61,7 @@ bool BiasAddRel(const Array& types, << "axis " << param->axis << " is out of range"; // assign output type - reporter->Assign(types[1], TensorTypeNode::make( + reporter->Assign(types[1], TensorType( {data->shape[axis]}, data->dtype)); reporter->Assign(types[2], types[0]); return true; @@ -138,7 +138,7 @@ bool FIFOBufferRel(const Array& types, Array oshape = buffer->shape; - reporter->Assign(types[2], TensorTypeNode::make(oshape, buffer->dtype)); + reporter->Assign(types[2], TensorType(oshape, buffer->dtype)); return true; } @@ -260,10 +260,10 @@ bool PReluRel(const Array& types, // assign alpha type Array alpha_shape({data->shape[param->axis]}); - reporter->Assign(types[1], TensorTypeNode::make(alpha_shape, data->dtype)); + reporter->Assign(types[1], TensorType(alpha_shape, data->dtype)); // assign output type - reporter->Assign(types[2], TensorTypeNode::make(data->shape, data->dtype)); + reporter->Assign(types[2], TensorType(data->shape, data->dtype)); return true; } @@ -419,7 +419,7 @@ bool BatchFlattenRel(const Array& types, std::vector oshape({data->shape[0], target_dim}); // assign output type - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; } @@ -585,7 +585,7 @@ bool DropoutRel(const Array& types, // dropout returns the original tensor with dropout applied // and a mask tensor (1.0 where element not dropped, 0.0 where dropped) - auto ret_type = TensorTypeNode::make(data->shape, data->dtype); + auto ret_type = TensorType(data->shape, data->dtype); reporter->Assign(types[1], TupleType(Array({ret_type, ret_type}))); return true; } @@ -661,17 +661,17 @@ bool BatchNormRel(const Array& types, auto axis_size = data->shape[axis]; // if we are using beta and gamma, they need to be of shape (dim,) - reporter->Assign(types[1], TensorTypeNode::make({axis_size}, data->dtype)); - reporter->Assign(types[2], TensorTypeNode::make({axis_size}, data->dtype)); - reporter->Assign(types[3], TensorTypeNode::make({axis_size}, data->dtype)); - reporter->Assign(types[4], TensorTypeNode::make({axis_size}, data->dtype)); + reporter->Assign(types[1], TensorType({axis_size}, data->dtype)); + reporter->Assign(types[2], TensorType({axis_size}, data->dtype)); + reporter->Assign(types[3], TensorType({axis_size}, data->dtype)); + reporter->Assign(types[4], TensorType({axis_size}, data->dtype)); // output is a tuple of the normed data (same shape as input), new running mean, // and new running average (the latter two are both vectors of length dim) std::vector fields; - auto vec_ty = TensorTypeNode::make(Array({data->shape[axis]}), + auto vec_ty = TensorType(Array({data->shape[axis]}), data->dtype); - fields.push_back(TensorTypeNode::make(data->shape, data->dtype)); + fields.push_back(TensorType(data->shape, data->dtype)); fields.push_back(vec_ty); fields.push_back(vec_ty); reporter->Assign(types[5], TupleType(Array(fields))); @@ -754,9 +754,9 @@ bool InstanceNormRel(const Array& types, const InstanceNormAttrs* param = attrs.as(); int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size(); CHECK(axis >= 0 && axis < (int)data->shape.size()); - reporter->Assign(types[1], TensorTypeNode::make({data->shape[axis]}, data->dtype)); - reporter->Assign(types[2], TensorTypeNode::make({data->shape[axis]}, data->dtype)); - reporter->Assign(types[3], TensorTypeNode::make(data->shape, data->dtype)); + reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype)); + reporter->Assign(types[2], TensorType({data->shape[axis]}, data->dtype)); + reporter->Assign(types[3], TensorType(data->shape, data->dtype)); return true; } @@ -824,9 +824,9 @@ bool LayerNormRel(const Array& types, const LayerNormAttrs* param = attrs.as(); int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size(); CHECK(axis >= 0 && axis < (int)data->shape.size()); - reporter->Assign(types[1], TensorTypeNode::make({data->shape[axis]}, data->dtype)); - reporter->Assign(types[2], TensorTypeNode::make({data->shape[axis]}, data->dtype)); - reporter->Assign(types[3], TensorTypeNode::make(data->shape, data->dtype)); + reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype)); + reporter->Assign(types[2], TensorType({data->shape[axis]}, data->dtype)); + reporter->Assign(types[3], TensorType(data->shape, data->dtype)); return true; } @@ -881,7 +881,7 @@ bool BatchMatmulRel(const Array& types, oshape.Set(2, y->shape[1]); // assign output type - reporter->Assign(types[2], TensorTypeNode::make(oshape, x->dtype)); + reporter->Assign(types[2], TensorType(oshape, x->dtype)); return true; } @@ -940,7 +940,7 @@ bool CrossEntropyRel(const Array& types, << "x shape = " << x->shape << ", " << "y shape = " << y->shape; // assign output type - reporter->Assign(types[2], TensorTypeNode::make({}, x->dtype)); + reporter->Assign(types[2], TensorType({}, x->dtype)); return true; } @@ -1016,7 +1016,7 @@ bool DepthToSpaceRel(const Array& types, int num_inputs, const Attrs& attr // Assign output type reporter->Assign(types[1], - TensorTypeNode::make(layout_converter.BackwardShape(oshape), data->dtype)); + TensorType(layout_converter.BackwardShape(oshape), data->dtype)); return true; } @@ -1074,7 +1074,7 @@ bool SpaceToDepthRel(const Array& types, int num_inputs, const Attrs& attr // Assign output type reporter->Assign(types[1], - TensorTypeNode::make(layout_converter.BackwardShape(oshape), data->dtype)); + TensorType(layout_converter.BackwardShape(oshape), data->dtype)); return true; } diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 73899097fe52..dc876e863ad0 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -52,7 +52,7 @@ bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, // data dtype as the weight dtype. However if weight dtype is explicitly // present we will use that. auto weight_dtype = (weight == nullptr ? data->dtype : weight->dtype); - reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype)); + reporter->Assign(types[1], TensorType(wshape, weight_dtype)); oshape.Set((oshape.size() - 1), param->units); } else { if (weight == nullptr) return false; @@ -70,7 +70,7 @@ bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, out_dtype = data->dtype; } // assign output type - reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); + reporter->Assign(types[2], TensorType(oshape, out_dtype)); return true; } diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index e33f751fb638..16561588d6b2 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -155,7 +155,7 @@ bool PadRel(const Array& types, } } - reporter->Assign(types[1], TensorTypeNode::make(Array(oshape), + reporter->Assign(types[1], TensorType(Array(oshape), data->dtype)); return true; } @@ -260,7 +260,7 @@ bool MirrorPadRel(const Array& types, oshape.push_back(data->shape[i] + padding); } - reporter->Assign(types[1], TensorTypeNode::make(Array(oshape), + reporter->Assign(types[1], TensorType(Array(oshape), data->dtype)); return true; } diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 0c74b2711f01..7b6deffee7d7 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -161,7 +161,7 @@ bool Pool2DRel(const Array& types, } // assign output type - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; } @@ -327,7 +327,7 @@ bool GlobalPool2DRel(const Array& types, oshape.Set(widx, 1); // assign output type - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; } @@ -462,7 +462,7 @@ bool AdaptivePool2DRel(const Array& types, oshape.Set(widx, output_width); // assign output type - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; } @@ -792,7 +792,7 @@ bool Pool1DRel(const Array& types, } // assign output type - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; } @@ -987,7 +987,7 @@ bool Pool3DRel(const Array& types, } // assign output type - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; } diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc index caad01b9e66b..c01b760c4ad0 100644 --- a/src/relay/op/nn/sparse.cc +++ b/src/relay/op/nn/sparse.cc @@ -47,7 +47,7 @@ bool SparseDenseRel(const Array& types, int num_inputs, const Attrs& attrs if (weight_data->shape.size() == 1) { // CSR case. Array oshape({data->shape[0], weight_indptr->shape[0] - 1}); - reporter->Assign(types[4], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[4], TensorType(oshape, data->dtype)); return true; } @@ -56,7 +56,7 @@ bool SparseDenseRel(const Array& types, int num_inputs, const Attrs& attrs Array oshape({ data->shape[0], (weight_indptr->shape[0] - 1) * weight_data->shape[1]}); - reporter->Assign(types[4], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[4], TensorType(oshape, data->dtype)); return true; } LOG(FATAL) << "Unknown weight ndim for nn.sparse_dense, should be 1 (CSR) or 3 (BSR)"; @@ -105,9 +105,9 @@ bool SparseTransposeRel(const Array& types, int num_inputs, const Attrs& a const auto* sparse_indptr = types[2].as(); std::vector output_types; - output_types.push_back(TensorTypeNode::make(sparse_data->shape, sparse_data->dtype)); - output_types.push_back(TensorTypeNode::make(sparse_indices->shape, sparse_indices->dtype)); - output_types.push_back(TensorTypeNode::make(sparse_indptr->shape, sparse_indptr->dtype)); + output_types.push_back(TensorType(sparse_data->shape, sparse_data->dtype)); + output_types.push_back(TensorType(sparse_indices->shape, sparse_indices->dtype)); + output_types.push_back(TensorType(sparse_indptr->shape, sparse_indptr->dtype)); reporter->Assign(types[3], TupleType(Array(output_types))); return true; diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index e78f7fdade68..477cec76133d 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -87,7 +87,7 @@ bool UpSamplingRel(const Array& types, // assign output type reporter->Assign(types[1], - TensorTypeNode::make(layout_converter.BackwardShape(oshape), + TensorType(layout_converter.BackwardShape(oshape), data->dtype)); return true; } @@ -167,7 +167,7 @@ bool UpSampling3DRel(const Array& types, // assign output type reporter->Assign(types[1], - TensorTypeNode::make(layout_converter.BackwardShape(oshape), + TensorType(layout_converter.BackwardShape(oshape), data->dtype)); return true; } diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 5156330d7601..880a337c4ef8 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -272,7 +272,7 @@ bool ArgReduceRel(const Array& types, // assign output type and shape auto oshape = ReduceShapeImpl(in_shape, param, reporter); - reporter->Assign(types[1], TensorTypeNode::make(oshape, DataType::Int(32))); + reporter->Assign(types[1], TensorType(oshape, DataType::Int(32))); return true; } @@ -297,7 +297,7 @@ bool ReduceRel(const Array& types, // assign output type and shape auto oshape = ReduceShapeImpl(in_shape, param, reporter); - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; } @@ -594,7 +594,7 @@ bool VarianceRel(const Array& types, // assign output type and shape auto oshape = ReduceShapeImpl(in_shape, param, reporter); - reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[2], TensorType(oshape, data->dtype)); return true; } diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 538c92ed42df..b95875517e2c 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -60,7 +60,7 @@ bool CastRel(const Array& types, return false; } const auto* param = attrs.as(); - reporter->Assign(types[1], TensorTypeNode::make( + reporter->Assign(types[1], TensorType( data->shape, param->dtype)); return true; } @@ -120,7 +120,7 @@ bool CastLikeRel(const Array& types, << types[1]; return false; } - reporter->Assign(types[2], TensorTypeNode::make(data->shape, dtype_like->dtype)); + reporter->Assign(types[2], TensorType(data->shape, dtype_like->dtype)); return true; } @@ -226,7 +226,7 @@ bool ExpandDimsRel(const Array& types, for (int i = pivot; i < ndim; ++i) { oshape.emplace_back(data->shape[i]); } - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; } @@ -408,7 +408,7 @@ bool StackRel(const Array& types, for (int i = axis; i < ndim; ++i) { oshape.emplace_back(first->shape[i]); } - reporter->Assign(types[1], TensorTypeNode::make(oshape, dtype)); + reporter->Assign(types[1], TensorType(oshape, dtype)); return true; } @@ -500,7 +500,7 @@ bool TransposeRel(const Array& types, for (int axis : int_axes) { oshape.push_back(data->shape[axis]); } - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; } @@ -679,10 +679,10 @@ bool ReshapeRel(const Array& types, } if (param->reverse) { - reporter->Assign(types[1], TensorTypeNode::make( + reporter->Assign(types[1], TensorType( Array(oshape.rbegin(), oshape.rend()), data->dtype)); } else { - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[1], TensorType(oshape, data->dtype)); } return true; } @@ -809,7 +809,7 @@ bool ReshapeLikeRel(const Array& types, CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size())) << "Reshape inputs size should be compatible."; } - reporter->Assign(types[2], TensorTypeNode::make(reshape_like->shape, data->dtype)); + reporter->Assign(types[2], TensorType(reshape_like->shape, data->dtype)); return true; } @@ -853,7 +853,7 @@ bool ArgWhereRel(const Array& types, std::vector result_shape; result_shape.push_back(Any::make()); result_shape.push_back(IntImm(DataType::Int(32), input_rank)); - reporter->Assign(types[1], TensorTypeNode::make(result_shape, DataType::Int(32))); + reporter->Assign(types[1], TensorType(result_shape, DataType::Int(32))); return true; } @@ -894,7 +894,7 @@ bool TakeRel(const Array& types, if (!param->axis.defined()) { std::vector oshape(indices->shape.begin(), indices->shape.end()); - reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[2], TensorType(oshape, data->dtype)); return true; } @@ -918,7 +918,7 @@ bool TakeRel(const Array& types, oshape.emplace_back(data->shape[i]); } - reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[2], TensorType(oshape, data->dtype)); return true; } @@ -1005,7 +1005,7 @@ bool FullRel(const Array& types, << "Fill value should be a scalar but has dimension " << fill_value->shape.size() << "."; - reporter->Assign(types[1], TensorTypeNode::make(param->shape, out_dtype)); + reporter->Assign(types[1], TensorType(param->shape, out_dtype)); return true; } @@ -1049,7 +1049,7 @@ bool InitOpRel(const Array& types, CHECK_EQ(types.size(), 1); const InitOpAttrs* param = attrs.as(); - reporter->Assign(types[0], TensorTypeNode::make(param->shape, param->dtype)); + reporter->Assign(types[0], TensorType(param->shape, param->dtype)); return true; } @@ -1113,7 +1113,7 @@ bool FullLikeRel(const Array& types, << "The fill value should be a scalar but here it has dimension " << fill_value->shape.size() << "."; - reporter->Assign(types[2], TensorTypeNode::make(data->shape, data->dtype)); + reporter->Assign(types[2], TensorType(data->shape, data->dtype)); return true; } @@ -1197,7 +1197,7 @@ bool ArangeRel(const Array& types, reporter->Assign(types[0], types[1]); reporter->Assign(types[1], types[2]); - reporter->Assign(types[2], TensorTypeNode::make({}, attrs->dtype)); + reporter->Assign(types[2], TensorType({}, attrs->dtype)); if ((cstart = attrs->start.as()) && (cstop = attrs->stop.as()) && @@ -1209,10 +1209,10 @@ bool ArangeRel(const Array& types, CHECK_GT(num_elem, 0) << "Invalid arange attributes (start, stop, step): " << attrs->start << ", " << attrs->stop << ", " << attrs->step; - reporter->Assign(types[3], TensorTypeNode::make({num_elem}, attrs->dtype)); + reporter->Assign(types[3], TensorType({num_elem}, attrs->dtype)); return true; } else { - reporter->Assign(types[3], TensorTypeNode::make({Any::make()}, attrs->dtype)); + reporter->Assign(types[3], TensorType({Any::make()}, attrs->dtype)); return true; } } @@ -1320,7 +1320,7 @@ bool RepeatRel(const Array& types, for (int i = pivot + 1; i < ndim; ++i) { oshape.emplace_back(data->shape[i]); } - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; } @@ -1431,7 +1431,7 @@ bool TileRel(const Array& types, oshape.emplace_back(data_shape[i] * reps_shape[i]); } } - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; } @@ -1560,7 +1560,7 @@ bool WhereRel(const Array& types, << "condition and x must have the same shape: " << cond_shape << " vs " << x_shape; } } - reporter->Assign(types[3], TensorTypeNode::make(x_shape, x->dtype)); + reporter->Assign(types[3], TensorType(x_shape, x->dtype)); return true; } @@ -1683,7 +1683,7 @@ bool SqueezeRel(const Array& types, } } } - reporter->Assign(types[1], TensorTypeNode::make(result_shape, data->dtype)); + reporter->Assign(types[1], TensorType(result_shape, data->dtype)); return true; } @@ -1761,7 +1761,7 @@ bool BroadCastToRel(const Array& types, CHECK(ioattrs); auto intt = types[0].as(); if (intt == nullptr) { return false; } - auto type = TensorTypeNode::make(ioattrs->shape, intt->dtype); + auto type = TensorType(ioattrs->shape, intt->dtype); reporter->Assign(types[1], type); return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter); } @@ -1942,7 +1942,7 @@ bool StridedSliceRel(const Array& types, } oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1) / step); } - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; } @@ -2147,7 +2147,7 @@ bool SplitRel(const Array& types, for (int i = 0; i < sections->value; ++i) { std::vector oshape(data->shape.begin(), data->shape.end()); oshape[axis] = indexdiv(oshape[axis], sections->value); - auto vec_type = TensorTypeNode::make(oshape, data->dtype); + auto vec_type = TensorType(oshape, data->dtype); fields.push_back(vec_type); } reporter->Assign(types[1], TupleType(Array(fields))); @@ -2161,14 +2161,14 @@ bool SplitRel(const Array& types, std::vector oshape(data->shape.begin(), data->shape.end()); oshape[axis] = Downcast(indices[i]) - begin; begin = Downcast(indices[i]); - auto vec_type = TensorTypeNode::make(oshape, data->dtype); + auto vec_type = TensorType(oshape, data->dtype); fields.push_back(vec_type); } CHECK(reporter->Assert(begin < data->shape[axis])) << "The sum of sections must match the input.shape[axis]"; std::vector oshape(data->shape.begin(), data->shape.end()); oshape[axis] = data->shape[axis] - begin; - auto vec_type = TensorTypeNode::make(oshape, data->dtype); + auto vec_type = TensorType(oshape, data->dtype); fields.push_back(vec_type); reporter->Assign(types[1], TupleType(Array(fields))); } @@ -2290,7 +2290,7 @@ bool SliceLikeRel(const Array& types, } } - reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[2], TensorType(oshape, data->dtype)); return true; } @@ -2400,7 +2400,7 @@ bool LayoutTransformRel(const Array& types, << "cannot convert from " << params->src_layout << " to " << params->dst_layout; const auto& out_shape = layout_converter.ForwardShape(data->shape); - reporter->Assign(types[1], TensorTypeNode::make(out_shape, data->dtype)); + reporter->Assign(types[1], TensorType(out_shape, data->dtype)); return true; } @@ -2499,7 +2499,7 @@ bool GatherNDRel(const Array& types, oshape.push_back(indices->shape[i]); for (size_t i = mdim->value; i < ndim; ++i) oshape.push_back(data->shape[i]); - reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[2], TensorType(oshape, data->dtype)); return true; } @@ -2552,7 +2552,7 @@ bool SequenceMaskRel(const Array& types, Array valid_length_shape; CHECK(param->axis == 0 || param->axis == 1); valid_length_shape.push_back(data->shape[1 - param->axis]); - reporter->Assign(types[1], TensorTypeNode::make(valid_length_shape, valid_length->dtype)); + reporter->Assign(types[1], TensorType(valid_length_shape, valid_length->dtype)); reporter->Assign(types[2], types[0]); return true; } @@ -2666,7 +2666,7 @@ bool OneHotRel(const Array& types, } } - reporter->Assign(types[3], TensorTypeNode::make(oshape, param->dtype)); + reporter->Assign(types[3], TensorType(oshape, param->dtype)); return true; } diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index a1cbf7aa45c2..b69f6e7178ff 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -119,7 +119,7 @@ bool ConcatenateRel(const Array& types, concat_dim = Any::make(); } - auto rtype = TensorTypeNode::make(oshape, dtype); + auto rtype = TensorType(oshape, dtype); reporter->Assign(types[1], rtype); return true; } diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 331653b49445..98ff09911d75 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -286,7 +286,7 @@ bool ShapeOfRel(const Array& types, const auto* param = attrs.as(); CHECK(param != nullptr); auto rank_shape = RankShape(tt->shape); - reporter->Assign(types[1], TensorTypeNode::make(rank_shape, param->dtype)); + reporter->Assign(types[1], TensorType(rank_shape, param->dtype)); return true; } @@ -337,7 +337,7 @@ bool NdarraySizeRel(const Array& types, CHECK(tt != nullptr); const auto* param = attrs.as(); CHECK(param != nullptr); - reporter->Assign(types[1], TensorTypeNode::make({1}, param->dtype)); + reporter->Assign(types[1], TensorType({1}, param->dtype)); return true; } diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index cd476fd49b87..3c7d148ab803 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -96,7 +96,7 @@ Type ConcreteBroadcast(const TensorType& t1, for (; i <= max_ndim; ++i) { oshape.push_back(rshape[max_ndim - i]); } - return TensorTypeNode::make(Array( + return TensorType(Array( oshape.rbegin(), oshape.rend()), output_dtype); } diff --git a/src/relay/op/vision/multibox_op.cc b/src/relay/op/vision/multibox_op.cc index b801186fb5b5..eb5012fbfb35 100644 --- a/src/relay/op/vision/multibox_op.cc +++ b/src/relay/op/vision/multibox_op.cc @@ -50,7 +50,7 @@ bool MultiboxPriorRel(const Array& types, {1, in_height * in_width * (num_sizes + num_ratios - 1), 4}); // assign output type - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; } @@ -122,8 +122,8 @@ bool MultiBoxTransformLocRel(const Array& types, std::vector oshape0({cls_shape[0], anchor_shape[1], 6}); std::vector oshape1({cls_shape[0]}); std::vector fields; - fields.push_back(TensorTypeNode::make(oshape0, cls_prob->dtype)); - fields.push_back(TensorTypeNode::make(oshape1, DataType::Int(32))); + fields.push_back(TensorType(oshape0, cls_prob->dtype)); + fields.push_back(TensorType(oshape1, DataType::Int(32))); // assign output type reporter->Assign(types[3], TupleType(Array(fields))); diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 452477928593..bec0c1d8d45a 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -40,8 +40,8 @@ bool GetValidCountRel(const Array& types, std::vector oshape({data->shape[0]}); std::vector fields; - fields.push_back(TensorTypeNode::make(oshape, DataType::Int(32))); - fields.push_back(TensorTypeNode::make(data->shape, data->dtype)); + fields.push_back(TensorType(oshape, DataType::Int(32))); + fields.push_back(TensorType(data->shape, data->dtype)); // assign output type reporter->Assign(types[1], TupleType(Array(fields))); @@ -95,9 +95,9 @@ bool NMSRel(const Array& types, // assign output type if (param->return_indices) { std::vector oshape({dshape[0], dshape[1]}); - reporter->Assign(types[2], TensorTypeNode::make(oshape, DataType::Int(32))); + reporter->Assign(types[2], TensorType(oshape, DataType::Int(32))); } else { - reporter->Assign(types[2], TensorTypeNode::make(dshape, data->dtype)); + reporter->Assign(types[2], TensorType(dshape, data->dtype)); } return true; } diff --git a/src/relay/op/vision/rcnn_op.cc b/src/relay/op/vision/rcnn_op.cc index 7b3533d65ff3..65efd0495656 100644 --- a/src/relay/op/vision/rcnn_op.cc +++ b/src/relay/op/vision/rcnn_op.cc @@ -45,7 +45,7 @@ bool ROIAlignRel(const Array& types, int num_inputs, const Attrs& attrs, // assign output type std::vector oshape( {rshape[0], dshape[1], roi_align_attrs->pooled_size[0], roi_align_attrs->pooled_size[1]}); - reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[2], TensorType(oshape, data->dtype)); return true; } @@ -96,7 +96,7 @@ bool ROIPoolRel(const Array& types, int num_inputs, const Attrs& attrs, // assign output type std::vector oshape( {rshape[0], dshape[1], roi_pool_attrs->pooled_size[0], roi_pool_attrs->pooled_size[1]}); - reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[2], TensorType(oshape, data->dtype)); return true; } @@ -155,7 +155,7 @@ bool ProposalRel(const Array& types, int num_inputs, const Attrs& attrs, std::vector oshape( {batch * proposal_attrs->rpn_post_nms_top_n, 5}); - reporter->Assign(types[3], TensorTypeNode::make(oshape, cls_prob->dtype)); + reporter->Assign(types[3], TensorType(oshape, cls_prob->dtype)); return true; } diff --git a/src/relay/op/vision/yolo.cc b/src/relay/op/vision/yolo.cc index 9964a8274392..5a59a74e7369 100644 --- a/src/relay/op/vision/yolo.cc +++ b/src/relay/op/vision/yolo.cc @@ -56,7 +56,7 @@ bool YoloReorgRel(const Array& types, oshape[1] = oshape[1] * param->stride * param->stride; oshape[2] = indexdiv(oshape[2], param->stride); oshape[3] = indexdiv(oshape[3], param->stride); - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; } diff --git a/src/relay/pass/de_duplicate.cc b/src/relay/pass/de_duplicate.cc index fc7f820e6d86..d8167601cb07 100644 --- a/src/relay/pass/de_duplicate.cc +++ b/src/relay/pass/de_duplicate.cc @@ -22,11 +22,10 @@ * \file de_duplicate.cc * \brief Use a fresh Id for every Var to make the result well-formed. */ - +#include #include #include #include -#include "../ir/type_functor.h" namespace tvm { namespace relay { diff --git a/src/relay/pass/eta_expand.cc b/src/relay/pass/eta_expand.cc index 8dece3fa3528..b274460bbcff 100644 --- a/src/relay/pass/eta_expand.cc +++ b/src/relay/pass/eta_expand.cc @@ -23,10 +23,10 @@ * \brief Add an abstraction over constructors and/or global variables bound to a function. * */ +#include #include #include #include -#include "../ir/type_functor.h" namespace tvm { namespace relay { diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 20958ab598da..7d94d4e39353 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -21,7 +21,7 @@ * \file ad.cc * \brief API for Automatic Differentiation for the Relay IR. */ - +#include #include #include #include @@ -30,7 +30,6 @@ #include "pattern_util.h" #include "pass_util.h" #include "let_list.h" -#include "../ir/type_functor.h" namespace tvm { namespace relay { @@ -265,7 +264,7 @@ TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient") struct ReverseADType : TypeMutator { Type VisitType_(const TensorTypeNode* ttn) final { Type t = GetRef(ttn); - return TupleType({t, RefTypeNode::make(t)}); + return TupleType({t, RelayRefType(t)}); } }; diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index 55fd78a9b8db..d43059cf9f6a 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -31,9 +31,9 @@ * We check this by ensuring the `dtype` field of a Tensor always * contains a data type such as `int`, `float`, `uint`. */ +#include #include #include -#include "../ir/type_functor.h" namespace tvm { namespace relay { @@ -107,9 +107,9 @@ struct KindChecker : TypeFunctor { return Kind::kType; } - Kind VisitType_(const RefTypeNode* op) override { + Kind VisitType_(const RelayRefTypeNode* op) override { // ref types should only contain normal types - RefType rt = GetRef(op); + RelayRefType rt = GetRef(op); CheckKindMatches(op->value, rt, Kind::kType, "ref contents"); return Kind::kType; } diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index e9e37d2e9102..37ce34817787 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -89,12 +89,12 @@ * * These assumptions do not affect the correctness of the algorithm, however. */ +#include #include #include #include #include #include -#include "../ir/type_functor.h" #include "pass_util.h" #include "let_list.h" @@ -863,7 +863,7 @@ class PartialEvaluator : public ExprFunctor subst.Set(func->type_params[i], type_args[i]); } for (size_t i = type_args.size(); i < func->type_params.size(); ++i) { - subst.Set(func->type_params[i], IncompleteTypeNode::make(kType)); + subst.Set(func->type_params[i], IncompleteType(kType)); } return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll); } else { diff --git a/src/relay/pass/quantize/quantize.cc b/src/relay/pass/quantize/quantize.cc index 41a5a8e5726d..2441f6e65d88 100644 --- a/src/relay/pass/quantize/quantize.cc +++ b/src/relay/pass/quantize/quantize.cc @@ -48,9 +48,9 @@ bool SimulatedQuantizeRel(const Array& types, CHECK(data != nullptr); CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty"; - reporter->Assign(types[1], TensorTypeNode::make({}, DataType::Float(32))); // dom_scale - reporter->Assign(types[2], TensorTypeNode::make({}, DataType::Float(32))); // clip_min - reporter->Assign(types[3], TensorTypeNode::make({}, DataType::Float(32))); // clip_max + reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // dom_scale + reporter->Assign(types[2], TensorType({}, DataType::Float(32))); // clip_min + reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // clip_max reporter->Assign(types[4], types[0]); // output return true; } diff --git a/src/relay/pass/to_cps.cc b/src/relay/pass/to_cps.cc index f88a7c91b27b..293d69667ed7 100644 --- a/src/relay/pass/to_cps.cc +++ b/src/relay/pass/to_cps.cc @@ -50,10 +50,10 @@ * All cases in the transform must return via the mcont, * wheter directly invoking it, or indirectly by recursion. */ +#include #include #include #include -#include "../ir/type_functor.h" #include "let_list.h" #include "pass_util.h" diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index a513f3e51a10..ed5f91a3f1e0 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -37,7 +37,7 @@ * If we can not infer a type or there are conflicting typing * constraints we will trigger an error. */ - +#include #include #include #include @@ -45,7 +45,6 @@ #include #include "./pass_util.h" #include "type_solver.h" -#include "../ir/type_functor.h" namespace tvm { namespace relay { @@ -180,7 +179,7 @@ class TypeInferencer : private ExprFunctor, if (op->type_annotation.defined()) { return op->type_annotation; } else { - return IncompleteTypeNode::make(Kind::kType); + return IncompleteType(Kind::kType); } } @@ -215,7 +214,7 @@ class TypeInferencer : private ExprFunctor, EnvFunc::Get("tvm.relay.type_relation.TupleGetItem")); } Type tuple_type = GetType(op->tuple); - Type rtype = IncompleteTypeNode::make(Kind::kType); + Type rtype = IncompleteType(Kind::kType); auto attrs = make_object(); attrs->index = op->index; solver_.AddConstraint(TypeRelation( @@ -233,7 +232,7 @@ class TypeInferencer : private ExprFunctor, // we can expect a certain number of arguments Array unknown_args; for (size_t i = 0; i < td->type_vars.size(); i++) { - unknown_args.push_back(IncompleteTypeNode::make(Kind::kType)); + unknown_args.push_back(IncompleteType(Kind::kType)); } Type expected = TypeCall(con->constructor->belong_to, unknown_args); Type unified = Unify(t, expected, GetRef(con)); @@ -275,7 +274,7 @@ class TypeInferencer : private ExprFunctor, // we can expect a certain number of arguments Array unknown_args; for (size_t i = 0; i < tup->patterns.size(); i++) { - unknown_args.push_back(IncompleteTypeNode::make(Kind::kType)); + unknown_args.push_back(IncompleteType(Kind::kType)); } Type expected = TupleType(unknown_args); Type unified = Unify(t, expected, GetRef(tup)); @@ -302,7 +301,7 @@ class TypeInferencer : private ExprFunctor, for (const auto& c : op->clauses) { VisitPattern(c->lhs, dtype); } - Type rtype = IncompleteTypeNode::make(Kind::kType); + Type rtype = IncompleteType(Kind::kType); for (const auto& c : op->clauses) { rtype = this->Unify(rtype, GetType(c->rhs), @@ -336,7 +335,7 @@ class TypeInferencer : private ExprFunctor, Type VisitExpr_(const LetNode* let) final { // if the definition is a function literal, permit recursion bool is_functional_literal = let->value.as() != nullptr; - Type let_type = IncompleteTypeNode::make(Kind::kType); + Type let_type = IncompleteType(Kind::kType); if (is_functional_literal) { let_type = GetType(let->var); @@ -362,7 +361,7 @@ class TypeInferencer : private ExprFunctor, // that is a rank-0 boolean tensor. Type cond_type = this->GetType(ite->cond); this->Unify(cond_type, - TensorTypeNode::Scalar(tvm::DataType::Bool()), + TensorType::Scalar(tvm::DataType::Bool()), ite->cond); Type checked_true = this->GetType(ite->true_branch); Type checked_false = this->GetType(ite->false_branch); @@ -385,7 +384,7 @@ class TypeInferencer : private ExprFunctor, for (size_t i = 0; i < op->type_params.size(); ++i) { if (!op->type_params[i].same_as(rel->args[i])) return Type(); } - Type rtype = IncompleteTypeNode::make(Kind::kType); + Type rtype = IncompleteType(Kind::kType); arg_types.push_back(rtype); // we can do simple replacement here solver_.AddConstraint(TypeRelation( @@ -404,7 +403,7 @@ class TypeInferencer : private ExprFunctor, } for (size_t i = ty_args.size(); i < fn_ty->type_params.size(); ++i) { - subst_map.Set(fn_ty->type_params[i], IncompleteTypeNode::make(Kind::kType)); + subst_map.Set(fn_ty->type_params[i], IncompleteType(Kind::kType)); } Type ret_type = fn_ty->ret_type; @@ -415,7 +414,7 @@ class TypeInferencer : private ExprFunctor, // This is a temporary work around to check recursive functions whose // return type is not yet known. if (!ret_type.defined()) { - ret_type = IncompleteTypeNode::make(Kind::kType); + ret_type = IncompleteType(Kind::kType); } Type inst_ty = FuncType(fn_ty->arg_types, @@ -433,7 +432,7 @@ class TypeInferencer : private ExprFunctor, Array type_args; for (size_t i = 0; i < fn_ty->type_params.size(); i++) { - type_args.push_back(IncompleteTypeNode::make(Kind::kType)); + type_args.push_back(IncompleteType(Kind::kType)); } return InstantiateFuncType(fn_ty, type_args); } @@ -466,7 +465,7 @@ class TypeInferencer : private ExprFunctor, // incomplete type => it must be a function taking the arg types // with an unknown return type if (inc_ty_node != nullptr) { - Type ret_type = IncompleteTypeNode::make(Kind::kType); + Type ret_type = IncompleteType(Kind::kType); Type func_type = FuncType(arg_types, ret_type, {}, {}); Type unified = this->Unify(ftype, func_type, GetRef(call)); fn_ty_node = unified.as(); @@ -562,18 +561,18 @@ class TypeInferencer : private ExprFunctor, } Type VisitExpr_(const RefCreateNode* op) final { - return RefTypeNode::make(GetType(op->value)); + return RelayRefType(GetType(op->value)); } Type VisitExpr_(const RefReadNode* op) final { - Type it = IncompleteTypeNode::make(Kind::kType); - this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef(op)); + Type it = IncompleteType(Kind::kType); + this->Unify(GetType(op->ref), RelayRefType(it), GetRef(op)); return it; } Type VisitExpr_(const RefWriteNode* op) final { - Type it = IncompleteTypeNode::make(Kind::kType); - this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef(op)); + Type it = IncompleteType(Kind::kType); + this->Unify(GetType(op->ref), RelayRefType(it), GetRef(op)); this->Unify(GetType(op->value), it, GetRef(op)); return TupleType::Empty(); } diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index ec6d721cf7b0..0ad43d03b60d 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -21,13 +21,13 @@ * \file type_solver.cc * \brief Type solver implementations. */ +#include #include #include #include #include #include #include "type_solver.h" -#include "../ir/type_functor.h" namespace tvm { namespace relay { @@ -270,7 +270,7 @@ class TypeSolver::Unifier : public TypeFunctor { return Type(nullptr); } - return TensorTypeNode::make(shape, tt1->dtype); + return TensorType(shape, tt1->dtype); } Type VisitType_(const TupleTypeNode* op, const Type& tn) final { @@ -312,7 +312,7 @@ class TypeSolver::Unifier : public TypeFunctor { } for (size_t i = ftn->type_params.size(); i < op->type_params.size(); ++i) { - subst_map.Set(op->type_params[i], IncompleteTypeNode::make(kType)); + subst_map.Set(op->type_params[i], IncompleteType(kType)); } FuncType ft = FuncType(op->arg_types, @@ -343,12 +343,12 @@ class TypeSolver::Unifier : public TypeFunctor { return FuncType(arg_types, ret_type, ft2->type_params, type_constraints); } - Type VisitType_(const RefTypeNode* op, const Type& tn) final { - const auto* rtn = tn.as(); + Type VisitType_(const RelayRefTypeNode* op, const Type& tn) final { + const auto* rtn = tn.as(); if (!rtn) { return Type(nullptr); } - return RefTypeNode::make(Unify(op->value, rtn->value)); + return RelayRefType(Unify(op->value, rtn->value)); } Type VisitType_(const TypeCallNode* op, const Type& tn) override { @@ -690,7 +690,7 @@ TVM_REGISTER_GLOBAL("relay._analysis._test_type_solver") } else if (name == "AddConstraint") { return TypedPackedFunc([solver](TypeConstraint c) { Expr e = VarNode::make("dummy_var", - IncompleteTypeNode::make(Kind::kType)); + IncompleteType(Kind::kType)); return solver->AddConstraint(c, e); }); } else { diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index e45b15a0c4da..b3419668cf0d 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -23,12 +23,12 @@ * * \brief Utility functions for Relay. */ +#include #include #include #include #include #include "pass_util.h" -#include "../ir/type_functor.h" namespace tvm { namespace relay { diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index e01a47dc5f1b..ee6799799011 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -52,7 +52,7 @@ bool DequantizeRel(const Array& types, const Array oshape = data->shape; // assign output type, output will always be float 32. - reporter->Assign(types[3], TensorTypeNode::make(oshape, DataType::Float(32))); + reporter->Assign(types[3], TensorType(oshape, DataType::Float(32))); return true; } diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index f53d2c5ee438..e2472c6c5453 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -63,7 +63,7 @@ bool QuantizeRel(const Array& types, out_dtype == DataType::Int(32)) << "Output type should be one of [int8, unit8, int32] but was " << out_dtype; // assign output type - reporter->Assign(types[3], TensorTypeNode::make(oshape, out_dtype)); + reporter->Assign(types[3], TensorType(oshape, out_dtype)); return true; } diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 2686965e7b62..cf5b31377784 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -197,7 +197,7 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, out_dtype == DataType::UInt(8) || out_dtype == DataType::Int(32)) << "Output type should be one of [int8, uint8, int32] but was " << out_dtype; - reporter->Assign(types[5], TensorTypeNode::make(oshape, out_dtype)); + reporter->Assign(types[5], TensorType(oshape, out_dtype)); return true; } diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index 2d4bcb4fd6ab..6362421f051b 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -176,7 +176,7 @@ static inline void AssignType(const Type& expr_type, const DataType& dtype, cons const auto tensor_dtype = tensor_type->dtype; CHECK(tensor_dtype == dtype) << "Expected type is " << dtype << " but received " << tensor_dtype; if (tensor_type->shape.size() != 0) { - reporter->Assign(expr_type, TensorTypeNode::make({shape}, tensor_type->dtype)); + reporter->Assign(expr_type, TensorType({shape}, tensor_type->dtype)); } } diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index 5ddb6d485946..9d954ea02e26 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -36,7 +36,7 @@ TVM_REGISTER_GLOBAL("test.sch") TEST(Relay, BuildModule) { using namespace tvm; - auto tensor_type = relay::TensorTypeNode::make({2, 3}, DataType::Float(32)); + auto tensor_type = relay::TensorType({2, 3}, DataType::Float(32)); auto a = relay::VarNode::make("a", tensor_type); auto b = relay::VarNode::make("b", tensor_type); auto add_op = relay::Op::Get("add"); diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index 68d5d0d67f74..dcb44430994f 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -26,7 +26,7 @@ TEST(Relay, SelfReference) { using namespace tvm; - auto tensor_type = relay::TensorTypeNode::make({}, DataType::Bool()); + auto tensor_type = relay::TensorType({}, DataType::Bool()); auto x = relay::VarNode::make("x", relay::Type()); auto f = relay::FunctionNode::make(tvm::Array{ x }, x, relay::Type(), {}); CHECK(f->IsInstance()); diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential.cc index d4295548e0c0..5593f070cd71 100644 --- a/tests/cpp/relay_transform_sequential.cc +++ b/tests/cpp/relay_transform_sequential.cc @@ -36,7 +36,7 @@ TVM_REGISTER_GLOBAL("schedule") TEST(Relay, Sequential) { using namespace tvm; - auto tensor_type = relay::TensorTypeNode::make({1, 2, 3}, DataType::Float(32)); + auto tensor_type = relay::TensorType({1, 2, 3}, DataType::Float(32)); auto c_data = tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); diff --git a/tests/cpp/utvm_runtime_standalone_test.cc b/tests/cpp/utvm_runtime_standalone_test.cc index 7d3c80978be9..55f5c972bae3 100644 --- a/tests/cpp/utvm_runtime_standalone_test.cc +++ b/tests/cpp/utvm_runtime_standalone_test.cc @@ -51,7 +51,7 @@ TVM_REGISTER_GLOBAL("test.sch").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* TEST(MicroStandaloneRuntime, BuildModule) { using namespace tvm; - auto tensor_type = relay::TensorTypeNode::make({2, 3}, ::tvm::Float(32)); + auto tensor_type = relay::TensorType({2, 3}, ::tvm::Float(32)); auto a = relay::VarNode::make("a", tensor_type); auto b = relay::VarNode::make("b", tensor_type); auto add_op = relay::Op::Get("add"); From 2df73c7da7b36f0f89899129b19a3de9e59df4ae Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 19 Jan 2020 21:23:59 -0800 Subject: [PATCH 2/2] Add atol --- tests/python/frontend/mxnet/test_forward.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 18250d0ea5a6..7381a0728567 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -793,7 +793,7 @@ def verify(shape, axis=-1): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x, gamma, beta) - tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) verify((2, 5)) verify((2, 5), axis=0) verify((2, 5, 6)) @@ -809,7 +809,7 @@ def verify(indices_shape, depth, on_value, off_value, dtype): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x.astype("float32")) - tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) verify((3,), 3, 1, 0, "int32") verify((3,), 3, 1.0, 0.0, "float32") verify((2, 2), 5, 2, -2, "int32") @@ -898,7 +898,7 @@ def verify(data_shape, kernel_size, stride, pad, num_filter): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x, weight, bias) - tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) verify(data_shape=(1,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4) verify(data_shape=(20,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)