Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
85f937f
[PR-15983][FFI] Allow IntImm arguments to PackedFunc with int parameter
Lunderberg Oct 25, 2023
acfeb78
Use IntImm unwrapping in relax VM
Lunderberg Oct 25, 2023
29e9327
Compiling with Expr TupleGetItem::index
Lunderberg Oct 26, 2023
aa5dc39
Added unit tests
Lunderberg Oct 26, 2023
38c0cd1
Passing unit tests
Lunderberg Oct 26, 2023
1f569eb
Resolve failing unit tests
Lunderberg Oct 30, 2023
0d9a78a
Fix printing of non-int64 relax integers
Lunderberg Oct 30, 2023
3e19873
Correct conversion of python bool to PrimValue
Lunderberg Oct 30, 2023
7be26d5
Update to fix failing unit tests
Lunderberg Oct 31, 2023
c98041a
Revert majority of implementation, in preparation for re-implementation
Lunderberg Nov 6, 2023
b4245fa
Re-implement dynamic tuple access in terms of intrinsic
Lunderberg Nov 6, 2023
00c4aff
Rename unit test file
Lunderberg Nov 6, 2023
b9bf03b
Implement printing/parsing of dynamic tuple indices
Lunderberg Nov 6, 2023
1dba439
fix lint errors
Lunderberg Nov 7, 2023
8f43de7
Revert "[PR-15983][FFI] Allow IntImm arguments to PackedFunc with int…
Lunderberg Nov 7, 2023
8d357fb
[Container] Support non-nullable types in Array::Map
Lunderberg Nov 28, 2023
e739ba9
[FFI] Separate runtime types from IR types for int/float/bool
Lunderberg Nov 13, 2023
d376d76
[UnitTest] Update apache/main unit tests for Box<int>
Lunderberg Nov 27, 2023
45c1133
[TIR] Update FFI conversion registration
Lunderberg Nov 28, 2023
9fc7ebf
relax tuple get item unit test
Lunderberg Nov 29, 2023
924c8ff
Fixed type used in target tags
Lunderberg Nov 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 71 additions & 29 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -769,53 +769,95 @@ inline const TTypeNode* RelayExprNode::type_as() const {

namespace tvm {
namespace runtime {
// common rule for RetValue and ArgValue

// Automatic conversion into IntImm, Integer, and Bool, when called
// through the FFI. Automatic conversions into PrimExpr are
// registered in "tvm/tir/expr.h", as it includes conversions to the
// TIR-only StringImm.
//
// While the FFI only requires the From() method, these
// implementations also define a TryFrom() method to avoid duplicate
// logic in the PrimExpr conversion.

template <>
struct PackedFuncValueConverter<PrimExpr> {
static PrimExpr From(const TVMPODValue_& val) {
if (val.type_code() == kTVMNullptr) {
return PrimExpr(ObjectPtr<Object>(nullptr));
}
if (val.type_code() == kDLInt) {
int64_t value = val.operator int64_t();
if (value > std::numeric_limits<int>::max() || value < std::numeric_limits<int>::min()) {
return IntImm(runtime::DataType::Int(64), value);
}
return IntImm(runtime::DataType::Int(32), val.operator int());
}
if (val.type_code() == kDLFloat) {
return FloatImm(runtime::DataType::Float(32), val.operator double());
struct PackedFuncValueConverter<tvm::IntImm> {
static Optional<tvm::IntImm> TryFrom(const TVMPODValue_& val) {
if (auto opt = val.TryAsInt()) {
int64_t value = opt.value();
auto dtype =
(value > std::numeric_limits<int>::max() || value < std::numeric_limits<int>::min())
? DataType::Int(64)
: DataType::Int(32);
return IntImm(dtype, value);
} else if (auto opt = val.TryAsBool()) {
return IntImm(DataType::Int(32), opt.value());
} else {
return NullOpt;
}
}

return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
static tvm::IntImm From(const TVMPODValue_& val) {
if (auto opt = TryFrom(val)) {
return opt.value();
} else {
return val.AsObjectRef<tvm::IntImm>();
}
}
};

template <>
struct PackedFuncValueConverter<tvm::Integer> {
static tvm::Integer From(const TVMPODValue_& val) {
if (val.type_code() == kTVMNullptr) {
return Integer(ObjectPtr<Object>(nullptr));
if (auto opt = val.TryAsInt()) {
return Integer(opt.value());
} else if (auto opt = val.TryAsBool()) {
return Integer(opt.value());
} else {
return val.AsObjectRef<tvm::Integer>();
}
if (val.type_code() == kTVMArgInt) {
return Integer(val.operator int());
}
return val.AsObjectRef<tvm::Integer>();
}
};

template <>
struct PackedFuncValueConverter<tvm::Bool> {
static Optional<tvm::Bool> TryFrom(const TVMPODValue_& val) {
if (auto opt = val.TryAsBool()) {
return Bool(opt.value());
} else if (auto opt = val.TryAsInt()) {
int value = opt.value();
ICHECK(value == 0 || value == 1)
<< "ValueError: boolean value can only be 0 or 1, but get " << value;
return Bool(static_cast<bool>(value));
} else {
return NullOpt;
}
}

static tvm::Bool From(const TVMPODValue_& val) {
if (val.type_code() == kTVMNullptr) {
return Bool(ObjectPtr<Object>(nullptr));
if (auto opt = TryFrom(val)) {
return opt.value();
} else {
return val.AsObjectRef<tvm::Bool>();
}
if (val.type_code() == kTVMArgInt) {
int v = val.operator int();
ICHECK(v == 0 || v == 1) << "ValueError: boolean value can only be 0 or 1, but get " << v;
return Bool(static_cast<bool>(v));
}
};

template <>
struct PackedFuncValueConverter<tvm::FloatImm> {
static Optional<tvm::FloatImm> TryFrom(const TVMPODValue_& val) {
if (auto opt = val.TryAsFloat()) {
return FloatImm(runtime::DataType::Float(32), opt.value());
} else {
return NullOpt;
}
}

static tvm::FloatImm From(const TVMPODValue_& val) {
if (auto opt = TryFrom(val)) {
return opt.value();
} else {
return val.AsObjectRef<tvm::FloatImm>();
}
return val.AsObjectRef<tvm::Bool>();
}
};

Expand Down
33 changes: 33 additions & 0 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,25 @@ If WithFields(If if_expr, Optional<Expr> opt_cond = Optional<Expr>(),
Optional<Expr> opt_false_branch = Optional<Expr>(),
Optional<Span> opt_span = Optional<Span>());

/*! \brief Perform tuple access
*
* Use of this method is recommended, rather than constructing a
* `TupleGetItem` directly.
*
* 1. May resolve to the tuple's contents, avoiding the intermediate
* `TupleGetItem`.
*
* 2. Handles access of a tuple at a dynamic index, where
* `TupleGetItem` requires a statically-known index.
*
* \param tuple The tuple to be accessed
*
* \param index The index at which the access occurs
*
* \return An expression for the access of the tuple
*/
Expr tuple_get_item(Expr tuple, Expr index);

/*! \brief Tuple container */
class TupleNode : public ExprNode {
public:
Expand Down Expand Up @@ -320,6 +339,20 @@ class Tuple : public Expr {
*/
TVM_DLL explicit Tuple(tvm::Array<Expr> fields, Span span = Span());

/*! \brief Helper to delegate access to the tuple
*
* The `tuple_get_item` can be applied to any `relax::Expr`.
* However, this helper function is only provided for
* `relax::Tuple`, because `relax::Expr` is a typedef for
* `RelayExpr`, and we should avoid updating relay classes to
* provide relax-specific functionality..
*
* \param index The index at which the tuple is accessed
*
* \return The contents of the tuple at the specified index
*/
inline Expr operator[](Expr index) { return tuple_get_item(*this, index); }

TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode);
};
Expand Down
14 changes: 12 additions & 2 deletions include/tvm/runtime/container/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -827,8 +827,13 @@ class Array : public ObjectRef {
// consisting of any previous elements that had mapped to
// themselves (if any), and the element that didn't map to
// itself.
//
// We cannot use `U()` as the default object, as `U` may be
// a non-nullable type. Since the default `ObjectRef()`
// will be overwritten before returning, all objects will be
// of type `U` for the calling scope.
all_identical = false;
output = ArrayNode::CreateRepeated(arr->size(), U());
output = ArrayNode::CreateRepeated(arr->size(), ObjectRef());
output->InitRange(0, arr->begin(), it);
output->SetItem(it - arr->begin(), std::move(mapped));
it++;
Expand All @@ -843,7 +848,12 @@ class Array : public ObjectRef {
// compatible types isn't strictly necessary, as the first
// mapped.same_as(*it) would return false, but we might as well
// avoid it altogether.
output = ArrayNode::CreateRepeated(arr->size(), U());
//
// We cannot use `U()` as the default object, as `U` may be a
// non-nullable type. Since the default `ObjectRef()` will be
// overwritten before returning, all objects will be of type `U`
// for the calling scope.
output = ArrayNode::CreateRepeated(arr->size(), ObjectRef());
}

// Normal path for incompatible types, or post-copy path for
Expand Down
121 changes: 121 additions & 0 deletions include/tvm/runtime/container/boxed_primitive.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/runtime/container/boxed_primitive.h
* \brief Runtime container types for primitives stored as ObjectRef.
*/
#ifndef TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_
#define TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_

#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>

namespace tvm {
namespace runtime {

namespace detail {
/* \brief Provide the BoxNode<T> type key in templated contexts
*
* The Box<T> class is used in many templated contexts, and is easier
* to have templated over the primitive type. However, much of the
* TVM type system depends on classes having a unique name. For
* example, the use of `Object::IsInstance` depends on
* `Object::GetOrAllocRuntimeTypeIndex`. Any duplicate names will
* result in duplicate indices, and invalid downcasting.
*
* Furthermore, the name must be specified in the Python FFI using
* `tvm._ffi.register_object`. This prevents use of
* `typeid(T)::name()` to build a unique name, as the name is not
* required to be human-readable or consistent across compilers.
*
* This utility struct exists to bridge that gap, providing a unique
* name where required.
*/
template <typename Prim>
struct BoxNodeTypeKey;

template <>
struct BoxNodeTypeKey<int64_t> {
static constexpr const char* _type_key = "runtime.BoxInt";
};

template <>
struct BoxNodeTypeKey<double> {
static constexpr const char* _type_key = "runtime.BoxFloat";
};

template <>
struct BoxNodeTypeKey<bool> {
static constexpr const char* _type_key = "runtime.BoxBool";
};
} // namespace detail

template <typename Prim>
class BoxNode : public Object {
public:
/*! \brief Constructor
*
* \param value The value to be boxed
*/
BoxNode(Prim value) : value(value) {}

/*! \brief The boxed value */
Prim value;

static constexpr const char* _type_key = detail::BoxNodeTypeKey<Prim>::_type_key;
static constexpr bool _type_has_method_visit_attrs = false;
TVM_DECLARE_FINAL_OBJECT_INFO(BoxNode, Object);
};

template <typename Prim>
class Box : public ObjectRef {
public:
/*! \brief Constructor
*
* \param value The value to be boxed
*/
Box(Prim value) : ObjectRef(make_object<BoxNode<Prim>>(value)) {}

operator Prim() const { return (*this)->value; }

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Box, ObjectRef, BoxNode<Prim>);
};

/*! \brief Runtime equivalent of IntImm */
using BoxInt = Box<int64_t>;

/*! \brief Runtime equivalent of FloatImm */
using BoxFloat = Box<double>;

/*! \brief Runtime equivalent of IntImm with DataType::Bool()
*
* When passing from Python to C++, TVM PackedFunc conversion follow
* C++ conversion rules, and allow bool->int and int->bool
* conversions. When passing from C++ to Python, the types are
* returned as bool or int. If the C++ function uses ObjectRef to
* hold the object, a Python to C++ to Python round trip will preserve
* the distinction between bool and int.
*/
using BoxBool = Box<bool>;

} // namespace runtime
} // namespace tvm

#endif // TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_
Loading