From 85f937fda8ed8c048c82b4ac613a37797887d804 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 25 Oct 2023 15:23:11 -0500 Subject: [PATCH 01/21] [PR-15983][FFI] Allow IntImm arguments to PackedFunc with int parameter TVM containers, such as tvm::runtime::Array, require the contained objects to inherit from `ObjectRef`. As a result, the wrapper types `IntImm`, `FloatImm`, and `StringImm` are often used to allow native types in the TVM containers. Conversions into these wrapper type may be required when using a container, and may be performed automatically when passing an object across the FFI. By also providing conversion to an unwrapped type, these automatic conversions are transparent become transparent to users. The trait can be specialized to add type specific conversion logic from the TVMArgvalue and TVMRetValue. --- include/tvm/ir/expr.h | 36 +++++++++++++++ include/tvm/runtime/packed_func.h | 74 +++++++++++++++++++++++++++++++ tests/cpp/packed_func_test.cc | 60 +++++++++++++++++++++++++ 3 files changed, 170 insertions(+) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 594e2b86e9f9..eec2811a4764 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -540,6 +540,24 @@ class IntImm : public PrimExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(IntImmNode); }; +/* \brief FFI extention, ObjectRef to integer conversion + * + * If a PackedFunc expects an integer type, and the user passes an + * IntImm as the argument, this specialization allows it to be + * converted by the FFI. + */ +template +struct runtime::PackedFuncObjectRefConverter>> { + static std::optional TryFrom(const ObjectRef& obj) { + if (auto ptr = obj.as()) { + return ptr->value; + } else { + return std::nullopt; + } + } +}; + /*! * \brief Constant floating point literals in the program. * \sa FloatImm @@ -587,6 +605,24 @@ class FloatImm : public PrimExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(FloatImmNode); }; +/* \brief FFI extention, ObjectRef to integer conversion + * + * If a PackedFunc expects an integer type, and the user passes an + * IntImm as the argument, this specialization allows it to be + * converted by the FFI. + */ +template +struct runtime::PackedFuncObjectRefConverter< + FloatType, std::enable_if_t>> { + static std::optional TryFrom(const ObjectRef& obj) { + if (auto ptr = obj.as()) { + return ptr->value; + } else { + return std::nullopt; + } + } +}; + /*! * \brief Boolean constant. * diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 7266f8c4a50a..d6e3bcd44d1a 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -37,6 +37,7 @@ #include #include #include +#include #include #include #include @@ -545,6 +546,42 @@ struct ObjectTypeChecker> { } }; +class TVMPODValue_; + +/*! + * \brief Type trait to specify special value conversion rules from + * ObjectRef to primitive types. + * + * TVM containers, such as tvm::runtime::Array, require the contained + * objects to inherit from ObjectRef. As a result, the wrapper types + * IntImm, FloatImm, and StringImm are often used to hold primitive + * types inside a TVM container. Conversions into this type may be + * required when using a container, and may be performed + * automatically when passing an object across the FFI. By also + * handling conversions from wrapped to unwrapped types, these + * conversions can be transparent to users. + * + * The trait can be specialized to add type specific conversion logic + * from the TVMArgvalue and TVMRetValue. + * + * \tparam T The type (e.g. int64_t) which may be contained within the + * ObjectRef. + * + * \tparam (anonymous) An anonymous and unused type parameter, which + * may be used for SFINAE. + */ +template +struct PackedFuncObjectRefConverter { + /*! + * \brief Attempt to convert an ObjectRef from an argument value. + * + * \param obj The ObjectRef which may be convertible to T + * + * \return The converted result, or std::nullopt if not convertible. + */ + static std::optional TryFrom(const ObjectRef& obj) { return std::nullopt; } +}; + /*! * \brief Internal base class to * handle conversion to POD values. @@ -557,25 +594,41 @@ class TVMPODValue_ { // the frontend while the API expects a float. if (type_code_ == kDLInt) { return static_cast(value_.v_int64); + } else if (auto opt = ThroughObjectRef()) { + return opt.value(); + } else if (auto opt = ThroughObjectRef()) { + return opt.value(); } TVM_CHECK_TYPE_CODE(type_code_, kDLFloat); return value_.v_float64; } operator int64_t() const { + if (auto opt = ThroughObjectRef()) { + return opt.value(); + } TVM_CHECK_TYPE_CODE(type_code_, kDLInt); return value_.v_int64; } operator uint64_t() const { + if (auto opt = ThroughObjectRef()) { + return opt.value(); + } TVM_CHECK_TYPE_CODE(type_code_, kDLInt); return value_.v_int64; } operator int() const { + if (auto opt = ThroughObjectRef()) { + return opt.value(); + } TVM_CHECK_TYPE_CODE(type_code_, kDLInt); ICHECK_LE(value_.v_int64, std::numeric_limits::max()); ICHECK_GE(value_.v_int64, std::numeric_limits::min()); return static_cast(value_.v_int64); } operator bool() const { + if (auto opt = ThroughObjectRef()) { + return opt.value(); + } TVM_CHECK_TYPE_CODE(type_code_, kDLInt); return value_.v_int64 != 0; } @@ -646,6 +699,27 @@ class TVMPODValue_ { TVMValue value_; /*! \brief the type code */ int type_code_; + + private: + /* \brief A utility function to check for conversions through + * PackedFuncObjectRefConverter + * + * \tparam T The type to attempt to convert into + * + * \return The converted type, or std::nullopt if the value cannot + * be converted into T. + */ + template + std::optional ThroughObjectRef() const { + if (IsObjectRef()) { + if (std::optional from_obj = + PackedFuncObjectRefConverter::TryFrom(AsObjectRef())) { + return from_obj.value(); + } + } + + return std::nullopt; + } }; /*! diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index 183aca1385a7..17f194520aa5 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -319,3 +319,63 @@ TEST(TypedPackedFunc, RValue) { tf(1, true); } } + +TEST(TypedPackedFunc, IntImmWrapper) { + using namespace tvm::runtime; + + TypedPackedFunc typed_func = [](int x) {}; + PackedFunc func = typed_func; + + // Integer argument may be provided + func(5); + + // IntImm argument may be provided, automatically unwrapped. + tvm::IntImm lvalue_intimm(DataType::Int(32), 10); + func(lvalue_intimm); + + // Unwrapping of IntImm argument works for rvalues as well + func(tvm::IntImm(DataType::Int(32), 10)); +} + +TEST(TypedPackedFunc, FloatImmWrapper) { + using namespace tvm::runtime; + + TypedPackedFunc typed_func = [](double x) {}; + PackedFunc func = typed_func; + + // Argument may be provided as a floating point. If provided as an + // integer, it will be converted to a float. + func(static_cast(5.0)); + func(static_cast(5)); + + // IntImm and FloatImm arguments may be provided, and are + // automatically unwrapped. These arguments work correctly for + // either lvalue or rvalue arguments. + + tvm::IntImm lvalue_intimm(DataType::Int(32), 10); + tvm::FloatImm lvalue_floatimm(DataType::Float(32), 10.5); + + func(lvalue_intimm); + func(lvalue_floatimm); + func(tvm::IntImm(DataType::Int(32), 10)); + func(tvm::FloatImm(DataType::Float(32), 10.5)); +} + +TEST(TypedPackedFunc, BoolWrapper) { + using namespace tvm::runtime; + + TypedPackedFunc typed_func = [](bool x) {}; + PackedFunc func = typed_func; + + // Argument may be provided as a floating point. If provided as an + // integer, it will be converted to a float. + func(true); + + tvm::IntImm lvalue_intimm(DataType::Int(32), 10); + func(lvalue_intimm); + func(tvm::IntImm(DataType::Int(32), 10)); + + tvm::Bool lvalue_bool(false); + func(lvalue_bool); + func(tvm::Bool(true)); +} From acfeb7829eb3294ea6858e6bd95aa162e0826899 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 25 Oct 2023 16:11:47 -0500 Subject: [PATCH 02/21] Use IntImm unwrapping in relax VM --- src/runtime/relax_vm/builtin.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index d6b086f201af..394a8e16d070 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -33,6 +33,7 @@ #include #include "../runtime_base.h" +#include "tvm/ir/expr.h" namespace tvm { namespace runtime { From 29e9327d25af12ba612615d9ce821e97a0e0b133 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 26 Oct 2023 16:10:14 -0500 Subject: [PATCH 03/21] Compiling with Expr TupleGetItem::index --- include/tvm/relax/expr.h | 24 ++++- include/tvm/relax/nested_msg.h | 44 +++++---- python/tvm/relax/expr.py | 7 +- src/contrib/msc/core/ir/graph_builder.cc | 10 +- .../msc/core/transform/set_expr_layout.cc | 12 ++- .../msc/core/transform/set_expr_name.cc | 23 +++-- .../contrib/codegen_json/codegen_json.h | 6 +- src/relax/backend/vm/codegen_vm.cc | 7 +- src/relax/backend/vm/codegen_vm_tir.cc | 7 +- src/relax/ir/block_builder.cc | 9 +- src/relax/ir/dataflow_matcher.cc | 18 +++- src/relax/ir/expr.cc | 94 +++++++++++++++---- src/relax/transform/canonicalize_bindings.cc | 9 +- src/relax/transform/convert_layout.cc | 33 ++++++- src/relax/transform/fuse_ops.cc | 23 +++-- src/relax/transform/fuse_tir.cc | 34 +++++-- src/relax/transform/gradient.cc | 12 ++- .../transform/static_plan_block_memory.cc | 19 +++- src/relax/transform/to_mixed_precision.cc | 19 +++- src/script/printer/relax/expr.cc | 5 +- 20 files changed, 314 insertions(+), 101 deletions(-) diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index bb1b2c8dd74a..d077fd896102 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -338,7 +338,7 @@ class TupleGetItemNode : public ExprNode { /*! \brief The tuple Expression */ Expr tuple; /*! \brief which value to get */ - int index; + Expr index; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("tuple_value", &tuple); @@ -358,12 +358,29 @@ class TupleGetItemNode : public ExprNode { hash_reduce(index); } + /* \brief Utility to get the exact index, if known + * + * In the most common case where index is known to be an integer, + * this utility allows it to be extracted. + * + * \return The known integer index, or NullOpt if unknown. + */ + Optional GetKnownIndex() const; + static constexpr const char* _type_key = "relax.expr.TupleGetItem"; TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode); }; class TupleGetItem : public Expr { public: + /*! + * \brief The constructor + * \param tuple The tuple to get an element from. + * \param index The index for extracting a value in the tuple. + * \param span The source span of the expression. + */ + TVM_DLL TupleGetItem(Expr tuple, Expr index, Span span = Span()); + /*! * \brief The constructor * \param tuple The tuple to get an element from. @@ -381,9 +398,8 @@ class TupleGetItem : public Expr { * Returns \p tuple_get_item if all properties are unchanged. Otherwise, returns a copy with the new * fields. */ -TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple = Optional(), - Optional opt_index = Optional(), - Optional opt_span = Optional()); +TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple = NullOpt, + Optional opt_index = NullOpt, Optional opt_span = NullOpt); /*! * \brief Base type of all (non-function) leaf Exprs. diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h index 1ad5f02e0763..e879eccb2500 100644 --- a/include/tvm/relax/nested_msg.h +++ b/include/tvm/relax/nested_msg.h @@ -388,26 +388,36 @@ TargetType NestedMsgTo(NestedMsg msg, FMapLeaf fmapleaf, FCombine fcombine) { */ template Expr NestedMsgToExpr(NestedMsg msg, FType fmapleaf) { - return NestedMsgTo(msg, fmapleaf, [](Array arr) { + return NestedMsgTo(msg, fmapleaf, [](Array arr) -> Expr { + if (arr.empty()) { + return Tuple(arr); + } + Optional simplified_tuple; - bool simplified_flag = false; - if (arr.size() >= 1) { - simplified_flag = true; - for (size_t i = 0; i < arr.size() && simplified_flag; ++i) { - auto* node = arr[i].as(); - if (node == nullptr || node->index != static_cast(i)) { - simplified_flag = false; - } else { - if (simplified_tuple.defined()) { - simplified_flag &= (simplified_tuple == node->tuple); - } else { - simplified_tuple = node->tuple; - ICHECK(simplified_tuple.defined()); - } - } + for (size_t i = 0; i < arr.size(); ++i) { + auto* node = arr[i].as(); + if (node == nullptr) { + return Tuple(arr); + } + + auto index_sinfo = node->index->struct_info_.as(); + CHECK(index_sinfo && index_sinfo->dtype == DataType::Int(64)) + << "The index of TupleGetItem must be R.Prim('int64'), " + << "but expression " << GetRef(node) << " has index " << node->index + << " with struct info " << node->index->struct_info_; + + auto known_index = index_sinfo->value.as(); + if (!known_index || known_index->value != static_cast(i)) { + return Tuple(arr); + } + + if (simplified_tuple && !simplified_tuple.same_as(node->tuple)) { + return Tuple(arr); + } else if (!simplified_tuple) { + simplified_tuple = node->tuple; } } - return simplified_flag ? simplified_tuple.value() : Tuple(arr); + return simplified_tuple.value(); }); } diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 71f23577e70d..f719641fa733 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -370,7 +370,7 @@ class TupleGetItem(ExprWithOp): tuple_value: Expr The input tuple expression. - index: int + index: Union[int, Expr] The index. span: Optional[Span] @@ -381,7 +381,10 @@ class TupleGetItem(ExprWithOp): index: int span: Optional[Span] - def __init__(self, tuple_value: Expr, index: int, span: Optional[Span] = None): + def __init__(self, tuple_value: Expr, index: Union[int, Expr],span: Optional[Span] = None):): + if isinstance(index, int): + index = PrimValue(index) + self.__init_handle_by_constructor__( _ffi_api.TupleGetItem, tuple_value, index, span # type: ignore ) diff --git a/src/contrib/msc/core/ir/graph_builder.cc b/src/contrib/msc/core/ir/graph_builder.cc index dab4ae813ea6..dd790b7f93ac 100644 --- a/src/contrib/msc/core/ir/graph_builder.cc +++ b/src/contrib/msc/core/ir/graph_builder.cc @@ -64,7 +64,11 @@ void RelaxFuncAttrGetter::VisitExpr_(const relax::CallNode* op) { } void RelaxFuncAttrGetter::VisitExpr_(const relax::TupleGetItemNode* op) { - attrs_.Set("index", std::to_string(op->index)); + if (auto known_index = op->GetKnownIndex()) { + attrs_.Set("index", std::to_string(known_index.value()->value)); + } else { + LOG(FATAL) << "MSC does not support TupleGetItem with dynamic index"; + } } void RelaxFuncValueGetter::VisitExpr_(const relax::CallNode* op) { @@ -280,7 +284,9 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional } else if (const auto* shape_node = expr.as()) { attrs.Set("shape", StringUtils::ToString(shape_node->values)); } else if (const auto* get_node = expr.as()) { - attrs.Set("index", std::to_string(get_node->index)); + auto known_value = get_node->GetKnownIndex(); + ICHECK(known_value) << "MSC does not support TupleGetItem with dynamic index"; + attrs.Set("index", std::to_string(known_value.value()->value)); } // Get scope diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc b/src/contrib/msc/core/transform/set_expr_layout.cc index d2023b886a0f..3f33590fb82e 100644 --- a/src/contrib/msc/core/transform/set_expr_layout.cc +++ b/src/contrib/msc/core/transform/set_expr_layout.cc @@ -48,10 +48,14 @@ LayoutDecision InferLayoutDecision(const Expr& expr, const VarLayoutMap& var_lay } LayoutDecision InferLayoutDecisionAt(const Expr& expr, const VarLayoutMap& var_layout_map, - size_t index = 0) { + Expr index = PrimValue::Int64(0)) { const auto& nlayouts = InferNLayout(expr, var_layout_map); if (nlayouts.IsLeaf()) { - return index == 0 ? nlayouts.LeafValue() : LayoutDecision(""); + auto int_index = Downcast(index->struct_info_) + ->value.as() + .value_or(Integer(0)) + ->value; + return int_index == 0 ? nlayouts.LeafValue() : LayoutDecision(""); } const auto& nlayout = nlayouts.NestedArray()[0]; ICHECK(nlayout.IsLeaf()) << "Cannot get output layout for " << expr; @@ -715,7 +719,7 @@ InferLayoutOutput BackwardInferLayoutArgMaxMin(const Call& call, InferLayoutOutput BackwardInferLayoutBatchNorm(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = InferLayoutDecisionAt(call, var_layout_map, 0); + LayoutDecision output_layout = InferLayoutDecisionAt(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); } @@ -749,7 +753,7 @@ InferLayoutOutput BackwardInferLayoutExpandDims(const Call& call, InferLayoutOutput BackwardInferLayoutNormalize(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = InferLayoutDecisionAt(call, var_layout_map, 0); + LayoutDecision output_layout = InferLayoutDecisionAt(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); } diff --git a/src/contrib/msc/core/transform/set_expr_name.cc b/src/contrib/msc/core/transform/set_expr_name.cc index 97850c70e8e8..3897e38ff93c 100644 --- a/src/contrib/msc/core/transform/set_expr_name.cc +++ b/src/contrib/msc/core/transform/set_expr_name.cc @@ -145,12 +145,23 @@ class RelaxExprNameSetter : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) { ExprVisitor::VisitBinding_(binding, val); - String unique_name; - if (expr_names_.count(val->tuple)) { - unique_name = expr_names_[val->tuple] + "." + std::to_string(val->index); - } else if (const auto* v_node = val->tuple.as()) { - unique_name = v_node->name_hint() + "." + std::to_string(val->index); - } + + String unique_name = [&]() { + std::stringstream ss; + if (expr_names_.count(val->tuple)) { + ss << expr_names_[val->tuple]; + } else if (const auto* v_node = val->tuple.as()) { + ss << v_node->name_hint(); + } + ss << "."; + + if (auto known_index = val->GetKnownIndex()) { + ss << known_index.value()->value; + } + + return ss.str(); + }(); + if (unique_name != SpanUtils::GetAttr(val->span, "name")) { val->span = SpanUtils::SetAttr(val->span, "name", unique_name); } diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h b/src/relax/backend/contrib/codegen_json/codegen_json.h index 7dc6ddc16227..cff040391301 100644 --- a/src/relax/backend/contrib/codegen_json/codegen_json.h +++ b/src/relax/backend/contrib/codegen_json/codegen_json.h @@ -363,7 +363,11 @@ class JSONSerializer : public relax::MemoizedExprTranslator { NodeEntries VisitExpr_(const TupleGetItemNode* gtn) { auto vtuple = VisitExpr(gtn->tuple); - return {vtuple[gtn->index]}; + if (auto known_index = gtn->GetKnownIndex()) { + return {vtuple[known_index.value()->value]}; + } else { + return vtuple; + } } NodeEntries VisitExpr_(const FunctionNode* fn) { diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 218fe6b1202c..3aa7f06ba706 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -282,9 +282,10 @@ class CodeGenVM : public ExprFunctor { Instruction::Arg VisitExpr_(const TupleGetItemNode* op) final { TupleGetItem expr = GetRef(op); - std::vector args = {this->VisitExpr(expr->tuple)}; - - args.push_back(builder_->ConvertConstant(expr->index)); + std::vector args = { + VisitExpr(expr->tuple), + VisitExpr(expr->index), + }; size_t dst_register = NewRegister(); builder_->EmitCall("vm.builtin.tuple_getitem", args, dst_register); diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index ec1678e9e0f3..804baf8f66f3 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -323,9 +323,10 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { Optional VisitExpr_(const TupleGetItemNode* op) final { TupleGetItem expr = GetRef(op); - Array args = {this->VisitExpr(expr->tuple).value()}; - - args.push_back(ConstInt64(expr->index)); + Array args = { + VisitExpr(expr->tuple).value(), + VisitExpr(expr->index).value(), + }; int64_t dst_register = NewRegister(); this->EmitCallPacked("vm.builtin.tuple_getitem", args, dst_register); diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index fda31e44a920..76abefbad7b5 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -689,12 +689,9 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctortuple) ? GetRef(op) : TupleGetItem(new_tuple, op->index); - if (!node->struct_info_.defined()) { - auto opt = MatchStructInfo(node->tuple); - ICHECK(opt) << "The struct info of Tuple must be TupleStructInfo, " - << "but expression " << node << " has struct info " << node->struct_info_; - UpdateStructInfo(node, opt.value()->fields[node->index]); - } + ICHECK(node->struct_info_.defined()) + << "InternalError: " + << "TupleGetItem expected to define its struct info on construction"; return node; } diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 9524c90b577c..c2833680f1e4 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -358,11 +358,19 @@ bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const Expr bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr0) { auto expr = TryGetValOfVar(expr0, var2val_); - if (const auto* tuple_get_item_node = expr.as()) { - return (op->index == -1 || op->index == tuple_get_item_node->index) && - VisitDFPattern(op->tuple, tuple_get_item_node->tuple); - } - return false; + const auto* tuple_get_item_node = expr.as(); + if (!tuple_get_item_node) return false; + + bool is_correct_index = [&]() -> bool { + if (op->index == -1) return true; + + auto known_index = tuple_get_item_node->GetKnownIndex(); + if (!known_index) return false; + + return known_index.value()->value == op->index; + }(); + + return is_correct_index && VisitDFPattern(op->tuple, tuple_get_item_node->tuple); } bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr0) { diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 00ad252ec4a4..2791c7f169d1 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -165,29 +165,91 @@ Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional o return tuple; } -TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { - CHECK_GE(index, 0) << "Index out of bounds: Tuple " << tuple - << " cannot be accessed with negative index " << index; - ObjectPtr n = make_object(); +TupleGetItem::TupleGetItem(Expr tuple, int arg_index, Span span) + : TupleGetItem(tuple, PrimValue::Int64(arg_index), span) {} + +TupleGetItem::TupleGetItem(Expr tuple, Expr arg_index, Span span) { + auto index_sinfo = arg_index->struct_info_.as(); + CHECK(index_sinfo && index_sinfo->dtype == DataType::Int(64)) + << "TupleGetItem requires the index to be a R.Prim('int64'), " + << "but received " << arg_index << " with struct info " << arg_index->struct_info_; + + auto known_index = index_sinfo->value.as(); + + if (known_index) { + // If we know the index, we can check against the lower bound of + // zero. Checking the upper bound will require also knowing the + // tuple's size. + CHECK_GE(known_index->value, 0) + << "IndexError: " + << "Tuple " << tuple << " cannot be accessed with negative index " << arg_index; + } - if (auto* tuple_info = tuple->struct_info_.as()) { - CHECK_LT(index, tuple_info->fields.size()) - << "Index out of bounds: Tuple " << tuple << " is of size " << tuple_info->fields.size() - << ", and cannot be accessed with index " << index; - auto sinfo = tuple_info->fields[index]; - n->struct_info_ = sinfo; - n->checked_type_ = GetStaticType(sinfo); + auto* tuple_info = tuple->struct_info_.as(); + + Optional item_sinfo = NullOpt; + + if (known_index && tuple_info) { + // The exact index used to access the tuple is known. We can + // apply bounds-checking, and can provide the exact StructInfo of + // the accessed element. + int int_index = known_index->value; + + CHECK_LT(int_index, tuple_info->fields.size()) + << "IndexError: " + << "Tuple " << tuple << " is of size " << tuple_info->fields.size() + << ", and cannot be accessed with index " << int_index; + item_sinfo = tuple_info->fields[int_index]; + + } else if (tuple_info) { + // The exact index used to access the tuple is unknown. We can't + // apply bounds checking, but we can check that an index might + // exist. We can't provide an exact StructInfo for the accessed + // type, but we can provide the common base type of all items in + // the tuple. + CHECK_GT(tuple_info->fields.size(), 0) + << "IndexError: " + << "The exact value of index " << arg_index << " is unknown, " + << "but expression " << tuple << " has struct info " << tuple->struct_info_ << ". " + << "This is a tuple of length zero, and there is no index such that 0 <= index < 0."; + + StructInfo reduce_lca = tuple_info->fields[0]; + for (size_t i = 1; i < tuple_info->fields.size(); i++) { + reduce_lca = StructInfoLCA(reduce_lca, tuple_info->fields[1]); + } + item_sinfo = reduce_lca; } + + ObjectPtr n = make_object(); n->tuple = std::move(tuple); - n->index = index; + n->index = arg_index; n->span = std::move(span); + if (item_sinfo) { + n->struct_info_ = item_sinfo; + n->checked_type_ = GetStaticType(item_sinfo.value()); + } + data_ = std::move(n); } +Optional TupleGetItemNode::GetKnownIndex() const { + auto prim_sinfo = index->struct_info_.as(); + CHECK(prim_sinfo->dtype == DataType::Int(64)) + << "The index of TupleGetItem must be R.Prim('int64'), " + << "but expression " << GetRef(this) << " has index " << index << " with struct info " + << index->struct_info_; + + if (auto int_index = prim_sinfo->value.as()) { + return Integer(int_index.value()); + } else { + return NullOpt; + } +} + TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, - Optional opt_index, Optional opt_span) { + Optional opt_index, Optional opt_span) { Expr tuple = opt_tuple.value_or(tuple_get_item->tuple); - Integer index = opt_index.value_or(tuple_get_item->index); + Expr index = opt_index.value_or(tuple_get_item->index); Span span = opt_span.value_or(tuple_get_item->span); bool unchanged = tuple.same_as(tuple_get_item->tuple) && (index == tuple_get_item->index) && @@ -195,7 +257,7 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, if (!unchanged) { TupleGetItemNode* cow_tuple_get_item_node = tuple_get_item.CopyOnWrite(); cow_tuple_get_item_node->tuple = tuple; - cow_tuple_get_item_node->index = index.IntValue(); + cow_tuple_get_item_node->index = index; cow_tuple_get_item_node->span = span; } return tuple_get_item; @@ -203,7 +265,7 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, TVM_REGISTER_NODE_TYPE(TupleGetItemNode); -TVM_REGISTER_GLOBAL("relax.TupleGetItem").set_body_typed([](Expr tuple, int index, Span span) { +TVM_REGISTER_GLOBAL("relax.TupleGetItem").set_body_typed([](Expr tuple, Expr index, Span span) { return TupleGetItem(tuple, index, span); }); diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index 246b38f6f83b..6ef07dd11c03 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -138,7 +138,8 @@ class CanonicalizePlanner : public ExprVisitor { LOG(FATAL) << "Invalid binding type: " << binding->GetTypeKey(); } - // Unwrap TupleGetItem, if the Tuple being accessed is known. + // Unwrap TupleGetItem, if were know which Tuple is being + // accessed, and the index at which is is being accessed. if (auto tuple_get_item = value.as()) { Expr tuple = tuple_get_item->tuple; while (auto tuple_var = tuple.as()) { @@ -149,8 +150,10 @@ class CanonicalizePlanner : public ExprVisitor { } } - if (auto ptr = tuple.as()) { - value = ptr->fields[tuple_get_item->index]; + auto known_tuple = tuple.as(); + auto known_index = tuple_get_item->GetKnownIndex(); + if (known_tuple && known_index) { + value = known_tuple->fields[known_index.value()->value]; } } diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc index 6530d0d2cf0c..b2a742ee3ae2 100644 --- a/src/relax/transform/convert_layout.cc +++ b/src/relax/transform/convert_layout.cc @@ -234,10 +234,37 @@ class LayoutConvertMutator : public ExprMutator { NLayout input_layout = binding->var->IsInstance() ? GetNLayout(var_layout_map_, val->tuple) : InitialNLayout(val->tuple); - ReEmitBinding(binding, builder_->Normalize( - TupleGetItem(RewriteExpr(val->tuple, input_layout), val->index))); + + Expr new_tuple = RewriteExpr(val->tuple, input_layout); + Expr new_index = RewriteExpr(val->index, input_layout); + + ReEmitBinding(binding, builder_->Normalize(TupleGetItem(new_tuple, new_index))); + + NLayout item_layout = [&]() { + if (auto known_index = val->GetKnownIndex()) { + // Most common case, we know the index at which the tuple is + // being accessed. + return input_layout.NestedArray()[known_index.value()->value]; + } + + std::unordered_set unique_layouts; + for (const auto& layout : input_layout.NestedArray()) { + unique_layouts.insert(layout); + } + if (unique_layouts.size() == 1) { + // Fallback case. We don't know where we are accessing the + // tuple, but it doesn't matter because all elements in the + // tuple are being transformed. + return *unique_layouts.begin(); + } + + LOG(FATAL) << "Cannot determine the layout of " << GetRef(val) + << ". The index is unknown, and the tuple contains more than multiple layouts: " + << Array(unique_layouts.begin(), unique_layouts.end()); + }(); + // update the layout map - var_layout_map_[binding->var] = input_layout.NestedArray()[val->index]; + var_layout_map_[binding->var] = item_layout; } void VisitBinding_(const MatchCastNode* binding) final { diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index b0eeba399e90..f196a84a527c 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -424,7 +424,15 @@ class FunctionCreator : public ExprMutator { if (partially_used_tuple_params_.find(tuple_item->tuple.get()) != partially_used_tuple_params_.end()) { // Appending get-item index to the mapping. - partially_used_tuple_params_[tuple_item->tuple.get()].push_back(tuple_item->index); + auto& used_indices = partially_used_tuple_params_[tuple_item->tuple.get()]; + if (auto known_index = tuple_item->GetKnownIndex()) { + used_indices.push_back(known_index.value()->value); + } else { + auto num_fields = Downcast(tuple_item->struct_info_)->fields.size(); + for (size_t i = 0; i < num_fields; i++) { + used_indices.push_back(i); + } + } } } @@ -494,12 +502,15 @@ class FunctionCreator : public ExprMutator { // Special handing for TupleGetItem. if (const auto* var_binding = binding.as()) { if (const auto* tuple_get_item = var_binding->value.as()) { - auto it = tuple_get_item_remap.find(tuple_get_item->tuple.get()); - if (it != tuple_get_item_remap.end()) { - ICHECK(it->second.find(tuple_get_item->index) != it->second.end()); - var_remap_[var_binding->var->vid] = it->second[tuple_get_item->index]; + if (auto it = tuple_get_item_remap.find(tuple_get_item->tuple.get()); + it != tuple_get_item_remap.end()) { + auto opt_known_index = tuple_get_item->GetKnownIndex(); + ICHECK(opt_known_index) << "FuseOps requires static indices into tuples"; + int known_index = opt_known_index.value()->value; + ICHECK(it->second.find(known_index) != it->second.end()); + var_remap_[var_binding->var->vid] = it->second[known_index]; if (auto output_idx = GetOutputIndex(binding->var)) { - outputs.Set(*output_idx, it->second[tuple_get_item->index]); + outputs.Set(*output_idx, it->second[known_index]); } continue; } diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index df3c85c05ce1..e2513de19f2b 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -27,6 +27,7 @@ #include "../../relay/analysis/graph_partitioner.h" #include "../../support/arena.h" +#include "../../support/ordered_set.h" #include "../../tir/ir/functor_common.h" namespace tvm { @@ -379,7 +380,15 @@ class FusedTIRConstructor : public ExprVisitor { PostOrderVisit(func->body, [=, &tuple_param](Expr e) { if (auto tup_get = e.as(); tup_get && tuple_param.count(tup_get->tuple.get())) { - func_info_.used_tuple_field_indices[tup_get->tuple.get()].insert(tup_get->index); + auto& used_indices = func_info_.used_tuple_field_indices[tup_get->tuple.get()]; + if (auto known_index = tup_get->GetKnownIndex()) { + used_indices.insert(known_index.value()->value); + } else { + auto num_fields = Downcast(tup_get->struct_info_)->fields.size(); + for (size_t i = 0; i < num_fields; i++) { + used_indices.insert(i); + } + } } }); @@ -523,12 +532,18 @@ class FusedTIRConstructor : public ExprVisitor { void VisitExpr_(const TupleGetItemNode* tuple_get_item) final { ExprVisitor::VisitExpr_(tuple_get_item); + auto it = func_info_.expr2buffers.find(tuple_get_item->tuple); if (it != func_info_.expr2buffers.end()) { + auto opt_known_index = tuple_get_item->GetKnownIndex(); + ICHECK(opt_known_index) << "FuseTIR requires all tuple indices to be known, " + << "but " << GetRef(tuple_get_item) << " has a dynamic index"; + auto tuple_index = opt_known_index.value()->value; + int begin_buf_idx = 0; int end_buf_idx = 0; const TupleType& tuple_type = Downcast(tuple_get_item->tuple->checked_type()); - for (int i = 0; i < tuple_get_item->index; ++i) { + for (int i = 0; i < tuple_index; ++i) { auto it = func_info_.used_tuple_field_indices.find(tuple_get_item->tuple.get()); // If this tuple is not passed as a parameter, or if the field at the index i is actually // used, the corresponding buffer needs to be taken into account by this function. @@ -536,7 +551,7 @@ class FusedTIRConstructor : public ExprVisitor { begin_buf_idx += GetTotalTensorSize(tuple_type->fields[i]); } } - end_buf_idx = begin_buf_idx + GetTotalTensorSize(tuple_type->fields[tuple_get_item->index]); + end_buf_idx = begin_buf_idx + GetTotalTensorSize(tuple_type->fields[tuple_index]); func_info_.expr2buffers.Set( GetRef(tuple_get_item), {(*it).second.begin() + begin_buf_idx, (*it).second.begin() + end_buf_idx}); @@ -851,15 +866,20 @@ class FusedTIRConstructor : public ExprVisitor { std::vector GetTupleAccessedIndices(const FunctionNode* func, const Var& tuple_var) { // Need to be ordered - std::vector indices; + support::OrderedSet indices; PostOrderVisit(func->body, [&indices, tuple_var](Expr e) { if (auto tup_get = e.as(); tup_get && tup_get->tuple.same_as(tuple_var)) { - if (std::find(indices.begin(), indices.end(), tup_get->index) == indices.end()) { - indices.push_back(tup_get->index); + if (auto known_index = tup_get->GetKnownIndex()) { + indices.insert(known_index.value()->value); + } else { + auto num_fields = Downcast(tup_get->struct_info_)->fields.size(); + for (size_t i = 0; i < num_fields; i++) { + indices.insert(i); + } } } }); - return indices; + return std::vector(indices.begin(), indices.end()); } /*! diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc index 70e3e37876fd..e48c659ca6ed 100644 --- a/src/relax/transform/gradient.cc +++ b/src/relax/transform/gradient.cc @@ -418,15 +418,21 @@ class BackwardBindingGenerator : private ExprVisitor { auto* tuple_sinfo = GetStructInfoAs(tuple_get_item->tuple); ICHECK(tuple_sinfo) << "The tuple field of a TupleGetItem must has a TupleStructInfo"; + auto opt_index = tuple_get_item->GetKnownIndex(); + ICHECK(opt_index) << "relax.transform.Gradient requires all tuple indices to be known, " + << "but expression " << GetRef(tuple_get_item) << " has index " + << tuple_get_item->index << ", whose value isn't known."; + int index = opt_index.value()->value; + const Var& tuple_var = Downcast(tuple_get_item->tuple); if (adjoint_var_map_.count(tuple_var) == 0) { auto nested_zeros = Downcast(NestedZeros(GetRef(tuple_sinfo))); auto tuple_fields = nested_zeros->fields; - tuple_fields.Set(tuple_get_item->index, adjoint_var_map_[binding->var]); + tuple_fields.Set(index, adjoint_var_map_[binding->var]); EmitAdjoint(tuple_var, Tuple(tuple_fields), false); } else { - Expr updated_adjoint = AddInTuple(adjoint_var_map_[tuple_var], tuple_get_item->index, - adjoint_var_map_[binding->var]); + Expr updated_adjoint = + AddInTuple(adjoint_var_map_[tuple_var], index, adjoint_var_map_[binding->var]); EmitAdjoint(tuple_var, updated_adjoint, false); } } diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index 4a2a1555ff46..b6c3d5396a1f 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -270,9 +270,22 @@ class StorageAllocatorBaseVisitor : public ExprVisitor { } ICHECK(tokens.IsNested()); Array field_tokens = tokens.NestedArray(); - ICHECK_GT(static_cast(field_tokens.size()), tuple_item->index); - ICHECK_GE(tuple_item->index, 0); - SetTokens(tuple_item, field_tokens[tuple_item->index]); + + auto item_tokens = [&]() -> Tokens { + if (auto known_index = tuple_item->GetKnownIndex()) { + // If the tuple access is at a specific index, the field uses + // the token of that index. + int index = known_index.value()->value; + ICHECK_GT(static_cast(field_tokens.size()), index); + ICHECK_GE(index, 0); + return field_tokens[index]; + } else { + // If the tuple access is at an unknown index, the field may + // require any token from the tuple. + return tokens; + } + }(); + SetTokens(tuple_item, item_tokens); } /******************** Utilities ********************/ diff --git a/src/relax/transform/to_mixed_precision.cc b/src/relax/transform/to_mixed_precision.cc index c844d5935623..a999d8e6c2a1 100644 --- a/src/relax/transform/to_mixed_precision.cc +++ b/src/relax/transform/to_mixed_precision.cc @@ -218,15 +218,24 @@ class DTypeDecisionCollector : public ExprVisitor { // require the i-th field rhs tuple to be the type of the lhs NType lhs_type = GetDType(binding->var); std::vector require_rhs; + const TupleStructInfoNode* sinfo = tuple_get_item_node->tuple->struct_info_.as(); ICHECK(sinfo != nullptr) << "TupleGetItemNode must have TupleStructInfo"; + + auto opt_known_index = tuple_get_item_node->GetKnownIndex(); + ICHECK(opt_known_index) << "ToMixedPrecision pass does not support dynamic tuple indices"; + size_t known_index = opt_known_index.value()->value; + for (size_t i = 0; i < sinfo->fields.size(); ++i) { - if (i == static_cast(tuple_get_item_node->index)) { - require_rhs.push_back(lhs_type); - } else { - require_rhs.push_back(NTypeFrom(sinfo->fields[i], unknown_)); - } + NType field_type = [&]() { + if (i == known_index) { + return lhs_type; + } else { + return NTypeFrom(sinfo->fields[i], unknown_); + } + }(); + require_rhs.push_back(field_type); } RequireArgsToType({tuple_get_item_node->tuple}, {NType(require_rhs)}); } diff --git a/src/script/printer/relax/expr.cc b/src/script/printer/relax/expr.cc index f6cbde0b4b23..595615e9b01a 100644 --- a/src/script/printer/relax/expr.cc +++ b/src/script/printer/relax/expr.cc @@ -64,8 +64,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::TupleGetItem n, ObjectPath n_p, IRDocsifier d) -> Doc { - ExprDoc idx = LiteralDoc::Int(n->index, n_p->Attr("index")); - return d->AsDoc(n->tuple, n_p->Attr("tuple"))[{idx}]; + auto tuple_doc = d->AsDoc(n->tuple, n_p->Attr("tuple")); + auto index_doc = d->AsDoc(n->index, n_p->Attr("index")); + return tuple_doc[{index_doc}]; }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) From aa5dc3965d36b653f83f86e42c9f8db24447eeac Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 26 Oct 2023 16:10:46 -0500 Subject: [PATCH 04/21] Added unit tests --- tests/python/relax/test_vm_tuple_get_item.py | 67 ++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 tests/python/relax/test_vm_tuple_get_item.py diff --git a/tests/python/relax/test_vm_tuple_get_item.py b/tests/python/relax/test_vm_tuple_get_item.py new file mode 100644 index 000000000000..16f99e35fafb --- /dev/null +++ b/tests/python/relax/test_vm_tuple_get_item.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import tvm +import tvm.script +import tvm.testing +from tvm import relax +from tvm.script import relax as R, ir as I + +exec_mode = tvm.testing.parameter("bytecode", "compiled") + +tuple_type_annotation = tvm.testing.parameter( + by_dict={ + "tuple_of_obj": R.Tuple([R.Object, R.Object]), + "tuple_of_known_types": R.Tuple([R.Prim("int64"), R.Prim("float32")]), + } +) + +tuple_index_type = tvm.testing.parameter("static", "dynamic") + + +def test_vm_tuple_get_item(exec_mode, tuple_type_annotation, tuple_index_type): + index_var = tvm.tir.Var("index", "int64") + + def access_tuple(tuple_obj, dyn_index): + if tuple_index_type == "static": + return tuple_obj[0] + elif tuple_index_type == "dynamic": + return tuple_obj[dyn_index] + + @R.function(private=True) + def func(arg: tuple_type_annotation, index_param: R.Prim(value=index_var)): + # Trivial binding provides a usage of + # `tuple_type_annotation` within the body of the function, + # which is required to expose it as a meta-variable for + # TVMScript. + arg: tuple_type_annotation = arg + index_param: R.Prim(value=index_var) = index_param + return access_tuple(arg, index_param) + + mod = tvm.IRModule({"main": func}) + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + res = vm["main"]((17, 42.5), 0) + assert res == 17 + + +if __name__ == "__main__": + tvm.testing.main() From 38c0cd1bb04a7a8af8e538a19f70086d461061e6 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 26 Oct 2023 16:14:53 -0500 Subject: [PATCH 05/21] Passing unit tests --- tests/python/relax/test_vm_tuple_get_item.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/python/relax/test_vm_tuple_get_item.py b/tests/python/relax/test_vm_tuple_get_item.py index 16f99e35fafb..a441f0d85fb2 100644 --- a/tests/python/relax/test_vm_tuple_get_item.py +++ b/tests/python/relax/test_vm_tuple_get_item.py @@ -20,7 +20,7 @@ import tvm.script import tvm.testing from tvm import relax -from tvm.script import relax as R, ir as I +from tvm.script import relax as R, tir as T exec_mode = tvm.testing.parameter("bytecode", "compiled") @@ -35,8 +35,6 @@ def test_vm_tuple_get_item(exec_mode, tuple_type_annotation, tuple_index_type): - index_var = tvm.tir.Var("index", "int64") - def access_tuple(tuple_obj, dyn_index): if tuple_index_type == "static": return tuple_obj[0] @@ -44,13 +42,13 @@ def access_tuple(tuple_obj, dyn_index): return tuple_obj[dyn_index] @R.function(private=True) - def func(arg: tuple_type_annotation, index_param: R.Prim(value=index_var)): + def func(arg: tuple_type_annotation, index_param: R.Prim(value="index_var")): + index_var = T.int64() # Trivial binding provides a usage of # `tuple_type_annotation` within the body of the function, # which is required to expose it as a meta-variable for # TVMScript. arg: tuple_type_annotation = arg - index_param: R.Prim(value=index_var) = index_param return access_tuple(arg, index_param) mod = tvm.IRModule({"main": func}) From 1f569eb9fbeba72981a4c9c202965f4855077dcf Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 30 Oct 2023 08:47:58 -0500 Subject: [PATCH 06/21] Resolve failing unit tests --- .../tvm/relax/transform/lazy_transform_params.py | 3 +-- src/relax/ir/block_builder.cc | 16 +++++++++++++--- src/script/printer/relax/expr.cc | 11 ++++++++++- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/transform/lazy_transform_params.py b/python/tvm/relax/transform/lazy_transform_params.py index 7f734f8a3c47..ee680a9995a4 100644 --- a/python/tvm/relax/transform/lazy_transform_params.py +++ b/python/tvm/relax/transform/lazy_transform_params.py @@ -198,8 +198,7 @@ def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> relax.Expr: get_item_result = self.builder_.emit( relax.Call( relax.ExternFunc(self.func_creator.fget_item), - self.func_creator.extra_get_item_params - + [relax.PrimValue(tuple_get_item.index)], + self.func_creator.extra_get_item_params + [tuple_get_item.index], None, [relax.ObjectStructInfo()], ) diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 76abefbad7b5..dabe0070bfb8 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -685,13 +685,23 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorNormalizeArgument(op->tuple); + Expr new_index = this->NormalizeArgument(op->index); - TupleGetItem node = new_tuple.same_as(op->tuple) ? GetRef(op) - : TupleGetItem(new_tuple, op->index); + TupleGetItem node = [&]() { + if (new_tuple.same_as(op->tuple) && new_index.same_as(op->index) && + op->struct_info_.defined()) { + return GetRef(op); + } else { + return TupleGetItem(new_tuple, new_index); + } + }(); ICHECK(node->struct_info_.defined()) << "InternalError: " - << "TupleGetItem expected to define its struct info on construction"; + << "TupleGetItem expected to define its struct info on construction, " + << "but access of " << node->tuple << " (struct info = " << node->tuple->struct_info_ + << ") at index " << node->index << " (struct info = " << node->index->struct_info_ + << ") produced empty struct info"; return node; } diff --git a/src/script/printer/relax/expr.cc b/src/script/printer/relax/expr.cc index 595615e9b01a..c80a6d370509 100644 --- a/src/script/printer/relax/expr.cc +++ b/src/script/printer/relax/expr.cc @@ -30,8 +30,17 @@ namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::PrimValue n, ObjectPath n_p, IRDocsifier d) -> Doc { + auto path = n_p->Attr("value"); + + // Special case to print `R.prim_value(0)` as `0`, since it + // would be converted back to `R.prim_value` on parsing. + if (d->cfg->syntax_sugar && n->value->dtype == DataType::Int(64)) { + if (auto as_int = n->value.as()) { + return LiteralDoc::Int(as_int->value, path); + } + } // TODO(@junrushao): float numbers - return Relax(d, "prim_value")->Call({d->AsDoc(n->value, n_p->Attr("value"))}); + return Relax(d, "prim_value")->Call({d->AsDoc(n->value, path)}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) From 0d9a78a294e19d3c1e49031d173ff72ad1235a75 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 30 Oct 2023 12:59:35 -0500 Subject: [PATCH 07/21] Fix printing of non-int64 relax integers --- src/script/printer/relax/tir.cc | 7 +++++-- .../python/relax/test_tvmscript_printer_relax.py | 16 +++++++++++++--- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc index 7c7752cfe65d..f0fd7ad4d12b 100644 --- a/src/script/printer/relax/tir.cc +++ b/src/script/printer/relax/tir.cc @@ -72,8 +72,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch("relax", P TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "relax", [](tvm::IntImm n, ObjectPath n_p, IRDocsifier d) -> Doc { // - // TODO(@junrushao): support non-int64 cases - return LiteralDoc::Int(n->value, n_p); + ExprDoc doc = LiteralDoc::Int(n->value, n_p); + if (n->dtype != DataType::Int(64)) { + doc = TIR(d, DType2Str(n->dtype))->Call({doc}); + } + return doc; }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index dc3334f216c0..a951ca648a7f 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -266,9 +266,19 @@ def test_func_type(): ) -def test_prim_value(): - obj = relax.PrimValue(1) - _assert_print(obj, "R.prim_value(1)") +def test_prim_value_int64(): + obj = relax.PrimValue(T.int64(1)) + _assert_print(obj, "1") + + +def test_prim_value_int32(): + obj = relax.PrimValue(T.int32(1)) + _assert_print(obj, "R.prim_value(T.int32(1))") + + +def test_prim_value_int16(): + obj = relax.PrimValue(T.int16(1)) + _assert_print(obj, "R.prim_value(T.int16(1))") def test_string_imm(): From 3e19873a2279a57e698ba4b3db87cc77762ff4b5 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 30 Oct 2023 13:48:13 -0500 Subject: [PATCH 08/21] Correct conversion of python bool to PrimValue Because `isinstance(bool_value, int)` returns True, boolean values were being converted to `T.int64`, instead of to `T.bool`. --- python/tvm/relax/expr.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index f719641fa733..43c44d5ad291 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -563,7 +563,9 @@ class PrimValue(Expr, Scriptable): value: PrimExpr def __init__(self, value: Union[PrimExpr, int], span: Optional[Span] = None) -> None: - if isinstance(value, int): + if isinstance(value, bool): + value = tvm.tir.IntImm("bool", value) + elif isinstance(value, int): value = tvm.tir.IntImm("int64", value) self.__init_handle_by_constructor__(_ffi_api.PrimValue, value, span) # type: ignore From 7be26d5d0aa542ee99fa8b701bea80f525150d0b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 31 Oct 2023 10:58:50 -0500 Subject: [PATCH 09/21] Update to fix failing unit tests --- src/script/printer/ir/distributed.cc | 2 +- src/script/printer/relax/distributed.cc | 8 ++++++++ src/script/printer/relax/tir.cc | 9 ++++----- .../distributed/test_distributed_tvmscript_printer.py | 11 ++++++----- tests/python/relax/test_ast_printer.py | 4 ++-- tests/python/relax/test_tvmscript_printer_relax.py | 11 +++++++---- 6 files changed, 28 insertions(+), 17 deletions(-) diff --git a/src/script/printer/ir/distributed.cc b/src/script/printer/ir/distributed.cc index 29e45bc5c598..5b2c42e99cd7 100644 --- a/src/script/printer/ir/distributed.cc +++ b/src/script/printer/ir/distributed.cc @@ -32,7 +32,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) Array results; results.reserve(s); for (int i = 0; i < s; ++i) { - results.push_back(d->AsDoc(Integer(n[i]), n_p->ArrayIndex(i))); + results.push_back(LiteralDoc::Int(n[i], n_p->ArrayIndex(i))); } return TupleDoc(results); }); diff --git a/src/script/printer/relax/distributed.cc b/src/script/printer/relax/distributed.cc index 9bf49a2830db..b36d7f4480f0 100644 --- a/src/script/printer/relax/distributed.cc +++ b/src/script/printer/relax/distributed.cc @@ -20,6 +20,7 @@ #include #include "../ir/utils.h" +#include "../tir/utils.h" #include "./utils.h" namespace tvm { @@ -101,8 +102,15 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) f = ir_frame; } } + if (!has_relax_frame || !f) { Array args; + + // Device mesh uses the TIR integer conversion rules, so + // we print the arguments using the TIR printer. + With frame(d, n); + (*frame)->AddDispatchToken(d, "tir"); + args.push_back(d->AsDoc(n->shape, n_p->Attr("shape"))); if (n->device_range.defined()) { args.push_back(d->AsDoc(n->device_range, n_p->Attr("device_range"))); diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc index f0fd7ad4d12b..a513a36df2f0 100644 --- a/src/script/printer/relax/tir.cc +++ b/src/script/printer/relax/tir.cc @@ -18,6 +18,7 @@ */ #include +#include "../tir/utils.h" #include "./utils.h" namespace tvm { @@ -114,11 +115,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("relax", [](Range range, ObjectPath p, IRDocsifier d) -> Doc { - return Relax(d, "Range") - ->Call({ - d->AsDoc(range->min, p->Attr("min")), - d->AsDoc(range->extent + range->min, p->Attr("extent")), - }); + With frame(d, range); + (*frame)->AddDispatchToken(d, "tir"); + return d->AsDoc(range, p); }); } // namespace printer diff --git a/tests/python/relax/distributed/test_distributed_tvmscript_printer.py b/tests/python/relax/distributed/test_distributed_tvmscript_printer.py index f1709c449d16..31e4818e950c 100644 --- a/tests/python/relax/distributed/test_distributed_tvmscript_printer.py +++ b/tests/python/relax/distributed/test_distributed_tvmscript_printer.py @@ -40,7 +40,7 @@ def test_constant(): ) assert ( constant.__str__() - == """R.dist.const(1, R.DTensor((), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "R, R"))""" + == """R.dist.const(1, R.DTensor((), "float32", R.device_mesh((2, 2), T.Range(0, 4)), "R, R"))""" ) @@ -52,7 +52,7 @@ def test_dtensor_struct_info(): ) assert ( obj0.__str__() - == """R.DTensor((32, 32), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "S[1], R")""" + == """R.DTensor((32, 32), "float32", R.device_mesh((2, 2), T.Range(0, 4)), "S[1], R")""" ) obj1 = DTensorStructInfo( @@ -60,7 +60,7 @@ def test_dtensor_struct_info(): ) assert ( obj1.__str__() - == """R.DTensor((32, 32), device_mesh=R.device_mesh((2, 2), R.Range(0, 4)), placement="S[1], R")""" + == """R.DTensor((32, 32), device_mesh=R.device_mesh((2, 2), T.Range(0, 4)), placement="S[1], R")""" ) obj2 = DTensorStructInfo( @@ -113,11 +113,12 @@ def test_func(): _assert_print( TestModule["foo"], """ +# from tvm.script import tir as T # from tvm.script import relax as R @R.function -def foo(x: R.DTensor((128, 128), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "S[0], R")) -> R.DTensor((128, 128), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "S[0], R"): - gv0 = R.dist.call_tir(tir_func, (x,), out_sinfo=R.DTensor((128, 128), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "S[0], R")) +def foo(x: R.DTensor((128, 128), "float32", R.device_mesh((2, 2), T.Range(0, 4)), "S[0], R")) -> R.DTensor((128, 128), "float32", R.device_mesh((2, 2), T.Range(0, 4)), "S[0], R"): + gv0 = R.dist.call_tir(tir_func, (x,), out_sinfo=R.DTensor((128, 128), "float32", R.device_mesh((2, 2), T.Range(0, 4)), "S[0], R")) return gv0 """, ) diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index 2a554f16e23f..089c41ce76c1 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -559,8 +559,8 @@ def foo(x: R.Tensor): ) ) assert 'Op(name="relax.unique")' in foo_str - # the sorted argument is true, so it will be a PrimValue of 1 - assert "PrimExpr(value=`T.int64(1)`)" in foo_str + # the sorted argument is true, so it will be a PrimValue of True + assert "PrimExpr(value=`T.bool(True)`)" in foo_str # axis is -1 assert "PrimExpr(value=`T.int64(-1)`)" in foo_str diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index a951ca648a7f..3fc9a26c9fc9 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -41,11 +41,12 @@ def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore _assert_print( func, """ +# from tvm.script import tir as T # from tvm.script import relax as R @R.function def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): - R.func_attr({"some_attr": 1}) + R.func_attr({"some_attr": T.int32(1)}) return a""", ) @@ -60,11 +61,12 @@ def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore _assert_print( func, """ +# from tvm.script import tir as T # from tvm.script import relax as R @R.function(private=True) def main(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): - R.func_attr({"some_attr": 1}) + R.func_attr({"some_attr": T.int32(1)}) return a""", ) @@ -731,6 +733,7 @@ def quux(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): obj, """ # from tvm.script import ir as I +# from tvm.script import tir as T # from tvm.script import relax as R @I.ir_module @@ -742,7 +745,7 @@ def bar(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): @R.function def baz(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): - R.func_attr({"relax.force_pure": 1}) + R.func_attr({"relax.force_pure": T.bool(1)}) y: R.Tuple = R.print(format=R.str("Hi there!")) z: R.Tensor((), dtype="int32") = R.add(x, x) return z @@ -754,7 +757,7 @@ def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): @R.function(private=True) def quux(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): - R.func_attr({"relax.force_pure": 1}) + R.func_attr({"relax.force_pure": T.bool(1)}) y: R.Tuple = R.print(format=R.str("Lol")) z: R.Tensor((), dtype="int32") = R.multiply(x, x) return z From c98041a4e38ed3fe74464d7ce051dedd4f13f69f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 6 Nov 2023 09:53:12 -0600 Subject: [PATCH 10/21] Revert majority of implementation, in preparation for re-implementation --- include/tvm/relax/expr.h | 24 +---- include/tvm/relax/nested_msg.h | 44 ++++----- python/tvm/relax/expr.py | 7 +- .../relax/transform/lazy_transform_params.py | 3 +- src/contrib/msc/core/ir/graph_builder.cc | 10 +- .../msc/core/transform/set_expr_layout.cc | 12 +-- .../msc/core/transform/set_expr_name.cc | 23 ++--- .../contrib/codegen_json/codegen_json.h | 6 +- src/relax/backend/vm/codegen_vm.cc | 7 +- src/relax/backend/vm/codegen_vm_tir.cc | 7 +- src/relax/ir/block_builder.cc | 9 +- src/relax/ir/dataflow_matcher.cc | 18 +--- src/relax/ir/expr.cc | 94 ++++--------------- src/relax/transform/canonicalize_bindings.cc | 9 +- src/relax/transform/convert_layout.cc | 33 +------ src/relax/transform/fuse_ops.cc | 23 ++--- src/relax/transform/fuse_tir.cc | 34 ++----- src/relax/transform/gradient.cc | 12 +-- .../transform/static_plan_block_memory.cc | 19 +--- src/relax/transform/to_mixed_precision.cc | 19 +--- src/script/printer/relax/expr.cc | 16 +--- 21 files changed, 101 insertions(+), 328 deletions(-) diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index d077fd896102..bb1b2c8dd74a 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -338,7 +338,7 @@ class TupleGetItemNode : public ExprNode { /*! \brief The tuple Expression */ Expr tuple; /*! \brief which value to get */ - Expr index; + int index; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("tuple_value", &tuple); @@ -358,29 +358,12 @@ class TupleGetItemNode : public ExprNode { hash_reduce(index); } - /* \brief Utility to get the exact index, if known - * - * In the most common case where index is known to be an integer, - * this utility allows it to be extracted. - * - * \return The known integer index, or NullOpt if unknown. - */ - Optional GetKnownIndex() const; - static constexpr const char* _type_key = "relax.expr.TupleGetItem"; TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode); }; class TupleGetItem : public Expr { public: - /*! - * \brief The constructor - * \param tuple The tuple to get an element from. - * \param index The index for extracting a value in the tuple. - * \param span The source span of the expression. - */ - TVM_DLL TupleGetItem(Expr tuple, Expr index, Span span = Span()); - /*! * \brief The constructor * \param tuple The tuple to get an element from. @@ -398,8 +381,9 @@ class TupleGetItem : public Expr { * Returns \p tuple_get_item if all properties are unchanged. Otherwise, returns a copy with the new * fields. */ -TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple = NullOpt, - Optional opt_index = NullOpt, Optional opt_span = NullOpt); +TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple = Optional(), + Optional opt_index = Optional(), + Optional opt_span = Optional()); /*! * \brief Base type of all (non-function) leaf Exprs. diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h index e879eccb2500..1ad5f02e0763 100644 --- a/include/tvm/relax/nested_msg.h +++ b/include/tvm/relax/nested_msg.h @@ -388,36 +388,26 @@ TargetType NestedMsgTo(NestedMsg msg, FMapLeaf fmapleaf, FCombine fcombine) { */ template Expr NestedMsgToExpr(NestedMsg msg, FType fmapleaf) { - return NestedMsgTo(msg, fmapleaf, [](Array arr) -> Expr { - if (arr.empty()) { - return Tuple(arr); - } - + return NestedMsgTo(msg, fmapleaf, [](Array arr) { Optional simplified_tuple; - for (size_t i = 0; i < arr.size(); ++i) { - auto* node = arr[i].as(); - if (node == nullptr) { - return Tuple(arr); - } - - auto index_sinfo = node->index->struct_info_.as(); - CHECK(index_sinfo && index_sinfo->dtype == DataType::Int(64)) - << "The index of TupleGetItem must be R.Prim('int64'), " - << "but expression " << GetRef(node) << " has index " << node->index - << " with struct info " << node->index->struct_info_; - - auto known_index = index_sinfo->value.as(); - if (!known_index || known_index->value != static_cast(i)) { - return Tuple(arr); - } - - if (simplified_tuple && !simplified_tuple.same_as(node->tuple)) { - return Tuple(arr); - } else if (!simplified_tuple) { - simplified_tuple = node->tuple; + bool simplified_flag = false; + if (arr.size() >= 1) { + simplified_flag = true; + for (size_t i = 0; i < arr.size() && simplified_flag; ++i) { + auto* node = arr[i].as(); + if (node == nullptr || node->index != static_cast(i)) { + simplified_flag = false; + } else { + if (simplified_tuple.defined()) { + simplified_flag &= (simplified_tuple == node->tuple); + } else { + simplified_tuple = node->tuple; + ICHECK(simplified_tuple.defined()); + } + } } } - return simplified_tuple.value(); + return simplified_flag ? simplified_tuple.value() : Tuple(arr); }); } diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 43c44d5ad291..d52e933d0277 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -370,7 +370,7 @@ class TupleGetItem(ExprWithOp): tuple_value: Expr The input tuple expression. - index: Union[int, Expr] + index: int The index. span: Optional[Span] @@ -381,10 +381,7 @@ class TupleGetItem(ExprWithOp): index: int span: Optional[Span] - def __init__(self, tuple_value: Expr, index: Union[int, Expr],span: Optional[Span] = None):): - if isinstance(index, int): - index = PrimValue(index) - + def __init__(self, tuple_value: Expr, index: int, span: Optional[Span] = None): self.__init_handle_by_constructor__( _ffi_api.TupleGetItem, tuple_value, index, span # type: ignore ) diff --git a/python/tvm/relax/transform/lazy_transform_params.py b/python/tvm/relax/transform/lazy_transform_params.py index ee680a9995a4..7f734f8a3c47 100644 --- a/python/tvm/relax/transform/lazy_transform_params.py +++ b/python/tvm/relax/transform/lazy_transform_params.py @@ -198,7 +198,8 @@ def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> relax.Expr: get_item_result = self.builder_.emit( relax.Call( relax.ExternFunc(self.func_creator.fget_item), - self.func_creator.extra_get_item_params + [tuple_get_item.index], + self.func_creator.extra_get_item_params + + [relax.PrimValue(tuple_get_item.index)], None, [relax.ObjectStructInfo()], ) diff --git a/src/contrib/msc/core/ir/graph_builder.cc b/src/contrib/msc/core/ir/graph_builder.cc index dd790b7f93ac..dab4ae813ea6 100644 --- a/src/contrib/msc/core/ir/graph_builder.cc +++ b/src/contrib/msc/core/ir/graph_builder.cc @@ -64,11 +64,7 @@ void RelaxFuncAttrGetter::VisitExpr_(const relax::CallNode* op) { } void RelaxFuncAttrGetter::VisitExpr_(const relax::TupleGetItemNode* op) { - if (auto known_index = op->GetKnownIndex()) { - attrs_.Set("index", std::to_string(known_index.value()->value)); - } else { - LOG(FATAL) << "MSC does not support TupleGetItem with dynamic index"; - } + attrs_.Set("index", std::to_string(op->index)); } void RelaxFuncValueGetter::VisitExpr_(const relax::CallNode* op) { @@ -284,9 +280,7 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional } else if (const auto* shape_node = expr.as()) { attrs.Set("shape", StringUtils::ToString(shape_node->values)); } else if (const auto* get_node = expr.as()) { - auto known_value = get_node->GetKnownIndex(); - ICHECK(known_value) << "MSC does not support TupleGetItem with dynamic index"; - attrs.Set("index", std::to_string(known_value.value()->value)); + attrs.Set("index", std::to_string(get_node->index)); } // Get scope diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc b/src/contrib/msc/core/transform/set_expr_layout.cc index 3f33590fb82e..d2023b886a0f 100644 --- a/src/contrib/msc/core/transform/set_expr_layout.cc +++ b/src/contrib/msc/core/transform/set_expr_layout.cc @@ -48,14 +48,10 @@ LayoutDecision InferLayoutDecision(const Expr& expr, const VarLayoutMap& var_lay } LayoutDecision InferLayoutDecisionAt(const Expr& expr, const VarLayoutMap& var_layout_map, - Expr index = PrimValue::Int64(0)) { + size_t index = 0) { const auto& nlayouts = InferNLayout(expr, var_layout_map); if (nlayouts.IsLeaf()) { - auto int_index = Downcast(index->struct_info_) - ->value.as() - .value_or(Integer(0)) - ->value; - return int_index == 0 ? nlayouts.LeafValue() : LayoutDecision(""); + return index == 0 ? nlayouts.LeafValue() : LayoutDecision(""); } const auto& nlayout = nlayouts.NestedArray()[0]; ICHECK(nlayout.IsLeaf()) << "Cannot get output layout for " << expr; @@ -719,7 +715,7 @@ InferLayoutOutput BackwardInferLayoutArgMaxMin(const Call& call, InferLayoutOutput BackwardInferLayoutBatchNorm(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = InferLayoutDecisionAt(call, var_layout_map); + LayoutDecision output_layout = InferLayoutDecisionAt(call, var_layout_map, 0); if (!output_layout->layout.defined()) { return InferLayoutOutput(); } @@ -753,7 +749,7 @@ InferLayoutOutput BackwardInferLayoutExpandDims(const Call& call, InferLayoutOutput BackwardInferLayoutNormalize(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = InferLayoutDecisionAt(call, var_layout_map); + LayoutDecision output_layout = InferLayoutDecisionAt(call, var_layout_map, 0); if (!output_layout->layout.defined()) { return InferLayoutOutput(); } diff --git a/src/contrib/msc/core/transform/set_expr_name.cc b/src/contrib/msc/core/transform/set_expr_name.cc index 3897e38ff93c..97850c70e8e8 100644 --- a/src/contrib/msc/core/transform/set_expr_name.cc +++ b/src/contrib/msc/core/transform/set_expr_name.cc @@ -145,23 +145,12 @@ class RelaxExprNameSetter : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) { ExprVisitor::VisitBinding_(binding, val); - - String unique_name = [&]() { - std::stringstream ss; - if (expr_names_.count(val->tuple)) { - ss << expr_names_[val->tuple]; - } else if (const auto* v_node = val->tuple.as()) { - ss << v_node->name_hint(); - } - ss << "."; - - if (auto known_index = val->GetKnownIndex()) { - ss << known_index.value()->value; - } - - return ss.str(); - }(); - + String unique_name; + if (expr_names_.count(val->tuple)) { + unique_name = expr_names_[val->tuple] + "." + std::to_string(val->index); + } else if (const auto* v_node = val->tuple.as()) { + unique_name = v_node->name_hint() + "." + std::to_string(val->index); + } if (unique_name != SpanUtils::GetAttr(val->span, "name")) { val->span = SpanUtils::SetAttr(val->span, "name", unique_name); } diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h b/src/relax/backend/contrib/codegen_json/codegen_json.h index cff040391301..7dc6ddc16227 100644 --- a/src/relax/backend/contrib/codegen_json/codegen_json.h +++ b/src/relax/backend/contrib/codegen_json/codegen_json.h @@ -363,11 +363,7 @@ class JSONSerializer : public relax::MemoizedExprTranslator { NodeEntries VisitExpr_(const TupleGetItemNode* gtn) { auto vtuple = VisitExpr(gtn->tuple); - if (auto known_index = gtn->GetKnownIndex()) { - return {vtuple[known_index.value()->value]}; - } else { - return vtuple; - } + return {vtuple[gtn->index]}; } NodeEntries VisitExpr_(const FunctionNode* fn) { diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 3aa7f06ba706..218fe6b1202c 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -282,10 +282,9 @@ class CodeGenVM : public ExprFunctor { Instruction::Arg VisitExpr_(const TupleGetItemNode* op) final { TupleGetItem expr = GetRef(op); - std::vector args = { - VisitExpr(expr->tuple), - VisitExpr(expr->index), - }; + std::vector args = {this->VisitExpr(expr->tuple)}; + + args.push_back(builder_->ConvertConstant(expr->index)); size_t dst_register = NewRegister(); builder_->EmitCall("vm.builtin.tuple_getitem", args, dst_register); diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index 804baf8f66f3..ec1678e9e0f3 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -323,10 +323,9 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { Optional VisitExpr_(const TupleGetItemNode* op) final { TupleGetItem expr = GetRef(op); - Array args = { - VisitExpr(expr->tuple).value(), - VisitExpr(expr->index).value(), - }; + Array args = {this->VisitExpr(expr->tuple).value()}; + + args.push_back(ConstInt64(expr->index)); int64_t dst_register = NewRegister(); this->EmitCallPacked("vm.builtin.tuple_getitem", args, dst_register); diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index dabe0070bfb8..f82612ce1662 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -685,14 +685,12 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorNormalizeArgument(op->tuple); - Expr new_index = this->NormalizeArgument(op->index); TupleGetItem node = [&]() { - if (new_tuple.same_as(op->tuple) && new_index.same_as(op->index) && - op->struct_info_.defined()) { + if (new_tuple.same_as(op->tuple) && op->struct_info_.defined()) { return GetRef(op); } else { - return TupleGetItem(new_tuple, new_index); + return TupleGetItem(new_tuple, op->index); } }(); @@ -700,8 +698,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctortuple << " (struct info = " << node->tuple->struct_info_ - << ") at index " << node->index << " (struct info = " << node->index->struct_info_ - << ") produced empty struct info"; + << ") at index " << node->index << " produced empty struct info"; return node; } diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index c2833680f1e4..9524c90b577c 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -358,19 +358,11 @@ bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const Expr bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr0) { auto expr = TryGetValOfVar(expr0, var2val_); - const auto* tuple_get_item_node = expr.as(); - if (!tuple_get_item_node) return false; - - bool is_correct_index = [&]() -> bool { - if (op->index == -1) return true; - - auto known_index = tuple_get_item_node->GetKnownIndex(); - if (!known_index) return false; - - return known_index.value()->value == op->index; - }(); - - return is_correct_index && VisitDFPattern(op->tuple, tuple_get_item_node->tuple); + if (const auto* tuple_get_item_node = expr.as()) { + return (op->index == -1 || op->index == tuple_get_item_node->index) && + VisitDFPattern(op->tuple, tuple_get_item_node->tuple); + } + return false; } bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr0) { diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 2791c7f169d1..00ad252ec4a4 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -165,91 +165,29 @@ Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional o return tuple; } -TupleGetItem::TupleGetItem(Expr tuple, int arg_index, Span span) - : TupleGetItem(tuple, PrimValue::Int64(arg_index), span) {} - -TupleGetItem::TupleGetItem(Expr tuple, Expr arg_index, Span span) { - auto index_sinfo = arg_index->struct_info_.as(); - CHECK(index_sinfo && index_sinfo->dtype == DataType::Int(64)) - << "TupleGetItem requires the index to be a R.Prim('int64'), " - << "but received " << arg_index << " with struct info " << arg_index->struct_info_; - - auto known_index = index_sinfo->value.as(); - - if (known_index) { - // If we know the index, we can check against the lower bound of - // zero. Checking the upper bound will require also knowing the - // tuple's size. - CHECK_GE(known_index->value, 0) - << "IndexError: " - << "Tuple " << tuple << " cannot be accessed with negative index " << arg_index; - } +TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { + CHECK_GE(index, 0) << "Index out of bounds: Tuple " << tuple + << " cannot be accessed with negative index " << index; + ObjectPtr n = make_object(); - auto* tuple_info = tuple->struct_info_.as(); - - Optional item_sinfo = NullOpt; - - if (known_index && tuple_info) { - // The exact index used to access the tuple is known. We can - // apply bounds-checking, and can provide the exact StructInfo of - // the accessed element. - int int_index = known_index->value; - - CHECK_LT(int_index, tuple_info->fields.size()) - << "IndexError: " - << "Tuple " << tuple << " is of size " << tuple_info->fields.size() - << ", and cannot be accessed with index " << int_index; - item_sinfo = tuple_info->fields[int_index]; - - } else if (tuple_info) { - // The exact index used to access the tuple is unknown. We can't - // apply bounds checking, but we can check that an index might - // exist. We can't provide an exact StructInfo for the accessed - // type, but we can provide the common base type of all items in - // the tuple. - CHECK_GT(tuple_info->fields.size(), 0) - << "IndexError: " - << "The exact value of index " << arg_index << " is unknown, " - << "but expression " << tuple << " has struct info " << tuple->struct_info_ << ". " - << "This is a tuple of length zero, and there is no index such that 0 <= index < 0."; - - StructInfo reduce_lca = tuple_info->fields[0]; - for (size_t i = 1; i < tuple_info->fields.size(); i++) { - reduce_lca = StructInfoLCA(reduce_lca, tuple_info->fields[1]); - } - item_sinfo = reduce_lca; + if (auto* tuple_info = tuple->struct_info_.as()) { + CHECK_LT(index, tuple_info->fields.size()) + << "Index out of bounds: Tuple " << tuple << " is of size " << tuple_info->fields.size() + << ", and cannot be accessed with index " << index; + auto sinfo = tuple_info->fields[index]; + n->struct_info_ = sinfo; + n->checked_type_ = GetStaticType(sinfo); } - - ObjectPtr n = make_object(); n->tuple = std::move(tuple); - n->index = arg_index; + n->index = index; n->span = std::move(span); - if (item_sinfo) { - n->struct_info_ = item_sinfo; - n->checked_type_ = GetStaticType(item_sinfo.value()); - } - data_ = std::move(n); } -Optional TupleGetItemNode::GetKnownIndex() const { - auto prim_sinfo = index->struct_info_.as(); - CHECK(prim_sinfo->dtype == DataType::Int(64)) - << "The index of TupleGetItem must be R.Prim('int64'), " - << "but expression " << GetRef(this) << " has index " << index << " with struct info " - << index->struct_info_; - - if (auto int_index = prim_sinfo->value.as()) { - return Integer(int_index.value()); - } else { - return NullOpt; - } -} - TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, - Optional opt_index, Optional opt_span) { + Optional opt_index, Optional opt_span) { Expr tuple = opt_tuple.value_or(tuple_get_item->tuple); - Expr index = opt_index.value_or(tuple_get_item->index); + Integer index = opt_index.value_or(tuple_get_item->index); Span span = opt_span.value_or(tuple_get_item->span); bool unchanged = tuple.same_as(tuple_get_item->tuple) && (index == tuple_get_item->index) && @@ -257,7 +195,7 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, if (!unchanged) { TupleGetItemNode* cow_tuple_get_item_node = tuple_get_item.CopyOnWrite(); cow_tuple_get_item_node->tuple = tuple; - cow_tuple_get_item_node->index = index; + cow_tuple_get_item_node->index = index.IntValue(); cow_tuple_get_item_node->span = span; } return tuple_get_item; @@ -265,7 +203,7 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, TVM_REGISTER_NODE_TYPE(TupleGetItemNode); -TVM_REGISTER_GLOBAL("relax.TupleGetItem").set_body_typed([](Expr tuple, Expr index, Span span) { +TVM_REGISTER_GLOBAL("relax.TupleGetItem").set_body_typed([](Expr tuple, int index, Span span) { return TupleGetItem(tuple, index, span); }); diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index 6ef07dd11c03..246b38f6f83b 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -138,8 +138,7 @@ class CanonicalizePlanner : public ExprVisitor { LOG(FATAL) << "Invalid binding type: " << binding->GetTypeKey(); } - // Unwrap TupleGetItem, if were know which Tuple is being - // accessed, and the index at which is is being accessed. + // Unwrap TupleGetItem, if the Tuple being accessed is known. if (auto tuple_get_item = value.as()) { Expr tuple = tuple_get_item->tuple; while (auto tuple_var = tuple.as()) { @@ -150,10 +149,8 @@ class CanonicalizePlanner : public ExprVisitor { } } - auto known_tuple = tuple.as(); - auto known_index = tuple_get_item->GetKnownIndex(); - if (known_tuple && known_index) { - value = known_tuple->fields[known_index.value()->value]; + if (auto ptr = tuple.as()) { + value = ptr->fields[tuple_get_item->index]; } } diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc index b2a742ee3ae2..6530d0d2cf0c 100644 --- a/src/relax/transform/convert_layout.cc +++ b/src/relax/transform/convert_layout.cc @@ -234,37 +234,10 @@ class LayoutConvertMutator : public ExprMutator { NLayout input_layout = binding->var->IsInstance() ? GetNLayout(var_layout_map_, val->tuple) : InitialNLayout(val->tuple); - - Expr new_tuple = RewriteExpr(val->tuple, input_layout); - Expr new_index = RewriteExpr(val->index, input_layout); - - ReEmitBinding(binding, builder_->Normalize(TupleGetItem(new_tuple, new_index))); - - NLayout item_layout = [&]() { - if (auto known_index = val->GetKnownIndex()) { - // Most common case, we know the index at which the tuple is - // being accessed. - return input_layout.NestedArray()[known_index.value()->value]; - } - - std::unordered_set unique_layouts; - for (const auto& layout : input_layout.NestedArray()) { - unique_layouts.insert(layout); - } - if (unique_layouts.size() == 1) { - // Fallback case. We don't know where we are accessing the - // tuple, but it doesn't matter because all elements in the - // tuple are being transformed. - return *unique_layouts.begin(); - } - - LOG(FATAL) << "Cannot determine the layout of " << GetRef(val) - << ". The index is unknown, and the tuple contains more than multiple layouts: " - << Array(unique_layouts.begin(), unique_layouts.end()); - }(); - + ReEmitBinding(binding, builder_->Normalize( + TupleGetItem(RewriteExpr(val->tuple, input_layout), val->index))); // update the layout map - var_layout_map_[binding->var] = item_layout; + var_layout_map_[binding->var] = input_layout.NestedArray()[val->index]; } void VisitBinding_(const MatchCastNode* binding) final { diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index f196a84a527c..b0eeba399e90 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -424,15 +424,7 @@ class FunctionCreator : public ExprMutator { if (partially_used_tuple_params_.find(tuple_item->tuple.get()) != partially_used_tuple_params_.end()) { // Appending get-item index to the mapping. - auto& used_indices = partially_used_tuple_params_[tuple_item->tuple.get()]; - if (auto known_index = tuple_item->GetKnownIndex()) { - used_indices.push_back(known_index.value()->value); - } else { - auto num_fields = Downcast(tuple_item->struct_info_)->fields.size(); - for (size_t i = 0; i < num_fields; i++) { - used_indices.push_back(i); - } - } + partially_used_tuple_params_[tuple_item->tuple.get()].push_back(tuple_item->index); } } @@ -502,15 +494,12 @@ class FunctionCreator : public ExprMutator { // Special handing for TupleGetItem. if (const auto* var_binding = binding.as()) { if (const auto* tuple_get_item = var_binding->value.as()) { - if (auto it = tuple_get_item_remap.find(tuple_get_item->tuple.get()); - it != tuple_get_item_remap.end()) { - auto opt_known_index = tuple_get_item->GetKnownIndex(); - ICHECK(opt_known_index) << "FuseOps requires static indices into tuples"; - int known_index = opt_known_index.value()->value; - ICHECK(it->second.find(known_index) != it->second.end()); - var_remap_[var_binding->var->vid] = it->second[known_index]; + auto it = tuple_get_item_remap.find(tuple_get_item->tuple.get()); + if (it != tuple_get_item_remap.end()) { + ICHECK(it->second.find(tuple_get_item->index) != it->second.end()); + var_remap_[var_binding->var->vid] = it->second[tuple_get_item->index]; if (auto output_idx = GetOutputIndex(binding->var)) { - outputs.Set(*output_idx, it->second[known_index]); + outputs.Set(*output_idx, it->second[tuple_get_item->index]); } continue; } diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index e2513de19f2b..df3c85c05ce1 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -27,7 +27,6 @@ #include "../../relay/analysis/graph_partitioner.h" #include "../../support/arena.h" -#include "../../support/ordered_set.h" #include "../../tir/ir/functor_common.h" namespace tvm { @@ -380,15 +379,7 @@ class FusedTIRConstructor : public ExprVisitor { PostOrderVisit(func->body, [=, &tuple_param](Expr e) { if (auto tup_get = e.as(); tup_get && tuple_param.count(tup_get->tuple.get())) { - auto& used_indices = func_info_.used_tuple_field_indices[tup_get->tuple.get()]; - if (auto known_index = tup_get->GetKnownIndex()) { - used_indices.insert(known_index.value()->value); - } else { - auto num_fields = Downcast(tup_get->struct_info_)->fields.size(); - for (size_t i = 0; i < num_fields; i++) { - used_indices.insert(i); - } - } + func_info_.used_tuple_field_indices[tup_get->tuple.get()].insert(tup_get->index); } }); @@ -532,18 +523,12 @@ class FusedTIRConstructor : public ExprVisitor { void VisitExpr_(const TupleGetItemNode* tuple_get_item) final { ExprVisitor::VisitExpr_(tuple_get_item); - auto it = func_info_.expr2buffers.find(tuple_get_item->tuple); if (it != func_info_.expr2buffers.end()) { - auto opt_known_index = tuple_get_item->GetKnownIndex(); - ICHECK(opt_known_index) << "FuseTIR requires all tuple indices to be known, " - << "but " << GetRef(tuple_get_item) << " has a dynamic index"; - auto tuple_index = opt_known_index.value()->value; - int begin_buf_idx = 0; int end_buf_idx = 0; const TupleType& tuple_type = Downcast(tuple_get_item->tuple->checked_type()); - for (int i = 0; i < tuple_index; ++i) { + for (int i = 0; i < tuple_get_item->index; ++i) { auto it = func_info_.used_tuple_field_indices.find(tuple_get_item->tuple.get()); // If this tuple is not passed as a parameter, or if the field at the index i is actually // used, the corresponding buffer needs to be taken into account by this function. @@ -551,7 +536,7 @@ class FusedTIRConstructor : public ExprVisitor { begin_buf_idx += GetTotalTensorSize(tuple_type->fields[i]); } } - end_buf_idx = begin_buf_idx + GetTotalTensorSize(tuple_type->fields[tuple_index]); + end_buf_idx = begin_buf_idx + GetTotalTensorSize(tuple_type->fields[tuple_get_item->index]); func_info_.expr2buffers.Set( GetRef(tuple_get_item), {(*it).second.begin() + begin_buf_idx, (*it).second.begin() + end_buf_idx}); @@ -866,20 +851,15 @@ class FusedTIRConstructor : public ExprVisitor { std::vector GetTupleAccessedIndices(const FunctionNode* func, const Var& tuple_var) { // Need to be ordered - support::OrderedSet indices; + std::vector indices; PostOrderVisit(func->body, [&indices, tuple_var](Expr e) { if (auto tup_get = e.as(); tup_get && tup_get->tuple.same_as(tuple_var)) { - if (auto known_index = tup_get->GetKnownIndex()) { - indices.insert(known_index.value()->value); - } else { - auto num_fields = Downcast(tup_get->struct_info_)->fields.size(); - for (size_t i = 0; i < num_fields; i++) { - indices.insert(i); - } + if (std::find(indices.begin(), indices.end(), tup_get->index) == indices.end()) { + indices.push_back(tup_get->index); } } }); - return std::vector(indices.begin(), indices.end()); + return indices; } /*! diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc index e48c659ca6ed..70e3e37876fd 100644 --- a/src/relax/transform/gradient.cc +++ b/src/relax/transform/gradient.cc @@ -418,21 +418,15 @@ class BackwardBindingGenerator : private ExprVisitor { auto* tuple_sinfo = GetStructInfoAs(tuple_get_item->tuple); ICHECK(tuple_sinfo) << "The tuple field of a TupleGetItem must has a TupleStructInfo"; - auto opt_index = tuple_get_item->GetKnownIndex(); - ICHECK(opt_index) << "relax.transform.Gradient requires all tuple indices to be known, " - << "but expression " << GetRef(tuple_get_item) << " has index " - << tuple_get_item->index << ", whose value isn't known."; - int index = opt_index.value()->value; - const Var& tuple_var = Downcast(tuple_get_item->tuple); if (adjoint_var_map_.count(tuple_var) == 0) { auto nested_zeros = Downcast(NestedZeros(GetRef(tuple_sinfo))); auto tuple_fields = nested_zeros->fields; - tuple_fields.Set(index, adjoint_var_map_[binding->var]); + tuple_fields.Set(tuple_get_item->index, adjoint_var_map_[binding->var]); EmitAdjoint(tuple_var, Tuple(tuple_fields), false); } else { - Expr updated_adjoint = - AddInTuple(adjoint_var_map_[tuple_var], index, adjoint_var_map_[binding->var]); + Expr updated_adjoint = AddInTuple(adjoint_var_map_[tuple_var], tuple_get_item->index, + adjoint_var_map_[binding->var]); EmitAdjoint(tuple_var, updated_adjoint, false); } } diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index b6c3d5396a1f..4a2a1555ff46 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -270,22 +270,9 @@ class StorageAllocatorBaseVisitor : public ExprVisitor { } ICHECK(tokens.IsNested()); Array field_tokens = tokens.NestedArray(); - - auto item_tokens = [&]() -> Tokens { - if (auto known_index = tuple_item->GetKnownIndex()) { - // If the tuple access is at a specific index, the field uses - // the token of that index. - int index = known_index.value()->value; - ICHECK_GT(static_cast(field_tokens.size()), index); - ICHECK_GE(index, 0); - return field_tokens[index]; - } else { - // If the tuple access is at an unknown index, the field may - // require any token from the tuple. - return tokens; - } - }(); - SetTokens(tuple_item, item_tokens); + ICHECK_GT(static_cast(field_tokens.size()), tuple_item->index); + ICHECK_GE(tuple_item->index, 0); + SetTokens(tuple_item, field_tokens[tuple_item->index]); } /******************** Utilities ********************/ diff --git a/src/relax/transform/to_mixed_precision.cc b/src/relax/transform/to_mixed_precision.cc index a999d8e6c2a1..c844d5935623 100644 --- a/src/relax/transform/to_mixed_precision.cc +++ b/src/relax/transform/to_mixed_precision.cc @@ -218,24 +218,15 @@ class DTypeDecisionCollector : public ExprVisitor { // require the i-th field rhs tuple to be the type of the lhs NType lhs_type = GetDType(binding->var); std::vector require_rhs; - const TupleStructInfoNode* sinfo = tuple_get_item_node->tuple->struct_info_.as(); ICHECK(sinfo != nullptr) << "TupleGetItemNode must have TupleStructInfo"; - - auto opt_known_index = tuple_get_item_node->GetKnownIndex(); - ICHECK(opt_known_index) << "ToMixedPrecision pass does not support dynamic tuple indices"; - size_t known_index = opt_known_index.value()->value; - for (size_t i = 0; i < sinfo->fields.size(); ++i) { - NType field_type = [&]() { - if (i == known_index) { - return lhs_type; - } else { - return NTypeFrom(sinfo->fields[i], unknown_); - } - }(); - require_rhs.push_back(field_type); + if (i == static_cast(tuple_get_item_node->index)) { + require_rhs.push_back(lhs_type); + } else { + require_rhs.push_back(NTypeFrom(sinfo->fields[i], unknown_)); + } } RequireArgsToType({tuple_get_item_node->tuple}, {NType(require_rhs)}); } diff --git a/src/script/printer/relax/expr.cc b/src/script/printer/relax/expr.cc index c80a6d370509..f6cbde0b4b23 100644 --- a/src/script/printer/relax/expr.cc +++ b/src/script/printer/relax/expr.cc @@ -30,17 +30,8 @@ namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::PrimValue n, ObjectPath n_p, IRDocsifier d) -> Doc { - auto path = n_p->Attr("value"); - - // Special case to print `R.prim_value(0)` as `0`, since it - // would be converted back to `R.prim_value` on parsing. - if (d->cfg->syntax_sugar && n->value->dtype == DataType::Int(64)) { - if (auto as_int = n->value.as()) { - return LiteralDoc::Int(as_int->value, path); - } - } // TODO(@junrushao): float numbers - return Relax(d, "prim_value")->Call({d->AsDoc(n->value, path)}); + return Relax(d, "prim_value")->Call({d->AsDoc(n->value, n_p->Attr("value"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) @@ -73,9 +64,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::TupleGetItem n, ObjectPath n_p, IRDocsifier d) -> Doc { - auto tuple_doc = d->AsDoc(n->tuple, n_p->Attr("tuple")); - auto index_doc = d->AsDoc(n->index, n_p->Attr("index")); - return tuple_doc[{index_doc}]; + ExprDoc idx = LiteralDoc::Int(n->index, n_p->Attr("index")); + return d->AsDoc(n->tuple, n_p->Attr("tuple"))[{idx}]; }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) From b4245fa2310ac86f524c92c898c10d3e1ace0e4d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 6 Nov 2023 12:37:37 -0600 Subject: [PATCH 11/21] Re-implement dynamic tuple access in terms of intrinsic --- include/tvm/relax/expr.h | 33 +++++ python/tvm/relax/expr.py | 17 +-- src/relax/backend/vm/codegen_vm.cc | 9 ++ src/relax/backend/vm/codegen_vm_tir.cc | 9 ++ src/relax/ir/expr.cc | 7 +- src/relax/op/tuple.cc | 163 +++++++++++++++++++++++++ 6 files changed, 223 insertions(+), 15 deletions(-) create mode 100644 src/relax/op/tuple.cc diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index bb1b2c8dd74a..66fa2e8d2353 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -287,6 +287,25 @@ If WithFields(If if_expr, Optional opt_cond = Optional(), Optional opt_false_branch = Optional(), Optional opt_span = Optional()); +/*! \brief Perform tuple access + * + * Use of this method is recommended, rather than constructing a + * `TupleGetItem` directly. + * + * 1. May resolve to the tuple's contents, avoiding the intermediate + * `TupleGetItem`. + * + * 2. Handles access of a tuple at a dynamic index, where + * `TupleGetItem` requires a statically-known index. + * + * \param tuple The tuple to be accessed + * + * \param index The index at which the access occurs + * + * \return An expression for the access of the tuple + */ +Expr tuple_get_item(Expr tuple, Expr index); + /*! \brief Tuple container */ class TupleNode : public ExprNode { public: @@ -320,6 +339,20 @@ class Tuple : public Expr { */ TVM_DLL explicit Tuple(tvm::Array fields, Span span = Span()); + /*! \brief Helper to delegate access to the tuple + * + * The `tuple_get_item` can be applied to any `relax::Expr`. + * However, this helper function is only provided for + * `relax::Tuple`, because `relax::Expr` is a typedef for + * `RelayExpr`, and we should avoid updating relay classes to + * provide relax-specific functionality.. + * + * \param index The index at which the tuple is accessed + * + * \return The contents of the tuple at the specified index + */ + inline Expr operator[](Expr index) { return tuple_get_item(*this, index); } + TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode); }; diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index d52e933d0277..3ccdddc28287 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -215,7 +215,7 @@ def __call__(self, *args: List[Expr], attrs: Optional[Dict[str, Any]] = None) -> """ return Call(self, args, attrs=attrs) - def __getitem__(self, index: int) -> "ExprWithOp": + def __getitem__(self, index: Union[Expr, PrimExpr, int]) -> "ExprWithOp": """Get the i-th element of the tuple or Expr with TupleType. Parameters @@ -232,17 +232,10 @@ def __getitem__(self, index: int) -> "ExprWithOp": result: ExprWithOp The result expression. """ - try: - return TupleGetItem(self, index) - except tvm.TVMError as err: - # For Python objects with __getitem__, but without - # __len__, tuple unpacking is done by iterating over - # sequential indices until IndexError is raised. - # Therefore, convert from TVMError to IndexError for - # compatibility. - if "Index out of bounds" in err.args[0]: - raise IndexError from err - raise + if not isinstance(index, Expr): + index = PrimValue(index) + + return _ffi_api.tuple_get_item(self, index) @tvm._ffi.register_object("relax.expr.Call") diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 218fe6b1202c..a9402149b9a2 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -171,6 +171,8 @@ class CodeGenVM : public ExprFunctor { EmitAllocTensor(call, dst_reg); } else if (call_node->op == kill_object_op_) { dst_reg = EmitKillObject(call); + } else if (call_node->op == tuple_getitem_op_) { + EmitTupleAccess(call, dst_reg); } else { // every "normal" operator is lowered to a global var in the IRModule. The Attrs for those // ops are handled in a pass when lowering them to TIR. @@ -373,6 +375,12 @@ class CodeGenVM : public ExprFunctor { return dst_reg; } + void EmitTupleAccess(const Call& call_node, RegName dst_register) { + ICHECK_EQ(call_node->args.size(), 2); + std::vector args = VisitArray(call_node->args); + builder_->EmitCall("vm.builtin.tuple_getitem", args, dst_register); + } + void EmitCallBuiltinWithCtx(const Call& call_node, RegName dst_reg) { std::vector args; args.push_back(Instruction::Arg::Register(Instruction::kVMRegister)); @@ -425,6 +433,7 @@ class CodeGenVM : public ExprFunctor { const Op& kill_object_op_ = Op::Get("relax.vm.kill_object"); const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx"); const Op& null_value_op_ = Op::Get("relax.null_value"); + const Op& tuple_getitem_op_ = Op::Get("relax.tuple_get_item_dyn"); }; /*! diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index ec1678e9e0f3..baee0062e99d 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -240,6 +240,8 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { EmitAllocTensor(call, dst_reg); } else if (call_node->op == kill_object_op_) { dst_reg = EmitKillObject(call); + } else if (call_node->op == tuple_getitem_op_) { + EmitTupleAccess(call, dst_reg); } else { // every "normal" operator is lowered to a global var in the IRModule. The Attrs for those // ops are handled in a pass when lowering them to TIR. @@ -433,6 +435,12 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { return dst_reg; } + void EmitTupleAccess(const Call& call_node, int64_t dst_register) { + ICHECK_EQ(call_node->args.size(), 2); + auto args = call_node->args.Map([this](Expr expr) { return VisitExpr(expr).value(); }); + EmitCallPacked("vm.builtin.tuple_getitem", args, dst_register); + } + void EmitCallBuiltinWithCtx(const Call& call_node, int64_t dst_reg) { Array args; // if context is required, pass as first argument. @@ -522,6 +530,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { const Op& kill_object_op_ = Op::Get("relax.vm.kill_object"); const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx"); const Op& null_value_op_ = Op::Get("relax.null_value"); + const Op& tuple_getitem_op_ = Op::Get("relax.tuple_get_item_dyn"); }; /*! diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 00ad252ec4a4..1a1d41e6ced3 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -166,13 +166,14 @@ Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional o } TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { - CHECK_GE(index, 0) << "Index out of bounds: Tuple " << tuple - << " cannot be accessed with negative index " << index; + CHECK_GE(index, 0) << "IndexError: " + << "Tuple " << tuple << " cannot be accessed with negative index " << index; ObjectPtr n = make_object(); if (auto* tuple_info = tuple->struct_info_.as()) { CHECK_LT(index, tuple_info->fields.size()) - << "Index out of bounds: Tuple " << tuple << " is of size " << tuple_info->fields.size() + << "IndexError: " + << "Tuple " << tuple << " is of size " << tuple_info->fields.size() << ", and cannot be accessed with index " << index; auto sinfo = tuple_info->fields[index]; n->struct_info_ = sinfo; diff --git a/src/relax/op/tuple.cc b/src/relax/op/tuple.cc new file mode 100644 index 000000000000..2412a588173b --- /dev/null +++ b/src/relax/op/tuple.cc @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/op/tuple.cc + * + * builtin intrinsic operators for manipulating tuples + */ +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +namespace { +/*! \brief Utility function for NormalizeTupleGetItem and tuple_get_item + * + * \param index The index at which the tuple is accessed + * + * \return The known index, if static, otherwise std::nullopt. + */ +std::optional FindStaticIndex(const Expr& index) { + if (auto index_sinfo = index->struct_info_.as()) { + if (auto known_index = index_sinfo->value.as()) { + return known_index->value; + } + } + return std::nullopt; +} +} // namespace + +StructInfo InferStructInfoTupleGetItem(const Call& call, const BlockBuilder&) { + CHECK_EQ(call->args.size(), 2) << "Operator " << call->op + << " expects exactly two arguments [tuple, index], " + << "but received " << call->args.size() + << " arguments in expression " << call; + auto tuple = call->args[0]; + auto index = call->args[1]; + + auto tuple_sinfo = tuple->struct_info_.as(); + CHECK(tuple_sinfo) << "Operator " << call->op + << " expects its first argument to specify a tuple, " + << "but expression " << call << " has tuple argument " << tuple + << ", which has struct info " << tuple->struct_info_; + + auto index_sinfo = index->struct_info_.as(); + CHECK(index_sinfo && index_sinfo->dtype == DataType::Int(64)) + << "TupleGetItem requires the index to be a R.Prim('int64'), " + << "but expression " << call << " has index argument " << index << ", which has struct info " + << index->struct_info_; + + auto known_index = index_sinfo->value.as(); + + if (known_index) { + // The exact index used to access the tuple is known. We can + // apply bounds-checking, and can provide the exact StructInfo of + // the accessed element. + int int_index = known_index->value; + + CHECK_GE(int_index, 0) << "IndexError: " + << "Operator " << call->op << " attempted to access tuple " << tuple + << " at index " << index << ". " + << "However, the index " << index << " is known to be " << int_index + << ", and negative indices are not allowed."; + + CHECK_LT(int_index, tuple_sinfo->fields.size()) + << "IndexError: " + << "Operator " << call->op << " attempted to access tuple " << tuple << " at index " + << index << ". " + << "However, tuple " << tuple << " is of size " << tuple_sinfo->fields.size() + << ", the index expression has a known value of " << int_index + << ", outside the bounds of the tuple"; + return tuple_sinfo->fields[int_index]; + + } else { + // The exact index used to access the tuple is unknown. We can't + // apply bounds checking, but we can check that an index might + // exist. We can't provide an exact StructInfo for the accessed + // type, but we can provide the common base type of all items in + // the tuple. + CHECK_GT(tuple_sinfo->fields.size(), 0) + << "IndexError: " + << "The exact value of index " << index << " is unknown, " + << "but expression " << tuple << " has struct info " << tuple->struct_info_ << ". " + << "This is a tuple of length zero, and there is no index such that 0 <= index < 0."; + + StructInfo reduce_lca = tuple_sinfo->fields[0]; + for (size_t i = 1; i < tuple_sinfo->fields.size(); i++) { + reduce_lca = StructInfoLCA(reduce_lca, tuple_sinfo->fields[1]); + } + return reduce_lca; + } +} + +Expr NormalizeTupleGetItem(const BlockBuilder&, const Call& call) { + ICHECK_EQ(call->args.size(), 2); + auto tuple = call->args[0]; + auto index = call->args[1]; + + if (auto index_sinfo = index->struct_info_.as()) { + if (auto known_index = index_sinfo->value.as()) { + return TupleGetItem(tuple, known_index->value); + } + } + return std::move(call); +} + +RELAY_REGISTER_OP("relax.tuple_get_item_dyn") + .set_num_inputs(2) + .add_argument("tuple", "Expr (R.Tuple([...]))", "The tuple to access") + .add_argument("index", "Expr (R.Prim(dtype='int64'))", + "The index at which to access the tuple.") + .set_attr("FInferStructInfo", InferStructInfoTupleGetItem) + .set_attr("FNormalize", NormalizeTupleGetItem) + .set_attr("FPurity", Bool(true)); + +Expr tuple_get_item(Expr tuple, Expr index) { + auto opt_static_index = FindStaticIndex(index); + auto known_tuple = tuple.as(); + + if (opt_static_index && known_tuple) { + // Both the tuple and index are known. We can return the accessed + // expression directly. + return known_tuple->fields[opt_static_index.value()]; + } else if (opt_static_index) { + // The index is known, but the tuple is bound to a variable. We + // can return a static TupleGetItem, which is useful in many + // passes. + return TupleGetItem(tuple, opt_static_index.value()); + } else { + // The index isn't known, so fall back to the most general case. + // If a later pass (e.g. BindParams) provides a statically-known + // index, then this will be normalized back to a TupleGetItem at + // that point. + static const auto op = Op::Get("relax.tuple_get_item_dyn"); + return Call(op, {tuple, index}); + } +} + +TVM_REGISTER_GLOBAL("relax.tuple_get_item").set_body_typed(tuple_get_item); + +} // namespace relax +} // namespace tvm From 00c4affc77ec2ccfd07f98eb638eda1ba23f27e8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 6 Nov 2023 12:38:30 -0600 Subject: [PATCH 12/21] Rename unit test file --- .../relax/{test_vm_tuple_get_item.py => test_tuple_get_item.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/python/relax/{test_vm_tuple_get_item.py => test_tuple_get_item.py} (100%) diff --git a/tests/python/relax/test_vm_tuple_get_item.py b/tests/python/relax/test_tuple_get_item.py similarity index 100% rename from tests/python/relax/test_vm_tuple_get_item.py rename to tests/python/relax/test_tuple_get_item.py From b9bf03b9a3899ecd6a369b04af340d839793b6ef Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 6 Nov 2023 13:25:12 -0600 Subject: [PATCH 13/21] Implement printing/parsing of dynamic tuple indices --- python/tvm/relax/expr.py | 5 +- python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/tuple.py | 93 +++++++++++++++++++++++ python/tvm/script/ir_builder/relax/ir.py | 4 + src/relax/op/tuple.cc | 2 +- src/script/printer/relax/call.cc | 22 ++++++ tests/python/relax/test_tuple_get_item.py | 34 +++++++++ 7 files changed, 156 insertions(+), 5 deletions(-) create mode 100644 python/tvm/relax/op/tuple.py diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 3ccdddc28287..0be9d1a6acea 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -232,10 +232,7 @@ def __getitem__(self, index: Union[Expr, PrimExpr, int]) -> "ExprWithOp": result: ExprWithOp The result expression. """ - if not isinstance(index, Expr): - index = PrimValue(index) - - return _ffi_api.tuple_get_item(self, index) + return tvm.relax.op.tuple_get_item(self, index) @tvm._ffi.register_object("relax.expr.Call") diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 60a4332d838c..61535df8ec37 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -101,6 +101,7 @@ from .set import unique from .statistical import cumsum, max, mean, min, prod, std, sum, variance from .ternary import ewise_fma +from .tuple import tuple_get_item, tuple_get_item_dyn from .unary import ( abs, acos, diff --git a/python/tvm/relax/op/tuple.py b/python/tvm/relax/op/tuple.py new file mode 100644 index 000000000000..39661a6b93ca --- /dev/null +++ b/python/tvm/relax/op/tuple.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tuple operators.""" +from typing import Union + +import tvm +from tvm.ir.expr import PrimExpr + +from . import _ffi_api +from ..expr import Expr, PrimValue + + +def tuple_get_item(tuple: Expr, index: Union[int, PrimExpr, Expr]) -> Expr: + """Perform tuple access + + Use of this method is recommended, rather than constructing a + `relax.TupleGetItem` directly. + + 1. May resolve to the tuple's contents, avoiding the intermediate + `TupleGetItem`. + + 2. Handles access of a tuple at a dynamic index, where + `TupleGetItem` requires a statically-known index. + + Parameters + ---------- + tuple: Expr + + The tuple to be accessed. The tuple is not required to be an + in-line `relax.Tuple`, but must have `TupleStructInfo` + + index: Union[int, PrimExpr, Expr] + + The index at which the tuple is accessed. The index may be + static or dynamic. + + Returns + ------- + Expr + + An expression representing the item in the tuple. + """ + + if not isinstance(index, Expr): + index = PrimValue(index) + + return _ffi_api.tuple_get_item(tuple, index) # type: ignore + + +def tuple_get_item_dyn(tuple: Expr, index: Union[int, PrimExpr, Expr]) -> Expr: + """Explicitly generate a call to tuple_get_item_dyn + + This method is not recommended for general use, and is provided to + ensure round-trip consistency in TVMScript. In most cases, the + `tuple_get_item` method should be used, which will delegate to the + dynamic builtin for cases where the index is dynamic. + + Parameters + ---------- + tuple: Expr + + The tuple to be accessed. The tuple is not required to be an + in-line `relax.Tuple`, but must have `TupleStructInfo` + + index: Union[int, PrimExpr, Expr] + + The index at which the tuple is accessed. The index may be + static or dynamic. + + Returns + ------- + Expr + + An expression representing the item in the tuple. + + """ + if not isinstance(index, Expr): + index = PrimValue(index) + return tvm.relax.Call(tvm.ir.Op.get("relax.tuple_get_item_dyn"), [tuple, index]) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 142d0e6d96aa..6068880d3a96 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -146,6 +146,8 @@ tile, tril, triu, + tuple_get_item, + tuple_get_item_dyn, unique, vm, where, @@ -775,6 +777,8 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "tril", "triu", "tuple", + "tuple_get_item", + "tuple_get_item_dyn", "unique", "variance", "vm", diff --git a/src/relax/op/tuple.cc b/src/relax/op/tuple.cc index 2412a588173b..6defce67cd71 100644 --- a/src/relax/op/tuple.cc +++ b/src/relax/op/tuple.cc @@ -157,7 +157,7 @@ Expr tuple_get_item(Expr tuple, Expr index) { } } -TVM_REGISTER_GLOBAL("relax.tuple_get_item").set_body_typed(tuple_get_item); +TVM_REGISTER_GLOBAL("relax.op.tuple_get_item").set_body_typed(tuple_get_item); } // namespace relax } // namespace tvm diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc index 785dc6d96320..3d85477d1ab1 100644 --- a/src/script/printer/relax/call.cc +++ b/src/script/printer/relax/call.cc @@ -249,6 +249,24 @@ Optional PrintRelaxPrint(const relax::Call& n, const ObjectPath& n_p, return Relax(d, "print")->Call(args, {"format"}, {first_arg}); } +Optional PrintTupleGetItem(const relax::Call& call, const ObjectPath& path, + const IRDocsifier& doc) { + static const Op& print_op = Op::Get("relax.tuple_get_item_dyn"); + if (!call->op.same_as(print_op)) { + return NullOpt; + } + + if (!doc->cfg->syntax_sugar) { + // Fall back to the default printing for builtins as `R.tuple_get_item_dyn` + return NullOpt; + } + + ICHECK_EQ(call->args.size(), 2); + ExprDoc tuple = doc->AsDoc(call->args[0], path->Attr("args")->ArrayIndex(0)); + ExprDoc index = doc->AsDoc(call->args[1], path->Attr("args")->ArrayIndex(1)); + return tuple[{index}]; +} + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::Call n, ObjectPath n_p, IRDocsifier d) -> Doc { @@ -272,6 +290,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (Optional doc = PrintRelaxPrint(n, n_p, d)) { return doc.value(); } + // Special case: tuple_get_item_dyn + if (Optional doc = PrintTupleGetItem(n, n_p, d)) { + return doc.value(); + } ExprDoc prefix{nullptr}; Array args; Array kwargs_keys; diff --git a/tests/python/relax/test_tuple_get_item.py b/tests/python/relax/test_tuple_get_item.py index a441f0d85fb2..ba8a9a6eb836 100644 --- a/tests/python/relax/test_tuple_get_item.py +++ b/tests/python/relax/test_tuple_get_item.py @@ -22,6 +22,8 @@ from tvm import relax from tvm.script import relax as R, tir as T +import pytest + exec_mode = tvm.testing.parameter("bytecode", "compiled") tuple_type_annotation = tvm.testing.parameter( @@ -33,6 +35,8 @@ tuple_index_type = tvm.testing.parameter("static", "dynamic") +syntax_sugar = tvm.testing.parameter(by_dict={"sugared": True, "unsugared": False}) + def test_vm_tuple_get_item(exec_mode, tuple_type_annotation, tuple_index_type): def access_tuple(tuple_obj, dyn_index): @@ -61,5 +65,35 @@ def func(arg: tuple_type_annotation, index_param: R.Prim(value="index_var")): assert res == 17 +def test_dynamic_index_printing(syntax_sugar: bool): + """Check syntax-sugar for dynamic tuple indices + + The "relax.tuple_get_item_dyn" operator should be printed as + `my_tuple[my_index]` by default, which will regenerate the + original operator when parsed. If syntax sugar is disabled, it + should display the `R.tuple_get_item_dyn` directly. + """ + + @R.function(private=True) + def func( + arg_tuple: R.Tuple([R.Prim("int64"), R.Prim("float32")]), + arg_index: R.Prim(value="index_var"), + ): + return arg_tuple[arg_index] + + script = func.script(syntax_sugar=syntax_sugar) + + if syntax_sugar: + assert "arg_tuple[arg_index]" in script + assert "tuple_get_item_dyn" not in script + else: + assert "arg_tuple[arg_index]" not in script + assert "tuple_get_item_dyn" in script + + roundtrip = tvm.script.from_source(script) + + tvm.ir.assert_structural_equal(func, roundtrip) + + if __name__ == "__main__": tvm.testing.main() From 1dba4398518f14df76767c8587a61479785e8446 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 7 Nov 2023 08:19:00 -0600 Subject: [PATCH 14/21] fix lint errors --- python/tvm/relax/op/tuple.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/op/tuple.py b/python/tvm/relax/op/tuple.py index 39661a6b93ca..e04cbf63be2d 100644 --- a/python/tvm/relax/op/tuple.py +++ b/python/tvm/relax/op/tuple.py @@ -24,7 +24,7 @@ from ..expr import Expr, PrimValue -def tuple_get_item(tuple: Expr, index: Union[int, PrimExpr, Expr]) -> Expr: +def tuple_get_item(tuple_expr: Expr, index: Union[int, PrimExpr, Expr]) -> Expr: """Perform tuple access Use of this method is recommended, rather than constructing a @@ -38,7 +38,7 @@ def tuple_get_item(tuple: Expr, index: Union[int, PrimExpr, Expr]) -> Expr: Parameters ---------- - tuple: Expr + tuple_expr: Expr The tuple to be accessed. The tuple is not required to be an in-line `relax.Tuple`, but must have `TupleStructInfo` @@ -58,10 +58,10 @@ def tuple_get_item(tuple: Expr, index: Union[int, PrimExpr, Expr]) -> Expr: if not isinstance(index, Expr): index = PrimValue(index) - return _ffi_api.tuple_get_item(tuple, index) # type: ignore + return _ffi_api.tuple_get_item(tuple_expr, index) # type: ignore -def tuple_get_item_dyn(tuple: Expr, index: Union[int, PrimExpr, Expr]) -> Expr: +def tuple_get_item_dyn(tuple_expr: Expr, index: Union[int, PrimExpr, Expr]) -> Expr: """Explicitly generate a call to tuple_get_item_dyn This method is not recommended for general use, and is provided to @@ -71,7 +71,7 @@ def tuple_get_item_dyn(tuple: Expr, index: Union[int, PrimExpr, Expr]) -> Expr: Parameters ---------- - tuple: Expr + tuple_expr: Expr The tuple to be accessed. The tuple is not required to be an in-line `relax.Tuple`, but must have `TupleStructInfo` @@ -90,4 +90,4 @@ def tuple_get_item_dyn(tuple: Expr, index: Union[int, PrimExpr, Expr]) -> Expr: """ if not isinstance(index, Expr): index = PrimValue(index) - return tvm.relax.Call(tvm.ir.Op.get("relax.tuple_get_item_dyn"), [tuple, index]) + return tvm.relax.Call(tvm.ir.Op.get("relax.tuple_get_item_dyn"), [tuple_expr, index]) From 8f43de7b8bd248243a349b72a59259838ee91097 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 7 Nov 2023 16:01:06 -0600 Subject: [PATCH 15/21] Revert "[PR-15983][FFI] Allow IntImm arguments to PackedFunc with int parameter" This reverts commit 61f6322bf3887e9d4497f1e1ae97b39617d03993. --- include/tvm/ir/expr.h | 36 --------------- include/tvm/runtime/packed_func.h | 74 ------------------------------- tests/cpp/packed_func_test.cc | 60 ------------------------- 3 files changed, 170 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index eec2811a4764..594e2b86e9f9 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -540,24 +540,6 @@ class IntImm : public PrimExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(IntImmNode); }; -/* \brief FFI extention, ObjectRef to integer conversion - * - * If a PackedFunc expects an integer type, and the user passes an - * IntImm as the argument, this specialization allows it to be - * converted by the FFI. - */ -template -struct runtime::PackedFuncObjectRefConverter>> { - static std::optional TryFrom(const ObjectRef& obj) { - if (auto ptr = obj.as()) { - return ptr->value; - } else { - return std::nullopt; - } - } -}; - /*! * \brief Constant floating point literals in the program. * \sa FloatImm @@ -605,24 +587,6 @@ class FloatImm : public PrimExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(FloatImmNode); }; -/* \brief FFI extention, ObjectRef to integer conversion - * - * If a PackedFunc expects an integer type, and the user passes an - * IntImm as the argument, this specialization allows it to be - * converted by the FFI. - */ -template -struct runtime::PackedFuncObjectRefConverter< - FloatType, std::enable_if_t>> { - static std::optional TryFrom(const ObjectRef& obj) { - if (auto ptr = obj.as()) { - return ptr->value; - } else { - return std::nullopt; - } - } -}; - /*! * \brief Boolean constant. * diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index d6e3bcd44d1a..7266f8c4a50a 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -37,7 +37,6 @@ #include #include #include -#include #include #include #include @@ -546,42 +545,6 @@ struct ObjectTypeChecker> { } }; -class TVMPODValue_; - -/*! - * \brief Type trait to specify special value conversion rules from - * ObjectRef to primitive types. - * - * TVM containers, such as tvm::runtime::Array, require the contained - * objects to inherit from ObjectRef. As a result, the wrapper types - * IntImm, FloatImm, and StringImm are often used to hold primitive - * types inside a TVM container. Conversions into this type may be - * required when using a container, and may be performed - * automatically when passing an object across the FFI. By also - * handling conversions from wrapped to unwrapped types, these - * conversions can be transparent to users. - * - * The trait can be specialized to add type specific conversion logic - * from the TVMArgvalue and TVMRetValue. - * - * \tparam T The type (e.g. int64_t) which may be contained within the - * ObjectRef. - * - * \tparam (anonymous) An anonymous and unused type parameter, which - * may be used for SFINAE. - */ -template -struct PackedFuncObjectRefConverter { - /*! - * \brief Attempt to convert an ObjectRef from an argument value. - * - * \param obj The ObjectRef which may be convertible to T - * - * \return The converted result, or std::nullopt if not convertible. - */ - static std::optional TryFrom(const ObjectRef& obj) { return std::nullopt; } -}; - /*! * \brief Internal base class to * handle conversion to POD values. @@ -594,41 +557,25 @@ class TVMPODValue_ { // the frontend while the API expects a float. if (type_code_ == kDLInt) { return static_cast(value_.v_int64); - } else if (auto opt = ThroughObjectRef()) { - return opt.value(); - } else if (auto opt = ThroughObjectRef()) { - return opt.value(); } TVM_CHECK_TYPE_CODE(type_code_, kDLFloat); return value_.v_float64; } operator int64_t() const { - if (auto opt = ThroughObjectRef()) { - return opt.value(); - } TVM_CHECK_TYPE_CODE(type_code_, kDLInt); return value_.v_int64; } operator uint64_t() const { - if (auto opt = ThroughObjectRef()) { - return opt.value(); - } TVM_CHECK_TYPE_CODE(type_code_, kDLInt); return value_.v_int64; } operator int() const { - if (auto opt = ThroughObjectRef()) { - return opt.value(); - } TVM_CHECK_TYPE_CODE(type_code_, kDLInt); ICHECK_LE(value_.v_int64, std::numeric_limits::max()); ICHECK_GE(value_.v_int64, std::numeric_limits::min()); return static_cast(value_.v_int64); } operator bool() const { - if (auto opt = ThroughObjectRef()) { - return opt.value(); - } TVM_CHECK_TYPE_CODE(type_code_, kDLInt); return value_.v_int64 != 0; } @@ -699,27 +646,6 @@ class TVMPODValue_ { TVMValue value_; /*! \brief the type code */ int type_code_; - - private: - /* \brief A utility function to check for conversions through - * PackedFuncObjectRefConverter - * - * \tparam T The type to attempt to convert into - * - * \return The converted type, or std::nullopt if the value cannot - * be converted into T. - */ - template - std::optional ThroughObjectRef() const { - if (IsObjectRef()) { - if (std::optional from_obj = - PackedFuncObjectRefConverter::TryFrom(AsObjectRef())) { - return from_obj.value(); - } - } - - return std::nullopt; - } }; /*! diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index 17f194520aa5..183aca1385a7 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -319,63 +319,3 @@ TEST(TypedPackedFunc, RValue) { tf(1, true); } } - -TEST(TypedPackedFunc, IntImmWrapper) { - using namespace tvm::runtime; - - TypedPackedFunc typed_func = [](int x) {}; - PackedFunc func = typed_func; - - // Integer argument may be provided - func(5); - - // IntImm argument may be provided, automatically unwrapped. - tvm::IntImm lvalue_intimm(DataType::Int(32), 10); - func(lvalue_intimm); - - // Unwrapping of IntImm argument works for rvalues as well - func(tvm::IntImm(DataType::Int(32), 10)); -} - -TEST(TypedPackedFunc, FloatImmWrapper) { - using namespace tvm::runtime; - - TypedPackedFunc typed_func = [](double x) {}; - PackedFunc func = typed_func; - - // Argument may be provided as a floating point. If provided as an - // integer, it will be converted to a float. - func(static_cast(5.0)); - func(static_cast(5)); - - // IntImm and FloatImm arguments may be provided, and are - // automatically unwrapped. These arguments work correctly for - // either lvalue or rvalue arguments. - - tvm::IntImm lvalue_intimm(DataType::Int(32), 10); - tvm::FloatImm lvalue_floatimm(DataType::Float(32), 10.5); - - func(lvalue_intimm); - func(lvalue_floatimm); - func(tvm::IntImm(DataType::Int(32), 10)); - func(tvm::FloatImm(DataType::Float(32), 10.5)); -} - -TEST(TypedPackedFunc, BoolWrapper) { - using namespace tvm::runtime; - - TypedPackedFunc typed_func = [](bool x) {}; - PackedFunc func = typed_func; - - // Argument may be provided as a floating point. If provided as an - // integer, it will be converted to a float. - func(true); - - tvm::IntImm lvalue_intimm(DataType::Int(32), 10); - func(lvalue_intimm); - func(tvm::IntImm(DataType::Int(32), 10)); - - tvm::Bool lvalue_bool(false); - func(lvalue_bool); - func(tvm::Bool(true)); -} From 8d357fb7919cebf1df286e7ccd5f2d55248a2305 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 28 Nov 2023 09:37:49 -0600 Subject: [PATCH 16/21] [Container] Support non-nullable types in Array::Map Prior to this commit, the `Array::Map` member function could only be applied to nullable object types. This was due to the internal use of `U()` as the default value for initializing the output `ArrayNode`, where `U` is the return type of the mapping function. This default constructor is only available for nullable types, and would result in a compile-time failure for non-nullable types. This commit replaces `U()` with `ObjectRef()` in `Array::Map`, removing this limitation. Since all items in the output array are overwritten before returning to the calling scope, initializing the output array with `ObjectRef()` does not violate type safety. --- include/tvm/runtime/container/array.h | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/include/tvm/runtime/container/array.h b/include/tvm/runtime/container/array.h index ff0bd03ab9cb..ba8fdfac5565 100644 --- a/include/tvm/runtime/container/array.h +++ b/include/tvm/runtime/container/array.h @@ -827,8 +827,13 @@ class Array : public ObjectRef { // consisting of any previous elements that had mapped to // themselves (if any), and the element that didn't map to // itself. + // + // We cannot use `U()` as the default object, as `U` may be + // a non-nullable type. Since the default `ObjectRef()` + // will be overwritten before returning, all objects will be + // of type `U` for the calling scope. all_identical = false; - output = ArrayNode::CreateRepeated(arr->size(), U()); + output = ArrayNode::CreateRepeated(arr->size(), ObjectRef()); output->InitRange(0, arr->begin(), it); output->SetItem(it - arr->begin(), std::move(mapped)); it++; @@ -843,7 +848,12 @@ class Array : public ObjectRef { // compatible types isn't strictly necessary, as the first // mapped.same_as(*it) would return false, but we might as well // avoid it altogether. - output = ArrayNode::CreateRepeated(arr->size(), U()); + // + // We cannot use `U()` as the default object, as `U` may be a + // non-nullable type. Since the default `ObjectRef()` will be + // overwritten before returning, all objects will be of type `U` + // for the calling scope. + output = ArrayNode::CreateRepeated(arr->size(), ObjectRef()); } // Normal path for incompatible types, or post-copy path for From e739ba9d8ea66bf3921721413c73566c82bcdec0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 13 Nov 2023 09:04:19 -0600 Subject: [PATCH 17/21] [FFI] Separate runtime types from IR types for int/float/bool Prior to this commit, `int`, `float`, and `bool` arguments from Python were converted to `IntImm`, `FloatImm`, and `Bool`. These are subtypes of `PrimExpr`, and should only be used at compile-time. By automatically applying this conversion as part of the FFI, these types are required to be present whenever a primitive is converted to a `tvm::ObjectRef`. This can become especially fragile for an end-user when storing objects into a TVM container. Because TVM containers require all contents to be `ObjectRef` subclasses, an automatic conversion may be applied on storing into a container, resulting in an unexpected type being retrieved from the container. For example, this currently occurs in Relax when extracting a `R.Prim` from a `R.Tuple`. This commit introduces a `Box` type for storage of boxed primitives at runtime, distinct from the IR types. * Primitive arguments provided to a PackedFunc that requires an `ObjectRef` will be converted to the corresponding boxed type. (e.g. Passing a Python `int` to a C++ function accepting `ObjectRef` produces a `Box`. * Boxed primitives provided to a PackedFunc that requires an unboxed primitive will be converted to the corresponding primitive. * PackedFunc return values of `ObjectRef` are converted to the corresponding primitive, if present. (e.g. If a `tuple_getitem` with static return type `ObjectRef` returns a `Box`, it will be unwrapped to a python `int`.) Together, these three rules provide backwards compatibility for existing PackedFunc definitions, while avoiding exposing the user to any container-induced type conversions betweeen primitive types and `ObjectRef`. --- include/tvm/ir/expr.h | 54 ++- .../tvm/runtime/container/boxed_primitive.h | 121 ++++++ include/tvm/runtime/packed_func.h | 368 ++++++++++++++---- python/tvm/_ffi/_ctypes/object.py | 13 + python/tvm/_ffi/_ctypes/packed_func.py | 9 +- python/tvm/_ffi/_cython/object.pxi | 9 + python/tvm/runtime/__init__.py | 4 +- python/tvm/runtime/container.py | 40 ++ python/tvm/runtime/object_generic.py | 74 ++-- src/node/boxed_primitive.cc | 125 ++++++ src/runtime/boxed_primitive.cc | 65 ++++ src/support/ffi_testing.cc | 19 + .../python/runtime/test_runtime_container.py | 73 +++- 13 files changed, 830 insertions(+), 144 deletions(-) create mode 100644 include/tvm/runtime/container/boxed_primitive.h create mode 100644 src/node/boxed_primitive.cc create mode 100644 src/runtime/boxed_primitive.cc diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 594e2b86e9f9..44f234eb286f 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -773,49 +773,47 @@ namespace runtime { template <> struct PackedFuncValueConverter { static PrimExpr From(const TVMPODValue_& val) { - if (val.type_code() == kTVMNullptr) { - return PrimExpr(ObjectPtr(nullptr)); + if (auto opt = val.TryAsBool()) { + return Bool(opt.value()); + } else if (auto opt = val.TryAsInt()) { + int64_t value = opt.value(); + auto dtype = + (value > std::numeric_limits::max() || value < std::numeric_limits::min()) + ? DataType::Int(64) + : DataType::Int(32); + return IntImm(dtype, value); + } else if (auto opt = val.TryAsFloat()) { + return FloatImm(runtime::DataType::Float(32), opt.value()); + } else { + return val.AsObjectRef(); } - if (val.type_code() == kDLInt) { - int64_t value = val.operator int64_t(); - if (value > std::numeric_limits::max() || value < std::numeric_limits::min()) { - return IntImm(runtime::DataType::Int(64), value); - } - return IntImm(runtime::DataType::Int(32), val.operator int()); - } - if (val.type_code() == kDLFloat) { - return FloatImm(runtime::DataType::Float(32), val.operator double()); - } - - return PrimExpr::FromObject_(val.AsObjectRef()); } }; template <> struct PackedFuncValueConverter { static tvm::Integer From(const TVMPODValue_& val) { - if (val.type_code() == kTVMNullptr) { - return Integer(ObjectPtr(nullptr)); + if (auto opt = val.TryAsInt()) { + return Integer(opt.value()); + } else { + return val.AsObjectRef(); } - if (val.type_code() == kTVMArgInt) { - return Integer(val.operator int()); - } - return val.AsObjectRef(); } }; template <> struct PackedFuncValueConverter { static tvm::Bool From(const TVMPODValue_& val) { - if (val.type_code() == kTVMNullptr) { - return Bool(ObjectPtr(nullptr)); - } - if (val.type_code() == kTVMArgInt) { - int v = val.operator int(); - ICHECK(v == 0 || v == 1) << "ValueError: boolean value can only be 0 or 1, but get " << v; - return Bool(static_cast(v)); + if (auto opt = val.TryAsBool()) { + return Bool(opt.value()); + } else if (auto opt = val.TryAsInt()) { + int value = opt.value(); + ICHECK(value == 0 || value == 1) + << "ValueError: boolean value can only be 0 or 1, but get " << value; + return Bool(static_cast(value)); + } else { + return val.AsObjectRef(); } - return val.AsObjectRef(); } }; diff --git a/include/tvm/runtime/container/boxed_primitive.h b/include/tvm/runtime/container/boxed_primitive.h new file mode 100644 index 000000000000..bf47e354fb33 --- /dev/null +++ b/include/tvm/runtime/container/boxed_primitive.h @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container/boxed_primitive.h + * \brief Runtime container types for primitives stored as ObjectRef. + */ +#ifndef TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_ +#define TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_ + +#include +#include + +namespace tvm { +namespace runtime { + +namespace detail { +/* \brief Provide the BoxNode type key in templated contexts + * + * The Box class is used in many templated contexts, and is easier + * to have templated over the primitive type. However, much of the + * TVM type system depends on classes having a unique name. For + * example, the use of `Object::IsInstance` depends on + * `Object::GetOrAllocRuntimeTypeIndex`. Any duplicate names will + * result in duplicate indices, and invalid downcasting. + * + * Furthermore, the name must be specified in the Python FFI using + * `tvm._ffi.register_object`. This prevents use of + * `typeid(T)::name()` to build a unique name, as the name is not + * required to be human-readable or consistent across compilers. + * + * This utility struct exists to bridge that gap, providing a unique + * name where required. + */ +template +struct BoxNodeTypeKey; + +template <> +struct BoxNodeTypeKey { + static constexpr const char* _type_key = "runtime.BoxInt"; +}; + +template <> +struct BoxNodeTypeKey { + static constexpr const char* _type_key = "runtime.BoxFloat"; +}; + +template <> +struct BoxNodeTypeKey { + static constexpr const char* _type_key = "runtime.BoxBool"; +}; +} // namespace detail + +template +class BoxNode : public Object { + public: + /*! \brief Constructor + * + * \param value The value to be boxed + */ + BoxNode(Prim value) : value(value) {} + + /*! \brief The boxed value */ + Prim value; + + static constexpr const char* _type_key = detail::BoxNodeTypeKey::_type_key; + static constexpr bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(BoxNode, Object); +}; + +template +class Box : public ObjectRef { + public: + /*! \brief Constructor + * + * \param value The value to be boxed + */ + Box(Prim value) : ObjectRef(make_object>(value)) {} + + operator Prim() const { return (*this)->value; } + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Box, ObjectRef, BoxNode); +}; + +/*! \brief Runtime equivalent of IntImm */ +using BoxInt = Box; + +/*! \brief Runtime equivalent of FloatImm */ +using BoxFloat = Box; + +/*! \brief Runtime equivalent of IntImm with DataType::Bool() + * + * When passing from Python to C++, TVM PackedFunc conversion follow + * C++ conversion rules, and allow bool->int and int->bool + * conversions. When passing from C++ to Python, the types are + * returned as bool or int. If the C++ function uses ObjectRef to + * hold the object, a Python to C++ to Python round trip will preserve + * the distinction between bool and int. + */ +using BoxBool = Box; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_ diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 7266f8c4a50a..4a8f2a681f20 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -37,6 +38,7 @@ #include #include #include +#include #include #include #include @@ -429,9 +431,11 @@ inline const char* ArgTypeCode2Str(int type_code); inline std::ostream& operator<<(std::ostream& os, DLDevice dev); // NOLINT(*) +#define TVM_LOG_INCORRECT_TYPE_CODE(CODE, T) \ + "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE) + // macro to check type code. -#define TVM_CHECK_TYPE_CODE(CODE, T) \ - ICHECK_EQ(CODE, T) << "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE) +#define TVM_CHECK_TYPE_CODE(CODE, T) ICHECK_EQ(CODE, T) << TVM_LOG_INCORRECT_TYPE_CODE(CODE, T) /*! * \brief Type traits for runtime type check during FFI conversion. @@ -555,29 +559,43 @@ class TVMPODValue_ { // Allow automatic conversion from int to float // This avoids errors when user pass in int from // the frontend while the API expects a float. - if (type_code_ == kDLInt) { - return static_cast(value_.v_int64); + if (auto opt = TryAsBool()) { + return opt.value(); + } else if (auto opt = TryAsInt()) { + return opt.value(); + } else if (auto opt = TryAsFloat()) { + return opt.value(); + } else { + LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLFloat); } - TVM_CHECK_TYPE_CODE(type_code_, kDLFloat); - return value_.v_float64; } operator int64_t() const { - TVM_CHECK_TYPE_CODE(type_code_, kDLInt); - return value_.v_int64; - } - operator uint64_t() const { - TVM_CHECK_TYPE_CODE(type_code_, kDLInt); - return value_.v_int64; + if (auto opt = TryAsBool()) { + return opt.value(); + } else if (auto opt = TryAsInt()) { + return opt.value(); + } else if (IsObjectRef()) { + auto obj = AsObjectRef(); + LOG(FATAL) << "Expected integer, but found object with type key " << obj->GetTypeKey(); + } else { + LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLInt); + } } + operator uint64_t() const { return operator int64_t(); } operator int() const { - TVM_CHECK_TYPE_CODE(type_code_, kDLInt); - ICHECK_LE(value_.v_int64, std::numeric_limits::max()); - ICHECK_GE(value_.v_int64, std::numeric_limits::min()); - return static_cast(value_.v_int64); + int64_t value = operator int64_t(); + ICHECK_LE(value, std::numeric_limits::max()); + ICHECK_GE(value, std::numeric_limits::min()); + return value; } operator bool() const { - TVM_CHECK_TYPE_CODE(type_code_, kDLInt); - return value_.v_int64 != 0; + if (auto opt = TryAsBool()) { + return opt.value(); + } else if (auto opt = TryAsInt()) { + return opt.value(); + } else { + LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLInt); + } } operator void*() const { if (type_code_ == kTVMNullptr) return nullptr; @@ -635,6 +653,38 @@ class TVMPODValue_ { template inline TObjectRef AsObjectRef() const; + std::optional TryAsInt() const { + // Helper function to reduce duplication in the variable integer + // conversions. This is publicly exposed, as it can be useful in + // specializations of PackedFuncValueConverter. + if (auto opt = FromBoxed()) { + return opt.value(); + } else if (type_code_ == kDLInt) { + return value_.v_int64; + } else { + return std::nullopt; + } + } + + std::optional TryAsFloat() const { + // Helper function to reduce duplication in the variable integer + // conversions. This is publicly exposed, as it can be useful in + // specializations of PackedFuncValueConverter. + if (auto opt = FromBoxed()) { + return opt.value(); + } else if (type_code_ == kDLFloat) { + return value_.v_float64; + } else { + return std::nullopt; + } + } + + std::optional TryAsBool() const { + // Booleans may be kept distinct from Int by using Box and + // Box. + return FromBoxed(); + } + protected: friend class TVMArgsSetter; friend class TVMRetValue; @@ -642,6 +692,15 @@ class TVMPODValue_ { TVMPODValue_() : type_code_(kTVMNullptr) {} TVMPODValue_(TVMValue value, int type_code) : value_(value), type_code_(type_code) {} + template + std::optional FromBoxed() const { + if (IsObjectRef>()) { + return AsObjectRef>()->value; + } else { + return std::nullopt; + } + } + /*! \brief The value */ TVMValue value_; /*! \brief the type code */ @@ -901,9 +960,12 @@ class TVMRetValue : public TVMPODValue_ { } TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); } TVMRetValue& operator=(bool value) { - this->SwitchToPOD(kDLInt); - value_.v_int64 = value; - return *this; + // While a boolean could be stored using the primitive kDLInt + // type, this causes round-trip inconsistencies for languages that + // distinguish between integer and boolean types (i.e. Anything + // after C89). Rather than adding another type for booleans, this + // is stored in the Box container. + return operator=(Box(value)); } TVMRetValue& operator=(std::string value) { this->SwitchToClass(kTVMStr, value); @@ -989,9 +1051,9 @@ class TVMRetValue : public TVMPODValue_ { } // ObjectRef handling template ::value>::type> + typename = typename std::enable_if_t>> inline TVMRetValue& operator=(TObjectRef other); - template ::value>::type> + template >> inline operator T() const; private: @@ -1019,9 +1081,11 @@ class TVMRetValue : public TVMPODValue_ { break; } case kTVMObjectHandle: { - // Avoid operator ObjectRef as we already know it is not NDArray/Module - SwitchToObject(kTVMObjectHandle, - GetObjectPtr(static_cast(other.value_.v_handle))); + // We already known it is not NDArray/Module, but + // operator=(ObjectRef) also handles conversions from wrappers + // around primitive types. For NDArray/Module, the duplicate + // checks are removed with if constexpr. + operator=(other.operator ObjectRef()); break; } case kTVMObjectRValueRefArg: { @@ -1951,33 +2015,96 @@ inline T TVMArgs::At(int i) const { template inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { using ContainerType = typename std::remove_reference::type::ContainerType; - if (value.defined()) { - Object* ptr = value.data_.data_; - if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { + if (!value.defined()) { + type_codes_[i] = kTVMNullptr; + values_[i].v_handle = nullptr; + return; + } + + Object* ptr = value.data_.data_; + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { values_[i].v_handle = NDArray::FFIGetHandle(value); type_codes_[i] = kTVMNDArrayHandle; - } else if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { + return; + } + } + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { values_[i].v_handle = ptr; type_codes_[i] = kTVMModuleHandle; - } else if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { + return; + } + } + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { values_[i].v_handle = ptr; type_codes_[i] = kTVMPackedFuncHandle; - } else if (std::is_rvalue_reference::value) { - values_[i].v_handle = const_cast(&(value.data_.data_)); - type_codes_[i] = kTVMObjectRValueRefArg; - } else { - values_[i].v_handle = value.data_.data_; - type_codes_[i] = kTVMObjectHandle; + return; + } + } + + // If a boxed integer is being returned, always unbox it to the + // primitive type. This must be checked at the PackedFunc level to + // ensure that a boxed primitive argument is round-tripped correctly + // when the boxing is no longer required. + // + // For example, consider a PackedFunc with signature `ObjectRef + // func(Array)`, and returns the first element of that + // array. When passing a Python array `[5, 17.5, "hello"]`, the + // items are converted to `[Box(5), Box(17.5), + // String("hello")]` in order to provide an `Array`. + // + // If we had no additional conversions, the caller would receive the + // return value as a `Box(5)`, which would be unexpected and + // require additional unwrapping. We could perform this check + // inside the PackedFunc, but that would require a large amount of + // duplicated checked, and would require explicit handling of + // `TVMRetValue`. Instead, this conversion is checked in the FFI + // return value, to ensure that boxing/unboxing is applied + // consistently. + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { + values_[i].v_int64 = static_cast(ptr)->value; + type_codes_[i] = kTVMArgInt; + return; + } + } + + // Like with BoxInt, unwrap any BoxFloat instances. See the BoxInt + // explanation for more detail. + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { + values_[i].v_float64 = static_cast(ptr)->value; + type_codes_[i] = kTVMArgFloat; + return; } + } + + // Deliberately do *not* unwrap BoxBool instances. If BoxBool were + // unwrapped to kTVMArgInt, it would be ambiguous whether the + // user-defined object was a bool or an int. + + // Final fallback, if the ObjectRef has no special cases that must + // be expressed within the TVMRetValue. + if constexpr (std::is_rvalue_reference_v) { + values_[i].v_handle = const_cast(&(value.data_.data_)); + type_codes_[i] = kTVMObjectRValueRefArg; } else { - type_codes_[i] = kTVMNullptr; - values_[i].v_handle = nullptr; + values_[i].v_handle = value.data_.data_; + type_codes_[i] = kTVMObjectHandle; } } @@ -2023,8 +2150,10 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expect a not null value of " << ContainerType::_type_key; return TObjectRef(ObjectPtr(nullptr)); } - // NOTE: the following code can be optimized by constant folding. - if (std::is_base_of::value) { + + // NOTE: The following code uses "if constexpr" wherever possible to + // minimize the number of runtime checks. + if constexpr (std::is_base_of_v) { // Casting to a sub-class of NDArray TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle); ObjectPtr data = @@ -2033,7 +2162,8 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); return TObjectRef(data); } - if (std::is_base_of::value) { + + if constexpr (std::is_base_of_v) { // Casting to a sub-class of Module TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle); ObjectPtr data = GetObjectPtr(static_cast(value_.v_handle)); @@ -2041,7 +2171,8 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); return TObjectRef(data); } - if (std::is_base_of::value) { + + if constexpr (std::is_base_of_v) { // Casting to a sub-class of PackedFunc TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); ObjectPtr data = GetObjectPtr(static_cast(value_.v_handle)); @@ -2049,6 +2180,7 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); return TObjectRef(data); } + if (type_code_ == kTVMObjectHandle) { // normal object type check. Object* ptr = static_cast(value_.v_handle); @@ -2062,46 +2194,100 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { ICHECK(!checked_type.defined()) << "Expected " << ObjectTypeChecker::TypeName() << ", but got " << checked_type.value(); return TObjectRef(GetObjectPtr(ptr)); - } else if (std::is_base_of::value && - type_code_ == kTVMNDArrayHandle) { - // Casting to a base class that NDArray can sub-class - ObjectPtr data = - NDArray::FFIDataFromHandle(static_cast(value_.v_handle)); - return TObjectRef(data); - } else if (std::is_base_of::value && - type_code_ == kTVMModuleHandle) { - // Casting to a base class that Module can sub-class - return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); - } else if (std::is_base_of::value && - type_code_ == kTVMPackedFuncHandle) { - // Casting to a base class that PackedFunc can sub-class - return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); - } else { - TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); - return TObjectRef(ObjectPtr(nullptr)); } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMNDArrayHandle) { + // Casting to a base class that NDArray can sub-class + ObjectPtr data = + NDArray::FFIDataFromHandle(static_cast(value_.v_handle)); + return TObjectRef(data); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMModuleHandle) { + // Casting to a base class that Module can sub-class + return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMPackedFuncHandle) { + // Casting to a base class that PackedFunc can sub-class + return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMArgInt) { + return BoxInt(value_.v_int64); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMArgFloat) { + return BoxFloat(value_.v_float64); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMStr) { + return String(value_.v_str); + } + } + + TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); + return TObjectRef(ObjectPtr(nullptr)); } template inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { using ContainerType = typename TObjectRef::ContainerType; const Object* ptr = other.get(); - if (ptr != nullptr) { - if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (ptr && (std::is_base_of_v || + ptr->IsInstance())) { return operator=(NDArray(std::move(other.data_))); } - if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { + } + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (ptr && (std::is_base_of_v || + ptr->IsInstance())) { return operator=(Module(std::move(other.data_))); } - if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { + } + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (ptr && (std::is_base_of_v || + ptr->IsInstance())) { return operator=(PackedFunc(std::move(other.data_))); } + } + + if constexpr (std::is_base_of_v || std::is_base_of_v) { + if (ptr && + (std::is_base_of_v || ptr->IsInstance())) { + int64_t value = static_cast(ptr)->value; + return operator=(value); + } + } + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (ptr && + (std::is_base_of_v || ptr->IsInstance())) { + double value = static_cast(ptr)->value; + return operator=(value); + } + } + + if (ptr) { SwitchToObject(kTVMObjectHandle, std::move(other.data_)); } else { SwitchToPOD(kTVMNullptr); @@ -2156,6 +2342,34 @@ struct PackedFuncValueConverter<::tvm::runtime::String> { } }; +template +struct PackedFuncValueConverter> { + static Array From(const TVMArgValue& val) { + auto untyped_array = val.AsObjectRef>(); + + // Attempt to convert each item of the array into the desired + // type. If the items do not require a conversion, no copies are + // made. + return untyped_array.Map([](ObjectRef item) { + // The TVMArgValue is intentionally defined through + // `TVMArgsSetter`, rather than defining it with + // `value.data_ = item.get();` and type code + // `kTVMObjectHandle`. `TVMArgsSetter::operator()` includes + // special handling for unwrapping boxed primitives, + // PackedFunc, runtime::Module, etc, which should be checked + // before delegating to the array element's + // PackedFuncValueConverter implementation. + TVMValue value; + int type_code; + TVMArgsSetter setter(&value, &type_code); + setter(0, item); + TVMArgValue arg(value, type_code); + return PackedFuncValueConverter::From(arg); + }); + } + static Array From(const TVMRetValue& val) { return val.AsObjectRef>(); } +}; + template struct PackedFuncValueConverter> { static Optional From(const TVMArgValue& val) { @@ -2207,7 +2421,7 @@ struct PackedFuncValueConverter> { static Optional TryValueConverter(const PODSubclass& val) { try { return VType(PackedFuncValueConverter::From(val)); - } catch (const InternalError&) { + } catch (const Error&) { } if constexpr (sizeof...(VarRest)) { diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py index 520e0e42ebbe..2c739b7cfbab 100644 --- a/python/tvm/_ffi/_ctypes/object.py +++ b/python/tvm/_ffi/_ctypes/object.py @@ -60,14 +60,27 @@ def _return_object(x): tindex = ctypes.c_uint() check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex))) cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT) + + # Handle return values that subclass from both TVM objects and + # python native objects (e.g. runtime.String, a subclass of str). if issubclass(cls, PyNativeObject): obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) obj.handle = handle return cls.__from_tvm_object__(cls, obj) + # Avoid calling __init__ of cls, instead directly call __new__ # This allows child class to implement their own __init__ obj = cls.__new__(cls) obj.handle = handle + + # Handle return values that must be converted from the TVM object + # to a python native object. This should be used in cases where + # subclassing the python native object is forbidden. For example, + # `runtime.BoxBool` cannot be a subclass of `bool`, as `bool` does + # not allow any subclasses. + if hasattr(obj, '__into_pynative_object__'): + return obj.__into_pynative_object__() + return obj diff --git a/python/tvm/_ffi/_ctypes/packed_func.py b/python/tvm/_ffi/_ctypes/packed_func.py index 6465e0335db0..dee5e0cc2f1a 100644 --- a/python/tvm/_ffi/_ctypes/packed_func.py +++ b/python/tvm/_ffi/_ctypes/packed_func.py @@ -134,6 +134,13 @@ def _make_tvm_args(args, temp_args): elif isinstance(arg, _nd._TVM_COMPATS): values[i].v_handle = ctypes.c_void_p(arg._tvm_handle) type_codes[i] = arg.__class__._tvm_tcode + elif isinstance(arg, bool): + # A python `bool` is a subclass of `int`, so this check + # must occur before `Integral`. + arg = _FUNC_CONVERT_TO_OBJECT(arg) + values[i].v_handle = arg.handle + type_codes[i] = ArgTypeCode.OBJECT_HANDLE + temp_args.append(arg) elif isinstance(arg, Integral): values[i].v_int64 = arg type_codes[i] = ArgTypeCode.INT @@ -147,7 +154,7 @@ def _make_tvm_args(args, temp_args): values[i].v_int64 = _device_to_int64(arg) type_codes[i] = ArgTypeCode.DLDEVICE elif isinstance(arg, (bytearray, bytes)): - # from_buffer only taeks in bytearray. + # from_buffer only takes in bytearray. if isinstance(arg, bytes): byte_arr = bytearray(arg) temp_args.append(byte_arr) diff --git a/python/tvm/_ffi/_cython/object.pxi b/python/tvm/_ffi/_cython/object.pxi index 94a9310d7815..50ec3ce0fb00 100644 --- a/python/tvm/_ffi/_cython/object.pxi +++ b/python/tvm/_ffi/_cython/object.pxi @@ -60,6 +60,15 @@ cdef inline object make_ret_object(void* chandle): obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) (obj).chandle = chandle + + # Handle return values that must be converted from the TVM object + # to a python native object. This should be used in cases where + # subclassing the python native object is forbidden. For example, + # `runtime.BoxBool` cannot be a subclass of `bool`, as `bool` does + # not allow any subclasses. + if hasattr(obj, '__into_pynative_object__'): + return obj.__into_pynative_object__() + return obj diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index eccdcbad9520..1f11cba76249 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -27,11 +27,11 @@ from .profiling import Report # function exposures -from .object_generic import convert_to_object, convert, const from .ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl from .ndarray import vpi, rocm, ext_dev from .module import load_module, enabled, system_lib, load_static_library -from .container import String, ShapeTuple +from .container import String, ShapeTuple, BoxBool +from .object_generic import convert_to_object, convert, const from .params import ( save_param_dict, load_param_dict, diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py index 686b4a26c80c..7c754ba1e622 100644 --- a/python/tvm/runtime/container.py +++ b/python/tvm/runtime/container.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Runtime container structures.""" +from typing import Union + import tvm._ffi from .object import Object, PyNativeObject from .object_generic import ObjectTypes @@ -172,3 +174,41 @@ def __eq__(self, other): return False return True + + +@tvm._ffi.register_object("runtime.BoxBool") +class BoxBool(Object): + """A boolean wrapped as a tvm Object + + Parameters + ---------- + value: bool + + The value to hold + """ + + def __init__(self, value: bool): + # Convert to int to avoid an infinite recursion, because + # BoxBool may be constructed in _make_tvm_args, and calling + # the packed func `_ffi_api.BoxBool` internally calls + # `_make_tvm_args`. + self.__init_handle_by_constructor__(_ffi_api.BoxBool, int(value)) + + def __into_pynative_object__(self) -> bool: + return self.value + + @property + def value(self) -> bool: + """Unwrap the boxed value. + + This is implemented explicitly rather than using the usual + PackedFunc handling or AttrVisitor mechanics for two reasons. + First, because the PackedFunc handling would require ambiguous + representations between `True`/`1` and `False`/`0`. Second, + because the boxing/unboxing must be available in + `libtvm_runtime.so`, and AttrVisitor is only available in + `libtvm.so`. + """ + unboxed_bool = _ffi_api.UnBoxBool(self) + assert unboxed_bool is not None + return bool(unboxed_bool) diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index 887c2faaeb2b..b0b29426dbee 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -41,6 +41,14 @@ def asobject(self): def convert_to_object(value, span=None): """Convert a Python value to corresponding object type. + Type conversions performed by this function must *only* produce + types that are supported by `libtvm_runtime.so`. This function + must be usable in environments where only TVM runtime support is + present. Automatic conversions to compile-time representations + (e.g. `tir.IntImm` or `relax.PrimValue`) should not be done as + part of this conversion, as these types are not available in + `libtvm_runtime.so`. + Parameters ---------- value : str @@ -53,38 +61,46 @@ def convert_to_object(value, span=None): ------- obj : Object The corresponding object value. + """ + + # Import inside function call to avoid circular import from + # uninitialized tvm.runtime module. + from .container import BoxBool # pylint: disable=import-outside-toplevel + if isinstance(value, ObjectTypes): return value - if isinstance(value, bool): - return const(value, "uint1x1", span=span) - if isinstance(value, Number): + + elif isinstance(value, bool): + # Python types int and float will be converted to C++ types + # Box and Box using kDLInt. Boolean types need + # to be explicitly converted to Box to avoid ambiguous + # representation. This allows `bool(True)` and `int(1)` to be + # unambiguously passed to the C++ implementations. + return BoxBool(value) + + elif isinstance(value, Number): return const(value, span=span) - if isinstance(value, string_types): + elif isinstance(value, string_types): return _ffi_api.String(value) - if isinstance(value, (list, tuple)): - value = [convert_to_object(x) for x in value] + elif isinstance(value, (list, tuple)): + # The call to _ffi_api.Array will convert its own arguments, + # so we don't need to apply any explicit conversions here. return _ffi_api.Array(*value) - if isinstance(value, dict): - vlist = [] - for item in value.items(): - if ( - not isinstance(item[0], ObjectTypes) - and not isinstance(item[0], string_types) - and not isinstance(item[0], Number) - ): - raise ValueError("key of map must already been a container type") - vlist.append(convert_to_object(item[0])) - vlist.append(convert_to_object(item[1])) + elif isinstance(value, dict): + if any(not isinstance(key, (ObjectTypes, string_types, Number)) for key in value): + raise ValueError("key of map must already been a container type") + + vlist = [kv for item in value.items() for kv in item] return _ffi_api.Map(*vlist) - if isinstance(value, ObjectGeneric): + elif isinstance(value, ObjectGeneric): return value.asobject() - if callable(value): + elif callable(value): return convert_to_tvm_func(value) - if value is None: + elif value is None: return None - - raise ValueError(f"don't know how to convert type {type(value)} to object") + else: + raise TypeError(f"don't know how to convert type {type(value)} to object") def convert(value, span=None): @@ -107,29 +123,29 @@ def convert(value, span=None): This function is redirected to `convert_to_object` as it is widely used in the codebase. We can choose one to keep and discard the other one later. """ + return convert_to_object(value, span=span) def _scalar_type_inference(value): if hasattr(value, "dtype"): - dtype = str(value.dtype) + return str(value.dtype) elif isinstance(value, bool): - dtype = "bool" + return "bool" elif isinstance(value, float): # We intentionally prefer convert the float to float32 since it's more common in DL. if -3.40282347e38 <= value <= 3.40282347e38: - dtype = "float32" + return "float32" else: - dtype = "float64" + return "float64" elif isinstance(value, int): # We intentionally prefer convert the python int to int32 since it's more common in DL. if -2147483648 <= value <= 2147483647: - dtype = "int32" + return "int32" else: - dtype = "int64" + return "int64" else: raise NotImplementedError(f"Cannot automatically inference the type. value={value}") - return dtype def const(value, dtype=None, span=None): diff --git a/src/node/boxed_primitive.cc b/src/node/boxed_primitive.cc new file mode 100644 index 000000000000..10b308fcf4db --- /dev/null +++ b/src/node/boxed_primitive.cc @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file node/boxed_primitive.cc + * + * \brief Reflection utilities for runtime-supported classes + * + * The fundamental support for boxing and unboxing of primitives + * during FFI calls is implemented in runtime/boxed_primitive.cc. In + * addition, boxed primitives may be registered with compile-time + * utilities (e.g. reflection, JSON import/export) that can provide + * additional functionality and improved debugging ability. However, + * neither these compile-time utilities nor any registration of + * `Box` into the compile-time utilities should be included as + * part of `libtvm_runtime.so`. + * + * This file contains the registration of the `libtvm_runtime.so` + * class `Box` for utilities that are contained in `libtvm.so`. + */ +#include +#include +#include +#include + +namespace tvm { +namespace runtime_ext { + +using runtime::Box; +using runtime::BoxNode; + +template +struct BoxNodeTrait { + static constexpr const std::nullptr_t VisitAttrs = nullptr; + + static void SHashReduce(const BoxNode* node, SHashReducer hash_reduce) { + hash_reduce(node->value); + } + + static bool SEqualReduce(const BoxNode* lhs, const BoxNode* rhs, + SEqualReducer equal) { + return equal(lhs->value, rhs->value); + } +}; + +TVM_REGISTER_REFLECTION_VTABLE(BoxNode, BoxNodeTrait) + .set_creator([](const std::string& blob) -> ObjectPtr { + int64_t value = std::atoll(blob.c_str()); + return make_object>(value); + }) + .set_repr_bytes([](const Object* n) -> std::string { + int64_t value = GetRef(n).as>().value()->value; + std::stringstream ss; + ss << value; + return ss.str(); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch>([](const ObjectRef& node, ReprPrinter* p) { + auto box = Downcast>(node); + p->stream << box->GetTypeKey() << "(" << box->value << ")"; + }); + +TVM_REGISTER_REFLECTION_VTABLE(BoxNode, BoxNodeTrait) + .set_creator([](const std::string& blob) -> ObjectPtr { + if (blob == "true") { + return make_object>(true); + } else if (blob == "false") { + return make_object>(false); + } else { + LOG(FATAL) << "Invalid string '" << blob << "' for boolean"; + } + }) + .set_repr_bytes([](const Object* n) -> std::string { + bool value = GetRef(n).as>().value()->value; + if (value) { + return "true"; + } else { + return "false"; + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch>([](const ObjectRef& node, ReprPrinter* p) { + auto box = Downcast>(node); + p->stream << box->GetTypeKey() << "(" << (box->value ? "true" : "false") << ")"; + }); + +TVM_REGISTER_REFLECTION_VTABLE(BoxNode, BoxNodeTrait) + .set_creator([](const std::string& blob) -> ObjectPtr { + double value = std::atof(blob.c_str()); + return make_object>(value); + }) + .set_repr_bytes([](const Object* n) -> std::string { + double value = GetRef(n).as>().value()->value; + std::stringstream ss; + ss << value; + return ss.str(); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch>([](const ObjectRef& node, ReprPrinter* p) { + auto box = Downcast>(node); + p->stream << box->GetTypeKey() << "(" << box->value << ")"; + }); + +} // namespace runtime_ext + +} // namespace tvm diff --git a/src/runtime/boxed_primitive.cc b/src/runtime/boxed_primitive.cc new file mode 100644 index 000000000000..9ab83a7b471c --- /dev/null +++ b/src/runtime/boxed_primitive.cc @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/boxed_primitive.cc + * \brief Implementations of ObjectRef wrapper. + */ + +#include +#include + +namespace tvm { +namespace runtime { + +TVM_REGISTER_OBJECT_TYPE(BoxNode); +TVM_REGISTER_OBJECT_TYPE(BoxNode); +TVM_REGISTER_OBJECT_TYPE(BoxNode); + +/* \brief Allow explicit construction of Box + * + * Convert a `bool` to `Box`. For use in FFI handling, to + * provide an umambiguous representation between `bool(true)` and + * `int(1)`. Will be automatically unboxed in the case where a + * `Box` is provided to a PackedFunc that requires `int` input, + * mimicking C++'s default conversions. + * + * This is only needed for Box, as Box and Box + * can be converted in C++ as part of `TVMArgValue::operator + * ObjectRef()` without ambiguity, postponing conversions until + * required. + */ +TVM_REGISTER_GLOBAL("runtime.BoxBool").set_body_typed([](bool value) { return Box(value); }); + +/* \brief Return the underlying boolean object. + * + * Used while unboxing a boolean return value during FFI handling. + * The return type is intentionally `int` and not `bool`, to avoid + * recursive unwrapping of boolean values. + * + * This is only needed for Box, as Box and Box + * can be unambiguously unboxed as part of + * `TVMRetValue::operator=(ObjectRef)`. + */ +TVM_REGISTER_GLOBAL("runtime.UnBoxBool").set_body_typed([](Box obj) -> int { + return obj->value; +}); + +} // namespace runtime +} // namespace tvm diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 75b5a2527f76..f7994d940bc0 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -189,4 +189,23 @@ TVM_REGISTER_GLOBAL("testing.ReturnsVariant").set_body_typed([](int x) -> Varian TVM_REGISTER_GLOBAL("testing.AcceptsVariant") .set_body_typed([](Variant arg) -> String { return arg->GetTypeKey(); }); +TVM_REGISTER_GLOBAL("testing.AcceptsBool").set_body_typed([](bool arg) -> bool { return arg; }); + +TVM_REGISTER_GLOBAL("testing.AcceptsInt").set_body_typed([](int arg) -> int { return arg; }); + +TVM_REGISTER_GLOBAL("testing.AcceptsObjectRef").set_body_typed([](ObjectRef arg) -> ObjectRef { + return arg; +}); + +TVM_REGISTER_GLOBAL("testing.AcceptsObjectRefArray") + .set_body_typed([](Array arg) -> ObjectRef { return arg[0]; }); + +TVM_REGISTER_GLOBAL("testing.AcceptsMapReturnsValue") + .set_body_typed([](Map map, ObjectRef key) -> ObjectRef { + return map[key]; + }); + +TVM_REGISTER_GLOBAL("testing.AcceptsMapReturnsMap") + .set_body_typed([](Map map) -> ObjectRef { return map; }); + } // namespace tvm diff --git a/tests/python/runtime/test_runtime_container.py b/tests/python/runtime/test_runtime_container.py index 7538075ae7f8..ec9da414cc36 100644 --- a/tests/python/runtime/test_runtime_container.py +++ b/tests/python/runtime/test_runtime_container.py @@ -15,12 +15,13 @@ # specific language governing permissions and limitations # under the License. -import numpy as np +import pickle import random + +import numpy as np + import tvm import tvm.testing -import pickle -from tvm import te from tvm import nd, relay from tvm.runtime import container as _container @@ -96,8 +97,66 @@ def test_shape_tuple(): assert stuple == z +def test_bool_argument(): + """Boolean objects are currently stored as int""" + func = tvm.get_global_func("testing.AcceptsBool") + + assert isinstance(func(True), bool) + assert isinstance(func(1), bool) + assert isinstance(func(0), bool) + + +def test_int_argument(): + func = tvm.get_global_func("testing.AcceptsInt") + + assert isinstance(func(True), int) + assert isinstance(func(1), int) + assert isinstance(func(0), int) + + +def test_object_ref_argument(): + func = tvm.get_global_func("testing.AcceptsObjectRef") + + assert isinstance(func(True), bool) + assert isinstance(func(1), int) + assert isinstance(func(3.5), float) + assert func(3.5) == 3.5 + + +def test_object_ref_array_argument(): + func = tvm.get_global_func("testing.AcceptsObjectRefArray") + + assert isinstance(func([True, 17, "hello"]), bool) + assert isinstance(func([True]), bool) + assert isinstance(func([17]), int) + assert isinstance(func(["hello"]), str) + + +def test_map_argument_returns_value(): + func = tvm.get_global_func("testing.AcceptsMapReturnsValue") + + res = func({"a": 1, "b": 2}, "a") + assert isinstance(res, int) + assert res == 1 + + res = func({"a": True, "b": False}, "a") + assert isinstance(res, bool) + assert res == True + + +def test_map_argument_returns_map(): + func = tvm.get_global_func("testing.AcceptsMapReturnsMap") + + res = func({"a": 1, "b": 2}) + for key, value in res.items(): + assert isinstance(key, str) + assert isinstance(value, int) + + res = func({"a": False, "b": True}) + for key, value in res.items(): + assert isinstance(key, str) + assert isinstance(value, bool) + + if __name__ == "__main__": - test_string() - test_adt_constructor() - test_tuple_object() - test_shape_tuple() + tvm.testing.main() From d376d765eae31723ff0898ea9db1607d07ff0260 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 27 Nov 2023 15:04:26 -0600 Subject: [PATCH 18/21] [UnitTest] Update apache/main unit tests for Box Mostly, this requires removing `.value` unwrapping that is now applied automatically. --- include/tvm/runtime/packed_func.h | 10 +++++++++- src/node/script_printer.cc | 14 +++++++------- tests/python/ir/test_ir_container.py | 15 +++++++++------ .../python/tvmscript/test_tvmscript_roundtrip.py | 4 ++-- 4 files changed, 27 insertions(+), 16 deletions(-) diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 4a8f2a681f20..327ff78deefa 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -2367,7 +2367,15 @@ struct PackedFuncValueConverter> { return PackedFuncValueConverter::From(arg); }); } - static Array From(const TVMRetValue& val) { return val.AsObjectRef>(); } + static Array From(const TVMRetValue& val) { + auto untyped_array = val.AsObjectRef>(); + + return untyped_array.Map([](ObjectRef item) { + TVMRetValue item_val; + item_val = std::move(item); + return PackedFuncValueConverter::From(item_val); + }); + } }; template diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index f2d985279f12..301216d04cad 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -48,7 +48,7 @@ PrinterConfig::PrinterConfig(Map config_dict) { n->binding_names.push_back(Downcast(v)); } if (auto v = config_dict.Get("show_meta")) { - n->show_meta = Downcast(v)->value; + n->show_meta = Downcast(v)->value; } if (auto v = config_dict.Get("ir_prefix")) { n->ir_prefix = Downcast(v); @@ -72,16 +72,16 @@ PrinterConfig::PrinterConfig(Map config_dict) { n->float_dtype = DataType(runtime::String2DLDataType(Downcast(v))); } if (auto v = config_dict.Get("verbose_expr")) { - n->verbose_expr = Downcast(v)->value; + n->verbose_expr = Downcast(v)->value; } if (auto v = config_dict.Get("indent_spaces")) { - n->indent_spaces = Downcast(v)->value; + n->indent_spaces = Downcast(v)->value; } if (auto v = config_dict.Get("print_line_numbers")) { - n->print_line_numbers = Downcast(v)->value; + n->print_line_numbers = Downcast(v)->value; } if (auto v = config_dict.Get("num_context_lines")) { - n->num_context_lines = Downcast(v)->value; + n->num_context_lines = Downcast(v)->value; } if (auto v = config_dict.Get("path_to_underline")) { n->path_to_underline = Downcast>>(v).value_or(Array()); @@ -98,10 +98,10 @@ PrinterConfig::PrinterConfig(Map config_dict) { Downcast>>(v).value_or(Map()); } if (auto v = config_dict.Get("syntax_sugar")) { - n->syntax_sugar = Downcast(v)->value; + n->syntax_sugar = Downcast(v)->value; } if (auto v = config_dict.Get("show_object_address")) { - n->show_object_address = Downcast(v)->value; + n->show_object_address = Downcast(v)->value; } // Checking prefixes if they are valid Python identifiers. diff --git a/tests/python/ir/test_ir_container.py b/tests/python/ir/test_ir_container.py index aa482dd65cd7..1e3249197851 100644 --- a/tests/python/ir/test_ir_container.py +++ b/tests/python/ir/test_ir_container.py @@ -23,16 +23,19 @@ def test_array(): a = tvm.runtime.convert([1, 2, 3]) assert len(a) == 3 - assert a[-1].value == 3 + assert a[-1] == 3 a_slice = a[-3:-1] - assert (a_slice[0].value, a_slice[1].value) == (1, 2) + assert (a_slice[0], a_slice[1]) == (1, 2) def test_array_save_load_json(): - a = tvm.runtime.convert([1, 2, 3]) + a = tvm.runtime.convert([1, 2, 3.5, True]) json_str = tvm.ir.save_json(a) a_loaded = tvm.ir.load_json(json_str) - assert a_loaded[1].value == 2 + assert a_loaded[1] == 2 + assert a_loaded[2] == 3.5 + assert a_loaded[3] == True + assert isinstance(a_loaded[3], bool) def test_dir_array(): @@ -66,7 +69,7 @@ def test_str_map(): assert "a" in amap assert len(amap) == 2 dd = dict(amap.items()) - assert amap["a"].value == 2 + assert amap["a"] == 2 assert "a" in dd assert "b" in dd @@ -78,7 +81,7 @@ def test_map_save_load_json(): json_str = tvm.ir.save_json(amap) amap = tvm.ir.load_json(json_str) assert len(amap) == 2 - dd = {kv[0].name: kv[1].value for kv in amap.items()} + dd = {kv[0].name: kv[1] for kv in amap.items()} assert dd == {"a": 2, "b": 3} diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index 5b3e68e22fa9..b2f9b7d51235 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -2689,14 +2689,14 @@ def test_match_buffer_region(): outer_block = root.body.body.body.block assert len(outer_block.match_buffers) == 1 buffer_C = outer_block.match_buffers[0].buffer - tvm.ir.assert_structural_equal(buffer_C.shape, [16, 1, 4]) + tvm.ir.assert_structural_equal(buffer_C.shape, [T.int32(16), T.int32(1), T.int32(4)]) assert isinstance(outer_block.body, tir.stmt.For) assert isinstance(outer_block.body.body, tir.stmt.BlockRealize) inner_block = outer_block.body.body.block assert len(inner_block.match_buffers) == 1 buffer_D = inner_block.match_buffers[0].buffer - tvm.ir.assert_structural_equal(buffer_D.shape, [4, 1, 4]) + tvm.ir.assert_structural_equal(buffer_D.shape, [T.int32(4), T.int32(1), T.int32(4)]) def block_elements(): From 45c1133b1dc86fed0bc80c4ba3e56b296a5871b9 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 28 Nov 2023 10:55:46 -0600 Subject: [PATCH 19/21] [TIR] Update FFI conversion registration * Change tir.Call signature accept `Array>` instead of `Array`. This allows the FFI to apply registered conversions. * Update target parsing to expect the default object types. * Extend conversion into PrimExpr. Several APIs that check for a PrimExpr may now receive a `runtime.String`, `runtime.Box` or `runtime.Box`. These must be converted to `StringImm`, `Bool`, or `IntImm` for use in functions that accept `PrimExpr`. --- include/tvm/ir/expr.h | 64 ++++++++++++--- include/tvm/target/target.h | 10 ++- include/tvm/target/target_kind.h | 4 +- include/tvm/tir/expr.h | 60 ++++++++++++++ src/relay/op/tensor/transform.cc | 6 +- src/target/tag.cc | 50 +++++------ src/target/target.cc | 63 ++++++-------- src/target/target_kind.cc | 137 ++++++++++++++++--------------- src/tir/ir/expr.cc | 4 +- 9 files changed, 250 insertions(+), 148 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 44f234eb286f..53c0e6be0fc4 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -769,23 +769,38 @@ inline const TTypeNode* RelayExprNode::type_as() const { namespace tvm { namespace runtime { -// common rule for RetValue and ArgValue + +// Automatic conversion into IntImm, Integer, and Bool, when called +// through the FFI. Automatic conversions into PrimExpr are +// registered in "tvm/tir/expr.h", as it includes conversions to the +// TIR-only StringImm. +// +// While the FFI only requires the From() method, these +// implementations also define a TryFrom() method to avoid duplicate +// logic in the PrimExpr conversion. + template <> -struct PackedFuncValueConverter { - static PrimExpr From(const TVMPODValue_& val) { - if (auto opt = val.TryAsBool()) { - return Bool(opt.value()); - } else if (auto opt = val.TryAsInt()) { +struct PackedFuncValueConverter { + static Optional TryFrom(const TVMPODValue_& val) { + if (auto opt = val.TryAsInt()) { int64_t value = opt.value(); auto dtype = (value > std::numeric_limits::max() || value < std::numeric_limits::min()) ? DataType::Int(64) : DataType::Int(32); return IntImm(dtype, value); - } else if (auto opt = val.TryAsFloat()) { - return FloatImm(runtime::DataType::Float(32), opt.value()); + } else if (auto opt = val.TryAsBool()) { + return IntImm(DataType::Int(32), opt.value()); } else { - return val.AsObjectRef(); + return NullOpt; + } + } + + static tvm::IntImm From(const TVMPODValue_& val) { + if (auto opt = TryFrom(val)) { + return opt.value(); + } else { + return val.AsObjectRef(); } } }; @@ -795,6 +810,8 @@ struct PackedFuncValueConverter { static tvm::Integer From(const TVMPODValue_& val) { if (auto opt = val.TryAsInt()) { return Integer(opt.value()); + } else if (auto opt = val.TryAsBool()) { + return Integer(opt.value()); } else { return val.AsObjectRef(); } @@ -803,7 +820,7 @@ struct PackedFuncValueConverter { template <> struct PackedFuncValueConverter { - static tvm::Bool From(const TVMPODValue_& val) { + static Optional TryFrom(const TVMPODValue_& val) { if (auto opt = val.TryAsBool()) { return Bool(opt.value()); } else if (auto opt = val.TryAsInt()) { @@ -811,12 +828,39 @@ struct PackedFuncValueConverter { ICHECK(value == 0 || value == 1) << "ValueError: boolean value can only be 0 or 1, but get " << value; return Bool(static_cast(value)); + } else { + return NullOpt; + } + } + + static tvm::Bool From(const TVMPODValue_& val) { + if (auto opt = TryFrom(val)) { + return opt.value(); } else { return val.AsObjectRef(); } } }; +template <> +struct PackedFuncValueConverter { + static Optional TryFrom(const TVMPODValue_& val) { + if (auto opt = val.TryAsFloat()) { + return FloatImm(runtime::DataType::Float(32), opt.value()); + } else { + return NullOpt; + } + } + + static tvm::FloatImm From(const TVMPODValue_& val) { + if (auto opt = TryFrom(val)) { + return opt.value(); + } else { + return val.AsObjectRef(); + } + } +}; + } // namespace runtime } // namespace tvm #endif // TVM_IR_EXPR_H_ diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index d47ac94e067e..4c1d1fc1f3d2 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -113,7 +113,15 @@ class TargetNode : public Object { "Can only call GetAttr with ObjectRef types."); auto it = attrs.find(attr_key); if (it != attrs.end()) { - return Downcast>((*it).second); + // For backwards compatibility, return through TVMRetValue. + // This triggers any automatic conversions registered with + // PackedFuncValueConverter. Importantly, this allows use of + // `GetAttr` and `GetAttr` for properties that + // are stored internally as `runtime::Box` and + // `runtime::Box`. + TVMRetValue ret; + ret = (*it).second; + return ret; } else { return default_value; } diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 130aea32f844..fb9c0f17b011 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -445,8 +445,8 @@ constexpr const char* kRelayToTIR = "RelayToTIR"; .add_attr_option("model") \ .add_attr_option>("libs") \ .add_attr_option("host") \ - .add_attr_option("from_device") \ - .add_attr_option("target_device_type") + .add_attr_option("from_device") \ + .add_attr_option("target_device_type") } // namespace tvm diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 4e29eddadd8c..7d1e5e3768de 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -1150,6 +1150,66 @@ inline std::unordered_map as_unordered_map(const Map& dmap) { } // namespace tir } // namespace tvm +namespace tvm { +namespace runtime { + +// Automatic conversion into PrimExpr, when called through the FFI. +// Automatic conversions into IntImm, Integer, and Bool are registered +// in "tvm/ir/expr.h", as they are currently in use outside of TIR. + +template <> +struct PackedFuncValueConverter { + template + static Optional TryFrom(const PODSubclass& val) { + auto type_code = val.type_code(); + bool can_convert = type_code == kTVMDataType || type_code == kTVMBytes || + type_code == kTVMStr || val.template IsObjectRef(); + if (can_convert) { + return tvm::tir::StringImm(PackedFuncValueConverter::From(val)); + } else { + return NullOpt; + } + } + + template + static tvm::tir::StringImm From(const PODSubclass& val) { + if (auto opt = TryFrom(val)) { + return opt.value(); + } else { + return val.template AsObjectRef(); + } + } +}; + +template <> +struct PackedFuncValueConverter { + // Common rule for RetValue and ArgValue. Templated to ensure + // correct delegation to `operator std::string()` for either + // TVMArgValue or TVMRetValue. + template + static PrimExpr From(const PODSubclass& val) { + if (auto opt = val.TryAsBool()) { + // Check against val.TryAsBool directly, to avoid the + // bounds-checking in PackedFuncValueConverter::TryFrom. + return Bool(opt.value()); + } else if (auto opt = PackedFuncValueConverter::TryFrom(val)) { + return opt.value(); + } else if (auto opt = PackedFuncValueConverter::TryFrom(val)) { + return opt.value(); + } else if (auto opt = PackedFuncValueConverter::TryFrom(val)) { + return opt.value(); + } else if (val.template IsObjectRef()) { + // Delegate to the implicit conversion from IterVar to PrimExpr + return val.template AsObjectRef(); + } else { + return val.template AsObjectRef(); + } + } +}; + +} // namespace runtime +} // namespace tvm + namespace std { template <> struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectPtrHash {}; diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index fde6daa4d851..d50aaeb8de12 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -4157,11 +4157,13 @@ bool ScanopRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } -Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Bool exclusive) { +Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Optional exclusive) { auto attrs = make_object(); attrs->dtype = dtype; attrs->axis = axis; - attrs->exclusive = exclusive; + if (exclusive.defined()) { + attrs->exclusive = exclusive.value(); + } static const Op& op = Op::Get("cumsum"); return Call(op, {data}, Attrs(attrs), {}); } diff --git a/src/target/tag.cc b/src/target/tag.cc index e6521d384397..f64ddf2d37c7 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -75,46 +75,46 @@ TVM_REGISTER_TARGET_TAG("raspberry-pi/4b-aarch64") {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("cortex-a72")}, {"mattr", Array{"+neon"}}, - {"num-cores", Integer(4)}, + {"num-cores", runtime::BoxInt(4)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("cortex-a72")}, {"mattr", Array{"+neon"}}, - {"num-cores", Integer(4)}}}}); + {"num-cores", runtime::BoxInt(4)}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-xavier") .set_config({{"kind", String("cuda")}, {"arch", String("sm_72")}, - {"max_shared_memory_per_block", Integer(49152)}, - {"max_threads_per_block", Integer(1024)}, - {"thread_warp_size", Integer(32)}, - {"registers_per_block", Integer(65536)}, + {"max_shared_memory_per_block", runtime::BoxInt(49152)}, + {"max_threads_per_block", runtime::BoxInt(1024)}, + {"thread_warp_size", runtime::BoxInt(32)}, + {"registers_per_block", runtime::BoxInt(65536)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("carmel")}, - {"num-cores", Integer(8)}}}}); + {"num-cores", runtime::BoxInt(8)}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-orin-nano") .set_config({{"kind", String("cuda")}, {"arch", String("sm_87")}, - {"max_shared_memory_per_block", Integer(49152)}, - {"max_threads_per_block", Integer(1024)}, - {"thread_warp_size", Integer(32)}, - {"registers_per_block", Integer(65536)}, + {"max_shared_memory_per_block", runtime::BoxInt(49152)}, + {"max_threads_per_block", runtime::BoxInt(1024)}, + {"thread_warp_size", runtime::BoxInt(32)}, + {"registers_per_block", runtime::BoxInt(65536)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("carmel")}, - {"num-cores", Integer(6)}}}}); + {"num-cores", runtime::BoxInt(6)}}}}); -#define TVM_REGISTER_CUDA_TAG(Name, Arch, SharedMem, RegPerBlock) \ - TVM_REGISTER_TARGET_TAG(Name).set_config({ \ - {"kind", String("cuda")}, \ - {"keys", Array{"cuda", "gpu"}}, \ - {"arch", String(Arch)}, \ - {"max_shared_memory_per_block", Integer(SharedMem)}, \ - {"max_threads_per_block", Integer(1024)}, \ - {"thread_warp_size", Integer(32)}, \ - {"registers_per_block", Integer(RegPerBlock)}, \ +#define TVM_REGISTER_CUDA_TAG(Name, Arch, SharedMem, RegPerBlock) \ + TVM_REGISTER_TARGET_TAG(Name).set_config({ \ + {"kind", String("cuda")}, \ + {"keys", Array{"cuda", "gpu"}}, \ + {"arch", String(Arch)}, \ + {"max_shared_memory_per_block", runtime::BoxInt(SharedMem)}, \ + {"max_threads_per_block", runtime::BoxInt(1024)}, \ + {"thread_warp_size", runtime::BoxInt(32)}, \ + {"registers_per_block", runtime::BoxInt(RegPerBlock)}, \ }) // Naming convention for CUDA tags see https://developer.nvidia.com/cuda-gpus @@ -386,7 +386,7 @@ TVM_REGISTER_CUDA_TAG("nvidia/tegra-x1", "sm_53", 49152, 32768); TVM_REGISTER_TARGET_TAG(Name).set_config({{"kind", String("llvm")}, \ {"keys", Array{"x86", "cpu"}}, \ {"mcpu", String(Arch)}, \ - {"num-cores", Integer(Cores)}}); + {"num-cores", runtime::BoxInt(Cores)}}); TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.large", 1, "skylake-avx512"); TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.xlarge", 2, "skylake-avx512"); @@ -402,9 +402,9 @@ TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.24xlarge", 48, "cascadelake"); #define TVM_REGISTER_METAL_GPU_TAG(Name, ThreadsPerBlock, SharedMem, WarpSize) \ TVM_REGISTER_TARGET_TAG(Name).set_config( \ {{"kind", String("metal")}, \ - {"max_threads_per_block", Integer(ThreadsPerBlock)}, \ - {"max_shared_memory_per_block", Integer(SharedMem)}, \ - {"thread_warp_size", Integer(WarpSize)}, \ + {"max_threads_per_block", runtime::BoxInt(ThreadsPerBlock)}, \ + {"max_shared_memory_per_block", runtime::BoxInt(SharedMem)}, \ + {"thread_warp_size", runtime::BoxInt(WarpSize)}, \ {"host", Map{{"kind", String("llvm")}, \ {"mtriple", String("arm64-apple-macos")}, \ {"mcpu", String("apple-latest")}}}}); diff --git a/src/target/target.cc b/src/target/target.cc index cd2e3714e422..10223b239385 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -359,24 +359,31 @@ const TargetKindNode::ValueTypeInfo& TargetInternal::FindTypeInfo(const TargetKi ObjectRef TargetInternal::ParseType(const std::string& str, const TargetKindNode::ValueTypeInfo& info) { std::string interp_str = Interpret(str); - if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { - // Parsing integer + if (info.type_index == runtime::BoxInt::ContainerType::_GetOrAllocRuntimeTypeIndex() || + info.type_index == runtime::BoxBool::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + // Parsing integer or boolean std::istringstream is(interp_str); int v; if (!(is >> v)) { std::string lower(interp_str.size(), '\x0'); std::transform(interp_str.begin(), interp_str.end(), lower.begin(), [](unsigned char c) { return std::tolower(c); }); - // Bool is a subclass of IntImm, so allow textual boolean values. + // Mimic C++ automatic conversions, allowing bool to be used for + // integer parameters. if (lower == "true") { v = 1; } else if (lower == "false") { v = 0; } else { - throw Error(": Cannot parse into type \"Integer\" from string: " + interp_str); + throw Error(": Cannot parse integer from string: " + interp_str); } } - return Integer(v); + + if (info.type_index == runtime::BoxInt::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + return runtime::BoxInt(v); + } else { + return runtime::BoxBool(v); + } } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing string, strip leading/trailing spaces, and enclosing quotes if any auto start = interp_str.find_first_not_of(' '); @@ -410,13 +417,14 @@ ObjectRef TargetInternal::ParseType(const std::string& str, ObjectRef TargetInternal::ParseType(const ObjectRef& obj, const TargetKindNode::ValueTypeInfo& info) { - if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + if (info.type_index == runtime::BoxInt::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing integer - return GetRef(ObjTypeCheck(obj, "Integer")); - } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + return GetRef( + ObjTypeCheck(obj, "runtime.BoxInt")); + } else if (info.type_index == String::ContainerType::RuntimeTypeIndex()) { // Parsing string return GetRef(ObjTypeCheck(obj, "String")); - } else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + } else if (info.type_index == Target::ContainerType::RuntimeTypeIndex()) { // Parsing target if (auto opt = obj.as()) { return opt.value(); @@ -483,7 +491,9 @@ ObjectRef TargetInternal::ParseType(const ObjectRef& obj, /********** Stringifying **********/ std::string TargetInternal::StringifyAtomicType(const ObjectRef& obj) { - if (const auto* p = obj.as()) { + if (const auto* p = obj.as>()) { + return std::to_string(p->value); + } else if (const auto* p = obj.as>()) { return std::to_string(p->value); } if (auto tvm_str = obj.as()) { @@ -953,7 +963,7 @@ ObjectPtr TargetInternal::FromConfig(Map config) { // If requested, query attributes from the device. User-specified // parameters take precedence over queried parameters. if (attrs.count("from_device")) { - int device_id = Downcast(attrs.at("from_device")).IntValue(); + int device_id = Downcast(attrs.at("from_device"))->value; attrs.erase("from_device"); auto device_params = QueryDevice(device_id, target.get()); @@ -1006,38 +1016,13 @@ std::unordered_map TargetInternal::QueryDevice(int device_id, for (const auto& kv : target->kind->key2vtype_) { const String& key = kv.first; - const TargetKindNode::ValueTypeInfo& type_info = kv.second; TVMRetValue ret; api->GetTargetProperty(device, key, &ret); - switch (ret.type_code()) { - case kTVMNullptr: - // Nothing returned for this parameter, move on to the next one. - continue; - - case kTVMArgInt: - if (type_info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { - output[key] = Integer(static_cast(ret)); - } else if (type_info.type_index == Bool::ContainerType::_GetOrAllocRuntimeTypeIndex()) { - output[key] = Bool(static_cast(ret)); - } else { - LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key - << "', but received integer from device api"; - } - break; - - case kTVMStr: - ICHECK_EQ(type_info.type_index, String::ContainerType::_GetOrAllocRuntimeTypeIndex()) - << "Expected " << type_info.type_key << " parameter for attribute '" << key - << "', but received string from device api"; - output[key] = String(ret.operator std::string()); - break; - - default: - LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key - << "', but received TVMArgTypeCode(" << ret.type_code() << ") from device api"; - break; + // Delegate conversion from TVMRetValue to the FFI's default conversions. + if (Optional opt = ret) { + output[key] = opt.value(); } } diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index aa4499ec9667..c659859c6b2b 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -266,7 +266,7 @@ TargetJSON UpdateROCmAttrs(TargetJSON target) { * \return The updated attributes */ TargetJSON TestTargetParser(TargetJSON target) { - Map features = {{"is_test", Bool(true)}}; + Map features = {{"is_test", runtime::BoxBool(true)}}; target.Set("features", features); return target; } @@ -279,22 +279,22 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("mtriple") .add_attr_option("mfloat-abi") .add_attr_option("mabi") - .add_attr_option("num-cores") + .add_attr_option("num-cores") // Fast math flags, see https://llvm.org/docs/LangRef.html#fast-math-flags - .add_attr_option("fast-math") // implies all the below - .add_attr_option("fast-math-nnan") - .add_attr_option("fast-math-ninf") - .add_attr_option("fast-math-nsz") - .add_attr_option("fast-math-arcp") - .add_attr_option("fast-math-contract") - .add_attr_option("fast-math-reassoc") - .add_attr_option("opt-level") + .add_attr_option("fast-math") // implies all the below + .add_attr_option("fast-math-nnan") + .add_attr_option("fast-math-ninf") + .add_attr_option("fast-math-nsz") + .add_attr_option("fast-math-arcp") + .add_attr_option("fast-math-contract") + .add_attr_option("fast-math-reassoc") + .add_attr_option("opt-level") // LLVM command line flags, see below .add_attr_option>("cl-opt") .set_default_keys({"cpu"}) // Force the external codegen kind attribute to be registered, even if no external // codegen targets are enabled by the TVM build. - .set_attr(tvm::attr::kIsExternalCodegen, Bool(false)) + .set_attr(tvm::attr::kIsExternalCodegen, runtime::BoxBool(false)) .set_target_parser(tvm::target::parsers::cpu::ParseTarget); // Note regarding the "cl-opt" attribute: @@ -322,28 +322,29 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) TVM_REGISTER_TARGET_KIND("c", kDLCPU) .add_attr_option("mcpu") .add_attr_option("march") - .add_attr_option("workspace-byte-alignment") - .add_attr_option("constants-byte-alignment") + .add_attr_option("workspace-byte-alignment") + .add_attr_option("constants-byte-alignment") .set_default_keys({"cpu"}) .set_target_parser(tvm::target::parsers::cpu::ParseTarget); TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA) .add_attr_option("mcpu") .add_attr_option("arch") - .add_attr_option("max_shared_memory_per_block") - .add_attr_option("max_threads_per_block") - .add_attr_option("thread_warp_size", Integer(32)) - .add_attr_option("registers_per_block") - .add_attr_option("l2_cache_size_bytes") - .add_attr_option("max_num_threads", Integer(1024)) // TODO(@zxybazh): deprecate it + .add_attr_option("max_shared_memory_per_block") + .add_attr_option("max_threads_per_block") + .add_attr_option("thread_warp_size", runtime::BoxInt(32)) + .add_attr_option("registers_per_block") + .add_attr_option("l2_cache_size_bytes") + .add_attr_option("max_num_threads", + runtime::BoxInt(1024)) // TODO(@zxybazh): deprecate it .set_default_keys({"cuda", "gpu"}) .set_target_parser(UpdateCUDAAttrs); TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA) .add_attr_option("mcpu") .add_attr_option("mtriple") - .add_attr_option("max_num_threads", Integer(1024)) - .add_attr_option("thread_warp_size", Integer(32)) + .add_attr_option("max_num_threads", runtime::BoxInt(1024)) + .add_attr_option("thread_warp_size", runtime::BoxInt(32)) .set_default_keys({"cuda", "gpu"}) .set_target_parser(UpdateNVPTXAttrs); @@ -353,24 +354,24 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM) .add_attr_option>("mattr") // TODO(masahi): Support querying from a target device // On RDNA cards, thread_warp_size should be 32 - .add_attr_option("max_num_threads", Integer(256)) - .add_attr_option("max_threads_per_block", Integer(256)) - .add_attr_option("max_shared_memory_per_block", Integer(65536)) - .add_attr_option("thread_warp_size", Integer(64)) + .add_attr_option("max_num_threads", runtime::BoxInt(256)) + .add_attr_option("max_threads_per_block", runtime::BoxInt(256)) + .add_attr_option("max_shared_memory_per_block", runtime::BoxInt(65536)) + .add_attr_option("thread_warp_size", runtime::BoxInt(64)) .set_default_keys({"rocm", "gpu"}) .set_target_parser(UpdateROCmAttrs); TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL) - .add_attr_option("max_threads_per_block", Integer(256)) - .add_attr_option("max_shared_memory_per_block", Integer(16384)) - .add_attr_option("max_num_threads", Integer(256)) - .add_attr_option("thread_warp_size", Integer(1)) - .add_attr_option("texture_spatial_limit", Integer(16384)) + .add_attr_option("max_threads_per_block", runtime::BoxInt(256)) + .add_attr_option("max_shared_memory_per_block", runtime::BoxInt(16384)) + .add_attr_option("max_num_threads", runtime::BoxInt(256)) + .add_attr_option("thread_warp_size", runtime::BoxInt(1)) + .add_attr_option("texture_spatial_limit", runtime::BoxInt(16384)) // Faced that Qualcomm OpenCL runtime crashed without any error message in // the case when the number of kernel arguments was pretty big. OpenCL doesn't // specify any limitations on the number of kernel arguments. max_function_args // equals to 128 looks like a reasonable number of kernel arguments. - .add_attr_option("max_function_args", Integer(128)) + .add_attr_option("max_function_args", runtime::BoxInt(128)) .set_default_keys({"opencl", "gpu"}); // The metal has some limitations on the number of input parameters. This is why attribute @@ -379,55 +380,55 @@ TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL) // https://developer.apple.com/documentation/metal/buffers/about_argument_buffers?language=objc // See also https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf TVM_REGISTER_TARGET_KIND("metal", kDLMetal) - .add_attr_option("max_num_threads", Integer(256)) - .add_attr_option("max_threads_per_block", Integer(256)) - .add_attr_option("max_shared_memory_per_block", Integer(32768)) - .add_attr_option("thread_warp_size", Integer(16)) - .add_attr_option("max_function_args", Integer(31)) + .add_attr_option("max_num_threads", runtime::BoxInt(256)) + .add_attr_option("max_threads_per_block", runtime::BoxInt(256)) + .add_attr_option("max_shared_memory_per_block", runtime::BoxInt(32768)) + .add_attr_option("thread_warp_size", runtime::BoxInt(16)) + .add_attr_option("max_function_args", runtime::BoxInt(31)) .set_default_keys({"metal", "gpu"}); TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option>("mattr") // Feature support - .add_attr_option("supports_float16") - .add_attr_option("supports_float32", Bool(true)) - .add_attr_option("supports_float64") - .add_attr_option("supports_int8") - .add_attr_option("supports_int16") - .add_attr_option("supports_int32", Bool(true)) - .add_attr_option("supports_int64") - .add_attr_option("supports_8bit_buffer") - .add_attr_option("supports_16bit_buffer") - .add_attr_option("supports_storage_buffer_storage_class") - .add_attr_option("supports_push_descriptor") - .add_attr_option("supports_dedicated_allocation") - .add_attr_option("supports_integer_dot_product") - .add_attr_option("supports_cooperative_matrix") - .add_attr_option("supported_subgroup_operations") + .add_attr_option("supports_float16") + .add_attr_option("supports_float32", runtime::BoxBool(true)) + .add_attr_option("supports_float64") + .add_attr_option("supports_int8") + .add_attr_option("supports_int16") + .add_attr_option("supports_int32", runtime::BoxBool(true)) + .add_attr_option("supports_int64") + .add_attr_option("supports_8bit_buffer") + .add_attr_option("supports_16bit_buffer") + .add_attr_option("supports_storage_buffer_storage_class") + .add_attr_option("supports_push_descriptor") + .add_attr_option("supports_dedicated_allocation") + .add_attr_option("supports_integer_dot_product") + .add_attr_option("supports_cooperative_matrix") + .add_attr_option("supported_subgroup_operations") // Physical device limits - .add_attr_option("max_num_threads", Integer(256)) - .add_attr_option("max_threads_per_block", Integer(256)) - .add_attr_option("thread_warp_size", Integer(1)) - .add_attr_option("max_block_size_x") - .add_attr_option("max_block_size_y") - .add_attr_option("max_block_size_z") - .add_attr_option("max_push_constants_size") - .add_attr_option("max_uniform_buffer_range") - .add_attr_option("max_storage_buffer_range") - .add_attr_option("max_per_stage_descriptor_storage_buffer") - .add_attr_option("max_shared_memory_per_block") + .add_attr_option("max_num_threads", runtime::BoxInt(256)) + .add_attr_option("max_threads_per_block", runtime::BoxInt(256)) + .add_attr_option("thread_warp_size", runtime::BoxInt(1)) + .add_attr_option("max_block_size_x") + .add_attr_option("max_block_size_y") + .add_attr_option("max_block_size_z") + .add_attr_option("max_push_constants_size") + .add_attr_option("max_uniform_buffer_range") + .add_attr_option("max_storage_buffer_range") + .add_attr_option("max_per_stage_descriptor_storage_buffer") + .add_attr_option("max_shared_memory_per_block") // Other device properties .add_attr_option("device_type") .add_attr_option("device_name") .add_attr_option("driver_name") - .add_attr_option("driver_version") - .add_attr_option("vulkan_api_version") - .add_attr_option("max_spirv_version") + .add_attr_option("driver_version") + .add_attr_option("vulkan_api_version") + .add_attr_option("max_spirv_version") // Tags .set_default_keys({"vulkan", "gpu"}); TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU) - .add_attr_option("max_num_threads", Integer(256)) + .add_attr_option("max_num_threads", runtime::BoxInt(256)) .set_default_keys({"webgpu", "gpu"}); TVM_REGISTER_TARGET_KIND("sdaccel", kDLOpenCL) // line break @@ -444,8 +445,8 @@ TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon) .add_attr_option("mcpu") .add_attr_option("mtriple") .add_attr_option>("llvm-options") - .add_attr_option("num-cores") - .add_attr_option("vtcm-capacity") + .add_attr_option("num-cores") + .add_attr_option("vtcm-capacity") .set_default_keys({"hexagon", "cpu"}); TVM_REGISTER_TARGET_KIND("stackvm", kDLCPU) // line break diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 41500051fa89..7064433a2874 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -510,7 +510,9 @@ Call::Call(DataType dtype, RelayExpr op, Array args, Span span) { } TVM_REGISTER_GLOBAL("tir.Call") - .set_body_typed([](DataType type, RelayExpr op, Array args, Span span) { + .set_body_typed([](DataType type, RelayExpr op, + Array> args, + Span span) { Array prim_expr_args; for (const auto& it : args) { ICHECK(it->IsInstance() || it->IsInstance() || From 9fc7ebf4b0f86fe0a634a1f8a8cc5bb4ea06952c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 29 Nov 2023 11:53:09 -0600 Subject: [PATCH 20/21] relax tuple get item unit test --- tests/python/relax/test_tuple_get_item.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/python/relax/test_tuple_get_item.py b/tests/python/relax/test_tuple_get_item.py index ba8a9a6eb836..d9a4a4ebacde 100644 --- a/tests/python/relax/test_tuple_get_item.py +++ b/tests/python/relax/test_tuple_get_item.py @@ -95,5 +95,22 @@ def func( tvm.ir.assert_structural_equal(func, roundtrip) +def test_tuple_get_item_simple(): + exec_mode = "bytecode" + + @R.function(private=True) + def func(arg: R.Tuple([R.Prim("int64"), R.Prim("float32")])): + return arg[0] + + mod = tvm.IRModule({"main": func}) + + target = tvm.target.Target("llvm", host="llvm") + ex = tvm.relax.build(mod, target, exec_mode=exec_mode) + vm = tvm.relax.VirtualMachine(ex, tvm.cpu()) + + res = vm["main"]((17, 42.5)) + assert res == 17 + + if __name__ == "__main__": tvm.testing.main() From 924c8ffee3d67a71d495fa194902814bc38eaed4 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 29 Nov 2023 12:26:11 -0600 Subject: [PATCH 21/21] Fixed type used in target tags --- src/target/tag.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/target/tag.cc b/src/target/tag.cc index f64ddf2d37c7..4360cf389bcf 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -130,7 +130,7 @@ TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2075", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2050", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2070", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a100", "sm_80", 49152, 65536) - .with_config("l2_cache_size_bytes", Integer(41943040)); + .with_config("l2_cache_size_bytes", runtime::BoxInt(41943040)); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a40", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a30", "sm_80", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a10", "sm_86", 49152, 65536); @@ -233,7 +233,7 @@ TVM_REGISTER_CUDA_TAG("nvidia/nvs-5400m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvs-5200m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvs-4200m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-4090", "sm_89", 49152, 65536) - .with_config("l2_cache_size_bytes", Integer(75497472)); + .with_config("l2_cache_size_bytes", runtime::BoxInt(75497472)); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3090-ti", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3090", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3080-ti", "sm_86", 49152, 65536);