Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions ffi/include/tvm/ffi/reflection/reflection.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ class ObjectDef : public ReflectionDefBase {
*
* \return The reflection definition.
*/
template <typename T, typename... Extra>
TVM_FFI_INLINE ObjectDef& def_ro(const char* name, T Class::*field_ptr, Extra&&... extra) {
template <typename T, typename BaseClass, typename... Extra>
TVM_FFI_INLINE ObjectDef& def_ro(const char* name, T BaseClass::*field_ptr, Extra&&... extra) {
RegisterField(name, field_ptr, false, std::forward<Extra>(extra)...);
return *this;
}
Expand All @@ -181,8 +181,8 @@ class ObjectDef : public ReflectionDefBase {
*
* \return The reflection definition.
*/
template <typename T, typename... Extra>
TVM_FFI_INLINE ObjectDef& def_rw(const char* name, T Class::*field_ptr, Extra&&... extra) {
template <typename T, typename BaseClass, typename... Extra>
TVM_FFI_INLINE ObjectDef& def_rw(const char* name, T BaseClass::*field_ptr, Extra&&... extra) {
static_assert(Class::_type_mutable, "Only mutable classes are supported for writable fields");
RegisterField(name, field_ptr, true, std::forward<Extra>(extra)...);
return *this;
Expand Down Expand Up @@ -239,9 +239,10 @@ class ObjectDef : public ReflectionDefBase {
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterExtraInfo(type_index_, &info));
}

template <typename T, typename... ExtraArgs>
void RegisterField(const char* name, T Class::*field_ptr, bool writable,
template <typename T, typename BaseClass, typename... ExtraArgs>
void RegisterField(const char* name, T BaseClass::*field_ptr, bool writable,
ExtraArgs&&... extra_args) {
static_assert(std::is_base_of_v<BaseClass, Class>, "BaseClass must be a base class of Class");
TVMFFIFieldInfo info;
info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
info.field_static_type_index = TypeToFieldStaticTypeIndex<T>::value;
Expand Down
40 changes: 40 additions & 0 deletions ffi/include/tvm/ffi/type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,46 @@ struct TypeTraits<Int, std::enable_if_t<std::is_integral_v<Int>>> : public TypeT
static TVM_FFI_INLINE std::string TypeStr() { return StaticTypeKey::kTVMFFIInt; }
};

// Enum Integer POD values
template <typename IntEnum>
struct TypeTraits<IntEnum, std::enable_if_t<std::is_enum_v<IntEnum> &&
std::is_integral_v<std::underlying_type_t<IntEnum>>>>
: public TypeTraitsBase {
static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIInt;

static TVM_FFI_INLINE void CopyToAnyView(const IntEnum& src, TVMFFIAny* result) {
result->type_index = TypeIndex::kTVMFFIInt;
result->v_int64 = static_cast<int64_t>(src);
}

static TVM_FFI_INLINE void MoveToAny(IntEnum src, TVMFFIAny* result) {
CopyToAnyView(src, result);
}

static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) {
// NOTE: CheckAnyStrict is always strict and should be consistent with MoveToAny
return src->type_index == TypeIndex::kTVMFFIInt;
}

static TVM_FFI_INLINE IntEnum CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
return static_cast<IntEnum>(src->v_int64);
}

static TVM_FFI_INLINE IntEnum MoveFromAnyAfterCheck(TVMFFIAny* src) {
// POD type, we can just copy the value
return CopyFromAnyViewAfterCheck(src);
}

static TVM_FFI_INLINE std::optional<IntEnum> TryCastFromAnyView(const TVMFFIAny* src) {
if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool) {
return static_cast<IntEnum>(src->v_int64);
}
return std::nullopt;
}

static TVM_FFI_INLINE std::string TypeStr() { return StaticTypeKey::kTVMFFIInt; }
};

// Float POD values
template <typename Float>
struct TypeTraits<Float, std::enable_if_t<std::is_floating_point_v<Float>>>
Expand Down
18 changes: 18 additions & 0 deletions ffi/tests/cpp/test_any.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,24 @@ TEST(Any, Int) {
EXPECT_EQ(view0.CopyToTVMFFIAny().v_int64, 2);
}

TEST(Any, Enum) {
enum class ENum : int {
A = 1,
B = 2,
};

AnyView view0;
Optional<ENum> opt_v0 = view0.as<ENum>();
EXPECT_TRUE(!opt_v0.has_value());

AnyView view1 = ENum::A;
EXPECT_EQ(view1.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIInt);
EXPECT_EQ(view1.CopyToTVMFFIAny().v_int64, 1);

ENum v1 = view1.cast<ENum>();
EXPECT_EQ(v1, ENum::A);
}

TEST(Any, bool) {
AnyView view0;
Optional<bool> opt_v0 = view0.as<bool>();
Expand Down
21 changes: 15 additions & 6 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TVM_ARITH_ANALYZER_H_

#include <tvm/arith/int_set.h>
#include <tvm/ffi/reflection/reflection.h>
#include <tvm/ir/expr.h>
#include <tvm/support/with.h>

Expand Down Expand Up @@ -86,11 +87,15 @@ class ConstIntBoundNode : public Object {
int64_t min_value;
int64_t max_value;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("min_value", &min_value);
v->Visit("max_value", &max_value);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<ConstIntBoundNode>()
.def_ro("min_value", &ConstIntBoundNode::min_value)
.def_ro("max_value", &ConstIntBoundNode::max_value);
}

static constexpr bool _type_has_method_visit_attrs = false;

bool SEqualReduce(const ConstIntBoundNode* other, SEqualReducer equal) const {
return equal(min_value, other->min_value) && equal(max_value, other->max_value);
}
Expand Down Expand Up @@ -208,11 +213,15 @@ class ModularSetNode : public Object {
/*! \brief The base */
int64_t base;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("coeff", &coeff);
v->Visit("base", &base);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<ModularSetNode>()
.def_ro("coeff", &ModularSetNode::coeff)
.def_ro("base", &ModularSetNode::base);
}

static constexpr bool _type_has_method_visit_attrs = false;

bool SEqualReduce(const ModularSetNode* other, SEqualReducer equal) const {
return equal(coeff, other->coeff) && equal(base, other->base);
}
Expand Down
37 changes: 23 additions & 14 deletions include/tvm/arith/int_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,13 @@ class IntGroupBoundsNode : public Object {
Array<PrimExpr> equal;
Array<PrimExpr> upper;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("coef", &coef);
v->Visit("lower", &lower);
v->Visit("equal", &equal);
v->Visit("upper", &upper);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<IntGroupBoundsNode>()
.def_ro("coef", &IntGroupBoundsNode::coef)
.def_ro("lower", &IntGroupBoundsNode::lower)
.def_ro("equal", &IntGroupBoundsNode::equal)
.def_ro("upper", &IntGroupBoundsNode::upper);
}

bool SEqualReduce(const IntGroupBoundsNode* other, SEqualReducer eq) const {
Expand All @@ -81,6 +83,7 @@ class IntGroupBoundsNode : public Object {
hash_reduce(upper);
}

static constexpr const bool _type_has_method_visit_attrs = false;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "arith.IntGroupBounds";
TVM_DECLARE_FINAL_OBJECT_INFO(IntGroupBoundsNode, Object);
Expand Down Expand Up @@ -152,10 +155,12 @@ class IntConstraintsNode : public Object {
// e.g., A \alpha = \beta or A \alpha <= \beta
Array<PrimExpr> relations;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("variables", &variables);
v->Visit("ranges", &ranges);
v->Visit("relations", &relations);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<IntConstraintsNode>()
.def_ro("variables", &IntConstraintsNode::variables)
.def_ro("ranges", &IntConstraintsNode::ranges)
.def_ro("relations", &IntConstraintsNode::relations);
}

bool SEqualReduce(const IntConstraintsNode* other, SEqualReducer equal) const {
Expand All @@ -169,6 +174,7 @@ class IntConstraintsNode : public Object {
hash_reduce(relations);
}

static constexpr const bool _type_has_method_visit_attrs = false;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "arith.IntConstraints";
TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsNode, Object);
Expand Down Expand Up @@ -213,11 +219,13 @@ class IntConstraintsTransformNode : public Object {
Map<Var, PrimExpr> src_to_dst;
Map<Var, PrimExpr> dst_to_src;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("src", &src);
v->Visit("dst", &dst);
v->Visit("src_to_dst", &src_to_dst);
v->Visit("dst_to_src", &dst_to_src);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<IntConstraintsTransformNode>()
.def_ro("src", &IntConstraintsTransformNode::src)
.def_ro("dst", &IntConstraintsTransformNode::dst)
.def_ro("src_to_dst", &IntConstraintsTransformNode::src_to_dst)
.def_ro("dst_to_src", &IntConstraintsTransformNode::dst_to_src);
}

bool SEqualReduce(const IntConstraintsTransformNode* other, SEqualReducer equal) const {
Expand All @@ -232,6 +240,7 @@ class IntConstraintsTransformNode : public Object {
hash_reduce(dst_to_src);
}

static constexpr const bool _type_has_method_visit_attrs = false;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "arith.IntConstraintsTransform";
TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsTransformNode, Object);
Expand Down
55 changes: 33 additions & 22 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
#define TVM_ARITH_ITER_AFFINE_MAP_H_

#include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/reflection.h>
#include <tvm/ir/diagnostic.h>
#include <tvm/ir/expr.h>
#include <tvm/tir/var.h>
Expand All @@ -65,9 +66,7 @@ namespace arith {
*/
class IterMapExprNode : public PrimExprNode {
public:
// overrides
void VisitAttrs(tvm::AttrVisitor* v) {}

static constexpr bool _type_has_method_visit_attrs = false;
static constexpr const char* _type_key = "arith.IterMapExpr";
static constexpr const uint32_t _type_child_slots = 2;
TVM_DECLARE_BASE_OBJECT_INFO(IterMapExprNode, PrimExprNode);
Expand Down Expand Up @@ -100,12 +99,15 @@ class IterMarkNode : public Object {
*/
PrimExpr extent;

// overrides
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("source", &source);
v->Visit("extent", &extent);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<IterMarkNode>()
.def_ro("source", &IterMarkNode::source)
.def_ro("extent", &IterMarkNode::extent);
}

static constexpr bool _type_has_method_visit_attrs = false;

bool SEqualReduce(const IterMarkNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(source, other->source) && equal(extent, other->extent);
Expand Down Expand Up @@ -156,14 +158,17 @@ class IterSplitExprNode : public IterMapExprNode {
/*! \brief Additional scale. */
PrimExpr scale;

// overrides
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("source", &source);
v->Visit("lower_factor", &lower_factor);
v->Visit("extent", &extent);
v->Visit("scale", &scale);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<IterSplitExprNode>()
.def_ro("source", &IterSplitExprNode::source)
.def_ro("lower_factor", &IterSplitExprNode::lower_factor)
.def_ro("extent", &IterSplitExprNode::extent)
.def_ro("scale", &IterSplitExprNode::scale);
}

static constexpr bool _type_has_method_visit_attrs = false;

bool SEqualReduce(const IterSplitExprNode* other, SEqualReducer equal) const {
return equal(source, other->source) && equal(lower_factor, other->lower_factor) &&
equal(extent, other->extent) && equal(scale, other->scale);
Expand Down Expand Up @@ -223,12 +228,15 @@ class IterSumExprNode : public IterMapExprNode {
/*! \brief The base offset. */
PrimExpr base;

// overrides
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("args", &args);
v->Visit("base", &base);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<IterSumExprNode>()
.def_ro("args", &IterSumExprNode::args)
.def_ro("base", &IterSumExprNode::base);
}

static constexpr bool _type_has_method_visit_attrs = false;

bool SEqualReduce(const IterSumExprNode* other, SEqualReducer equal) const {
return equal(args, other->args) && equal(base, other->base);
}
Expand Down Expand Up @@ -291,13 +299,16 @@ class IterMapResultNode : public Object {
*/
PrimExpr padding_predicate;

// overrides
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("errors", &errors);
v->Visit("indices", &indices);
v->Visit("padding_predicate", &padding_predicate);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<IterMapResultNode>()
.def_ro("indices", &IterMapResultNode::indices)
.def_ro("errors", &IterMapResultNode::errors)
.def_ro("padding_predicate", &IterMapResultNode::padding_predicate);
}

static constexpr bool _type_has_method_visit_attrs = false;

static constexpr const char* _type_key = "arith.IterMapResult";
TVM_DECLARE_FINAL_OBJECT_INFO(IterMapResultNode, Object);
};
Expand Down
Loading
Loading