diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 594e2b86e9f9..53c0e6be0fc4 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -769,53 +769,95 @@ inline const TTypeNode* RelayExprNode::type_as() const { namespace tvm { namespace runtime { -// common rule for RetValue and ArgValue + +// Automatic conversion into IntImm, Integer, and Bool, when called +// through the FFI. Automatic conversions into PrimExpr are +// registered in "tvm/tir/expr.h", as it includes conversions to the +// TIR-only StringImm. +// +// While the FFI only requires the From() method, these +// implementations also define a TryFrom() method to avoid duplicate +// logic in the PrimExpr conversion. + template <> -struct PackedFuncValueConverter { - static PrimExpr From(const TVMPODValue_& val) { - if (val.type_code() == kTVMNullptr) { - return PrimExpr(ObjectPtr(nullptr)); - } - if (val.type_code() == kDLInt) { - int64_t value = val.operator int64_t(); - if (value > std::numeric_limits::max() || value < std::numeric_limits::min()) { - return IntImm(runtime::DataType::Int(64), value); - } - return IntImm(runtime::DataType::Int(32), val.operator int()); - } - if (val.type_code() == kDLFloat) { - return FloatImm(runtime::DataType::Float(32), val.operator double()); +struct PackedFuncValueConverter { + static Optional TryFrom(const TVMPODValue_& val) { + if (auto opt = val.TryAsInt()) { + int64_t value = opt.value(); + auto dtype = + (value > std::numeric_limits::max() || value < std::numeric_limits::min()) + ? DataType::Int(64) + : DataType::Int(32); + return IntImm(dtype, value); + } else if (auto opt = val.TryAsBool()) { + return IntImm(DataType::Int(32), opt.value()); + } else { + return NullOpt; } + } - return PrimExpr::FromObject_(val.AsObjectRef()); + static tvm::IntImm From(const TVMPODValue_& val) { + if (auto opt = TryFrom(val)) { + return opt.value(); + } else { + return val.AsObjectRef(); + } } }; template <> struct PackedFuncValueConverter { static tvm::Integer From(const TVMPODValue_& val) { - if (val.type_code() == kTVMNullptr) { - return Integer(ObjectPtr(nullptr)); + if (auto opt = val.TryAsInt()) { + return Integer(opt.value()); + } else if (auto opt = val.TryAsBool()) { + return Integer(opt.value()); + } else { + return val.AsObjectRef(); } - if (val.type_code() == kTVMArgInt) { - return Integer(val.operator int()); - } - return val.AsObjectRef(); } }; template <> struct PackedFuncValueConverter { + static Optional TryFrom(const TVMPODValue_& val) { + if (auto opt = val.TryAsBool()) { + return Bool(opt.value()); + } else if (auto opt = val.TryAsInt()) { + int value = opt.value(); + ICHECK(value == 0 || value == 1) + << "ValueError: boolean value can only be 0 or 1, but get " << value; + return Bool(static_cast(value)); + } else { + return NullOpt; + } + } + static tvm::Bool From(const TVMPODValue_& val) { - if (val.type_code() == kTVMNullptr) { - return Bool(ObjectPtr(nullptr)); + if (auto opt = TryFrom(val)) { + return opt.value(); + } else { + return val.AsObjectRef(); } - if (val.type_code() == kTVMArgInt) { - int v = val.operator int(); - ICHECK(v == 0 || v == 1) << "ValueError: boolean value can only be 0 or 1, but get " << v; - return Bool(static_cast(v)); + } +}; + +template <> +struct PackedFuncValueConverter { + static Optional TryFrom(const TVMPODValue_& val) { + if (auto opt = val.TryAsFloat()) { + return FloatImm(runtime::DataType::Float(32), opt.value()); + } else { + return NullOpt; + } + } + + static tvm::FloatImm From(const TVMPODValue_& val) { + if (auto opt = TryFrom(val)) { + return opt.value(); + } else { + return val.AsObjectRef(); } - return val.AsObjectRef(); } }; diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index bb1b2c8dd74a..66fa2e8d2353 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -287,6 +287,25 @@ If WithFields(If if_expr, Optional opt_cond = Optional(), Optional opt_false_branch = Optional(), Optional opt_span = Optional()); +/*! \brief Perform tuple access + * + * Use of this method is recommended, rather than constructing a + * `TupleGetItem` directly. + * + * 1. May resolve to the tuple's contents, avoiding the intermediate + * `TupleGetItem`. + * + * 2. Handles access of a tuple at a dynamic index, where + * `TupleGetItem` requires a statically-known index. + * + * \param tuple The tuple to be accessed + * + * \param index The index at which the access occurs + * + * \return An expression for the access of the tuple + */ +Expr tuple_get_item(Expr tuple, Expr index); + /*! \brief Tuple container */ class TupleNode : public ExprNode { public: @@ -320,6 +339,20 @@ class Tuple : public Expr { */ TVM_DLL explicit Tuple(tvm::Array fields, Span span = Span()); + /*! \brief Helper to delegate access to the tuple + * + * The `tuple_get_item` can be applied to any `relax::Expr`. + * However, this helper function is only provided for + * `relax::Tuple`, because `relax::Expr` is a typedef for + * `RelayExpr`, and we should avoid updating relay classes to + * provide relax-specific functionality.. + * + * \param index The index at which the tuple is accessed + * + * \return The contents of the tuple at the specified index + */ + inline Expr operator[](Expr index) { return tuple_get_item(*this, index); } + TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode); }; diff --git a/include/tvm/runtime/container/array.h b/include/tvm/runtime/container/array.h index ff0bd03ab9cb..ba8fdfac5565 100644 --- a/include/tvm/runtime/container/array.h +++ b/include/tvm/runtime/container/array.h @@ -827,8 +827,13 @@ class Array : public ObjectRef { // consisting of any previous elements that had mapped to // themselves (if any), and the element that didn't map to // itself. + // + // We cannot use `U()` as the default object, as `U` may be + // a non-nullable type. Since the default `ObjectRef()` + // will be overwritten before returning, all objects will be + // of type `U` for the calling scope. all_identical = false; - output = ArrayNode::CreateRepeated(arr->size(), U()); + output = ArrayNode::CreateRepeated(arr->size(), ObjectRef()); output->InitRange(0, arr->begin(), it); output->SetItem(it - arr->begin(), std::move(mapped)); it++; @@ -843,7 +848,12 @@ class Array : public ObjectRef { // compatible types isn't strictly necessary, as the first // mapped.same_as(*it) would return false, but we might as well // avoid it altogether. - output = ArrayNode::CreateRepeated(arr->size(), U()); + // + // We cannot use `U()` as the default object, as `U` may be a + // non-nullable type. Since the default `ObjectRef()` will be + // overwritten before returning, all objects will be of type `U` + // for the calling scope. + output = ArrayNode::CreateRepeated(arr->size(), ObjectRef()); } // Normal path for incompatible types, or post-copy path for diff --git a/include/tvm/runtime/container/boxed_primitive.h b/include/tvm/runtime/container/boxed_primitive.h new file mode 100644 index 000000000000..bf47e354fb33 --- /dev/null +++ b/include/tvm/runtime/container/boxed_primitive.h @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container/boxed_primitive.h + * \brief Runtime container types for primitives stored as ObjectRef. + */ +#ifndef TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_ +#define TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_ + +#include +#include + +namespace tvm { +namespace runtime { + +namespace detail { +/* \brief Provide the BoxNode type key in templated contexts + * + * The Box class is used in many templated contexts, and is easier + * to have templated over the primitive type. However, much of the + * TVM type system depends on classes having a unique name. For + * example, the use of `Object::IsInstance` depends on + * `Object::GetOrAllocRuntimeTypeIndex`. Any duplicate names will + * result in duplicate indices, and invalid downcasting. + * + * Furthermore, the name must be specified in the Python FFI using + * `tvm._ffi.register_object`. This prevents use of + * `typeid(T)::name()` to build a unique name, as the name is not + * required to be human-readable or consistent across compilers. + * + * This utility struct exists to bridge that gap, providing a unique + * name where required. + */ +template +struct BoxNodeTypeKey; + +template <> +struct BoxNodeTypeKey { + static constexpr const char* _type_key = "runtime.BoxInt"; +}; + +template <> +struct BoxNodeTypeKey { + static constexpr const char* _type_key = "runtime.BoxFloat"; +}; + +template <> +struct BoxNodeTypeKey { + static constexpr const char* _type_key = "runtime.BoxBool"; +}; +} // namespace detail + +template +class BoxNode : public Object { + public: + /*! \brief Constructor + * + * \param value The value to be boxed + */ + BoxNode(Prim value) : value(value) {} + + /*! \brief The boxed value */ + Prim value; + + static constexpr const char* _type_key = detail::BoxNodeTypeKey::_type_key; + static constexpr bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(BoxNode, Object); +}; + +template +class Box : public ObjectRef { + public: + /*! \brief Constructor + * + * \param value The value to be boxed + */ + Box(Prim value) : ObjectRef(make_object>(value)) {} + + operator Prim() const { return (*this)->value; } + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Box, ObjectRef, BoxNode); +}; + +/*! \brief Runtime equivalent of IntImm */ +using BoxInt = Box; + +/*! \brief Runtime equivalent of FloatImm */ +using BoxFloat = Box; + +/*! \brief Runtime equivalent of IntImm with DataType::Bool() + * + * When passing from Python to C++, TVM PackedFunc conversion follow + * C++ conversion rules, and allow bool->int and int->bool + * conversions. When passing from C++ to Python, the types are + * returned as bool or int. If the C++ function uses ObjectRef to + * hold the object, a Python to C++ to Python round trip will preserve + * the distinction between bool and int. + */ +using BoxBool = Box; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_ diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 7266f8c4a50a..327ff78deefa 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -37,6 +38,7 @@ #include #include #include +#include #include #include #include @@ -429,9 +431,11 @@ inline const char* ArgTypeCode2Str(int type_code); inline std::ostream& operator<<(std::ostream& os, DLDevice dev); // NOLINT(*) +#define TVM_LOG_INCORRECT_TYPE_CODE(CODE, T) \ + "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE) + // macro to check type code. -#define TVM_CHECK_TYPE_CODE(CODE, T) \ - ICHECK_EQ(CODE, T) << "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE) +#define TVM_CHECK_TYPE_CODE(CODE, T) ICHECK_EQ(CODE, T) << TVM_LOG_INCORRECT_TYPE_CODE(CODE, T) /*! * \brief Type traits for runtime type check during FFI conversion. @@ -555,29 +559,43 @@ class TVMPODValue_ { // Allow automatic conversion from int to float // This avoids errors when user pass in int from // the frontend while the API expects a float. - if (type_code_ == kDLInt) { - return static_cast(value_.v_int64); + if (auto opt = TryAsBool()) { + return opt.value(); + } else if (auto opt = TryAsInt()) { + return opt.value(); + } else if (auto opt = TryAsFloat()) { + return opt.value(); + } else { + LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLFloat); } - TVM_CHECK_TYPE_CODE(type_code_, kDLFloat); - return value_.v_float64; } operator int64_t() const { - TVM_CHECK_TYPE_CODE(type_code_, kDLInt); - return value_.v_int64; - } - operator uint64_t() const { - TVM_CHECK_TYPE_CODE(type_code_, kDLInt); - return value_.v_int64; + if (auto opt = TryAsBool()) { + return opt.value(); + } else if (auto opt = TryAsInt()) { + return opt.value(); + } else if (IsObjectRef()) { + auto obj = AsObjectRef(); + LOG(FATAL) << "Expected integer, but found object with type key " << obj->GetTypeKey(); + } else { + LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLInt); + } } + operator uint64_t() const { return operator int64_t(); } operator int() const { - TVM_CHECK_TYPE_CODE(type_code_, kDLInt); - ICHECK_LE(value_.v_int64, std::numeric_limits::max()); - ICHECK_GE(value_.v_int64, std::numeric_limits::min()); - return static_cast(value_.v_int64); + int64_t value = operator int64_t(); + ICHECK_LE(value, std::numeric_limits::max()); + ICHECK_GE(value, std::numeric_limits::min()); + return value; } operator bool() const { - TVM_CHECK_TYPE_CODE(type_code_, kDLInt); - return value_.v_int64 != 0; + if (auto opt = TryAsBool()) { + return opt.value(); + } else if (auto opt = TryAsInt()) { + return opt.value(); + } else { + LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLInt); + } } operator void*() const { if (type_code_ == kTVMNullptr) return nullptr; @@ -635,6 +653,38 @@ class TVMPODValue_ { template inline TObjectRef AsObjectRef() const; + std::optional TryAsInt() const { + // Helper function to reduce duplication in the variable integer + // conversions. This is publicly exposed, as it can be useful in + // specializations of PackedFuncValueConverter. + if (auto opt = FromBoxed()) { + return opt.value(); + } else if (type_code_ == kDLInt) { + return value_.v_int64; + } else { + return std::nullopt; + } + } + + std::optional TryAsFloat() const { + // Helper function to reduce duplication in the variable integer + // conversions. This is publicly exposed, as it can be useful in + // specializations of PackedFuncValueConverter. + if (auto opt = FromBoxed()) { + return opt.value(); + } else if (type_code_ == kDLFloat) { + return value_.v_float64; + } else { + return std::nullopt; + } + } + + std::optional TryAsBool() const { + // Booleans may be kept distinct from Int by using Box and + // Box. + return FromBoxed(); + } + protected: friend class TVMArgsSetter; friend class TVMRetValue; @@ -642,6 +692,15 @@ class TVMPODValue_ { TVMPODValue_() : type_code_(kTVMNullptr) {} TVMPODValue_(TVMValue value, int type_code) : value_(value), type_code_(type_code) {} + template + std::optional FromBoxed() const { + if (IsObjectRef>()) { + return AsObjectRef>()->value; + } else { + return std::nullopt; + } + } + /*! \brief The value */ TVMValue value_; /*! \brief the type code */ @@ -901,9 +960,12 @@ class TVMRetValue : public TVMPODValue_ { } TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); } TVMRetValue& operator=(bool value) { - this->SwitchToPOD(kDLInt); - value_.v_int64 = value; - return *this; + // While a boolean could be stored using the primitive kDLInt + // type, this causes round-trip inconsistencies for languages that + // distinguish between integer and boolean types (i.e. Anything + // after C89). Rather than adding another type for booleans, this + // is stored in the Box container. + return operator=(Box(value)); } TVMRetValue& operator=(std::string value) { this->SwitchToClass(kTVMStr, value); @@ -989,9 +1051,9 @@ class TVMRetValue : public TVMPODValue_ { } // ObjectRef handling template ::value>::type> + typename = typename std::enable_if_t>> inline TVMRetValue& operator=(TObjectRef other); - template ::value>::type> + template >> inline operator T() const; private: @@ -1019,9 +1081,11 @@ class TVMRetValue : public TVMPODValue_ { break; } case kTVMObjectHandle: { - // Avoid operator ObjectRef as we already know it is not NDArray/Module - SwitchToObject(kTVMObjectHandle, - GetObjectPtr(static_cast(other.value_.v_handle))); + // We already known it is not NDArray/Module, but + // operator=(ObjectRef) also handles conversions from wrappers + // around primitive types. For NDArray/Module, the duplicate + // checks are removed with if constexpr. + operator=(other.operator ObjectRef()); break; } case kTVMObjectRValueRefArg: { @@ -1951,33 +2015,96 @@ inline T TVMArgs::At(int i) const { template inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { using ContainerType = typename std::remove_reference::type::ContainerType; - if (value.defined()) { - Object* ptr = value.data_.data_; - if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { + if (!value.defined()) { + type_codes_[i] = kTVMNullptr; + values_[i].v_handle = nullptr; + return; + } + + Object* ptr = value.data_.data_; + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { values_[i].v_handle = NDArray::FFIGetHandle(value); type_codes_[i] = kTVMNDArrayHandle; - } else if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { + return; + } + } + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { values_[i].v_handle = ptr; type_codes_[i] = kTVMModuleHandle; - } else if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { + return; + } + } + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { values_[i].v_handle = ptr; type_codes_[i] = kTVMPackedFuncHandle; - } else if (std::is_rvalue_reference::value) { - values_[i].v_handle = const_cast(&(value.data_.data_)); - type_codes_[i] = kTVMObjectRValueRefArg; - } else { - values_[i].v_handle = value.data_.data_; - type_codes_[i] = kTVMObjectHandle; + return; } + } + + // If a boxed integer is being returned, always unbox it to the + // primitive type. This must be checked at the PackedFunc level to + // ensure that a boxed primitive argument is round-tripped correctly + // when the boxing is no longer required. + // + // For example, consider a PackedFunc with signature `ObjectRef + // func(Array)`, and returns the first element of that + // array. When passing a Python array `[5, 17.5, "hello"]`, the + // items are converted to `[Box(5), Box(17.5), + // String("hello")]` in order to provide an `Array`. + // + // If we had no additional conversions, the caller would receive the + // return value as a `Box(5)`, which would be unexpected and + // require additional unwrapping. We could perform this check + // inside the PackedFunc, but that would require a large amount of + // duplicated checked, and would require explicit handling of + // `TVMRetValue`. Instead, this conversion is checked in the FFI + // return value, to ensure that boxing/unboxing is applied + // consistently. + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { + values_[i].v_int64 = static_cast(ptr)->value; + type_codes_[i] = kTVMArgInt; + return; + } + } + + // Like with BoxInt, unwrap any BoxFloat instances. See the BoxInt + // explanation for more detail. + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { + values_[i].v_float64 = static_cast(ptr)->value; + type_codes_[i] = kTVMArgFloat; + return; + } + } + + // Deliberately do *not* unwrap BoxBool instances. If BoxBool were + // unwrapped to kTVMArgInt, it would be ambiguous whether the + // user-defined object was a bool or an int. + + // Final fallback, if the ObjectRef has no special cases that must + // be expressed within the TVMRetValue. + if constexpr (std::is_rvalue_reference_v) { + values_[i].v_handle = const_cast(&(value.data_.data_)); + type_codes_[i] = kTVMObjectRValueRefArg; } else { - type_codes_[i] = kTVMNullptr; - values_[i].v_handle = nullptr; + values_[i].v_handle = value.data_.data_; + type_codes_[i] = kTVMObjectHandle; } } @@ -2023,8 +2150,10 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expect a not null value of " << ContainerType::_type_key; return TObjectRef(ObjectPtr(nullptr)); } - // NOTE: the following code can be optimized by constant folding. - if (std::is_base_of::value) { + + // NOTE: The following code uses "if constexpr" wherever possible to + // minimize the number of runtime checks. + if constexpr (std::is_base_of_v) { // Casting to a sub-class of NDArray TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle); ObjectPtr data = @@ -2033,7 +2162,8 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); return TObjectRef(data); } - if (std::is_base_of::value) { + + if constexpr (std::is_base_of_v) { // Casting to a sub-class of Module TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle); ObjectPtr data = GetObjectPtr(static_cast(value_.v_handle)); @@ -2041,7 +2171,8 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); return TObjectRef(data); } - if (std::is_base_of::value) { + + if constexpr (std::is_base_of_v) { // Casting to a sub-class of PackedFunc TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); ObjectPtr data = GetObjectPtr(static_cast(value_.v_handle)); @@ -2049,6 +2180,7 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); return TObjectRef(data); } + if (type_code_ == kTVMObjectHandle) { // normal object type check. Object* ptr = static_cast(value_.v_handle); @@ -2062,46 +2194,100 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { ICHECK(!checked_type.defined()) << "Expected " << ObjectTypeChecker::TypeName() << ", but got " << checked_type.value(); return TObjectRef(GetObjectPtr(ptr)); - } else if (std::is_base_of::value && - type_code_ == kTVMNDArrayHandle) { - // Casting to a base class that NDArray can sub-class - ObjectPtr data = - NDArray::FFIDataFromHandle(static_cast(value_.v_handle)); - return TObjectRef(data); - } else if (std::is_base_of::value && - type_code_ == kTVMModuleHandle) { - // Casting to a base class that Module can sub-class - return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); - } else if (std::is_base_of::value && - type_code_ == kTVMPackedFuncHandle) { - // Casting to a base class that PackedFunc can sub-class - return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); - } else { - TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); - return TObjectRef(ObjectPtr(nullptr)); } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMNDArrayHandle) { + // Casting to a base class that NDArray can sub-class + ObjectPtr data = + NDArray::FFIDataFromHandle(static_cast(value_.v_handle)); + return TObjectRef(data); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMModuleHandle) { + // Casting to a base class that Module can sub-class + return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMPackedFuncHandle) { + // Casting to a base class that PackedFunc can sub-class + return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMArgInt) { + return BoxInt(value_.v_int64); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMArgFloat) { + return BoxFloat(value_.v_float64); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMStr) { + return String(value_.v_str); + } + } + + TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); + return TObjectRef(ObjectPtr(nullptr)); } template inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { using ContainerType = typename TObjectRef::ContainerType; const Object* ptr = other.get(); - if (ptr != nullptr) { - if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (ptr && (std::is_base_of_v || + ptr->IsInstance())) { return operator=(NDArray(std::move(other.data_))); } - if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { + } + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (ptr && (std::is_base_of_v || + ptr->IsInstance())) { return operator=(Module(std::move(other.data_))); } - if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { + } + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (ptr && (std::is_base_of_v || + ptr->IsInstance())) { return operator=(PackedFunc(std::move(other.data_))); } + } + + if constexpr (std::is_base_of_v || std::is_base_of_v) { + if (ptr && + (std::is_base_of_v || ptr->IsInstance())) { + int64_t value = static_cast(ptr)->value; + return operator=(value); + } + } + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (ptr && + (std::is_base_of_v || ptr->IsInstance())) { + double value = static_cast(ptr)->value; + return operator=(value); + } + } + + if (ptr) { SwitchToObject(kTVMObjectHandle, std::move(other.data_)); } else { SwitchToPOD(kTVMNullptr); @@ -2156,6 +2342,42 @@ struct PackedFuncValueConverter<::tvm::runtime::String> { } }; +template +struct PackedFuncValueConverter> { + static Array From(const TVMArgValue& val) { + auto untyped_array = val.AsObjectRef>(); + + // Attempt to convert each item of the array into the desired + // type. If the items do not require a conversion, no copies are + // made. + return untyped_array.Map([](ObjectRef item) { + // The TVMArgValue is intentionally defined through + // `TVMArgsSetter`, rather than defining it with + // `value.data_ = item.get();` and type code + // `kTVMObjectHandle`. `TVMArgsSetter::operator()` includes + // special handling for unwrapping boxed primitives, + // PackedFunc, runtime::Module, etc, which should be checked + // before delegating to the array element's + // PackedFuncValueConverter implementation. + TVMValue value; + int type_code; + TVMArgsSetter setter(&value, &type_code); + setter(0, item); + TVMArgValue arg(value, type_code); + return PackedFuncValueConverter::From(arg); + }); + } + static Array From(const TVMRetValue& val) { + auto untyped_array = val.AsObjectRef>(); + + return untyped_array.Map([](ObjectRef item) { + TVMRetValue item_val; + item_val = std::move(item); + return PackedFuncValueConverter::From(item_val); + }); + } +}; + template struct PackedFuncValueConverter> { static Optional From(const TVMArgValue& val) { @@ -2207,7 +2429,7 @@ struct PackedFuncValueConverter> { static Optional TryValueConverter(const PODSubclass& val) { try { return VType(PackedFuncValueConverter::From(val)); - } catch (const InternalError&) { + } catch (const Error&) { } if constexpr (sizeof...(VarRest)) { diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index d47ac94e067e..4c1d1fc1f3d2 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -113,7 +113,15 @@ class TargetNode : public Object { "Can only call GetAttr with ObjectRef types."); auto it = attrs.find(attr_key); if (it != attrs.end()) { - return Downcast>((*it).second); + // For backwards compatibility, return through TVMRetValue. + // This triggers any automatic conversions registered with + // PackedFuncValueConverter. Importantly, this allows use of + // `GetAttr` and `GetAttr` for properties that + // are stored internally as `runtime::Box` and + // `runtime::Box`. + TVMRetValue ret; + ret = (*it).second; + return ret; } else { return default_value; } diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 130aea32f844..fb9c0f17b011 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -445,8 +445,8 @@ constexpr const char* kRelayToTIR = "RelayToTIR"; .add_attr_option("model") \ .add_attr_option>("libs") \ .add_attr_option("host") \ - .add_attr_option("from_device") \ - .add_attr_option("target_device_type") + .add_attr_option("from_device") \ + .add_attr_option("target_device_type") } // namespace tvm diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 4e29eddadd8c..7d1e5e3768de 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -1150,6 +1150,66 @@ inline std::unordered_map as_unordered_map(const Map& dmap) { } // namespace tir } // namespace tvm +namespace tvm { +namespace runtime { + +// Automatic conversion into PrimExpr, when called through the FFI. +// Automatic conversions into IntImm, Integer, and Bool are registered +// in "tvm/ir/expr.h", as they are currently in use outside of TIR. + +template <> +struct PackedFuncValueConverter { + template + static Optional TryFrom(const PODSubclass& val) { + auto type_code = val.type_code(); + bool can_convert = type_code == kTVMDataType || type_code == kTVMBytes || + type_code == kTVMStr || val.template IsObjectRef(); + if (can_convert) { + return tvm::tir::StringImm(PackedFuncValueConverter::From(val)); + } else { + return NullOpt; + } + } + + template + static tvm::tir::StringImm From(const PODSubclass& val) { + if (auto opt = TryFrom(val)) { + return opt.value(); + } else { + return val.template AsObjectRef(); + } + } +}; + +template <> +struct PackedFuncValueConverter { + // Common rule for RetValue and ArgValue. Templated to ensure + // correct delegation to `operator std::string()` for either + // TVMArgValue or TVMRetValue. + template + static PrimExpr From(const PODSubclass& val) { + if (auto opt = val.TryAsBool()) { + // Check against val.TryAsBool directly, to avoid the + // bounds-checking in PackedFuncValueConverter::TryFrom. + return Bool(opt.value()); + } else if (auto opt = PackedFuncValueConverter::TryFrom(val)) { + return opt.value(); + } else if (auto opt = PackedFuncValueConverter::TryFrom(val)) { + return opt.value(); + } else if (auto opt = PackedFuncValueConverter::TryFrom(val)) { + return opt.value(); + } else if (val.template IsObjectRef()) { + // Delegate to the implicit conversion from IterVar to PrimExpr + return val.template AsObjectRef(); + } else { + return val.template AsObjectRef(); + } + } +}; + +} // namespace runtime +} // namespace tvm + namespace std { template <> struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectPtrHash {}; diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py index 520e0e42ebbe..2c739b7cfbab 100644 --- a/python/tvm/_ffi/_ctypes/object.py +++ b/python/tvm/_ffi/_ctypes/object.py @@ -60,14 +60,27 @@ def _return_object(x): tindex = ctypes.c_uint() check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex))) cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT) + + # Handle return values that subclass from both TVM objects and + # python native objects (e.g. runtime.String, a subclass of str). if issubclass(cls, PyNativeObject): obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) obj.handle = handle return cls.__from_tvm_object__(cls, obj) + # Avoid calling __init__ of cls, instead directly call __new__ # This allows child class to implement their own __init__ obj = cls.__new__(cls) obj.handle = handle + + # Handle return values that must be converted from the TVM object + # to a python native object. This should be used in cases where + # subclassing the python native object is forbidden. For example, + # `runtime.BoxBool` cannot be a subclass of `bool`, as `bool` does + # not allow any subclasses. + if hasattr(obj, '__into_pynative_object__'): + return obj.__into_pynative_object__() + return obj diff --git a/python/tvm/_ffi/_ctypes/packed_func.py b/python/tvm/_ffi/_ctypes/packed_func.py index 6465e0335db0..dee5e0cc2f1a 100644 --- a/python/tvm/_ffi/_ctypes/packed_func.py +++ b/python/tvm/_ffi/_ctypes/packed_func.py @@ -134,6 +134,13 @@ def _make_tvm_args(args, temp_args): elif isinstance(arg, _nd._TVM_COMPATS): values[i].v_handle = ctypes.c_void_p(arg._tvm_handle) type_codes[i] = arg.__class__._tvm_tcode + elif isinstance(arg, bool): + # A python `bool` is a subclass of `int`, so this check + # must occur before `Integral`. + arg = _FUNC_CONVERT_TO_OBJECT(arg) + values[i].v_handle = arg.handle + type_codes[i] = ArgTypeCode.OBJECT_HANDLE + temp_args.append(arg) elif isinstance(arg, Integral): values[i].v_int64 = arg type_codes[i] = ArgTypeCode.INT @@ -147,7 +154,7 @@ def _make_tvm_args(args, temp_args): values[i].v_int64 = _device_to_int64(arg) type_codes[i] = ArgTypeCode.DLDEVICE elif isinstance(arg, (bytearray, bytes)): - # from_buffer only taeks in bytearray. + # from_buffer only takes in bytearray. if isinstance(arg, bytes): byte_arr = bytearray(arg) temp_args.append(byte_arr) diff --git a/python/tvm/_ffi/_cython/object.pxi b/python/tvm/_ffi/_cython/object.pxi index 94a9310d7815..50ec3ce0fb00 100644 --- a/python/tvm/_ffi/_cython/object.pxi +++ b/python/tvm/_ffi/_cython/object.pxi @@ -60,6 +60,15 @@ cdef inline object make_ret_object(void* chandle): obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) (obj).chandle = chandle + + # Handle return values that must be converted from the TVM object + # to a python native object. This should be used in cases where + # subclassing the python native object is forbidden. For example, + # `runtime.BoxBool` cannot be a subclass of `bool`, as `bool` does + # not allow any subclasses. + if hasattr(obj, '__into_pynative_object__'): + return obj.__into_pynative_object__() + return obj diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 71f23577e70d..0be9d1a6acea 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -215,7 +215,7 @@ def __call__(self, *args: List[Expr], attrs: Optional[Dict[str, Any]] = None) -> """ return Call(self, args, attrs=attrs) - def __getitem__(self, index: int) -> "ExprWithOp": + def __getitem__(self, index: Union[Expr, PrimExpr, int]) -> "ExprWithOp": """Get the i-th element of the tuple or Expr with TupleType. Parameters @@ -232,17 +232,7 @@ def __getitem__(self, index: int) -> "ExprWithOp": result: ExprWithOp The result expression. """ - try: - return TupleGetItem(self, index) - except tvm.TVMError as err: - # For Python objects with __getitem__, but without - # __len__, tuple unpacking is done by iterating over - # sequential indices until IndexError is raised. - # Therefore, convert from TVMError to IndexError for - # compatibility. - if "Index out of bounds" in err.args[0]: - raise IndexError from err - raise + return tvm.relax.op.tuple_get_item(self, index) @tvm._ffi.register_object("relax.expr.Call") @@ -560,7 +550,9 @@ class PrimValue(Expr, Scriptable): value: PrimExpr def __init__(self, value: Union[PrimExpr, int], span: Optional[Span] = None) -> None: - if isinstance(value, int): + if isinstance(value, bool): + value = tvm.tir.IntImm("bool", value) + elif isinstance(value, int): value = tvm.tir.IntImm("int64", value) self.__init_handle_by_constructor__(_ffi_api.PrimValue, value, span) # type: ignore diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 60a4332d838c..61535df8ec37 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -101,6 +101,7 @@ from .set import unique from .statistical import cumsum, max, mean, min, prod, std, sum, variance from .ternary import ewise_fma +from .tuple import tuple_get_item, tuple_get_item_dyn from .unary import ( abs, acos, diff --git a/python/tvm/relax/op/tuple.py b/python/tvm/relax/op/tuple.py new file mode 100644 index 000000000000..e04cbf63be2d --- /dev/null +++ b/python/tvm/relax/op/tuple.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tuple operators.""" +from typing import Union + +import tvm +from tvm.ir.expr import PrimExpr + +from . import _ffi_api +from ..expr import Expr, PrimValue + + +def tuple_get_item(tuple_expr: Expr, index: Union[int, PrimExpr, Expr]) -> Expr: + """Perform tuple access + + Use of this method is recommended, rather than constructing a + `relax.TupleGetItem` directly. + + 1. May resolve to the tuple's contents, avoiding the intermediate + `TupleGetItem`. + + 2. Handles access of a tuple at a dynamic index, where + `TupleGetItem` requires a statically-known index. + + Parameters + ---------- + tuple_expr: Expr + + The tuple to be accessed. The tuple is not required to be an + in-line `relax.Tuple`, but must have `TupleStructInfo` + + index: Union[int, PrimExpr, Expr] + + The index at which the tuple is accessed. The index may be + static or dynamic. + + Returns + ------- + Expr + + An expression representing the item in the tuple. + """ + + if not isinstance(index, Expr): + index = PrimValue(index) + + return _ffi_api.tuple_get_item(tuple_expr, index) # type: ignore + + +def tuple_get_item_dyn(tuple_expr: Expr, index: Union[int, PrimExpr, Expr]) -> Expr: + """Explicitly generate a call to tuple_get_item_dyn + + This method is not recommended for general use, and is provided to + ensure round-trip consistency in TVMScript. In most cases, the + `tuple_get_item` method should be used, which will delegate to the + dynamic builtin for cases where the index is dynamic. + + Parameters + ---------- + tuple_expr: Expr + + The tuple to be accessed. The tuple is not required to be an + in-line `relax.Tuple`, but must have `TupleStructInfo` + + index: Union[int, PrimExpr, Expr] + + The index at which the tuple is accessed. The index may be + static or dynamic. + + Returns + ------- + Expr + + An expression representing the item in the tuple. + + """ + if not isinstance(index, Expr): + index = PrimValue(index) + return tvm.relax.Call(tvm.ir.Op.get("relax.tuple_get_item_dyn"), [tuple_expr, index]) diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index eccdcbad9520..1f11cba76249 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -27,11 +27,11 @@ from .profiling import Report # function exposures -from .object_generic import convert_to_object, convert, const from .ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl from .ndarray import vpi, rocm, ext_dev from .module import load_module, enabled, system_lib, load_static_library -from .container import String, ShapeTuple +from .container import String, ShapeTuple, BoxBool +from .object_generic import convert_to_object, convert, const from .params import ( save_param_dict, load_param_dict, diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py index 686b4a26c80c..7c754ba1e622 100644 --- a/python/tvm/runtime/container.py +++ b/python/tvm/runtime/container.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Runtime container structures.""" +from typing import Union + import tvm._ffi from .object import Object, PyNativeObject from .object_generic import ObjectTypes @@ -172,3 +174,41 @@ def __eq__(self, other): return False return True + + +@tvm._ffi.register_object("runtime.BoxBool") +class BoxBool(Object): + """A boolean wrapped as a tvm Object + + Parameters + ---------- + value: bool + + The value to hold + """ + + def __init__(self, value: bool): + # Convert to int to avoid an infinite recursion, because + # BoxBool may be constructed in _make_tvm_args, and calling + # the packed func `_ffi_api.BoxBool` internally calls + # `_make_tvm_args`. + self.__init_handle_by_constructor__(_ffi_api.BoxBool, int(value)) + + def __into_pynative_object__(self) -> bool: + return self.value + + @property + def value(self) -> bool: + """Unwrap the boxed value. + + This is implemented explicitly rather than using the usual + PackedFunc handling or AttrVisitor mechanics for two reasons. + First, because the PackedFunc handling would require ambiguous + representations between `True`/`1` and `False`/`0`. Second, + because the boxing/unboxing must be available in + `libtvm_runtime.so`, and AttrVisitor is only available in + `libtvm.so`. + """ + unboxed_bool = _ffi_api.UnBoxBool(self) + assert unboxed_bool is not None + return bool(unboxed_bool) diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index 887c2faaeb2b..b0b29426dbee 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -41,6 +41,14 @@ def asobject(self): def convert_to_object(value, span=None): """Convert a Python value to corresponding object type. + Type conversions performed by this function must *only* produce + types that are supported by `libtvm_runtime.so`. This function + must be usable in environments where only TVM runtime support is + present. Automatic conversions to compile-time representations + (e.g. `tir.IntImm` or `relax.PrimValue`) should not be done as + part of this conversion, as these types are not available in + `libtvm_runtime.so`. + Parameters ---------- value : str @@ -53,38 +61,46 @@ def convert_to_object(value, span=None): ------- obj : Object The corresponding object value. + """ + + # Import inside function call to avoid circular import from + # uninitialized tvm.runtime module. + from .container import BoxBool # pylint: disable=import-outside-toplevel + if isinstance(value, ObjectTypes): return value - if isinstance(value, bool): - return const(value, "uint1x1", span=span) - if isinstance(value, Number): + + elif isinstance(value, bool): + # Python types int and float will be converted to C++ types + # Box and Box using kDLInt. Boolean types need + # to be explicitly converted to Box to avoid ambiguous + # representation. This allows `bool(True)` and `int(1)` to be + # unambiguously passed to the C++ implementations. + return BoxBool(value) + + elif isinstance(value, Number): return const(value, span=span) - if isinstance(value, string_types): + elif isinstance(value, string_types): return _ffi_api.String(value) - if isinstance(value, (list, tuple)): - value = [convert_to_object(x) for x in value] + elif isinstance(value, (list, tuple)): + # The call to _ffi_api.Array will convert its own arguments, + # so we don't need to apply any explicit conversions here. return _ffi_api.Array(*value) - if isinstance(value, dict): - vlist = [] - for item in value.items(): - if ( - not isinstance(item[0], ObjectTypes) - and not isinstance(item[0], string_types) - and not isinstance(item[0], Number) - ): - raise ValueError("key of map must already been a container type") - vlist.append(convert_to_object(item[0])) - vlist.append(convert_to_object(item[1])) + elif isinstance(value, dict): + if any(not isinstance(key, (ObjectTypes, string_types, Number)) for key in value): + raise ValueError("key of map must already been a container type") + + vlist = [kv for item in value.items() for kv in item] return _ffi_api.Map(*vlist) - if isinstance(value, ObjectGeneric): + elif isinstance(value, ObjectGeneric): return value.asobject() - if callable(value): + elif callable(value): return convert_to_tvm_func(value) - if value is None: + elif value is None: return None - - raise ValueError(f"don't know how to convert type {type(value)} to object") + else: + raise TypeError(f"don't know how to convert type {type(value)} to object") def convert(value, span=None): @@ -107,29 +123,29 @@ def convert(value, span=None): This function is redirected to `convert_to_object` as it is widely used in the codebase. We can choose one to keep and discard the other one later. """ + return convert_to_object(value, span=span) def _scalar_type_inference(value): if hasattr(value, "dtype"): - dtype = str(value.dtype) + return str(value.dtype) elif isinstance(value, bool): - dtype = "bool" + return "bool" elif isinstance(value, float): # We intentionally prefer convert the float to float32 since it's more common in DL. if -3.40282347e38 <= value <= 3.40282347e38: - dtype = "float32" + return "float32" else: - dtype = "float64" + return "float64" elif isinstance(value, int): # We intentionally prefer convert the python int to int32 since it's more common in DL. if -2147483648 <= value <= 2147483647: - dtype = "int32" + return "int32" else: - dtype = "int64" + return "int64" else: raise NotImplementedError(f"Cannot automatically inference the type. value={value}") - return dtype def const(value, dtype=None, span=None): diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 142d0e6d96aa..6068880d3a96 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -146,6 +146,8 @@ tile, tril, triu, + tuple_get_item, + tuple_get_item_dyn, unique, vm, where, @@ -775,6 +777,8 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "tril", "triu", "tuple", + "tuple_get_item", + "tuple_get_item_dyn", "unique", "variance", "vm", diff --git a/src/node/boxed_primitive.cc b/src/node/boxed_primitive.cc new file mode 100644 index 000000000000..10b308fcf4db --- /dev/null +++ b/src/node/boxed_primitive.cc @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file node/boxed_primitive.cc + * + * \brief Reflection utilities for runtime-supported classes + * + * The fundamental support for boxing and unboxing of primitives + * during FFI calls is implemented in runtime/boxed_primitive.cc. In + * addition, boxed primitives may be registered with compile-time + * utilities (e.g. reflection, JSON import/export) that can provide + * additional functionality and improved debugging ability. However, + * neither these compile-time utilities nor any registration of + * `Box` into the compile-time utilities should be included as + * part of `libtvm_runtime.so`. + * + * This file contains the registration of the `libtvm_runtime.so` + * class `Box` for utilities that are contained in `libtvm.so`. + */ +#include +#include +#include +#include + +namespace tvm { +namespace runtime_ext { + +using runtime::Box; +using runtime::BoxNode; + +template +struct BoxNodeTrait { + static constexpr const std::nullptr_t VisitAttrs = nullptr; + + static void SHashReduce(const BoxNode* node, SHashReducer hash_reduce) { + hash_reduce(node->value); + } + + static bool SEqualReduce(const BoxNode* lhs, const BoxNode* rhs, + SEqualReducer equal) { + return equal(lhs->value, rhs->value); + } +}; + +TVM_REGISTER_REFLECTION_VTABLE(BoxNode, BoxNodeTrait) + .set_creator([](const std::string& blob) -> ObjectPtr { + int64_t value = std::atoll(blob.c_str()); + return make_object>(value); + }) + .set_repr_bytes([](const Object* n) -> std::string { + int64_t value = GetRef(n).as>().value()->value; + std::stringstream ss; + ss << value; + return ss.str(); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch>([](const ObjectRef& node, ReprPrinter* p) { + auto box = Downcast>(node); + p->stream << box->GetTypeKey() << "(" << box->value << ")"; + }); + +TVM_REGISTER_REFLECTION_VTABLE(BoxNode, BoxNodeTrait) + .set_creator([](const std::string& blob) -> ObjectPtr { + if (blob == "true") { + return make_object>(true); + } else if (blob == "false") { + return make_object>(false); + } else { + LOG(FATAL) << "Invalid string '" << blob << "' for boolean"; + } + }) + .set_repr_bytes([](const Object* n) -> std::string { + bool value = GetRef(n).as>().value()->value; + if (value) { + return "true"; + } else { + return "false"; + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch>([](const ObjectRef& node, ReprPrinter* p) { + auto box = Downcast>(node); + p->stream << box->GetTypeKey() << "(" << (box->value ? "true" : "false") << ")"; + }); + +TVM_REGISTER_REFLECTION_VTABLE(BoxNode, BoxNodeTrait) + .set_creator([](const std::string& blob) -> ObjectPtr { + double value = std::atof(blob.c_str()); + return make_object>(value); + }) + .set_repr_bytes([](const Object* n) -> std::string { + double value = GetRef(n).as>().value()->value; + std::stringstream ss; + ss << value; + return ss.str(); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch>([](const ObjectRef& node, ReprPrinter* p) { + auto box = Downcast>(node); + p->stream << box->GetTypeKey() << "(" << box->value << ")"; + }); + +} // namespace runtime_ext + +} // namespace tvm diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index f2d985279f12..301216d04cad 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -48,7 +48,7 @@ PrinterConfig::PrinterConfig(Map config_dict) { n->binding_names.push_back(Downcast(v)); } if (auto v = config_dict.Get("show_meta")) { - n->show_meta = Downcast(v)->value; + n->show_meta = Downcast(v)->value; } if (auto v = config_dict.Get("ir_prefix")) { n->ir_prefix = Downcast(v); @@ -72,16 +72,16 @@ PrinterConfig::PrinterConfig(Map config_dict) { n->float_dtype = DataType(runtime::String2DLDataType(Downcast(v))); } if (auto v = config_dict.Get("verbose_expr")) { - n->verbose_expr = Downcast(v)->value; + n->verbose_expr = Downcast(v)->value; } if (auto v = config_dict.Get("indent_spaces")) { - n->indent_spaces = Downcast(v)->value; + n->indent_spaces = Downcast(v)->value; } if (auto v = config_dict.Get("print_line_numbers")) { - n->print_line_numbers = Downcast(v)->value; + n->print_line_numbers = Downcast(v)->value; } if (auto v = config_dict.Get("num_context_lines")) { - n->num_context_lines = Downcast(v)->value; + n->num_context_lines = Downcast(v)->value; } if (auto v = config_dict.Get("path_to_underline")) { n->path_to_underline = Downcast>>(v).value_or(Array()); @@ -98,10 +98,10 @@ PrinterConfig::PrinterConfig(Map config_dict) { Downcast>>(v).value_or(Map()); } if (auto v = config_dict.Get("syntax_sugar")) { - n->syntax_sugar = Downcast(v)->value; + n->syntax_sugar = Downcast(v)->value; } if (auto v = config_dict.Get("show_object_address")) { - n->show_object_address = Downcast(v)->value; + n->show_object_address = Downcast(v)->value; } // Checking prefixes if they are valid Python identifiers. diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 218fe6b1202c..a9402149b9a2 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -171,6 +171,8 @@ class CodeGenVM : public ExprFunctor { EmitAllocTensor(call, dst_reg); } else if (call_node->op == kill_object_op_) { dst_reg = EmitKillObject(call); + } else if (call_node->op == tuple_getitem_op_) { + EmitTupleAccess(call, dst_reg); } else { // every "normal" operator is lowered to a global var in the IRModule. The Attrs for those // ops are handled in a pass when lowering them to TIR. @@ -373,6 +375,12 @@ class CodeGenVM : public ExprFunctor { return dst_reg; } + void EmitTupleAccess(const Call& call_node, RegName dst_register) { + ICHECK_EQ(call_node->args.size(), 2); + std::vector args = VisitArray(call_node->args); + builder_->EmitCall("vm.builtin.tuple_getitem", args, dst_register); + } + void EmitCallBuiltinWithCtx(const Call& call_node, RegName dst_reg) { std::vector args; args.push_back(Instruction::Arg::Register(Instruction::kVMRegister)); @@ -425,6 +433,7 @@ class CodeGenVM : public ExprFunctor { const Op& kill_object_op_ = Op::Get("relax.vm.kill_object"); const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx"); const Op& null_value_op_ = Op::Get("relax.null_value"); + const Op& tuple_getitem_op_ = Op::Get("relax.tuple_get_item_dyn"); }; /*! diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index ec1678e9e0f3..baee0062e99d 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -240,6 +240,8 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { EmitAllocTensor(call, dst_reg); } else if (call_node->op == kill_object_op_) { dst_reg = EmitKillObject(call); + } else if (call_node->op == tuple_getitem_op_) { + EmitTupleAccess(call, dst_reg); } else { // every "normal" operator is lowered to a global var in the IRModule. The Attrs for those // ops are handled in a pass when lowering them to TIR. @@ -433,6 +435,12 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { return dst_reg; } + void EmitTupleAccess(const Call& call_node, int64_t dst_register) { + ICHECK_EQ(call_node->args.size(), 2); + auto args = call_node->args.Map([this](Expr expr) { return VisitExpr(expr).value(); }); + EmitCallPacked("vm.builtin.tuple_getitem", args, dst_register); + } + void EmitCallBuiltinWithCtx(const Call& call_node, int64_t dst_reg) { Array args; // if context is required, pass as first argument. @@ -522,6 +530,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { const Op& kill_object_op_ = Op::Get("relax.vm.kill_object"); const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx"); const Op& null_value_op_ = Op::Get("relax.null_value"); + const Op& tuple_getitem_op_ = Op::Get("relax.tuple_get_item_dyn"); }; /*! diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index fda31e44a920..f82612ce1662 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -686,15 +686,19 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorNormalizeArgument(op->tuple); - TupleGetItem node = new_tuple.same_as(op->tuple) ? GetRef(op) - : TupleGetItem(new_tuple, op->index); + TupleGetItem node = [&]() { + if (new_tuple.same_as(op->tuple) && op->struct_info_.defined()) { + return GetRef(op); + } else { + return TupleGetItem(new_tuple, op->index); + } + }(); - if (!node->struct_info_.defined()) { - auto opt = MatchStructInfo(node->tuple); - ICHECK(opt) << "The struct info of Tuple must be TupleStructInfo, " - << "but expression " << node << " has struct info " << node->struct_info_; - UpdateStructInfo(node, opt.value()->fields[node->index]); - } + ICHECK(node->struct_info_.defined()) + << "InternalError: " + << "TupleGetItem expected to define its struct info on construction, " + << "but access of " << node->tuple << " (struct info = " << node->tuple->struct_info_ + << ") at index " << node->index << " produced empty struct info"; return node; } diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 00ad252ec4a4..1a1d41e6ced3 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -166,13 +166,14 @@ Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional o } TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { - CHECK_GE(index, 0) << "Index out of bounds: Tuple " << tuple - << " cannot be accessed with negative index " << index; + CHECK_GE(index, 0) << "IndexError: " + << "Tuple " << tuple << " cannot be accessed with negative index " << index; ObjectPtr n = make_object(); if (auto* tuple_info = tuple->struct_info_.as()) { CHECK_LT(index, tuple_info->fields.size()) - << "Index out of bounds: Tuple " << tuple << " is of size " << tuple_info->fields.size() + << "IndexError: " + << "Tuple " << tuple << " is of size " << tuple_info->fields.size() << ", and cannot be accessed with index " << index; auto sinfo = tuple_info->fields[index]; n->struct_info_ = sinfo; diff --git a/src/relax/op/tuple.cc b/src/relax/op/tuple.cc new file mode 100644 index 000000000000..6defce67cd71 --- /dev/null +++ b/src/relax/op/tuple.cc @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/op/tuple.cc + * + * builtin intrinsic operators for manipulating tuples + */ +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +namespace { +/*! \brief Utility function for NormalizeTupleGetItem and tuple_get_item + * + * \param index The index at which the tuple is accessed + * + * \return The known index, if static, otherwise std::nullopt. + */ +std::optional FindStaticIndex(const Expr& index) { + if (auto index_sinfo = index->struct_info_.as()) { + if (auto known_index = index_sinfo->value.as()) { + return known_index->value; + } + } + return std::nullopt; +} +} // namespace + +StructInfo InferStructInfoTupleGetItem(const Call& call, const BlockBuilder&) { + CHECK_EQ(call->args.size(), 2) << "Operator " << call->op + << " expects exactly two arguments [tuple, index], " + << "but received " << call->args.size() + << " arguments in expression " << call; + auto tuple = call->args[0]; + auto index = call->args[1]; + + auto tuple_sinfo = tuple->struct_info_.as(); + CHECK(tuple_sinfo) << "Operator " << call->op + << " expects its first argument to specify a tuple, " + << "but expression " << call << " has tuple argument " << tuple + << ", which has struct info " << tuple->struct_info_; + + auto index_sinfo = index->struct_info_.as(); + CHECK(index_sinfo && index_sinfo->dtype == DataType::Int(64)) + << "TupleGetItem requires the index to be a R.Prim('int64'), " + << "but expression " << call << " has index argument " << index << ", which has struct info " + << index->struct_info_; + + auto known_index = index_sinfo->value.as(); + + if (known_index) { + // The exact index used to access the tuple is known. We can + // apply bounds-checking, and can provide the exact StructInfo of + // the accessed element. + int int_index = known_index->value; + + CHECK_GE(int_index, 0) << "IndexError: " + << "Operator " << call->op << " attempted to access tuple " << tuple + << " at index " << index << ". " + << "However, the index " << index << " is known to be " << int_index + << ", and negative indices are not allowed."; + + CHECK_LT(int_index, tuple_sinfo->fields.size()) + << "IndexError: " + << "Operator " << call->op << " attempted to access tuple " << tuple << " at index " + << index << ". " + << "However, tuple " << tuple << " is of size " << tuple_sinfo->fields.size() + << ", the index expression has a known value of " << int_index + << ", outside the bounds of the tuple"; + return tuple_sinfo->fields[int_index]; + + } else { + // The exact index used to access the tuple is unknown. We can't + // apply bounds checking, but we can check that an index might + // exist. We can't provide an exact StructInfo for the accessed + // type, but we can provide the common base type of all items in + // the tuple. + CHECK_GT(tuple_sinfo->fields.size(), 0) + << "IndexError: " + << "The exact value of index " << index << " is unknown, " + << "but expression " << tuple << " has struct info " << tuple->struct_info_ << ". " + << "This is a tuple of length zero, and there is no index such that 0 <= index < 0."; + + StructInfo reduce_lca = tuple_sinfo->fields[0]; + for (size_t i = 1; i < tuple_sinfo->fields.size(); i++) { + reduce_lca = StructInfoLCA(reduce_lca, tuple_sinfo->fields[1]); + } + return reduce_lca; + } +} + +Expr NormalizeTupleGetItem(const BlockBuilder&, const Call& call) { + ICHECK_EQ(call->args.size(), 2); + auto tuple = call->args[0]; + auto index = call->args[1]; + + if (auto index_sinfo = index->struct_info_.as()) { + if (auto known_index = index_sinfo->value.as()) { + return TupleGetItem(tuple, known_index->value); + } + } + return std::move(call); +} + +RELAY_REGISTER_OP("relax.tuple_get_item_dyn") + .set_num_inputs(2) + .add_argument("tuple", "Expr (R.Tuple([...]))", "The tuple to access") + .add_argument("index", "Expr (R.Prim(dtype='int64'))", + "The index at which to access the tuple.") + .set_attr("FInferStructInfo", InferStructInfoTupleGetItem) + .set_attr("FNormalize", NormalizeTupleGetItem) + .set_attr("FPurity", Bool(true)); + +Expr tuple_get_item(Expr tuple, Expr index) { + auto opt_static_index = FindStaticIndex(index); + auto known_tuple = tuple.as(); + + if (opt_static_index && known_tuple) { + // Both the tuple and index are known. We can return the accessed + // expression directly. + return known_tuple->fields[opt_static_index.value()]; + } else if (opt_static_index) { + // The index is known, but the tuple is bound to a variable. We + // can return a static TupleGetItem, which is useful in many + // passes. + return TupleGetItem(tuple, opt_static_index.value()); + } else { + // The index isn't known, so fall back to the most general case. + // If a later pass (e.g. BindParams) provides a statically-known + // index, then this will be normalized back to a TupleGetItem at + // that point. + static const auto op = Op::Get("relax.tuple_get_item_dyn"); + return Call(op, {tuple, index}); + } +} + +TVM_REGISTER_GLOBAL("relax.op.tuple_get_item").set_body_typed(tuple_get_item); + +} // namespace relax +} // namespace tvm diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index fde6daa4d851..d50aaeb8de12 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -4157,11 +4157,13 @@ bool ScanopRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } -Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Bool exclusive) { +Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Optional exclusive) { auto attrs = make_object(); attrs->dtype = dtype; attrs->axis = axis; - attrs->exclusive = exclusive; + if (exclusive.defined()) { + attrs->exclusive = exclusive.value(); + } static const Op& op = Op::Get("cumsum"); return Call(op, {data}, Attrs(attrs), {}); } diff --git a/src/runtime/boxed_primitive.cc b/src/runtime/boxed_primitive.cc new file mode 100644 index 000000000000..9ab83a7b471c --- /dev/null +++ b/src/runtime/boxed_primitive.cc @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/boxed_primitive.cc + * \brief Implementations of ObjectRef wrapper. + */ + +#include +#include + +namespace tvm { +namespace runtime { + +TVM_REGISTER_OBJECT_TYPE(BoxNode); +TVM_REGISTER_OBJECT_TYPE(BoxNode); +TVM_REGISTER_OBJECT_TYPE(BoxNode); + +/* \brief Allow explicit construction of Box + * + * Convert a `bool` to `Box`. For use in FFI handling, to + * provide an umambiguous representation between `bool(true)` and + * `int(1)`. Will be automatically unboxed in the case where a + * `Box` is provided to a PackedFunc that requires `int` input, + * mimicking C++'s default conversions. + * + * This is only needed for Box, as Box and Box + * can be converted in C++ as part of `TVMArgValue::operator + * ObjectRef()` without ambiguity, postponing conversions until + * required. + */ +TVM_REGISTER_GLOBAL("runtime.BoxBool").set_body_typed([](bool value) { return Box(value); }); + +/* \brief Return the underlying boolean object. + * + * Used while unboxing a boolean return value during FFI handling. + * The return type is intentionally `int` and not `bool`, to avoid + * recursive unwrapping of boolean values. + * + * This is only needed for Box, as Box and Box + * can be unambiguously unboxed as part of + * `TVMRetValue::operator=(ObjectRef)`. + */ +TVM_REGISTER_GLOBAL("runtime.UnBoxBool").set_body_typed([](Box obj) -> int { + return obj->value; +}); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index d6b086f201af..394a8e16d070 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -33,6 +33,7 @@ #include #include "../runtime_base.h" +#include "tvm/ir/expr.h" namespace tvm { namespace runtime { diff --git a/src/script/printer/ir/distributed.cc b/src/script/printer/ir/distributed.cc index 29e45bc5c598..5b2c42e99cd7 100644 --- a/src/script/printer/ir/distributed.cc +++ b/src/script/printer/ir/distributed.cc @@ -32,7 +32,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) Array results; results.reserve(s); for (int i = 0; i < s; ++i) { - results.push_back(d->AsDoc(Integer(n[i]), n_p->ArrayIndex(i))); + results.push_back(LiteralDoc::Int(n[i], n_p->ArrayIndex(i))); } return TupleDoc(results); }); diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc index 785dc6d96320..3d85477d1ab1 100644 --- a/src/script/printer/relax/call.cc +++ b/src/script/printer/relax/call.cc @@ -249,6 +249,24 @@ Optional PrintRelaxPrint(const relax::Call& n, const ObjectPath& n_p, return Relax(d, "print")->Call(args, {"format"}, {first_arg}); } +Optional PrintTupleGetItem(const relax::Call& call, const ObjectPath& path, + const IRDocsifier& doc) { + static const Op& print_op = Op::Get("relax.tuple_get_item_dyn"); + if (!call->op.same_as(print_op)) { + return NullOpt; + } + + if (!doc->cfg->syntax_sugar) { + // Fall back to the default printing for builtins as `R.tuple_get_item_dyn` + return NullOpt; + } + + ICHECK_EQ(call->args.size(), 2); + ExprDoc tuple = doc->AsDoc(call->args[0], path->Attr("args")->ArrayIndex(0)); + ExprDoc index = doc->AsDoc(call->args[1], path->Attr("args")->ArrayIndex(1)); + return tuple[{index}]; +} + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::Call n, ObjectPath n_p, IRDocsifier d) -> Doc { @@ -272,6 +290,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (Optional doc = PrintRelaxPrint(n, n_p, d)) { return doc.value(); } + // Special case: tuple_get_item_dyn + if (Optional doc = PrintTupleGetItem(n, n_p, d)) { + return doc.value(); + } ExprDoc prefix{nullptr}; Array args; Array kwargs_keys; diff --git a/src/script/printer/relax/distributed.cc b/src/script/printer/relax/distributed.cc index 9bf49a2830db..b36d7f4480f0 100644 --- a/src/script/printer/relax/distributed.cc +++ b/src/script/printer/relax/distributed.cc @@ -20,6 +20,7 @@ #include #include "../ir/utils.h" +#include "../tir/utils.h" #include "./utils.h" namespace tvm { @@ -101,8 +102,15 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) f = ir_frame; } } + if (!has_relax_frame || !f) { Array args; + + // Device mesh uses the TIR integer conversion rules, so + // we print the arguments using the TIR printer. + With frame(d, n); + (*frame)->AddDispatchToken(d, "tir"); + args.push_back(d->AsDoc(n->shape, n_p->Attr("shape"))); if (n->device_range.defined()) { args.push_back(d->AsDoc(n->device_range, n_p->Attr("device_range"))); diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc index 7c7752cfe65d..a513a36df2f0 100644 --- a/src/script/printer/relax/tir.cc +++ b/src/script/printer/relax/tir.cc @@ -18,6 +18,7 @@ */ #include +#include "../tir/utils.h" #include "./utils.h" namespace tvm { @@ -72,8 +73,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch("relax", P TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "relax", [](tvm::IntImm n, ObjectPath n_p, IRDocsifier d) -> Doc { // - // TODO(@junrushao): support non-int64 cases - return LiteralDoc::Int(n->value, n_p); + ExprDoc doc = LiteralDoc::Int(n->value, n_p); + if (n->dtype != DataType::Int(64)) { + doc = TIR(d, DType2Str(n->dtype))->Call({doc}); + } + return doc; }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) @@ -111,11 +115,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("relax", [](Range range, ObjectPath p, IRDocsifier d) -> Doc { - return Relax(d, "Range") - ->Call({ - d->AsDoc(range->min, p->Attr("min")), - d->AsDoc(range->extent + range->min, p->Attr("extent")), - }); + With frame(d, range); + (*frame)->AddDispatchToken(d, "tir"); + return d->AsDoc(range, p); }); } // namespace printer diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 75b5a2527f76..f7994d940bc0 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -189,4 +189,23 @@ TVM_REGISTER_GLOBAL("testing.ReturnsVariant").set_body_typed([](int x) -> Varian TVM_REGISTER_GLOBAL("testing.AcceptsVariant") .set_body_typed([](Variant arg) -> String { return arg->GetTypeKey(); }); +TVM_REGISTER_GLOBAL("testing.AcceptsBool").set_body_typed([](bool arg) -> bool { return arg; }); + +TVM_REGISTER_GLOBAL("testing.AcceptsInt").set_body_typed([](int arg) -> int { return arg; }); + +TVM_REGISTER_GLOBAL("testing.AcceptsObjectRef").set_body_typed([](ObjectRef arg) -> ObjectRef { + return arg; +}); + +TVM_REGISTER_GLOBAL("testing.AcceptsObjectRefArray") + .set_body_typed([](Array arg) -> ObjectRef { return arg[0]; }); + +TVM_REGISTER_GLOBAL("testing.AcceptsMapReturnsValue") + .set_body_typed([](Map map, ObjectRef key) -> ObjectRef { + return map[key]; + }); + +TVM_REGISTER_GLOBAL("testing.AcceptsMapReturnsMap") + .set_body_typed([](Map map) -> ObjectRef { return map; }); + } // namespace tvm diff --git a/src/target/tag.cc b/src/target/tag.cc index e6521d384397..4360cf389bcf 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -75,46 +75,46 @@ TVM_REGISTER_TARGET_TAG("raspberry-pi/4b-aarch64") {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("cortex-a72")}, {"mattr", Array{"+neon"}}, - {"num-cores", Integer(4)}, + {"num-cores", runtime::BoxInt(4)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("cortex-a72")}, {"mattr", Array{"+neon"}}, - {"num-cores", Integer(4)}}}}); + {"num-cores", runtime::BoxInt(4)}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-xavier") .set_config({{"kind", String("cuda")}, {"arch", String("sm_72")}, - {"max_shared_memory_per_block", Integer(49152)}, - {"max_threads_per_block", Integer(1024)}, - {"thread_warp_size", Integer(32)}, - {"registers_per_block", Integer(65536)}, + {"max_shared_memory_per_block", runtime::BoxInt(49152)}, + {"max_threads_per_block", runtime::BoxInt(1024)}, + {"thread_warp_size", runtime::BoxInt(32)}, + {"registers_per_block", runtime::BoxInt(65536)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("carmel")}, - {"num-cores", Integer(8)}}}}); + {"num-cores", runtime::BoxInt(8)}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-orin-nano") .set_config({{"kind", String("cuda")}, {"arch", String("sm_87")}, - {"max_shared_memory_per_block", Integer(49152)}, - {"max_threads_per_block", Integer(1024)}, - {"thread_warp_size", Integer(32)}, - {"registers_per_block", Integer(65536)}, + {"max_shared_memory_per_block", runtime::BoxInt(49152)}, + {"max_threads_per_block", runtime::BoxInt(1024)}, + {"thread_warp_size", runtime::BoxInt(32)}, + {"registers_per_block", runtime::BoxInt(65536)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("carmel")}, - {"num-cores", Integer(6)}}}}); + {"num-cores", runtime::BoxInt(6)}}}}); -#define TVM_REGISTER_CUDA_TAG(Name, Arch, SharedMem, RegPerBlock) \ - TVM_REGISTER_TARGET_TAG(Name).set_config({ \ - {"kind", String("cuda")}, \ - {"keys", Array{"cuda", "gpu"}}, \ - {"arch", String(Arch)}, \ - {"max_shared_memory_per_block", Integer(SharedMem)}, \ - {"max_threads_per_block", Integer(1024)}, \ - {"thread_warp_size", Integer(32)}, \ - {"registers_per_block", Integer(RegPerBlock)}, \ +#define TVM_REGISTER_CUDA_TAG(Name, Arch, SharedMem, RegPerBlock) \ + TVM_REGISTER_TARGET_TAG(Name).set_config({ \ + {"kind", String("cuda")}, \ + {"keys", Array{"cuda", "gpu"}}, \ + {"arch", String(Arch)}, \ + {"max_shared_memory_per_block", runtime::BoxInt(SharedMem)}, \ + {"max_threads_per_block", runtime::BoxInt(1024)}, \ + {"thread_warp_size", runtime::BoxInt(32)}, \ + {"registers_per_block", runtime::BoxInt(RegPerBlock)}, \ }) // Naming convention for CUDA tags see https://developer.nvidia.com/cuda-gpus @@ -130,7 +130,7 @@ TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2075", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2050", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2070", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a100", "sm_80", 49152, 65536) - .with_config("l2_cache_size_bytes", Integer(41943040)); + .with_config("l2_cache_size_bytes", runtime::BoxInt(41943040)); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a40", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a30", "sm_80", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a10", "sm_86", 49152, 65536); @@ -233,7 +233,7 @@ TVM_REGISTER_CUDA_TAG("nvidia/nvs-5400m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvs-5200m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvs-4200m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-4090", "sm_89", 49152, 65536) - .with_config("l2_cache_size_bytes", Integer(75497472)); + .with_config("l2_cache_size_bytes", runtime::BoxInt(75497472)); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3090-ti", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3090", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3080-ti", "sm_86", 49152, 65536); @@ -386,7 +386,7 @@ TVM_REGISTER_CUDA_TAG("nvidia/tegra-x1", "sm_53", 49152, 32768); TVM_REGISTER_TARGET_TAG(Name).set_config({{"kind", String("llvm")}, \ {"keys", Array{"x86", "cpu"}}, \ {"mcpu", String(Arch)}, \ - {"num-cores", Integer(Cores)}}); + {"num-cores", runtime::BoxInt(Cores)}}); TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.large", 1, "skylake-avx512"); TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.xlarge", 2, "skylake-avx512"); @@ -402,9 +402,9 @@ TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.24xlarge", 48, "cascadelake"); #define TVM_REGISTER_METAL_GPU_TAG(Name, ThreadsPerBlock, SharedMem, WarpSize) \ TVM_REGISTER_TARGET_TAG(Name).set_config( \ {{"kind", String("metal")}, \ - {"max_threads_per_block", Integer(ThreadsPerBlock)}, \ - {"max_shared_memory_per_block", Integer(SharedMem)}, \ - {"thread_warp_size", Integer(WarpSize)}, \ + {"max_threads_per_block", runtime::BoxInt(ThreadsPerBlock)}, \ + {"max_shared_memory_per_block", runtime::BoxInt(SharedMem)}, \ + {"thread_warp_size", runtime::BoxInt(WarpSize)}, \ {"host", Map{{"kind", String("llvm")}, \ {"mtriple", String("arm64-apple-macos")}, \ {"mcpu", String("apple-latest")}}}}); diff --git a/src/target/target.cc b/src/target/target.cc index cd2e3714e422..10223b239385 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -359,24 +359,31 @@ const TargetKindNode::ValueTypeInfo& TargetInternal::FindTypeInfo(const TargetKi ObjectRef TargetInternal::ParseType(const std::string& str, const TargetKindNode::ValueTypeInfo& info) { std::string interp_str = Interpret(str); - if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { - // Parsing integer + if (info.type_index == runtime::BoxInt::ContainerType::_GetOrAllocRuntimeTypeIndex() || + info.type_index == runtime::BoxBool::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + // Parsing integer or boolean std::istringstream is(interp_str); int v; if (!(is >> v)) { std::string lower(interp_str.size(), '\x0'); std::transform(interp_str.begin(), interp_str.end(), lower.begin(), [](unsigned char c) { return std::tolower(c); }); - // Bool is a subclass of IntImm, so allow textual boolean values. + // Mimic C++ automatic conversions, allowing bool to be used for + // integer parameters. if (lower == "true") { v = 1; } else if (lower == "false") { v = 0; } else { - throw Error(": Cannot parse into type \"Integer\" from string: " + interp_str); + throw Error(": Cannot parse integer from string: " + interp_str); } } - return Integer(v); + + if (info.type_index == runtime::BoxInt::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + return runtime::BoxInt(v); + } else { + return runtime::BoxBool(v); + } } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing string, strip leading/trailing spaces, and enclosing quotes if any auto start = interp_str.find_first_not_of(' '); @@ -410,13 +417,14 @@ ObjectRef TargetInternal::ParseType(const std::string& str, ObjectRef TargetInternal::ParseType(const ObjectRef& obj, const TargetKindNode::ValueTypeInfo& info) { - if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + if (info.type_index == runtime::BoxInt::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing integer - return GetRef(ObjTypeCheck(obj, "Integer")); - } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + return GetRef( + ObjTypeCheck(obj, "runtime.BoxInt")); + } else if (info.type_index == String::ContainerType::RuntimeTypeIndex()) { // Parsing string return GetRef(ObjTypeCheck(obj, "String")); - } else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + } else if (info.type_index == Target::ContainerType::RuntimeTypeIndex()) { // Parsing target if (auto opt = obj.as()) { return opt.value(); @@ -483,7 +491,9 @@ ObjectRef TargetInternal::ParseType(const ObjectRef& obj, /********** Stringifying **********/ std::string TargetInternal::StringifyAtomicType(const ObjectRef& obj) { - if (const auto* p = obj.as()) { + if (const auto* p = obj.as>()) { + return std::to_string(p->value); + } else if (const auto* p = obj.as>()) { return std::to_string(p->value); } if (auto tvm_str = obj.as()) { @@ -953,7 +963,7 @@ ObjectPtr TargetInternal::FromConfig(Map config) { // If requested, query attributes from the device. User-specified // parameters take precedence over queried parameters. if (attrs.count("from_device")) { - int device_id = Downcast(attrs.at("from_device")).IntValue(); + int device_id = Downcast(attrs.at("from_device"))->value; attrs.erase("from_device"); auto device_params = QueryDevice(device_id, target.get()); @@ -1006,38 +1016,13 @@ std::unordered_map TargetInternal::QueryDevice(int device_id, for (const auto& kv : target->kind->key2vtype_) { const String& key = kv.first; - const TargetKindNode::ValueTypeInfo& type_info = kv.second; TVMRetValue ret; api->GetTargetProperty(device, key, &ret); - switch (ret.type_code()) { - case kTVMNullptr: - // Nothing returned for this parameter, move on to the next one. - continue; - - case kTVMArgInt: - if (type_info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { - output[key] = Integer(static_cast(ret)); - } else if (type_info.type_index == Bool::ContainerType::_GetOrAllocRuntimeTypeIndex()) { - output[key] = Bool(static_cast(ret)); - } else { - LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key - << "', but received integer from device api"; - } - break; - - case kTVMStr: - ICHECK_EQ(type_info.type_index, String::ContainerType::_GetOrAllocRuntimeTypeIndex()) - << "Expected " << type_info.type_key << " parameter for attribute '" << key - << "', but received string from device api"; - output[key] = String(ret.operator std::string()); - break; - - default: - LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key - << "', but received TVMArgTypeCode(" << ret.type_code() << ") from device api"; - break; + // Delegate conversion from TVMRetValue to the FFI's default conversions. + if (Optional opt = ret) { + output[key] = opt.value(); } } diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index aa4499ec9667..c659859c6b2b 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -266,7 +266,7 @@ TargetJSON UpdateROCmAttrs(TargetJSON target) { * \return The updated attributes */ TargetJSON TestTargetParser(TargetJSON target) { - Map features = {{"is_test", Bool(true)}}; + Map features = {{"is_test", runtime::BoxBool(true)}}; target.Set("features", features); return target; } @@ -279,22 +279,22 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("mtriple") .add_attr_option("mfloat-abi") .add_attr_option("mabi") - .add_attr_option("num-cores") + .add_attr_option("num-cores") // Fast math flags, see https://llvm.org/docs/LangRef.html#fast-math-flags - .add_attr_option("fast-math") // implies all the below - .add_attr_option("fast-math-nnan") - .add_attr_option("fast-math-ninf") - .add_attr_option("fast-math-nsz") - .add_attr_option("fast-math-arcp") - .add_attr_option("fast-math-contract") - .add_attr_option("fast-math-reassoc") - .add_attr_option("opt-level") + .add_attr_option("fast-math") // implies all the below + .add_attr_option("fast-math-nnan") + .add_attr_option("fast-math-ninf") + .add_attr_option("fast-math-nsz") + .add_attr_option("fast-math-arcp") + .add_attr_option("fast-math-contract") + .add_attr_option("fast-math-reassoc") + .add_attr_option("opt-level") // LLVM command line flags, see below .add_attr_option>("cl-opt") .set_default_keys({"cpu"}) // Force the external codegen kind attribute to be registered, even if no external // codegen targets are enabled by the TVM build. - .set_attr(tvm::attr::kIsExternalCodegen, Bool(false)) + .set_attr(tvm::attr::kIsExternalCodegen, runtime::BoxBool(false)) .set_target_parser(tvm::target::parsers::cpu::ParseTarget); // Note regarding the "cl-opt" attribute: @@ -322,28 +322,29 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) TVM_REGISTER_TARGET_KIND("c", kDLCPU) .add_attr_option("mcpu") .add_attr_option("march") - .add_attr_option("workspace-byte-alignment") - .add_attr_option("constants-byte-alignment") + .add_attr_option("workspace-byte-alignment") + .add_attr_option("constants-byte-alignment") .set_default_keys({"cpu"}) .set_target_parser(tvm::target::parsers::cpu::ParseTarget); TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA) .add_attr_option("mcpu") .add_attr_option("arch") - .add_attr_option("max_shared_memory_per_block") - .add_attr_option("max_threads_per_block") - .add_attr_option("thread_warp_size", Integer(32)) - .add_attr_option("registers_per_block") - .add_attr_option("l2_cache_size_bytes") - .add_attr_option("max_num_threads", Integer(1024)) // TODO(@zxybazh): deprecate it + .add_attr_option("max_shared_memory_per_block") + .add_attr_option("max_threads_per_block") + .add_attr_option("thread_warp_size", runtime::BoxInt(32)) + .add_attr_option("registers_per_block") + .add_attr_option("l2_cache_size_bytes") + .add_attr_option("max_num_threads", + runtime::BoxInt(1024)) // TODO(@zxybazh): deprecate it .set_default_keys({"cuda", "gpu"}) .set_target_parser(UpdateCUDAAttrs); TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA) .add_attr_option("mcpu") .add_attr_option("mtriple") - .add_attr_option("max_num_threads", Integer(1024)) - .add_attr_option("thread_warp_size", Integer(32)) + .add_attr_option("max_num_threads", runtime::BoxInt(1024)) + .add_attr_option("thread_warp_size", runtime::BoxInt(32)) .set_default_keys({"cuda", "gpu"}) .set_target_parser(UpdateNVPTXAttrs); @@ -353,24 +354,24 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM) .add_attr_option>("mattr") // TODO(masahi): Support querying from a target device // On RDNA cards, thread_warp_size should be 32 - .add_attr_option("max_num_threads", Integer(256)) - .add_attr_option("max_threads_per_block", Integer(256)) - .add_attr_option("max_shared_memory_per_block", Integer(65536)) - .add_attr_option("thread_warp_size", Integer(64)) + .add_attr_option("max_num_threads", runtime::BoxInt(256)) + .add_attr_option("max_threads_per_block", runtime::BoxInt(256)) + .add_attr_option("max_shared_memory_per_block", runtime::BoxInt(65536)) + .add_attr_option("thread_warp_size", runtime::BoxInt(64)) .set_default_keys({"rocm", "gpu"}) .set_target_parser(UpdateROCmAttrs); TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL) - .add_attr_option("max_threads_per_block", Integer(256)) - .add_attr_option("max_shared_memory_per_block", Integer(16384)) - .add_attr_option("max_num_threads", Integer(256)) - .add_attr_option("thread_warp_size", Integer(1)) - .add_attr_option("texture_spatial_limit", Integer(16384)) + .add_attr_option("max_threads_per_block", runtime::BoxInt(256)) + .add_attr_option("max_shared_memory_per_block", runtime::BoxInt(16384)) + .add_attr_option("max_num_threads", runtime::BoxInt(256)) + .add_attr_option("thread_warp_size", runtime::BoxInt(1)) + .add_attr_option("texture_spatial_limit", runtime::BoxInt(16384)) // Faced that Qualcomm OpenCL runtime crashed without any error message in // the case when the number of kernel arguments was pretty big. OpenCL doesn't // specify any limitations on the number of kernel arguments. max_function_args // equals to 128 looks like a reasonable number of kernel arguments. - .add_attr_option("max_function_args", Integer(128)) + .add_attr_option("max_function_args", runtime::BoxInt(128)) .set_default_keys({"opencl", "gpu"}); // The metal has some limitations on the number of input parameters. This is why attribute @@ -379,55 +380,55 @@ TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL) // https://developer.apple.com/documentation/metal/buffers/about_argument_buffers?language=objc // See also https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf TVM_REGISTER_TARGET_KIND("metal", kDLMetal) - .add_attr_option("max_num_threads", Integer(256)) - .add_attr_option("max_threads_per_block", Integer(256)) - .add_attr_option("max_shared_memory_per_block", Integer(32768)) - .add_attr_option("thread_warp_size", Integer(16)) - .add_attr_option("max_function_args", Integer(31)) + .add_attr_option("max_num_threads", runtime::BoxInt(256)) + .add_attr_option("max_threads_per_block", runtime::BoxInt(256)) + .add_attr_option("max_shared_memory_per_block", runtime::BoxInt(32768)) + .add_attr_option("thread_warp_size", runtime::BoxInt(16)) + .add_attr_option("max_function_args", runtime::BoxInt(31)) .set_default_keys({"metal", "gpu"}); TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option>("mattr") // Feature support - .add_attr_option("supports_float16") - .add_attr_option("supports_float32", Bool(true)) - .add_attr_option("supports_float64") - .add_attr_option("supports_int8") - .add_attr_option("supports_int16") - .add_attr_option("supports_int32", Bool(true)) - .add_attr_option("supports_int64") - .add_attr_option("supports_8bit_buffer") - .add_attr_option("supports_16bit_buffer") - .add_attr_option("supports_storage_buffer_storage_class") - .add_attr_option("supports_push_descriptor") - .add_attr_option("supports_dedicated_allocation") - .add_attr_option("supports_integer_dot_product") - .add_attr_option("supports_cooperative_matrix") - .add_attr_option("supported_subgroup_operations") + .add_attr_option("supports_float16") + .add_attr_option("supports_float32", runtime::BoxBool(true)) + .add_attr_option("supports_float64") + .add_attr_option("supports_int8") + .add_attr_option("supports_int16") + .add_attr_option("supports_int32", runtime::BoxBool(true)) + .add_attr_option("supports_int64") + .add_attr_option("supports_8bit_buffer") + .add_attr_option("supports_16bit_buffer") + .add_attr_option("supports_storage_buffer_storage_class") + .add_attr_option("supports_push_descriptor") + .add_attr_option("supports_dedicated_allocation") + .add_attr_option("supports_integer_dot_product") + .add_attr_option("supports_cooperative_matrix") + .add_attr_option("supported_subgroup_operations") // Physical device limits - .add_attr_option("max_num_threads", Integer(256)) - .add_attr_option("max_threads_per_block", Integer(256)) - .add_attr_option("thread_warp_size", Integer(1)) - .add_attr_option("max_block_size_x") - .add_attr_option("max_block_size_y") - .add_attr_option("max_block_size_z") - .add_attr_option("max_push_constants_size") - .add_attr_option("max_uniform_buffer_range") - .add_attr_option("max_storage_buffer_range") - .add_attr_option("max_per_stage_descriptor_storage_buffer") - .add_attr_option("max_shared_memory_per_block") + .add_attr_option("max_num_threads", runtime::BoxInt(256)) + .add_attr_option("max_threads_per_block", runtime::BoxInt(256)) + .add_attr_option("thread_warp_size", runtime::BoxInt(1)) + .add_attr_option("max_block_size_x") + .add_attr_option("max_block_size_y") + .add_attr_option("max_block_size_z") + .add_attr_option("max_push_constants_size") + .add_attr_option("max_uniform_buffer_range") + .add_attr_option("max_storage_buffer_range") + .add_attr_option("max_per_stage_descriptor_storage_buffer") + .add_attr_option("max_shared_memory_per_block") // Other device properties .add_attr_option("device_type") .add_attr_option("device_name") .add_attr_option("driver_name") - .add_attr_option("driver_version") - .add_attr_option("vulkan_api_version") - .add_attr_option("max_spirv_version") + .add_attr_option("driver_version") + .add_attr_option("vulkan_api_version") + .add_attr_option("max_spirv_version") // Tags .set_default_keys({"vulkan", "gpu"}); TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU) - .add_attr_option("max_num_threads", Integer(256)) + .add_attr_option("max_num_threads", runtime::BoxInt(256)) .set_default_keys({"webgpu", "gpu"}); TVM_REGISTER_TARGET_KIND("sdaccel", kDLOpenCL) // line break @@ -444,8 +445,8 @@ TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon) .add_attr_option("mcpu") .add_attr_option("mtriple") .add_attr_option>("llvm-options") - .add_attr_option("num-cores") - .add_attr_option("vtcm-capacity") + .add_attr_option("num-cores") + .add_attr_option("vtcm-capacity") .set_default_keys({"hexagon", "cpu"}); TVM_REGISTER_TARGET_KIND("stackvm", kDLCPU) // line break diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 41500051fa89..7064433a2874 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -510,7 +510,9 @@ Call::Call(DataType dtype, RelayExpr op, Array args, Span span) { } TVM_REGISTER_GLOBAL("tir.Call") - .set_body_typed([](DataType type, RelayExpr op, Array args, Span span) { + .set_body_typed([](DataType type, RelayExpr op, + Array> args, + Span span) { Array prim_expr_args; for (const auto& it : args) { ICHECK(it->IsInstance() || it->IsInstance() || diff --git a/tests/python/ir/test_ir_container.py b/tests/python/ir/test_ir_container.py index aa482dd65cd7..1e3249197851 100644 --- a/tests/python/ir/test_ir_container.py +++ b/tests/python/ir/test_ir_container.py @@ -23,16 +23,19 @@ def test_array(): a = tvm.runtime.convert([1, 2, 3]) assert len(a) == 3 - assert a[-1].value == 3 + assert a[-1] == 3 a_slice = a[-3:-1] - assert (a_slice[0].value, a_slice[1].value) == (1, 2) + assert (a_slice[0], a_slice[1]) == (1, 2) def test_array_save_load_json(): - a = tvm.runtime.convert([1, 2, 3]) + a = tvm.runtime.convert([1, 2, 3.5, True]) json_str = tvm.ir.save_json(a) a_loaded = tvm.ir.load_json(json_str) - assert a_loaded[1].value == 2 + assert a_loaded[1] == 2 + assert a_loaded[2] == 3.5 + assert a_loaded[3] == True + assert isinstance(a_loaded[3], bool) def test_dir_array(): @@ -66,7 +69,7 @@ def test_str_map(): assert "a" in amap assert len(amap) == 2 dd = dict(amap.items()) - assert amap["a"].value == 2 + assert amap["a"] == 2 assert "a" in dd assert "b" in dd @@ -78,7 +81,7 @@ def test_map_save_load_json(): json_str = tvm.ir.save_json(amap) amap = tvm.ir.load_json(json_str) assert len(amap) == 2 - dd = {kv[0].name: kv[1].value for kv in amap.items()} + dd = {kv[0].name: kv[1] for kv in amap.items()} assert dd == {"a": 2, "b": 3} diff --git a/tests/python/relax/distributed/test_distributed_tvmscript_printer.py b/tests/python/relax/distributed/test_distributed_tvmscript_printer.py index f1709c449d16..31e4818e950c 100644 --- a/tests/python/relax/distributed/test_distributed_tvmscript_printer.py +++ b/tests/python/relax/distributed/test_distributed_tvmscript_printer.py @@ -40,7 +40,7 @@ def test_constant(): ) assert ( constant.__str__() - == """R.dist.const(1, R.DTensor((), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "R, R"))""" + == """R.dist.const(1, R.DTensor((), "float32", R.device_mesh((2, 2), T.Range(0, 4)), "R, R"))""" ) @@ -52,7 +52,7 @@ def test_dtensor_struct_info(): ) assert ( obj0.__str__() - == """R.DTensor((32, 32), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "S[1], R")""" + == """R.DTensor((32, 32), "float32", R.device_mesh((2, 2), T.Range(0, 4)), "S[1], R")""" ) obj1 = DTensorStructInfo( @@ -60,7 +60,7 @@ def test_dtensor_struct_info(): ) assert ( obj1.__str__() - == """R.DTensor((32, 32), device_mesh=R.device_mesh((2, 2), R.Range(0, 4)), placement="S[1], R")""" + == """R.DTensor((32, 32), device_mesh=R.device_mesh((2, 2), T.Range(0, 4)), placement="S[1], R")""" ) obj2 = DTensorStructInfo( @@ -113,11 +113,12 @@ def test_func(): _assert_print( TestModule["foo"], """ +# from tvm.script import tir as T # from tvm.script import relax as R @R.function -def foo(x: R.DTensor((128, 128), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "S[0], R")) -> R.DTensor((128, 128), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "S[0], R"): - gv0 = R.dist.call_tir(tir_func, (x,), out_sinfo=R.DTensor((128, 128), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "S[0], R")) +def foo(x: R.DTensor((128, 128), "float32", R.device_mesh((2, 2), T.Range(0, 4)), "S[0], R")) -> R.DTensor((128, 128), "float32", R.device_mesh((2, 2), T.Range(0, 4)), "S[0], R"): + gv0 = R.dist.call_tir(tir_func, (x,), out_sinfo=R.DTensor((128, 128), "float32", R.device_mesh((2, 2), T.Range(0, 4)), "S[0], R")) return gv0 """, ) diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index 2a554f16e23f..089c41ce76c1 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -559,8 +559,8 @@ def foo(x: R.Tensor): ) ) assert 'Op(name="relax.unique")' in foo_str - # the sorted argument is true, so it will be a PrimValue of 1 - assert "PrimExpr(value=`T.int64(1)`)" in foo_str + # the sorted argument is true, so it will be a PrimValue of True + assert "PrimExpr(value=`T.bool(True)`)" in foo_str # axis is -1 assert "PrimExpr(value=`T.int64(-1)`)" in foo_str diff --git a/tests/python/relax/test_tuple_get_item.py b/tests/python/relax/test_tuple_get_item.py new file mode 100644 index 000000000000..d9a4a4ebacde --- /dev/null +++ b/tests/python/relax/test_tuple_get_item.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import tvm +import tvm.script +import tvm.testing +from tvm import relax +from tvm.script import relax as R, tir as T + +import pytest + +exec_mode = tvm.testing.parameter("bytecode", "compiled") + +tuple_type_annotation = tvm.testing.parameter( + by_dict={ + "tuple_of_obj": R.Tuple([R.Object, R.Object]), + "tuple_of_known_types": R.Tuple([R.Prim("int64"), R.Prim("float32")]), + } +) + +tuple_index_type = tvm.testing.parameter("static", "dynamic") + +syntax_sugar = tvm.testing.parameter(by_dict={"sugared": True, "unsugared": False}) + + +def test_vm_tuple_get_item(exec_mode, tuple_type_annotation, tuple_index_type): + def access_tuple(tuple_obj, dyn_index): + if tuple_index_type == "static": + return tuple_obj[0] + elif tuple_index_type == "dynamic": + return tuple_obj[dyn_index] + + @R.function(private=True) + def func(arg: tuple_type_annotation, index_param: R.Prim(value="index_var")): + index_var = T.int64() + # Trivial binding provides a usage of + # `tuple_type_annotation` within the body of the function, + # which is required to expose it as a meta-variable for + # TVMScript. + arg: tuple_type_annotation = arg + return access_tuple(arg, index_param) + + mod = tvm.IRModule({"main": func}) + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + res = vm["main"]((17, 42.5), 0) + assert res == 17 + + +def test_dynamic_index_printing(syntax_sugar: bool): + """Check syntax-sugar for dynamic tuple indices + + The "relax.tuple_get_item_dyn" operator should be printed as + `my_tuple[my_index]` by default, which will regenerate the + original operator when parsed. If syntax sugar is disabled, it + should display the `R.tuple_get_item_dyn` directly. + """ + + @R.function(private=True) + def func( + arg_tuple: R.Tuple([R.Prim("int64"), R.Prim("float32")]), + arg_index: R.Prim(value="index_var"), + ): + return arg_tuple[arg_index] + + script = func.script(syntax_sugar=syntax_sugar) + + if syntax_sugar: + assert "arg_tuple[arg_index]" in script + assert "tuple_get_item_dyn" not in script + else: + assert "arg_tuple[arg_index]" not in script + assert "tuple_get_item_dyn" in script + + roundtrip = tvm.script.from_source(script) + + tvm.ir.assert_structural_equal(func, roundtrip) + + +def test_tuple_get_item_simple(): + exec_mode = "bytecode" + + @R.function(private=True) + def func(arg: R.Tuple([R.Prim("int64"), R.Prim("float32")])): + return arg[0] + + mod = tvm.IRModule({"main": func}) + + target = tvm.target.Target("llvm", host="llvm") + ex = tvm.relax.build(mod, target, exec_mode=exec_mode) + vm = tvm.relax.VirtualMachine(ex, tvm.cpu()) + + res = vm["main"]((17, 42.5)) + assert res == 17 + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index dc3334f216c0..3fc9a26c9fc9 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -41,11 +41,12 @@ def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore _assert_print( func, """ +# from tvm.script import tir as T # from tvm.script import relax as R @R.function def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): - R.func_attr({"some_attr": 1}) + R.func_attr({"some_attr": T.int32(1)}) return a""", ) @@ -60,11 +61,12 @@ def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore _assert_print( func, """ +# from tvm.script import tir as T # from tvm.script import relax as R @R.function(private=True) def main(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): - R.func_attr({"some_attr": 1}) + R.func_attr({"some_attr": T.int32(1)}) return a""", ) @@ -266,9 +268,19 @@ def test_func_type(): ) -def test_prim_value(): - obj = relax.PrimValue(1) - _assert_print(obj, "R.prim_value(1)") +def test_prim_value_int64(): + obj = relax.PrimValue(T.int64(1)) + _assert_print(obj, "1") + + +def test_prim_value_int32(): + obj = relax.PrimValue(T.int32(1)) + _assert_print(obj, "R.prim_value(T.int32(1))") + + +def test_prim_value_int16(): + obj = relax.PrimValue(T.int16(1)) + _assert_print(obj, "R.prim_value(T.int16(1))") def test_string_imm(): @@ -721,6 +733,7 @@ def quux(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): obj, """ # from tvm.script import ir as I +# from tvm.script import tir as T # from tvm.script import relax as R @I.ir_module @@ -732,7 +745,7 @@ def bar(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): @R.function def baz(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): - R.func_attr({"relax.force_pure": 1}) + R.func_attr({"relax.force_pure": T.bool(1)}) y: R.Tuple = R.print(format=R.str("Hi there!")) z: R.Tensor((), dtype="int32") = R.add(x, x) return z @@ -744,7 +757,7 @@ def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): @R.function(private=True) def quux(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): - R.func_attr({"relax.force_pure": 1}) + R.func_attr({"relax.force_pure": T.bool(1)}) y: R.Tuple = R.print(format=R.str("Lol")) z: R.Tensor((), dtype="int32") = R.multiply(x, x) return z diff --git a/tests/python/runtime/test_runtime_container.py b/tests/python/runtime/test_runtime_container.py index 7538075ae7f8..ec9da414cc36 100644 --- a/tests/python/runtime/test_runtime_container.py +++ b/tests/python/runtime/test_runtime_container.py @@ -15,12 +15,13 @@ # specific language governing permissions and limitations # under the License. -import numpy as np +import pickle import random + +import numpy as np + import tvm import tvm.testing -import pickle -from tvm import te from tvm import nd, relay from tvm.runtime import container as _container @@ -96,8 +97,66 @@ def test_shape_tuple(): assert stuple == z +def test_bool_argument(): + """Boolean objects are currently stored as int""" + func = tvm.get_global_func("testing.AcceptsBool") + + assert isinstance(func(True), bool) + assert isinstance(func(1), bool) + assert isinstance(func(0), bool) + + +def test_int_argument(): + func = tvm.get_global_func("testing.AcceptsInt") + + assert isinstance(func(True), int) + assert isinstance(func(1), int) + assert isinstance(func(0), int) + + +def test_object_ref_argument(): + func = tvm.get_global_func("testing.AcceptsObjectRef") + + assert isinstance(func(True), bool) + assert isinstance(func(1), int) + assert isinstance(func(3.5), float) + assert func(3.5) == 3.5 + + +def test_object_ref_array_argument(): + func = tvm.get_global_func("testing.AcceptsObjectRefArray") + + assert isinstance(func([True, 17, "hello"]), bool) + assert isinstance(func([True]), bool) + assert isinstance(func([17]), int) + assert isinstance(func(["hello"]), str) + + +def test_map_argument_returns_value(): + func = tvm.get_global_func("testing.AcceptsMapReturnsValue") + + res = func({"a": 1, "b": 2}, "a") + assert isinstance(res, int) + assert res == 1 + + res = func({"a": True, "b": False}, "a") + assert isinstance(res, bool) + assert res == True + + +def test_map_argument_returns_map(): + func = tvm.get_global_func("testing.AcceptsMapReturnsMap") + + res = func({"a": 1, "b": 2}) + for key, value in res.items(): + assert isinstance(key, str) + assert isinstance(value, int) + + res = func({"a": False, "b": True}) + for key, value in res.items(): + assert isinstance(key, str) + assert isinstance(value, bool) + + if __name__ == "__main__": - test_string() - test_adt_constructor() - test_tuple_object() - test_shape_tuple() + tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index 5b3e68e22fa9..b2f9b7d51235 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -2689,14 +2689,14 @@ def test_match_buffer_region(): outer_block = root.body.body.body.block assert len(outer_block.match_buffers) == 1 buffer_C = outer_block.match_buffers[0].buffer - tvm.ir.assert_structural_equal(buffer_C.shape, [16, 1, 4]) + tvm.ir.assert_structural_equal(buffer_C.shape, [T.int32(16), T.int32(1), T.int32(4)]) assert isinstance(outer_block.body, tir.stmt.For) assert isinstance(outer_block.body.body, tir.stmt.BlockRealize) inner_block = outer_block.body.body.block assert len(inner_block.match_buffers) == 1 buffer_D = inner_block.match_buffers[0].buffer - tvm.ir.assert_structural_equal(buffer_D.shape, [4, 1, 4]) + tvm.ir.assert_structural_equal(buffer_D.shape, [T.int32(4), T.int32(1), T.int32(4)]) def block_elements():