From f6ec97810f4967dc074eb4e2a804366efe8d99ee Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 14 Jun 2024 14:52:02 -0500 Subject: [PATCH] [Cleanup] Accept Variant<...> instead of ObjectRef when possible Prior to the implementation of `Variant<...>` in https://github.com/apache/tvm/pull/15672, functions that were polymorphic over an argument type would typically accept an `ObjectRef` argument, then downcast to an allowed type. This delays the catching of an error, and can accidentally omit automatic conversions applied by the FFI. This commit updates several locations using this pattern to instead accept a `Variant`, templated over the allowed types. This enables C++ type checking for C++ callers, standardizes the type-checking in the FFI for non-C++ callers, and ensures that FFI type conversions are uniformly applied. --- include/tvm/runtime/container/variant.h | 2 +- include/tvm/tir/function.h | 2 +- src/relax/op/tensor/create.cc | 2 +- src/relax/op/tensor/create.h | 2 +- src/relax/op/tensor/manipulate.cc | 6 +++--- src/relax/op/tensor/manipulate.h | 4 ++-- src/relay/transforms/to_mixed_precision.cc | 4 ++-- src/tir/ir/expr.cc | 4 +++- src/tir/ir/specialize.cc | 2 +- src/tir/transforms/inline_private_functions.cc | 2 +- 10 files changed, 16 insertions(+), 14 deletions(-) diff --git a/include/tvm/runtime/container/variant.h b/include/tvm/runtime/container/variant.h index 7953ac47c1cf..e8defa4e6fee 100644 --- a/include/tvm/runtime/container/variant.h +++ b/include/tvm/runtime/container/variant.h @@ -82,7 +82,7 @@ class Variant : public ObjectRef { public: /* \brief Helper utility to check if the type is part of the variant */ template - static constexpr bool is_variant = (std::is_same_v || ...); + static constexpr bool is_variant = (std::is_base_of_v || ...); /* \brief Helper utility for SFINAE if the type is part of the variant */ template diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 274ebd0a6558..1d218c6a7c61 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -264,7 +264,7 @@ class TensorIntrin : public ObjectRef { * B[vi, vj] = A[vi, vj] * \endcode */ -PrimFunc Specialize(PrimFunc func, const Map& param_map); +PrimFunc Specialize(PrimFunc func, const Map>& param_map); /*! * \brief PrimFunc specific attribute names. diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index fd6fea6e703c..7aca1470aee4 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -36,7 +36,7 @@ namespace relax { TVM_REGISTER_NODE_TYPE(InitAttrs); /* relax.full */ -Expr full(ObjectRef shape, Expr fill_value, DataType dtype) { +Expr full(Variant> shape, Expr fill_value, DataType dtype) { Expr shape_in_expr{nullptr}; if (const auto* expr = shape.as()) { shape_in_expr = GetRef(expr); diff --git a/src/relax/op/tensor/create.h b/src/relax/op/tensor/create.h index 989eaa12fdbf..6e7c8255238a 100644 --- a/src/relax/op/tensor/create.h +++ b/src/relax/op/tensor/create.h @@ -39,7 +39,7 @@ namespace relax { * If dtype is not given, it will by default use the dtype of fill_value. * \return The result tensor. */ -Expr full(ObjectRef shape, Expr fill_value, DataType dtype); +Expr full(Variant> shape, Expr fill_value, DataType dtype); /*! * \brief Construct a tensor such that diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index ad2a812c8254..f7c7ff4adac3 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -652,7 +652,7 @@ TVM_REGISTER_OP("relax.permute_dims") .set_attr("FPurity", Bool(true)); /* relax.reshape */ -Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { +Expr ConvertNewShapeToExpr(const Expr& data, const Variant>& shape) { const ArrayNode* array; // Treat shape expressions as constant arrays to handle special values. if (const auto* e = shape.as()) { @@ -745,7 +745,7 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { return ShapeExpr(array_ref); } -Expr reshape(Expr x, ObjectRef shape) { +Expr reshape(Expr x, Variant> shape) { Expr shape_in_expr = ConvertNewShapeToExpr(x, shape); static const Op& op = Op::Get("relax.reshape"); return Call(op, {std::move(x), std::move(shape_in_expr)}, Attrs(), {}); @@ -810,7 +810,7 @@ TVM_REGISTER_OP("relax.reshape") /* relax.split */ TVM_REGISTER_NODE_TYPE(SplitAttrs); -Expr split(Expr x, ObjectRef indices_or_sections, int axis) { +Expr split(Expr x, Variant> indices_or_sections, int axis) { ObjectPtr attrs = make_object(); if (const auto* indices = indices_or_sections.as()) { for (int i = 0; i < static_cast(indices->size()); ++i) { diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index b19e3b85070d..be39b053604a 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -88,7 +88,7 @@ Expr permute_dims(Expr x, Optional> axes); * It is required to be either an Array of PrimExpr, or a Shape in Relax * \return The reshaped result. */ -Expr reshape(Expr x, ObjectRef shape); +Expr reshape(Expr x, Variant> shape); /*! * \brief Split input tensor along axis by sections or indices. @@ -103,7 +103,7 @@ Expr reshape(Expr x, ObjectRef shape); * \param axis The axis over which to split. * \return The computed result. */ -Expr split(Expr x, ObjectRef indices_or_sections, int axis); +Expr split(Expr x, Variant> indices_or_sections, int axis); /*! * \brief Squeeze axes in the array. diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index 5026b1bcba79..1112755b76a0 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -66,7 +66,7 @@ using CachedCastNodes = std::unordered_map, // Return array is of type : [MixedTypeConversionCategory (int), String, String] // The fields are : [ConversionCategory, accumulation_datatype, output_datatype] // Call is a call node, DataType is the mixed precision type -using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc( +using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc>( const Call& call_node, const std::string& target_dtype_str)>; /*! \brief This class transforms the given relay module into a version where @@ -372,7 +372,7 @@ class MixedPrecisionPass : public MixedModeMutator { if (attr_map.count(op)) { // Calculate the conversion category and dtypes from registered attribute. FTVMMixedPrecisionConversionType func = attr_map[op]; - Array op_descriptor = + Array> op_descriptor = func(GetRef(pre_call_node), DLDataType2String(mixed_precision_type_)); ICHECK(op_descriptor.size() == 3) << "got the wrong number of returned arguments (expected 3 got " << op_descriptor.size() diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 1506082003fd..7143e7b79c3d 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -546,7 +546,9 @@ Call::Call(DataType dtype, RelayExpr op, Array args, Span span) { } TVM_REGISTER_GLOBAL("tir.Call") - .set_body_typed([](DataType type, RelayExpr op, Array args, Span span) { + .set_body_typed([](DataType type, RelayExpr op, + Array> args, + Span span) { Array prim_expr_args; for (const auto& it : args) { ICHECK(it->IsInstance() || it->IsInstance() || diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index b30d0caf6af3..78fb9365cc71 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -414,7 +414,7 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimEx /**************** Implementation ****************/ -PrimFunc Specialize(PrimFunc func, const Map& param_map) { +PrimFunc Specialize(PrimFunc func, const Map>& param_map) { VarMap var_map; for (const auto& kv : param_map) { const Var& param = kv.first; diff --git a/src/tir/transforms/inline_private_functions.cc b/src/tir/transforms/inline_private_functions.cc index cc33ba9f86c2..14672f568549 100644 --- a/src/tir/transforms/inline_private_functions.cc +++ b/src/tir/transforms/inline_private_functions.cc @@ -231,7 +231,7 @@ class PrimFuncInliner : StmtExprMutator { << "Inlining of PrimFuncs with buffer arguments is not yet supported, " << "but callee " << gvar << " has non-empty buffer map " << callee->buffer_map; - Map param_map; + Map> param_map; for (size_t i = 0; i < callee->params.size(); i++) { param_map.Set(callee->params[i], args[i]); }