Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/runtime/container/variant.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class Variant : public ObjectRef {
public:
/* \brief Helper utility to check if the type is part of the variant */
template <typename T>
static constexpr bool is_variant = (std::is_same_v<T, V> || ...);
static constexpr bool is_variant = (std::is_base_of_v<V, T> || ...);

/* \brief Helper utility for SFINAE if the type is part of the variant */
template <typename T>
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ class TensorIntrin : public ObjectRef {
* B[vi, vj] = A[vi, vj]
* \endcode
*/
PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map);
PrimFunc Specialize(PrimFunc func, const Map<Var, Variant<Buffer, PrimExpr>>& param_map);

/*!
* \brief PrimFunc specific attribute names.
Expand Down
2 changes: 1 addition & 1 deletion src/relax/op/tensor/create.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace relax {
TVM_REGISTER_NODE_TYPE(InitAttrs);

/* relax.full */
Expr full(ObjectRef shape, Expr fill_value, DataType dtype) {
Expr full(Variant<Expr, Array<PrimExpr>> shape, Expr fill_value, DataType dtype) {
Expr shape_in_expr{nullptr};
if (const auto* expr = shape.as<ExprNode>()) {
shape_in_expr = GetRef<Expr>(expr);
Expand Down
2 changes: 1 addition & 1 deletion src/relax/op/tensor/create.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace relax {
* If dtype is not given, it will by default use the dtype of fill_value.
* \return The result tensor.
*/
Expr full(ObjectRef shape, Expr fill_value, DataType dtype);
Expr full(Variant<Expr, Array<PrimExpr>> shape, Expr fill_value, DataType dtype);

/*!
* \brief Construct a tensor such that
Expand Down
6 changes: 3 additions & 3 deletions src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ TVM_REGISTER_OP("relax.permute_dims")
.set_attr<Bool>("FPurity", Bool(true));

/* relax.reshape */
Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
Expr ConvertNewShapeToExpr(const Expr& data, const Variant<Expr, Array<PrimExpr>>& shape) {
const ArrayNode* array;
// Treat shape expressions as constant arrays to handle special values.
if (const auto* e = shape.as<ShapeExprNode>()) {
Expand Down Expand Up @@ -745,7 +745,7 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
return ShapeExpr(array_ref);
}

Expr reshape(Expr x, ObjectRef shape) {
Expr reshape(Expr x, Variant<Expr, Array<PrimExpr>> shape) {
Expr shape_in_expr = ConvertNewShapeToExpr(x, shape);
static const Op& op = Op::Get("relax.reshape");
return Call(op, {std::move(x), std::move(shape_in_expr)}, Attrs(), {});
Expand Down Expand Up @@ -810,7 +810,7 @@ TVM_REGISTER_OP("relax.reshape")
/* relax.split */
TVM_REGISTER_NODE_TYPE(SplitAttrs);

Expr split(Expr x, ObjectRef indices_or_sections, int axis) {
Expr split(Expr x, Variant<IntImm, Array<IntImm>> indices_or_sections, int axis) {
ObjectPtr<SplitAttrs> attrs = make_object<SplitAttrs>();
if (const auto* indices = indices_or_sections.as<ArrayNode>()) {
for (int i = 0; i < static_cast<int>(indices->size()); ++i) {
Expand Down
4 changes: 2 additions & 2 deletions src/relax/op/tensor/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ Expr permute_dims(Expr x, Optional<Array<Integer>> axes);
* It is required to be either an Array of PrimExpr, or a Shape in Relax
* \return The reshaped result.
*/
Expr reshape(Expr x, ObjectRef shape);
Expr reshape(Expr x, Variant<Expr, Array<PrimExpr>> shape);

/*!
* \brief Split input tensor along axis by sections or indices.
Expand All @@ -103,7 +103,7 @@ Expr reshape(Expr x, ObjectRef shape);
* \param axis The axis over which to split.
* \return The computed result.
*/
Expr split(Expr x, ObjectRef indices_or_sections, int axis);
Expr split(Expr x, Variant<IntImm, Array<IntImm>> indices_or_sections, int axis);

/*!
* \brief Squeeze axes in the array.
Expand Down
4 changes: 2 additions & 2 deletions src/relay/transforms/to_mixed_precision.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>,
// Return array is of type : [MixedTypeConversionCategory (int), String, String]
// The fields are : [ConversionCategory, accumulation_datatype, output_datatype]
// Call is a call node, DataType is the mixed precision type
using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<Variant<Integer, String>>(
const Call& call_node, const std::string& target_dtype_str)>;

/*! \brief This class transforms the given relay module into a version where
Expand Down Expand Up @@ -372,7 +372,7 @@ class MixedPrecisionPass : public MixedModeMutator {
if (attr_map.count(op)) {
// Calculate the conversion category and dtypes from registered attribute.
FTVMMixedPrecisionConversionType func = attr_map[op];
Array<ObjectRef> op_descriptor =
Array<Variant<Integer, String>> op_descriptor =
func(GetRef<Call>(pre_call_node), DLDataType2String(mixed_precision_type_));
ICHECK(op_descriptor.size() == 3)
<< "got the wrong number of returned arguments (expected 3 got " << op_descriptor.size()
Expand Down
4 changes: 3 additions & 1 deletion src/tir/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,9 @@ Call::Call(DataType dtype, RelayExpr op, Array<PrimExpr> args, Span span) {
}

TVM_REGISTER_GLOBAL("tir.Call")
.set_body_typed([](DataType type, RelayExpr op, Array<ObjectRef> args, Span span) {
.set_body_typed([](DataType type, RelayExpr op,
Array<Variant<runtime::String, IterVar, BufferRegion, PrimExpr>> args,
Span span) {
Array<PrimExpr> prim_expr_args;
for (const auto& it : args) {
ICHECK(it->IsInstance<runtime::StringObj>() || it->IsInstance<PrimExprNode>() ||
Expand Down
2 changes: 1 addition & 1 deletion src/tir/ir/specialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimEx

/**************** Implementation ****************/

PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map) {
PrimFunc Specialize(PrimFunc func, const Map<Var, Variant<Buffer, PrimExpr>>& param_map) {
VarMap var_map;
for (const auto& kv : param_map) {
const Var& param = kv.first;
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/inline_private_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ class PrimFuncInliner : StmtExprMutator {
<< "Inlining of PrimFuncs with buffer arguments is not yet supported, "
<< "but callee " << gvar << " has non-empty buffer map " << callee->buffer_map;

Map<Var, ObjectRef> param_map;
Map<Var, Variant<tir::Buffer, tvm::PrimExpr>> param_map;
for (size_t i = 0; i < callee->params.size(); i++) {
param_map.Set(callee->params[i], args[i]);
}
Expand Down