diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 5b44f79ad70a..380c2fcce25d 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -411,8 +411,11 @@ Var EnvThread(String thread_tag, DataType dtype = DataType::Int(32)); * \param buffer The buffer. * \param value The value to be stored. * \param indices The indices location to be stored. + * \param predicate A vector mask of boolean values indicating which lanes of a vector are to be + * stored. The number lanes of the mask must be equal to the number of lanes in value. */ -void BufferStore(Buffer buffer, PrimExpr value, Array indices); +void BufferStore(Buffer buffer, PrimExpr value, Array indices, + Optional predicate); /*! * \brief The prefetch hint for a buffer diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index b2736a30e4bb..276198abb89c 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -209,14 +209,20 @@ class Buffer : public ObjectRef { * \brief Create an Expr that does a vector load at begin index. * \param begin The beginning index * \param dtype The data type to be loaded. + * \param predicate A vector mask of boolean values indicating which lanes of a vector are to be + * loaded. The number lanes of the mask must be equal to the number of lanes in being loaded. */ - TVM_DLL PrimExpr vload(Array begin, DataType dtype) const; + TVM_DLL PrimExpr vload(Array begin, DataType dtype, + Optional predicate = NullOpt) const; /*! * \brief Create a Stmt that does a vector store at begin index. * \param begin The beginning index * \param value The value to be stored. + * \param predicate A vector mask of boolean values indicating which lanes of a vector are to be + * stored. The number lanes of the mask must be equal to the number of lanes in value. */ - TVM_DLL Stmt vstore(Array begin, PrimExpr value) const; + TVM_DLL Stmt vstore(Array begin, PrimExpr value, + Optional predicate = NullOpt) const; /*! * \brief Get a flattened version of the buffer diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 39b32f563350..d9b65dc8745c 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -630,11 +630,14 @@ class BufferLoadNode : public PrimExprNode { Buffer buffer; /*! \brief The indices location to be loaded. */ Array indices; + /*! \brief The predicate mask for loading values. */ + Optional predicate; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &(this->dtype)); v->Visit("buffer", &buffer); v->Visit("indices", &indices); + v->Visit("predicate", &predicate); v->Visit("span", &span); } @@ -647,6 +650,7 @@ class BufferLoadNode : public PrimExprNode { hash_reduce(dtype); hash_reduce(buffer); hash_reduce(indices); + hash_reduce(predicate); } static constexpr const char* _type_key = "tir.BufferLoad"; @@ -675,7 +679,8 @@ class BufferLoadNode : public PrimExprNode { */ class BufferLoad : public PrimExpr { public: - TVM_DLL explicit BufferLoad(Buffer buffer, Array indices, Span span = Span()); + TVM_DLL explicit BufferLoad(Buffer buffer, Array indices, + Optional predicate = NullOpt, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode); }; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 07cc9b5ad0d5..c77254ed34cb 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -231,11 +231,14 @@ class BufferStoreNode : public StmtNode { PrimExpr value; /*! \brief The indices location to be stored. */ Array indices; + /*! \brief The predicate mask for storing values. */ + Optional predicate; void VisitAttrs(AttrVisitor* v) { v->Visit("buffer", &buffer); v->Visit("value", &value); v->Visit("indices", &indices); + v->Visit("predicate", &predicate); v->Visit("span", &span); } @@ -248,6 +251,7 @@ class BufferStoreNode : public StmtNode { hash_reduce(buffer); hash_reduce(value); hash_reduce(indices); + hash_reduce(predicate); } static constexpr const char* _type_key = "tir.BufferStore"; @@ -261,7 +265,7 @@ class BufferStoreNode : public StmtNode { class BufferStore : public Stmt { public: TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array indices, - Span span = Span()); + Optional predicate = NullOpt, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode); diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index cb6e031667c5..756dbc4992f4 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -57,6 +57,31 @@ def _updater(data): return _updater +def create_updater_16_to_17(): + """ + Create an update to upgrade json from v0.16 to v0.17 + + Returns + ------- + fupdater : function + The updater function + """ + + def _update_predicate_argument(item, nodes): + null_value_idx = 0 + null_value = nodes[null_value_idx] + assert str(null_value) == "{'type_key': ''}", f"Expected a null value but got {null_value}" + item["attrs"]["predicate"] = str(null_value_idx) + return item + + node_map = { + "tir.BufferLoad": _update_predicate_argument, + "tir.BufferStore": _update_predicate_argument, + } + + return create_updater(node_map, "0.16", "0.17") + + def create_updater_15_to_16(): """ Create an update to upgrade json from v0.15 to v0.16 @@ -316,5 +341,7 @@ def _from_version(data): data = create_updater({}, "0.14", "0.15")(data) if _from_version(data).startswith("0.15"): data = create_updater_15_to_16()(data) + if _from_version(data).startswith("0.16"): + data = create_updater_16_to_17()(data) return json.dumps(data, indent=2) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 5a0a564a2ab5..8289ea96ae25 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1265,6 +1265,7 @@ def buffer_store( buffer: Buffer, # pylint: disable=redefined-outer-name value: PrimExpr, indices: List[Union[PrimExpr, slice]], + predicate: Optional[PrimExpr] = None, ) -> None: """Buffer store node. @@ -1278,6 +1279,11 @@ def buffer_store( indices : List[Union[PrimExpr, slice]] The indices location to be stored. + + predicate : Optional[PrimExpr] + A vector mask of boolean values indicating which lanes of a vector are to be + stored. The number lanes of the mask must be equal to the number of lanes in + value. """ from tvm.arith import Analyzer # pylint: disable=import-outside-toplevel @@ -1298,7 +1304,7 @@ def buffer_store( if isinstance(value, bool) and buffer.dtype == "bool": value = IntImm("bool", value) return _ffi_api.BufferStore( # type: ignore[attr-defined] # pylint: disable=no-member - buffer, value, expr_indices + buffer, value, expr_indices, predicate ) diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 679ae4e8adc0..600099bb0afb 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -462,6 +462,8 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: elif isinstance(res, str): # Ignore docstrings pass + elif isinstance(res, tvm.tir.stmt.BufferStore): + T.buffer_store(res.buffer, res.value, res.indices, res.predicate) else: self.report_error(node, f"Parsing resulted in unexpected type {type(res)}") diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index ec57ad7801ca..501d13b17e3d 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -101,7 +101,7 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0, self, access_mask, ptr_type, content_lanes, offset, extent # type: ignore ) - def vload(self, begin, dtype=None): + def vload(self, begin, dtype=None, predicate=None): """Generate an Expr that loads dtype from begin index. Parameters @@ -113,6 +113,10 @@ def vload(self, begin, dtype=None): The data type to be loaded, can be vector type which have lanes that is multiple of Buffer.dtype + predicate : Optional[PrimExpr] + A vector mask of boolean values indicating which lanes of a vector are to be + loaded. The number lanes of the mask must be equal to the number of lanes being loaded. + Returns ------- load : Expr @@ -120,9 +124,9 @@ def vload(self, begin, dtype=None): """ begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin dtype = dtype if dtype else self.dtype - return _ffi_api.BufferVLoad(self, begin, dtype) # type: ignore + return _ffi_api.BufferVLoad(self, begin, dtype, predicate) # type: ignore - def vstore(self, begin, value): + def vstore(self, begin, value, predicate=None): """Generate a Stmt that store value into begin index. Parameters @@ -133,13 +137,18 @@ def vstore(self, begin, value): value : Expr The value to be stored. + predicate : Optional[PrimExpr] + A vector mask of boolean values indicating which lanes of a vector are to be + stored. The number lanes of the mask must be equal to the number of lanes in + value. + Returns ------- store : Stmt The corresponding store stmt. """ begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin - return _ffi_api.BufferVStore(self, begin, value) # type: ignore + return _ffi_api.BufferVStore(self, begin, value, predicate) # type: ignore def scope(self): """Return the storage scope associated with this buffer. diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index fca501874d94..c78bb9e7ecd0 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -1093,20 +1093,28 @@ class BufferLoad(PrimExprWithOp): The buffer to be loaded. indices : List[PrimExpr] - The buffer indices. + The buffer indices to load values from. span : Optional[Span] The location of this expression in the source code. + + predicate : Optional[PrimExpr] + A vector mask of boolean values indicating which lanes of a vector are to be + loaded. The number lanes of the mask must be equal to the number of lanes being loaded. """ buffer: Buffer indices: List[PrimExpr] def __init__( - self, buffer: Buffer, indices: List[PrimExpr], span: Optional[Span] = None + self, + buffer: Buffer, + indices: List[PrimExpr], + predicate: Optional[PrimExpr] = None, + span: Optional[Span] = None, ) -> None: self.__init_handle_by_constructor__( - _ffi_api.BufferLoad, buffer, indices, span # type: ignore + _ffi_api.BufferLoad, buffer, indices, predicate, span # type: ignore ) diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 992c388e27bb..aa3b17a7a12f 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -224,6 +224,11 @@ class BufferStore(Stmt): indices : List[PrimExpr] The indices location to be stored. + predicate : Optional[PrimExpr] + A vector mask of boolean values indicating which lanes of a vector are to be + stored. The number lanes of the mask must be equal to the number of lanes in + value. + span : Optional[Span] The location of the stmt in the source code. """ @@ -231,6 +236,7 @@ class BufferStore(Stmt): buffer: Buffer value: PrimExpr indices: List[PrimExpr] + predicate: Optional[PrimExpr] span: Optional[Span] def __init__( @@ -238,10 +244,11 @@ def __init__( buffer: Buffer, value: PrimExpr, indices: List[PrimExpr], + predicate: Optional[PrimExpr] = None, span: Optional[Span] = None, ) -> None: self.__init_handle_by_constructor__( - _ffi_api.BufferStore, buffer, value, indices, span # type: ignore + _ffi_api.BufferStore, buffer, value, indices, predicate, span # type: ignore ) diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 0c4248bd3f26..08d5e9379dc6 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -233,15 +233,16 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { // "T.vscale" and the compile target uses a scalable architecture extension like // SVE, we can make some assumptions about the value of vscale and iterate over a // space of pre-defined values to attempt to prove the expression. + Target curr_target = Target::Current(); if (ContainsVscaleCall(simplified)) { - if (TargetHasSVE()) { + if (TargetHasSVE(curr_target)) { return CanProveVscaleExpressionFromKnownValues(this, simplified, kAArch64VScaleValues); } LOG(WARNING) << "The expression contains scalable values. An attempt to prove by substituting " "with known values of vscale was not performed. This proof currently only supports " "AArch64 SVE targets, but the target was " - << Target::Current(); + << curr_target; } return false; } diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 2f9d640ee712..ecd3b25bfc67 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -370,7 +370,7 @@ class ConstIntBoundAnalyzer::Impl return VisitLeftShift(op); } else if (op->op.same_as(tir::builtin::bitwise_and())) { return VisitBitwiseAnd(op); - } else if (op->op.same_as(tir::builtin::vscale()) && TargetHasSVE()) { + } else if (op->op.same_as(tir::builtin::vscale()) && TargetHasSVE(Target::Current())) { unsigned int max_val = *std::max_element(kAArch64VScaleValues.begin(), kAArch64VScaleValues.end()); return MakeBound(1, max_val); diff --git a/src/arith/scalable_expression.cc b/src/arith/scalable_expression.cc index 2df035d6151a..e5f3bc28ba52 100644 --- a/src/arith/scalable_expression.cc +++ b/src/arith/scalable_expression.cc @@ -93,8 +93,7 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr return can_prove_expr; } -bool TargetHasSVE() { - Target current_target = Target::Current(); +bool TargetHasSVE(Target current_target) { bool has_sve{false}; if (current_target.defined()) { has_sve = current_target->GetFeature("has_sve").value_or(Bool(false)); diff --git a/src/arith/scalable_expression.h b/src/arith/scalable_expression.h index 8e807eb3b839..06ff8104e928 100644 --- a/src/arith/scalable_expression.h +++ b/src/arith/scalable_expression.h @@ -27,6 +27,7 @@ #include #include +#include #include #include @@ -79,9 +80,10 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr /*! * \brief Check whether the compilation target supports SVE + * \param target The target to check. * \return Whether SVE is supported */ -bool TargetHasSVE(); +bool TargetHasSVE(Target target); } // namespace arith } // namespace tvm diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 7ea5032fa0cc..3026f6e58f18 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -44,6 +44,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_buffer_level_predication", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool); diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 3ce5c15e6cd0..17353561ee54 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -524,7 +524,8 @@ Var EnvThread(String thread_tag, DataType dtype) { return var; } -void BufferStore(Buffer buffer, PrimExpr value, Array indices) { +void BufferStore(Buffer buffer, PrimExpr value, Array indices, + Optional predicate = NullOpt) { runtime::DataType buffer_dtype = buffer->dtype; bool is_index_scalable = indices.empty() ? false : indices.back().dtype().is_scalable_vector(); bool is_buffer_dtype_scalable = buffer_dtype.is_scalable_vector(); @@ -586,7 +587,7 @@ void BufferStore(Buffer buffer, PrimExpr value, Array indices) { } value = tvm::cast(lhs_dtype, value); } - AddToParent(tvm::tir::BufferStore(buffer, value, indices)); + AddToParent(tvm::tir::BufferStore(buffer, value, indices, predicate)); } void Prefetch(Buffer buffer, Array bounds) { diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index 45a0dfd2aea4..87db53061ceb 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -273,14 +273,33 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::BufferStore store, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc buffer = d->AsDoc(store->buffer, p->Attr("buffer")); - return AssignDoc(/*lhs=*/buffer[BufferIndices(store->indices, p->Attr("indices"), d)], - /*rhs=*/d->AsDoc(store->value, p->Attr("value")), NullOpt); + ExprDoc value = d->AsDoc(store->value, p->Attr("value")); + + // Use .vstore(...) syntax when there is a predicate + if (store->predicate.defined()) { + ExprDoc indices = d->AsDoc(store->indices, p->Attr("indices")); + ExprDoc predicate = d->AsDoc(store->predicate, p->Attr("predicate")); + return ExprStmtDoc( + buffer->Attr("vstore")->Call({indices, value}, {"predicate"}, {predicate})); + } + + return AssignDoc( + /*lhs=*/buffer[BufferIndices(store->indices, p->Attr("indices"), d)], + /*rhs=*/value, NullOpt); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::BufferLoad load, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc buffer = d->AsDoc(load->buffer, p->Attr("buffer")); + + // Use .vload(...) syntax when there is a predicate + if (load->predicate.defined()) { + ExprDoc indices = d->AsDoc(load->indices, p->Attr("indices")); + ExprDoc predicate = d->AsDoc(load->predicate, p->Attr("predicate")); + return buffer->Attr("vload")->Call({indices}, {"predicate"}, {predicate}); + } + return buffer[BufferIndices(load->indices, p->Attr("indices"), d)]; }); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 6fc083d17ccf..6098a3f32f0d 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1668,9 +1668,9 @@ bool CodeGenLLVM::HasAlignmentPadding(DataType dtype) { } void CodeGenLLVM::BufferAccessHelper( - Buffer buffer, Array indices, DataType value_dtype, - std::function + Buffer buffer, Array indices, Optional predicate, DataType value_dtype, + std::function make_instruction) { DataType buffer_element_dtype = buffer->dtype; @@ -1750,6 +1750,11 @@ void CodeGenLLVM::BufferAccessHelper( std::vector all_index_values = earlier_index_values; all_index_values.push_back(last_index_value); + llvm::Value* predicate_value = nullptr; + if (predicate.defined()) { + predicate_value = MakeValue(predicate.value()); + } + TypedPointer buffer_ptr = value_dtype.is_scalable_vector() ? CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, all_index_values, @@ -1758,7 +1763,8 @@ void CodeGenLLVM::BufferAccessHelper( : CreateBufferPtr( MakeValue(buffer->data), buffer_element_dtype, all_index_values, value_dtype.with_lanes(value_dtype.lanes() / last_index.dtype().lanes())); - auto instruction = make_instruction(buffer_ptr, subelement_i, alignment, is_volatile); + auto instruction = + make_instruction(buffer_ptr, subelement_i, predicate_value, alignment, is_volatile); AddAliasInfo(instruction, buffer->data.get(), last_index_origin, buffer_element_dtype_origin); } } @@ -1768,17 +1774,30 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { std::vector loads; - auto make_load = [this, &loads](TypedPointer buffer_ptr, int /* subelement_i */, int alignment, - bool is_volatile) { + auto make_load = [this, &loads](TypedPointer buffer_ptr, int /* subelement_i */, + llvm::Value* predicate, int alignment, bool is_volatile) { + llvm::Instruction* load = nullptr; + if (predicate != NULL) { + ICHECK(!is_volatile) + << "The masked load intrinsic does not support declaring load as volatile."; +#if TVM_LLVM_VERSION >= 130 + load = builder_->CreateMaskedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(alignment), + predicate); +#elif TVM_LLVM_VERSION >= 110 + load = builder_->CreateMaskedLoad(buffer_ptr.addr, llvm::Align(alignment), predicate); +#else + load = builder_->CreateMaskedLoad(buffer_ptr.addr, alignment, predicate); +#endif + } else { #if TVM_LLVM_VERSION >= 110 - auto load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, - llvm::Align(alignment), is_volatile); + load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(alignment), + is_volatile); #elif TVM_LLVM_VERSION >= 80 - auto load = - builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, alignment, is_volatile); + load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, alignment, is_volatile); #else - auto load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile); + load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile); #endif + } loads.push_back(load); return load; @@ -1787,7 +1806,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { // Pass all indices into BufferAccessHelper. In CodeGenLLVM, // non-flat indices will result in an error in CreateBufferPtr, but // a subclass may override CreateBufferPtr. - BufferAccessHelper(op->buffer, op->indices, value_dtype, make_load); + BufferAccessHelper(op->buffer, op->indices, op->predicate, value_dtype, make_load); if (loads.size() == 1) { return loads[0]; @@ -1902,24 +1921,39 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { llvm::Value* value = MakeValue(op->value); - auto make_store = [this, value](TypedPointer buffer_ptr, int subelement_i, int alignment, - bool is_volatile) { + auto make_store = [this, value](TypedPointer buffer_ptr, int subelement_i, llvm::Value* predicate, + int alignment, bool is_volatile) { llvm::Value* to_store = value; + llvm::Instruction* store; + if (subelement_i != -1) { to_store = builder_->CreateExtractElement(value, subelement_i); } + + if (predicate != NULL) { + ICHECK(!is_volatile) + << "The masked store intrinsic does not support declaring store as volatile."; #if TVM_LLVM_VERSION >= 110 - return builder_->CreateAlignedStore(to_store, buffer_ptr.addr, llvm::Align(alignment), - is_volatile); + store = + builder_->CreateMaskedStore(to_store, buffer_ptr.addr, llvm::Align(alignment), predicate); #else - return builder_->CreateAlignedStore(to_store, buffer_ptr.addr, alignment, is_volatile); + store = builder_->CreateMaskedStore(to_store, buffer_ptr.addr, alignment, predicate); #endif + } else { +#if TVM_LLVM_VERSION >= 110 + store = builder_->CreateAlignedStore(to_store, buffer_ptr.addr, llvm::Align(alignment), + is_volatile); +#else + store = builder_->CreateAlignedStore(to_store, buffer_ptr.addr, alignment, is_volatile); +#endif + } + return store; }; // Pass all indices into BufferAccessHelper. In CodeGenLLVM, // non-flat indices will result in an error in CreateBufferPtr, but // a subclass may override CreateBufferPtr. - BufferAccessHelper(op->buffer, op->indices, value_dtype, make_store); + BufferAccessHelper(op->buffer, op->indices, op->predicate, value_dtype, make_store); } void CodeGenLLVM::VisitStmt_(const ForNode* op) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 06b36cb183d3..302a0d97b3f4 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -330,6 +330,10 @@ class CodeGenLLVM : public ExprFunctor, * * \param indices The indices at which the buffer is being accessed. * + * \param predicate A vector mask of boolean values indicating which lanes of a + * vector are to be accessed. The number lanes of the mask must be equal to the + * number of lanes being accessed. + * * \param value_dtype The datatype to be read from (BufferLoad) or * written to (BufferStore) the buffer. * @@ -342,6 +346,8 @@ class CodeGenLLVM : public ExprFunctor, * stored/loaded. If -1, indicates that the entire type, * vector or scalar, should be written. * + * - predicate: The predicate mask of the buffer. + * * - alignment: The alignment to be used for the read/write. * * - is_volatile: Whether the read/write should be volatile. @@ -349,9 +355,9 @@ class CodeGenLLVM : public ExprFunctor, * - Should return the generated expression. */ void BufferAccessHelper( - Buffer buffer, Array indices, DataType value_dtype, - std::function + Buffer buffer, Array indices, Optional predicate, DataType value_dtype, + std::function make_instruction); // Initialize target virtual void InitTarget(); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 009fc1672ace..5f6f493e08a3 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -764,6 +764,7 @@ void CodeGenC::VisitStmt_(const DeclBufferNode* op) { this->PrintStmt(op->body); void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLINT(*) ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported."; + ICHECK(!op->predicate.defined()) << "Predicated buffer load is not supported."; DataType value_dtype = op->dtype; PrimExpr index = op->indices[0]; @@ -823,6 +824,7 @@ void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLI void CodeGenC::VisitStmt_(const BufferStoreNode* op) { ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported."; + ICHECK(!op->predicate.defined()) << "Predicated buffer store is not supported."; DataType value_dtype = op->value.dtype(); DataType element_dtype = op->buffer->dtype; diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index ba925056a379..f62e0db7ffdf 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -459,6 +459,7 @@ void CodeGenWebGPU::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // // to ensure correctness in the case of nested-expression // do not try to lift common printings from each case ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported."; + ICHECK(!op->predicate.defined()) << "Predicated buffer load is not supported."; DataType value_dtype = op->dtype; PrimExpr index = op->indices[0]; @@ -531,6 +532,8 @@ void CodeGenWebGPU::VisitStmt_(const LetStmtNode* op) { void CodeGenWebGPU::VisitStmt_(const BufferStoreNode* op) { CHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported."; + ICHECK(!op->predicate.defined()) << "Predicated buffer store is not supported."; + DataType value_dtype = op->value.dtype(); DataType element_dtype = op->buffer->dtype; PrimExpr index = op->indices[0]; diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 03de68e32624..c7dbf3f5e042 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -79,7 +79,7 @@ class BufferSubstituter : public StmtExprMutator { auto load = Downcast(StmtExprMutator::VisitExpr_(op)); auto it = buffer_map_.find(load->buffer.get()); if (it != buffer_map_.end()) { - return BufferLoad(it->second, load->indices, load->span); + return BufferLoad(it->second, load->indices, load->predicate, load->span); } return load; } @@ -88,7 +88,7 @@ class BufferSubstituter : public StmtExprMutator { auto store = Downcast(StmtExprMutator::VisitStmt_(op)); auto it = buffer_map_.find(store->buffer.get()); if (it != buffer_map_.end()) { - return BufferStore(it->second, store->value, store->indices, store->span); + return BufferStore(it->second, store->value, store->indices, store->predicate, store->span); } return store; } diff --git a/src/tir/analysis/device_constraint_utils.cc b/src/tir/analysis/device_constraint_utils.cc index 4554038bc770..40df8b65c295 100644 --- a/src/tir/analysis/device_constraint_utils.cc +++ b/src/tir/analysis/device_constraint_utils.cc @@ -254,7 +254,8 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { Downcast(StmtExprMutator::VisitExpr_(buffer_load_node)); Buffer new_buffer = Subst(new_buffer_load->buffer.get()); if (!new_buffer.same_as(new_buffer_load->buffer)) { - return BufferLoad(new_buffer, new_buffer_load->indices, new_buffer_load->span); + return BufferLoad(new_buffer, new_buffer_load->indices, new_buffer_load->predicate, + new_buffer_load->span); } return std::move(new_buffer_load); } @@ -293,7 +294,7 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { Buffer new_buffer = Subst(new_buffer_store->buffer.get()); if (!new_buffer.same_as(new_buffer_store->buffer)) { return BufferStore(new_buffer, new_buffer_store->value, new_buffer_store->indices, - new_buffer_store->span); + new_buffer_store->predicate, new_buffer_store->span); } return std::move(new_buffer_store); } diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc index 0c0d47571c4a..ac1cf0ef11bb 100644 --- a/src/tir/contrib/ethosu/passes.cc +++ b/src/tir/contrib/ethosu/passes.cc @@ -718,7 +718,8 @@ class MergeConstantsMutator : public StmtExprMutator { buffer->axis_separators, buffer->span}; old_to_new_read_buffers[buffer.as()] = new_buffer; - new_args.push_back(BufferLoad(new_buffer, buffer_load->indices, buffer_load->span)); + new_args.push_back(BufferLoad(new_buffer, buffer_load->indices, buffer_load->predicate, + buffer_load->span)); break; } case 2: /* length */ { diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index d71187922874..025605333138 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -399,37 +399,44 @@ Buffer Buffer::GetFlattenedBuffer() const { } } -PrimExpr Buffer::vload(Array begin, DataType value_dtype) const { +PrimExpr Buffer::vload(Array begin, DataType value_dtype, + Optional predicate) const { // specially handle bool, stored as DataType::Int(8) const BufferNode* n = operator->(); ICHECK(n != nullptr); ICHECK(value_dtype.element_of() == n->dtype.element_of() && - value_dtype.lanes() % n->dtype.lanes() == 0) + value_dtype.get_lanes_or_vscale_factor() % n->dtype.lanes() == 0) << "Cannot load " << value_dtype << " from buffer of " << n->dtype; Array indices = begin; - int factor = value_dtype.lanes() / n->dtype.lanes(); - if (factor > 1) { - indices.Set(indices.size() - 1, Ramp(indices[indices.size() - 1], 1, factor)); + PrimExpr base = indices[indices.size() - 1]; + if (value_dtype.is_fixed_length_vector()) { + int factor = value_dtype.lanes() / n->dtype.lanes(); + if (factor > 1 && base.dtype().is_scalar()) { + indices.Set(indices.size() - 1, Ramp(base, 1, factor)); + } } - return BufferLoad(*this, indices); + return BufferLoad(*this, indices, predicate); } -Stmt Buffer::vstore(Array begin, PrimExpr value) const { +Stmt Buffer::vstore(Array begin, PrimExpr value, Optional predicate) const { // specially handle bool, stored as DataType::Int(8) const BufferNode* n = operator->(); ICHECK(n != nullptr); DataType value_dtype = value.dtype(); ICHECK(value_dtype.element_of() == n->dtype.element_of() && - value_dtype.lanes() % n->dtype.lanes() == 0) + value_dtype.get_lanes_or_vscale_factor() % n->dtype.lanes() == 0) << "Cannot store " << value_dtype << " to buffer of " << n->dtype; Array indices = begin; - int factor = value_dtype.lanes() / n->dtype.lanes(); - if (factor > 1) { - indices.Set(indices.size() - 1, Ramp(indices[indices.size() - 1], 1, factor)); + PrimExpr base = indices[indices.size() - 1]; + if (value_dtype.is_fixed_length_vector()) { + int factor = value_dtype.lanes() / n->dtype.lanes(); + if (factor > 1 && base.dtype().is_scalar()) { + indices.Set(indices.size() - 1, Ramp(base, 1, factor)); + } } - return BufferStore(*this, value, indices); + return BufferStore(*this, value, indices, predicate); } String Buffer::scope() const { diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 2cd2a698debe..1506082003fd 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -772,24 +772,47 @@ void BufferLoadNode::LegalizeDType() { } } -BufferLoad::BufferLoad(Buffer buffer, Array indices, Span span) { +BufferLoad::BufferLoad(Buffer buffer, Array indices, Optional predicate, + Span span) { ICHECK_EQ(buffer->shape.size(), indices.size()) << "Buffer " << buffer->name << " is " << buffer->shape.size() << "-dimensional, cannot be indexed with the " << indices.size() << "-dimensional indices provided."; + if (predicate.defined()) { + DataType predicate_dtype = predicate.value().dtype(); + + bool is_index_scalable = indices.empty() ? false : indices.back().dtype().is_scalable_vector(); + bool is_predicate_scalable = predicate_dtype.is_scalable_vector(); + ICHECK_EQ(is_index_scalable, is_predicate_scalable) + << "Predicate mask dtype and load indices must both be scalable."; + + int buffer_lanes = buffer->dtype.get_lanes_or_vscale_factor(); + int index_lanes = indices.empty() ? 1 : indices.back().dtype().get_lanes_or_vscale_factor(); + int predicate_lanes = predicate_dtype.get_lanes_or_vscale_factor(); + ICHECK_EQ(index_lanes * buffer_lanes, predicate_lanes) + << "Got a predicate mask with " << predicate_lanes + << " lanes, but trying to load a vector with " << index_lanes + << " lanes. The number of lanes must match."; + + DataType predicate_element_dtype = predicate_dtype.element_of(); + ICHECK(predicate_element_dtype.is_bool()) + << "Predicate mask elements must be boolean values, but got " << predicate_element_dtype + << "."; + } + ObjectPtr node = make_object(); node->buffer = std::move(buffer); node->indices = std::move(indices); + node->predicate = std::move(predicate); node->span = std::move(span); node->LegalizeDType(); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.BufferLoad") - .set_body_typed([](Buffer buffer, Array indices, Span span) { - return BufferLoad(buffer, indices, span); - }); + .set_body_typed([](Buffer buffer, Array indices, Optional predicate, + Span span) { return BufferLoad(buffer, indices, predicate, span); }); TVM_REGISTER_NODE_TYPE(BufferLoadNode); diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 089a1d31e7d0..34b46583d5ad 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -127,7 +127,7 @@ PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) { if (indices.same_as(op->indices)) { return GetRef(op); } else { - return BufferLoad(op->buffer, indices); + return BufferLoad(op->buffer, indices, op->predicate); } } diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 4774471afcc0..5df76450ff1e 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -458,7 +458,8 @@ TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed([](PrimExpr value, Span span) TVM_REGISTER_NODE_TYPE(EvaluateNode); // BufferStore -BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, Span span) { +BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, + Optional predicate, Span span) { ICHECK_EQ(buffer->shape.size(), indices.size()) << "Buffer " << buffer->name << " is " << buffer->shape.size() << "-dimensional, cannot be indexed with the " << indices.size() @@ -476,29 +477,39 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, ICHECK(!(is_index_scalable && is_buffer_dtype_scalable)) << "Index dtype and buffer dtype can't both be scalable."; - if (is_index_scalable || is_buffer_dtype_scalable) { - ICHECK(is_value_dtype_scalable) << "Can't store non-scalable data into scalable buffer"; + if (predicate.defined()) { + bool is_predicate_dtype_scalable = predicate.value().dtype().is_scalable_vector(); + ICHECK_EQ(is_value_dtype_scalable, is_predicate_dtype_scalable) + << "Predicate mask dtype and value dtype must both be scalable."; } - int index_lanes; - if (indices.empty()) { - index_lanes = 1; - } else if (is_index_scalable) { - index_lanes = indices.back().dtype().vscale_factor(); - } else { - index_lanes = indices.back().dtype().lanes(); + if (is_index_scalable || is_buffer_dtype_scalable) { + ICHECK(is_value_dtype_scalable) << "Can't store non-scalable data into scalable buffer"; } - int buffer_lanes = - is_buffer_dtype_scalable ? buffer->dtype.vscale_factor() : buffer->dtype.lanes(); - int value_dtype_lanes = - is_value_dtype_scalable ? value.dtype().vscale_factor() : value.dtype().lanes(); + int index_lanes = indices.empty() ? 1 : indices.back().dtype().get_lanes_or_vscale_factor(); + int buffer_lanes = buffer->dtype.get_lanes_or_vscale_factor(); + int value_dtype_lanes = value.dtype().get_lanes_or_vscale_factor(); ICHECK_EQ(index_lanes * buffer_lanes, value_dtype_lanes) << "Cannot store value with " << value_dtype_lanes << ", expected value with " << index_lanes * buffer_lanes << " (" << index_lanes << " index lanes * " << buffer_lanes << " buffer element lanes)"; + if (predicate.defined()) { + DataType predicate_dtype = predicate.value().dtype(); + int predicate_dtype_lanes = predicate_dtype.get_lanes_or_vscale_factor(); + ICHECK_EQ(value_dtype_lanes, predicate_dtype_lanes) + << "Got a predicate mask with " << predicate_dtype_lanes + << " lanes, but trying to store a value with " << value_dtype_lanes + << " lanes. The number of lanes must match."; + + DataType predicate_element_dtype = predicate_dtype.element_of(); + ICHECK(predicate_element_dtype.is_bool()) + << "Predicate mask elements must be boolean values, but got " << predicate_element_dtype + << "."; + } + runtime::DataType buffer_dtype; if (is_index_scalable || is_buffer_dtype_scalable) { buffer_dtype = buffer->dtype.with_scalable_vscale_factor(buffer_lanes * index_lanes); @@ -517,14 +528,15 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, node->buffer = std::move(buffer); node->value = std::move(value); node->indices = std::move(indices); + node->predicate = std::move(predicate); node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.BufferStore") - .set_body_typed([](Buffer buffer, PrimExpr value, Array indices, Span span) { - return BufferStore(buffer, value, indices, span); - }); + .set_body_typed([](Buffer buffer, PrimExpr value, Array indices, + Optional predicate, + Span span) { return BufferStore(buffer, value, indices, predicate, span); }); TVM_REGISTER_NODE_TYPE(BufferStoreNode); diff --git a/src/tir/transforms/inject_rolling_buffer.cc b/src/tir/transforms/inject_rolling_buffer.cc index 5f7b9b4156c3..03f94e3e9139 100644 --- a/src/tir/transforms/inject_rolling_buffer.cc +++ b/src/tir/transforms/inject_rolling_buffer.cc @@ -257,7 +257,9 @@ class RollingBufferInjector : public StmtExprMutator { indices.push_back(index); } } - Stmt buffer_store = BufferStore(op->buffer, op->value, indices, op->span); + ICHECK(!op->predicate.defined()) << "Predicated buffer store is not currently supported in " + "the inject rolling buffer pass."; + Stmt buffer_store = BufferStore(op->buffer, op->value, indices, op->predicate, op->span); // Then wrap the BufferStores in some Ifs to avoid recomputing elements for (size_t i{0}; i < rolling_buffer_info.axis_iter_vars.size(); ++i) { auto iter_var{rolling_buffer_info.axis_iter_vars[i]}; @@ -293,7 +295,9 @@ class RollingBufferInjector : public StmtExprMutator { indices.push_back(index); } } - return BufferLoad(op->buffer, indices, op->span); + ICHECK(!op->predicate.defined()) + << "Predicated buffer load is not currently supported in inject rolling buffer pass."; + return BufferLoad(op->buffer, indices, op->predicate, op->span); } else { return expr; } diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index 700587fe0e21..3c2c6b67e653 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -97,6 +97,8 @@ class MatchBufferLower : public StmtExprMutator { auto n = CopyOnWrite(op); n->indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); n->buffer = source->buffer; + ICHECK(!op->predicate.defined()) + << "Predicated buffer store is not currently supported in lower match buffer pass."; return Stmt(n); } } @@ -113,6 +115,8 @@ class MatchBufferLower : public StmtExprMutator { const Buffer& buffer = (*it).first; const BufferRegion& source = (*it).second; Array indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); + ICHECK(!op->predicate.defined()) + << "Predicated buffer load is not currently supported in lower match buffer pass."; return BufferLoad(source->buffer, indices); } } diff --git a/src/tir/transforms/manifest_shared_memory_local_stage.cc b/src/tir/transforms/manifest_shared_memory_local_stage.cc index 619a9f0a9e8f..885d5917136d 100644 --- a/src/tir/transforms/manifest_shared_memory_local_stage.cc +++ b/src/tir/transforms/manifest_shared_memory_local_stage.cc @@ -67,6 +67,8 @@ class IntermediateStageRewriter { Stmt local_stage = MakeLocalStage(block, new_buffer, buffer_indices, relaxed_loops, store); // Step 3: Create BufferLoad from the intermediate buffer + ICHECK(!store->predicate.defined()) << "Predicated buffer store is not currently supported in " + "manifest shared memory local stage pass."; BufferLoad new_buffer_load = BufferLoad(new_buffer, buffer_indices); BufferStore new_buffer_store = Downcast(block->body); new_buffer_store.CopyOnWrite()->value = new_buffer_load; diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index bc606aa0b7ff..3b418aac0cf5 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -213,7 +213,8 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { // A write whose destination is known to already contain the // values to be written is a no-op. // PrimExpr stores_existing_value = store->value == BufferLoad(store->buffer, store->indices); - PrimExpr stores_existing_value = store->value - BufferLoad(store->buffer, store->indices) == 0; + PrimExpr stores_existing_value = + store->value - BufferLoad(store->buffer, store->indices, store->predicate) == 0; if (touch_pattern_.has_value()) { Stmt context_arg = context_ ? GetRef(context_) : Stmt(store); stores_existing_value = diff --git a/src/tir/transforms/remove_weight_layout_rewrite_block.cc b/src/tir/transforms/remove_weight_layout_rewrite_block.cc index 05b636f11403..e8d89bfb5700 100644 --- a/src/tir/transforms/remove_weight_layout_rewrite_block.cc +++ b/src/tir/transforms/remove_weight_layout_rewrite_block.cc @@ -196,7 +196,7 @@ class AllocateConstRewrite : public StmtExprMutator { op->buffer->elem_offset, it->second->name_hint, op->buffer->data_alignment, op->buffer->offset_factor, op->buffer->buffer_type); new_load_buf_[op->buffer->data.get()] = new_buffer; - return BufferLoad(new_buffer, op->indices); + return BufferLoad(new_buffer, op->indices, op->predicate); } return ExprMutator::VisitExpr_(op); } diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index c51dfd7913e4..06554f5f1dd1 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -730,7 +730,7 @@ class ThreadScopePropagate : public StmtExprMutator { auto it = buf_remap_.find(op->buffer->data); if (it != buf_remap_.end()) { - return BufferLoad(it->second, op->indices, op->span); + return BufferLoad(it->second, op->indices, op->predicate, op->span); } else { return expr; } @@ -743,7 +743,7 @@ class ThreadScopePropagate : public StmtExprMutator { auto it = buf_remap_.find(op->buffer->data); if (it != buf_remap_.end()) { - return BufferStore(it->second, op->value, op->indices, op->span); + return BufferStore(it->second, op->value, op->indices, op->predicate, op->span); } else { return stmt; } @@ -938,8 +938,11 @@ class BufferBindUnwrapper : public StmtExprMutator { const BufferEntry& e = GetBufferEntry(op->buffer); if (e.remap) { + ICHECK(!op->predicate.defined()) << "Predicated buffer load is not currently supported in " + "storage flatten pass."; return BufferLoad(e.remap->target, - remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span); + remap_indices(op->indices, e.remap->begins, e.remap->extents), + op->predicate, op->span); } else { return expr; } @@ -952,8 +955,11 @@ class BufferBindUnwrapper : public StmtExprMutator { const BufferEntry& e = GetBufferEntry(op->buffer); if (e.remap) { + ICHECK(!op->predicate.defined()) << "Predicated buffer store is not currently supported in " + "storage flatten pass."; return BufferStore(e.remap->target, op->value, - remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span); + remap_indices(op->indices, e.remap->begins, e.remap->extents), + op->predicate, op->span); } else { return stmt; } @@ -1418,7 +1424,9 @@ class StorageFlattener : public StmtExprMutator { auto flattened_indices = e.buffer->ElemOffset(op->indices); - Stmt body = BufferStore(e.flattened_buffer, value, flattened_indices, op->span); + ICHECK(!op->predicate.defined()) << "Predicated buffer store is not currently supported in " + "storage flatten pass."; + Stmt body = BufferStore(e.flattened_buffer, value, flattened_indices, op->predicate, op->span); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); } @@ -1573,8 +1581,10 @@ class StorageFlattener : public StmtExprMutator { shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); } + ICHECK(!op->predicate.defined()) << "Predicated buffer load is not currently supported in " + "storage flatten pass."; auto flattened_indices = e.buffer->ElemOffset(op->indices); - PrimExpr val = BufferLoad(e.flattened_buffer, flattened_indices, op->span); + PrimExpr val = BufferLoad(e.flattened_buffer, flattened_indices, op->predicate, op->span); if (op->dtype == DataType::Bool()) { ICHECK_EQ(e.flattened_buffer->dtype, DataType::Int(8)) diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index 5a14beb6dc4c..c75ecf77e708 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -330,6 +330,8 @@ class ComputeLegalizer : public StmtExprMutator { ICHECK(MatchDType(value->dtype)); value = cast(new_buf->dtype.with_lanes(value.dtype().lanes()), value); } + ICHECK(!op->predicate.defined()) << "Predicated buffer store is not currently supported in " + "data type legalizer pass."; return BufferStore(new_buf, value, indices); } } @@ -401,6 +403,8 @@ class ComputeLegalizer : public StmtExprMutator { if (new_buf.same_as(op->buffer)) { return ret; } else { + ICHECK(!op->predicate.defined()) << "Predicated buffer load is not currently supported in " + "data type legalizer pass."; return BufferLoad(new_buf, op->indices); } } @@ -562,6 +566,8 @@ class StorageLegalizer : public StmtExprMutator { if (MatchDType(op->value.dtype())) { ICHECK(new_buf->dtype.is_uint()); } + ICHECK(!op->predicate.defined()) << "Predicated buffer store is not currently supported in " + "data type legalizer pass."; return BufferStore(new_buf, value, indices); } } @@ -595,6 +601,8 @@ class StorageLegalizer : public StmtExprMutator { if (new_buf.same_as(op->buffer)) { return ret; } else { + ICHECK(!op->predicate.defined()) << "Predicated buffer load is not currently supported in " + "data type legalizer pass."; return BufferLoad(new_buf, op->indices); } } diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index c4dde01b8f81..aa62d5850513 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -72,6 +72,126 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) { return Broadcast(e, CreateNewLanes(is_scalable, lanes)); } +bool EnableBufferLevelPredication(Target target) { + transform::PassContext pass_ctx = transform::PassContext::Current(); + Optional enable_buffer_predication = + pass_ctx->GetConfig("tir.enable_buffer_level_predication"); + if (enable_buffer_predication.defined()) { + return enable_buffer_predication.value(); + } + + // Use buffer-level predication by default for AArch64 SVE targets + return arith::TargetHasSVE(target); +} + +/*! + * \brief A pass that tries to rewrite buffer accesses (loads and stores) with a + * predicate expression where possible. + * + * \note For now we start with a minimal case targeting block-level predicates + * produced by the split schedule primitive, with the potential for predicating + * more complex terms in the future if needed. + * + * \example + * Before: + * for i_0 in T.serial(4): + * for i_1 in T.vectorized(4): + * if i_0 * 4 + i_1 < 14: + * B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0 + * + * After: + * for i_0 in T.serial(4): + * predicate = T.get_active_lane_mask("uint1x4", i_0 * 4, 14) + * A_load = T.meta_var(A.vload([T.Ramp(i_0 * 4, 1, 4)], predicate=predicate)) + * B.vstore([T.Ramp(i_0 * 4, 1, 4)], A_load, predicate=predicate) + */ +class TryPredicateBufferAccesses : public StmtExprMutator { + public: + TryPredicateBufferAccesses() {} + + /*! + * \brief Run the pass to try to exact predicates. + * \param stmt - The statement containing buffer accesses (loads and stores) + * we want to attempt to predicate. + * \param condition - The conditional expression (block-level predicate) + * that we will try to remove. + * \return pair - Boolean value for success/failure, the rewritten + * stmt if successful. + */ + std::pair Run(Stmt stmt, PrimExpr condition) { + // Check that the condition provided is of the form a < b, for now. + if (!condition->IsInstance()) { + return {false, stmt}; + } + + LT lt = Downcast(condition); + + // Check the form of the vectorized condition, we're expecting + // Ramp(...) < Broadcast(...) + if (!lt->a->IsInstance() || !lt->b->IsInstance()) { + return {false, stmt}; + } + + base_ = Downcast(lt->a)->base; + limit_ = Downcast(lt->b)->value; + + // Now we can try to predicate + Stmt predicated_stmt = StmtExprMutator::operator()(std::move(stmt)); + if (num_accesses_analyzed_ > 0 && num_accesses_analyzed_ == num_accesses_rewritten_) { + return {true, predicated_stmt}; + } + return {false, stmt}; + } + + private: + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto load = Downcast(StmtExprMutator::VisitExpr_(op)); + return TryPredicateBufferAccess(load); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto store = Downcast(StmtExprMutator::VisitStmt_(op)); + return TryPredicateBufferAccess(store); + } + + template + AccessNode TryPredicateBufferAccess(AccessNode node) { + num_accesses_analyzed_ += 1; + + // Do not try to predicate non-vectorized accesses + Array indices = node->indices; + if (!indices.size() || !indices[0]->IsInstance()) { + return node; + } + Ramp ramp = Downcast(node->indices[0]); + + // The vectorized access pattern must match the base of the predicate + if (!tvm::StructuralEqual()(ramp->base, base_)) { + return node; + } + + DataType buf_predicate_dtype = + DataType(DataType::kUInt, 1, ramp->dtype.get_lanes_or_vscale_factor(), + ramp->dtype.is_scalable_vector()); + Call lane_mask = Call(buf_predicate_dtype, builtin::get_active_lane_mask(), {base_, limit_}); + + num_accesses_rewritten_ += 1; + auto writer = node.CopyOnWrite(); + writer->predicate = lane_mask; + return node; + } + + /*! \brief The variable base expr of the predicate. */ + PrimExpr base_; + /*! \brief The limit of the predicate. The expr specifies the upper bound of the base's + * evaluated value. */ + PrimExpr limit_; + /*! \brief The number of buffer accesses in the stmt we will analyze. */ + size_t num_accesses_analyzed_ = 0; + /*! \brief The number of buffer accesses rewritten with predicates. */ + size_t num_accesses_rewritten_ = 0; +}; + // Rewrite vectorized allocation access // This is necessary for making each vector component containing its own workspace. // Originates from Halide's loop vectorizer @@ -171,7 +291,8 @@ class Vectorizer : public StmtMutator, public ExprFunctordtype, 0), IntImm(var->dtype, 1), var_lanes); } @@ -555,14 +676,26 @@ class Vectorizer : public StmtMutator, public ExprFunctorcondition.dtype().is_scalable_or_fixed_length_vector()); PrimExpr condition = this->VisitExpr(op->condition); - if (condition.dtype().is_scalable_or_fixed_length_vector()) { - return Scalarize(GetRef(op)); - } Stmt then_case = this->VisitStmt(op->then_case); Optional else_case = NullOpt; if (op->else_case) { else_case = this->VisitStmt(op->else_case.value()); } + + // Check if we can rewrite the condition with predicated buffers + if (EnableBufferLevelPredication(target_) && + condition.dtype().is_scalable_or_fixed_length_vector() && !else_case.defined()) { + std::pair success_stmt_pair = + TryPredicateBufferAccesses().Run(then_case, condition); + bool can_remove_if_then_else = success_stmt_pair.first; + if (can_remove_if_then_else) { + return success_stmt_pair.second; + } + } + + if (condition.dtype().is_scalable_or_fixed_length_vector()) { + return Scalarize(GetRef(op)); + } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); @@ -659,6 +792,8 @@ class Vectorizer : public StmtMutator, public ExprFunctor let_binding_; // vectorizable property OpAttrMap op_vectorizable_ = Op::GetAttrMap("TVectorizable"); + /*! \brief The current target context. */ + Target target_; // mutate array, with given lane requirement // when finished, p_lane updates the lane requirement. @@ -728,22 +863,41 @@ class Vectorizer : public StmtMutator, public ExprFunctor(tvm::attr::kTarget)) { + target_ = opt_target.value(); + } + } + Stmt VisitStmt_(const ForNode* op) final { if (op->kind == ForKind::kVectorized) { auto* extent_as_int = op->extent.as(); if (!extent_as_int || extent_as_int->value < 1) { bool is_scalable_expr = CheckContains::ExprContains(op->extent, arith::IsVScaleCall); - ICHECK(is_scalable_expr && arith::TargetHasSVE()) - << "Failed to vectorize loop with extent " << op->extent << " for target " - << Target::Current(); + ICHECK(is_scalable_expr && arith::TargetHasSVE(target_)) + << "Failed to vectorize loop with extent " << op->extent << " for target " << target_; } ICHECK(is_zero(op->min)); - return Vectorizer(op->loop_var, op->extent)(op->body); + return Vectorizer(op->loop_var, op->extent, target_)(op->body); } else { return StmtMutator::VisitStmt_(op); } } + + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == tvm::attr::kTarget) { + Target previous_target = target_; + target_ = op->node.as().value(); + Stmt new_op = StmtMutator::VisitStmt_(op); + target_ = previous_target; + return new_op; + } + return StmtMutator::VisitStmt_(op); + } + + private: + Target target_ = Target::Current(); }; class VectorizeSkipper : public StmtMutator { @@ -768,7 +922,7 @@ Pass VectorizeLoop(bool enable_vectorize) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); if (enable_vectorize) { - n->body = LoopVectorizer()(std::move(n->body)); + n->body = LoopVectorizer(n->attrs)(std::move(n->body)); } else { n->body = VectorizeSkipper()(std::move(n->body)); } diff --git a/tests/python/codegen/test_target_codegen.py b/tests/python/codegen/test_target_codegen.py new file mode 100644 index 000000000000..bae15b5377e3 --- /dev/null +++ b/tests/python/codegen/test_target_codegen.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import tvm +from tvm.script import tir as T + + +@tvm.testing.parametrize_targets("c") +def test_buffer_store_predicate_not_supported(target): + @T.prim_func + def func(b: T.handle): + B = T.match_buffer(b, (8,), "float32") + B.vstore([T.Ramp(0, 2, 4)], T.Broadcast(1.0, 4), predicate=T.Broadcast(T.bool(True), 4)) + + err_msg = "Predicated buffer store is not supported." + with pytest.raises(tvm.TVMError, match=err_msg): + with tvm.target.Target(target): + tvm.build(func) + + +@tvm.testing.parametrize_targets("cuda", "opencl", "metal", "rocm", "vulkan -from_device=0") +def test_buffer_store_predicate_not_supported_gpu(target): + @T.prim_func + def func(a: T.handle, b: T.handle): + A = T.match_buffer(a, (2, 3), "float32") + B = T.match_buffer(b, (6,), "float32") + T.func_attr({"global_symbol": "main"}) + for i_0 in T.thread_binding(3, thread="threadIdx.x"): + B.vstore( + [T.Ramp(i_0, 1, 4)], T.Broadcast(1.0, 4), predicate=T.Broadcast(T.bool(True), 4) + ) + + err_msg = "Predicated buffer store is not supported." + with pytest.raises(tvm.TVMError, match=err_msg): + with tvm.target.Target(target): + tvm.build(func) + + +@tvm.testing.parametrize_targets("c") +def test_buffer_load_predicate_not_supported(target): + @T.prim_func + def func(a: T.handle, b: T.handle): + A = T.match_buffer(a, (8,), "float32") + B = T.match_buffer(b, (8,), "float32") + for i_0 in range(4): + B.vstore( + [T.Ramp(0, 2, 4)], + A.vload([T.Ramp(i_0, 1, 4)], predicate=T.Broadcast(T.bool(True), 4)), + ) + + err_msg = "Predicated buffer load is not supported." + with pytest.raises(tvm.TVMError, match=err_msg): + with tvm.target.Target(target): + tvm.build(func) + + +@tvm.testing.parametrize_targets("cuda", "opencl", "metal", "rocm", "vulkan -from_device=0") +def test_buffer_load_predicate_not_supported_gpu(target): + @T.prim_func + def func(a: T.handle, b: T.handle): + A = T.match_buffer(a, (8,), "float32") + B = T.match_buffer(b, (8,), "float32") + for i_0 in T.thread_binding(3, thread="threadIdx.x"): + B.vstore( + [T.Ramp(0, 2, 4)], + A.vload([T.Ramp(i_0, 1, 4)], predicate=T.Broadcast(T.bool(True), 4)), + ) + + err_msg = "Predicated buffer load is not supported." + with pytest.raises(tvm.TVMError, match=err_msg): + with tvm.target.Target(target): + tvm.build(func) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index f73d96e7c916..251e625b8173 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -771,7 +771,7 @@ def test_get_active_lane_mask(): def before(a: T.handle): A = T.match_buffer(a, (30,), "int1") for i in range(T.ceildiv(30, T.vscale() * 4)): - A[i : i + T.vscale() * 4] = T.get_active_lane_mask("int1xvscalex4", i, 30) + A[i : i + T.vscale() * 4] = T.get_active_lane_mask("uint1xvscalex4", i, 30) with tvm.target.Target(target): out = tvm.build(before) @@ -780,5 +780,31 @@ def before(a: T.handle): assert "get.active.lane.mask" in ll +@pytest.mark.skipif( + llvm_version_major() < 11, + reason="Vscale and get.active.lane.mask are not supported in earlier versions of LLVM", +) +def test_predicated_scalable_buffer(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(16, 4 * T.vscale())): + for i_1 in T.vectorized(4 * T.vscale()): + if i_0 * 4 * T.vscale() + i_1 < 14: + B[i_0 * 4 * T.vscale() + i_1] = A[i_0 * 4 * T.vscale() + i_1] + 1.0 + + with tvm.target.Target(target): + out = tvm.build(before) + + ll = out.get_source("ll") + assert "get.active.lane.mask" in ll + assert "llvm.masked.load" in ll + assert "llvm.masked.store" in ll + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index f1316ae3cee0..f50d63878e4f 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -1109,5 +1109,34 @@ def func(): built = tvm.build(func, target="llvm") +def test_invalid_volatile_masked_buffer_load(): + @T.prim_func + def func(b: T.handle): + B = T.match_buffer(b, [4]) + a = T.allocate([4], "float32", scope="global") + T.attr(a, "volatile_scope", 1) + A = T.Buffer([4], data=a) + B[0:4] = A.vload([T.Ramp(0, 1, 4)], predicate=T.Broadcast(T.bool(True), 4)) + + err_msg = "The masked load intrinsic does not support declaring load as volatile." + with pytest.raises(tvm.TVMError, match=err_msg): + with tvm.target.Target("llvm"): + tvm.build(func) + + +def test_invalid_volatile_masked_buffer_store(): + @T.prim_func + def func(): + a = T.allocate([4], "float32", scope="global") + T.attr(a, "volatile_scope", 1) + A = T.Buffer([4], data=a) + A.vstore([T.Ramp(0, 1, 4)], T.Broadcast(0.0, 4), predicate=T.Broadcast(T.bool(True), 4)) + + err_msg = "The masked store intrinsic does not support declaring store as volatile." + with pytest.raises(tvm.TVMError, match=err_msg): + with tvm.target.Target("llvm"): + tvm.build(func) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index d4fa17bf8fa4..65381a0eb9ee 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -348,5 +348,99 @@ def test_v0_16_ramp_broadcast_lanes(): assert graph.value.lanes == 12 +def test_v0_17_load_store_predicate(): + json_graph_v0_16 = { + "root": 1, + "nodes": [ + {"type_key": ""}, + { + "type_key": "tir.BufferStore", + "attrs": { + "buffer": "2", + "indices": "19", + "predicate": "0", + "span": "0", + "value": "13", + }, + }, + { + "type_key": "tir.Buffer", + "attrs": { + "axis_separators": "11", + "buffer_type": "1", + "data": "3", + "data_alignment": "64", + "dtype": "float32", + "elem_offset": "12", + "name": "4", + "offset_factor": "1", + "shape": "8", + "span": "0", + "strides": "10", + }, + }, + { + "type_key": "tir.Var", + "attrs": {"dtype": "handle", "name": "4", "span": "0", "type_annotation": "5"}, + }, + {"type_key": "runtime.String"}, + {"type_key": "PointerType", "attrs": {"element_type": "6", "storage_scope": "7"}}, + {"type_key": "PrimType", "attrs": {"dtype": "float32"}}, + {"type_key": "runtime.String", "repr_str": "global"}, + {"type_key": "Array", "data": [9]}, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "8"}}, + {"type_key": "Array"}, + {"type_key": "Array"}, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "0"}}, + { + "type_key": "tir.BufferLoad", + "attrs": { + "buffer": "2", + "dtype": "float32x4", + "indices": "14", + "predicate": "0", + "span": "0", + }, + }, + {"type_key": "Array", "data": [15]}, + { + "type_key": "tir.Ramp", + "attrs": { + "base": "16", + "dtype": "int32x4", + "lanes": "18", + "span": "0", + "stride": "17", + }, + }, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "0"}}, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "1"}}, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "4"}}, + {"type_key": "Array", "data": [20]}, + { + "type_key": "tir.Ramp", + "attrs": { + "base": "21", + "dtype": "int32x4", + "lanes": "23", + "span": "0", + "stride": "22", + }, + }, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "4"}}, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "1"}}, + {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0", "value": "4"}}, + ], + "b64ndarrays": [], + "attrs": {"tvm_version": "0.16.0"}, + } + + expr = tvm.ir.load_json(json.dumps(json_graph_v0_16)) + buffer_store = expr + buffer_load = buffer_store.value + assert not buffer_store.predicate + assert not buffer_load.predicate + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-base/test_tir_nodes.py b/tests/python/tir-base/test_tir_nodes.py index 31a1317e6817..eeedae1f127c 100644 --- a/tests/python/tir-base/test_tir_nodes.py +++ b/tests/python/tir-base/test_tir_nodes.py @@ -468,6 +468,75 @@ def test_buffer_store_scalable_vec(): assert store.value.dtype == "int32xvscalex4" +def test_buffer_store_predicate_invalid_scalability(): + b = tvm.tir.decl_buffer((24,), "int32") + value = tvm.tir.expr.Broadcast(1, 4 * tvm.tir.vscale()) + index = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale()) + predicate = tvm.tir.expr.Broadcast(tvm.tir.IntImm("int1", 1), 4) + + err_msg = "Predicate mask dtype and value dtype must both be scalable." + with pytest.raises(tvm.TVMError, match=err_msg): + tvm.tir.BufferStore(b, value, [index], predicate) + + +def test_buffer_store_predicate_invalid_lanes(): + b = tvm.tir.decl_buffer((24,), "int32") + value = tvm.tir.expr.Broadcast(1, 4 * tvm.tir.vscale()) + index = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale()) + predicate = tvm.tir.expr.Broadcast(tvm.tir.IntImm("int1", 1), 8 * tvm.tir.vscale()) + + err_msg = ( + "Got a predicate mask with 8 lanes, but trying to store a " + "value with 4 lanes. The number of lanes must match." + ) + with pytest.raises(tvm.TVMError, match=err_msg): + tvm.tir.BufferStore(b, value, [index], predicate) + + +def test_buffer_store_predicate_elements_invalid_type(): + b = tvm.tir.decl_buffer((24,), "int32") + value = tvm.tir.expr.Broadcast(1, 4 * tvm.tir.vscale()) + index = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale()) + predicate = tvm.tir.expr.Broadcast(1, 4 * tvm.tir.vscale()) + + err_msg = "Predicate mask elements must be boolean values, but got int32." + with pytest.raises(tvm.TVMError, match=err_msg): + tvm.tir.BufferStore(b, value, [index], predicate) + + +def test_buffer_load_predicate_elements_invalid_type(): + b = tvm.tir.decl_buffer((24,), "int32") + index = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale()) + predicate = tvm.tir.expr.Broadcast(1, 4 * tvm.tir.vscale()) + + err_msg = "Predicate mask elements must be boolean values, but got int32." + with pytest.raises(tvm.TVMError, match=err_msg): + tvm.tir.BufferLoad(b, [index], predicate) + + +def test_buffer_store_predicate_invalid_scalability(): + b = tvm.tir.decl_buffer((24,), "int32") + index = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale()) + predicate = tvm.tir.expr.Broadcast(tvm.tir.IntImm("int1", 1), 4) + + err_msg = "Predicate mask dtype and load indices must both be scalable." + with pytest.raises(tvm.TVMError, match=err_msg): + tvm.tir.BufferLoad(b, [index], predicate) + + +def test_buffer_store_predicate_invalid_lanes(): + b = tvm.tir.decl_buffer((24,), "int32") + index = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale()) + predicate = tvm.tir.expr.Broadcast(tvm.tir.IntImm("int1", 1), 8 * tvm.tir.vscale()) + + err_msg = ( + "Got a predicate mask with 8 lanes, but trying to load a " + "vector with 4 lanes. The number of lanes must match." + ) + with pytest.raises(tvm.TVMError, match=err_msg): + tvm.tir.BufferLoad(b, [index], predicate) + + def test_scalable_vec_cast(): b = tvm.tir.decl_buffer((24,), "float32") value = tvm.tir.expr.Broadcast(1, 12 * tvm.tir.vscale()).astype("float32xvscalex12") diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index de5453eb5c44..e02c227b05b7 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -125,12 +125,15 @@ def main(A: T.Buffer((25,), "float32")): tvm.tir.transform.VectorizeLoop()(Module) -@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) -def test_vectorize_with_if(extent, target): +def test_vectorize_with_if(): + extent = 4 + target = simple_target + @I.ir_module class Before: @T.prim_func - def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): + def main(a: T.handle, n: T.int32, x: T.int32): + A = T.match_buffer(a, (25,), "float32") for i in T.vectorized(extent): if x < n: A[i] = A[i] + T.float32(1) @@ -141,7 +144,8 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): @I.ir_module class After: @T.prim_func - def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): + def main(a: T.handle, n: T.int32, x: T.int32): + A = T.match_buffer(a, (25,), "float32") if x < n: A[T.Ramp(0, 1, extent)] = A[T.Ramp(0, 1, extent)] + T.Broadcast( T.float32(1), extent @@ -156,6 +160,43 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): tvm.ir.assert_structural_equal(mod, After) +def test_vectorize_if_scalable_extent(): + extent = T.vscale() * 4 + target = sve_target + + @I.ir_module + class Before: + @T.prim_func + def main(a: T.handle, n: T.int32, x: T.int32): + A = T.match_buffer(a, (25,), "float32") + for i in T.vectorized(extent): + if x < n: + A[i] = A[i] + T.float32(1) + else: + if i < n: + A[i] = T.float32(2) + + @I.ir_module + class After: + @T.prim_func + def main(a: T.handle, n: T.int32, x: T.int32): + A = T.match_buffer(a, (25,), "float32") + if x < n: + A[T.Ramp(0, 1, extent)] = A[T.Ramp(0, 1, extent)] + T.Broadcast( + T.float32(1), extent + ) + else: + A.vstore( + [T.Ramp(0, 1, T.vscale() * 4)], + T.Broadcast(T.float32(2), T.vscale() * 4), + predicate=T.get_active_lane_mask("uint1xvscalex4", 0, n), + ) + + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) + + def test_vectorize_with_if_cond_int64(): m = te.size_var("m", dtype="int64") A = te.placeholder((m,), name="A", dtype="float32") @@ -488,5 +529,243 @@ def main(A: T.Buffer((16,), "float32")): tvm.tir.transform.VectorizeLoop()(Mod) +def test_vectorize_and_predicate_all_buffer_loads_stores(): + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 < 14: + B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0 + + @T.prim_func + def expected(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + for i_0 in range(4): + load_a = T.meta_var( + A.vload( + [T.Ramp(i_0 * 4, 1, 4)], + predicate=T.get_active_lane_mask("uint1x4", i_0 * 4, 14), + ) + ) + add_1 = T.meta_var(load_a + T.Broadcast(T.float32(1), 4)) + B.vstore( + [T.Ramp(i_0 * 4, 1, 4)], + add_1, + predicate=T.get_active_lane_mask("uint1x4", i_0 * 4, 14), + ) + + mod = tvm.IRModule.from_expr(before) + with tvm.transform.PassContext(config={"tir.enable_buffer_level_predication": True}): + after = tvm.tir.transform.VectorizeLoop()(mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + +def test_vectorize_and_predicate_some_buffer_loads_stores(): + # Currently revert to scalarizing the block if not all accesses + # have been predicated, otherwise incorrect code is generated. + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 < 14: + B[i_0 * 4 + i_1] = A[i_0] + 1.0 + + @T.prim_func + def expected(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + for i_0, i_1_s in T.grid(4, 4): + if i_0 * 4 + i_1_s < 14: + B[i_0 * 4 + i_1_s] = A[i_0] + T.float32(1) + + mod = tvm.IRModule.from_expr(before) + with tvm.transform.PassContext(config={"tir.enable_buffer_level_predication": True}): + after = tvm.tir.transform.VectorizeLoop()(mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + +def test_vectorize_and_predicate_multiple_access_statements(): + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 < 14: + A[i_0 * 4 + i_1] = 2.0 + B[i_0 * 4 + i_1] = 1.0 + + @T.prim_func + def expected(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + for i_0 in range(4): + A.vstore( + [T.Ramp(i_0 * 4, 1, 4)], + T.Broadcast(T.float32(2), 4), + predicate=T.get_active_lane_mask("uint1x4", i_0 * 4, 14), + ) + B.vstore( + [T.Ramp(i_0 * 4, 1, 4)], + T.Broadcast(T.float32(1), 4), + predicate=T.get_active_lane_mask("uint1x4", i_0 * 4, 14), + ) + + before_mod = tvm.IRModule.from_expr(before) + with tvm.transform.PassContext(config={"tir.enable_buffer_level_predication": True}): + after = tvm.tir.transform.VectorizeLoop()(before_mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + +def test_vectorize_and_predicate_invalid_conditions(): + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 > 14: + A[i_0 * 4 + i_1] = 2.0 + if 14 < i_0 * 4 + i_1: + A[i_0 * 4 + i_1] = 2.0 + if i_0 * 4 + i_1 < i_0 * 4 + i_1: + A[i_0 * 4 + i_1] = 2.0 + + @T.prim_func + def expected(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + for i_0 in range(4): + for i_1_s in range(4): + if i_0 * 4 + i_1_s > 14: + A[i_0 * 4 + i_1_s] = T.float32(2) + for i_1_s in range(4): + if 14 < i_0 * 4 + i_1_s: + A[i_0 * 4 + i_1_s] = T.float32(2) + for i_1_s in range(4): + if i_0 * 4 + i_1_s < i_0 * 4 + i_1_s: + A[i_0 * 4 + i_1_s] = T.float32(2) + + before_mod = tvm.IRModule.from_expr(before) + with tvm.transform.PassContext(config={"tir.enable_buffer_level_predication": True}): + after = tvm.tir.transform.VectorizeLoop()(before_mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + +def test_vectorize_with_explicitly_disabled_buffer_level_predication(): + # Since the target has the SVE feature, buffer level predication is enabled + # by default. However, it has been explicitly disabled by the pass context + # option, so no buffer-level predicates should be added. + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 < 14: + B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0 + + @T.prim_func + def expected(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0, i_1_s in T.grid(4, 4): + if i_0 * 4 + i_1_s < 14: + B[i_0 * 4 + i_1_s] = A[i_0 * 4 + i_1_s] + T.float32(1) + + mod = tvm.IRModule.from_expr(before) + with tvm.transform.PassContext(config={"tir.enable_buffer_level_predication": False}): + with tvm.target.Target(sve_target): + after = tvm.tir.transform.VectorizeLoop()(mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + +def test_vectorize_and_predicate_buffer_load_stores_with_sve_func_attr_target(): + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True, "target": sve_target}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 < 14: + B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0 + + @T.prim_func + def expected(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True), "target": sve_target}) + for i_0 in range(4): + load_a = T.meta_var( + A.vload( + [T.Ramp(i_0 * 4, 1, 4)], + predicate=T.get_active_lane_mask("uint1x4", i_0 * 4, 14), + ) + ) + add_1 = T.meta_var(load_a + T.Broadcast(T.float32(1), 4)) + B.vstore( + [T.Ramp(i_0 * 4, 1, 4)], + add_1, + predicate=T.get_active_lane_mask("uint1x4", i_0 * 4, 14), + ) + + mod = tvm.IRModule.from_expr(before) + after = tvm.tir.transform.VectorizeLoop()(mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + +def test_vectorize_and_predicate_buffer_load_stores_with_sve_attr_scope_target(): + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + with T.attr(sve_target, "target", 0): + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 < 14: + B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0 + + @T.prim_func + def expected(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + with T.attr(sve_target, "target", 0): + for i_0 in range(4): + load_a = T.meta_var( + A.vload( + [T.Ramp(i_0 * 4, 1, 4)], + predicate=T.get_active_lane_mask("uint1x4", i_0 * 4, 14), + ) + ) + add_1 = T.meta_var(load_a + T.Broadcast(T.float32(1), 4)) + B.vstore( + [T.Ramp(i_0 * 4, 1, 4)], + add_1, + predicate=T.get_active_lane_mask("uint1x4", i_0 * 4, 14), + ) + + mod = tvm.IRModule.from_expr(before) + after = tvm.tir.transform.VectorizeLoop()(mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py index c20784b4bf75..daad7f53140b 100644 --- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py +++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py @@ -468,6 +468,20 @@ def test_ir_builder_tir_buffer_store_scalable_vec(): assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) +def test_ir_builder_tir_buffer_store_predicate(): + buffer_a = T.Buffer((30,), "float32") + value = T.broadcast(0.11, T.vscale() * 4) + index = T.ramp(0, 1, T.vscale() * 4) + predicate = T.broadcast(T.bool(True), T.vscale() * 4) + + with IRBuilder() as ib: + T.buffer_store(buffer_a, value, [index], predicate) + + ir_actual = ib.get() + ir_expected = tir.BufferStore(buffer_a, value, [index], predicate) + assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) + + def test_ir_builder_tir_prefetch(): with IRBuilder() as ib: buffer_a = T.Buffer((128, 128), "float32") diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index edc6da31636b..9e77fa090021 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -948,5 +948,102 @@ def func(): _assert_print(func, expected_output) +def test_predicated_load_store(): + from tvm.script import tir as T + + @T.prim_func + def main(a: T.handle, b: T.handle): + A = T.match_buffer(a, (128, 128), "float32") + B = T.match_buffer(b, (256, 256), "float32") + T.func_attr({"global_symbol": "func"}) + a_load = T.meta_var(A.vload([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(T.bool(False), 4))) + A.vstore([0, T.Ramp(0, 2, 4)], a_load, predicate=T.Broadcast(T.bool(False), 4)) + + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func +def func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): + A.vstore([0, T.Ramp(0, 2, 4)], A.vload([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(T.bool(False), 4)), predicate=T.Broadcast(T.bool(False), 4)) + """ + _assert_print(main, expected_output) + + +def test_predicated_buffer_load_store(): + a = tir.Var("a", "handle") + b = tir.Var("b", "handle") + buffer_map = { + a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A"), + b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"), + } + buffer_load = tir.BufferLoad( + buffer=buffer_map[b], + indices=[0, tir.Ramp(0, 4, 4)], + predicate=tir.Broadcast(tir.IntImm("uint1", 0), 4), + ) + body = tir.BufferStore( + buffer=buffer_map[a], + value=buffer_load, + indices=[0, tir.Ramp(0, 2, 4)], + predicate=tir.Broadcast(tir.IntImm("uint1", 0), 4), + ) + func = tir.PrimFunc( + params=[a, b], + ret_type=None, + buffer_map=buffer_map, + body=body, + ) + + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func(private=True) +def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): + A.vstore([0, T.Ramp(0, 2, 4)], B.vload([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(T.bool(False), 4)), predicate=T.Broadcast(T.bool(False), 4)) + """ + _assert_print(func, expected_output) + + +def test_predicated_scalable_load_store(): + from tvm.script import tir as T + + @T.prim_func + def main(a: T.handle, b: T.handle): + A = T.match_buffer(a, (128, 128), "float32") + B = T.match_buffer(b, (256, 256), "float32") + T.func_attr({"global_symbol": "func"}) + mask = T.meta_var(T.get_active_lane_mask("uint1xvscalex4", 0, 13)) + a_load = T.meta_var(A.vload([0, T.Ramp(0, 4, T.vscale() * 4)], predicate=mask)) + A.vstore([0, T.Ramp(0, 2, T.vscale() * 4)], a_load, predicate=mask) + + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func +def func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): + A.vstore([0, T.Ramp(0, 2, T.vscale() * 4)], A.vload([0, T.Ramp(0, 4, T.vscale() * 4)], predicate=T.get_active_lane_mask("uint1xvscalex4", 0, 13)), predicate=T.get_active_lane_mask("uint1xvscalex4", 0, 13)) + """ + _assert_print(main, expected_output) + + +def test_vload_with_explicit_scalable_data_type(): + from tvm.script import tir as T + + @T.prim_func + def main(a: T.handle, b: T.handle): + A = T.match_buffer(a, (128,), "float32") + B = T.match_buffer(b, (128,), "float32") + B[0 : T.vscale() * 4] = A.vload([T.Ramp(0, 1, T.vscale() * 4)], dtype="float32xvscalex4") + + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func +def main(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): + B[0:T.vscale() * 4] = A[0:T.vscale() * 4] + """ + _assert_print(main, expected_output) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index 73bf200bb22a..ee404f08efb8 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -3352,6 +3352,20 @@ def func(a: T.handle): return func +def predicated_buffer_load_store(): + @T.prim_func + def func(a: T.handle, b: T.handle): + A = T.match_buffer(a, (4,), "float32") + B = T.match_buffer(b, (8,), "float32") + for i_0 in range(4): + load_a = T.meta_var( + A.vload([T.Ramp(i_0, 1, 4)], predicate=T.Broadcast(T.bool(True), 4)) + ) + B.vstore([T.Ramp(0, 2, 4)], load_a, predicate=T.Broadcast(T.bool(True), 4)) + + return func + + def let_expression(): @T.prim_func def func(): @@ -4116,6 +4130,8 @@ def func(A: R.Object): buffer_axis_separator, buffer_ramp_access_as_slice_index, ramp_int64, + scalable_vectors, + predicated_buffer_load_store, let_expression, void_ptr, decl_buffer,