diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index ac7e879a644d..6d083e2631f7 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -27,6 +27,7 @@ #include #include +#include #include #include @@ -71,11 +72,16 @@ class DataType { * \param code The type code. * \param bits The number of bits in the type. * \param lanes The number of lanes. + * \param scalable Whether or not the data type is scalable. */ - DataType(int code, int bits, int lanes) { + DataType(int code, int bits, int lanes, bool scalable = false) { data_.code = static_cast(code); data_.bits = static_cast(bits); - data_.lanes = static_cast(lanes); + if (scalable) { + data_.lanes = static_cast(-lanes); + } else { + data_.lanes = static_cast(lanes); + } if (code == kBFloat) { ICHECK_EQ(bits, 16); } @@ -90,7 +96,14 @@ class DataType { /*! \return number of bytes to store each scalar. */ int bytes() const { return (bits() + 7) / 8; } /*! \return number of lanes in the data. */ - int lanes() const { return static_cast(data_.lanes); } + int lanes() const { + int encoded_lanes = static_cast(data_.lanes); + if (is_scalable()) { + return -encoded_lanes; + } else { + return encoded_lanes; + } + } /*! \return whether type is a scalar type. */ bool is_scalar() const { return lanes() == 1; } /*! \return whether type is a scalar type. */ @@ -114,17 +127,28 @@ class DataType { /*! \return whether type is a handle type. */ bool is_handle() const { return code() == DataType::kHandle && !is_void(); } /*! \return whether type is a vector type. */ - bool is_vector() const { return lanes() > 1; } + bool is_vector() const { + int encoded_lanes = static_cast(data_.lanes); + return encoded_lanes != 0 && encoded_lanes != 1; + } /*! \return whether type is a bool vector type. */ bool is_vector_bool() const { return is_vector() && bits() == 1; } /*! \return whether type is a Void type. */ bool is_void() const { return code() == DataType::kHandle && bits() == 0 && lanes() == 0; } + /*! \return Whether the type is scalable. */ + bool is_scalable() const { return static_cast(data_.lanes) < 0; } /*! * \brief Create a new data type by change lanes to a specified value. * \param lanes The target number of lanes. * \return the result type. */ DataType with_lanes(int lanes) const { return DataType(data_.code, data_.bits, lanes); } + /*! + * \brief Create a new scalable data type by changing the lanes to a specified value. + * \param lanes The target number of lanes. + * \return A copy of the old DataType with the number of scalable lanes. + */ + DataType with_scalable_lanes(int lanes) const { return DataType(data_.code, data_.bits, -lanes); } /*! * \brief Create a new data type by change bits to a specified value. * \param bits The target number of bits. @@ -247,6 +271,9 @@ class DataType { * \return Number of bytes needed. */ inline int GetVectorBytes(DataType dtype) { + if (dtype.is_scalable()) { + LOG(FATAL) << "Cannot get vector bytes of scalable vector"; + } int data_bits = dtype.bits() * dtype.lanes(); // allow bool to exist if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) || @@ -357,8 +384,12 @@ inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*) } if (t.code == kTVMOpaqueHandle) return os; os << static_cast(t.bits); - if (t.lanes != 1) { - os << 'x' << static_cast(t.lanes); + + int16_t lanes = static_cast(t.lanes); + if (lanes > 1) { + os << 'x' << lanes; + } else if (lanes < 0) { + os << 'x' << -lanes << "xvscale"; } return os; } @@ -424,6 +455,10 @@ inline DLDataType String2DLDataType(std::string s) { if (*xdelim == 'x') { t.lanes = static_cast(strtoul(xdelim + 1, &endpt, 10)); } + if (strncmp(endpt, "xvscale", 7) == 0) { + t.lanes = -t.lanes; + endpt = endpt + 7; + } ICHECK(endpt == s.c_str() + s.length()) << "unknown type " << s; return t; } diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 735d5ba6c0a1..d83303daac9a 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -410,8 +410,9 @@ Var EnvThread(String thread_tag); * \param buffer The buffer. * \param value The value to be stored. * \param indices The indices location to be stored. + * \param predicate A vector mask of int1 values that prevents storing values on masked-off lanes. */ -void BufferStore(Buffer buffer, PrimExpr value, Array indices); +void BufferStore(Buffer buffer, PrimExpr value, Array indices, PrimExpr predicate); /*! * \brief The prefetch hint for a buffer diff --git a/include/tvm/te/schedule.h b/include/tvm/te/schedule.h index 9ffcb105a7ba..c2f206f5c8c0 100644 --- a/include/tvm/te/schedule.h +++ b/include/tvm/te/schedule.h @@ -131,10 +131,11 @@ class Stage : public ObjectRef { * \param factor The split factor of the loop. * \param p_outer The result outer domain * \param p_inner The result inner domain. + * \param disable_predication Whether to not predicate the loop * \return reference to self. */ - TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer, - IterVar* p_inner); // NOLINT(*) + TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner, + bool disable_predication = false); // NOLINT(*) /*! * \brief Split the iteration with given number of parts. * @@ -142,10 +143,11 @@ class Stage : public ObjectRef { * \param nparts The number of parts in the outer domain. * \param p_outer The result outer domain. * \param p_inner The result inner domain. + * \param disable_predication Whether to not predicate the loop * \return reference to self. */ TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, - IterVar* p_inner); // NOLINT(*) + IterVar* p_inner, bool disable_predication = false); // NOLINT(*) /*! * \brief Fuse the inner outer domain to the target * \param outer The outer domain to be fused. @@ -761,6 +763,8 @@ class SplitNode : public IterVarRelationNode { PrimExpr factor; /*! \brief Number of parts, only factor or nparts can be given */ PrimExpr nparts; + /*! \brief Whether to disable the predication */ + bool disable_predication; void VisitAttrs(AttrVisitor* v) { v->Visit("parent", &parent); @@ -768,6 +772,7 @@ class SplitNode : public IterVarRelationNode { v->Visit("inner", &inner); v->Visit("factor", &factor); v->Visit("nparts", &nparts); + v->Visit("disable_predication", &disable_predication); } static constexpr const char* _type_key = "Split"; @@ -780,7 +785,8 @@ class SplitNode : public IterVarRelationNode { */ class Split : public IterVarRelation { public: - TVM_DLL Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts); + TVM_DLL Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts, + bool disable_predication); TVM_DEFINE_OBJECT_REF_METHODS(Split, IterVarRelation, SplitNode); }; diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index e6116605f8a2..76a29997fdd1 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -909,6 +909,18 @@ TVM_DLL const Op& anylist_setitem_call_packed(); */ TVM_DLL const Op& anylist_setitem_call_cpacked(); +/*! + * \brief Return the value of vscale + */ +TVM_DLL const Op& vscale(); + +/*! + * \brief Provide the predicate constructed of the currently active lanes + * + * Calculate the active lane masks given a bound and a current value + */ +TVM_DLL const Op& get_active_lane_mask(); + /*! \brief The kind of structure field info used in intrinsic */ enum TVMStructFieldKind : int { // array head address diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 4e29eddadd8c..1babb4a92a97 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -630,23 +630,27 @@ class BufferLoadNode : public PrimExprNode { Buffer buffer; /*! \brief The indices location to be loaded. */ Array indices; + /*! \brief The buffer predicate */ + PrimExpr predicate; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &(this->dtype)); v->Visit("buffer", &buffer); v->Visit("indices", &indices); + v->Visit("predicate", &predicate); v->Visit("span", &span); } bool SEqualReduce(const BufferLoadNode* other, SEqualReducer equal) const { return equal(dtype, other->dtype) && equal(buffer, other->buffer) && - equal(indices, other->indices); + equal(indices, other->indices) && equal(predicate, other->predicate); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); hash_reduce(buffer); hash_reduce(indices); + hash_reduce(predicate); } static constexpr const char* _type_key = "tir.BufferLoad"; @@ -675,7 +679,8 @@ class BufferLoadNode : public PrimExprNode { */ class BufferLoad : public PrimExpr { public: - TVM_DLL explicit BufferLoad(Buffer buffer, Array indices, Span span = Span()); + TVM_DLL explicit BufferLoad(Buffer buffer, Array indices, + PrimExpr predicate = PrimExpr(nullptr), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode); }; @@ -746,7 +751,7 @@ class RampNode : public PrimExprNode { /*! \brief The stride of each step. */ PrimExpr stride; /*! \brief Total number of lanes. */ - int lanes; + PrimExpr lanes; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); @@ -778,7 +783,7 @@ class RampNode : public PrimExprNode { */ class Ramp : public PrimExpr { public: - TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span = Span()); + TVM_DLL Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Ramp, PrimExpr, RampNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(RampNode); }; @@ -789,7 +794,7 @@ class BroadcastNode : public PrimExprNode { /*! \brief The base value. */ PrimExpr value; /*! \brief The number of lanes. */ - int lanes; + PrimExpr lanes; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); @@ -818,7 +823,7 @@ class BroadcastNode : public PrimExprNode { */ class Broadcast : public PrimExpr { public: - TVM_DLL Broadcast(PrimExpr value, int lanes, Span span = Span()); + TVM_DLL Broadcast(PrimExpr value, PrimExpr lanes, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Broadcast, PrimExpr, BroadcastNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BroadcastNode); }; diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 273912ed1f8f..7f9c9f0a8801 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -350,10 +350,15 @@ class ScheduleNode : public runtime::Object { * \param factors The positive tiling factors, and at most one of which is `NullOpt`, which means * that factor is inferred. * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings + * \param disable_predication If enabled, don't create a predicate for guarding the + * loop. This can be useful when splitting with scalable factors that the schedule writer + * knows are divisible. Warning: enabling this feature may result in incorrect code generation + * if not used carefully. * \return The new loops after split */ virtual Array Split(const LoopRV& loop_rv, const Array>& factors, - bool preserve_unit_iters = true) = 0; + bool preserve_unit_iters = true, + bool disable_predication = false) = 0; /*! * \brief Reorder a list of loops. It doesn't require the loops to be consecutive. * It requires: diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 07cc9b5ad0d5..3255b1ac90ac 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -231,23 +231,27 @@ class BufferStoreNode : public StmtNode { PrimExpr value; /*! \brief The indices location to be stored. */ Array indices; + /*! \brief The predicate for this store. */ + PrimExpr predicate; void VisitAttrs(AttrVisitor* v) { v->Visit("buffer", &buffer); v->Visit("value", &value); v->Visit("indices", &indices); + v->Visit("predicate", &predicate); v->Visit("span", &span); } bool SEqualReduce(const BufferStoreNode* other, SEqualReducer equal) const { return equal(buffer, other->buffer) && equal(value, other->value) && - equal(indices, other->indices); + equal(indices, other->indices) && equal(predicate, other->predicate); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(buffer); hash_reduce(value); hash_reduce(indices); + hash_reduce(predicate); } static constexpr const char* _type_key = "tir.BufferStore"; @@ -261,7 +265,7 @@ class BufferStoreNode : public StmtNode { class BufferStore : public Stmt { public: TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array indices, - Span span = Span()); + PrimExpr predicate = PrimExpr(nullptr), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode); diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 54e4d8f205a1..e8a61eb73736 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -135,7 +135,10 @@ def __init__(self, type_str): arr = type_str.split("x") head = arr[0] - self.lanes = int(arr[1]) if len(arr) > 1 else 1 + if len(arr) == 3 and arr[2] == "vscale": + self.lanes = ctypes.c_uint16(-int(arr[1])) + elif len(arr) > 1: + self.lanes = ctypes.c_uint16(int(arr[1])) bits = 32 if head.startswith("int"): @@ -188,8 +191,11 @@ def __repr__(self): type_name = "custom[%s]" % tvm.runtime._ffi_api._datatype_get_type_name(self.type_code) x = "%s%d" % (type_name, self.bits) - if self.lanes != 1: + lanes_as_int = ctypes.c_int16(self.lanes).value + if lanes_as_int > 1: x += "x%d" % self.lanes + elif lanes_as_int < 0: + x += "x%dxvscale" % -lanes_as_int return x def __eq__(self, other): diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 74b0bd2ba4e1..104ce5764c20 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1262,6 +1262,7 @@ def buffer_store( buffer: Buffer, # pylint: disable=redefined-outer-name value: PrimExpr, indices: List[Union[PrimExpr, slice]], + predicate: Optional[PrimExpr] = None, ) -> None: """Buffer store node. @@ -1275,6 +1276,9 @@ def buffer_store( indices : List[Union[PrimExpr, slice]] The indices location to be stored. + + predicate : Optional[PrimExpr] + A vector mask of int1 values that prevents storing values on masked-off lanes. """ from tvm.arith import Analyzer # pylint: disable=import-outside-toplevel @@ -1289,13 +1293,13 @@ def buffer_store( if lanes == 1: expr_indices.append(index.start) else: - expr_indices.append(ramp(index.start, step, int(lanes))) + expr_indices.append(ramp(index.start, step, lanes)) else: expr_indices.append(index) if isinstance(value, bool) and buffer.dtype == "bool": value = IntImm("bool", value) return _ffi_api.BufferStore( # type: ignore[attr-defined] # pylint: disable=no-member - buffer, value, expr_indices + buffer, value, expr_indices, predicate ) @@ -1854,6 +1858,7 @@ def wrapped(*args, **kwargs): create_barriers = _op_wrapper(_tir_op.create_barriers) assume = _op_wrapper(_tir_op.assume) undef = _op_wrapper(_tir_op.undef) +vscale = _op_wrapper(_tir_op.vscale) TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace) TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace) start_profile_intrinsic = _op_wrapper(_tir_op.start_profile_intrinsic) @@ -1890,7 +1895,7 @@ def wrapped(*args, **kwargs): vectorlow = _dtype_forward(_tir_op.vectorlow) vectorhigh = _dtype_forward(_tir_op.vectorhigh) vectorcombine = _dtype_forward(_tir_op.vectorcombine) - +get_active_lane_mask = _dtype_forward(_tir_op.get_active_lane_mask) broadcast = Broadcast ramp = Ramp @@ -2199,4 +2204,6 @@ def wrapped(*args, **kwargs): "IterVar", "CommReducer", "Range", + "vscale", + "get_active_lane_mask", ] diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 89673d291b88..18e3528e693b 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -462,6 +462,8 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: elif isinstance(res, str): # Ignore docstrings pass + elif isinstance(res, tvm.tir.stmt.BufferStore): + T.buffer_store(res.buffer, res.value, res.indices, res.predicate) else: self.report_error(node, f"Parsing resulted in unexpected type {type(res)}") diff --git a/python/tvm/te/schedule.py b/python/tvm/te/schedule.py index 936ead654dc8..180f334d1486 100644 --- a/python/tvm/te/schedule.py +++ b/python/tvm/te/schedule.py @@ -201,7 +201,7 @@ def rfactor(self, tensor, axis, factor_axis=0): class Stage(Object): """A Stage represents schedule for one operation.""" - def split(self, parent, factor=None, nparts=None): + def split(self, parent, factor=None, nparts=None, disable_predication=False): """Split the stage either by factor providing outer scope, or both Parameters @@ -215,6 +215,13 @@ def split(self, parent, factor=None, nparts=None): nparts : Expr, optional The number of outer parts. + disable_predication : bool + If enabled, don't create a predicate for guarding the loop. This + can be useful when splitting with scalable factors that the + schedule writer knows are divisible. + Warning: enabling this feature may result in incorrect code + generation if not used carefully. + Returns ------- outer : IterVar @@ -226,11 +233,11 @@ def split(self, parent, factor=None, nparts=None): if nparts is not None: if factor is not None: raise ValueError("Do not need to provide both outer and nparts") - outer, inner = _ffi_api.StageSplitByNParts(self, parent, nparts) + outer, inner = _ffi_api.StageSplitByNParts(self, parent, nparts, disable_predication) else: if factor is None: raise ValueError("Either nparts or factor need to be provided") - outer, inner = _ffi_api.StageSplitByFactor(self, parent, factor) + outer, inner = _ffi_api.StageSplitByFactor(self, parent, factor, disable_predication) return outer, inner def fuse(self, *args): diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index ccad989c33ef..0fbb2084e4de 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -1023,7 +1023,7 @@ def _corstone300_compile_time_check(): # check cpu features -def _has_cpu_feat(features): +def has_cpu_feat(features): cpu = codegen.llvm_get_system_cpu() triple = codegen.llvm_get_system_triple() target = "llvm -mtriple=%s -mcpu=%s" % (triple, cpu) @@ -1035,21 +1035,28 @@ def _has_cpu_feat(features): requires_arm_dot = Feature( "arm_dot", "ARM dot product", - run_time_check=lambda: _has_cpu_feat("dotprod"), + run_time_check=lambda: has_cpu_feat("dotprod"), +) + + +requires_aarch64_sve = Feature( + "arm_sve", + "AArch64 SVE", + run_time_check=lambda: has_cpu_feat("sve"), ) requires_x86_vnni = Feature( "x86_vnni", "x86 VNNI Extensions", - run_time_check=lambda: (_has_cpu_feat("avx512vnni") or _has_cpu_feat("avxvnni")), + run_time_check=lambda: (has_cpu_feat("avx512vnni") or has_cpu_feat("avxvnni")), ) requires_x86_avx512 = Feature( "x86_avx512", "x86 AVX512 Extensions", - run_time_check=lambda: _has_cpu_feat( + run_time_check=lambda: has_cpu_feat( ["avx512bw", "avx512cd", "avx512dq", "avx512vl", "avx512f"] ), ) @@ -1058,7 +1065,7 @@ def _has_cpu_feat(features): requires_x86_amx = Feature( "x86_amx", "x86 AMX Extensions", - run_time_check=lambda: _has_cpu_feat("amx-int8"), + run_time_check=lambda: has_cpu_feat("amx-int8"), ) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index f0500290b888..24ba4ccd2e58 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -88,6 +88,7 @@ from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace from .op import start_profile_intrinsic, end_profile_intrinsic +from .op import vscale, get_active_lane_mask from .generic import add, subtract, multiply from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index ec57ad7801ca..17104ba3d0ec 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -141,6 +141,57 @@ def vstore(self, begin, value): begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin return _ffi_api.BufferVStore(self, begin, value) # type: ignore + def load(self, indices, predicate=None): + """ + Load values at specified indices from buffer. + + Longhand notation that can be used for complex buffer load + expressions. For example, when the load involves predication. + + Parameters + ---------- + indices : List[PrimExpr] + The buffer indices to load values from. + + predicate : Optional[PrimExpr] + A vector mask of int1 values that prevents loading values on masked-off lanes. + + Returns + ------- + BufferLoad + A buffer load Expr. + """ + from .expr import BufferLoad # pylint: disable=import-outside-toplevel + + return BufferLoad(self, indices, predicate) + + def store(self, value, indices, predicate=None): + """ + Store given value at the specified indices in the buffer. + + Longhand notation that can be used for complex buffer store + statements. For example, when the store involves predication. + + Parameters + ---------- + value : PrimExpr + The value to be stored. + + indices : List[PrimExpr] + The buffer indices to store values to. + + predicate : Optional[PrimExpr] + A vector mask of int1 values that prevents storing values on masked-off lanes. + + Returns + ------- + BufferStore + A buffer store Stmt. + """ + from .stmt import BufferStore # pylint: disable=import-outside-toplevel + + return BufferStore(self, value, indices, predicate) + def scope(self): """Return the storage scope associated with this buffer. Returns diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index fad9fca083a1..c6c502a29632 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -1093,7 +1093,10 @@ class BufferLoad(PrimExprWithOp): The buffer to be loaded. indices : List[PrimExpr] - The buffer indices. + The buffer indices to load values from. + + predicate : Optional[PrimExpr] + A vector mask of int1 values that prevents loading values on masked-off lanes. span : Optional[Span] The location of this expression in the source code. @@ -1103,10 +1106,14 @@ class BufferLoad(PrimExprWithOp): indices: List[PrimExpr] def __init__( - self, buffer: Buffer, indices: List[PrimExpr], span: Optional[Span] = None + self, + buffer: Buffer, + indices: List[PrimExpr], + predicate: Optional[PrimExpr] = None, + span: Optional[Span] = None, ) -> None: self.__init_handle_by_constructor__( - _ffi_api.BufferLoad, buffer, indices, span # type: ignore + _ffi_api.BufferLoad, buffer, indices, predicate, span # type: ignore ) diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 4735a2644e83..ddfadf1b0d4e 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -1586,6 +1586,30 @@ def vectorlow(dtype, vec): return call_intrin(dtype, "tir.vectorlow", vec) +def vscale(): + return call_intrin("int32", "tir.vscale") + + +def get_active_lane_mask(dtype, base, limit): + """ + Creates a mask corresponding to active and inactive vector lanes. + + Analogous to https://llvm.org/docs/LangRef.html#llvm-get-active-lane-mask-intrinsics. + + Parameters + ---------- + dtype : str + The data type of the result. + + base : PrimExpr + An expression reprsenting the base. + + limit : PrimExpr + An expression representing the limit. + """ + return call_intrin(dtype, "tir.get_active_lane_mask", base, limit) + + def vectorhigh(dtype, vec): """Get the high level half of the vector diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 23b000c09015..32563d504f61 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -736,6 +736,7 @@ def split( loop: LoopRV, factors: List[Union[int, ExprRV, None]], preserve_unit_iters: bool = True, + disable_predication: bool = False, ) -> List[LoopRV]: """Split a loop into a list of consecutive loops. It requires: 1) The loop can't have annotation or thread binding. @@ -759,6 +760,14 @@ def split( preserve_unit_iters : bool Whether or not to preserve unit iterators in block bindings + disable_predication : bool + If enabled, don't create a predicate for guarding the loop. This + can be useful when splitting with scalable factors that the + schedule writer knows are divisible. + + Warning: enabling this feature may result in incorrect code + generation if not used carefully. + Returns ------- split_loops : List[LoopRV] @@ -809,7 +818,11 @@ def after_split(a: T.handle, b: T.handle) -> None: # that there is at most one None in `factors` return list( _ffi_api.ScheduleSplit( # type: ignore # pylint: disable=no-member - self, loop, factors, preserve_unit_iters + self, + loop, + factors, + preserve_unit_iters, + disable_predication, ) ) diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 992c388e27bb..0950547082a7 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -219,11 +219,14 @@ class BufferStore(Stmt): The buffer. value : PrimExpr - The value we to be stored. + The value to be stored. indices : List[PrimExpr] The indices location to be stored. + predicate : Optional[PrimExpr] + A vector mask of int1 values that prevents storing values on masked-off lanes. + span : Optional[Span] The location of the stmt in the source code. """ @@ -231,6 +234,7 @@ class BufferStore(Stmt): buffer: Buffer value: PrimExpr indices: List[PrimExpr] + predicate: Optional[PrimExpr] span: Optional[Span] def __init__( @@ -238,10 +242,11 @@ def __init__( buffer: Buffer, value: PrimExpr, indices: List[PrimExpr], + predicate: Optional[PrimExpr] = None, span: Optional[Span] = None, ) -> None: self.__init_handle_by_constructor__( - _ffi_api.BufferStore, buffer, value, indices, span # type: ignore + _ffi_api.BufferStore, buffer, value, indices, predicate, span # type: ignore ) diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 3e5b8834ebca..17b0634227eb 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -22,9 +22,13 @@ */ #include #include +#include #include #include +#include "../target/parsers/aprofile.h" +#include "../tir/analysis/check_contains.h" +#include "./scalable_expression.h" #include "const_fold.h" #include "product_normal_form.h" @@ -227,6 +231,33 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { } } + // Current analysis may not be powerful enough to prove expressions with + // multiple symbolic values. When the expression is scalable and the compile + // target is AArch64, we can make some assumptions about the value of vscale + // and iterate over a space of pre-defined values to attempt to prove the + // expression. + Target current_target = tvm::Target::Current(true); + if (!current_target.defined()) { + return false; + } + TargetJSON target_json = target::parsers::aprofile::ParseTarget(current_target->Export()); + TargetFeatures features = Downcast(target_json.at("features")); + bool is_llvm_aarch64 = Downcast(features.at("is_aarch64")); + if (is_llvm_aarch64 && tir::CheckContains::ExprContains(simplified, IsVScaleCall)) { + bool can_prove_expr = true; + for (const unsigned int vscale_value : kVScaleValues) { + PrimExpr result = SubstituteVScaleWithKnownValue(expr, vscale_value); + result = Simplify(result); + const int64_t* as_int = tir::as_const_int(result); + if (!as_int || *as_int == 0) { + can_prove_expr = false; + } + } + if (can_prove_expr) { + return true; + } + } + return false; } diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 625488430bf8..88d6d2149b7a 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -466,14 +467,21 @@ class IntervalSetEvaluator : public ExprFunctor { if (stride.Match(op->stride)) { DataType t = op->base.dtype(); int64_t vstride = stride.Eval()->value; - if (vstride > 0) { - return Combine(analyzer_, base, - IntervalSet(make_zero(t), make_const(t, vstride * (op->lanes - 1))), - op->dtype); - } else { - return Combine(analyzer_, base, - IntervalSet(make_const(t, vstride * (op->lanes - 1)), make_zero(t)), - op->dtype); + if (op->lanes->IsInstance()) { + int lanes = static_cast(Downcast(op->lanes)->value); + if (vstride > 0) { + return Combine(analyzer_, base, + IntervalSet(make_zero(t), make_const(t, vstride * (lanes - 1))), + op->dtype); + } else { + return Combine(analyzer_, base, + IntervalSet(make_const(t, vstride * (lanes - 1)), make_zero(t)), + op->dtype); + } + } else { /* Scalable vector */ + if (vstride > 0) { + return Combine(analyzer_, base, IntervalSet(make_zero(t), pos_inf()), op->dtype); + } } } DLOG(WARNING) << "cannot evaluate set on expression " << GetRef(op); diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index d057a840e8b7..98cf61990d90 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -628,10 +628,11 @@ inline PRampExpr ramp(const Pattern& base, } template -inline PRampExpr, PConst> ramp(const Pattern& base, - int stride, int lanes) { - return PRampExpr, PConst>( - base.derived(), PConstWithTypeLike(base.derived(), stride), PConst(lanes)); +inline PRampExpr, PConstWithTypeLike> ramp( + const Pattern& base, int stride, int lanes) { + return PRampExpr, PConstWithTypeLike>( + base.derived(), PConstWithTypeLike(base.derived(), stride), + PConstWithTypeLike(base.derived(), lanes)); } /*! @@ -835,6 +836,12 @@ inline PCallExpr if_then_else(const Pattern false_value.derived()); } +// vscale +struct PVscaleOp { + static PrimExpr Eval() { return tir::Call(DataType::Int(32), GetOp(), {}); } + static const Op& GetOp() { return tir::builtin::vscale(); } +}; + template class PMatchesOneOf { public: diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index d5f946fca02a..ddcc4eff96a1 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -247,7 +247,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { // Pattern var match FloatImm PVar c4; // Pattern var for lanes in broadcast and ramp - PVar lanes; + PVar lanes; // Vector rules if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(ramp(b1, s1, lanes) + ramp(b2, s2, lanes), ramp(b1 + b2, s1 + s2, lanes)); @@ -396,7 +396,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { // Pattern var match IntImm PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp - PVar lanes; + PVar lanes; // Vector rules if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes), ramp(b1 - b2, s1 - s2, lanes)); @@ -580,7 +580,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { // Pattern var match FloatImm PVar c3; // Pattern var for lanes in broadcast and ramp - PVar lanes; + PVar lanes; // Vector rules if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes), broadcast(x * y, lanes)); @@ -617,7 +617,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { // Pattern var match IntImm PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp - PVar lanes; + PVar lanes; // x / 2.0 = x * 0.5 if (const FloatImmNode* ptr = op->b.as()) { @@ -639,10 +639,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { return ramp(div(b1, c2), div(c1, c2), lanes).Eval(); } // If all possible indices in ramp are the same. - if (CanProveGreaterEqual(b1.Eval(), 0)) { + if (CanProveGreaterEqual(b1.Eval(), 0) && !IsWithScalableLanes(lanes.Eval())) { ModularSet bmod = analyzer_->modular_set(b1.Eval()); int64_t ramp_min = bmod->base / c2val; - int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val; + auto lanes_int = lanes.Eval().as()->value; + int64_t ramp_max = (bmod->base + (lanes_int - 1) * c1val) / c2val; if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) { return broadcast(div(b1, c2), lanes).Eval(); } @@ -777,7 +778,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { // Pattern var match IntImm PVar c1, c2; // Pattern var for lanes in broadcast and ramp - PVar lanes; + PVar lanes; // Vector rules if (op->dtype.lanes() != 1) { @@ -793,14 +794,23 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { return broadcast(truncmod(b1, c2), lanes).Eval(); } // If all possible indices in ramp are the same. - if (CanProveGreaterEqual(b1.Eval(), 0)) { + if (CanProveGreaterEqual(b1.Eval(), 0) /*&& !IsWithScalableLanes(lanes.Eval())*/) { ModularSet bmod = analyzer_->modular_set(b1.Eval()); - int64_t ramp_min = bmod->base / c2val; - int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val; - if (bmod->coeff % c2val == 0) { - if (ramp_min == ramp_max) { - return ramp(truncmod(bmod->base, c2), c1, lanes).Eval(); - } else { + if (!IsWithScalableLanes(lanes.Eval())) { + auto lanes_int = lanes.Eval().as()->value; + int64_t ramp_min = bmod->base / c2val; + int64_t ramp_max = (bmod->base + (lanes_int - 1) * c1val) / c2val; + if (bmod->coeff % c2val == 0) { + if (ramp_min == ramp_max) { + return ramp(truncmod(bmod->base, c2), c1, lanes).Eval(); + } else { + return truncmod(ramp(truncmod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)) + .Eval(); + } + } + } else { /* Special case for scalable vectors */ + ModularSet bmod = analyzer_->modular_set(b1.Eval()); + if (bmod->coeff % c2val == 0) { return truncmod(ramp(truncmod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); } } @@ -857,7 +867,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { // Pattern var match IntImm PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp - PVar lanes; + PVar lanes; // Vector rules if (op->dtype.lanes() != 1) { @@ -872,17 +882,20 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { return ramp(floordiv(b1, c2), floordiv(c1, c2), lanes).Eval(); } // If all possible indices in ramp are the same. - ModularSet bmod = analyzer_->modular_set(b1.Eval()); - int64_t ramp_min = floordiv(bmod->base, c2val); - int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val); - if (ramp_min == ramp_max) { - // If b1 can devide c2 - if (bmod->coeff % c2val == 0) { - return broadcast(floordiv(b1, c2), lanes).Eval(); - } - // If all indices can be guaranteed to settle inside a coeff range - if (c2val % bmod->coeff == 0 && bmod->base + (lanes.Eval() - 1) * c1val < bmod->coeff) { - return broadcast(floordiv(b1, c2), lanes).Eval(); + if (!IsWithScalableLanes(lanes.Eval())) { + ModularSet bmod = analyzer_->modular_set(b1.Eval()); + int64_t ramp_min = floordiv(bmod->base, c2val); + auto lanes_int = lanes.Eval().as()->value; + int64_t ramp_max = floordiv(bmod->base + (lanes_int - 1) * c1val, c2val); + if (ramp_min == ramp_max) { + // If b1 can devide c2 + if (bmod->coeff % c2val == 0) { + return broadcast(floordiv(b1, c2), lanes).Eval(); + } + // If all indices can be guaranteed to settle inside a coeff range + if (c2val % bmod->coeff == 0 && bmod->base + (lanes_int - 1) * c1val < bmod->coeff) { + return broadcast(floordiv(b1, c2), lanes).Eval(); + } } } } @@ -993,7 +1006,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { // Pattern var match IntImm PVar c1, c2; // Pattern var for lanes in broadcast and ramp - PVar lanes; + PVar lanes; // Vector rules if (op->dtype.lanes() != 1) { @@ -1010,21 +1023,28 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { } // If all possible indices in ramp are the same. ModularSet bmod = analyzer_->modular_set(b1.Eval()); - int64_t ramp_min = floordiv(bmod->base, c2val); - int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val); - if (ramp_min == ramp_max) { - // If b1 can devide c2 + if (!IsWithScalableLanes(lanes.Eval())) { + int64_t ramp_min = floordiv(bmod->base, c2val); + auto lanes_int = lanes.Eval().as()->value; + int64_t ramp_max = floordiv(bmod->base + (lanes_int - 1) * c1val, c2val); + if (ramp_min == ramp_max) { + // If b1 can devide c2 + if (bmod->coeff % c2val == 0) { + return ramp(floormod(bmod->base, c2), c1, lanes).Eval(); + } + // If all indices can be guaranteed to settle inside a coeff range + if (c2val % bmod->coeff == 0 && bmod->base + (lanes_int - 1) * c1val < bmod->coeff) { + return ramp(floormod(b1, c2), c1, lanes).Eval(); + } + } if (bmod->coeff % c2val == 0) { - return ramp(floormod(bmod->base, c2), c1, lanes).Eval(); + return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); } - // If all indices can be guaranteed to settle inside a coeff range - if (c2val % bmod->coeff == 0 && bmod->base + (lanes.Eval() - 1) * c1val < bmod->coeff) { - return ramp(floormod(b1, c2), c1, lanes).Eval(); + } else { /* scalable vectors */ + if (bmod->coeff % c2val == 0) { + return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); } } - if (bmod->coeff % c2val == 0) { - return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); - } } } @@ -1093,7 +1113,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) { PVar x, y, z, s1, s2; // Pattern var match IntImm PVar c1, c2; - PVar lanes; + PVar lanes; // vector rule if (op->dtype.lanes() != 1) { @@ -1267,7 +1287,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) { PVar x, y, z, s1, s2; // Pattern var match IntImm PVar c1, c2; - PVar lanes; + PVar lanes; // vector rule if (op->dtype.lanes() != 1) { @@ -1475,7 +1495,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { PVar x, y; // Pattern var match IntImm PVar c1, c2; - PVar lanes; + PVar lanes; // vector rule if (ret->dtype.lanes() != 1) { @@ -1603,7 +1623,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) { PVar x, y, z, s1, s2; // Pattern var match IntImm PVar c1, c2; - PVar lanes; + PVar lanes; // vector rule if (ret->dtype.lanes() != 1) { @@ -1761,7 +1781,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) { PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(Not ret) { // Pattern var to match any expression PVar x, y; - PVar lanes; + PVar lanes; if (ret->dtype.lanes() != 1) { TVM_TRY_REWRITE(!broadcast(x, lanes), broadcast(!x, lanes)); } @@ -1836,7 +1856,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { PVar x, y, z; // Pattern var match IntImm PVar c1, c2, c3; - PVar lanes; + PVar lanes; if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), broadcast(x && y, lanes)); @@ -1984,7 +2004,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { PVar x, y, z; // Pattern var match IntImm PVar c1, c2; - PVar lanes; + PVar lanes; if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), broadcast(x || y, lanes)); diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 7c4b0eab2224..288ee948ff90 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -221,6 +221,8 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { bool CanProveGreaterEqual(const PrimExpr& x, int64_t val) { return analyzer_->CanProveGreaterEqual(x, val); } + // Whether the lanes are scalable + bool IsWithScalableLanes(const PrimExpr& lanes) { return !lanes.as(); } // Whether x < val bool CanProveLess(const PrimExpr& x, int64_t val) { return analyzer_->CanProveLess(x, val); } // Whether x == val diff --git a/src/arith/scalable_expression.cc b/src/arith/scalable_expression.cc new file mode 100644 index 000000000000..c3437f5f3807 --- /dev/null +++ b/src/arith/scalable_expression.cc @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/arith/scalable_expression.cc + * \brief Analyze scalable expressions. + */ + +#include "scalable_expression.h" + +#include +#include + +// #include "../tir/analysis/vscale_expr.h" +#include "../tir/transforms/replace_selected_expr.h" +#include "./pattern_match.h" + +namespace tvm { +namespace arith { + +bool IsVScaleCall(const PrimExpr& expr) { + if (auto call = expr.as()) { + return call->op.same_as(tir::builtin::vscale()); + } + return false; +} + +PrimExpr SubstituteVScaleWithKnownValue(const PrimExpr& expr, unsigned int vscale_value) { + std::function predicate_selector = [](const PrimExpr& current_expr) { + return IsVScaleCall(current_expr); + }; + + std::function predicate_can_replace_inside = + [](const PrimExpr& current_expr) { return true; }; + + return tir::ReplaceSelectedExpr::ReplaceSelectedExprInExpr( + expr, predicate_selector, tir::MakeConstScalar(DataType::Int(32), vscale_value), + predicate_can_replace_inside); +} + +PrimExpr CanonicalizeScalableLanes(const PrimExpr& lanes) { + PVar multiplier; + PCallExpr vscale; + + PrimExpr new_lanes; + + if ((multiplier * vscale).Match(lanes)) { + new_lanes = lanes; + } else if ((vscale * multiplier).Match(lanes)) { + new_lanes = + tir::Mul(multiplier.Eval(), tir::Call(DataType::Int(32), tir::builtin::vscale(), {})); + } else { + LOG(FATAL) << "Illegal form for scalable vector lanes: " << lanes; + } + + return new_lanes; +} + +} // namespace arith +} // namespace tvm diff --git a/src/arith/scalable_expression.h b/src/arith/scalable_expression.h new file mode 100644 index 000000000000..c55e930b29fe --- /dev/null +++ b/src/arith/scalable_expression.h @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/arith/scalable_expression.h + * \brief Analyze scalable expressions. + */ + +#ifndef TVM_ARITH_SCALABLE_EXPRESSION_H_ +#define TVM_ARITH_SCALABLE_EXPRESSION_H_ + +#include + +#include + +namespace tvm { +namespace arith { + +/*! \brief A list of known vscale values to try. */ +static const std::vector kVScaleValues = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}; + +/*! + * \brief Substitute a vscale intrinsic call with a known value. + * \param expr The expr to apply substitutions to. + * \param vscale_value The scalar value to replace vscale with. + * \return A rewritten expression with vscale values replaced with a scalar value. + */ +PrimExpr SubstituteVScaleWithKnownValue(const PrimExpr& expr, unsigned int vscale_value); + +/*! + * \brief Check if an expr is a call to the vscale intrinsic. + * \param expr The expr to check + * \return True if the expr is a call to the vscale intrinsic, false if not. + */ +bool IsVScaleCall(const PrimExpr& expr); + +/*! + * \brief Returns the scalable lanes in a form multiplier * vscale + * \param lanes The scalable lanes as a PrimExpr + * \return Scalable lanes in a form multiplier * vscale + */ +PrimExpr CanonicalizeScalableLanes(const PrimExpr& lanes); + +} // namespace arith +} // namespace tvm + +#endif // TVM_ARITH_SCALABLE_EXPRESSION_H_ diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 17cd5c49a1bf..42a4a0a2fe3f 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -44,6 +44,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_buffer_predication", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool); diff --git a/src/ir/expr.cc b/src/ir/expr.cc index fdd8c2cd8bc5..596805f74b24 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -53,8 +53,8 @@ PrimExpr PrimExpr::FromObject_(ObjectRef ref) { for (const Range& r : buffer_region->region) { if (tvm::tir::is_one(r->extent)) { indices.push_back(r->min); - } else if (const auto* extent = r->extent.as()) { - indices.push_back(tir::Ramp(r->min, tvm::tir::make_const(r->min->dtype, 1), extent->value)); + } else if (r->extent.as()) { + indices.push_back(tir::Ramp(r->min, tvm::tir::make_const(r->min->dtype, 1), r->extent)); } else { LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: " << ref; } diff --git a/src/relay/printer/tir_text_printer.cc b/src/relay/printer/tir_text_printer.cc index e9a9ee231358..c34788be91b8 100644 --- a/src/relay/printer/tir_text_printer.cc +++ b/src/relay/printer/tir_text_printer.cc @@ -381,13 +381,13 @@ Doc TIRTextPrinter::VisitExpr_(const ProducerLoadNode* op) { Doc TIRTextPrinter::VisitExpr_(const RampNode* op) { Doc doc; - doc << "ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " << op->lanes << ")"; + doc << "ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " << Print(op->lanes) << ")"; return doc; } Doc TIRTextPrinter::VisitExpr_(const BroadcastNode* op) { Doc doc; - doc << "broadcast(" << Print(op->value) << ", " << op->lanes << ")"; + doc << "broadcast(" << Print(op->value) << ", " << Print(op->lanes) << ")"; return doc; } diff --git a/src/relay/printer/tvmscript_printer.cc b/src/relay/printer/tvmscript_printer.cc index b0085b82426e..1126e633d633 100644 --- a/src/relay/printer/tvmscript_printer.cc +++ b/src/relay/printer/tvmscript_printer.cc @@ -912,14 +912,14 @@ Doc TVMScriptPrinter::VisitExpr_(const RampNode* op, ExprPrecedence* out_precede *out_precedence = ExprPrecedence::kIdentity; Doc doc; doc << tir_prefix_ << ".ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " - << op->lanes << ")"; + << Print(op->lanes) << ")"; return doc; } Doc TVMScriptPrinter::VisitExpr_(const BroadcastNode* op, ExprPrecedence* out_precedence) { *out_precedence = ExprPrecedence::kIdentity; Doc doc; - doc << tir_prefix_ << ".broadcast(" << Print(op->value) << ", " << op->lanes << ")"; + doc << tir_prefix_ << ".broadcast(" << Print(op->value) << ", " << Print(op->lanes) << ")"; return doc; } diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index d6554fc37103..c8f725606580 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -516,10 +516,18 @@ Var EnvThread(String thread_tag) { return var; } -void BufferStore(Buffer buffer, PrimExpr value, Array indices) { +void BufferStore(Buffer buffer, PrimExpr value, Array indices, + PrimExpr predicate = PrimExpr()) { runtime::DataType buffer_dtype = buffer->dtype; int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1; - runtime::DataType lhs_dtype = buffer_dtype.with_lanes(buffer_dtype.lanes() * index_lanes); + bool index_scalable = indices.size() ? indices.back().dtype().is_scalable() : false; + bool buffer_scalable = buffer->dtype.is_scalable(); + runtime::DataType lhs_dtype; + if (index_scalable || buffer_scalable) { + lhs_dtype = buffer_dtype.with_scalable_lanes(buffer_dtype.lanes() * index_lanes); + } else { + lhs_dtype = buffer_dtype.with_lanes(buffer_dtype.lanes() * index_lanes); + } runtime::DataType rhs_dtype = value->dtype; if (lhs_dtype != rhs_dtype) { if (lhs_dtype.lanes() != rhs_dtype.lanes()) { @@ -545,7 +553,7 @@ void BufferStore(Buffer buffer, PrimExpr value, Array indices) { } value = tvm::cast(lhs_dtype, value); } - AddToParent(tvm::tir::BufferStore(buffer, value, indices)); + AddToParent(tvm::tir::BufferStore(buffer, value, indices, predicate)); } void Prefetch(Buffer buffer, Array bounds) { diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index 45a0dfd2aea4..e5daee7d7b75 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -17,6 +17,7 @@ * under the License. */ #include // For `kAllocAlignment` +#include #include "./utils.h" @@ -273,14 +274,33 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::BufferStore store, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc buffer = d->AsDoc(store->buffer, p->Attr("buffer")); - return AssignDoc(/*lhs=*/buffer[BufferIndices(store->indices, p->Attr("indices"), d)], - /*rhs=*/d->AsDoc(store->value, p->Attr("value")), NullOpt); + ExprDoc value = d->AsDoc(store->value, p->Attr("value")); + + // Use .store(...) syntax when there is a predicate + if (store->predicate.defined()) { + ExprDoc indices = d->AsDoc(store->indices, p->Attr("indices")); + ExprDoc predicate = d->AsDoc(store->predicate, p->Attr("predicate")); + return ExprStmtDoc( + buffer->Attr("store")->Call({value, indices}, {"predicate"}, {predicate})); + } + + return AssignDoc( + /*lhs=*/buffer[BufferIndices(store->indices, p->Attr("indices"), d)], + /*rhs=*/value, NullOpt); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::BufferLoad load, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc buffer = d->AsDoc(load->buffer, p->Attr("buffer")); + + // Use .load(...) syntax when there is a predicate + if (load->predicate.defined()) { + ExprDoc indices = d->AsDoc(load->indices, p->Attr("indices")); + ExprDoc predicate = d->AsDoc(load->predicate, p->Attr("predicate")); + return buffer->Attr("load")->Call({indices}, {"predicate"}, {predicate}); + } + return buffer[BufferIndices(load->indices, p->Attr("indices"), d)]; }); diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index e25b074401d4..8268e6b35ecb 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -137,7 +137,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return TIR(d, "Ramp")->Call({ d->AsDoc(ramp->base, ramp_p->Attr("base")), d->AsDoc(ramp->stride, ramp_p->Attr("stride")), - LiteralDoc::Int(ramp->lanes, ramp_p->Attr("lanes")), + d->AsDoc(ramp->lanes, ramp_p->Attr("lanes")), }); }); @@ -146,7 +146,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return TIR(d, "Broadcast") ->Call({ d->AsDoc(bc->value, bc_p->Attr("value")), - LiteralDoc::Int(bc->lanes, bc_p->Attr("lanes")), + d->AsDoc(bc->lanes, bc_p->Attr("lanes")), }); }); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 9701a299f1d1..745fe5c00ec9 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -589,9 +589,17 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { } if (dtype.lanes() != 1) { #if TVM_LLVM_VERSION >= 110 - return llvm::FixedVectorType::get(etype, dtype.lanes()); + if (dtype.is_scalable()) { + return llvm::VectorType::get(etype, dtype.lanes(), true); + } else { + return llvm::FixedVectorType::get(etype, dtype.lanes()); + } #else - return llvm::VectorType::get(etype, dtype.lanes()); + if (dtype.is_scalable()) { + return llvm::VectorType::get(etype, dtype.lanes(), true); + } else { + return llvm::VectorType::get(etype, dtype.lanes()); + } #endif } else { return etype; @@ -641,13 +649,13 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer_va int64_t base = 0, width = 0; arith::PVar pbase, pstride; - arith::PVar planes; + arith::PVar planes; // create meta-data for alias analysis // Use a group of binary tree ranges of memory banks. int64_t xwith = 0; if (arith::ramp(pbase, pstride, planes).Match(index)) { base = pbase.Eval()->value; - xwith = planes.Eval() * pstride.Eval()->value; + xwith = planes.Eval()->value * pstride.Eval()->value; } else if (auto* ptr = index.as()) { base = ptr->value; xwith = 1; @@ -749,26 +757,6 @@ std::unique_ptr CodeGenLLVM::CreateDebugInfo(llvm::Modul return debug_info; } -llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) { -#if TVM_LLVM_VERSION >= 110 - llvm::Type* type = llvm::FixedVectorType::get(value->getType(), lanes); -#else - llvm::Type* type = llvm::VectorType::get(value->getType(), lanes); -#endif - llvm::Constant* undef = llvm::UndefValue::get(type); - llvm::Constant* zero = ConstInt32(0); - value = builder_->CreateInsertElement(undef, value, zero); -#if TVM_LLVM_VERSION >= 120 - llvm::Constant* mask = llvm::ConstantVector::getSplat(llvm::ElementCount::getFixed(lanes), zero); -#elif TVM_LLVM_VERSION >= 110 - llvm::Constant* mask = - llvm::ConstantVector::getSplat(llvm::ElementCount(lanes, /*Scalable=*/false), zero); -#else - llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero); -#endif - return builder_->CreateShuffleVector(value, undef, mask); -} - llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) { int num_elems = GetVectorNumElements(vec); if (extent == num_elems && begin == 0) return vec; @@ -1478,6 +1466,17 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { return builder_->CreateAssumption(cond); } else if (op->op.same_as(builtin::tvm_thread_invariant())) { return MakeValue(op->args[0]); +#if TVM_LLVM_VERSION >= 110 + } else if (op->op.same_as(builtin::vscale())) { + llvm::Intrinsic::ID id = llvm::Intrinsic::vscale; + llvm::Function* f = GetIntrinsicDecl(id, builder_->getInt32Ty(), {}); + return builder_->CreateCall(f); + } else if (op->op.same_as(builtin::get_active_lane_mask())) { + llvm::Intrinsic::ID id = llvm::Intrinsic::get_active_lane_mask; + llvm::Function* f = GetIntrinsicDecl(id, DTypeToLLVMType(op->dtype), + {builder_->getInt32Ty(), builder_->getInt32Ty()}); + return builder_->CreateCall(f, {MakeValue(op->args[0]), MakeValue(op->args[1])}); +#endif } else { LOG(FATAL) << "unknown intrinsic " << op->op; } @@ -1658,9 +1657,9 @@ bool CodeGenLLVM::HasAlignmentPadding(DataType dtype) { } void CodeGenLLVM::BufferAccessHelper( - Buffer buffer, Array indices, DataType value_dtype, - std::function + Buffer buffer, Array indices, PrimExpr predicate, DataType value_dtype, + std::function make_instruction) { DataType buffer_element_dtype = buffer->dtype; @@ -1739,10 +1738,20 @@ void CodeGenLLVM::BufferAccessHelper( std::vector all_index_values = earlier_index_values; all_index_values.push_back(last_index_value); + llvm::Value* pred_val = nullptr; + if (predicate.defined()) { + pred_val = MakeValue(predicate); + } + TypedPointer buffer_ptr = - CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, all_index_values, - value_dtype.with_lanes(value_dtype.lanes() / last_index.dtype().lanes())); - auto instruction = make_instruction(buffer_ptr, subelement_i, alignment, is_volatile); + value_dtype.is_scalable() + ? CreateBufferPtr( + MakeValue(buffer->data), buffer_element_dtype, all_index_values, + value_dtype.with_scalable_lanes(value_dtype.lanes() / last_index.dtype().lanes())) + : CreateBufferPtr( + MakeValue(buffer->data), buffer_element_dtype, all_index_values, + value_dtype.with_lanes(value_dtype.lanes() / last_index.dtype().lanes())); + auto instruction = make_instruction(buffer_ptr, subelement_i, pred_val, alignment, is_volatile); AddAliasInfo(instruction, buffer->data.get(), last_index_origin, buffer_element_dtype_origin); } } @@ -1752,11 +1761,17 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { std::vector loads; - auto make_load = [this, &loads](TypedPointer buffer_ptr, int /* subelement_i */, int alignment, - bool is_volatile) { + auto make_load = [this, &loads](TypedPointer buffer_ptr, int /* subelement_i */, + llvm::Value* predicate, int alignment, bool is_volatile) { #if TVM_LLVM_VERSION >= 110 - auto load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, - llvm::Align(alignment), is_volatile); + llvm::Instruction* load = nullptr; + if (predicate != NULL) { + load = builder_->CreateMaskedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(alignment), + predicate); + } else { + load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(alignment), + is_volatile); + } #elif TVM_LLVM_VERSION >= 80 auto load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, alignment, is_volatile); @@ -1771,7 +1786,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { // Pass all indices into BufferAccessHelper. In CodeGenLLVM, // non-flat indices will result in an error in CreateBufferPtr, but // a subclass may override CreateBufferPtr. - BufferAccessHelper(op->buffer, op->indices, value_dtype, make_load); + BufferAccessHelper(op->buffer, op->indices, op->predicate, value_dtype, make_load); if (loads.size() == 1) { return loads[0]; @@ -1821,7 +1836,10 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) { llvm::Value* vec = llvm::UndefValue::get(DTypeToLLVMType(op->dtype)); - for (int i = 0; i < op->lanes; ++i) { + auto* lanes_as_int = op->lanes.as(); + ICHECK(lanes_as_int) << "vscale in codegen_llvm"; + auto lanes = static_cast(lanes_as_int->value); + for (int i = 0; i < lanes; ++i) { vec = builder_->CreateInsertElement( vec, MakeValue(op->base + op->stride * make_const(op->stride.dtype(), i)), ConstInt32(i)); } @@ -1853,7 +1871,22 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) { - return CreateBroadcast(MakeValue(op->value), op->lanes); + DataType dtype = op->dtype; + llvm::Value* value = MakeValue(op->value); + llvm::Type* type = DTypeToLLVMType(dtype); + llvm::Constant* undef = llvm::UndefValue::get(type); + llvm::Constant* zero = ConstInt32(0); + value = builder_->CreateInsertElement(undef, value, zero); +#if TVM_LLVM_VERSION >= 110 + llvm::ElementCount ec = llvm::ElementCount::get(dtype.lanes(), dtype.is_scalable()); + llvm::Constant* mask = llvm::ConstantVector::getSplat(ec, zero); +#else + if (dtype.is_scalable()) { + LOG(FATAL) << "Can't create scalable broadcast"; + } + llvm::Constant* mask = llvm::ConstantVector::getSplat(dtype.lanes(), zero); +#endif + return builder_->CreateShuffleVector(value, undef, mask); } void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { @@ -1863,24 +1896,31 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { llvm::Value* value = MakeValue(op->value); - auto make_store = [this, value](TypedPointer buffer_ptr, int subelement_i, int alignment, - bool is_volatile) { + auto make_store = [this, value](TypedPointer buffer_ptr, int subelement_i, llvm::Value* predicate, + int alignment, bool is_volatile) { llvm::Value* to_store = value; + llvm::Instruction* store; if (subelement_i != -1) { to_store = builder_->CreateExtractElement(value, subelement_i); } #if TVM_LLVM_VERSION >= 110 - return builder_->CreateAlignedStore(to_store, buffer_ptr.addr, llvm::Align(alignment), - is_volatile); + if (predicate != NULL) { + store = + builder_->CreateMaskedStore(to_store, buffer_ptr.addr, llvm::Align(alignment), predicate); + } else { + store = builder_->CreateAlignedStore(to_store, buffer_ptr.addr, llvm::Align(alignment), + is_volatile); + } #else - return builder_->CreateAlignedStore(to_store, buffer_ptr.addr, alignment, is_volatile); + store = builder_->CreateAlignedStore(to_store, buffer_ptr.addr, alignment, is_volatile); #endif + return store; }; // Pass all indices into BufferAccessHelper. In CodeGenLLVM, // non-flat indices will result in an error in CreateBufferPtr, but // a subclass may override CreateBufferPtr. - BufferAccessHelper(op->buffer, op->indices, value_dtype, make_store); + BufferAccessHelper(op->buffer, op->indices, op->predicate, value_dtype, make_store); } void CodeGenLLVM::VisitStmt_(const ForNode* op) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 8c8929c8f093..23b2d6edbeef 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -349,9 +349,9 @@ class CodeGenLLVM : public ExprFunctor, * - Should return the generated expression. */ void BufferAccessHelper( - Buffer buffer, Array indices, DataType value_dtype, - std::function + Buffer buffer, Array indices, PrimExpr predicate, DataType value_dtype, + std::function make_instruction); // Initialize target virtual void InitTarget(); @@ -468,7 +468,6 @@ class CodeGenLLVM : public ExprFunctor, llvm::Value* CreateAdd(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateSub(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateMul(DataType t, llvm::Value* a, llvm::Value* b); - llvm::Value* CreateBroadcast(llvm::Value* value, int lanes); virtual TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype, llvm::ArrayRef indices, DataType value_dtype); // Vector concatenation. diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index dd83fbdcbd62..90722557550b 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -888,11 +888,13 @@ void CodeGenC::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*) // NOTE: C have comma expression so cannot use (int2)(v0, v1) // instead should use int2(v0, v1) PrintType(op->dtype, os); + ICHECK(op->lanes->IsInstance()) << "Scalable vectors are not supported in codegen_c"; + int lanes = static_cast(Downcast(op->lanes)->value); os << "("; - for (int i = 0; i < op->lanes; i++) { + for (int i = 0; i < lanes; i++) { os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i << ")"; - if (i != op->lanes - 1) os << ", "; + if (i != lanes - 1) os << ", "; } os << ")"; } diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index caef43e8af28..de37a591ac5b 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -204,10 +204,13 @@ void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*) void CodeGenCHost::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); + ICHECK(op->lanes->IsInstance()) + << "Scalable vectors are not supported in codegen_c_host"; + int lanes = static_cast(Downcast(op->lanes)->value); os << "(("; PrintType(op->dtype, os); os << ")("; - for (int i = 0; i < op->lanes; ++i) { + for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; os << v; } diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index efed5c02f1c9..e694c7eeb4ed 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -1162,19 +1162,25 @@ void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) { } void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) { - CHECK_LE(op->lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed."; + ICHECK(op->lanes->IsInstance()) + << "Scalable vectors are not supported in codegen_cuda"; + int lanes = static_cast(Downcast(op->lanes)->value); + CHECK_LE(lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed."; PrintVecConstructor(op->dtype, os); os << "("; - for (int i = 0; i < op->lanes; i++) { + for (int i = 0; i < lanes; i++) { os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i << ")"; - if (i != op->lanes - 1) os << ", "; + if (i != lanes - 1) os << ", "; } os << ")"; } void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) - if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && op->lanes == 4) { + ICHECK(op->lanes->IsInstance()) + << "Scalable vectors are not supported in codegen_c_host"; + int lanes = static_cast(Downcast(op->lanes)->value); + if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && lanes == 4) { // make_int8x4 const int64_t* p = as_const_int(op->value); ICHECK(p); @@ -1192,7 +1198,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO std::string v = PrintExpr(op->value); PrintVecConstructor(op->dtype, os); os << '('; - for (int i = 0; i < op->lanes / 2; ++i) { + for (int i = 0; i < lanes / 2; ++i) { if (i != 0) os << ", "; os << "__pack_half2(" << v << ", " << v << ")"; } @@ -1204,7 +1210,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO std::string v = PrintExpr(op->value); PrintVecConstructor(op->dtype, os); os << '('; - for (int i = 0; i < op->lanes / 2; ++i) { + for (int i = 0; i < lanes / 2; ++i) { if (i != 0) os << ", "; os << "__pack_nv_bfloat162(" << v << ", " << v << ")"; } @@ -1218,7 +1224,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO ICHECK(p); int64_t v = *p & 0xF; - if (op->lanes == 4) { + if (lanes == 4) { v = (v << 12) | (v << 8) | (v << 4) | v; if (op->dtype.is_uint()) { os << "(uint16_t)" << v; @@ -1227,16 +1233,16 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO } } else { v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v; - if (op->lanes == 8) { + if (lanes == 8) { if (op->dtype.is_uint()) { os << "(uint)" << v; } else { os << "(int)" << v; } - } else if (op->lanes == 16 || op->lanes == 32) { + } else if (lanes == 16 || lanes == 32) { PrintVecConstructor(op->dtype, os); os << '('; - for (int i = 0; i < op->lanes / 8; ++i) { + for (int i = 0; i < lanes / 8; ++i) { if (i != 0) os << ", "; if (op->dtype.is_uint()) { os << "(uint)" << v; @@ -1258,7 +1264,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO std::string v = PrintExpr(op->value); PrintVecConstructor(op->dtype, os); os << '('; - for (int i = 0; i < op->lanes; ++i) { + for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; os << v; } diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 86d5956dec19..5a1ef12b32a2 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -308,9 +308,12 @@ void CodeGenMetal::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLI void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); + ICHECK(op->lanes->IsInstance()) + << "Scalable vectors are not supported in codegen_metal"; + int lanes = static_cast(Downcast(op->lanes)->value); PrintType(op->dtype, os); os << "("; - for (int i = 0; i < op->lanes; ++i) { + for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; os << v; } diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index da6a4de6196a..b33191734fc7 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -472,10 +472,13 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); + ICHECK(op->lanes->IsInstance()) + << "Scalable vectors are not supported in codegen_opencl"; + int lanes = static_cast(Downcast(op->lanes)->value); os << "(("; PrintType(op->dtype, os); os << ")("; - for (int i = 0; i < op->lanes; ++i) { + for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; os << v; } @@ -486,10 +489,13 @@ void CodeGenOpenCL::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLIN os << "(("; PrintType(op->dtype, os); os << ")("; - for (int i = 0; i < op->lanes; i++) { + ICHECK(op->lanes->IsInstance()) + << "Scalable vectors are not supported in codegen_opencl"; + int lanes = static_cast(Downcast(op->lanes)->value); + for (int i = 0; i < lanes; i++) { os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i << ")"; - if (i != op->lanes - 1) os << ", "; + if (i != lanes - 1) os << ", "; } os << "))"; } diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 1702699ac232..e52be9b24da6 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -344,9 +344,12 @@ void CodeGenWebGPU::PrintSSAAssign(const std::string& target, const std::string& void CodeGenWebGPU::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); + ICHECK(op->lanes->IsInstance()) + << "Scalable vectors are not supported in codegen_webgpu"; + int lanes = static_cast(Downcast(op->lanes)->value); PrintType(op->dtype, os); os << "("; - for (int i = 0; i < op->lanes; ++i) { + for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; os << v; } diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index aca504b94b98..c997a5df197b 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -519,7 +519,10 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const RampNode* op) { std::vector values; spirv::Value base = MakeValue(op->base); - for (int i = 0; i < op->lanes; ++i) { + ICHECK(op->lanes->IsInstance()) + << "Scalable vectors are not supported in codegen_spirv"; + int lanes = static_cast(Downcast(op->lanes)->value); + for (int i = 0; i < lanes; ++i) { spirv::Value v = base; if (i != 0) { spirv::Value offset = MakeValue(make_const(op->stride.dtype(), i) * op->stride); @@ -533,7 +536,10 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const RampNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const BroadcastNode* op) { std::vector values; spirv::Value v = MakeValue(op->value); - for (int i = 0; i < op->lanes; i++) { + ICHECK(op->lanes->IsInstance()) + << "Scalable vectors are not supported in codegen_spirv"; + int lanes = static_cast(Downcast(op->lanes)->value); + for (int i = 0; i < lanes; i++) { values.push_back(v); } return builder_->Concat(values); diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 0dc8b3870104..5eee46e2160c 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -79,7 +79,7 @@ class BufferSubstituter : public StmtExprMutator { auto load = Downcast(StmtExprMutator::VisitExpr_(op)); auto it = buffer_map_.find(load->buffer.get()); if (it != buffer_map_.end()) { - return BufferLoad(it->second, load->indices, load->span); + return BufferLoad(it->second, load->indices, load->predicate, load->span); } return load; } @@ -88,7 +88,7 @@ class BufferSubstituter : public StmtExprMutator { auto store = Downcast(StmtExprMutator::VisitStmt_(op)); auto it = buffer_map_.find(store->buffer.get()); if (it != buffer_map_.end()) { - return BufferStore(it->second, store->value, store->indices, store->span); + return BufferStore(it->second, store->value, store->indices, store->predicate, store->span); } return store; } diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 233663feac6d..e8f0d9332a16 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -637,7 +637,8 @@ void PassUpBoundCheck(const Stage& s, const Map& dom_map, if (outer || inner) { state[s->parent] = true; } else { - if (analyzer->CanProve(dom_map.at(s->parent)->extent == factor * step)) { + if (analyzer->CanProve(dom_map.at(s->parent)->extent == factor * step) || + s->disable_predication) { state[s->parent] = false; } else { state[s->parent] = true; diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index 44e742eee4cf..9e142b1bf76c 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -70,7 +70,7 @@ DataType MatchDataType(std::vector dtypes) { } void SplitHelper(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts, - IterVar* p_outer, IterVar* p_inner) { + IterVar* p_outer, IterVar* p_inner, bool disable_predication) { // Check if split is valid. ICHECK(parent->iter_type == kDataPar || parent->iter_type == kCommReduce || parent->iter_type == kOrdered) @@ -83,7 +83,7 @@ void SplitHelper(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr npar Array& all_vars = self->all_iter_vars; Array& leaf_vars = self->leaf_iter_vars; size_t pos = FindLeafVar(all_vars.GetArrayNode(), leaf_vars.GetArrayNode(), parent); - self->relations.push_back(Split(parent, outer, inner, factor, nparts)); + self->relations.push_back(Split(parent, outer, inner, factor, nparts, disable_predication)); // add vars to all vars all_vars.push_back(outer); all_vars.push_back(inner); @@ -226,17 +226,17 @@ Stage& Stage::set_store_predicate(PrimExpr predicate) { return *this; } -Stage& Stage::split(IterVar parent, PrimExpr factor, IterVar* p_outer, - IterVar* p_inner) { // NOLINT(*) +Stage& Stage::split(IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner, + bool disable_predication) { // NOLINT(*) With ctx(operator->()->attach_sch, __func__); - SplitHelper(operator->(), parent, factor, PrimExpr(), p_outer, p_inner); + SplitHelper(operator->(), parent, factor, PrimExpr(), p_outer, p_inner, disable_predication); return *this; } -Stage& Stage::split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, - IterVar* p_inner) { // NOLINT(*) +Stage& Stage::split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner, + bool disable_predication) { // NOLINT(*) With ctx(operator->()->attach_sch, __func__); - SplitHelper(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner); + SplitHelper(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner, disable_predication); return *this; } @@ -805,13 +805,15 @@ void ScheduleContext::ExitWithScope() { } } -Split::Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts) { +Split::Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts, + bool disable_predication) { auto n = make_object(); n->parent = parent; n->outer = outer; n->inner = inner; n->factor = factor; n->nparts = nparts; + n->disable_predication = disable_predication; data_ = std::move(n); } @@ -927,6 +929,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ", nparts="; p->Print(op->nparts); } + p->stream << ", disable_predication="; + p->stream << op->disable_predication; p->stream << ')'; }) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { @@ -973,16 +977,16 @@ TVM_REGISTER_GLOBAL("te.StageSetScope").set_body_method(&Stage::set_scope); TVM_REGISTER_GLOBAL("te.StageBind").set_body_method(&Stage::bind); TVM_REGISTER_GLOBAL("te.StageSplitByFactor") - .set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) { + .set_body_typed([](Stage stage, IterVar parent, PrimExpr factor, bool disable_predication) { IterVar outer, inner; - stage.split(parent, factor, &outer, &inner); + stage.split(parent, factor, &outer, &inner, disable_predication); return Array({outer, inner}); }); TVM_REGISTER_GLOBAL("te.StageSplitByNParts") - .set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) { + .set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts, bool disable_predication) { IterVar outer, inner; - stage.split_by_nparts(parent, nparts, &outer, &inner); + stage.split_by_nparts(parent, nparts, &outer, &inner, disable_predication); return Array({outer, inner}); }); diff --git a/src/tir/analysis/device_constraint_utils.cc b/src/tir/analysis/device_constraint_utils.cc index 4554038bc770..40df8b65c295 100644 --- a/src/tir/analysis/device_constraint_utils.cc +++ b/src/tir/analysis/device_constraint_utils.cc @@ -254,7 +254,8 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { Downcast(StmtExprMutator::VisitExpr_(buffer_load_node)); Buffer new_buffer = Subst(new_buffer_load->buffer.get()); if (!new_buffer.same_as(new_buffer_load->buffer)) { - return BufferLoad(new_buffer, new_buffer_load->indices, new_buffer_load->span); + return BufferLoad(new_buffer, new_buffer_load->indices, new_buffer_load->predicate, + new_buffer_load->span); } return std::move(new_buffer_load); } @@ -293,7 +294,7 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { Buffer new_buffer = Subst(new_buffer_store->buffer.get()); if (!new_buffer.same_as(new_buffer_store->buffer)) { return BufferStore(new_buffer, new_buffer_store->value, new_buffer_store->indices, - new_buffer_store->span); + new_buffer_store->predicate, new_buffer_store->span); } return std::move(new_buffer_store); } diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc index fba506fba1c9..6dc72d27286a 100644 --- a/src/tir/contrib/ethosu/passes.cc +++ b/src/tir/contrib/ethosu/passes.cc @@ -718,7 +718,8 @@ class MergeConstantsMutator : public StmtExprMutator { buffer->axis_separators, buffer->span}; old_to_new_read_buffers[buffer.as()] = new_buffer; - new_args.push_back(BufferLoad(new_buffer, buffer_load->indices, buffer_load->span)); + new_args.push_back(BufferLoad(new_buffer, buffer_load->indices, buffer_load->predicate, + buffer_load->span)); break; } case 2: /* length */ { diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 41500051fa89..6d41ce53818a 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -21,10 +21,13 @@ * \file expr.cc */ #include +#include +#include #include #include #include +#include "../../arith/scalable_expression.h" #include "../../support/str_escape.h" #include "buffer_common.h" @@ -427,18 +430,26 @@ TVM_REGISTER_GLOBAL("tir.Select") TVM_REGISTER_NODE_TYPE(SelectNode); // Ramp -Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) { +Ramp::Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) { ICHECK(base.defined()); ICHECK(stride.defined()); ICHECK(base.dtype().is_scalar()); ICHECK(stride.dtype().is_scalar()); - ICHECK_GT(lanes, 1); if (stride.dtype() != base.dtype()) { stride = cast(base.dtype(), stride); } ObjectPtr node = make_object(); - node->dtype = base.dtype().with_lanes(lanes); + auto* lanes_int = lanes.as(); + if (lanes_int) { + int lanes = static_cast(lanes_int->value); + ICHECK_GT(lanes, 1); + node->dtype = base.dtype().with_lanes(lanes); + } else { /* scalable vector */ + lanes = arith::CanonicalizeScalableLanes(lanes); + int vscale_multiplier = Downcast(Downcast(lanes)->a)->value; + node->dtype = base.dtype().with_scalable_lanes(vscale_multiplier); + } node->base = base; node->stride = stride; node->lanes = lanes; @@ -447,27 +458,35 @@ Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) { } TVM_REGISTER_GLOBAL("tir.Ramp") - .set_body_typed([](PrimExpr base, PrimExpr stride, int lanes, Span span) { + .set_body_typed([](PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) { return Ramp(base, stride, lanes, span); }); TVM_REGISTER_NODE_TYPE(RampNode); // Broadcast -Broadcast::Broadcast(PrimExpr value, int lanes, Span span) { +Broadcast::Broadcast(PrimExpr value, PrimExpr lanes, Span span) { ICHECK(value.defined()); ICHECK(value.dtype().is_scalar()); - ICHECK_GT(lanes, 1); ObjectPtr node = make_object(); - node->dtype = value.dtype().with_lanes(lanes); + auto* lanes_int = lanes.as(); + if (lanes_int) { + int lanes = static_cast(lanes_int->value); + ICHECK_GT(lanes, 1); + node->dtype = value.dtype().with_lanes(lanes); + } else { /* scalable vector */ + lanes = arith::CanonicalizeScalableLanes(lanes); + int vscale_multiplier = Downcast(Downcast(lanes)->a)->value; + node->dtype = value.dtype().with_scalable_lanes(vscale_multiplier); + } node->value = std::move(value); node->lanes = lanes; node->span = std::move(span); data_ = node; } -TVM_REGISTER_GLOBAL("tir.Broadcast").set_body_typed([](PrimExpr value, int lanes, Span span) { +TVM_REGISTER_GLOBAL("tir.Broadcast").set_body_typed([](PrimExpr value, PrimExpr lanes, Span span) { return Broadcast(value, lanes, span); }); @@ -525,8 +544,8 @@ TVM_REGISTER_GLOBAL("tir.Call") for (Range r : br->region) { if (is_one(r->extent)) { indices.push_back(r->min); - } else if (const auto* extent = r->extent.as()) { - indices.push_back(tir::Ramp(r->min, make_const(r->min->dtype, 1), extent->value)); + } else if (r->extent.as()) { + indices.push_back(tir::Ramp(r->min, make_const(r->min->dtype, 1), r->extent)); } else { LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: " << GetRef(br); @@ -717,10 +736,14 @@ void BufferLoadNode::LegalizeDType() { int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1; int buffer_lanes = buffer->dtype.lanes(); - this->dtype = buffer->dtype.with_lanes(index_lanes * buffer_lanes); + if ((indices.size() && indices.back().dtype().is_scalable()) || buffer->dtype.is_scalable()) { + this->dtype = buffer->dtype.with_scalable_lanes(index_lanes * buffer_lanes); + } else { + this->dtype = buffer->dtype.with_lanes(index_lanes * buffer_lanes); + } } -BufferLoad::BufferLoad(Buffer buffer, Array indices, Span span) { +BufferLoad::BufferLoad(Buffer buffer, Array indices, PrimExpr predicate, Span span) { ICHECK_EQ(buffer->shape.size(), indices.size()) << "Buffer " << buffer->name << " is " << buffer->shape.size() << "-dimensional, cannot be indexed with the " << indices.size() @@ -729,14 +752,15 @@ BufferLoad::BufferLoad(Buffer buffer, Array indices, Span span) { ObjectPtr node = make_object(); node->buffer = std::move(buffer); node->indices = std::move(indices); + node->predicate = std::move(predicate); node->span = std::move(span); node->LegalizeDType(); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.BufferLoad") - .set_body_typed([](Buffer buffer, Array indices, Span span) { - return BufferLoad(buffer, indices, span); + .set_body_typed([](Buffer buffer, Array indices, PrimExpr predicate, Span span) { + return BufferLoad(buffer, indices, predicate, span); }); TVM_REGISTER_NODE_TYPE(BufferLoadNode); diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 8a93d9dd8242..34b46583d5ad 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -127,7 +127,7 @@ PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) { if (indices.same_as(op->indices)) { return GetRef(op); } else { - return BufferLoad(op->buffer, indices); + return BufferLoad(op->buffer, indices, op->predicate); } } @@ -258,19 +258,21 @@ PrimExpr ExprMutator::VisitExpr_(const SelectNode* op) { PrimExpr ExprMutator::VisitExpr_(const RampNode* op) { PrimExpr base = this->VisitExpr(op->base); PrimExpr stride = this->VisitExpr(op->stride); - if (base.same_as(op->base) && stride.same_as(op->stride)) { + PrimExpr lanes = this->VisitExpr(op->lanes); + if (base.same_as(op->base) && stride.same_as(op->stride) && lanes.same_as(op->lanes)) { return GetRef(op); } else { - return Ramp(base, stride, op->lanes); + return Ramp(base, stride, lanes); } } PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op) { PrimExpr value = this->VisitExpr(op->value); - if (value.same_as(op->value)) { + PrimExpr lanes = this->VisitExpr(op->lanes); + if (value.same_as(op->value) && lanes.same_as(op->lanes)) { return GetRef(op); } else { - return Broadcast(value, op->lanes); + return Broadcast(value, lanes); } } diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 1d1e674a9dd1..fd9ffe7dd70d 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -458,7 +458,8 @@ TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed([](PrimExpr value, Span span) TVM_REGISTER_NODE_TYPE(EvaluateNode); // BufferStore -BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, Span span) { +BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, PrimExpr predicate, + Span span) { ICHECK_EQ(buffer->shape.size(), indices.size()) << "Buffer " << buffer->name << " is " << buffer->shape.size() << "-dimensional, cannot be indexed with the " << indices.size() @@ -470,16 +471,26 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, } int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1; + bool index_scalable = indices.size() ? indices.back().dtype().is_scalable() : false; int buffer_lanes = buffer->dtype.lanes(); + bool buffer_scalable = buffer->dtype.is_scalable(); ICHECK_EQ(index_lanes * buffer_lanes, value.dtype().lanes()) << "Cannot store value with " << value.dtype().lanes() << ", expected value with " << index_lanes * buffer_lanes << " (" << index_lanes << " index lanes * " << buffer_lanes << " buffer element lanes)"; - if (buffer->dtype.with_lanes(buffer_lanes * index_lanes) != value.dtype()) { + + runtime::DataType buffer_dtype; + if (index_scalable || buffer_scalable) { + buffer_dtype = buffer->dtype.with_scalable_lanes(buffer_lanes * index_lanes); + } else { + buffer_dtype = buffer->dtype.with_lanes(buffer_lanes * index_lanes); + } + if (buffer_dtype != value.dtype()) { LOG(FATAL) << "TypeError: dtype mismatch on BufferStore: " // << "buffer's dtype is `" << buffer->dtype // << "`, the lanes of indexing are: `" << index_lanes // + << "`, the scalability is: `" << buffer_dtype.is_scalable() << "`, but RHS's dtype is `" << value.dtype() << "`"; } @@ -487,14 +498,14 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, node->buffer = std::move(buffer); node->value = std::move(value); node->indices = std::move(indices); + node->predicate = std::move(predicate); node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.BufferStore") - .set_body_typed([](Buffer buffer, PrimExpr value, Array indices, Span span) { - return BufferStore(buffer, value, indices, span); - }); + .set_body_typed([](Buffer buffer, PrimExpr value, Array indices, PrimExpr predicate, + Span span) { return BufferStore(buffer, value, indices, predicate, span); }); TVM_REGISTER_NODE_TYPE(BufferStoreNode); diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index fb92463c3c32..c5f6d707b95e 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -368,6 +368,15 @@ TIR_DEFINE_BUILTIN_FUNC(dma_start_group) TIR_DEFINE_BUILTIN_FUNC(dma_end_group) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(vscale).set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(get_active_lane_mask) + .set_num_inputs(2) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); + TIR_DEFINE_BUILTIN_FUNC(assume) .set_attr("TCallEffectKind", Integer(CallEffectKind::kEmbedInfo)) .set_num_inputs(1); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index fd14f4892154..86317b0bbfec 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -122,20 +122,33 @@ PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span s {x, y, q, s}, span); } +void BroadcastToMatchLanes(PrimExpr& op_a, PrimExpr& op_b) { // NOLINT(*) + DataType dtype_a = op_a.dtype(); + DataType dtype_b = op_b.dtype(); + + if (dtype_a.lanes() == 1 && dtype_b.lanes() != 1) { + if (dtype_b.is_scalable()) { + op_a = tir::Broadcast( + op_a, tir::Mul(dtype_b.lanes(), Call(DataType::Int(32), builtin::vscale(), {}))); + } else { + op_a = tir::Broadcast(op_a, dtype_b.lanes()); + } + } +} + // The public function with a quick checking path. void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) CHECK(lhs.defined()) << "ValueError: `lhs` is null in the binary operator"; CHECK(rhs.defined()) << "ValueError: `rhs` is null in the binary operator"; if (lhs.dtype() == rhs.dtype()) return; + + BroadcastToMatchLanes(lhs, rhs); + BroadcastToMatchLanes(rhs, lhs); + DataType ltype = lhs.dtype(); DataType rtype = rhs.dtype(); - if (ltype.lanes() == 1 && rtype.lanes() != 1) { - lhs = tir::Broadcast(lhs, rtype.lanes()); - } else if (rtype.lanes() == 1 && ltype.lanes() != 1) { - rhs = tir::Broadcast(rhs, ltype.lanes()); - } else { - ICHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << " vs " << rtype; - } + + ICHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << " vs " << rtype; if (lhs.dtype() == rhs.dtype()) return; ltype = lhs.dtype(); @@ -337,11 +350,17 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span) { value = tir::Cast(vtype, value, span); } } - return tir::Broadcast(value, t.lanes(), span); + if (t.is_scalable()) { + return tir::Broadcast( + value, tir::Mul(t.lanes(), Call(DataType::Int(32), builtin::vscale(), {})), span); + } else { + return tir::Broadcast(value, t.lanes(), span); + } } else { ICHECK(value.dtype().lanes() == t.lanes()); + ICHECK(value.dtype().is_scalable() == t.is_scalable()); if (const auto* broadcast = value.as()) { - return tir::Broadcast(cast(vtype, broadcast->value, span), t.lanes(), span); + return tir::Broadcast(cast(vtype, broadcast->value, span), broadcast->lanes, span); } else if (const auto* ramp = value.as()) { if (t.is_int() || t.is_uint()) { // only cast to index data type can be folded to ramp diff --git a/src/tir/schedule/analysis/reducer.cc b/src/tir/schedule/analysis/reducer.cc index d8d1e8fc2572..6754ee334938 100644 --- a/src/tir/schedule/analysis/reducer.cc +++ b/src/tir/schedule/analysis/reducer.cc @@ -178,16 +178,14 @@ class PatternMatcher : public ExprVisitor { if (ptr == nullptr) { match_success_ = false; } else { - if (op->lanes != ptr->lanes) { - match_success_ = false; - } else { - PrimExpr tmp = expr_to_match_; - expr_to_match_ = ptr->base; - VisitExpr(op->base); - expr_to_match_ = ptr->stride; - VisitExpr(op->stride); - std::swap(expr_to_match_, tmp); - } + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->base; + VisitExpr(op->base); + expr_to_match_ = ptr->stride; + VisitExpr(op->stride); + expr_to_match_ = ptr->lanes; + VisitExpr(op->lanes); + std::swap(expr_to_match_, tmp); } } @@ -196,14 +194,12 @@ class PatternMatcher : public ExprVisitor { if (ptr == nullptr) { match_success_ = false; } else { - if (op->lanes != ptr->lanes) { - match_success_ = false; - } else { - PrimExpr tmp = expr_to_match_; - expr_to_match_ = ptr->value; - VisitExpr(op->value); - std::swap(expr_to_match_, tmp); - } + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->value; + VisitExpr(op->value); + expr_to_match_ = ptr->lanes; + VisitExpr(op->lanes); + std::swap(expr_to_match_, tmp); } } @@ -265,7 +261,6 @@ class PatternMatcher : public ExprVisitor { void Match(const Array& exprs_to_match) { this->match_success_ = true; this->filled_map_.clear(); - ICHECK_EQ(pattern_.size(), exprs_to_match.size()); int n_buffers = pattern_.size(); for (int i = 0; i < n_buffers; ++i) { diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 343fb7617886..930d9eb222c1 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -20,6 +20,8 @@ #include +#include "../analysis/check_contains.h" + namespace tvm { namespace tir { @@ -396,7 +398,7 @@ LoopRV ConcreteScheduleNode::Fuse(const Array& loop_rvs, bool preserve_u Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, const Array>& factor_rvs, - bool preserve_unit_iters) { + bool preserve_unit_iters, bool disable_predication) { class NotSingleInferFactorError : public ScheduleError { public: explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {} @@ -488,13 +490,14 @@ Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, tot_length *= factor; } } + if (infer_index != -1) { factors.Set(infer_index, this->analyzer_->Simplify(floordiv(loop->extent + tot_length - 1, tot_length))); } else if (!this->analyzer_->CanProve(tot_length >= loop->extent)) { throw WrongFactorProductError(state_->mod, GetRef(loop)); } - results = tir::Split(state_, loop_sref, factors, preserve_unit_iters); + results = tir::Split(state_, loop_sref, factors, preserve_unit_iters, disable_predication); TVM_TIR_SCHEDULE_END("split", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(results); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index a6c47070c8df..e4526425511a 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -108,7 +108,7 @@ class ConcreteScheduleNode : public ScheduleNode { LoopRV Fuse(const Array& loop_rvs, bool preserve_unit_iters) override; LoopRV Merge(const Array& loop_rvs) override; Array Split(const LoopRV& loop_rv, const Array>& factors, - bool preserve_unit_iters) override; + bool preserve_unit_iters, bool disable_predication) override; void Reorder(const Array& ordered_loop_rvs) override; void ReorderBlockIterVar(const BlockRV& block_rv, const Array new_order) override; LoopRV AddUnitLoop(const BlockRV& block_rv) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 02fb982f5ed9..e0bb4f39cdc2 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -204,10 +204,15 @@ Array GetOutputBlocks(const ScheduleState& self, const StmtSRef& scope * \param loop_sref The sref to the loop being split * \param factors The splitting factors * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings + * \param disable_predication If enabled, don't create a predicate for guarding the + * loop. This can be useful when splitting with scalable factors that the schedule writer + * knows are divisible. Warning: enabling this feature may result in incorrect code generation + * if not used carefully. * \return An array of srefs to the loops after splitting */ TVM_DLL Array Split(ScheduleState self, const StmtSRef& loop_sref, - const Array& factors, bool preserve_unit_iters); + const Array& factors, bool preserve_unit_iters, + bool disable_predication); /*! * \brief Merge a list of loops into one. The loops under their LCA requires: diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index a6b97bf17906..4881504d67ef 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -386,7 +386,7 @@ class DependentLoopError : public ScheduleError { }; Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array& factors, - bool preserve_unit_iters) { + bool preserve_unit_iters, bool disable_predication) { // Invariance // - The total repeat number has not changed for each direct child block with updating predicate. // - The execution order has not changed. (The block executes with the same args and the same @@ -433,7 +433,7 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array &opaque_block_reuse)(std::move(new_stmt)); // Step 3. Update predicate to guard the loop PrimExpr predicate = substitute_value < loop->extent; - if (!analyzer.CanProve(predicate, arith::ProofStrength::kSymbolicBound)) { + if (!disable_predication && !analyzer.CanProve(predicate, arith::ProofStrength::kSymbolicBound)) { new_stmt = BlockPredicateAppender(/*predicate=*/predicate)(std::move(new_stmt)); } // Step 4. Generate nested loops to replace the original loop and simplify the binding @@ -920,7 +920,7 @@ struct SplitTraits : public UnpackedInstTraits { private: static constexpr size_t kNumInputs = 2; - static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; template @@ -936,16 +936,18 @@ struct SplitTraits : public UnpackedInstTraits { static Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Array> factors, - Bool preserve_unit_iters) { - return sch->Split(loop_rv, factors, preserve_unit_iters.operator bool()); + Bool preserve_unit_iters, Bool disable_predication) { + return sch->Split(loop_rv, factors, preserve_unit_iters.operator bool(), + disable_predication.operator bool()); } static String UnpackedAsPython(Array outputs, String loop_rv, Array factors, - Bool preserve_unit_iters) { + Bool preserve_unit_iters, Bool disable_predication) { PythonAPICall py("split"); py.Input("loop", loop_rv); py.Input("factors", factors); py.Input("preserve_unit_iters", preserve_unit_iters.operator bool()); + py.Input("disable_predication", disable_predication.operator bool()); py.OutputList(outputs); return py.Str(); } diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index e55a5cf8078c..e25a9eaf6ac3 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -226,8 +226,9 @@ LoopRV TracedScheduleNode::Fuse(const Array& loop_rvs, bool preserve_uni Array TracedScheduleNode::Split(const LoopRV& loop_rv, const Array>& factor_rvs, - bool preserve_unit_iters) { - Array results = ConcreteScheduleNode::Split(loop_rv, factor_rvs, preserve_unit_iters); + bool preserve_unit_iters, bool disable_predication) { + Array results = + ConcreteScheduleNode::Split(loop_rv, factor_rvs, preserve_unit_iters, disable_predication); std::vector inputs; inputs.reserve(1 + factor_rvs.size()); @@ -237,10 +238,12 @@ Array TracedScheduleNode::Split(const LoopRV& loop_rv, } static const InstructionKind& kind = InstructionKind::Get("Split"); - trace_->Append(/*inst=*/Instruction(/*kind=*/kind, - /*inputs=*/inputs, - /*attrs=*/{Integer(preserve_unit_iters)}, - /*outputs=*/{results.begin(), results.end()})); + trace_->Append( + /*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/inputs, + /*attrs=*/{Integer(preserve_unit_iters), Integer(disable_predication)}, + /*outputs=*/{results.begin(), results.end()})); return results; } diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index d7a42f63d4dc..1d801f017d85 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -67,7 +67,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { LoopRV Fuse(const Array& loop_rvs, bool preserve_unit_iters) final; LoopRV Merge(const Array& loop_rvs) final; Array Split(const LoopRV& loop_rv, const Array>& factor_rvs, - bool preserve_unit_iters) final; + bool preserve_unit_iters, bool disable_predication) final; void Reorder(const Array& ordered_loop_rvs) final; void ReorderBlockIterVar(const BlockRV& block_rv, const Array new_order) final; LoopRV AddUnitLoop(const BlockRV& block_rv) final; diff --git a/src/tir/transforms/bound_checker.cc b/src/tir/transforms/bound_checker.cc index f5aa6773e66d..358f864d3a24 100644 --- a/src/tir/transforms/bound_checker.cc +++ b/src/tir/transforms/bound_checker.cc @@ -34,6 +34,8 @@ #include #include +#include "../../arith/unwrap_vector_expr.h" + namespace tvm { namespace tir { @@ -156,7 +158,12 @@ class BoundChecker : public StmtExprMutator { if (!IsValidScalar(ramp_index->stride)) { return false; } - if (ramp_index->lanes <= 0) { + bool lanes_int = ramp_index->lanes->IsInstance(); + if (!lanes_int) { + return false; + } + int lanes = static_cast(Downcast(ramp_index->lanes)->value); + if (lanes <= 0) { return false; } } @@ -192,11 +199,7 @@ class BoundChecker : public StmtExprMutator { PrimExpr upper_bound = shape[i]; if (const RampNode* ramp_index = index.as()) { - // In case index is base + stride * i. - // Non inclusive range. - index = Add(ramp_index->base, - Mul(ramp_index->stride, - make_const(ramp_index->stride.dtype(), ramp_index->lanes - 1))); + index = arith::UnwrapVectorExpr(GetRef(ramp_index), ramp_index->lanes); } // Try to simplify index and bound. diff --git a/src/tir/transforms/inject_rolling_buffer.cc b/src/tir/transforms/inject_rolling_buffer.cc index 5f7b9b4156c3..afa3e15f0260 100644 --- a/src/tir/transforms/inject_rolling_buffer.cc +++ b/src/tir/transforms/inject_rolling_buffer.cc @@ -257,7 +257,8 @@ class RollingBufferInjector : public StmtExprMutator { indices.push_back(index); } } - Stmt buffer_store = BufferStore(op->buffer, op->value, indices, op->span); + ICHECK(!op->predicate.defined()) << "Indices change can affect the predicate"; + Stmt buffer_store = BufferStore(op->buffer, op->value, indices, op->predicate, op->span); // Then wrap the BufferStores in some Ifs to avoid recomputing elements for (size_t i{0}; i < rolling_buffer_info.axis_iter_vars.size(); ++i) { auto iter_var{rolling_buffer_info.axis_iter_vars[i]}; @@ -293,7 +294,8 @@ class RollingBufferInjector : public StmtExprMutator { indices.push_back(index); } } - return BufferLoad(op->buffer, indices, op->span); + ICHECK(!op->predicate.defined()) << "Indices change can affect the predicate"; + return BufferLoad(op->buffer, indices, op->predicate, op->span); } else { return expr; } diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index 700587fe0e21..6579aa03dc03 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -97,6 +97,7 @@ class MatchBufferLower : public StmtExprMutator { auto n = CopyOnWrite(op); n->indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); n->buffer = source->buffer; + ICHECK(!op->predicate.defined()) << "Indices change can affect the predicate"; return Stmt(n); } } @@ -113,6 +114,7 @@ class MatchBufferLower : public StmtExprMutator { const Buffer& buffer = (*it).first; const BufferRegion& source = (*it).second; Array indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); + ICHECK(!op->predicate.defined()) << "Indices change can affect the predicate"; return BufferLoad(source->buffer, indices); } } diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 870235954689..392373d40b63 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -137,6 +137,8 @@ class WarpStoreCoeffFinder : private StmtExprVisitor { PrimExpr index = op->indices[0]; if (op->value.dtype().lanes() != 1) { + ICHECK(!op->value.dtype().is_scalable()) + << "Scalable vectors are not supported in lower_warp_memory"; arith::PVar base; ICHECK(arith::ramp(base, 1, op->value.dtype().lanes()).Match(index)) << "LowerWarpMemory failed due to store index=" << index @@ -343,6 +345,8 @@ class WarpAccessRewriter : protected StmtExprMutator { std::pair SplitIndexByGroup(const PrimExpr& index) { if (index.dtype().lanes() != 1) { arith::PVar base; + ICHECK(!index.dtype().is_scalable()) + << "Scalable vectors are not supported in lower_warp_memory"; ICHECK(arith::ramp(base, 1, index.dtype().lanes()).Match(index)); auto [local_index, group] = SplitIndexByGroup(base.Eval()); diff --git a/src/tir/transforms/manifest_shared_memory_local_stage.cc b/src/tir/transforms/manifest_shared_memory_local_stage.cc index 619a9f0a9e8f..f5b905a3f502 100644 --- a/src/tir/transforms/manifest_shared_memory_local_stage.cc +++ b/src/tir/transforms/manifest_shared_memory_local_stage.cc @@ -67,6 +67,7 @@ class IntermediateStageRewriter { Stmt local_stage = MakeLocalStage(block, new_buffer, buffer_indices, relaxed_loops, store); // Step 3: Create BufferLoad from the intermediate buffer + ICHECK(!store->predicate.defined()); BufferLoad new_buffer_load = BufferLoad(new_buffer, buffer_indices); BufferStore new_buffer_store = Downcast(block->body); new_buffer_store.CopyOnWrite()->value = new_buffer_load; diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index bc606aa0b7ff..3b418aac0cf5 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -213,7 +213,8 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { // A write whose destination is known to already contain the // values to be written is a no-op. // PrimExpr stores_existing_value = store->value == BufferLoad(store->buffer, store->indices); - PrimExpr stores_existing_value = store->value - BufferLoad(store->buffer, store->indices) == 0; + PrimExpr stores_existing_value = + store->value - BufferLoad(store->buffer, store->indices, store->predicate) == 0; if (touch_pattern_.has_value()) { Stmt context_arg = context_ ? GetRef(context_) : Stmt(store); stores_existing_value = diff --git a/src/tir/transforms/remove_weight_layout_rewrite_block.cc b/src/tir/transforms/remove_weight_layout_rewrite_block.cc index 05b636f11403..e8d89bfb5700 100644 --- a/src/tir/transforms/remove_weight_layout_rewrite_block.cc +++ b/src/tir/transforms/remove_weight_layout_rewrite_block.cc @@ -196,7 +196,7 @@ class AllocateConstRewrite : public StmtExprMutator { op->buffer->elem_offset, it->second->name_hint, op->buffer->data_alignment, op->buffer->offset_factor, op->buffer->buffer_type); new_load_buf_[op->buffer->data.get()] = new_buffer; - return BufferLoad(new_buffer, op->indices); + return BufferLoad(new_buffer, op->indices, op->predicate); } return ExprMutator::VisitExpr_(op); } diff --git a/src/tir/transforms/renormalize_split_pattern.cc b/src/tir/transforms/renormalize_split_pattern.cc index eb596beb181a..beb5997d4982 100644 --- a/src/tir/transforms/renormalize_split_pattern.cc +++ b/src/tir/transforms/renormalize_split_pattern.cc @@ -63,7 +63,7 @@ class SplitPatternReNormalizer : public IRMutatorWithAnalyzer { // Pattern var match IntImm PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp - PVar lanes; + PVar lanes; // floordiv(floormod(x, c1 * c2), c2) = floormod(floordiv(x, c2), c1) TRY_RECURSIVE_REWRITE_IF(floordiv(floormod(x, c3), c2), diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 9c1244838173..2742498e5d74 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -730,7 +730,7 @@ class ThreadScopePropagate : public StmtExprMutator { auto it = buf_remap_.find(op->buffer->data); if (it != buf_remap_.end()) { - return BufferLoad(it->second, op->indices, op->span); + return BufferLoad(it->second, op->indices, op->predicate, op->span); } else { return expr; } @@ -743,7 +743,7 @@ class ThreadScopePropagate : public StmtExprMutator { auto it = buf_remap_.find(op->buffer->data); if (it != buf_remap_.end()) { - return BufferStore(it->second, op->value, op->indices, op->span); + return BufferStore(it->second, op->value, op->indices, op->predicate, op->span); } else { return stmt; } @@ -938,8 +938,10 @@ class BufferBindUnwrapper : public StmtExprMutator { const BufferEntry& e = GetBufferEntry(op->buffer); if (e.remap) { + ICHECK(!op->predicate.defined()) << "Remapping predicate is not supported"; return BufferLoad(e.remap->target, - remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span); + remap_indices(op->indices, e.remap->begins, e.remap->extents), + op->predicate, op->span); } else { return expr; } @@ -952,8 +954,10 @@ class BufferBindUnwrapper : public StmtExprMutator { const BufferEntry& e = GetBufferEntry(op->buffer); if (e.remap) { + ICHECK(!op->predicate.defined()) << "Remapping predicate is not supported"; return BufferStore(e.remap->target, op->value, - remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span); + remap_indices(op->indices, e.remap->begins, e.remap->extents), + op->predicate, op->span); } else { return stmt; } @@ -1418,7 +1422,8 @@ class StorageFlattener : public StmtExprMutator { auto flattened_indices = e.buffer->ElemOffset(op->indices); - Stmt body = BufferStore(e.flattened_buffer, value, flattened_indices, op->span); + ICHECK(!op->predicate.defined()) << "Changing the index can affect the predicate"; + Stmt body = BufferStore(e.flattened_buffer, value, flattened_indices, op->predicate, op->span); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); } @@ -1574,7 +1579,9 @@ class StorageFlattener : public StmtExprMutator { } auto flattened_indices = e.buffer->ElemOffset(op->indices); - PrimExpr val = BufferLoad(e.flattened_buffer, flattened_indices, op->span); + + ICHECK(!op->predicate.defined()) << "Changing the index can affect the predicate"; + PrimExpr val = BufferLoad(e.flattened_buffer, flattened_indices, op->predicate, op->span); if (op->dtype == DataType::Bool()) { ICHECK_EQ(e.flattened_buffer->dtype, DataType::Int(8)) diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 6875523a956d..5c1026258561 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -1318,9 +1318,12 @@ class VectorTypeAccessChecker : public StmtExprVisitor { if (indices.size()) { const RampNode* ramp_index = indices[indices.size() - 1].as(); if (ramp_index && is_one(ramp_index->stride)) { - arith::ModularSet me = analyzer_.modular_set(ramp_index->base); - if ((me->coeff % ramp_index->lanes == 0) && (me->base % ramp_index->lanes == 0)) { - lanes_used = ramp_index->lanes; + if (ramp_index->lanes->IsInstance()) { + int lanes = static_cast(Downcast(ramp_index->lanes)->value); + arith::ModularSet me = analyzer_.modular_set(ramp_index->base); + if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) { + lanes_used = lanes; + } } } } @@ -1453,13 +1456,15 @@ class VectorTypeRewriter : public StmtExprMutator { Array indices = node->indices; const PrimExpr& last_dim_index = indices[indices.size() - 1]; - if (const RampNode* ramp_index = last_dim_index.as(); - ramp_index && is_one(ramp_index->stride)) { - PrimExpr new_index = - ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes); - if (ramp_index->lanes != info.factor()) { - ICHECK(info.factor() && ramp_index->lanes % info.factor() == 0); - int new_lanes = ramp_index->lanes / info.factor(); + const RampNode* ramp_index = indices[indices.size() - 1].as(); + if (ramp_index && is_one(ramp_index->stride)) { + ICHECK(ramp_index->lanes->IsInstance()) + << "Rewriting pointer type into scalable type is currently not supported"; + auto lanes = static_cast(Downcast(ramp_index->lanes)->value); + PrimExpr new_index = ramp_index->base / make_const(ramp_index->base.dtype(), lanes); + if (lanes != info.factor()) { + ICHECK(info.factor() && lanes % info.factor() == 0); + int new_lanes = lanes / info.factor(); new_index = Ramp(new_index * new_lanes, ramp_index->stride, new_lanes, ramp_index->span); } indices.Set(indices.size() - 1, new_index); diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index 030dbd01badf..0882e1c4c3af 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -333,6 +333,8 @@ class ComputeLegalizer : public StmtExprMutator { ICHECK(MatchDType(value->dtype)); value = cast(new_buf->dtype.with_lanes(value.dtype().lanes()), value); } + ICHECK(!op->predicate.defined()) + << "Predicated buffers are not supported in data type legalizer "; return BufferStore(new_buf, value, indices); } } @@ -404,6 +406,8 @@ class ComputeLegalizer : public StmtExprMutator { if (new_buf.same_as(op->buffer)) { return ret; } else { + ICHECK(!op->predicate.defined()) + << "Predicated buffers are not supported in data type legalizer "; return BufferLoad(new_buf, op->indices); } } @@ -565,6 +569,8 @@ class StorageLegalizer : public StmtExprMutator { if (MatchDType(op->value.dtype())) { ICHECK(new_buf->dtype.is_uint()); } + ICHECK(!op->predicate.defined()) + << "Predicated buffers are not supported in data type legalizer "; return BufferStore(new_buf, value, indices); } } @@ -598,6 +604,8 @@ class StorageLegalizer : public StmtExprMutator { if (new_buf.same_as(op->buffer)) { return ret; } else { + ICHECK(!op->predicate.defined()) + << "Predicated buffers are not supported in data type legalizer "; return BufferLoad(new_buf, op->indices); } } diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index b80a71aa311c..ecd26e7ba2f6 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -35,14 +35,30 @@ #include #include +#include "../../target/parsers/aprofile.h" + namespace tvm { namespace tir { -inline PrimExpr BroadcastTo(PrimExpr e, int lanes) { - if (e.dtype().lanes() == lanes) return e; - if (const BroadcastNode* op = e.as()) { - if (lanes % op->lanes == 0) { - return Broadcast(op->value, lanes); +inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool scalable) { + if (!scalable) { + // Check that we are not broadcasting a scalable vector into fixed length vector + ICHECK(!e.dtype().is_scalable()); + if (e.dtype().lanes() == lanes) return e; + if (const BroadcastNode* op = e.as()) { + int op_lanes = op->lanes.as()->value; + if (lanes % op_lanes == 0) { + return Broadcast(op->value, lanes); + } + } + } else { + if (e.dtype().is_scalable() && e.dtype().lanes() == lanes) { + // It's already a scalable vector in a correct form + return e; + } else { + ICHECK(lanes % e.dtype().lanes() == 0); + PrimExpr scalable_lanes = Mul(lanes, Call(DataType::Int(32), builtin::vscale(), {})); + return Broadcast(e, scalable_lanes); } } ICHECK_EQ(e.dtype().lanes(), 1) << "Cannot broadcast lane=" << e.dtype().lanes() << " to " @@ -50,6 +66,24 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes) { return Broadcast(e, lanes); } +bool EnableBufferPredication() { + transform::PassContext pass_ctx = transform::PassContext::Current(); + bool enable_buffer_predication = + pass_ctx->GetConfig("tir.enable_buffer_predication", Bool(false)).value(); + if (enable_buffer_predication) { + return true; + } + + // When compiling for aarch64 devices with SVE, we should enable predication by default + Target current_target = Target::Current(); + if (!current_target.defined()) { + return false; + } + TargetJSON target_json = target::parsers::aprofile::ParseTarget(current_target->Export()); + TargetFeatures features = Downcast(target_json.at("features")); + return Downcast(features.at("has_sve")); +} + // Rewrite vectorized allocation access // This is necessary for making each vector component containing its own workspace. // Originates from Halide's loop vectorizer @@ -60,7 +94,7 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes) { // class VecAllocAccess : public StmtExprMutator { public: - VecAllocAccess(const VarNode* buf, Var var, int var_lanes) + VecAllocAccess(const VarNode* buf, Var var, PrimExpr var_lanes) : buf_(buf), var_(var), var_lanes_(var_lanes) {} PrimExpr VisitExpr_(const BufferLoadNode* op) final { @@ -136,7 +170,7 @@ class VecAllocAccess : public StmtExprMutator { // variable to be replaced Var var_; // the lanes. - int var_lanes_; + PrimExpr var_lanes_; // Analyzer for simplifications arith::Analyzer analyzer_; }; @@ -149,7 +183,7 @@ class Vectorizer : public StmtMutator, public ExprFunctordtype, 0), IntImm(var->dtype, 1), var_lanes); } @@ -191,7 +225,8 @@ class Vectorizer : public StmtMutator, public ExprFunctorbase * a, b_ramp->stride * a, b_ramp->lanes); } } - return Mul(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); + bool is_scalable = a.dtype().is_scalable() || b.dtype().is_scalable(); + return Mul(BroadcastTo(a, lanes, is_scalable), BroadcastTo(b, lanes, is_scalable)); } return BinaryVec(op); } @@ -224,13 +259,17 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->stride); if (base.dtype().lanes() > 1 && stride.dtype().lanes() == 1) { const RampNode* base_ramp = base.as(); - if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.dtype(), op->lanes))) { + ICHECK(op->lanes->IsInstance()) << "Scalable ramp of ramps is not supported yet"; + int lanes = static_cast(Downcast(op->lanes)->value); + if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.dtype(), lanes))) { return Ramp(base_ramp->base, stride, op->lanes * base_ramp->lanes); } } + ICHECK(!base.dtype().is_scalable()) << "Ramp base with scalable dtype is not supported"; + ICHECK(!stride.dtype().is_scalable()) << "Ramp stride with scalable dtype is not supported"; int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes()); - base = BroadcastTo(base, lanes); - stride = BroadcastTo(stride, lanes); + base = BroadcastTo(base, lanes, false); + stride = BroadcastTo(stride, lanes, false); Array elems; for (int i = 0; i < lanes; ++i) { elems.push_back( @@ -260,15 +299,21 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op); } else { int lanes = std::max(std::max(cond.dtype().lanes(), t.dtype().lanes()), f.dtype().lanes()); - return Select(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes)); + bool is_scalable = t.dtype().is_scalable() || t.dtype().is_scalable(); + return Select(cond, BroadcastTo(t, lanes, is_scalable), BroadcastTo(f, lanes, is_scalable)); } } + PrimExpr VisitExpr_(const CastNode* op) final { PrimExpr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { return GetRef(op); } else { - return Cast(op->dtype.with_lanes(value.dtype().lanes()), value); + if (value.dtype().is_scalable()) { + return Cast(op->dtype.with_scalable_lanes(value.dtype().lanes()), value); + } else { + return Cast(op->dtype.with_lanes(value.dtype().lanes()), value); + } } } @@ -305,9 +350,14 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op); } else { int lanes = std::max(t.dtype().lanes(), f.dtype().lanes()); - t = BroadcastTo(t, lanes); - f = BroadcastTo(f, lanes); - return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}); + bool is_scalable = t.dtype().is_scalable() || f.dtype().is_scalable(); + t = BroadcastTo(t, lanes, is_scalable); + f = BroadcastTo(f, lanes, is_scalable); + if (op->dtype.is_scalable()) { + return Call(op->dtype.with_scalable_lanes(lanes), op->op, {cond, t, f}); + } else { + return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}); + } } } // Call @@ -413,10 +463,14 @@ class Vectorizer : public StmtMutator, public ExprFunctorindices) || !value.same_as(op->value)) { // How many lanes of indexing are present in the index and - // buffer element type, excluding the last index. T + // buffer element type, excluding the last index. int other_index_lanes = op->buffer->dtype.lanes(); + ICHECK(!op->buffer->dtype.is_scalable()) + << "Scalable buffer elements are not supported in vectorizer"; for (size_t i = 0; i < indices.size() - 1; i++) { other_index_lanes *= indices[i].dtype().lanes(); + // Only allow the last index to be scalable + ICHECK(!indices[i].dtype().is_scalable()); } // The total number of lanes of indexing, including the last index. @@ -434,11 +488,13 @@ class Vectorizer : public StmtMutator, public ExprFunctorindices = indices; - writer->value = BroadcastTo(value, total_lanes); + writer->value = BroadcastTo(value, total_lanes, is_last_index_scalable); } return std::move(store); @@ -462,18 +518,137 @@ class Vectorizer : public StmtMutator, public ExprFunctorannotations); } } + + /*! + * \brief A pass that tries to rewrite buffer accesses (loads and stores) with a + * predicate expression where possible. + * + * \note For now we start with a minimalized case targeting block-level predicates + * produced by the split schedule primitive, with the potential for predicating + * more complex terms in the future if needed. + * + * \example + * Before: + * for i_0 in T.serial(4): + * for i_1 in T.vectorized(4): + * if i_0 * 4 + i_1 < 14: + * B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0 + * + * After: + * for i_0 in T.serial(4): + * pred = T.get_active_lane_mask(i_0 * 4, 14) + * B(pred=pred)[i_0 * 4:i_0 * 4 + 4] = A(pred=pred)[i_0 * 4:i_0 * 4 + 4] + \ + * T.Broadcast(T.float32(1), 4) + */ + class TryPredicateBufferAccesses : public StmtExprMutator { + public: + TryPredicateBufferAccesses() {} + + /*! + * \brief Run the pass to try to exact predicates. + * \param stmt - The statement containing buffer accesses (loads and stores) + * we want to attempt to predicate. + * \param condition - The conditional expression (block-level predicate) + * that we will try to remove. + * \return pair - Boolean value for success/failure, the rewritten + * stmt if successful. + */ + std::pair Run(Stmt stmt, PrimExpr condition) { + // Check that the condition provided is of the form a < b, for now. + if (!condition->IsInstance()) { + return {false, stmt}; + } + + LT lt = Downcast(condition); + + // Check the form of the vectorized condition, we're expecting + // Ramp(...) < Broadcast(...) + if (!lt->a->IsInstance() || !lt->b->IsInstance()) { + return {false, stmt}; + } + + base_ = Downcast(lt->a)->base; + limit_ = Downcast(lt->b)->value; + + // Now we can try to predicate + Stmt predicated_stmt = StmtExprMutator::operator()(std::move(stmt)); + if (num_accesses_analyzed_ > 0 && num_accesses_analyzed_ == num_accesses_rewritten_) { + return {true, predicated_stmt}; + } + return {false, stmt}; + } + + private: + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto load = Downcast(StmtExprMutator::VisitExpr_(op)); + return TryPredicateBufferAccess(load); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto store = Downcast(StmtExprMutator::VisitStmt_(op)); + return TryPredicateBufferAccess(store); + } + + template + AccessNode TryPredicateBufferAccess(AccessNode node) { + num_accesses_analyzed_ += 1; + + // Do not try to predicate non-vectorized accesses + Array indices = node->indices; + if (!indices.size() || !indices[0]->IsInstance()) { + return node; + } + Ramp ramp = Downcast(node->indices[0]); + + // The vectorized access pattern must match the base of the predicate + if (!tvm::StructuralEqual()(ramp->base, base_)) { + return node; + } + + DataType buf_predicate_dtype = + DataType(DataType::kInt, 1, ramp->dtype.lanes(), ramp->dtype.is_scalable()); + Call lane_mask = Call(buf_predicate_dtype, builtin::get_active_lane_mask(), {base_, limit_}); + + num_accesses_rewritten_ += 1; + auto writer = node.CopyOnWrite(); + writer->predicate = lane_mask; + return node; + } + + /*! \brief The variable base expr of the predicate. */ + PrimExpr base_; + /*! \brief The limit of the predicate. The expr specifies the upper bound of the base's + * evaluated value. */ + PrimExpr limit_; + /*! \brief The number of buffer accesses in the stmt we will analyze. */ + size_t num_accesses_analyzed_ = 0; + /*! \brief The number of buffer accesses rewritten with predicates. */ + size_t num_accesses_rewritten_ = 0; + }; + // IfThenElse Stmt VisitStmt_(const IfThenElseNode* op) final { ICHECK(!op->condition.dtype().is_vector()); PrimExpr condition = this->VisitExpr(op->condition); - if (condition.dtype().is_vector()) { - return Scalarize(GetRef(op)); - } Stmt then_case = this->VisitStmt(op->then_case); Optional else_case = NullOpt; if (op->else_case) { else_case = this->VisitStmt(op->else_case.value()); } + + // Check if we can rewrite the condition with predicated buffers + if (EnableBufferPredication() && condition.dtype().is_vector() && !else_case.defined()) { + std::pair success_stmt_pair = + TryPredicateBufferAccesses().Run(then_case, condition); + bool can_remove_if_then_else = success_stmt_pair.first; + if (can_remove_if_then_else) { + return success_stmt_pair.second; + } + } + + if (condition.dtype().is_vector()) { + return Scalarize(GetRef(op)); + } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); @@ -481,6 +656,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorname_hint + ".s", var_->dtype); stmt = Substitute(stmt, {{var_, idx}}); - return For(idx, IntImm(var_->dtype, 0), IntImm(var_->dtype, var_lanes_), ForKind::kSerial, - stmt); + return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, stmt); } // ProducerStore Stmt VisitStmt_(const ProducerStoreNode* op) final { @@ -561,7 +736,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor MutateArray(Array arr, int* p_lanes) { if (arr.size() == 0) return arr; int& lanes = *p_lanes; @@ -588,7 +764,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op); } else { int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); - return TOp(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); + bool is_scalable = a.dtype().is_scalable() || b.dtype().is_scalable(); + return TOp(BroadcastTo(a, lanes, is_scalable), BroadcastTo(b, lanes, is_scalable)); } } template @@ -626,7 +803,8 @@ class Vectorizer : public StmtMutator, public ExprFunctorbase, b), a_ramp->stride, a_ramp->lanes); } } - return fcompute(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); + bool is_scalable = a.dtype().is_scalable() || b.dtype().is_scalable(); + return fcompute(BroadcastTo(a, lanes, is_scalable), BroadcastTo(b, lanes, is_scalable)); } } }; @@ -636,11 +814,7 @@ class LoopVectorizer : public StmtMutator { Stmt VisitStmt_(const ForNode* op) final { if (op->kind == ForKind::kVectorized) { ICHECK(is_zero(op->min)); - auto* extent_as_int = op->extent.as(); - if (!extent_as_int || extent_as_int->value < 1) { - LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent; - } - return Vectorizer(op->loop_var, static_cast(extent_as_int->value))(op->body); + return Vectorizer(op->loop_var, op->extent)(op->body); } else { return StmtMutator::VisitStmt_(op); } diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index 2e386c48b75c..f5e1210b807f 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -27,9 +27,11 @@ TEST(Pattern, Basic) { using namespace tvm::tir; using namespace tvm::arith; tvm::tir::Var x("x"), y("y"), z("z"); + PrimExpr scalable_lanes = Mul(4, Call(DataType::Int(32), builtin::vscale(), {})); arith::PVar px, py, pz; arith::PVar pt; - arith::PVar planes; + arith::PVar planes; + arith::PCallExpr vscale; // arithmetics auto r = 1 + (y + 1); @@ -110,14 +112,18 @@ TEST(Pattern, Basic) { // ramp pattern { ICHECK(ramp(px, PConst(1), planes).Match(tir::Ramp(x, 1, 10))); - ICHECK(planes.Eval() == 10); + ICHECK(planes.Eval().as()->value == 10); + ICHECK(ramp(px, PConst(1), planes).Match(tir::Ramp(x, 1, scalable_lanes))); + ICHECK((PConst(4) * vscale).Match(planes.Eval())); ICHECK(!ramp(px, PConst(1), planes).Match(tir::Ramp(x, 2, 10))); } // broadcast pattern { ICHECK(broadcast(px, planes).Match(tir::Broadcast(x, 10))); - ICHECK(planes.Eval() == 10); + ICHECK(planes.Eval().as()->value == 10); ICHECK(broadcast(px * py, planes).Match(tir::Broadcast(x * 10, 10))); + ICHECK(broadcast(px, planes).Match(tir::Broadcast(x, scalable_lanes))); + ICHECK((PConst(4) * vscale).Match(planes.Eval())); } } diff --git a/tests/cpp/tir_scalable_datatype.cc b/tests/cpp/tir_scalable_datatype.cc new file mode 100644 index 000000000000..6db36b8db796 --- /dev/null +++ b/tests/cpp/tir_scalable_datatype.cc @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include "../../src/script/printer/utils.h" + +using ::testing::HasSubstr; + +// --------- +// Data Type +// --------- +TEST(TIR, TestCreateScalableType) { + tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true); + ASSERT_EQ(scalable_type.code(), kDLInt); + ASSERT_EQ(scalable_type.bits(), 32); + ASSERT_EQ(scalable_type.lanes(), 4); + ASSERT_TRUE(scalable_type.is_scalable()); + ASSERT_TRUE(scalable_type.is_vector()); +} + +TEST(TIR, TestScalableWithBits) { + tvm::DataType scalable_type = tvm::DataType(kDLInt, 1, 1, true); + scalable_type = scalable_type.with_bits(32); + ASSERT_EQ(scalable_type.bits(), 32); + ASSERT_TRUE(scalable_type.is_scalable()); + ASSERT_TRUE(scalable_type.is_vector()); +} + +TEST(TIR, TestScalableWithLanes) { + tvm::DataType type = tvm::DataType(kDLInt, 32, 1); + tvm::DataType scalable_type = type.with_scalable_lanes(4); + ASSERT_EQ(scalable_type.lanes(), 4); + ASSERT_TRUE(scalable_type.is_scalable()); + ASSERT_TRUE(scalable_type.is_vector()); +} + +TEST(TIR, TestAssignScalableDataType) { + tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 1, true); + tvm::DataType scalable_type_copy = scalable_type; + ASSERT_TRUE(scalable_type_copy.is_scalable()); + ASSERT_TRUE(scalable_type_copy.is_vector()); +} + +TEST(TIR, TestScalableDataTypeEquality) { + ASSERT_TRUE(tvm::DataType(kDLInt, 32, 4, true) == tvm::DataType(kDLInt, 32, 4, true)); +} + +TEST(TIR, TestScalableDataTypeAndNonScalableDataTypeInequality) { + ASSERT_FALSE(tvm::DataType(kDLInt, 32, 4, true) == tvm::DataType(kDLInt, 32, 4)); +} + +TEST(TIR, TestScalableDataTypeToString) { + tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true); + EXPECT_EQ(tvm::runtime::DLDataType2String(scalable_type), "int32x4xvscale"); +} + +TEST(TIR, TestStringToScalableDataType) { + std::string scalable_type_str = "int32x4xvscale"; + EXPECT_EQ(tvm::DataType(tvm::runtime::String2DLDataType(scalable_type_str)), + tvm::DataType(kDLInt, 32, -4)); +} + +TEST(TIR, TestInvalidStringToScalableDataType) { + std::string scalable_type_str = "int32xvscalex4"; + EXPECT_THROW( + { + try { + tvm::runtime::String2DLDataType(scalable_type_str); + } catch (const tvm::InternalError& e) { + EXPECT_THAT(e.what(), HasSubstr("unknown type int32xvscalex4")); + throw; + } + }, + tvm::InternalError); +} + +TEST(TIR, TestGetScalableVectorBytes) { + tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true); + EXPECT_THROW( + { + try { + tvm::runtime::GetVectorBytes(scalable_type); + } catch (const tvm::InternalError& e) { + EXPECT_THAT(e.what(), HasSubstr("Cannot get vector bytes of scalable vector")); + throw; + } + }, + tvm::InternalError); +} + +// ----------- +// Integration +// ----------- +#if TVM_LLVM_VERSION >= 120 +TEST(TIR, TestScalableIntrinCall) { + tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true); + tvm::tir::Call call = tvm::tir::Call( + scalable_type, tvm::tir::builtin::call_llvm_intrin(), + {tvm::IntImm(tvm::DataType::Int(32), ::llvm::Intrinsic::experimental_stepvector)}); + ASSERT_EQ(call->dtype, scalable_type); + ASSERT_EQ(call->Script(), + "T.call_llvm_intrin(\"int32x4xvscale\", \"llvm.experimental.stepvector\")"); +} +#endif + +TEST(TIR, TestTIRScriptScalableDtype2Str) { + tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true); + ASSERT_EQ(tvm::script::printer::DType2Str(scalable_type), "int32x4xvscale"); +} diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py index 5b0627542204..21ab753330d7 100644 --- a/tests/python/arith/test_arith_rewrite_simplify.py +++ b/tests/python/arith/test_arith_rewrite_simplify.py @@ -80,8 +80,19 @@ class TestVector(BaseCompare): TestCase(tvm.tir.Ramp(x, 1, 4) + tvm.tir.Ramp(y, 2, 4), tvm.tir.Ramp(x + y, 3, 4)), TestCase(tvm.tir.Ramp(x, 1, 2) + y, tvm.tir.Ramp(x + y, 1, 2)), TestCase(y + tvm.tir.Ramp(x, 1, 2), tvm.tir.Ramp(y + x, 1, 2)), + TestCase( + tvm.tir.Ramp(x, 1, 4 * tir.vscale()) + tvm.tir.Ramp(y, 2, 4 * tir.vscale()), + tvm.tir.Ramp(x + y, 3, 4 * tir.vscale()), + ), TestCase(y.astype("int32x2") + x.astype("int32x2"), (y + x).astype("int32x2")), + TestCase( + y.astype("int32x8xvscale") + x.astype("int32x8xvscale"), + (y + x).astype("int32x8xvscale"), + ), TestCase(tvm.tir.Broadcast(0, 4) + y, tvm.tir.Broadcast(y, 4)), + TestCase( + tvm.tir.Broadcast(0, 8 * tir.vscale()) + y, tvm.tir.Broadcast(y, 8 * tir.vscale()) + ), TestCase( tvm.tir.Ramp(x, 1, 4).astype("float32x4") + tvm.tir.Broadcast(0.0, 4), tvm.tir.Ramp(x, 1, 4).astype("float32x4"), @@ -100,22 +111,43 @@ class TestVector(BaseCompare): ## DivMod rules # trunc div TestCase(tdiv(y.astype("int32x2"), x.astype("int32x2")), tdiv(y, x).astype("int32x2")), + TestCase( + tdiv(y.astype("int32x4xvscale"), x.astype("int32x4xvscale")), + tdiv(y, x).astype("int32x4xvscale"), + ), TestCase(tdiv(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Ramp(tdiv(x, 2), 2, 4)), + TestCase( + tdiv(tvm.tir.Ramp(x, 4, 5 * tir.vscale()), 2), + tvm.tir.Ramp(tdiv(x, 2), 2, 5 * tir.vscale()), + ), TestCase(tdiv(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), x.astype("int32x4"), x >= 0), TestCase(tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8), tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8)), # trunc mod TestCase(tmod(y.astype("int32x2"), x.astype("int32x2")), tmod(y, x).astype("int32x2")), TestCase(tmod(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Broadcast(tmod(x, 2), 4)), TestCase(tmod(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), tvm.tir.Ramp(1, 1, 4), x >= 0), + TestCase( + tmod(tvm.tir.Ramp(x * 8 + 1, 1, 4 * tir.vscale()), 8), + tmod(tvm.tir.Ramp(1, 1, 4 * tir.vscale()), 8), + x >= 0, + ), TestCase(tmod(tvm.tir.Ramp(x * 8 + 1, 15, 4), 8), tmod(tvm.tir.Ramp(1, 15, 4), 8), x >= 0), # floor div TestCase(fld(y.astype("int32x2"), x.astype("int32x2")), fld(y, x).astype("int32x2")), TestCase(fld(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Ramp(fld(x, 2), 2, 4)), + TestCase( + fld(tvm.tir.Ramp(x, 4, 4 * tir.vscale()), 2), + tvm.tir.Ramp(fld(x, 2), 2, 4 * tir.vscale()), + ), TestCase(fld(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), (x).astype("int32x4")), TestCase(fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8), fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8)), TestCase( fld(tvm.tir.Ramp(x, 8, 5), tvm.tir.Broadcast(4, 5)), tvm.tir.Ramp(fld(x, 4), 2, 5) ), + TestCase( + fld(tvm.tir.Ramp(x, 8, 4 * tir.vscale()), tvm.tir.Broadcast(4, 4 * tir.vscale())), + tvm.tir.Ramp(fld(x, 4), 2, 4 * tir.vscale()), + ), TestCase( fld(tvm.tir.Ramp(flm(x * 4, 256), 1, 4), tvm.tir.Broadcast(8, 4)), tvm.tir.Broadcast(fld(flm(x * 4, 256), 8), 4), @@ -127,6 +159,10 @@ class TestVector(BaseCompare): TestCase( fld(tvm.tir.Ramp(x * 8, 1, 4), tvm.tir.Broadcast(4, 4)), tvm.tir.Broadcast(x * 2, 4) ), + TestCase( + fld(tvm.tir.Ramp(x * 8, 1, 4 * tir.vscale()), tvm.tir.Broadcast(4, 4 * tir.vscale())), + fld(tvm.tir.Ramp(x * 8, 1, 4 * tir.vscale()), tvm.tir.Broadcast(4, 4 * tir.vscale())), + ), TestCase( fld(tvm.tir.Ramp(x * 8, 3, 4), tvm.tir.Broadcast(4, 4)), fld(tvm.tir.Ramp(x * 8, 3, 4), tvm.tir.Broadcast(4, 4)), @@ -158,7 +194,15 @@ class TestVector(BaseCompare): # floor mod TestCase(flm(y.astype("int32x2"), x.astype("int32x2")), flm(y, x).astype("int32x2")), TestCase(flm(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Broadcast(flm(x, 2), 4)), + TestCase( + flm(tvm.tir.Ramp(x, 4, 8 * tir.vscale()), 2), + tvm.tir.Broadcast(flm(x, 2), 8 * tir.vscale()), + ), TestCase(flm(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), tvm.tir.Ramp(1, 1, 4)), + TestCase( + flm(tvm.tir.Ramp(x * 8 + 1, 1, 4 * tir.vscale()), 8), + flm(tvm.tir.Ramp(1, 1, 4 * tir.vscale()), 8), + ), TestCase(flm(tvm.tir.Ramp(x * 8 + 1, 15, 4), 8), flm(tvm.tir.Ramp(1, 15, 4), 8)), TestCase( flm(tvm.tir.Ramp(x, 8, 4), tvm.tir.Broadcast(4, 4)), tvm.tir.Broadcast(flm(x, 4), 4) diff --git a/tests/python/arith/test_arith_simplify.py b/tests/python/arith/test_arith_simplify.py index 754bf36d7ab2..1353e2d88b01 100644 --- a/tests/python/arith/test_arith_simplify.py +++ b/tests/python/arith/test_arith_simplify.py @@ -65,5 +65,13 @@ def test_regression_simplify_inf_recursion(): ana.rewrite_simplify(res) +def test_symbolic_vscale_expression(): + with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu"): + ana = tvm.arith.Analyzer() + assert ana.can_prove(128 // tir.vscale() * tir.vscale() <= 128) + assert ana.can_prove(128 // (tir.vscale() * 4) * (tir.vscale() * 4) <= 128) + assert ana.can_prove(tir.vscale() % 2 <= tir.vscale()) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index e873bce52bdf..014f479ed913 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -16,7 +16,7 @@ # under the License. import tvm from tvm import te -from tvm.script import tir as TIR +from tvm.script import tir as T import re import os import ctypes @@ -476,5 +476,25 @@ def check_correct_assembly(type): check_correct_assembly(type=dtype) +def test_predicated_scalable_buffer(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + @T.prim_func + def before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(16, 4 * T.vscale())): + for i_1 in T.vectorized(4 * T.vscale()): + if i_0 * 4 * T.vscale() + i_1 < 14: + B[i_0 * 4 * T.vscale() + i_1] = A[i_0 * 4 * T.vscale() + i_1] + 1.0 + + with tvm.target.Target(target): + out = tvm.build(before) + + ll = out.get_source("ll") + assert "get.active.lane.mask" in ll + assert "llvm.masked.load" in ll + assert "llvm.masked.store" in ll + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py b/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py index 6c069dc6bf0a..e0d6876d7626 100644 --- a/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py +++ b/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py @@ -343,14 +343,20 @@ def correct_trace(a, b, c, d): ' b2 = sch.get_block(name="C", func_name="main")', " sch.compute_inline(block=b1)", " l3, l4 = sch.get_loops(block=b2)", - " l5, l6 = sch.split(loop=l3, factors=" + str(a) + ", preserve_unit_iters=True)", - " l7, l8 = sch.split(loop=l4, factors=" + str(b) + ", preserve_unit_iters=True)", + " l5, l6 = sch.split(loop=l3, factors=" + + str(a) + + ", preserve_unit_iters=True, disable_predication=False)", + " l7, l8 = sch.split(loop=l4, factors=" + + str(b) + + ", preserve_unit_iters=True, disable_predication=False)", " sch.reorder(l5, l7, l6, l8)", " l9, l10 = sch.get_loops(block=b0)", - " l11, l12 = sch.split(loop=l9, factors=" + str(c) + ", preserve_unit_iters=True)", + " l11, l12 = sch.split(loop=l9, factors=" + + str(c) + + ", preserve_unit_iters=True, disable_predication=False)", " l13, l14 = sch.split(loop=l10, factors=" + str(d) - + ", preserve_unit_iters=True)", + + ", preserve_unit_iters=True, disable_predication=False)", " sch.reorder(l11, l13, l12, l14)", ] ) diff --git a/tests/python/target/test_arm_target.py b/tests/python/target/test_arm_target.py index dc8452710a8a..6a66731c3430 100644 --- a/tests/python/target/test_arm_target.py +++ b/tests/python/target/test_arm_target.py @@ -14,11 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import subprocess +import tempfile + import pytest +import numpy as np import tvm from tvm.topi.arm_cpu.conv2d_int8 import is_int8_hw_support from tvm.target import codegen +from tvm.script import tir as T + +import re + llvm_version, arm_target, input_dtype, kernel_dtype, is_supported = tvm.testing.parameters( # Testing mcpu type @@ -61,3 +69,411 @@ def test_arm_conv2d_int8_support( with tvm.target.Target(arm_target): monkeypatch.setattr(codegen, "llvm_version_major", lambda: llvm_version) assert is_int8_hw_support(input_dtype, kernel_dtype) == is_supported + + +@pytest.fixture(scope="session") +def sve_device_vector_length(): + c_code = r""" + #include + #include + + int main() { + printf("%ld\n", svcntb() * 8); + } + """ + + with tempfile.TemporaryDirectory() as tmp_dir: + c_path = f"{tmp_dir}/vl.c" + o_path = f"{tmp_dir}/out.o" + with open(c_path, "w") as f: + f.write(c_code) + tvm.contrib.cc.create_executable(o_path, c_path, ["-march=native"]) + out = subprocess.check_output(o_path, shell=True).strip().decode() + + return int(out) + + +@tvm.testing.requires_aarch64_sve +def test_scalable_vectorize(sve_device_vector_length): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + num_elements = sve_device_vector_length // 32 + vscale = tvm.tir.vscale() + + @T.prim_func + def main(A: T.Buffer((num_elements,), "float32"), B: T.Buffer((num_elements,), "float32")): + for i_0 in range(T.ceildiv(num_elements, 4 * vscale)): + for i_1 in T.vectorized(4 * vscale): + A_1 = T.Buffer((num_elements,), data=A.data) + B_1 = T.Buffer((num_elements,), data=B.data) + B_1[i_0 * 4 * vscale + i_1] = A_1[i_0 * 4 * vscale + i_1] + + build_mod = tvm.build(main, target=target) + + llvm = build_mod.get_source() + sve_vec_instrs = re.findall(r"\", llvm) + assert len(sve_vec_instrs) > 0, "No scalable vectors in assembly" + + dev = tvm.cpu(0) + np_zeros = np.zeros((num_elements,)).astype("float32") + np_ones = np.ones((num_elements,)).astype("float32") + + input_buf = tvm.nd.array(np_ones, device=dev) + output_buf = tvm.nd.array(np_zeros, device=dev) + + build_mod(input_buf, output_buf) + tvm.testing.assert_allclose(output_buf.numpy(), np_ones) + + +@tvm.testing.requires_aarch64_sve +def test_scalable_div(sve_device_vector_length): + np.random.seed(0) + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + dev = tvm.cpu(0) + + @T.prim_func + def my_func(A: T.Buffer((1,), "int32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + A[0] = T.Div(10000, 4 * T.vscale()) + + mod = tvm.build(my_func, target=target) + + A_nd = tvm.nd.array(np.empty((1,), dtype="int32"), device=dev) + mod(A_nd) + + ref = 10000 // (sve_device_vector_length // 32) + tvm.testing.assert_allclose(A_nd.numpy()[0], ref) + + +@tvm.testing.requires_aarch64_sve +def test_scalable_buffer_load_store(sve_device_vector_length): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + num_elements = sve_device_vector_length // 32 + dev = tvm.cpu(0) + + @T.prim_func + def my_func(A: T.Buffer((num_elements,), "float32"), B: T.Buffer((num_elements,), "float32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())] + + mod = tvm.build(my_func, target=target) + + A_np = np.random.uniform(size=(num_elements,)).astype("float32") + B_np = np.zeros((num_elements,)).astype("float32") + A_nd = tvm.nd.array(A_np, device=dev) + B_nd = tvm.nd.array(B_np, device=dev) + mod(A_nd, B_nd) + + tvm.testing.assert_allclose(B_nd.numpy(), A_np) + + +@tvm.testing.requires_aarch64_sve +def test_scalable_loop_bound(sve_device_vector_length): + np.random.seed(0) + + dtype = "float32" + num_elements = sve_device_vector_length // 32 + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + dev = tvm.cpu(0) + + @T.prim_func + def my_func(A: T.Buffer((num_elements,), "float32"), B: T.Buffer((num_elements,), "float32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + for i in T.serial(0, 4 * T.vscale()): + B[i] = A[i] + + mod = tvm.build(my_func, target=target) + + A_np = np.random.uniform(size=(num_elements,)).astype(dtype) + B_np = np.zeros((num_elements,)).astype(dtype) + A_nd = tvm.nd.array(A_np, device=dev) + B_nd = tvm.nd.array(B_np, device=dev) + mod(A_nd, B_nd) + + tvm.testing.assert_allclose(B_nd.numpy(), A_np) + + +@tvm.testing.requires_aarch64_sve +def test_scalable_broadcast(sve_device_vector_length): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + num_elements = sve_device_vector_length // 32 + dev = tvm.cpu(0) + + @T.prim_func + def my_func(A: T.Buffer((num_elements,), "float32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(1, 4 * T.vscale()) + + mod = tvm.build(my_func, target=target) + + A_np = np.zeros((num_elements,)).astype("float32") + A_nd = tvm.nd.array(A_np, device=dev) + mod(A_nd) + + ref = np.ones((num_elements,)) + tvm.testing.assert_allclose(A_nd.numpy(), ref) + + +@tvm.testing.requires_aarch64_sve +def test_scalable_ptrue_predicate(sve_device_vector_length): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + num_elements = sve_device_vector_length // 32 + dev = tvm.cpu(0) + + @T.prim_func + def my_func(A: T.Buffer((num_elements,), "float32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(T.IntImm("int1", 1), 4 * T.vscale()) + + mod = tvm.build(my_func, target=target) + + A_np = np.zeros((num_elements,)).astype("float32") + A_nd = tvm.nd.array(A_np, device=dev) + mod(A_nd) + + ref = np.ones((num_elements,)) + tvm.testing.assert_allclose(A_nd.numpy(), ref) + + +@pytest.mark.skip(reason="Currently don't support scalable gathers in codegen") +@tvm.testing.requires_aarch64_sve +def test_scalable_gather(sve_device_vector_length): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + num_elements = sve_device_vector_length // 32 + dev = tvm.cpu(0) + stride = 2 + + @T.prim_func + def my_func( + A: T.Buffer((stride * num_elements,), "float32"), B: T.Buffer((num_elements,), "float32") + ): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, stride, 4 * T.vscale())] + + mod = tvm.build(my_func, target=target) + + A_np = np.random.uniform(size=(stride * num_elements,)).astype("float32") + B_np = np.zeros((num_elements,)).astype("float32") + A_nd = tvm.nd.array(A_np, device=dev) + B_nd = tvm.nd.array(B_np, device=dev) + mod(A_nd, B_nd) + + ref = A_np[::stride] + tvm.testing.assert_allclose(B_nd.numpy(), ref) + + +@pytest.mark.skip(reason="Currently don't support scalable scatters in codegen") +@tvm.testing.requires_aarch64_sve +def test_scalable_scatter(sve_device_vector_length): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + num_elements = sve_device_vector_length // 32 + dev = tvm.cpu(0) + stride = 2 + + @T.prim_func + def my_func( + A: T.Buffer((num_elements,), "float32"), B: T.Buffer((stride * num_elements,), "float32") + ): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + B[T.ramp(0, stride, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())] + + mod = tvm.build(my_func, target=target) + + A_np = np.random.uniform(size=(num_elements,)).astype("float32") + B_np = np.zeros((stride * num_elements,)).astype("float32") + A_nd = tvm.nd.array(A_np, device=dev) + B_nd = tvm.nd.array(B_np, device=dev) + mod(A_nd, B_nd) + + ref = B_np + ref[::stride] = A_np + tvm.testing.assert_allclose(B_nd.numpy(), ref) + + +@pytest.mark.skip(reason="Currently don't support scalable gathers in codegen") +@tvm.testing.requires_aarch64_sve +def test_scalable_complex_gather(sve_device_vector_length): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + num_elements = sve_device_vector_length // 32 + dev = tvm.cpu(0) + + @T.prim_func + def my_func(A: T.Buffer((num_elements,), "float32"), B: T.Buffer((num_elements,), "float32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + B[T.ramp(0, 1, 4 * T.vscale())] = A[2 * T.ramp(0, 1, 4 * T.vscale()) % 4] + + mod = tvm.build(my_func, target=target) + + A_np = np.random.uniform(size=(num_elements,)).astype("float32") + B_np = np.zeros((num_elements,)).astype("float32") + A_nd = tvm.nd.array(A_np, device=dev) + B_nd = tvm.nd.array(B_np, device=dev) + mod(A_nd, B_nd) + + tvm.testing.assert_allclose(B_nd.numpy(), A_np) + + +@pytest.mark.skip(reason="Currently don't support scalable ramps in codegen") +@tvm.testing.requires_aarch64_sve +def test_scalable_ramp(sve_device_vector_length): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + num_elements = sve_device_vector_length // 32 + dev = tvm.cpu(0) + + @T.prim_func + def my_func(A: T.Buffer((num_elements,), "int32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + A[T.ramp(0, 1, 4 * T.vscale())] = T.ramp(11, 1, 4 * T.vscale()) + + mod = tvm.build(my_func, target=target) + + A_np = np.zeros((num_elements,)).astype("int32") + A_nd = tvm.nd.array(A_np, device=dev) + mod(A_nd) + + ref = np.arange(11, 11 + num_elements) + tvm.testing.assert_allclose(A_nd.numpy(), ref) + + +@tvm.testing.requires_aarch64_sve +@pytest.mark.parametrize("disable_predication", [True, False]) +def test_schedule_split_vectorized(disable_predication): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + dev = tvm.cpu(0) + + @T.prim_func + def my_func(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + for i in T.serial(128): + with T.block("A"): + v_i = T.axis.remap("S", [i]) + B[v_i] = A[v_i] + 1.0 + + sch = tvm.tir.Schedule(my_func) + (a,) = sch.get_loops("A") + + with tvm.target.Target(target): + _, a1 = sch.split( + a, + factors=[T.ceildiv(128, 4 * T.vscale()), 4 * T.vscale()], + disable_predication=disable_predication, + ) + + sch.vectorize(a1) + mod = tvm.build(sch.mod["main"], target=target) + + A_np = np.arange(128).astype("float32") + A_nd = tvm.nd.array(A_np, device=dev) + B_np = np.zeros(128).astype("float32") + B_nd = tvm.nd.array(B_np, device=dev) + mod(A_nd, B_nd) + + ref = A_np + 1.0 + tvm.testing.assert_allclose(B_nd.numpy(), ref) + + +def _test_accuracy(input_values, output_values, build_mod): + dev = tvm.cpu(0) + + input_buf = tvm.nd.array(input_values, device=dev) + + np_zeros = np.zeros(output_values.shape).astype("float32") + output_buf = tvm.nd.array(np_zeros, device=dev) + + build_mod(input_buf, output_buf) + tvm.testing.assert_allclose(output_buf.numpy(), output_values) + + +@tvm.testing.skip_if_32bit(reason="Skipping test for i386 due to old version of LLVM") +def test_vectorize_to_sve(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + vscale = tvm.tir.vscale() + buffer_size = 128 + + @T.prim_func + def main(A: T.Buffer((buffer_size,), "float32"), B: T.Buffer((buffer_size,), "float32")): + for i_0 in range(tvm.tir.ceildiv(128, 4 * vscale)): + for i_1 in T.vectorized(4 * vscale): + A_1 = T.Buffer((128,), data=A.data) + B_1 = T.Buffer((128,), data=B.data) + B_1[i_0 * 4 * vscale + i_1] = A_1[i_0 * 4 * vscale + i_1] + + build_mod = tvm.build(main, target=target) + + llvm = build_mod.get_source() + + assert re.findall(r"\", llvm), "No scalable vectors in assembly" + + if tvm.testing.has_cpu_feat("sve"): + print("running on an SVE enabled machine...") + + np_ones = np.ones((buffer_size,)).astype("float32") + _test_accuracy(np_ones, np_ones, build_mod) + + +@tvm.testing.skip_if_32bit(reason="Skipping test for i386 due to old version of LLVM") +def test_vectorize_to_sve_with_broadcast(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + vscale = tvm.tir.vscale() + buffer_size = 128 + + @T.prim_func + def main(A: T.Buffer((buffer_size,), "float32"), B: T.Buffer((buffer_size,), "float32")): + for i_0 in range(tvm.tir.ceildiv(128, 4 * vscale)): + for i_1 in T.vectorized(4 * vscale): + A_1 = T.Buffer((128,), data=A.data) + B_1 = T.Buffer((128,), data=B.data) + B_1[i_0 * 4 * vscale + i_1] = A_1[i_0 * 4 * vscale + i_1] * 5 + + build_mod = tvm.build(main, target=target) + + llvm = build_mod.get_source() + + assert re.findall(r"\", llvm), "No scalable vectors in assembly" + + if tvm.testing.has_cpu_feat("sve"): + print("running on an SVE enabled machine...") + + np_ones = np.ones((buffer_size,)).astype("float32") + output_values = np_ones * 5 + + _test_accuracy(np_ones, output_values, build_mod) + + +@tvm.testing.skip_if_32bit(reason="Skipping test for i386 due to old version of LLVM") +def test_sve_full_stack(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + vscale = tvm.tir.vscale() + buffer_size = 130 + + @T.prim_func + def main(A: T.Buffer((buffer_size,), "float32"), B: T.Buffer((buffer_size,), "float32")): + for i in range(buffer_size): + with T.block("A"): + B[i] = A[i] + + # Schedule + with tvm.target.Target(target): + sch = tvm.tir.Schedule(main) + (l,) = sch.get_loops("A") + + _, l1 = sch.split(l, factors=[T.ceildiv(buffer_size, 4 * vscale), 4 * vscale]) + + sch.vectorize(l1) + + with tvm.transform.PassContext(config={"tir.enable_buffer_predication": True}): + build_mod = tvm.build(sch.mod["main"], target=target) + + llvm = build_mod.get_source() + + assert re.findall(r"\", llvm), "No scalable vectors in llvm" + assert re.findall(r"llvm.masked", llvm), "No masked instructions in llvm" + + if tvm.testing.has_cpu_feat("sve"): + print("running on an SVE enabled machine...") + + np_ones = np.ones((buffer_size,)).astype("float32") + _test_accuracy(np_ones, np_ones, build_mod) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/te/test_te_schedule.py b/tests/python/te/test_te_schedule.py index ed224883478e..fa04e729cf5b 100644 --- a/tests/python/te/test_te_schedule.py +++ b/tests/python/te/test_te_schedule.py @@ -19,6 +19,7 @@ import pytest import tvm from tvm import te +from tvm.driver.build_module import schedule_to_module def test_schedule_create(): @@ -354,21 +355,28 @@ def invalid_compute_at_loop(): invalid_compute_at_loop() +@pytest.mark.parametrize("split_factor", [4, 4 * tvm.tir.vscale()]) +@pytest.mark.parametrize("disable_predication", [True, False]) +def test_split_disable_predicate(disable_predication, split_factor): + A = te.placeholder((43,), name="A") + B = te.compute(A.shape, lambda i: A[i] + 2, name="C") + + sch = te.create_schedule(B.op) + (i,) = sch[B].op.axis + _, _ = sch[B].split(i, factor=split_factor, disable_predication=disable_predication) + + mod = schedule_to_module(sch, [A, B], "main") + + predicates = [] + + def _find_predicates(stmt): + if isinstance(stmt, tvm.tir.stmt.IfThenElse): + predicates.append(stmt) + + tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _find_predicates) + + assert bool(len(predicates)) != disable_predication + + if __name__ == "__main__": - test_singleton() - test_pragma() - test_tensor_intrin() - test_tensor_intrin_scalar_params() - test_rfactor() - test_schedule_create() - test_reorder() - test_tile() - test_split() - test_fuse() - test_fuse_with_split() - test_fuse_with_out_of_order_axis() - test_fuse_with_out_of_order_axis_with_reorder() - test_vectorize() - test_vectorize_commreduce() - test_legalize_invalid_attach() - test_compute_at() + tvm.testing.main() diff --git a/tests/python/tir-base/test_tir_nodes.py b/tests/python/tir-base/test_tir_nodes.py index 49816778f11f..bb8f14191380 100644 --- a/tests/python/tir-base/test_tir_nodes.py +++ b/tests/python/tir-base/test_tir_nodes.py @@ -401,5 +401,32 @@ def test_intimm_cond(): assert x == 1 +def _create_ramp(lanes): + return tvm.tir.Ramp(0, 1, lanes) + + +def _create_broadcast(lanes): + return tvm.tir.Broadcast(0, lanes) + + +@pytest.mark.parametrize("lanes", [(11 * tvm.tir.vscale()), (tvm.tir.vscale() * 11)]) +@pytest.mark.parametrize("node_func", [_create_ramp, _create_broadcast]) +def test_scalable_vec(lanes, node_func): + def _check_dtype(node): + assert node.dtype == "int32x11xvscale" + + _check_dtype(node_func(lanes)) + + +@pytest.mark.parametrize( + "lanes", [(tvm.tir.vscale()), (tvm.tir.vscale() + 3), (tvm.tir.vscale() * 2 + 5)] +) +@pytest.mark.parametrize("node_func", [_create_ramp, _create_broadcast]) +def test_scalable_vec_error(lanes, node_func): + + with pytest.raises(tvm.error.TVMError): + node_func(lanes) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-base/test_tir_scalable_datatype.py b/tests/python/tir-base/test_tir_scalable_datatype.py new file mode 100644 index 000000000000..2c83314470de --- /dev/null +++ b/tests/python/tir-base/test_tir_scalable_datatype.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import tir +from tvm.script import tir as T + +""" +Tests for scalable data types. +""" + + +def test_create_scalable_data_type_python_api(): + dtype = tvm.DataType("float32x4xvscale") + assert str(dtype) == "float32x4xvscale" + + +def test_create_scalable_tir_intrin(): + intrin = tir.call_llvm_intrin("int32x4xvscale", "llvm.experimental.stepvector") + assert intrin.dtype == "int32x4xvscale" + assert str(intrin) == 'T.call_llvm_intrin("int32x4xvscale", "llvm.experimental.stepvector")' + + +def test_tvm_script_create_scalable_tir_intrin(): + @T.prim_func + def my_func(): + T.call_llvm_intrin("int32x4xvscale", "llvm.experimental.stepvector") + + assert ( + 'T.call_llvm_intrin("int32x4xvscale", "llvm.experimental.stepvector")' in my_func.script() + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py index 679b147446ea..23bcdd1ec55b 100644 --- a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py +++ b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py @@ -653,5 +653,73 @@ def test_split_int64_factors(): assert_structural_equal_ignore_global_symbol(elementwise_symbolic_split, sch.mod["main"]) +@pytest.mark.parametrize("num_elements", [128, 115]) +def test_split_predicated(num_elements): + # By default, splitting with vscale will result in predication being + # applied. This is because at compile-time we don't know if vscale is + # a multiple of the extent of the loop to be split. + @T.prim_func + def before(A: T.Buffer((num_elements,), "float32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + for i in T.serial(num_elements): + with T.block("A"): + v_i = T.axis.remap("S", [i]) + A[v_i] = 1.0 + + @T.prim_func + def after(A: T.Buffer((num_elements,), "float32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + for i_0, i_1 in T.grid( + (T.vscale() * 4 + (num_elements - 1)) // (T.vscale() * 4), T.vscale() * 4 + ): + with T.block("A"): + v_i = T.axis.spatial(num_elements, i_0 * (T.vscale() * 4) + i_1) + T.where(i_0 * (T.vscale() * 4) + i_1 < num_elements) + A[v_i] = 1.0 + + with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu"): + sch = tvm.tir.Schedule(before) + (a,) = sch.get_loops("A") + sch.split(a, factors=[T.ceildiv(num_elements, 4 * T.vscale()), 4 * T.vscale()]) + + tvm.ir.assert_structural_equal(sch.mod["main"], after) + + +def test_split_assume_exact_multiple(): + # If the schedule writer knows the extent of the loop to be split will always + # be a multiple, they may use `disable_predication=True` to ensure + # a predicate is not created. + num_elements = 128 + + @T.prim_func + def before(A: T.Buffer((num_elements,), "float32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + for i in T.serial(num_elements): + with T.block("A"): + v_i = T.axis.remap("S", [i]) + A[v_i] = 1.0 + + @T.prim_func + def after(A: T.Buffer((num_elements,), "float32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + for i_0, i_1 in T.grid( + (T.vscale() * 4 + (num_elements - 1)) // (T.vscale() * 4), T.vscale() * 4 + ): + with T.block("A"): + v_i = T.axis.spatial(num_elements, i_0 * (T.vscale() * 4) + i_1) + A[v_i] = 1.0 + + with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu"): + sch = tvm.tir.Schedule(before) + (a,) = sch.get_loops("A") + sch.split( + a, + factors=[T.ceildiv(num_elements, 4 * T.vscale()), 4 * T.vscale()], + disable_predication=True, + ) + + tvm.ir.assert_structural_equal(sch.mod["main"], after) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-schedule/test_tir_schedule_trace.py b/tests/python/tir-schedule/test_tir_schedule_trace.py index a793699ca755..18f15d6a7af8 100644 --- a/tests/python/tir-schedule/test_tir_schedule_trace.py +++ b/tests/python/tir-schedule/test_tir_schedule_trace.py @@ -88,7 +88,7 @@ def _make_split(inputs, outputs): # pylint: disable=redefined-builtin return Instruction( kind=InstructionKind.get("Split"), inputs=inputs, - attrs=[True], + attrs=[True, False], outputs=outputs, ) @@ -304,7 +304,7 @@ def test_trace_simplified_3(): "def apply_trace(sch: tir.Schedule) -> None:", ' b0 = sch.get_block(name="B", func_name="main")', " l1, = sch.get_loops(block=b0)", - " l2, l3 = sch.split(loop=l1, factors=[None, 32], preserve_unit_iters=True)", + " l2, l3 = sch.split(loop=l1, factors=[None, 32], preserve_unit_iters=True, disable_predication=False)", ) ) diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index 2448fffe8929..8c54e3a8afbf 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm import te +from tvm.script import tir as T def test_vectorize_loop(): @@ -226,13 +227,124 @@ def test_vectorize_dtype_mismatch(): tvm.lower(s, [A], "llvm", simple_mode=True) +def test_vectorize_and_predicate_all_buffer_loads_stores(): + @T.prim_func + def before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 < 14: + B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0 + + @T.prim_func + def expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + for i_0 in range(4): + load_a = T.meta_var( + A.load( + [T.Ramp(i_0 * 4, 1, 4)], predicate=T.get_active_lane_mask("int1x4", i_0 * 4, 14) + ) + ) + add_1 = T.meta_var(load_a + T.Broadcast(T.float32(1), 4)) + B.store( + add_1, + [T.Ramp(i_0 * 4, 1, 4)], + predicate=T.get_active_lane_mask("int1x4", i_0 * 4, 14), + ) + + mod = tvm.IRModule.from_expr(before) + with tvm.transform.PassContext(config={"tir.enable_buffer_predication": True}): + after = tvm.tir.transform.VectorizeLoop()(mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + +def test_vectorize_and_predicate_some_buffer_loads_stores(): + # Currently revert to scalarizing the block if not all accesses + # have been predicated, otherwise incorrect code is generated. + @T.prim_func + def before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 < 14: + B[i_0 * 4 + i_1] = A[i_0] + 1.0 + + @T.prim_func + def expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + for i_0, i_1_s in T.grid(4, 4): + if i_0 * 4 + i_1_s < 14: + B[i_0 * 4 + i_1_s] = A[i_0] + T.float32(1) + + mod = tvm.IRModule.from_expr(before) + with tvm.transform.PassContext(config={"tir.enable_buffer_predication": True}): + after = tvm.tir.transform.VectorizeLoop()(mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + +def test_vectorize_and_predicate_multiple_access_statements(): + @T.prim_func + def before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 < 14: + A[i_0 * 4 + i_1] = 2.0 + B[i_0 * 4 + i_1] = 1.0 + + @T.prim_func + def expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + for i_0 in range(4): + A.store( + T.Broadcast(T.float32(2), 4), + [T.Ramp(i_0 * 4, 1, 4)], + predicate=T.get_active_lane_mask("int1x4", i_0 * 4, 14), + ) + B.store( + T.Broadcast(T.float32(1), 4), + [T.Ramp(i_0 * 4, 1, 4)], + predicate=T.get_active_lane_mask("int1x4", i_0 * 4, 14), + ) + + before_mod = tvm.IRModule.from_expr(before) + with tvm.transform.PassContext(config={"tir.enable_buffer_predication": True}): + after = tvm.tir.transform.VectorizeLoop()(before_mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + +def test_vectorize_and_predicate_invalid_conditions(): + @T.prim_func + def before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 > 14: + A[i_0 * 4 + i_1] = 2.0 + if 14 < i_0 * 4 + i_1: + A[i_0 * 4 + i_1] = 2.0 + if i_0 * 4 + i_1 < i_0 * 4 + i_1: + A[i_0 * 4 + i_1] = 2.0 + + @T.prim_func + def expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + for i_0 in range(4): + for i_1_s in range(4): + if i_0 * 4 + i_1_s > 14: + A[i_0 * 4 + i_1_s] = T.float32(2) + for i_1_s in range(4): + if 14 < i_0 * 4 + i_1_s: + A[i_0 * 4 + i_1_s] = T.float32(2) + for i_1_s in range(4): + if i_0 * 4 + i_1_s < i_0 * 4 + i_1_s: + A[i_0 * 4 + i_1_s] = T.float32(2) + + before_mod = tvm.IRModule.from_expr(before) + with tvm.transform.PassContext(config={"tir.enable_buffer_predication": True}): + after = tvm.tir.transform.VectorizeLoop()(before_mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + if __name__ == "__main__": - test_vectorize_vector() - test_vectorize_with_if() - test_vectorize_loop() - test_vectorize_if_then_else() - test_vectorize_with_le_cond() - test_vectorize_with_ge_cond() - test_vectorize_let() - test_vectorize_while_fail() - test_vectorize_dtype_mismatch() + tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index 18f4e153bff9..e258d58de606 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -889,5 +889,80 @@ def func(a_name: T.handle): assert re.match(expected_regex, script) +def test_predicated_load_store(): + from tvm.script import tir as T + + @T.prim_func + def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): + T.func_attr({"global_symbol": "func"}) + a_load = T.meta_var(A.load([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(0, 4))) + A.store(a_load, [0, T.Ramp(0, 2, 4)], predicate=T.Broadcast(0, 4)) + + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func +def func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): + A.store(A.load([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(0, 4)), [0, T.Ramp(0, 2, 4)], predicate=T.Broadcast(0, 4)) + """ + _assert_print(main, expected_output) + + +def test_predicated_buffer_load_store(): + a = tir.Var("a", "handle") + b = tir.Var("b", "handle") + buffer_map = { + a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A"), + b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"), + } + buffer_load = tir.BufferLoad( + buffer=buffer_map[b], indices=[0, tir.Ramp(0, 4, 4)], predicate=tir.Broadcast(0, 4) + ) + body = tir.BufferStore( + buffer=buffer_map[a], + value=buffer_load, + indices=[0, tir.Ramp(0, 2, 4)], + predicate=tir.Broadcast(0, 4), + ) + func = tir.PrimFunc( + params=[a, b], + ret_type=None, + buffer_map=buffer_map, + body=body, + ) + + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func(private=True) +def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): + A.store(B.load([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(0, 4)), [0, T.Ramp(0, 2, 4)], predicate=T.Broadcast(0, 4)) + """ + _assert_print(func, expected_output) + + +def test_predicated_scalable_load_store(): + from tvm.script import tir as T + + @T.prim_func + def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): + T.func_attr({"global_symbol": "func"}) + mask = T.meta_var(T.get_active_lane_mask("int1x4xvscale", 0, 13)) + a_load = T.meta_var(A.load([0, T.Ramp(0, 4, 4 * T.vscale())], predicate=mask)) + A.store(a_load, [0, T.Ramp(0, 2, 4 * T.vscale())], predicate=mask) + + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func +def func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): + A.store(\ + A.load([0, T.Ramp(0, 4, 4 * T.vscale())], predicate=T.get_active_lane_mask("int1x4xvscale", 0, 13)), \ + [0, T.Ramp(0, 2, 4 * T.vscale())], predicate=T.get_active_lane_mask("int1x4xvscale", 0, 13)\ + ) + """ + _assert_print(main, expected_output) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index 5b3e68e22fa9..fefacb5bd10b 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -3339,6 +3339,25 @@ def func() -> None: return func +def scalable_vectors(): + @T.prim_func + def func(a: T.handle): + A = T.match_buffer(a, (200,), "float32") + A[T.Ramp(11, 2, 4 * tir.vscale())] = T.Broadcast(125, 4 * tir.vscale()) + + return func + + +def predicated_buffer_load_store(): + @T.prim_func + def func(A: T.Buffer((4,), "float32"), B: T.Buffer((8,), "float32")): + for i_0 in range(4): + load_a = T.meta_var(A.load([T.Ramp(i_0, 1, 4)], predicate=T.Broadcast(1.0, 4))) + B.store(load_a, [T.Ramp(0, 2, 4)], predicate=T.Broadcast(1.0, 4)) + + return func + + def let_expression(): @T.prim_func def func(): @@ -4038,6 +4057,8 @@ def func(): buffer_axis_separator, buffer_ramp_access_as_slice_index, ramp_int64, + scalable_vectors, + predicated_buffer_load_store, let_expression, void_ptr, decl_buffer,