diff --git a/include/tvm/runtime/container/variant.h b/include/tvm/runtime/container/variant.h index 7953ac47c1cf..e8defa4e6fee 100644 --- a/include/tvm/runtime/container/variant.h +++ b/include/tvm/runtime/container/variant.h @@ -82,7 +82,7 @@ class Variant : public ObjectRef { public: /* \brief Helper utility to check if the type is part of the variant */ template - static constexpr bool is_variant = (std::is_same_v || ...); + static constexpr bool is_variant = (std::is_base_of_v || ...); /* \brief Helper utility for SFINAE if the type is part of the variant */ template diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 274ebd0a6558..1d218c6a7c61 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -264,7 +264,7 @@ class TensorIntrin : public ObjectRef { * B[vi, vj] = A[vi, vj] * \endcode */ -PrimFunc Specialize(PrimFunc func, const Map& param_map); +PrimFunc Specialize(PrimFunc func, const Map>& param_map); /*! * \brief PrimFunc specific attribute names. diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index fd6fea6e703c..7aca1470aee4 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -36,7 +36,7 @@ namespace relax { TVM_REGISTER_NODE_TYPE(InitAttrs); /* relax.full */ -Expr full(ObjectRef shape, Expr fill_value, DataType dtype) { +Expr full(Variant> shape, Expr fill_value, DataType dtype) { Expr shape_in_expr{nullptr}; if (const auto* expr = shape.as()) { shape_in_expr = GetRef(expr); diff --git a/src/relax/op/tensor/create.h b/src/relax/op/tensor/create.h index 989eaa12fdbf..6e7c8255238a 100644 --- a/src/relax/op/tensor/create.h +++ b/src/relax/op/tensor/create.h @@ -39,7 +39,7 @@ namespace relax { * If dtype is not given, it will by default use the dtype of fill_value. * \return The result tensor. */ -Expr full(ObjectRef shape, Expr fill_value, DataType dtype); +Expr full(Variant> shape, Expr fill_value, DataType dtype); /*! * \brief Construct a tensor such that diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index ad2a812c8254..f7c7ff4adac3 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -652,7 +652,7 @@ TVM_REGISTER_OP("relax.permute_dims") .set_attr("FPurity", Bool(true)); /* relax.reshape */ -Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { +Expr ConvertNewShapeToExpr(const Expr& data, const Variant>& shape) { const ArrayNode* array; // Treat shape expressions as constant arrays to handle special values. if (const auto* e = shape.as()) { @@ -745,7 +745,7 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { return ShapeExpr(array_ref); } -Expr reshape(Expr x, ObjectRef shape) { +Expr reshape(Expr x, Variant> shape) { Expr shape_in_expr = ConvertNewShapeToExpr(x, shape); static const Op& op = Op::Get("relax.reshape"); return Call(op, {std::move(x), std::move(shape_in_expr)}, Attrs(), {}); @@ -810,7 +810,7 @@ TVM_REGISTER_OP("relax.reshape") /* relax.split */ TVM_REGISTER_NODE_TYPE(SplitAttrs); -Expr split(Expr x, ObjectRef indices_or_sections, int axis) { +Expr split(Expr x, Variant> indices_or_sections, int axis) { ObjectPtr attrs = make_object(); if (const auto* indices = indices_or_sections.as()) { for (int i = 0; i < static_cast(indices->size()); ++i) { diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index b19e3b85070d..be39b053604a 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -88,7 +88,7 @@ Expr permute_dims(Expr x, Optional> axes); * It is required to be either an Array of PrimExpr, or a Shape in Relax * \return The reshaped result. */ -Expr reshape(Expr x, ObjectRef shape); +Expr reshape(Expr x, Variant> shape); /*! * \brief Split input tensor along axis by sections or indices. @@ -103,7 +103,7 @@ Expr reshape(Expr x, ObjectRef shape); * \param axis The axis over which to split. * \return The computed result. */ -Expr split(Expr x, ObjectRef indices_or_sections, int axis); +Expr split(Expr x, Variant> indices_or_sections, int axis); /*! * \brief Squeeze axes in the array. diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index 5026b1bcba79..1112755b76a0 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -66,7 +66,7 @@ using CachedCastNodes = std::unordered_map, // Return array is of type : [MixedTypeConversionCategory (int), String, String] // The fields are : [ConversionCategory, accumulation_datatype, output_datatype] // Call is a call node, DataType is the mixed precision type -using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc( +using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc>( const Call& call_node, const std::string& target_dtype_str)>; /*! \brief This class transforms the given relay module into a version where @@ -372,7 +372,7 @@ class MixedPrecisionPass : public MixedModeMutator { if (attr_map.count(op)) { // Calculate the conversion category and dtypes from registered attribute. FTVMMixedPrecisionConversionType func = attr_map[op]; - Array op_descriptor = + Array> op_descriptor = func(GetRef(pre_call_node), DLDataType2String(mixed_precision_type_)); ICHECK(op_descriptor.size() == 3) << "got the wrong number of returned arguments (expected 3 got " << op_descriptor.size() diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 1506082003fd..7143e7b79c3d 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -546,7 +546,9 @@ Call::Call(DataType dtype, RelayExpr op, Array args, Span span) { } TVM_REGISTER_GLOBAL("tir.Call") - .set_body_typed([](DataType type, RelayExpr op, Array args, Span span) { + .set_body_typed([](DataType type, RelayExpr op, + Array> args, + Span span) { Array prim_expr_args; for (const auto& it : args) { ICHECK(it->IsInstance() || it->IsInstance() || diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index b30d0caf6af3..78fb9365cc71 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -414,7 +414,7 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimEx /**************** Implementation ****************/ -PrimFunc Specialize(PrimFunc func, const Map& param_map) { +PrimFunc Specialize(PrimFunc func, const Map>& param_map) { VarMap var_map; for (const auto& kv : param_map) { const Var& param = kv.first; diff --git a/src/tir/transforms/inline_private_functions.cc b/src/tir/transforms/inline_private_functions.cc index cc33ba9f86c2..14672f568549 100644 --- a/src/tir/transforms/inline_private_functions.cc +++ b/src/tir/transforms/inline_private_functions.cc @@ -231,7 +231,7 @@ class PrimFuncInliner : StmtExprMutator { << "Inlining of PrimFuncs with buffer arguments is not yet supported, " << "but callee " << gvar << " has non-empty buffer map " << callee->buffer_map; - Map param_map; + Map> param_map; for (size_t i = 0; i < callee->params.size(); i++) { param_map.Set(callee->params[i], args[i]); }