From 6808c2a2a409186d019c59bb12296978ab516d0d Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Wed, 3 Jan 2024 11:21:11 +0000 Subject: [PATCH 1/6] [SVE] Implement scalable vectors in TVM This prototype is to accompany the open SVE RFC. It implements the design outlined in the RFC. The main changes to the stack include: 1. tir.split can accept an expression with vscale as a factor 2. LoopVectorizer can create Ramp and Broadcast nodes with scalable lanes 3. BufferLoad and BufferStore nodes can accept an optional predicate which is created in LoopVectorizer 4. LLVM codegen can lower the scalable predicated vectors into llvm.masked.* intrinsics The prototype is currently missing tir.tile and TVMScript parser support for predicates. Co-authored-by: Luke Hutton Co-authored-by: Neil Hickey --- include/tvm/runtime/data_type.h | 47 +- include/tvm/tir/builtin.h | 12 + include/tvm/tir/expr.h | 17 +- include/tvm/tir/schedule/schedule.h | 7 +- include/tvm/tir/stmt.h | 8 +- python/tvm/_ffi/runtime_ctypes.py | 10 +- python/tvm/script/ir_builder/tir/ir.py | 5 +- python/tvm/testing/utils.py | 17 +- python/tvm/tir/__init__.py | 1 + python/tvm/tir/expr.py | 11 +- python/tvm/tir/op.py | 4 + python/tvm/tir/schedule/schedule.py | 15 +- python/tvm/tir/stmt.py | 7 +- src/arith/analyzer.cc | 31 ++ src/arith/int_set.cc | 24 +- src/arith/pattern_match.h | 15 +- src/arith/rewrite_simplify.cc | 110 +++-- src/arith/rewrite_simplify.h | 2 + src/arith/scalable_expression.cc | 76 ++++ src/arith/scalable_expression.h | 64 +++ src/driver/driver_api.cc | 1 + src/ir/expr.cc | 4 +- src/relay/printer/tir_text_printer.cc | 4 +- src/relay/printer/tvmscript_printer.cc | 4 +- src/script/ir_builder/tir/ir.cc | 9 +- src/script/printer/tir/buffer.cc | 13 +- src/script/printer/tir/expr.cc | 4 +- src/target/llvm/codegen_llvm.cc | 131 ++++-- src/target/llvm/codegen_llvm.h | 7 +- src/target/source/codegen_c.cc | 6 +- src/target/source/codegen_c_host.cc | 5 +- src/target/source/codegen_cuda.cc | 28 +- src/target/source/codegen_metal.cc | 5 +- src/target/source/codegen_opencl.cc | 12 +- src/target/source/codegen_webgpu.cc | 5 +- src/te/operation/create_primfunc.cc | 4 +- src/tir/analysis/device_constraint_utils.cc | 5 +- src/tir/contrib/ethosu/passes.cc | 3 +- src/tir/ir/expr.cc | 52 ++- src/tir/ir/expr_functor.cc | 12 +- src/tir/ir/stmt.cc | 21 +- src/tir/op/builtin.cc | 7 + src/tir/op/op.cc | 37 +- src/tir/schedule/analysis/reducer.cc | 33 +- src/tir/schedule/concrete_schedule.cc | 7 +- src/tir/schedule/concrete_schedule.h | 2 +- src/tir/schedule/primitive.h | 7 +- .../schedule/primitive/loop_transformation.cc | 14 +- src/tir/schedule/traced_schedule.cc | 15 +- src/tir/schedule/traced_schedule.h | 2 +- src/tir/transforms/bound_checker.cc | 15 +- src/tir/transforms/inject_rolling_buffer.cc | 6 +- src/tir/transforms/lower_match_buffer.cc | 2 + src/tir/transforms/lower_warp_memory.cc | 4 + .../manifest_shared_memory_local_stage.cc | 1 + src/tir/transforms/remove_no_op.cc | 3 +- .../remove_weight_layout_rewrite_block.cc | 2 +- .../transforms/renormalize_split_pattern.cc | 2 +- src/tir/transforms/storage_flatten.cc | 19 +- src/tir/transforms/storage_rewrite.cc | 25 +- .../transforms/unsupported_dtype_legalize.cc | 8 + src/tir/transforms/vectorize_loop.cc | 242 ++++++++-- tests/cpp/pattern_match_test.cc | 12 +- tests/cpp/tir_scalable_datatype.cc | 129 ++++++ .../arith/test_arith_rewrite_simplify.py | 44 ++ tests/python/arith/test_arith_simplify.py | 8 + .../codegen/test_target_codegen_aarch64.py | 22 +- .../test_meta_schedule_post_order_apply.py | 14 +- tests/python/target/test_arm_target.py | 413 ++++++++++++++++++ tests/python/tir-base/test_tir_nodes.py | 27 ++ .../tir-base/test_tir_scalable_datatype.py | 49 +++ .../test_tir_schedule_split_fuse.py | 68 +++ .../tir-schedule/test_tir_schedule_trace.py | 4 +- .../test_tir_transform_vectorize.py | 119 ++++- .../tvmscript/test_tvmscript_roundtrip.py | 10 + 75 files changed, 1874 insertions(+), 316 deletions(-) create mode 100644 src/arith/scalable_expression.cc create mode 100644 src/arith/scalable_expression.h create mode 100644 tests/cpp/tir_scalable_datatype.cc create mode 100644 tests/python/tir-base/test_tir_scalable_datatype.py diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index ac7e879a644d..6d083e2631f7 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -27,6 +27,7 @@ #include #include +#include #include #include @@ -71,11 +72,16 @@ class DataType { * \param code The type code. * \param bits The number of bits in the type. * \param lanes The number of lanes. + * \param scalable Whether or not the data type is scalable. */ - DataType(int code, int bits, int lanes) { + DataType(int code, int bits, int lanes, bool scalable = false) { data_.code = static_cast(code); data_.bits = static_cast(bits); - data_.lanes = static_cast(lanes); + if (scalable) { + data_.lanes = static_cast(-lanes); + } else { + data_.lanes = static_cast(lanes); + } if (code == kBFloat) { ICHECK_EQ(bits, 16); } @@ -90,7 +96,14 @@ class DataType { /*! \return number of bytes to store each scalar. */ int bytes() const { return (bits() + 7) / 8; } /*! \return number of lanes in the data. */ - int lanes() const { return static_cast(data_.lanes); } + int lanes() const { + int encoded_lanes = static_cast(data_.lanes); + if (is_scalable()) { + return -encoded_lanes; + } else { + return encoded_lanes; + } + } /*! \return whether type is a scalar type. */ bool is_scalar() const { return lanes() == 1; } /*! \return whether type is a scalar type. */ @@ -114,17 +127,28 @@ class DataType { /*! \return whether type is a handle type. */ bool is_handle() const { return code() == DataType::kHandle && !is_void(); } /*! \return whether type is a vector type. */ - bool is_vector() const { return lanes() > 1; } + bool is_vector() const { + int encoded_lanes = static_cast(data_.lanes); + return encoded_lanes != 0 && encoded_lanes != 1; + } /*! \return whether type is a bool vector type. */ bool is_vector_bool() const { return is_vector() && bits() == 1; } /*! \return whether type is a Void type. */ bool is_void() const { return code() == DataType::kHandle && bits() == 0 && lanes() == 0; } + /*! \return Whether the type is scalable. */ + bool is_scalable() const { return static_cast(data_.lanes) < 0; } /*! * \brief Create a new data type by change lanes to a specified value. * \param lanes The target number of lanes. * \return the result type. */ DataType with_lanes(int lanes) const { return DataType(data_.code, data_.bits, lanes); } + /*! + * \brief Create a new scalable data type by changing the lanes to a specified value. + * \param lanes The target number of lanes. + * \return A copy of the old DataType with the number of scalable lanes. + */ + DataType with_scalable_lanes(int lanes) const { return DataType(data_.code, data_.bits, -lanes); } /*! * \brief Create a new data type by change bits to a specified value. * \param bits The target number of bits. @@ -247,6 +271,9 @@ class DataType { * \return Number of bytes needed. */ inline int GetVectorBytes(DataType dtype) { + if (dtype.is_scalable()) { + LOG(FATAL) << "Cannot get vector bytes of scalable vector"; + } int data_bits = dtype.bits() * dtype.lanes(); // allow bool to exist if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) || @@ -357,8 +384,12 @@ inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*) } if (t.code == kTVMOpaqueHandle) return os; os << static_cast(t.bits); - if (t.lanes != 1) { - os << 'x' << static_cast(t.lanes); + + int16_t lanes = static_cast(t.lanes); + if (lanes > 1) { + os << 'x' << lanes; + } else if (lanes < 0) { + os << 'x' << -lanes << "xvscale"; } return os; } @@ -424,6 +455,10 @@ inline DLDataType String2DLDataType(std::string s) { if (*xdelim == 'x') { t.lanes = static_cast(strtoul(xdelim + 1, &endpt, 10)); } + if (strncmp(endpt, "xvscale", 7) == 0) { + t.lanes = -t.lanes; + endpt = endpt + 7; + } ICHECK(endpt == s.c_str() + s.length()) << "unknown type " << s; return t; } diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index e6116605f8a2..76a29997fdd1 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -909,6 +909,18 @@ TVM_DLL const Op& anylist_setitem_call_packed(); */ TVM_DLL const Op& anylist_setitem_call_cpacked(); +/*! + * \brief Return the value of vscale + */ +TVM_DLL const Op& vscale(); + +/*! + * \brief Provide the predicate constructed of the currently active lanes + * + * Calculate the active lane masks given a bound and a current value + */ +TVM_DLL const Op& get_active_lane_mask(); + /*! \brief The kind of structure field info used in intrinsic */ enum TVMStructFieldKind : int { // array head address diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 4e29eddadd8c..1babb4a92a97 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -630,23 +630,27 @@ class BufferLoadNode : public PrimExprNode { Buffer buffer; /*! \brief The indices location to be loaded. */ Array indices; + /*! \brief The buffer predicate */ + PrimExpr predicate; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &(this->dtype)); v->Visit("buffer", &buffer); v->Visit("indices", &indices); + v->Visit("predicate", &predicate); v->Visit("span", &span); } bool SEqualReduce(const BufferLoadNode* other, SEqualReducer equal) const { return equal(dtype, other->dtype) && equal(buffer, other->buffer) && - equal(indices, other->indices); + equal(indices, other->indices) && equal(predicate, other->predicate); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); hash_reduce(buffer); hash_reduce(indices); + hash_reduce(predicate); } static constexpr const char* _type_key = "tir.BufferLoad"; @@ -675,7 +679,8 @@ class BufferLoadNode : public PrimExprNode { */ class BufferLoad : public PrimExpr { public: - TVM_DLL explicit BufferLoad(Buffer buffer, Array indices, Span span = Span()); + TVM_DLL explicit BufferLoad(Buffer buffer, Array indices, + PrimExpr predicate = PrimExpr(nullptr), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode); }; @@ -746,7 +751,7 @@ class RampNode : public PrimExprNode { /*! \brief The stride of each step. */ PrimExpr stride; /*! \brief Total number of lanes. */ - int lanes; + PrimExpr lanes; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); @@ -778,7 +783,7 @@ class RampNode : public PrimExprNode { */ class Ramp : public PrimExpr { public: - TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span = Span()); + TVM_DLL Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Ramp, PrimExpr, RampNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(RampNode); }; @@ -789,7 +794,7 @@ class BroadcastNode : public PrimExprNode { /*! \brief The base value. */ PrimExpr value; /*! \brief The number of lanes. */ - int lanes; + PrimExpr lanes; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); @@ -818,7 +823,7 @@ class BroadcastNode : public PrimExprNode { */ class Broadcast : public PrimExpr { public: - TVM_DLL Broadcast(PrimExpr value, int lanes, Span span = Span()); + TVM_DLL Broadcast(PrimExpr value, PrimExpr lanes, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Broadcast, PrimExpr, BroadcastNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BroadcastNode); }; diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 273912ed1f8f..7f9c9f0a8801 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -350,10 +350,15 @@ class ScheduleNode : public runtime::Object { * \param factors The positive tiling factors, and at most one of which is `NullOpt`, which means * that factor is inferred. * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings + * \param disable_predication If enabled, don't create a predicate for guarding the + * loop. This can be useful when splitting with scalable factors that the schedule writer + * knows are divisible. Warning: enabling this feature may result in incorrect code generation + * if not used carefully. * \return The new loops after split */ virtual Array Split(const LoopRV& loop_rv, const Array>& factors, - bool preserve_unit_iters = true) = 0; + bool preserve_unit_iters = true, + bool disable_predication = false) = 0; /*! * \brief Reorder a list of loops. It doesn't require the loops to be consecutive. * It requires: diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 07cc9b5ad0d5..3255b1ac90ac 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -231,23 +231,27 @@ class BufferStoreNode : public StmtNode { PrimExpr value; /*! \brief The indices location to be stored. */ Array indices; + /*! \brief The predicate for this store. */ + PrimExpr predicate; void VisitAttrs(AttrVisitor* v) { v->Visit("buffer", &buffer); v->Visit("value", &value); v->Visit("indices", &indices); + v->Visit("predicate", &predicate); v->Visit("span", &span); } bool SEqualReduce(const BufferStoreNode* other, SEqualReducer equal) const { return equal(buffer, other->buffer) && equal(value, other->value) && - equal(indices, other->indices); + equal(indices, other->indices) && equal(predicate, other->predicate); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(buffer); hash_reduce(value); hash_reduce(indices); + hash_reduce(predicate); } static constexpr const char* _type_key = "tir.BufferStore"; @@ -261,7 +265,7 @@ class BufferStoreNode : public StmtNode { class BufferStore : public Stmt { public: TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array indices, - Span span = Span()); + PrimExpr predicate = PrimExpr(nullptr), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode); diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 54e4d8f205a1..e8a61eb73736 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -135,7 +135,10 @@ def __init__(self, type_str): arr = type_str.split("x") head = arr[0] - self.lanes = int(arr[1]) if len(arr) > 1 else 1 + if len(arr) == 3 and arr[2] == "vscale": + self.lanes = ctypes.c_uint16(-int(arr[1])) + elif len(arr) > 1: + self.lanes = ctypes.c_uint16(int(arr[1])) bits = 32 if head.startswith("int"): @@ -188,8 +191,11 @@ def __repr__(self): type_name = "custom[%s]" % tvm.runtime._ffi_api._datatype_get_type_name(self.type_code) x = "%s%d" % (type_name, self.bits) - if self.lanes != 1: + lanes_as_int = ctypes.c_int16(self.lanes).value + if lanes_as_int > 1: x += "x%d" % self.lanes + elif lanes_as_int < 0: + x += "x%dxvscale" % -lanes_as_int return x def __eq__(self, other): diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 74b0bd2ba4e1..7e42ea3c2204 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1289,7 +1289,7 @@ def buffer_store( if lanes == 1: expr_indices.append(index.start) else: - expr_indices.append(ramp(index.start, step, int(lanes))) + expr_indices.append(ramp(index.start, step, lanes)) else: expr_indices.append(index) if isinstance(value, bool) and buffer.dtype == "bool": @@ -1854,6 +1854,7 @@ def wrapped(*args, **kwargs): create_barriers = _op_wrapper(_tir_op.create_barriers) assume = _op_wrapper(_tir_op.assume) undef = _op_wrapper(_tir_op.undef) +vscale = _op_wrapper(_tir_op.vscale) TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace) TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace) start_profile_intrinsic = _op_wrapper(_tir_op.start_profile_intrinsic) @@ -1891,7 +1892,6 @@ def wrapped(*args, **kwargs): vectorhigh = _dtype_forward(_tir_op.vectorhigh) vectorcombine = _dtype_forward(_tir_op.vectorcombine) - broadcast = Broadcast ramp = Ramp fabs = abs @@ -2199,4 +2199,5 @@ def wrapped(*args, **kwargs): "IterVar", "CommReducer", "Range", + "vscale", ] diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index ccad989c33ef..0fbb2084e4de 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -1023,7 +1023,7 @@ def _corstone300_compile_time_check(): # check cpu features -def _has_cpu_feat(features): +def has_cpu_feat(features): cpu = codegen.llvm_get_system_cpu() triple = codegen.llvm_get_system_triple() target = "llvm -mtriple=%s -mcpu=%s" % (triple, cpu) @@ -1035,21 +1035,28 @@ def _has_cpu_feat(features): requires_arm_dot = Feature( "arm_dot", "ARM dot product", - run_time_check=lambda: _has_cpu_feat("dotprod"), + run_time_check=lambda: has_cpu_feat("dotprod"), +) + + +requires_aarch64_sve = Feature( + "arm_sve", + "AArch64 SVE", + run_time_check=lambda: has_cpu_feat("sve"), ) requires_x86_vnni = Feature( "x86_vnni", "x86 VNNI Extensions", - run_time_check=lambda: (_has_cpu_feat("avx512vnni") or _has_cpu_feat("avxvnni")), + run_time_check=lambda: (has_cpu_feat("avx512vnni") or has_cpu_feat("avxvnni")), ) requires_x86_avx512 = Feature( "x86_avx512", "x86 AVX512 Extensions", - run_time_check=lambda: _has_cpu_feat( + run_time_check=lambda: has_cpu_feat( ["avx512bw", "avx512cd", "avx512dq", "avx512vl", "avx512f"] ), ) @@ -1058,7 +1065,7 @@ def _has_cpu_feat(features): requires_x86_amx = Feature( "x86_amx", "x86 AMX Extensions", - run_time_check=lambda: _has_cpu_feat("amx-int8"), + run_time_check=lambda: has_cpu_feat("amx-int8"), ) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index f0500290b888..1723804388b9 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -88,6 +88,7 @@ from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace from .op import start_profile_intrinsic, end_profile_intrinsic +from .op import vscale from .generic import add, subtract, multiply from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index fad9fca083a1..7123584f6c8e 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -1095,6 +1095,9 @@ class BufferLoad(PrimExprWithOp): indices : List[PrimExpr] The buffer indices. + predicate : Optional[PrimExpr] + The buffer predicate + span : Optional[Span] The location of this expression in the source code. """ @@ -1103,10 +1106,14 @@ class BufferLoad(PrimExprWithOp): indices: List[PrimExpr] def __init__( - self, buffer: Buffer, indices: List[PrimExpr], span: Optional[Span] = None + self, + buffer: Buffer, + indices: List[PrimExpr], + predicate: Optional[PrimExpr] = None, + span: Optional[Span] = None, ) -> None: self.__init_handle_by_constructor__( - _ffi_api.BufferLoad, buffer, indices, span # type: ignore + _ffi_api.BufferLoad, buffer, indices, predicate, span # type: ignore ) diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 4735a2644e83..a6f9e3fd12d3 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -1586,6 +1586,10 @@ def vectorlow(dtype, vec): return call_intrin(dtype, "tir.vectorlow", vec) +def vscale(): + return call_intrin("int32", "tir.vscale") + + def vectorhigh(dtype, vec): """Get the high level half of the vector diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 23b000c09015..32563d504f61 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -736,6 +736,7 @@ def split( loop: LoopRV, factors: List[Union[int, ExprRV, None]], preserve_unit_iters: bool = True, + disable_predication: bool = False, ) -> List[LoopRV]: """Split a loop into a list of consecutive loops. It requires: 1) The loop can't have annotation or thread binding. @@ -759,6 +760,14 @@ def split( preserve_unit_iters : bool Whether or not to preserve unit iterators in block bindings + disable_predication : bool + If enabled, don't create a predicate for guarding the loop. This + can be useful when splitting with scalable factors that the + schedule writer knows are divisible. + + Warning: enabling this feature may result in incorrect code + generation if not used carefully. + Returns ------- split_loops : List[LoopRV] @@ -809,7 +818,11 @@ def after_split(a: T.handle, b: T.handle) -> None: # that there is at most one None in `factors` return list( _ffi_api.ScheduleSplit( # type: ignore # pylint: disable=no-member - self, loop, factors, preserve_unit_iters + self, + loop, + factors, + preserve_unit_iters, + disable_predication, ) ) diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 992c388e27bb..31298af67f3d 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -224,6 +224,9 @@ class BufferStore(Stmt): indices : List[PrimExpr] The indices location to be stored. + predicate : Optional[PrimExpr] + The buffer predicate + span : Optional[Span] The location of the stmt in the source code. """ @@ -231,6 +234,7 @@ class BufferStore(Stmt): buffer: Buffer value: PrimExpr indices: List[PrimExpr] + predicate: Optional[PrimExpr] span: Optional[Span] def __init__( @@ -238,10 +242,11 @@ def __init__( buffer: Buffer, value: PrimExpr, indices: List[PrimExpr], + predicate: Optional[PrimExpr] = None, span: Optional[Span] = None, ) -> None: self.__init_handle_by_constructor__( - _ffi_api.BufferStore, buffer, value, indices, span # type: ignore + _ffi_api.BufferStore, buffer, value, indices, predicate, span # type: ignore ) diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 3e5b8834ebca..17b0634227eb 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -22,9 +22,13 @@ */ #include #include +#include #include #include +#include "../target/parsers/aprofile.h" +#include "../tir/analysis/check_contains.h" +#include "./scalable_expression.h" #include "const_fold.h" #include "product_normal_form.h" @@ -227,6 +231,33 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { } } + // Current analysis may not be powerful enough to prove expressions with + // multiple symbolic values. When the expression is scalable and the compile + // target is AArch64, we can make some assumptions about the value of vscale + // and iterate over a space of pre-defined values to attempt to prove the + // expression. + Target current_target = tvm::Target::Current(true); + if (!current_target.defined()) { + return false; + } + TargetJSON target_json = target::parsers::aprofile::ParseTarget(current_target->Export()); + TargetFeatures features = Downcast(target_json.at("features")); + bool is_llvm_aarch64 = Downcast(features.at("is_aarch64")); + if (is_llvm_aarch64 && tir::CheckContains::ExprContains(simplified, IsVScaleCall)) { + bool can_prove_expr = true; + for (const unsigned int vscale_value : kVScaleValues) { + PrimExpr result = SubstituteVScaleWithKnownValue(expr, vscale_value); + result = Simplify(result); + const int64_t* as_int = tir::as_const_int(result); + if (!as_int || *as_int == 0) { + can_prove_expr = false; + } + } + if (can_prove_expr) { + return true; + } + } + return false; } diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 625488430bf8..88d6d2149b7a 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -466,14 +467,21 @@ class IntervalSetEvaluator : public ExprFunctor { if (stride.Match(op->stride)) { DataType t = op->base.dtype(); int64_t vstride = stride.Eval()->value; - if (vstride > 0) { - return Combine(analyzer_, base, - IntervalSet(make_zero(t), make_const(t, vstride * (op->lanes - 1))), - op->dtype); - } else { - return Combine(analyzer_, base, - IntervalSet(make_const(t, vstride * (op->lanes - 1)), make_zero(t)), - op->dtype); + if (op->lanes->IsInstance()) { + int lanes = static_cast(Downcast(op->lanes)->value); + if (vstride > 0) { + return Combine(analyzer_, base, + IntervalSet(make_zero(t), make_const(t, vstride * (lanes - 1))), + op->dtype); + } else { + return Combine(analyzer_, base, + IntervalSet(make_const(t, vstride * (lanes - 1)), make_zero(t)), + op->dtype); + } + } else { /* Scalable vector */ + if (vstride > 0) { + return Combine(analyzer_, base, IntervalSet(make_zero(t), pos_inf()), op->dtype); + } } } DLOG(WARNING) << "cannot evaluate set on expression " << GetRef(op); diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index d057a840e8b7..98cf61990d90 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -628,10 +628,11 @@ inline PRampExpr ramp(const Pattern& base, } template -inline PRampExpr, PConst> ramp(const Pattern& base, - int stride, int lanes) { - return PRampExpr, PConst>( - base.derived(), PConstWithTypeLike(base.derived(), stride), PConst(lanes)); +inline PRampExpr, PConstWithTypeLike> ramp( + const Pattern& base, int stride, int lanes) { + return PRampExpr, PConstWithTypeLike>( + base.derived(), PConstWithTypeLike(base.derived(), stride), + PConstWithTypeLike(base.derived(), lanes)); } /*! @@ -835,6 +836,12 @@ inline PCallExpr if_then_else(const Pattern false_value.derived()); } +// vscale +struct PVscaleOp { + static PrimExpr Eval() { return tir::Call(DataType::Int(32), GetOp(), {}); } + static const Op& GetOp() { return tir::builtin::vscale(); } +}; + template class PMatchesOneOf { public: diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index d5f946fca02a..ddcc4eff96a1 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -247,7 +247,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { // Pattern var match FloatImm PVar c4; // Pattern var for lanes in broadcast and ramp - PVar lanes; + PVar lanes; // Vector rules if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(ramp(b1, s1, lanes) + ramp(b2, s2, lanes), ramp(b1 + b2, s1 + s2, lanes)); @@ -396,7 +396,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { // Pattern var match IntImm PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp - PVar lanes; + PVar lanes; // Vector rules if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes), ramp(b1 - b2, s1 - s2, lanes)); @@ -580,7 +580,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { // Pattern var match FloatImm PVar c3; // Pattern var for lanes in broadcast and ramp - PVar lanes; + PVar lanes; // Vector rules if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes), broadcast(x * y, lanes)); @@ -617,7 +617,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { // Pattern var match IntImm PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp - PVar lanes; + PVar lanes; // x / 2.0 = x * 0.5 if (const FloatImmNode* ptr = op->b.as()) { @@ -639,10 +639,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { return ramp(div(b1, c2), div(c1, c2), lanes).Eval(); } // If all possible indices in ramp are the same. - if (CanProveGreaterEqual(b1.Eval(), 0)) { + if (CanProveGreaterEqual(b1.Eval(), 0) && !IsWithScalableLanes(lanes.Eval())) { ModularSet bmod = analyzer_->modular_set(b1.Eval()); int64_t ramp_min = bmod->base / c2val; - int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val; + auto lanes_int = lanes.Eval().as()->value; + int64_t ramp_max = (bmod->base + (lanes_int - 1) * c1val) / c2val; if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) { return broadcast(div(b1, c2), lanes).Eval(); } @@ -777,7 +778,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { // Pattern var match IntImm PVar c1, c2; // Pattern var for lanes in broadcast and ramp - PVar lanes; + PVar lanes; // Vector rules if (op->dtype.lanes() != 1) { @@ -793,14 +794,23 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { return broadcast(truncmod(b1, c2), lanes).Eval(); } // If all possible indices in ramp are the same. - if (CanProveGreaterEqual(b1.Eval(), 0)) { + if (CanProveGreaterEqual(b1.Eval(), 0) /*&& !IsWithScalableLanes(lanes.Eval())*/) { ModularSet bmod = analyzer_->modular_set(b1.Eval()); - int64_t ramp_min = bmod->base / c2val; - int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val; - if (bmod->coeff % c2val == 0) { - if (ramp_min == ramp_max) { - return ramp(truncmod(bmod->base, c2), c1, lanes).Eval(); - } else { + if (!IsWithScalableLanes(lanes.Eval())) { + auto lanes_int = lanes.Eval().as()->value; + int64_t ramp_min = bmod->base / c2val; + int64_t ramp_max = (bmod->base + (lanes_int - 1) * c1val) / c2val; + if (bmod->coeff % c2val == 0) { + if (ramp_min == ramp_max) { + return ramp(truncmod(bmod->base, c2), c1, lanes).Eval(); + } else { + return truncmod(ramp(truncmod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)) + .Eval(); + } + } + } else { /* Special case for scalable vectors */ + ModularSet bmod = analyzer_->modular_set(b1.Eval()); + if (bmod->coeff % c2val == 0) { return truncmod(ramp(truncmod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); } } @@ -857,7 +867,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { // Pattern var match IntImm PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp - PVar lanes; + PVar lanes; // Vector rules if (op->dtype.lanes() != 1) { @@ -872,17 +882,20 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { return ramp(floordiv(b1, c2), floordiv(c1, c2), lanes).Eval(); } // If all possible indices in ramp are the same. - ModularSet bmod = analyzer_->modular_set(b1.Eval()); - int64_t ramp_min = floordiv(bmod->base, c2val); - int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val); - if (ramp_min == ramp_max) { - // If b1 can devide c2 - if (bmod->coeff % c2val == 0) { - return broadcast(floordiv(b1, c2), lanes).Eval(); - } - // If all indices can be guaranteed to settle inside a coeff range - if (c2val % bmod->coeff == 0 && bmod->base + (lanes.Eval() - 1) * c1val < bmod->coeff) { - return broadcast(floordiv(b1, c2), lanes).Eval(); + if (!IsWithScalableLanes(lanes.Eval())) { + ModularSet bmod = analyzer_->modular_set(b1.Eval()); + int64_t ramp_min = floordiv(bmod->base, c2val); + auto lanes_int = lanes.Eval().as()->value; + int64_t ramp_max = floordiv(bmod->base + (lanes_int - 1) * c1val, c2val); + if (ramp_min == ramp_max) { + // If b1 can devide c2 + if (bmod->coeff % c2val == 0) { + return broadcast(floordiv(b1, c2), lanes).Eval(); + } + // If all indices can be guaranteed to settle inside a coeff range + if (c2val % bmod->coeff == 0 && bmod->base + (lanes_int - 1) * c1val < bmod->coeff) { + return broadcast(floordiv(b1, c2), lanes).Eval(); + } } } } @@ -993,7 +1006,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { // Pattern var match IntImm PVar c1, c2; // Pattern var for lanes in broadcast and ramp - PVar lanes; + PVar lanes; // Vector rules if (op->dtype.lanes() != 1) { @@ -1010,21 +1023,28 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { } // If all possible indices in ramp are the same. ModularSet bmod = analyzer_->modular_set(b1.Eval()); - int64_t ramp_min = floordiv(bmod->base, c2val); - int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val); - if (ramp_min == ramp_max) { - // If b1 can devide c2 + if (!IsWithScalableLanes(lanes.Eval())) { + int64_t ramp_min = floordiv(bmod->base, c2val); + auto lanes_int = lanes.Eval().as()->value; + int64_t ramp_max = floordiv(bmod->base + (lanes_int - 1) * c1val, c2val); + if (ramp_min == ramp_max) { + // If b1 can devide c2 + if (bmod->coeff % c2val == 0) { + return ramp(floormod(bmod->base, c2), c1, lanes).Eval(); + } + // If all indices can be guaranteed to settle inside a coeff range + if (c2val % bmod->coeff == 0 && bmod->base + (lanes_int - 1) * c1val < bmod->coeff) { + return ramp(floormod(b1, c2), c1, lanes).Eval(); + } + } if (bmod->coeff % c2val == 0) { - return ramp(floormod(bmod->base, c2), c1, lanes).Eval(); + return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); } - // If all indices can be guaranteed to settle inside a coeff range - if (c2val % bmod->coeff == 0 && bmod->base + (lanes.Eval() - 1) * c1val < bmod->coeff) { - return ramp(floormod(b1, c2), c1, lanes).Eval(); + } else { /* scalable vectors */ + if (bmod->coeff % c2val == 0) { + return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); } } - if (bmod->coeff % c2val == 0) { - return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); - } } } @@ -1093,7 +1113,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) { PVar x, y, z, s1, s2; // Pattern var match IntImm PVar c1, c2; - PVar lanes; + PVar lanes; // vector rule if (op->dtype.lanes() != 1) { @@ -1267,7 +1287,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) { PVar x, y, z, s1, s2; // Pattern var match IntImm PVar c1, c2; - PVar lanes; + PVar lanes; // vector rule if (op->dtype.lanes() != 1) { @@ -1475,7 +1495,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { PVar x, y; // Pattern var match IntImm PVar c1, c2; - PVar lanes; + PVar lanes; // vector rule if (ret->dtype.lanes() != 1) { @@ -1603,7 +1623,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) { PVar x, y, z, s1, s2; // Pattern var match IntImm PVar c1, c2; - PVar lanes; + PVar lanes; // vector rule if (ret->dtype.lanes() != 1) { @@ -1761,7 +1781,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) { PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(Not ret) { // Pattern var to match any expression PVar x, y; - PVar lanes; + PVar lanes; if (ret->dtype.lanes() != 1) { TVM_TRY_REWRITE(!broadcast(x, lanes), broadcast(!x, lanes)); } @@ -1836,7 +1856,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { PVar x, y, z; // Pattern var match IntImm PVar c1, c2, c3; - PVar lanes; + PVar lanes; if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), broadcast(x && y, lanes)); @@ -1984,7 +2004,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { PVar x, y, z; // Pattern var match IntImm PVar c1, c2; - PVar lanes; + PVar lanes; if (op->dtype.lanes() != 1) { TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), broadcast(x || y, lanes)); diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 7c4b0eab2224..288ee948ff90 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -221,6 +221,8 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { bool CanProveGreaterEqual(const PrimExpr& x, int64_t val) { return analyzer_->CanProveGreaterEqual(x, val); } + // Whether the lanes are scalable + bool IsWithScalableLanes(const PrimExpr& lanes) { return !lanes.as(); } // Whether x < val bool CanProveLess(const PrimExpr& x, int64_t val) { return analyzer_->CanProveLess(x, val); } // Whether x == val diff --git a/src/arith/scalable_expression.cc b/src/arith/scalable_expression.cc new file mode 100644 index 000000000000..c3437f5f3807 --- /dev/null +++ b/src/arith/scalable_expression.cc @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/arith/scalable_expression.cc + * \brief Analyze scalable expressions. + */ + +#include "scalable_expression.h" + +#include +#include + +// #include "../tir/analysis/vscale_expr.h" +#include "../tir/transforms/replace_selected_expr.h" +#include "./pattern_match.h" + +namespace tvm { +namespace arith { + +bool IsVScaleCall(const PrimExpr& expr) { + if (auto call = expr.as()) { + return call->op.same_as(tir::builtin::vscale()); + } + return false; +} + +PrimExpr SubstituteVScaleWithKnownValue(const PrimExpr& expr, unsigned int vscale_value) { + std::function predicate_selector = [](const PrimExpr& current_expr) { + return IsVScaleCall(current_expr); + }; + + std::function predicate_can_replace_inside = + [](const PrimExpr& current_expr) { return true; }; + + return tir::ReplaceSelectedExpr::ReplaceSelectedExprInExpr( + expr, predicate_selector, tir::MakeConstScalar(DataType::Int(32), vscale_value), + predicate_can_replace_inside); +} + +PrimExpr CanonicalizeScalableLanes(const PrimExpr& lanes) { + PVar multiplier; + PCallExpr vscale; + + PrimExpr new_lanes; + + if ((multiplier * vscale).Match(lanes)) { + new_lanes = lanes; + } else if ((vscale * multiplier).Match(lanes)) { + new_lanes = + tir::Mul(multiplier.Eval(), tir::Call(DataType::Int(32), tir::builtin::vscale(), {})); + } else { + LOG(FATAL) << "Illegal form for scalable vector lanes: " << lanes; + } + + return new_lanes; +} + +} // namespace arith +} // namespace tvm diff --git a/src/arith/scalable_expression.h b/src/arith/scalable_expression.h new file mode 100644 index 000000000000..c55e930b29fe --- /dev/null +++ b/src/arith/scalable_expression.h @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/arith/scalable_expression.h + * \brief Analyze scalable expressions. + */ + +#ifndef TVM_ARITH_SCALABLE_EXPRESSION_H_ +#define TVM_ARITH_SCALABLE_EXPRESSION_H_ + +#include + +#include + +namespace tvm { +namespace arith { + +/*! \brief A list of known vscale values to try. */ +static const std::vector kVScaleValues = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}; + +/*! + * \brief Substitute a vscale intrinsic call with a known value. + * \param expr The expr to apply substitutions to. + * \param vscale_value The scalar value to replace vscale with. + * \return A rewritten expression with vscale values replaced with a scalar value. + */ +PrimExpr SubstituteVScaleWithKnownValue(const PrimExpr& expr, unsigned int vscale_value); + +/*! + * \brief Check if an expr is a call to the vscale intrinsic. + * \param expr The expr to check + * \return True if the expr is a call to the vscale intrinsic, false if not. + */ +bool IsVScaleCall(const PrimExpr& expr); + +/*! + * \brief Returns the scalable lanes in a form multiplier * vscale + * \param lanes The scalable lanes as a PrimExpr + * \return Scalable lanes in a form multiplier * vscale + */ +PrimExpr CanonicalizeScalableLanes(const PrimExpr& lanes); + +} // namespace arith +} // namespace tvm + +#endif // TVM_ARITH_SCALABLE_EXPRESSION_H_ diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 17cd5c49a1bf..42a4a0a2fe3f 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -44,6 +44,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_buffer_predication", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool); diff --git a/src/ir/expr.cc b/src/ir/expr.cc index fdd8c2cd8bc5..596805f74b24 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -53,8 +53,8 @@ PrimExpr PrimExpr::FromObject_(ObjectRef ref) { for (const Range& r : buffer_region->region) { if (tvm::tir::is_one(r->extent)) { indices.push_back(r->min); - } else if (const auto* extent = r->extent.as()) { - indices.push_back(tir::Ramp(r->min, tvm::tir::make_const(r->min->dtype, 1), extent->value)); + } else if (r->extent.as()) { + indices.push_back(tir::Ramp(r->min, tvm::tir::make_const(r->min->dtype, 1), r->extent)); } else { LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: " << ref; } diff --git a/src/relay/printer/tir_text_printer.cc b/src/relay/printer/tir_text_printer.cc index e9a9ee231358..c34788be91b8 100644 --- a/src/relay/printer/tir_text_printer.cc +++ b/src/relay/printer/tir_text_printer.cc @@ -381,13 +381,13 @@ Doc TIRTextPrinter::VisitExpr_(const ProducerLoadNode* op) { Doc TIRTextPrinter::VisitExpr_(const RampNode* op) { Doc doc; - doc << "ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " << op->lanes << ")"; + doc << "ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " << Print(op->lanes) << ")"; return doc; } Doc TIRTextPrinter::VisitExpr_(const BroadcastNode* op) { Doc doc; - doc << "broadcast(" << Print(op->value) << ", " << op->lanes << ")"; + doc << "broadcast(" << Print(op->value) << ", " << Print(op->lanes) << ")"; return doc; } diff --git a/src/relay/printer/tvmscript_printer.cc b/src/relay/printer/tvmscript_printer.cc index b0085b82426e..1126e633d633 100644 --- a/src/relay/printer/tvmscript_printer.cc +++ b/src/relay/printer/tvmscript_printer.cc @@ -912,14 +912,14 @@ Doc TVMScriptPrinter::VisitExpr_(const RampNode* op, ExprPrecedence* out_precede *out_precedence = ExprPrecedence::kIdentity; Doc doc; doc << tir_prefix_ << ".ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " - << op->lanes << ")"; + << Print(op->lanes) << ")"; return doc; } Doc TVMScriptPrinter::VisitExpr_(const BroadcastNode* op, ExprPrecedence* out_precedence) { *out_precedence = ExprPrecedence::kIdentity; Doc doc; - doc << tir_prefix_ << ".broadcast(" << Print(op->value) << ", " << op->lanes << ")"; + doc << tir_prefix_ << ".broadcast(" << Print(op->value) << ", " << Print(op->lanes) << ")"; return doc; } diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index d6554fc37103..a6569952177e 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -519,7 +519,14 @@ Var EnvThread(String thread_tag) { void BufferStore(Buffer buffer, PrimExpr value, Array indices) { runtime::DataType buffer_dtype = buffer->dtype; int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1; - runtime::DataType lhs_dtype = buffer_dtype.with_lanes(buffer_dtype.lanes() * index_lanes); + bool index_scalable = indices.size() ? indices.back().dtype().is_scalable() : false; + bool buffer_scalable = buffer->dtype.is_scalable(); + runtime::DataType lhs_dtype; + if (index_scalable || buffer_scalable) { + lhs_dtype = buffer_dtype.with_scalable_lanes(buffer_dtype.lanes() * index_lanes); + } else { + lhs_dtype = buffer_dtype.with_lanes(buffer_dtype.lanes() * index_lanes); + } runtime::DataType rhs_dtype = value->dtype; if (lhs_dtype != rhs_dtype) { if (lhs_dtype.lanes() != rhs_dtype.lanes()) { diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index 45a0dfd2aea4..d23a3fe17aba 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -273,14 +273,23 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::BufferStore store, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc buffer = d->AsDoc(store->buffer, p->Attr("buffer")); - return AssignDoc(/*lhs=*/buffer[BufferIndices(store->indices, p->Attr("indices"), d)], - /*rhs=*/d->AsDoc(store->value, p->Attr("value")), NullOpt); + if (store->predicate.defined()) { + ExprDoc predicate = d->AsDoc(store->predicate, p->Attr("predicate")); + buffer = CallDoc(buffer, {}, {"pred"}, {predicate}); + } + return AssignDoc( + /*lhs=*/buffer[BufferIndices(store->indices, p->Attr("indices"), d)], + /*rhs=*/d->AsDoc(store->value, p->Attr("value")), NullOpt); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::BufferLoad load, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc buffer = d->AsDoc(load->buffer, p->Attr("buffer")); + if (load->predicate.defined()) { + ExprDoc predicate = d->AsDoc(load->predicate, p->Attr("predicate")); + buffer = CallDoc(buffer, {}, {"pred"}, {predicate}); + } return buffer[BufferIndices(load->indices, p->Attr("indices"), d)]; }); diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index e25b074401d4..8268e6b35ecb 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -137,7 +137,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return TIR(d, "Ramp")->Call({ d->AsDoc(ramp->base, ramp_p->Attr("base")), d->AsDoc(ramp->stride, ramp_p->Attr("stride")), - LiteralDoc::Int(ramp->lanes, ramp_p->Attr("lanes")), + d->AsDoc(ramp->lanes, ramp_p->Attr("lanes")), }); }); @@ -146,7 +146,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return TIR(d, "Broadcast") ->Call({ d->AsDoc(bc->value, bc_p->Attr("value")), - LiteralDoc::Int(bc->lanes, bc_p->Attr("lanes")), + d->AsDoc(bc->lanes, bc_p->Attr("lanes")), }); }); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 9701a299f1d1..f38e27e5b524 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -589,9 +589,17 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { } if (dtype.lanes() != 1) { #if TVM_LLVM_VERSION >= 110 - return llvm::FixedVectorType::get(etype, dtype.lanes()); + if (dtype.is_scalable()) { + return llvm::VectorType::get(etype, dtype.lanes(), true); + } else { + return llvm::FixedVectorType::get(etype, dtype.lanes()); + } #else - return llvm::VectorType::get(etype, dtype.lanes()); + if (dtype.is_scalable()) { + return llvm::VectorType::get(etype, dtype.lanes(), true); + } else { + return llvm::VectorType::get(etype, dtype.lanes()); + } #endif } else { return etype; @@ -641,13 +649,13 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer_va int64_t base = 0, width = 0; arith::PVar pbase, pstride; - arith::PVar planes; + arith::PVar planes; // create meta-data for alias analysis // Use a group of binary tree ranges of memory banks. int64_t xwith = 0; if (arith::ramp(pbase, pstride, planes).Match(index)) { base = pbase.Eval()->value; - xwith = planes.Eval() * pstride.Eval()->value; + xwith = planes.Eval()->value * pstride.Eval()->value; } else if (auto* ptr = index.as()) { base = ptr->value; xwith = 1; @@ -749,26 +757,6 @@ std::unique_ptr CodeGenLLVM::CreateDebugInfo(llvm::Modul return debug_info; } -llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) { -#if TVM_LLVM_VERSION >= 110 - llvm::Type* type = llvm::FixedVectorType::get(value->getType(), lanes); -#else - llvm::Type* type = llvm::VectorType::get(value->getType(), lanes); -#endif - llvm::Constant* undef = llvm::UndefValue::get(type); - llvm::Constant* zero = ConstInt32(0); - value = builder_->CreateInsertElement(undef, value, zero); -#if TVM_LLVM_VERSION >= 120 - llvm::Constant* mask = llvm::ConstantVector::getSplat(llvm::ElementCount::getFixed(lanes), zero); -#elif TVM_LLVM_VERSION >= 110 - llvm::Constant* mask = - llvm::ConstantVector::getSplat(llvm::ElementCount(lanes, /*Scalable=*/false), zero); -#else - llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero); -#endif - return builder_->CreateShuffleVector(value, undef, mask); -} - llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) { int num_elems = GetVectorNumElements(vec); if (extent == num_elems && begin == 0) return vec; @@ -1478,6 +1466,15 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { return builder_->CreateAssumption(cond); } else if (op->op.same_as(builtin::tvm_thread_invariant())) { return MakeValue(op->args[0]); + } else if (op->op.same_as(builtin::vscale())) { + llvm::Intrinsic::ID id = llvm::Intrinsic::vscale; + llvm::Function* f = GetIntrinsicDecl(id, builder_->getInt32Ty(), {}); + return builder_->CreateCall(f); + } else if (op->op.same_as(builtin::get_active_lane_mask())) { + llvm::Intrinsic::ID id = llvm::Intrinsic::get_active_lane_mask; + llvm::Function* f = GetIntrinsicDecl(id, DTypeToLLVMType(op->dtype), + {builder_->getInt32Ty(), builder_->getInt32Ty()}); + return builder_->CreateCall(f, {MakeValue(op->args[0]), MakeValue(op->args[1])}); } else { LOG(FATAL) << "unknown intrinsic " << op->op; } @@ -1658,9 +1655,9 @@ bool CodeGenLLVM::HasAlignmentPadding(DataType dtype) { } void CodeGenLLVM::BufferAccessHelper( - Buffer buffer, Array indices, DataType value_dtype, - std::function + Buffer buffer, Array indices, PrimExpr predicate, DataType value_dtype, + std::function make_instruction) { DataType buffer_element_dtype = buffer->dtype; @@ -1739,10 +1736,20 @@ void CodeGenLLVM::BufferAccessHelper( std::vector all_index_values = earlier_index_values; all_index_values.push_back(last_index_value); + llvm::Value* pred_val = nullptr; + if (predicate.defined()) { + pred_val = MakeValue(predicate); + } + TypedPointer buffer_ptr = - CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, all_index_values, - value_dtype.with_lanes(value_dtype.lanes() / last_index.dtype().lanes())); - auto instruction = make_instruction(buffer_ptr, subelement_i, alignment, is_volatile); + value_dtype.is_scalable() + ? CreateBufferPtr( + MakeValue(buffer->data), buffer_element_dtype, all_index_values, + value_dtype.with_scalable_lanes(value_dtype.lanes() / last_index.dtype().lanes())) + : CreateBufferPtr( + MakeValue(buffer->data), buffer_element_dtype, all_index_values, + value_dtype.with_lanes(value_dtype.lanes() / last_index.dtype().lanes())); + auto instruction = make_instruction(buffer_ptr, subelement_i, pred_val, alignment, is_volatile); AddAliasInfo(instruction, buffer->data.get(), last_index_origin, buffer_element_dtype_origin); } } @@ -1752,11 +1759,17 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { std::vector loads; - auto make_load = [this, &loads](TypedPointer buffer_ptr, int /* subelement_i */, int alignment, - bool is_volatile) { + auto make_load = [this, &loads](TypedPointer buffer_ptr, int /* subelement_i */, + llvm::Value* predicate, int alignment, bool is_volatile) { #if TVM_LLVM_VERSION >= 110 - auto load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, - llvm::Align(alignment), is_volatile); + llvm::Instruction* load = nullptr; + if (predicate != NULL) { + load = builder_->CreateMaskedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(alignment), + predicate); + } else { + load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(alignment), + is_volatile); + } #elif TVM_LLVM_VERSION >= 80 auto load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, alignment, is_volatile); @@ -1771,7 +1784,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { // Pass all indices into BufferAccessHelper. In CodeGenLLVM, // non-flat indices will result in an error in CreateBufferPtr, but // a subclass may override CreateBufferPtr. - BufferAccessHelper(op->buffer, op->indices, value_dtype, make_load); + BufferAccessHelper(op->buffer, op->indices, op->predicate, value_dtype, make_load); if (loads.size() == 1) { return loads[0]; @@ -1821,7 +1834,10 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) { llvm::Value* vec = llvm::UndefValue::get(DTypeToLLVMType(op->dtype)); - for (int i = 0; i < op->lanes; ++i) { + auto* lanes_as_int = op->lanes.as(); + ICHECK(lanes_as_int) << "vscale in codegen_llvm"; + auto lanes = static_cast(lanes_as_int->value); + for (int i = 0; i < lanes; ++i) { vec = builder_->CreateInsertElement( vec, MakeValue(op->base + op->stride * make_const(op->stride.dtype(), i)), ConstInt32(i)); } @@ -1853,7 +1869,29 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) { - return CreateBroadcast(MakeValue(op->value), op->lanes); + DataType dtype = op->dtype; + llvm::Value* value = MakeValue(op->value); + llvm::Type* type = DTypeToLLVMType(dtype); + llvm::ElementCount ec; + if (dtype.is_scalable()) { + ec = llvm::ElementCount::get(dtype.lanes(), true); + } else { + ec = llvm::ElementCount::get(dtype.lanes(), false); + } + llvm::Constant* undef = llvm::UndefValue::get(type); + llvm::Constant* zero = ConstInt32(0); + value = builder_->CreateInsertElement(undef, value, zero); +#if TVM_LLVM_VERSION >= 120 + llvm::Constant* mask = llvm::ConstantVector::getSplat(ec, zero); +#elif TVM_LLVM_VERSION >= 110 + llvm::Constant* mask = llvm::ConstantVector::getSplat(ec, zero); +#else + if (dtype->is_scalable()) { + LOG(FATAL) << "Can't create scalable broadcast"; + } + llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero); +#endif + return builder_->CreateShuffleVector(value, undef, mask); } void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { @@ -1863,24 +1901,31 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { llvm::Value* value = MakeValue(op->value); - auto make_store = [this, value](TypedPointer buffer_ptr, int subelement_i, int alignment, - bool is_volatile) { + auto make_store = [this, value](TypedPointer buffer_ptr, int subelement_i, llvm::Value* predicate, + int alignment, bool is_volatile) { llvm::Value* to_store = value; + llvm::Instruction* store; if (subelement_i != -1) { to_store = builder_->CreateExtractElement(value, subelement_i); } #if TVM_LLVM_VERSION >= 110 - return builder_->CreateAlignedStore(to_store, buffer_ptr.addr, llvm::Align(alignment), - is_volatile); + if (predicate != NULL) { + store = + builder_->CreateMaskedStore(to_store, buffer_ptr.addr, llvm::Align(alignment), predicate); + } else { + store = builder_->CreateAlignedStore(to_store, buffer_ptr.addr, llvm::Align(alignment), + is_volatile); + } #else - return builder_->CreateAlignedStore(to_store, buffer_ptr.addr, alignment, is_volatile); + store = builder_->CreateAlignedStore(to_store, buffer_ptr.addr, alignment, is_volatile); #endif + return store; }; // Pass all indices into BufferAccessHelper. In CodeGenLLVM, // non-flat indices will result in an error in CreateBufferPtr, but // a subclass may override CreateBufferPtr. - BufferAccessHelper(op->buffer, op->indices, value_dtype, make_store); + BufferAccessHelper(op->buffer, op->indices, op->predicate, value_dtype, make_store); } void CodeGenLLVM::VisitStmt_(const ForNode* op) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 8c8929c8f093..23b2d6edbeef 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -349,9 +349,9 @@ class CodeGenLLVM : public ExprFunctor, * - Should return the generated expression. */ void BufferAccessHelper( - Buffer buffer, Array indices, DataType value_dtype, - std::function + Buffer buffer, Array indices, PrimExpr predicate, DataType value_dtype, + std::function make_instruction); // Initialize target virtual void InitTarget(); @@ -468,7 +468,6 @@ class CodeGenLLVM : public ExprFunctor, llvm::Value* CreateAdd(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateSub(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateMul(DataType t, llvm::Value* a, llvm::Value* b); - llvm::Value* CreateBroadcast(llvm::Value* value, int lanes); virtual TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype, llvm::ArrayRef indices, DataType value_dtype); // Vector concatenation. diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index dd83fbdcbd62..90722557550b 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -888,11 +888,13 @@ void CodeGenC::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*) // NOTE: C have comma expression so cannot use (int2)(v0, v1) // instead should use int2(v0, v1) PrintType(op->dtype, os); + ICHECK(op->lanes->IsInstance()) << "Scalable vectors are not supported in codegen_c"; + int lanes = static_cast(Downcast(op->lanes)->value); os << "("; - for (int i = 0; i < op->lanes; i++) { + for (int i = 0; i < lanes; i++) { os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i << ")"; - if (i != op->lanes - 1) os << ", "; + if (i != lanes - 1) os << ", "; } os << ")"; } diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index caef43e8af28..de37a591ac5b 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -204,10 +204,13 @@ void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*) void CodeGenCHost::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); + ICHECK(op->lanes->IsInstance()) + << "Scalable vectors are not supported in codegen_c_host"; + int lanes = static_cast(Downcast(op->lanes)->value); os << "(("; PrintType(op->dtype, os); os << ")("; - for (int i = 0; i < op->lanes; ++i) { + for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; os << v; } diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index efed5c02f1c9..e694c7eeb4ed 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -1162,19 +1162,25 @@ void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) { } void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) { - CHECK_LE(op->lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed."; + ICHECK(op->lanes->IsInstance()) + << "Scalable vectors are not supported in codegen_cuda"; + int lanes = static_cast(Downcast(op->lanes)->value); + CHECK_LE(lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed."; PrintVecConstructor(op->dtype, os); os << "("; - for (int i = 0; i < op->lanes; i++) { + for (int i = 0; i < lanes; i++) { os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i << ")"; - if (i != op->lanes - 1) os << ", "; + if (i != lanes - 1) os << ", "; } os << ")"; } void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) - if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && op->lanes == 4) { + ICHECK(op->lanes->IsInstance()) + << "Scalable vectors are not supported in codegen_c_host"; + int lanes = static_cast(Downcast(op->lanes)->value); + if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && lanes == 4) { // make_int8x4 const int64_t* p = as_const_int(op->value); ICHECK(p); @@ -1192,7 +1198,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO std::string v = PrintExpr(op->value); PrintVecConstructor(op->dtype, os); os << '('; - for (int i = 0; i < op->lanes / 2; ++i) { + for (int i = 0; i < lanes / 2; ++i) { if (i != 0) os << ", "; os << "__pack_half2(" << v << ", " << v << ")"; } @@ -1204,7 +1210,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO std::string v = PrintExpr(op->value); PrintVecConstructor(op->dtype, os); os << '('; - for (int i = 0; i < op->lanes / 2; ++i) { + for (int i = 0; i < lanes / 2; ++i) { if (i != 0) os << ", "; os << "__pack_nv_bfloat162(" << v << ", " << v << ")"; } @@ -1218,7 +1224,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO ICHECK(p); int64_t v = *p & 0xF; - if (op->lanes == 4) { + if (lanes == 4) { v = (v << 12) | (v << 8) | (v << 4) | v; if (op->dtype.is_uint()) { os << "(uint16_t)" << v; @@ -1227,16 +1233,16 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO } } else { v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v; - if (op->lanes == 8) { + if (lanes == 8) { if (op->dtype.is_uint()) { os << "(uint)" << v; } else { os << "(int)" << v; } - } else if (op->lanes == 16 || op->lanes == 32) { + } else if (lanes == 16 || lanes == 32) { PrintVecConstructor(op->dtype, os); os << '('; - for (int i = 0; i < op->lanes / 8; ++i) { + for (int i = 0; i < lanes / 8; ++i) { if (i != 0) os << ", "; if (op->dtype.is_uint()) { os << "(uint)" << v; @@ -1258,7 +1264,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO std::string v = PrintExpr(op->value); PrintVecConstructor(op->dtype, os); os << '('; - for (int i = 0; i < op->lanes; ++i) { + for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; os << v; } diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 86d5956dec19..5a1ef12b32a2 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -308,9 +308,12 @@ void CodeGenMetal::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLI void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); + ICHECK(op->lanes->IsInstance()) + << "Scalable vectors are not supported in codegen_metal"; + int lanes = static_cast(Downcast(op->lanes)->value); PrintType(op->dtype, os); os << "("; - for (int i = 0; i < op->lanes; ++i) { + for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; os << v; } diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index da6a4de6196a..b33191734fc7 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -472,10 +472,13 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); + ICHECK(op->lanes->IsInstance()) + << "Scalable vectors are not supported in codegen_opencl"; + int lanes = static_cast(Downcast(op->lanes)->value); os << "(("; PrintType(op->dtype, os); os << ")("; - for (int i = 0; i < op->lanes; ++i) { + for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; os << v; } @@ -486,10 +489,13 @@ void CodeGenOpenCL::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLIN os << "(("; PrintType(op->dtype, os); os << ")("; - for (int i = 0; i < op->lanes; i++) { + ICHECK(op->lanes->IsInstance()) + << "Scalable vectors are not supported in codegen_opencl"; + int lanes = static_cast(Downcast(op->lanes)->value); + for (int i = 0; i < lanes; i++) { os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i << ")"; - if (i != op->lanes - 1) os << ", "; + if (i != lanes - 1) os << ", "; } os << "))"; } diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 1702699ac232..e52be9b24da6 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -344,9 +344,12 @@ void CodeGenWebGPU::PrintSSAAssign(const std::string& target, const std::string& void CodeGenWebGPU::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); + ICHECK(op->lanes->IsInstance()) + << "Scalable vectors are not supported in codegen_webgpu"; + int lanes = static_cast(Downcast(op->lanes)->value); PrintType(op->dtype, os); os << "("; - for (int i = 0; i < op->lanes; ++i) { + for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; os << v; } diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 0dc8b3870104..5eee46e2160c 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -79,7 +79,7 @@ class BufferSubstituter : public StmtExprMutator { auto load = Downcast(StmtExprMutator::VisitExpr_(op)); auto it = buffer_map_.find(load->buffer.get()); if (it != buffer_map_.end()) { - return BufferLoad(it->second, load->indices, load->span); + return BufferLoad(it->second, load->indices, load->predicate, load->span); } return load; } @@ -88,7 +88,7 @@ class BufferSubstituter : public StmtExprMutator { auto store = Downcast(StmtExprMutator::VisitStmt_(op)); auto it = buffer_map_.find(store->buffer.get()); if (it != buffer_map_.end()) { - return BufferStore(it->second, store->value, store->indices, store->span); + return BufferStore(it->second, store->value, store->indices, store->predicate, store->span); } return store; } diff --git a/src/tir/analysis/device_constraint_utils.cc b/src/tir/analysis/device_constraint_utils.cc index 4554038bc770..40df8b65c295 100644 --- a/src/tir/analysis/device_constraint_utils.cc +++ b/src/tir/analysis/device_constraint_utils.cc @@ -254,7 +254,8 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { Downcast(StmtExprMutator::VisitExpr_(buffer_load_node)); Buffer new_buffer = Subst(new_buffer_load->buffer.get()); if (!new_buffer.same_as(new_buffer_load->buffer)) { - return BufferLoad(new_buffer, new_buffer_load->indices, new_buffer_load->span); + return BufferLoad(new_buffer, new_buffer_load->indices, new_buffer_load->predicate, + new_buffer_load->span); } return std::move(new_buffer_load); } @@ -293,7 +294,7 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { Buffer new_buffer = Subst(new_buffer_store->buffer.get()); if (!new_buffer.same_as(new_buffer_store->buffer)) { return BufferStore(new_buffer, new_buffer_store->value, new_buffer_store->indices, - new_buffer_store->span); + new_buffer_store->predicate, new_buffer_store->span); } return std::move(new_buffer_store); } diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc index fba506fba1c9..6dc72d27286a 100644 --- a/src/tir/contrib/ethosu/passes.cc +++ b/src/tir/contrib/ethosu/passes.cc @@ -718,7 +718,8 @@ class MergeConstantsMutator : public StmtExprMutator { buffer->axis_separators, buffer->span}; old_to_new_read_buffers[buffer.as()] = new_buffer; - new_args.push_back(BufferLoad(new_buffer, buffer_load->indices, buffer_load->span)); + new_args.push_back(BufferLoad(new_buffer, buffer_load->indices, buffer_load->predicate, + buffer_load->span)); break; } case 2: /* length */ { diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 41500051fa89..6d41ce53818a 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -21,10 +21,13 @@ * \file expr.cc */ #include +#include +#include #include #include #include +#include "../../arith/scalable_expression.h" #include "../../support/str_escape.h" #include "buffer_common.h" @@ -427,18 +430,26 @@ TVM_REGISTER_GLOBAL("tir.Select") TVM_REGISTER_NODE_TYPE(SelectNode); // Ramp -Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) { +Ramp::Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) { ICHECK(base.defined()); ICHECK(stride.defined()); ICHECK(base.dtype().is_scalar()); ICHECK(stride.dtype().is_scalar()); - ICHECK_GT(lanes, 1); if (stride.dtype() != base.dtype()) { stride = cast(base.dtype(), stride); } ObjectPtr node = make_object(); - node->dtype = base.dtype().with_lanes(lanes); + auto* lanes_int = lanes.as(); + if (lanes_int) { + int lanes = static_cast(lanes_int->value); + ICHECK_GT(lanes, 1); + node->dtype = base.dtype().with_lanes(lanes); + } else { /* scalable vector */ + lanes = arith::CanonicalizeScalableLanes(lanes); + int vscale_multiplier = Downcast(Downcast(lanes)->a)->value; + node->dtype = base.dtype().with_scalable_lanes(vscale_multiplier); + } node->base = base; node->stride = stride; node->lanes = lanes; @@ -447,27 +458,35 @@ Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) { } TVM_REGISTER_GLOBAL("tir.Ramp") - .set_body_typed([](PrimExpr base, PrimExpr stride, int lanes, Span span) { + .set_body_typed([](PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) { return Ramp(base, stride, lanes, span); }); TVM_REGISTER_NODE_TYPE(RampNode); // Broadcast -Broadcast::Broadcast(PrimExpr value, int lanes, Span span) { +Broadcast::Broadcast(PrimExpr value, PrimExpr lanes, Span span) { ICHECK(value.defined()); ICHECK(value.dtype().is_scalar()); - ICHECK_GT(lanes, 1); ObjectPtr node = make_object(); - node->dtype = value.dtype().with_lanes(lanes); + auto* lanes_int = lanes.as(); + if (lanes_int) { + int lanes = static_cast(lanes_int->value); + ICHECK_GT(lanes, 1); + node->dtype = value.dtype().with_lanes(lanes); + } else { /* scalable vector */ + lanes = arith::CanonicalizeScalableLanes(lanes); + int vscale_multiplier = Downcast(Downcast(lanes)->a)->value; + node->dtype = value.dtype().with_scalable_lanes(vscale_multiplier); + } node->value = std::move(value); node->lanes = lanes; node->span = std::move(span); data_ = node; } -TVM_REGISTER_GLOBAL("tir.Broadcast").set_body_typed([](PrimExpr value, int lanes, Span span) { +TVM_REGISTER_GLOBAL("tir.Broadcast").set_body_typed([](PrimExpr value, PrimExpr lanes, Span span) { return Broadcast(value, lanes, span); }); @@ -525,8 +544,8 @@ TVM_REGISTER_GLOBAL("tir.Call") for (Range r : br->region) { if (is_one(r->extent)) { indices.push_back(r->min); - } else if (const auto* extent = r->extent.as()) { - indices.push_back(tir::Ramp(r->min, make_const(r->min->dtype, 1), extent->value)); + } else if (r->extent.as()) { + indices.push_back(tir::Ramp(r->min, make_const(r->min->dtype, 1), r->extent)); } else { LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: " << GetRef(br); @@ -717,10 +736,14 @@ void BufferLoadNode::LegalizeDType() { int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1; int buffer_lanes = buffer->dtype.lanes(); - this->dtype = buffer->dtype.with_lanes(index_lanes * buffer_lanes); + if ((indices.size() && indices.back().dtype().is_scalable()) || buffer->dtype.is_scalable()) { + this->dtype = buffer->dtype.with_scalable_lanes(index_lanes * buffer_lanes); + } else { + this->dtype = buffer->dtype.with_lanes(index_lanes * buffer_lanes); + } } -BufferLoad::BufferLoad(Buffer buffer, Array indices, Span span) { +BufferLoad::BufferLoad(Buffer buffer, Array indices, PrimExpr predicate, Span span) { ICHECK_EQ(buffer->shape.size(), indices.size()) << "Buffer " << buffer->name << " is " << buffer->shape.size() << "-dimensional, cannot be indexed with the " << indices.size() @@ -729,14 +752,15 @@ BufferLoad::BufferLoad(Buffer buffer, Array indices, Span span) { ObjectPtr node = make_object(); node->buffer = std::move(buffer); node->indices = std::move(indices); + node->predicate = std::move(predicate); node->span = std::move(span); node->LegalizeDType(); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.BufferLoad") - .set_body_typed([](Buffer buffer, Array indices, Span span) { - return BufferLoad(buffer, indices, span); + .set_body_typed([](Buffer buffer, Array indices, PrimExpr predicate, Span span) { + return BufferLoad(buffer, indices, predicate, span); }); TVM_REGISTER_NODE_TYPE(BufferLoadNode); diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 8a93d9dd8242..34b46583d5ad 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -127,7 +127,7 @@ PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) { if (indices.same_as(op->indices)) { return GetRef(op); } else { - return BufferLoad(op->buffer, indices); + return BufferLoad(op->buffer, indices, op->predicate); } } @@ -258,19 +258,21 @@ PrimExpr ExprMutator::VisitExpr_(const SelectNode* op) { PrimExpr ExprMutator::VisitExpr_(const RampNode* op) { PrimExpr base = this->VisitExpr(op->base); PrimExpr stride = this->VisitExpr(op->stride); - if (base.same_as(op->base) && stride.same_as(op->stride)) { + PrimExpr lanes = this->VisitExpr(op->lanes); + if (base.same_as(op->base) && stride.same_as(op->stride) && lanes.same_as(op->lanes)) { return GetRef(op); } else { - return Ramp(base, stride, op->lanes); + return Ramp(base, stride, lanes); } } PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op) { PrimExpr value = this->VisitExpr(op->value); - if (value.same_as(op->value)) { + PrimExpr lanes = this->VisitExpr(op->lanes); + if (value.same_as(op->value) && lanes.same_as(op->lanes)) { return GetRef(op); } else { - return Broadcast(value, op->lanes); + return Broadcast(value, lanes); } } diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 1d1e674a9dd1..fd9ffe7dd70d 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -458,7 +458,8 @@ TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed([](PrimExpr value, Span span) TVM_REGISTER_NODE_TYPE(EvaluateNode); // BufferStore -BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, Span span) { +BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, PrimExpr predicate, + Span span) { ICHECK_EQ(buffer->shape.size(), indices.size()) << "Buffer " << buffer->name << " is " << buffer->shape.size() << "-dimensional, cannot be indexed with the " << indices.size() @@ -470,16 +471,26 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, } int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1; + bool index_scalable = indices.size() ? indices.back().dtype().is_scalable() : false; int buffer_lanes = buffer->dtype.lanes(); + bool buffer_scalable = buffer->dtype.is_scalable(); ICHECK_EQ(index_lanes * buffer_lanes, value.dtype().lanes()) << "Cannot store value with " << value.dtype().lanes() << ", expected value with " << index_lanes * buffer_lanes << " (" << index_lanes << " index lanes * " << buffer_lanes << " buffer element lanes)"; - if (buffer->dtype.with_lanes(buffer_lanes * index_lanes) != value.dtype()) { + + runtime::DataType buffer_dtype; + if (index_scalable || buffer_scalable) { + buffer_dtype = buffer->dtype.with_scalable_lanes(buffer_lanes * index_lanes); + } else { + buffer_dtype = buffer->dtype.with_lanes(buffer_lanes * index_lanes); + } + if (buffer_dtype != value.dtype()) { LOG(FATAL) << "TypeError: dtype mismatch on BufferStore: " // << "buffer's dtype is `" << buffer->dtype // << "`, the lanes of indexing are: `" << index_lanes // + << "`, the scalability is: `" << buffer_dtype.is_scalable() << "`, but RHS's dtype is `" << value.dtype() << "`"; } @@ -487,14 +498,14 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, node->buffer = std::move(buffer); node->value = std::move(value); node->indices = std::move(indices); + node->predicate = std::move(predicate); node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.BufferStore") - .set_body_typed([](Buffer buffer, PrimExpr value, Array indices, Span span) { - return BufferStore(buffer, value, indices, span); - }); + .set_body_typed([](Buffer buffer, PrimExpr value, Array indices, PrimExpr predicate, + Span span) { return BufferStore(buffer, value, indices, predicate, span); }); TVM_REGISTER_NODE_TYPE(BufferStoreNode); diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index fb92463c3c32..d8b73a8d90da 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -368,6 +368,13 @@ TIR_DEFINE_BUILTIN_FUNC(dma_start_group) TIR_DEFINE_BUILTIN_FUNC(dma_end_group) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(vscale).set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(get_active_lane_mask) + .set_num_inputs(2) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(assume) .set_attr("TCallEffectKind", Integer(CallEffectKind::kEmbedInfo)) .set_num_inputs(1); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index fd14f4892154..86317b0bbfec 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -122,20 +122,33 @@ PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span s {x, y, q, s}, span); } +void BroadcastToMatchLanes(PrimExpr& op_a, PrimExpr& op_b) { // NOLINT(*) + DataType dtype_a = op_a.dtype(); + DataType dtype_b = op_b.dtype(); + + if (dtype_a.lanes() == 1 && dtype_b.lanes() != 1) { + if (dtype_b.is_scalable()) { + op_a = tir::Broadcast( + op_a, tir::Mul(dtype_b.lanes(), Call(DataType::Int(32), builtin::vscale(), {}))); + } else { + op_a = tir::Broadcast(op_a, dtype_b.lanes()); + } + } +} + // The public function with a quick checking path. void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) CHECK(lhs.defined()) << "ValueError: `lhs` is null in the binary operator"; CHECK(rhs.defined()) << "ValueError: `rhs` is null in the binary operator"; if (lhs.dtype() == rhs.dtype()) return; + + BroadcastToMatchLanes(lhs, rhs); + BroadcastToMatchLanes(rhs, lhs); + DataType ltype = lhs.dtype(); DataType rtype = rhs.dtype(); - if (ltype.lanes() == 1 && rtype.lanes() != 1) { - lhs = tir::Broadcast(lhs, rtype.lanes()); - } else if (rtype.lanes() == 1 && ltype.lanes() != 1) { - rhs = tir::Broadcast(rhs, ltype.lanes()); - } else { - ICHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << " vs " << rtype; - } + + ICHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << " vs " << rtype; if (lhs.dtype() == rhs.dtype()) return; ltype = lhs.dtype(); @@ -337,11 +350,17 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span) { value = tir::Cast(vtype, value, span); } } - return tir::Broadcast(value, t.lanes(), span); + if (t.is_scalable()) { + return tir::Broadcast( + value, tir::Mul(t.lanes(), Call(DataType::Int(32), builtin::vscale(), {})), span); + } else { + return tir::Broadcast(value, t.lanes(), span); + } } else { ICHECK(value.dtype().lanes() == t.lanes()); + ICHECK(value.dtype().is_scalable() == t.is_scalable()); if (const auto* broadcast = value.as()) { - return tir::Broadcast(cast(vtype, broadcast->value, span), t.lanes(), span); + return tir::Broadcast(cast(vtype, broadcast->value, span), broadcast->lanes, span); } else if (const auto* ramp = value.as()) { if (t.is_int() || t.is_uint()) { // only cast to index data type can be folded to ramp diff --git a/src/tir/schedule/analysis/reducer.cc b/src/tir/schedule/analysis/reducer.cc index d8d1e8fc2572..6754ee334938 100644 --- a/src/tir/schedule/analysis/reducer.cc +++ b/src/tir/schedule/analysis/reducer.cc @@ -178,16 +178,14 @@ class PatternMatcher : public ExprVisitor { if (ptr == nullptr) { match_success_ = false; } else { - if (op->lanes != ptr->lanes) { - match_success_ = false; - } else { - PrimExpr tmp = expr_to_match_; - expr_to_match_ = ptr->base; - VisitExpr(op->base); - expr_to_match_ = ptr->stride; - VisitExpr(op->stride); - std::swap(expr_to_match_, tmp); - } + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->base; + VisitExpr(op->base); + expr_to_match_ = ptr->stride; + VisitExpr(op->stride); + expr_to_match_ = ptr->lanes; + VisitExpr(op->lanes); + std::swap(expr_to_match_, tmp); } } @@ -196,14 +194,12 @@ class PatternMatcher : public ExprVisitor { if (ptr == nullptr) { match_success_ = false; } else { - if (op->lanes != ptr->lanes) { - match_success_ = false; - } else { - PrimExpr tmp = expr_to_match_; - expr_to_match_ = ptr->value; - VisitExpr(op->value); - std::swap(expr_to_match_, tmp); - } + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->value; + VisitExpr(op->value); + expr_to_match_ = ptr->lanes; + VisitExpr(op->lanes); + std::swap(expr_to_match_, tmp); } } @@ -265,7 +261,6 @@ class PatternMatcher : public ExprVisitor { void Match(const Array& exprs_to_match) { this->match_success_ = true; this->filled_map_.clear(); - ICHECK_EQ(pattern_.size(), exprs_to_match.size()); int n_buffers = pattern_.size(); for (int i = 0; i < n_buffers; ++i) { diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 343fb7617886..930d9eb222c1 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -20,6 +20,8 @@ #include +#include "../analysis/check_contains.h" + namespace tvm { namespace tir { @@ -396,7 +398,7 @@ LoopRV ConcreteScheduleNode::Fuse(const Array& loop_rvs, bool preserve_u Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, const Array>& factor_rvs, - bool preserve_unit_iters) { + bool preserve_unit_iters, bool disable_predication) { class NotSingleInferFactorError : public ScheduleError { public: explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {} @@ -488,13 +490,14 @@ Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, tot_length *= factor; } } + if (infer_index != -1) { factors.Set(infer_index, this->analyzer_->Simplify(floordiv(loop->extent + tot_length - 1, tot_length))); } else if (!this->analyzer_->CanProve(tot_length >= loop->extent)) { throw WrongFactorProductError(state_->mod, GetRef(loop)); } - results = tir::Split(state_, loop_sref, factors, preserve_unit_iters); + results = tir::Split(state_, loop_sref, factors, preserve_unit_iters, disable_predication); TVM_TIR_SCHEDULE_END("split", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(results); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index a6c47070c8df..e4526425511a 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -108,7 +108,7 @@ class ConcreteScheduleNode : public ScheduleNode { LoopRV Fuse(const Array& loop_rvs, bool preserve_unit_iters) override; LoopRV Merge(const Array& loop_rvs) override; Array Split(const LoopRV& loop_rv, const Array>& factors, - bool preserve_unit_iters) override; + bool preserve_unit_iters, bool disable_predication) override; void Reorder(const Array& ordered_loop_rvs) override; void ReorderBlockIterVar(const BlockRV& block_rv, const Array new_order) override; LoopRV AddUnitLoop(const BlockRV& block_rv) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 02fb982f5ed9..e0bb4f39cdc2 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -204,10 +204,15 @@ Array GetOutputBlocks(const ScheduleState& self, const StmtSRef& scope * \param loop_sref The sref to the loop being split * \param factors The splitting factors * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings + * \param disable_predication If enabled, don't create a predicate for guarding the + * loop. This can be useful when splitting with scalable factors that the schedule writer + * knows are divisible. Warning: enabling this feature may result in incorrect code generation + * if not used carefully. * \return An array of srefs to the loops after splitting */ TVM_DLL Array Split(ScheduleState self, const StmtSRef& loop_sref, - const Array& factors, bool preserve_unit_iters); + const Array& factors, bool preserve_unit_iters, + bool disable_predication); /*! * \brief Merge a list of loops into one. The loops under their LCA requires: diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index a6b97bf17906..4881504d67ef 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -386,7 +386,7 @@ class DependentLoopError : public ScheduleError { }; Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array& factors, - bool preserve_unit_iters) { + bool preserve_unit_iters, bool disable_predication) { // Invariance // - The total repeat number has not changed for each direct child block with updating predicate. // - The execution order has not changed. (The block executes with the same args and the same @@ -433,7 +433,7 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array &opaque_block_reuse)(std::move(new_stmt)); // Step 3. Update predicate to guard the loop PrimExpr predicate = substitute_value < loop->extent; - if (!analyzer.CanProve(predicate, arith::ProofStrength::kSymbolicBound)) { + if (!disable_predication && !analyzer.CanProve(predicate, arith::ProofStrength::kSymbolicBound)) { new_stmt = BlockPredicateAppender(/*predicate=*/predicate)(std::move(new_stmt)); } // Step 4. Generate nested loops to replace the original loop and simplify the binding @@ -920,7 +920,7 @@ struct SplitTraits : public UnpackedInstTraits { private: static constexpr size_t kNumInputs = 2; - static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; template @@ -936,16 +936,18 @@ struct SplitTraits : public UnpackedInstTraits { static Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Array> factors, - Bool preserve_unit_iters) { - return sch->Split(loop_rv, factors, preserve_unit_iters.operator bool()); + Bool preserve_unit_iters, Bool disable_predication) { + return sch->Split(loop_rv, factors, preserve_unit_iters.operator bool(), + disable_predication.operator bool()); } static String UnpackedAsPython(Array outputs, String loop_rv, Array factors, - Bool preserve_unit_iters) { + Bool preserve_unit_iters, Bool disable_predication) { PythonAPICall py("split"); py.Input("loop", loop_rv); py.Input("factors", factors); py.Input("preserve_unit_iters", preserve_unit_iters.operator bool()); + py.Input("disable_predication", disable_predication.operator bool()); py.OutputList(outputs); return py.Str(); } diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index e55a5cf8078c..e25a9eaf6ac3 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -226,8 +226,9 @@ LoopRV TracedScheduleNode::Fuse(const Array& loop_rvs, bool preserve_uni Array TracedScheduleNode::Split(const LoopRV& loop_rv, const Array>& factor_rvs, - bool preserve_unit_iters) { - Array results = ConcreteScheduleNode::Split(loop_rv, factor_rvs, preserve_unit_iters); + bool preserve_unit_iters, bool disable_predication) { + Array results = + ConcreteScheduleNode::Split(loop_rv, factor_rvs, preserve_unit_iters, disable_predication); std::vector inputs; inputs.reserve(1 + factor_rvs.size()); @@ -237,10 +238,12 @@ Array TracedScheduleNode::Split(const LoopRV& loop_rv, } static const InstructionKind& kind = InstructionKind::Get("Split"); - trace_->Append(/*inst=*/Instruction(/*kind=*/kind, - /*inputs=*/inputs, - /*attrs=*/{Integer(preserve_unit_iters)}, - /*outputs=*/{results.begin(), results.end()})); + trace_->Append( + /*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/inputs, + /*attrs=*/{Integer(preserve_unit_iters), Integer(disable_predication)}, + /*outputs=*/{results.begin(), results.end()})); return results; } diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index d7a42f63d4dc..1d801f017d85 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -67,7 +67,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { LoopRV Fuse(const Array& loop_rvs, bool preserve_unit_iters) final; LoopRV Merge(const Array& loop_rvs) final; Array Split(const LoopRV& loop_rv, const Array>& factor_rvs, - bool preserve_unit_iters) final; + bool preserve_unit_iters, bool disable_predication) final; void Reorder(const Array& ordered_loop_rvs) final; void ReorderBlockIterVar(const BlockRV& block_rv, const Array new_order) final; LoopRV AddUnitLoop(const BlockRV& block_rv) final; diff --git a/src/tir/transforms/bound_checker.cc b/src/tir/transforms/bound_checker.cc index f5aa6773e66d..358f864d3a24 100644 --- a/src/tir/transforms/bound_checker.cc +++ b/src/tir/transforms/bound_checker.cc @@ -34,6 +34,8 @@ #include #include +#include "../../arith/unwrap_vector_expr.h" + namespace tvm { namespace tir { @@ -156,7 +158,12 @@ class BoundChecker : public StmtExprMutator { if (!IsValidScalar(ramp_index->stride)) { return false; } - if (ramp_index->lanes <= 0) { + bool lanes_int = ramp_index->lanes->IsInstance(); + if (!lanes_int) { + return false; + } + int lanes = static_cast(Downcast(ramp_index->lanes)->value); + if (lanes <= 0) { return false; } } @@ -192,11 +199,7 @@ class BoundChecker : public StmtExprMutator { PrimExpr upper_bound = shape[i]; if (const RampNode* ramp_index = index.as()) { - // In case index is base + stride * i. - // Non inclusive range. - index = Add(ramp_index->base, - Mul(ramp_index->stride, - make_const(ramp_index->stride.dtype(), ramp_index->lanes - 1))); + index = arith::UnwrapVectorExpr(GetRef(ramp_index), ramp_index->lanes); } // Try to simplify index and bound. diff --git a/src/tir/transforms/inject_rolling_buffer.cc b/src/tir/transforms/inject_rolling_buffer.cc index 5f7b9b4156c3..afa3e15f0260 100644 --- a/src/tir/transforms/inject_rolling_buffer.cc +++ b/src/tir/transforms/inject_rolling_buffer.cc @@ -257,7 +257,8 @@ class RollingBufferInjector : public StmtExprMutator { indices.push_back(index); } } - Stmt buffer_store = BufferStore(op->buffer, op->value, indices, op->span); + ICHECK(!op->predicate.defined()) << "Indices change can affect the predicate"; + Stmt buffer_store = BufferStore(op->buffer, op->value, indices, op->predicate, op->span); // Then wrap the BufferStores in some Ifs to avoid recomputing elements for (size_t i{0}; i < rolling_buffer_info.axis_iter_vars.size(); ++i) { auto iter_var{rolling_buffer_info.axis_iter_vars[i]}; @@ -293,7 +294,8 @@ class RollingBufferInjector : public StmtExprMutator { indices.push_back(index); } } - return BufferLoad(op->buffer, indices, op->span); + ICHECK(!op->predicate.defined()) << "Indices change can affect the predicate"; + return BufferLoad(op->buffer, indices, op->predicate, op->span); } else { return expr; } diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index 700587fe0e21..6579aa03dc03 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -97,6 +97,7 @@ class MatchBufferLower : public StmtExprMutator { auto n = CopyOnWrite(op); n->indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); n->buffer = source->buffer; + ICHECK(!op->predicate.defined()) << "Indices change can affect the predicate"; return Stmt(n); } } @@ -113,6 +114,7 @@ class MatchBufferLower : public StmtExprMutator { const Buffer& buffer = (*it).first; const BufferRegion& source = (*it).second; Array indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); + ICHECK(!op->predicate.defined()) << "Indices change can affect the predicate"; return BufferLoad(source->buffer, indices); } } diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 870235954689..392373d40b63 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -137,6 +137,8 @@ class WarpStoreCoeffFinder : private StmtExprVisitor { PrimExpr index = op->indices[0]; if (op->value.dtype().lanes() != 1) { + ICHECK(!op->value.dtype().is_scalable()) + << "Scalable vectors are not supported in lower_warp_memory"; arith::PVar base; ICHECK(arith::ramp(base, 1, op->value.dtype().lanes()).Match(index)) << "LowerWarpMemory failed due to store index=" << index @@ -343,6 +345,8 @@ class WarpAccessRewriter : protected StmtExprMutator { std::pair SplitIndexByGroup(const PrimExpr& index) { if (index.dtype().lanes() != 1) { arith::PVar base; + ICHECK(!index.dtype().is_scalable()) + << "Scalable vectors are not supported in lower_warp_memory"; ICHECK(arith::ramp(base, 1, index.dtype().lanes()).Match(index)); auto [local_index, group] = SplitIndexByGroup(base.Eval()); diff --git a/src/tir/transforms/manifest_shared_memory_local_stage.cc b/src/tir/transforms/manifest_shared_memory_local_stage.cc index 619a9f0a9e8f..f5b905a3f502 100644 --- a/src/tir/transforms/manifest_shared_memory_local_stage.cc +++ b/src/tir/transforms/manifest_shared_memory_local_stage.cc @@ -67,6 +67,7 @@ class IntermediateStageRewriter { Stmt local_stage = MakeLocalStage(block, new_buffer, buffer_indices, relaxed_loops, store); // Step 3: Create BufferLoad from the intermediate buffer + ICHECK(!store->predicate.defined()); BufferLoad new_buffer_load = BufferLoad(new_buffer, buffer_indices); BufferStore new_buffer_store = Downcast(block->body); new_buffer_store.CopyOnWrite()->value = new_buffer_load; diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index bc606aa0b7ff..3b418aac0cf5 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -213,7 +213,8 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { // A write whose destination is known to already contain the // values to be written is a no-op. // PrimExpr stores_existing_value = store->value == BufferLoad(store->buffer, store->indices); - PrimExpr stores_existing_value = store->value - BufferLoad(store->buffer, store->indices) == 0; + PrimExpr stores_existing_value = + store->value - BufferLoad(store->buffer, store->indices, store->predicate) == 0; if (touch_pattern_.has_value()) { Stmt context_arg = context_ ? GetRef(context_) : Stmt(store); stores_existing_value = diff --git a/src/tir/transforms/remove_weight_layout_rewrite_block.cc b/src/tir/transforms/remove_weight_layout_rewrite_block.cc index 05b636f11403..e8d89bfb5700 100644 --- a/src/tir/transforms/remove_weight_layout_rewrite_block.cc +++ b/src/tir/transforms/remove_weight_layout_rewrite_block.cc @@ -196,7 +196,7 @@ class AllocateConstRewrite : public StmtExprMutator { op->buffer->elem_offset, it->second->name_hint, op->buffer->data_alignment, op->buffer->offset_factor, op->buffer->buffer_type); new_load_buf_[op->buffer->data.get()] = new_buffer; - return BufferLoad(new_buffer, op->indices); + return BufferLoad(new_buffer, op->indices, op->predicate); } return ExprMutator::VisitExpr_(op); } diff --git a/src/tir/transforms/renormalize_split_pattern.cc b/src/tir/transforms/renormalize_split_pattern.cc index eb596beb181a..beb5997d4982 100644 --- a/src/tir/transforms/renormalize_split_pattern.cc +++ b/src/tir/transforms/renormalize_split_pattern.cc @@ -63,7 +63,7 @@ class SplitPatternReNormalizer : public IRMutatorWithAnalyzer { // Pattern var match IntImm PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp - PVar lanes; + PVar lanes; // floordiv(floormod(x, c1 * c2), c2) = floormod(floordiv(x, c2), c1) TRY_RECURSIVE_REWRITE_IF(floordiv(floormod(x, c3), c2), diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 9c1244838173..2742498e5d74 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -730,7 +730,7 @@ class ThreadScopePropagate : public StmtExprMutator { auto it = buf_remap_.find(op->buffer->data); if (it != buf_remap_.end()) { - return BufferLoad(it->second, op->indices, op->span); + return BufferLoad(it->second, op->indices, op->predicate, op->span); } else { return expr; } @@ -743,7 +743,7 @@ class ThreadScopePropagate : public StmtExprMutator { auto it = buf_remap_.find(op->buffer->data); if (it != buf_remap_.end()) { - return BufferStore(it->second, op->value, op->indices, op->span); + return BufferStore(it->second, op->value, op->indices, op->predicate, op->span); } else { return stmt; } @@ -938,8 +938,10 @@ class BufferBindUnwrapper : public StmtExprMutator { const BufferEntry& e = GetBufferEntry(op->buffer); if (e.remap) { + ICHECK(!op->predicate.defined()) << "Remapping predicate is not supported"; return BufferLoad(e.remap->target, - remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span); + remap_indices(op->indices, e.remap->begins, e.remap->extents), + op->predicate, op->span); } else { return expr; } @@ -952,8 +954,10 @@ class BufferBindUnwrapper : public StmtExprMutator { const BufferEntry& e = GetBufferEntry(op->buffer); if (e.remap) { + ICHECK(!op->predicate.defined()) << "Remapping predicate is not supported"; return BufferStore(e.remap->target, op->value, - remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span); + remap_indices(op->indices, e.remap->begins, e.remap->extents), + op->predicate, op->span); } else { return stmt; } @@ -1418,7 +1422,8 @@ class StorageFlattener : public StmtExprMutator { auto flattened_indices = e.buffer->ElemOffset(op->indices); - Stmt body = BufferStore(e.flattened_buffer, value, flattened_indices, op->span); + ICHECK(!op->predicate.defined()) << "Changing the index can affect the predicate"; + Stmt body = BufferStore(e.flattened_buffer, value, flattened_indices, op->predicate, op->span); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); } @@ -1574,7 +1579,9 @@ class StorageFlattener : public StmtExprMutator { } auto flattened_indices = e.buffer->ElemOffset(op->indices); - PrimExpr val = BufferLoad(e.flattened_buffer, flattened_indices, op->span); + + ICHECK(!op->predicate.defined()) << "Changing the index can affect the predicate"; + PrimExpr val = BufferLoad(e.flattened_buffer, flattened_indices, op->predicate, op->span); if (op->dtype == DataType::Bool()) { ICHECK_EQ(e.flattened_buffer->dtype, DataType::Int(8)) diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 6875523a956d..5c1026258561 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -1318,9 +1318,12 @@ class VectorTypeAccessChecker : public StmtExprVisitor { if (indices.size()) { const RampNode* ramp_index = indices[indices.size() - 1].as(); if (ramp_index && is_one(ramp_index->stride)) { - arith::ModularSet me = analyzer_.modular_set(ramp_index->base); - if ((me->coeff % ramp_index->lanes == 0) && (me->base % ramp_index->lanes == 0)) { - lanes_used = ramp_index->lanes; + if (ramp_index->lanes->IsInstance()) { + int lanes = static_cast(Downcast(ramp_index->lanes)->value); + arith::ModularSet me = analyzer_.modular_set(ramp_index->base); + if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) { + lanes_used = lanes; + } } } } @@ -1453,13 +1456,15 @@ class VectorTypeRewriter : public StmtExprMutator { Array indices = node->indices; const PrimExpr& last_dim_index = indices[indices.size() - 1]; - if (const RampNode* ramp_index = last_dim_index.as(); - ramp_index && is_one(ramp_index->stride)) { - PrimExpr new_index = - ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes); - if (ramp_index->lanes != info.factor()) { - ICHECK(info.factor() && ramp_index->lanes % info.factor() == 0); - int new_lanes = ramp_index->lanes / info.factor(); + const RampNode* ramp_index = indices[indices.size() - 1].as(); + if (ramp_index && is_one(ramp_index->stride)) { + ICHECK(ramp_index->lanes->IsInstance()) + << "Rewriting pointer type into scalable type is currently not supported"; + auto lanes = static_cast(Downcast(ramp_index->lanes)->value); + PrimExpr new_index = ramp_index->base / make_const(ramp_index->base.dtype(), lanes); + if (lanes != info.factor()) { + ICHECK(info.factor() && lanes % info.factor() == 0); + int new_lanes = lanes / info.factor(); new_index = Ramp(new_index * new_lanes, ramp_index->stride, new_lanes, ramp_index->span); } indices.Set(indices.size() - 1, new_index); diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index 030dbd01badf..0882e1c4c3af 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -333,6 +333,8 @@ class ComputeLegalizer : public StmtExprMutator { ICHECK(MatchDType(value->dtype)); value = cast(new_buf->dtype.with_lanes(value.dtype().lanes()), value); } + ICHECK(!op->predicate.defined()) + << "Predicated buffers are not supported in data type legalizer "; return BufferStore(new_buf, value, indices); } } @@ -404,6 +406,8 @@ class ComputeLegalizer : public StmtExprMutator { if (new_buf.same_as(op->buffer)) { return ret; } else { + ICHECK(!op->predicate.defined()) + << "Predicated buffers are not supported in data type legalizer "; return BufferLoad(new_buf, op->indices); } } @@ -565,6 +569,8 @@ class StorageLegalizer : public StmtExprMutator { if (MatchDType(op->value.dtype())) { ICHECK(new_buf->dtype.is_uint()); } + ICHECK(!op->predicate.defined()) + << "Predicated buffers are not supported in data type legalizer "; return BufferStore(new_buf, value, indices); } } @@ -598,6 +604,8 @@ class StorageLegalizer : public StmtExprMutator { if (new_buf.same_as(op->buffer)) { return ret; } else { + ICHECK(!op->predicate.defined()) + << "Predicated buffers are not supported in data type legalizer "; return BufferLoad(new_buf, op->indices); } } diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index b80a71aa311c..ecd26e7ba2f6 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -35,14 +35,30 @@ #include #include +#include "../../target/parsers/aprofile.h" + namespace tvm { namespace tir { -inline PrimExpr BroadcastTo(PrimExpr e, int lanes) { - if (e.dtype().lanes() == lanes) return e; - if (const BroadcastNode* op = e.as()) { - if (lanes % op->lanes == 0) { - return Broadcast(op->value, lanes); +inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool scalable) { + if (!scalable) { + // Check that we are not broadcasting a scalable vector into fixed length vector + ICHECK(!e.dtype().is_scalable()); + if (e.dtype().lanes() == lanes) return e; + if (const BroadcastNode* op = e.as()) { + int op_lanes = op->lanes.as()->value; + if (lanes % op_lanes == 0) { + return Broadcast(op->value, lanes); + } + } + } else { + if (e.dtype().is_scalable() && e.dtype().lanes() == lanes) { + // It's already a scalable vector in a correct form + return e; + } else { + ICHECK(lanes % e.dtype().lanes() == 0); + PrimExpr scalable_lanes = Mul(lanes, Call(DataType::Int(32), builtin::vscale(), {})); + return Broadcast(e, scalable_lanes); } } ICHECK_EQ(e.dtype().lanes(), 1) << "Cannot broadcast lane=" << e.dtype().lanes() << " to " @@ -50,6 +66,24 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes) { return Broadcast(e, lanes); } +bool EnableBufferPredication() { + transform::PassContext pass_ctx = transform::PassContext::Current(); + bool enable_buffer_predication = + pass_ctx->GetConfig("tir.enable_buffer_predication", Bool(false)).value(); + if (enable_buffer_predication) { + return true; + } + + // When compiling for aarch64 devices with SVE, we should enable predication by default + Target current_target = Target::Current(); + if (!current_target.defined()) { + return false; + } + TargetJSON target_json = target::parsers::aprofile::ParseTarget(current_target->Export()); + TargetFeatures features = Downcast(target_json.at("features")); + return Downcast(features.at("has_sve")); +} + // Rewrite vectorized allocation access // This is necessary for making each vector component containing its own workspace. // Originates from Halide's loop vectorizer @@ -60,7 +94,7 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes) { // class VecAllocAccess : public StmtExprMutator { public: - VecAllocAccess(const VarNode* buf, Var var, int var_lanes) + VecAllocAccess(const VarNode* buf, Var var, PrimExpr var_lanes) : buf_(buf), var_(var), var_lanes_(var_lanes) {} PrimExpr VisitExpr_(const BufferLoadNode* op) final { @@ -136,7 +170,7 @@ class VecAllocAccess : public StmtExprMutator { // variable to be replaced Var var_; // the lanes. - int var_lanes_; + PrimExpr var_lanes_; // Analyzer for simplifications arith::Analyzer analyzer_; }; @@ -149,7 +183,7 @@ class Vectorizer : public StmtMutator, public ExprFunctordtype, 0), IntImm(var->dtype, 1), var_lanes); } @@ -191,7 +225,8 @@ class Vectorizer : public StmtMutator, public ExprFunctorbase * a, b_ramp->stride * a, b_ramp->lanes); } } - return Mul(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); + bool is_scalable = a.dtype().is_scalable() || b.dtype().is_scalable(); + return Mul(BroadcastTo(a, lanes, is_scalable), BroadcastTo(b, lanes, is_scalable)); } return BinaryVec(op); } @@ -224,13 +259,17 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->stride); if (base.dtype().lanes() > 1 && stride.dtype().lanes() == 1) { const RampNode* base_ramp = base.as(); - if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.dtype(), op->lanes))) { + ICHECK(op->lanes->IsInstance()) << "Scalable ramp of ramps is not supported yet"; + int lanes = static_cast(Downcast(op->lanes)->value); + if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.dtype(), lanes))) { return Ramp(base_ramp->base, stride, op->lanes * base_ramp->lanes); } } + ICHECK(!base.dtype().is_scalable()) << "Ramp base with scalable dtype is not supported"; + ICHECK(!stride.dtype().is_scalable()) << "Ramp stride with scalable dtype is not supported"; int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes()); - base = BroadcastTo(base, lanes); - stride = BroadcastTo(stride, lanes); + base = BroadcastTo(base, lanes, false); + stride = BroadcastTo(stride, lanes, false); Array elems; for (int i = 0; i < lanes; ++i) { elems.push_back( @@ -260,15 +299,21 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op); } else { int lanes = std::max(std::max(cond.dtype().lanes(), t.dtype().lanes()), f.dtype().lanes()); - return Select(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes)); + bool is_scalable = t.dtype().is_scalable() || t.dtype().is_scalable(); + return Select(cond, BroadcastTo(t, lanes, is_scalable), BroadcastTo(f, lanes, is_scalable)); } } + PrimExpr VisitExpr_(const CastNode* op) final { PrimExpr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { return GetRef(op); } else { - return Cast(op->dtype.with_lanes(value.dtype().lanes()), value); + if (value.dtype().is_scalable()) { + return Cast(op->dtype.with_scalable_lanes(value.dtype().lanes()), value); + } else { + return Cast(op->dtype.with_lanes(value.dtype().lanes()), value); + } } } @@ -305,9 +350,14 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op); } else { int lanes = std::max(t.dtype().lanes(), f.dtype().lanes()); - t = BroadcastTo(t, lanes); - f = BroadcastTo(f, lanes); - return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}); + bool is_scalable = t.dtype().is_scalable() || f.dtype().is_scalable(); + t = BroadcastTo(t, lanes, is_scalable); + f = BroadcastTo(f, lanes, is_scalable); + if (op->dtype.is_scalable()) { + return Call(op->dtype.with_scalable_lanes(lanes), op->op, {cond, t, f}); + } else { + return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}); + } } } // Call @@ -413,10 +463,14 @@ class Vectorizer : public StmtMutator, public ExprFunctorindices) || !value.same_as(op->value)) { // How many lanes of indexing are present in the index and - // buffer element type, excluding the last index. T + // buffer element type, excluding the last index. int other_index_lanes = op->buffer->dtype.lanes(); + ICHECK(!op->buffer->dtype.is_scalable()) + << "Scalable buffer elements are not supported in vectorizer"; for (size_t i = 0; i < indices.size() - 1; i++) { other_index_lanes *= indices[i].dtype().lanes(); + // Only allow the last index to be scalable + ICHECK(!indices[i].dtype().is_scalable()); } // The total number of lanes of indexing, including the last index. @@ -434,11 +488,13 @@ class Vectorizer : public StmtMutator, public ExprFunctorindices = indices; - writer->value = BroadcastTo(value, total_lanes); + writer->value = BroadcastTo(value, total_lanes, is_last_index_scalable); } return std::move(store); @@ -462,18 +518,137 @@ class Vectorizer : public StmtMutator, public ExprFunctorannotations); } } + + /*! + * \brief A pass that tries to rewrite buffer accesses (loads and stores) with a + * predicate expression where possible. + * + * \note For now we start with a minimalized case targeting block-level predicates + * produced by the split schedule primitive, with the potential for predicating + * more complex terms in the future if needed. + * + * \example + * Before: + * for i_0 in T.serial(4): + * for i_1 in T.vectorized(4): + * if i_0 * 4 + i_1 < 14: + * B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0 + * + * After: + * for i_0 in T.serial(4): + * pred = T.get_active_lane_mask(i_0 * 4, 14) + * B(pred=pred)[i_0 * 4:i_0 * 4 + 4] = A(pred=pred)[i_0 * 4:i_0 * 4 + 4] + \ + * T.Broadcast(T.float32(1), 4) + */ + class TryPredicateBufferAccesses : public StmtExprMutator { + public: + TryPredicateBufferAccesses() {} + + /*! + * \brief Run the pass to try to exact predicates. + * \param stmt - The statement containing buffer accesses (loads and stores) + * we want to attempt to predicate. + * \param condition - The conditional expression (block-level predicate) + * that we will try to remove. + * \return pair - Boolean value for success/failure, the rewritten + * stmt if successful. + */ + std::pair Run(Stmt stmt, PrimExpr condition) { + // Check that the condition provided is of the form a < b, for now. + if (!condition->IsInstance()) { + return {false, stmt}; + } + + LT lt = Downcast(condition); + + // Check the form of the vectorized condition, we're expecting + // Ramp(...) < Broadcast(...) + if (!lt->a->IsInstance() || !lt->b->IsInstance()) { + return {false, stmt}; + } + + base_ = Downcast(lt->a)->base; + limit_ = Downcast(lt->b)->value; + + // Now we can try to predicate + Stmt predicated_stmt = StmtExprMutator::operator()(std::move(stmt)); + if (num_accesses_analyzed_ > 0 && num_accesses_analyzed_ == num_accesses_rewritten_) { + return {true, predicated_stmt}; + } + return {false, stmt}; + } + + private: + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto load = Downcast(StmtExprMutator::VisitExpr_(op)); + return TryPredicateBufferAccess(load); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto store = Downcast(StmtExprMutator::VisitStmt_(op)); + return TryPredicateBufferAccess(store); + } + + template + AccessNode TryPredicateBufferAccess(AccessNode node) { + num_accesses_analyzed_ += 1; + + // Do not try to predicate non-vectorized accesses + Array indices = node->indices; + if (!indices.size() || !indices[0]->IsInstance()) { + return node; + } + Ramp ramp = Downcast(node->indices[0]); + + // The vectorized access pattern must match the base of the predicate + if (!tvm::StructuralEqual()(ramp->base, base_)) { + return node; + } + + DataType buf_predicate_dtype = + DataType(DataType::kInt, 1, ramp->dtype.lanes(), ramp->dtype.is_scalable()); + Call lane_mask = Call(buf_predicate_dtype, builtin::get_active_lane_mask(), {base_, limit_}); + + num_accesses_rewritten_ += 1; + auto writer = node.CopyOnWrite(); + writer->predicate = lane_mask; + return node; + } + + /*! \brief The variable base expr of the predicate. */ + PrimExpr base_; + /*! \brief The limit of the predicate. The expr specifies the upper bound of the base's + * evaluated value. */ + PrimExpr limit_; + /*! \brief The number of buffer accesses in the stmt we will analyze. */ + size_t num_accesses_analyzed_ = 0; + /*! \brief The number of buffer accesses rewritten with predicates. */ + size_t num_accesses_rewritten_ = 0; + }; + // IfThenElse Stmt VisitStmt_(const IfThenElseNode* op) final { ICHECK(!op->condition.dtype().is_vector()); PrimExpr condition = this->VisitExpr(op->condition); - if (condition.dtype().is_vector()) { - return Scalarize(GetRef(op)); - } Stmt then_case = this->VisitStmt(op->then_case); Optional else_case = NullOpt; if (op->else_case) { else_case = this->VisitStmt(op->else_case.value()); } + + // Check if we can rewrite the condition with predicated buffers + if (EnableBufferPredication() && condition.dtype().is_vector() && !else_case.defined()) { + std::pair success_stmt_pair = + TryPredicateBufferAccesses().Run(then_case, condition); + bool can_remove_if_then_else = success_stmt_pair.first; + if (can_remove_if_then_else) { + return success_stmt_pair.second; + } + } + + if (condition.dtype().is_vector()) { + return Scalarize(GetRef(op)); + } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); @@ -481,6 +656,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorname_hint + ".s", var_->dtype); stmt = Substitute(stmt, {{var_, idx}}); - return For(idx, IntImm(var_->dtype, 0), IntImm(var_->dtype, var_lanes_), ForKind::kSerial, - stmt); + return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, stmt); } // ProducerStore Stmt VisitStmt_(const ProducerStoreNode* op) final { @@ -561,7 +736,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor MutateArray(Array arr, int* p_lanes) { if (arr.size() == 0) return arr; int& lanes = *p_lanes; @@ -588,7 +764,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op); } else { int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); - return TOp(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); + bool is_scalable = a.dtype().is_scalable() || b.dtype().is_scalable(); + return TOp(BroadcastTo(a, lanes, is_scalable), BroadcastTo(b, lanes, is_scalable)); } } template @@ -626,7 +803,8 @@ class Vectorizer : public StmtMutator, public ExprFunctorbase, b), a_ramp->stride, a_ramp->lanes); } } - return fcompute(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); + bool is_scalable = a.dtype().is_scalable() || b.dtype().is_scalable(); + return fcompute(BroadcastTo(a, lanes, is_scalable), BroadcastTo(b, lanes, is_scalable)); } } }; @@ -636,11 +814,7 @@ class LoopVectorizer : public StmtMutator { Stmt VisitStmt_(const ForNode* op) final { if (op->kind == ForKind::kVectorized) { ICHECK(is_zero(op->min)); - auto* extent_as_int = op->extent.as(); - if (!extent_as_int || extent_as_int->value < 1) { - LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent; - } - return Vectorizer(op->loop_var, static_cast(extent_as_int->value))(op->body); + return Vectorizer(op->loop_var, op->extent)(op->body); } else { return StmtMutator::VisitStmt_(op); } diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index 2e386c48b75c..f5e1210b807f 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -27,9 +27,11 @@ TEST(Pattern, Basic) { using namespace tvm::tir; using namespace tvm::arith; tvm::tir::Var x("x"), y("y"), z("z"); + PrimExpr scalable_lanes = Mul(4, Call(DataType::Int(32), builtin::vscale(), {})); arith::PVar px, py, pz; arith::PVar pt; - arith::PVar planes; + arith::PVar planes; + arith::PCallExpr vscale; // arithmetics auto r = 1 + (y + 1); @@ -110,14 +112,18 @@ TEST(Pattern, Basic) { // ramp pattern { ICHECK(ramp(px, PConst(1), planes).Match(tir::Ramp(x, 1, 10))); - ICHECK(planes.Eval() == 10); + ICHECK(planes.Eval().as()->value == 10); + ICHECK(ramp(px, PConst(1), planes).Match(tir::Ramp(x, 1, scalable_lanes))); + ICHECK((PConst(4) * vscale).Match(planes.Eval())); ICHECK(!ramp(px, PConst(1), planes).Match(tir::Ramp(x, 2, 10))); } // broadcast pattern { ICHECK(broadcast(px, planes).Match(tir::Broadcast(x, 10))); - ICHECK(planes.Eval() == 10); + ICHECK(planes.Eval().as()->value == 10); ICHECK(broadcast(px * py, planes).Match(tir::Broadcast(x * 10, 10))); + ICHECK(broadcast(px, planes).Match(tir::Broadcast(x, scalable_lanes))); + ICHECK((PConst(4) * vscale).Match(planes.Eval())); } } diff --git a/tests/cpp/tir_scalable_datatype.cc b/tests/cpp/tir_scalable_datatype.cc new file mode 100644 index 000000000000..59162a446b93 --- /dev/null +++ b/tests/cpp/tir_scalable_datatype.cc @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include "../../src/script/printer/utils.h" + +using ::testing::HasSubstr; + +// --------- +// Data Type +// --------- +TEST(TIR, TestCreateScalableType) { + tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true); + ASSERT_EQ(scalable_type.code(), kDLInt); + ASSERT_EQ(scalable_type.bits(), 32); + ASSERT_EQ(scalable_type.lanes(), 4); + ASSERT_TRUE(scalable_type.is_scalable()); + ASSERT_TRUE(scalable_type.is_vector()); +} + +TEST(TIR, TestScalableWithBits) { + tvm::DataType scalable_type = tvm::DataType(kDLInt, 1, 1, true); + scalable_type = scalable_type.with_bits(32); + ASSERT_EQ(scalable_type.bits(), 32); + ASSERT_TRUE(scalable_type.is_scalable()); + ASSERT_TRUE(scalable_type.is_vector()); +} + +TEST(TIR, TestScalableWithLanes) { + tvm::DataType type = tvm::DataType(kDLInt, 32, 1); + tvm::DataType scalable_type = type.with_scalable_lanes(4); + ASSERT_EQ(scalable_type.lanes(), 4); + ASSERT_TRUE(scalable_type.is_scalable()); + ASSERT_TRUE(scalable_type.is_vector()); +} + +TEST(TIR, TestAssignScalableDataType) { + tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 1, true); + tvm::DataType scalable_type_copy = scalable_type; + ASSERT_TRUE(scalable_type_copy.is_scalable()); + ASSERT_TRUE(scalable_type_copy.is_vector()); +} + +TEST(TIR, TestScalableDataTypeEquality) { + ASSERT_TRUE(tvm::DataType(kDLInt, 32, 4, true) == tvm::DataType(kDLInt, 32, 4, true)); +} + +TEST(TIR, TestScalableDataTypeAndNonScalableDataTypeInequality) { + ASSERT_FALSE(tvm::DataType(kDLInt, 32, 4, true) == tvm::DataType(kDLInt, 32, 4)); +} + +TEST(TIR, TestScalableDataTypeToString) { + tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true); + EXPECT_EQ(tvm::runtime::DLDataType2String(scalable_type), "int32x4xvscale"); +} + +TEST(TIR, TestStringToScalableDataType) { + std::string scalable_type_str = "int32x4xvscale"; + EXPECT_EQ(tvm::DataType(tvm::runtime::String2DLDataType(scalable_type_str)), + tvm::DataType(kDLInt, 32, -4)); +} + +TEST(TIR, TestInvalidStringToScalableDataType) { + std::string scalable_type_str = "int32xvscalex4"; + EXPECT_THROW( + { + try { + tvm::runtime::String2DLDataType(scalable_type_str); + } catch (const tvm::InternalError& e) { + EXPECT_THAT(e.what(), HasSubstr("unknown type int32xvscalex4")); + throw; + } + }, + tvm::InternalError); +} + +TEST(TIR, TestGetScalableVectorBytes) { + tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true); + EXPECT_THROW( + { + try { + tvm::runtime::GetVectorBytes(scalable_type); + } catch (const tvm::InternalError& e) { + EXPECT_THAT(e.what(), HasSubstr("Cannot get vector bytes of scalable vector")); + throw; + } + }, + tvm::InternalError); +} + +// ----------- +// Integration +// ----------- +TEST(TIR, TestScalableIntrinCall) { + tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true); + tvm::tir::Call call = tvm::tir::Call( + scalable_type, tvm::tir::builtin::call_llvm_intrin(), + {tvm::IntImm(tvm::DataType::Int(32), ::llvm::Intrinsic::experimental_stepvector)}); + ASSERT_EQ(call->dtype, scalable_type); + ASSERT_EQ(call->Script(), + "T.call_llvm_intrin(\"int32x4xvscale\", \"llvm.experimental.stepvector\")"); +} + +TEST(TIR, TestTIRScriptScalableDtype2Str) { + tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true); + ASSERT_EQ(tvm::script::printer::DType2Str(scalable_type), "int32x4xvscale"); +} diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py index 5b0627542204..21ab753330d7 100644 --- a/tests/python/arith/test_arith_rewrite_simplify.py +++ b/tests/python/arith/test_arith_rewrite_simplify.py @@ -80,8 +80,19 @@ class TestVector(BaseCompare): TestCase(tvm.tir.Ramp(x, 1, 4) + tvm.tir.Ramp(y, 2, 4), tvm.tir.Ramp(x + y, 3, 4)), TestCase(tvm.tir.Ramp(x, 1, 2) + y, tvm.tir.Ramp(x + y, 1, 2)), TestCase(y + tvm.tir.Ramp(x, 1, 2), tvm.tir.Ramp(y + x, 1, 2)), + TestCase( + tvm.tir.Ramp(x, 1, 4 * tir.vscale()) + tvm.tir.Ramp(y, 2, 4 * tir.vscale()), + tvm.tir.Ramp(x + y, 3, 4 * tir.vscale()), + ), TestCase(y.astype("int32x2") + x.astype("int32x2"), (y + x).astype("int32x2")), + TestCase( + y.astype("int32x8xvscale") + x.astype("int32x8xvscale"), + (y + x).astype("int32x8xvscale"), + ), TestCase(tvm.tir.Broadcast(0, 4) + y, tvm.tir.Broadcast(y, 4)), + TestCase( + tvm.tir.Broadcast(0, 8 * tir.vscale()) + y, tvm.tir.Broadcast(y, 8 * tir.vscale()) + ), TestCase( tvm.tir.Ramp(x, 1, 4).astype("float32x4") + tvm.tir.Broadcast(0.0, 4), tvm.tir.Ramp(x, 1, 4).astype("float32x4"), @@ -100,22 +111,43 @@ class TestVector(BaseCompare): ## DivMod rules # trunc div TestCase(tdiv(y.astype("int32x2"), x.astype("int32x2")), tdiv(y, x).astype("int32x2")), + TestCase( + tdiv(y.astype("int32x4xvscale"), x.astype("int32x4xvscale")), + tdiv(y, x).astype("int32x4xvscale"), + ), TestCase(tdiv(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Ramp(tdiv(x, 2), 2, 4)), + TestCase( + tdiv(tvm.tir.Ramp(x, 4, 5 * tir.vscale()), 2), + tvm.tir.Ramp(tdiv(x, 2), 2, 5 * tir.vscale()), + ), TestCase(tdiv(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), x.astype("int32x4"), x >= 0), TestCase(tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8), tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8)), # trunc mod TestCase(tmod(y.astype("int32x2"), x.astype("int32x2")), tmod(y, x).astype("int32x2")), TestCase(tmod(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Broadcast(tmod(x, 2), 4)), TestCase(tmod(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), tvm.tir.Ramp(1, 1, 4), x >= 0), + TestCase( + tmod(tvm.tir.Ramp(x * 8 + 1, 1, 4 * tir.vscale()), 8), + tmod(tvm.tir.Ramp(1, 1, 4 * tir.vscale()), 8), + x >= 0, + ), TestCase(tmod(tvm.tir.Ramp(x * 8 + 1, 15, 4), 8), tmod(tvm.tir.Ramp(1, 15, 4), 8), x >= 0), # floor div TestCase(fld(y.astype("int32x2"), x.astype("int32x2")), fld(y, x).astype("int32x2")), TestCase(fld(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Ramp(fld(x, 2), 2, 4)), + TestCase( + fld(tvm.tir.Ramp(x, 4, 4 * tir.vscale()), 2), + tvm.tir.Ramp(fld(x, 2), 2, 4 * tir.vscale()), + ), TestCase(fld(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), (x).astype("int32x4")), TestCase(fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8), fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8)), TestCase( fld(tvm.tir.Ramp(x, 8, 5), tvm.tir.Broadcast(4, 5)), tvm.tir.Ramp(fld(x, 4), 2, 5) ), + TestCase( + fld(tvm.tir.Ramp(x, 8, 4 * tir.vscale()), tvm.tir.Broadcast(4, 4 * tir.vscale())), + tvm.tir.Ramp(fld(x, 4), 2, 4 * tir.vscale()), + ), TestCase( fld(tvm.tir.Ramp(flm(x * 4, 256), 1, 4), tvm.tir.Broadcast(8, 4)), tvm.tir.Broadcast(fld(flm(x * 4, 256), 8), 4), @@ -127,6 +159,10 @@ class TestVector(BaseCompare): TestCase( fld(tvm.tir.Ramp(x * 8, 1, 4), tvm.tir.Broadcast(4, 4)), tvm.tir.Broadcast(x * 2, 4) ), + TestCase( + fld(tvm.tir.Ramp(x * 8, 1, 4 * tir.vscale()), tvm.tir.Broadcast(4, 4 * tir.vscale())), + fld(tvm.tir.Ramp(x * 8, 1, 4 * tir.vscale()), tvm.tir.Broadcast(4, 4 * tir.vscale())), + ), TestCase( fld(tvm.tir.Ramp(x * 8, 3, 4), tvm.tir.Broadcast(4, 4)), fld(tvm.tir.Ramp(x * 8, 3, 4), tvm.tir.Broadcast(4, 4)), @@ -158,7 +194,15 @@ class TestVector(BaseCompare): # floor mod TestCase(flm(y.astype("int32x2"), x.astype("int32x2")), flm(y, x).astype("int32x2")), TestCase(flm(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Broadcast(flm(x, 2), 4)), + TestCase( + flm(tvm.tir.Ramp(x, 4, 8 * tir.vscale()), 2), + tvm.tir.Broadcast(flm(x, 2), 8 * tir.vscale()), + ), TestCase(flm(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), tvm.tir.Ramp(1, 1, 4)), + TestCase( + flm(tvm.tir.Ramp(x * 8 + 1, 1, 4 * tir.vscale()), 8), + flm(tvm.tir.Ramp(1, 1, 4 * tir.vscale()), 8), + ), TestCase(flm(tvm.tir.Ramp(x * 8 + 1, 15, 4), 8), flm(tvm.tir.Ramp(1, 15, 4), 8)), TestCase( flm(tvm.tir.Ramp(x, 8, 4), tvm.tir.Broadcast(4, 4)), tvm.tir.Broadcast(flm(x, 4), 4) diff --git a/tests/python/arith/test_arith_simplify.py b/tests/python/arith/test_arith_simplify.py index 754bf36d7ab2..1353e2d88b01 100644 --- a/tests/python/arith/test_arith_simplify.py +++ b/tests/python/arith/test_arith_simplify.py @@ -65,5 +65,13 @@ def test_regression_simplify_inf_recursion(): ana.rewrite_simplify(res) +def test_symbolic_vscale_expression(): + with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu"): + ana = tvm.arith.Analyzer() + assert ana.can_prove(128 // tir.vscale() * tir.vscale() <= 128) + assert ana.can_prove(128 // (tir.vscale() * 4) * (tir.vscale() * 4) <= 128) + assert ana.can_prove(tir.vscale() % 2 <= tir.vscale()) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index e873bce52bdf..014f479ed913 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -16,7 +16,7 @@ # under the License. import tvm from tvm import te -from tvm.script import tir as TIR +from tvm.script import tir as T import re import os import ctypes @@ -476,5 +476,25 @@ def check_correct_assembly(type): check_correct_assembly(type=dtype) +def test_predicated_scalable_buffer(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + @T.prim_func + def before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(16, 4 * T.vscale())): + for i_1 in T.vectorized(4 * T.vscale()): + if i_0 * 4 * T.vscale() + i_1 < 14: + B[i_0 * 4 * T.vscale() + i_1] = A[i_0 * 4 * T.vscale() + i_1] + 1.0 + + with tvm.target.Target(target): + out = tvm.build(before) + + ll = out.get_source("ll") + assert "get.active.lane.mask" in ll + assert "llvm.masked.load" in ll + assert "llvm.masked.store" in ll + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py b/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py index 6c069dc6bf0a..e0d6876d7626 100644 --- a/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py +++ b/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py @@ -343,14 +343,20 @@ def correct_trace(a, b, c, d): ' b2 = sch.get_block(name="C", func_name="main")', " sch.compute_inline(block=b1)", " l3, l4 = sch.get_loops(block=b2)", - " l5, l6 = sch.split(loop=l3, factors=" + str(a) + ", preserve_unit_iters=True)", - " l7, l8 = sch.split(loop=l4, factors=" + str(b) + ", preserve_unit_iters=True)", + " l5, l6 = sch.split(loop=l3, factors=" + + str(a) + + ", preserve_unit_iters=True, disable_predication=False)", + " l7, l8 = sch.split(loop=l4, factors=" + + str(b) + + ", preserve_unit_iters=True, disable_predication=False)", " sch.reorder(l5, l7, l6, l8)", " l9, l10 = sch.get_loops(block=b0)", - " l11, l12 = sch.split(loop=l9, factors=" + str(c) + ", preserve_unit_iters=True)", + " l11, l12 = sch.split(loop=l9, factors=" + + str(c) + + ", preserve_unit_iters=True, disable_predication=False)", " l13, l14 = sch.split(loop=l10, factors=" + str(d) - + ", preserve_unit_iters=True)", + + ", preserve_unit_iters=True, disable_predication=False)", " sch.reorder(l11, l13, l12, l14)", ] ) diff --git a/tests/python/target/test_arm_target.py b/tests/python/target/test_arm_target.py index dc8452710a8a..8cb559f81b18 100644 --- a/tests/python/target/test_arm_target.py +++ b/tests/python/target/test_arm_target.py @@ -14,11 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import subprocess +import tempfile + import pytest +import numpy as np import tvm from tvm.topi.arm_cpu.conv2d_int8 import is_int8_hw_support from tvm.target import codegen +from tvm.script import tir as T + +import re + llvm_version, arm_target, input_dtype, kernel_dtype, is_supported = tvm.testing.parameters( # Testing mcpu type @@ -61,3 +69,408 @@ def test_arm_conv2d_int8_support( with tvm.target.Target(arm_target): monkeypatch.setattr(codegen, "llvm_version_major", lambda: llvm_version) assert is_int8_hw_support(input_dtype, kernel_dtype) == is_supported + + +@pytest.fixture(scope="session") +def sve_device_vector_length(): + c_code = r""" + #include + #include + + int main() { + printf("%ld\n", svcntb() * 8); + } + """ + + with tempfile.TemporaryDirectory() as tmp_dir: + c_path = f"{tmp_dir}/vl.c" + o_path = f"{tmp_dir}/out.o" + with open(c_path, "w") as f: + f.write(c_code) + tvm.contrib.cc.create_executable(o_path, c_path, ["-march=native"]) + out = subprocess.check_output(o_path, shell=True).strip().decode() + + return int(out) + + +@tvm.testing.requires_aarch64_sve +def test_scalable_vectorize(sve_device_vector_length): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + num_elements = sve_device_vector_length // 32 + vscale = tvm.tir.vscale() + + @T.prim_func + def main(A: T.Buffer((num_elements,), "float32"), B: T.Buffer((num_elements,), "float32")): + for i_0 in range(T.ceildiv(num_elements, 4 * vscale)): + for i_1 in T.vectorized(4 * vscale): + A_1 = T.Buffer((num_elements,), data=A.data) + B_1 = T.Buffer((num_elements,), data=B.data) + B_1[i_0 * 4 * vscale + i_1] = A_1[i_0 * 4 * vscale + i_1] + + build_mod = tvm.build(main, target=target) + + llvm = build_mod.get_source() + sve_vec_instrs = re.findall(r"\", llvm) + assert len(sve_vec_instrs) > 0, "No scalable vectors in assembly" + + dev = tvm.cpu(0) + np_zeros = np.zeros((num_elements,)).astype("float32") + np_ones = np.ones((num_elements,)).astype("float32") + + input_buf = tvm.nd.array(np_ones, device=dev) + output_buf = tvm.nd.array(np_zeros, device=dev) + + build_mod(input_buf, output_buf) + tvm.testing.assert_allclose(output_buf.numpy(), np_ones) + + +@tvm.testing.requires_aarch64_sve +def test_scalable_div(sve_device_vector_length): + np.random.seed(0) + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + dev = tvm.cpu(0) + + @T.prim_func + def my_func(A: T.Buffer((1,), "int32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + A[0] = T.Div(10000, 4 * T.vscale()) + + mod = tvm.build(my_func, target=target) + + A_nd = tvm.nd.array(np.empty((1,), dtype="int32"), device=dev) + mod(A_nd) + + ref = 10000 // (sve_device_vector_length // 32) + tvm.testing.assert_allclose(A_nd.numpy()[0], ref) + + +@tvm.testing.requires_aarch64_sve +def test_scalable_buffer_load_store(sve_device_vector_length): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + num_elements = sve_device_vector_length // 32 + dev = tvm.cpu(0) + + @T.prim_func + def my_func(A: T.Buffer((num_elements,), "float32"), B: T.Buffer((num_elements,), "float32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())] + + mod = tvm.build(my_func, target=target) + + A_np = np.random.uniform(size=(num_elements,)).astype("float32") + B_np = np.zeros((num_elements,)).astype("float32") + A_nd = tvm.nd.array(A_np, device=dev) + B_nd = tvm.nd.array(B_np, device=dev) + mod(A_nd, B_nd) + + tvm.testing.assert_allclose(B_nd.numpy(), A_np) + + +@tvm.testing.requires_aarch64_sve +def test_scalable_loop_bound(sve_device_vector_length): + np.random.seed(0) + + dtype = "float32" + num_elements = sve_device_vector_length // 32 + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + dev = tvm.cpu(0) + + @T.prim_func + def my_func(A: T.Buffer((num_elements,), "float32"), B: T.Buffer((num_elements,), "float32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + for i in T.serial(0, 4 * T.vscale()): + B[i] = A[i] + + mod = tvm.build(my_func, target=target) + + A_np = np.random.uniform(size=(num_elements,)).astype(dtype) + B_np = np.zeros((num_elements,)).astype(dtype) + A_nd = tvm.nd.array(A_np, device=dev) + B_nd = tvm.nd.array(B_np, device=dev) + mod(A_nd, B_nd) + + tvm.testing.assert_allclose(B_nd.numpy(), A_np) + + +@tvm.testing.requires_aarch64_sve +def test_scalable_broadcast(sve_device_vector_length): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + num_elements = sve_device_vector_length // 32 + dev = tvm.cpu(0) + + @T.prim_func + def my_func(A: T.Buffer((num_elements,), "float32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(1, 4 * T.vscale()) + + mod = tvm.build(my_func, target=target) + + A_np = np.zeros((num_elements,)).astype("float32") + A_nd = tvm.nd.array(A_np, device=dev) + mod(A_nd) + + ref = np.ones((num_elements,)) + tvm.testing.assert_allclose(A_nd.numpy(), ref) + + +@tvm.testing.requires_aarch64_sve +def test_scalable_ptrue_predicate(sve_device_vector_length): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + num_elements = sve_device_vector_length // 32 + dev = tvm.cpu(0) + + @T.prim_func + def my_func(A: T.Buffer((num_elements,), "float32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(T.IntImm("int1", 1), 4 * T.vscale()) + + mod = tvm.build(my_func, target=target) + + A_np = np.zeros((num_elements,)).astype("float32") + A_nd = tvm.nd.array(A_np, device=dev) + mod(A_nd) + + ref = np.ones((num_elements,)) + tvm.testing.assert_allclose(A_nd.numpy(), ref) + + +@pytest.mark.skip(reason="Currently don't support scalable gathers in codegen") +@tvm.testing.requires_aarch64_sve +def test_scalable_gather(sve_device_vector_length): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + num_elements = sve_device_vector_length // 32 + dev = tvm.cpu(0) + stride = 2 + + @T.prim_func + def my_func( + A: T.Buffer((stride * num_elements,), "float32"), B: T.Buffer((num_elements,), "float32") + ): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, stride, 4 * T.vscale())] + + mod = tvm.build(my_func, target=target) + + A_np = np.random.uniform(size=(stride * num_elements,)).astype("float32") + B_np = np.zeros((num_elements,)).astype("float32") + A_nd = tvm.nd.array(A_np, device=dev) + B_nd = tvm.nd.array(B_np, device=dev) + mod(A_nd, B_nd) + + ref = A_np[::stride] + tvm.testing.assert_allclose(B_nd.numpy(), ref) + + +@pytest.mark.skip(reason="Currently don't support scalable scatters in codegen") +@tvm.testing.requires_aarch64_sve +def test_scalable_scatter(sve_device_vector_length): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + num_elements = sve_device_vector_length // 32 + dev = tvm.cpu(0) + stride = 2 + + @T.prim_func + def my_func( + A: T.Buffer((num_elements,), "float32"), B: T.Buffer((stride * num_elements,), "float32") + ): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + B[T.ramp(0, stride, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())] + + mod = tvm.build(my_func, target=target) + + A_np = np.random.uniform(size=(num_elements,)).astype("float32") + B_np = np.zeros((stride * num_elements,)).astype("float32") + A_nd = tvm.nd.array(A_np, device=dev) + B_nd = tvm.nd.array(B_np, device=dev) + mod(A_nd, B_nd) + + ref = B_np + ref[::stride] = A_np + tvm.testing.assert_allclose(B_nd.numpy(), ref) + + +@pytest.mark.skip(reason="Currently don't support scalable gathers in codegen") +@tvm.testing.requires_aarch64_sve +def test_scalable_complex_gather(sve_device_vector_length): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + num_elements = sve_device_vector_length // 32 + dev = tvm.cpu(0) + + @T.prim_func + def my_func(A: T.Buffer((num_elements,), "float32"), B: T.Buffer((num_elements,), "float32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + B[T.ramp(0, 1, 4 * T.vscale())] = A[2 * T.ramp(0, 1, 4 * T.vscale()) % 4] + + mod = tvm.build(my_func, target=target) + + A_np = np.random.uniform(size=(num_elements,)).astype("float32") + B_np = np.zeros((num_elements,)).astype("float32") + A_nd = tvm.nd.array(A_np, device=dev) + B_nd = tvm.nd.array(B_np, device=dev) + mod(A_nd, B_nd) + + tvm.testing.assert_allclose(B_nd.numpy(), A_np) + + +@pytest.mark.skip(reason="Currently don't support scalable ramps in codegen") +@tvm.testing.requires_aarch64_sve +def test_scalable_ramp(sve_device_vector_length): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + num_elements = sve_device_vector_length // 32 + dev = tvm.cpu(0) + + @T.prim_func + def my_func(A: T.Buffer((num_elements,), "int32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + A[T.ramp(0, 1, 4 * T.vscale())] = T.ramp(11, 1, 4 * T.vscale()) + + mod = tvm.build(my_func, target=target) + + A_np = np.zeros((num_elements,)).astype("int32") + A_nd = tvm.nd.array(A_np, device=dev) + mod(A_nd) + + ref = np.arange(11, 11 + num_elements) + tvm.testing.assert_allclose(A_nd.numpy(), ref) + + +@tvm.testing.requires_aarch64_sve +@pytest.mark.parametrize("disable_predication", [True, False]) +def test_schedule_split_vectorized(disable_predication): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + dev = tvm.cpu(0) + + @T.prim_func + def my_func(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + for i in T.serial(128): + with T.block("A"): + v_i = T.axis.remap("S", [i]) + B[v_i] = A[v_i] + 1.0 + + sch = tvm.tir.Schedule(my_func) + (a,) = sch.get_loops("A") + + with tvm.target.Target(target): + _, a1 = sch.split( + a, + factors=[T.ceildiv(128, 4 * T.vscale()), 4 * T.vscale()], + disable_predication=disable_predication, + ) + + sch.vectorize(a1) + mod = tvm.build(sch.mod["main"], target=target) + + A_np = np.arange(128).astype("float32") + A_nd = tvm.nd.array(A_np, device=dev) + B_np = np.zeros(128).astype("float32") + B_nd = tvm.nd.array(B_np, device=dev) + mod(A_nd, B_nd) + + ref = A_np + 1.0 + tvm.testing.assert_allclose(B_nd.numpy(), ref) + + +def _test_accuracy(input_values, output_values, build_mod): + dev = tvm.cpu(0) + + input_buf = tvm.nd.array(input_values, device=dev) + + np_zeros = np.zeros(output_values.shape).astype("float32") + output_buf = tvm.nd.array(np_zeros, device=dev) + + build_mod(input_buf, output_buf) + tvm.testing.assert_allclose(output_buf.numpy(), output_values) + + +def test_vectorize_to_sve(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + vscale = tvm.tir.vscale() + buffer_size = 128 + + @T.prim_func + def main(A: T.Buffer((buffer_size,), "float32"), B: T.Buffer((buffer_size,), "float32")): + for i_0 in range(tvm.tir.ceildiv(128, 4 * vscale)): + for i_1 in T.vectorized(4 * vscale): + A_1 = T.Buffer((128,), data=A.data) + B_1 = T.Buffer((128,), data=B.data) + B_1[i_0 * 4 * vscale + i_1] = A_1[i_0 * 4 * vscale + i_1] + + build_mod = tvm.build(main, target=target) + + llvm = build_mod.get_source() + + assert re.findall(r"\", llvm), "No scalable vectors in assembly" + + if tvm.testing.has_cpu_feat("sve"): + print("running on an SVE enabled machine...") + + np_ones = np.ones((buffer_size,)).astype("float32") + _test_accuracy(np_ones, np_ones, build_mod) + + +def test_vectorize_to_sve_with_broadcast(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + vscale = tvm.tir.vscale() + buffer_size = 128 + + @T.prim_func + def main(A: T.Buffer((buffer_size,), "float32"), B: T.Buffer((buffer_size,), "float32")): + for i_0 in range(tvm.tir.ceildiv(128, 4 * vscale)): + for i_1 in T.vectorized(4 * vscale): + A_1 = T.Buffer((128,), data=A.data) + B_1 = T.Buffer((128,), data=B.data) + B_1[i_0 * 4 * vscale + i_1] = A_1[i_0 * 4 * vscale + i_1] * 5 + + build_mod = tvm.build(main, target=target) + + llvm = build_mod.get_source() + + assert re.findall(r"\", llvm), "No scalable vectors in assembly" + + if tvm.testing.has_cpu_feat("sve"): + print("running on an SVE enabled machine...") + + np_ones = np.ones((buffer_size,)).astype("float32") + output_values = np_ones * 5 + + _test_accuracy(np_ones, output_values, build_mod) + + +def test_sve_full_stack(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + vscale = tvm.tir.vscale() + buffer_size = 130 + + @T.prim_func + def main(A: T.Buffer((buffer_size,), "float32"), B: T.Buffer((buffer_size,), "float32")): + for i in range(buffer_size): + with T.block("A"): + B[i] = A[i] + + # Schedule + with tvm.target.Target(target): + sch = tvm.tir.Schedule(main) + (l,) = sch.get_loops("A") + + _, l1 = sch.split(l, factors=[T.ceildiv(buffer_size, 4 * vscale), 4 * vscale]) + + sch.vectorize(l1) + + with tvm.transform.PassContext(config={"tir.enable_buffer_predication": True}): + build_mod = tvm.build(sch.mod["main"], target=target) + + llvm = build_mod.get_source() + + assert re.findall(r"\", llvm), "No scalable vectors in llvm" + assert re.findall(r"llvm.masked", llvm), "No masked instructions in llvm" + + if tvm.testing.has_cpu_feat("sve"): + print("running on an SVE enabled machine...") + + np_ones = np.ones((buffer_size,)).astype("float32") + _test_accuracy(np_ones, np_ones, build_mod) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/tir-base/test_tir_nodes.py b/tests/python/tir-base/test_tir_nodes.py index 49816778f11f..bb8f14191380 100644 --- a/tests/python/tir-base/test_tir_nodes.py +++ b/tests/python/tir-base/test_tir_nodes.py @@ -401,5 +401,32 @@ def test_intimm_cond(): assert x == 1 +def _create_ramp(lanes): + return tvm.tir.Ramp(0, 1, lanes) + + +def _create_broadcast(lanes): + return tvm.tir.Broadcast(0, lanes) + + +@pytest.mark.parametrize("lanes", [(11 * tvm.tir.vscale()), (tvm.tir.vscale() * 11)]) +@pytest.mark.parametrize("node_func", [_create_ramp, _create_broadcast]) +def test_scalable_vec(lanes, node_func): + def _check_dtype(node): + assert node.dtype == "int32x11xvscale" + + _check_dtype(node_func(lanes)) + + +@pytest.mark.parametrize( + "lanes", [(tvm.tir.vscale()), (tvm.tir.vscale() + 3), (tvm.tir.vscale() * 2 + 5)] +) +@pytest.mark.parametrize("node_func", [_create_ramp, _create_broadcast]) +def test_scalable_vec_error(lanes, node_func): + + with pytest.raises(tvm.error.TVMError): + node_func(lanes) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-base/test_tir_scalable_datatype.py b/tests/python/tir-base/test_tir_scalable_datatype.py new file mode 100644 index 000000000000..2c83314470de --- /dev/null +++ b/tests/python/tir-base/test_tir_scalable_datatype.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import tir +from tvm.script import tir as T + +""" +Tests for scalable data types. +""" + + +def test_create_scalable_data_type_python_api(): + dtype = tvm.DataType("float32x4xvscale") + assert str(dtype) == "float32x4xvscale" + + +def test_create_scalable_tir_intrin(): + intrin = tir.call_llvm_intrin("int32x4xvscale", "llvm.experimental.stepvector") + assert intrin.dtype == "int32x4xvscale" + assert str(intrin) == 'T.call_llvm_intrin("int32x4xvscale", "llvm.experimental.stepvector")' + + +def test_tvm_script_create_scalable_tir_intrin(): + @T.prim_func + def my_func(): + T.call_llvm_intrin("int32x4xvscale", "llvm.experimental.stepvector") + + assert ( + 'T.call_llvm_intrin("int32x4xvscale", "llvm.experimental.stepvector")' in my_func.script() + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py index 679b147446ea..23bcdd1ec55b 100644 --- a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py +++ b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py @@ -653,5 +653,73 @@ def test_split_int64_factors(): assert_structural_equal_ignore_global_symbol(elementwise_symbolic_split, sch.mod["main"]) +@pytest.mark.parametrize("num_elements", [128, 115]) +def test_split_predicated(num_elements): + # By default, splitting with vscale will result in predication being + # applied. This is because at compile-time we don't know if vscale is + # a multiple of the extent of the loop to be split. + @T.prim_func + def before(A: T.Buffer((num_elements,), "float32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + for i in T.serial(num_elements): + with T.block("A"): + v_i = T.axis.remap("S", [i]) + A[v_i] = 1.0 + + @T.prim_func + def after(A: T.Buffer((num_elements,), "float32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + for i_0, i_1 in T.grid( + (T.vscale() * 4 + (num_elements - 1)) // (T.vscale() * 4), T.vscale() * 4 + ): + with T.block("A"): + v_i = T.axis.spatial(num_elements, i_0 * (T.vscale() * 4) + i_1) + T.where(i_0 * (T.vscale() * 4) + i_1 < num_elements) + A[v_i] = 1.0 + + with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu"): + sch = tvm.tir.Schedule(before) + (a,) = sch.get_loops("A") + sch.split(a, factors=[T.ceildiv(num_elements, 4 * T.vscale()), 4 * T.vscale()]) + + tvm.ir.assert_structural_equal(sch.mod["main"], after) + + +def test_split_assume_exact_multiple(): + # If the schedule writer knows the extent of the loop to be split will always + # be a multiple, they may use `disable_predication=True` to ensure + # a predicate is not created. + num_elements = 128 + + @T.prim_func + def before(A: T.Buffer((num_elements,), "float32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + for i in T.serial(num_elements): + with T.block("A"): + v_i = T.axis.remap("S", [i]) + A[v_i] = 1.0 + + @T.prim_func + def after(A: T.Buffer((num_elements,), "float32")): + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + for i_0, i_1 in T.grid( + (T.vscale() * 4 + (num_elements - 1)) // (T.vscale() * 4), T.vscale() * 4 + ): + with T.block("A"): + v_i = T.axis.spatial(num_elements, i_0 * (T.vscale() * 4) + i_1) + A[v_i] = 1.0 + + with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu"): + sch = tvm.tir.Schedule(before) + (a,) = sch.get_loops("A") + sch.split( + a, + factors=[T.ceildiv(num_elements, 4 * T.vscale()), 4 * T.vscale()], + disable_predication=True, + ) + + tvm.ir.assert_structural_equal(sch.mod["main"], after) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-schedule/test_tir_schedule_trace.py b/tests/python/tir-schedule/test_tir_schedule_trace.py index a793699ca755..18f15d6a7af8 100644 --- a/tests/python/tir-schedule/test_tir_schedule_trace.py +++ b/tests/python/tir-schedule/test_tir_schedule_trace.py @@ -88,7 +88,7 @@ def _make_split(inputs, outputs): # pylint: disable=redefined-builtin return Instruction( kind=InstructionKind.get("Split"), inputs=inputs, - attrs=[True], + attrs=[True, False], outputs=outputs, ) @@ -304,7 +304,7 @@ def test_trace_simplified_3(): "def apply_trace(sch: tir.Schedule) -> None:", ' b0 = sch.get_block(name="B", func_name="main")', " l1, = sch.get_loops(block=b0)", - " l2, l3 = sch.split(loop=l1, factors=[None, 32], preserve_unit_iters=True)", + " l2, l3 = sch.split(loop=l1, factors=[None, 32], preserve_unit_iters=True, disable_predication=False)", ) ) diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index 2448fffe8929..ffbf862a36e6 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm import te +from tvm.script import tir as T def test_vectorize_loop(): @@ -226,13 +227,113 @@ def test_vectorize_dtype_mismatch(): tvm.lower(s, [A], "llvm", simple_mode=True) +def test_vectorize_and_predicate_all_buffer_loads_stores(): + @T.prim_func + def before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 < 14: + B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0 + + # TODO(lhutton1) Needs parser support for predicates + # @T.prim_func + # def expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + # T.func_attr({"tir.noalias": T.bool(True)}) + # for i_0 in range(4): + # B(pred=T.get_active_lane_mask(i_0 * 4, 14))[i_0 * 4:i_0 * 4 + 4] = A(pred=T.get_active_lane_mask(i_0 * 4, 14))[i_0 * 4:i_0 * 4 + 4] + T.Broadcast(T.float32(1), 4) + # tvm.ir.assert_structural_equal(after, expected) + + # Instead, settle for a simple predicate check.. + mod = tvm.IRModule.from_expr(before) + with tvm.transform.PassContext(config={"tir.enable_buffer_predication": True}): + after = tvm.tir.transform.VectorizeLoop()(mod)["main"] + assert after.body.body.predicate != None + + +def test_vectorize_and_predicate_some_buffer_loads_stores(): + # Currently revert to scalarizing the block if not all accesses + # have been predicated, otherwise incorrect code is generated. + @T.prim_func + def before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 < 14: + B[i_0 * 4 + i_1] = A[i_0] + 1.0 + + @T.prim_func + def expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + for i_0, i_1_s in T.grid(4, 4): + if i_0 * 4 + i_1_s < 14: + B[i_0 * 4 + i_1_s] = A[i_0] + T.float32(1) + + mod = tvm.IRModule.from_expr(before) + with tvm.transform.PassContext(config={"tir.enable_buffer_predication": True}): + after = tvm.tir.transform.VectorizeLoop()(mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + +def test_vectorize_and_predicate_multiple_access_statements(): + @T.prim_func + def before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 < 14: + A[i_0 * 4 + i_1] = 2.0 + B[i_0 * 4 + i_1] = 1.0 + + # TODO(lhutton1) Needs parser support for predicates + # @T.prim_func + # def expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + # T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + # for i_0 in range(4): + # A(pred=T.get_active_lane_mask(i_0 * 4, 14))[i_0 * 4:i_0 * 4 + 4] = T.Broadcast(T.float32(2), 4) + # B(pred=T.get_active_lane_mask(i_0 * 4, 14))[i_0 * 4:i_0 * 4 + 4] = T.Broadcast(T.float32(1), 4) + # tvm.ir.assert_structural_equal(after, expected) + + # Instead, settle for a simple predicate check.. + before_mod = tvm.IRModule.from_expr(before) + with tvm.transform.PassContext(config={"tir.enable_buffer_predication": True}): + after = tvm.tir.transform.VectorizeLoop()(before_mod)["main"] + assert after.body.body[0].predicate != None + assert after.body.body[1].predicate != None + + +def test_vectorize_and_predicate_invalid_conditions(): + @T.prim_func + def before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(14, 4)): + for i_1 in T.vectorized(4): + if i_0 * 4 + i_1 > 14: + A[i_0 * 4 + i_1] = 2.0 + if 14 < i_0 * 4 + i_1: + A[i_0 * 4 + i_1] = 2.0 + if i_0 * 4 + i_1 < i_0 * 4 + i_1: + A[i_0 * 4 + i_1] = 2.0 + + @T.prim_func + def expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + for i_0 in range(4): + for i_1_s in range(4): + if i_0 * 4 + i_1_s > 14: + A[i_0 * 4 + i_1_s] = T.float32(2) + for i_1_s in range(4): + if 14 < i_0 * 4 + i_1_s: + A[i_0 * 4 + i_1_s] = T.float32(2) + for i_1_s in range(4): + if i_0 * 4 + i_1_s < i_0 * 4 + i_1_s: + A[i_0 * 4 + i_1_s] = T.float32(2) + + before_mod = tvm.IRModule.from_expr(before) + with tvm.transform.PassContext(config={"tir.enable_buffer_predication": True}): + after = tvm.tir.transform.VectorizeLoop()(before_mod)["main"] + tvm.ir.assert_structural_equal(after, expected) + + if __name__ == "__main__": - test_vectorize_vector() - test_vectorize_with_if() - test_vectorize_loop() - test_vectorize_if_then_else() - test_vectorize_with_le_cond() - test_vectorize_with_ge_cond() - test_vectorize_let() - test_vectorize_while_fail() - test_vectorize_dtype_mismatch() + tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index 5b3e68e22fa9..cfe48d4d923b 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -3339,6 +3339,15 @@ def func() -> None: return func +def scalable_vectors(): + @T.prim_func + def func(a: T.handle): + A = T.match_buffer(a, (200,), "float32") + A[T.Ramp(11, 2, 4 * tir.vscale())] = T.Broadcast(125, 4 * tir.vscale()) + + return func + + def let_expression(): @T.prim_func def func(): @@ -4038,6 +4047,7 @@ def func(): buffer_axis_separator, buffer_ramp_access_as_slice_index, ramp_int64, + scalable_vectors, let_expression, void_ptr, decl_buffer, From 1a3c9593736eee540ac7aa6f4ce8d5b72ac8d11c Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Tue, 9 Jan 2024 10:44:14 +0000 Subject: [PATCH 2/6] Fix ci_gpu and ci_i386 builds Change-Id: I7d90c8b8396ba7a2b609a91bea2fe5f599d5cb96 --- src/target/llvm/codegen_llvm.cc | 17 ++++++----------- src/target/spirv/codegen_spirv.cc | 10 ++++++++-- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index f38e27e5b524..745fe5c00ec9 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1466,6 +1466,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { return builder_->CreateAssumption(cond); } else if (op->op.same_as(builtin::tvm_thread_invariant())) { return MakeValue(op->args[0]); +#if TVM_LLVM_VERSION >= 110 } else if (op->op.same_as(builtin::vscale())) { llvm::Intrinsic::ID id = llvm::Intrinsic::vscale; llvm::Function* f = GetIntrinsicDecl(id, builder_->getInt32Ty(), {}); @@ -1475,6 +1476,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { llvm::Function* f = GetIntrinsicDecl(id, DTypeToLLVMType(op->dtype), {builder_->getInt32Ty(), builder_->getInt32Ty()}); return builder_->CreateCall(f, {MakeValue(op->args[0]), MakeValue(op->args[1])}); +#endif } else { LOG(FATAL) << "unknown intrinsic " << op->op; } @@ -1872,24 +1874,17 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) { DataType dtype = op->dtype; llvm::Value* value = MakeValue(op->value); llvm::Type* type = DTypeToLLVMType(dtype); - llvm::ElementCount ec; - if (dtype.is_scalable()) { - ec = llvm::ElementCount::get(dtype.lanes(), true); - } else { - ec = llvm::ElementCount::get(dtype.lanes(), false); - } llvm::Constant* undef = llvm::UndefValue::get(type); llvm::Constant* zero = ConstInt32(0); value = builder_->CreateInsertElement(undef, value, zero); -#if TVM_LLVM_VERSION >= 120 - llvm::Constant* mask = llvm::ConstantVector::getSplat(ec, zero); -#elif TVM_LLVM_VERSION >= 110 +#if TVM_LLVM_VERSION >= 110 + llvm::ElementCount ec = llvm::ElementCount::get(dtype.lanes(), dtype.is_scalable()); llvm::Constant* mask = llvm::ConstantVector::getSplat(ec, zero); #else - if (dtype->is_scalable()) { + if (dtype.is_scalable()) { LOG(FATAL) << "Can't create scalable broadcast"; } - llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero); + llvm::Constant* mask = llvm::ConstantVector::getSplat(dtype.lanes(), zero); #endif return builder_->CreateShuffleVector(value, undef, mask); } diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index aca504b94b98..c997a5df197b 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -519,7 +519,10 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const RampNode* op) { std::vector values; spirv::Value base = MakeValue(op->base); - for (int i = 0; i < op->lanes; ++i) { + ICHECK(op->lanes->IsInstance()) + << "Scalable vectors are not supported in codegen_spirv"; + int lanes = static_cast(Downcast(op->lanes)->value); + for (int i = 0; i < lanes; ++i) { spirv::Value v = base; if (i != 0) { spirv::Value offset = MakeValue(make_const(op->stride.dtype(), i) * op->stride); @@ -533,7 +536,10 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const RampNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const BroadcastNode* op) { std::vector values; spirv::Value v = MakeValue(op->value); - for (int i = 0; i < op->lanes; i++) { + ICHECK(op->lanes->IsInstance()) + << "Scalable vectors are not supported in codegen_spirv"; + int lanes = static_cast(Downcast(op->lanes)->value); + for (int i = 0; i < lanes; i++) { values.push_back(v); } return builder_->Concat(values); From dd6ba3e4ce73fec367d8ebdb0fe581e42cb1151b Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Fri, 12 Jan 2024 09:53:35 +0000 Subject: [PATCH 3/6] Fix i386 cpptest build Change-Id: I2e994ffdacaf1dacdc875c2cfd62be433e6952c6 --- tests/cpp/tir_scalable_datatype.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/cpp/tir_scalable_datatype.cc b/tests/cpp/tir_scalable_datatype.cc index 59162a446b93..6db36b8db796 100644 --- a/tests/cpp/tir_scalable_datatype.cc +++ b/tests/cpp/tir_scalable_datatype.cc @@ -113,6 +113,7 @@ TEST(TIR, TestGetScalableVectorBytes) { // ----------- // Integration // ----------- +#if TVM_LLVM_VERSION >= 120 TEST(TIR, TestScalableIntrinCall) { tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true); tvm::tir::Call call = tvm::tir::Call( @@ -122,6 +123,7 @@ TEST(TIR, TestScalableIntrinCall) { ASSERT_EQ(call->Script(), "T.call_llvm_intrin(\"int32x4xvscale\", \"llvm.experimental.stepvector\")"); } +#endif TEST(TIR, TestTIRScriptScalableDtype2Str) { tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true); From 34de084580eb5d734484c559dca53bb443b377cd Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Fri, 5 Jan 2024 16:01:51 +0000 Subject: [PATCH 4/6] [SVE] Add parser support for predicates This commit adds support for expressing and printing buffer loads/stores in TVMScript. The buffer API has been extended with load and store methods which support passing a predicate parameter to BufferLoad/Store. When the printer encounters a predicated BufferLoad/Store, it will print with the .load/.store syntax as opposed to the shorthand [...] syntax as it is easier to represent predicates. Extending the functionality of vload and vstore was considered but they do not currently support expressing loading/storing of non-consecutive values and such a change will result in many changes across the codebase. An example of a predicated load and store in TVMScript: ``` A.load( [T.Ramp(i_0 * 4, 1, 4)], predicate=T.get_active_lane_mask("int1x4", i_0 * 4, 14), ) B.store( T.Broadcast(T.float32(1), 4), [T.Ramp(i_0 * 4, 1, 4)], predicate=T.get_active_lane_mask("int1x4", i_0 * 4, 14), ) ``` Change-Id: I1305c1b5d052ad109232604c6660e40d4a566dd6 --- include/tvm/script/ir_builder/tir/ir.h | 3 +- python/tvm/script/ir_builder/tir/ir.py | 8 +- python/tvm/script/parser/tir/parser.py | 2 + python/tvm/tir/__init__.py | 2 +- python/tvm/tir/buffer.py | 51 +++++++++++++ python/tvm/tir/expr.py | 4 +- python/tvm/tir/op.py | 20 +++++ python/tvm/tir/stmt.py | 4 +- src/script/ir_builder/tir/ir.cc | 5 +- src/script/printer/tir/buffer.cc | 17 ++++- src/tir/op/builtin.cc | 4 +- .../test_tir_transform_vectorize.py | 53 +++++++------ .../tvmscript/test_tvmscript_printer_tir.py | 75 +++++++++++++++++++ .../tvmscript/test_tvmscript_roundtrip.py | 11 +++ 14 files changed, 225 insertions(+), 34 deletions(-) diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 735d5ba6c0a1..d83303daac9a 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -410,8 +410,9 @@ Var EnvThread(String thread_tag); * \param buffer The buffer. * \param value The value to be stored. * \param indices The indices location to be stored. + * \param predicate A vector mask of int1 values that prevents storing values on masked-off lanes. */ -void BufferStore(Buffer buffer, PrimExpr value, Array indices); +void BufferStore(Buffer buffer, PrimExpr value, Array indices, PrimExpr predicate); /*! * \brief The prefetch hint for a buffer diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 7e42ea3c2204..104ce5764c20 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1262,6 +1262,7 @@ def buffer_store( buffer: Buffer, # pylint: disable=redefined-outer-name value: PrimExpr, indices: List[Union[PrimExpr, slice]], + predicate: Optional[PrimExpr] = None, ) -> None: """Buffer store node. @@ -1275,6 +1276,9 @@ def buffer_store( indices : List[Union[PrimExpr, slice]] The indices location to be stored. + + predicate : Optional[PrimExpr] + A vector mask of int1 values that prevents storing values on masked-off lanes. """ from tvm.arith import Analyzer # pylint: disable=import-outside-toplevel @@ -1295,7 +1299,7 @@ def buffer_store( if isinstance(value, bool) and buffer.dtype == "bool": value = IntImm("bool", value) return _ffi_api.BufferStore( # type: ignore[attr-defined] # pylint: disable=no-member - buffer, value, expr_indices + buffer, value, expr_indices, predicate ) @@ -1891,6 +1895,7 @@ def wrapped(*args, **kwargs): vectorlow = _dtype_forward(_tir_op.vectorlow) vectorhigh = _dtype_forward(_tir_op.vectorhigh) vectorcombine = _dtype_forward(_tir_op.vectorcombine) +get_active_lane_mask = _dtype_forward(_tir_op.get_active_lane_mask) broadcast = Broadcast ramp = Ramp @@ -2200,4 +2205,5 @@ def wrapped(*args, **kwargs): "CommReducer", "Range", "vscale", + "get_active_lane_mask", ] diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 89673d291b88..18e3528e693b 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -462,6 +462,8 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: elif isinstance(res, str): # Ignore docstrings pass + elif isinstance(res, tvm.tir.stmt.BufferStore): + T.buffer_store(res.buffer, res.value, res.indices, res.predicate) else: self.report_error(node, f"Parsing resulted in unexpected type {type(res)}") diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 1723804388b9..24ba4ccd2e58 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -88,7 +88,7 @@ from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace from .op import start_profile_intrinsic, end_profile_intrinsic -from .op import vscale +from .op import vscale, get_active_lane_mask from .generic import add, subtract, multiply from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index ec57ad7801ca..17104ba3d0ec 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -141,6 +141,57 @@ def vstore(self, begin, value): begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin return _ffi_api.BufferVStore(self, begin, value) # type: ignore + def load(self, indices, predicate=None): + """ + Load values at specified indices from buffer. + + Longhand notation that can be used for complex buffer load + expressions. For example, when the load involves predication. + + Parameters + ---------- + indices : List[PrimExpr] + The buffer indices to load values from. + + predicate : Optional[PrimExpr] + A vector mask of int1 values that prevents loading values on masked-off lanes. + + Returns + ------- + BufferLoad + A buffer load Expr. + """ + from .expr import BufferLoad # pylint: disable=import-outside-toplevel + + return BufferLoad(self, indices, predicate) + + def store(self, value, indices, predicate=None): + """ + Store given value at the specified indices in the buffer. + + Longhand notation that can be used for complex buffer store + statements. For example, when the store involves predication. + + Parameters + ---------- + value : PrimExpr + The value to be stored. + + indices : List[PrimExpr] + The buffer indices to store values to. + + predicate : Optional[PrimExpr] + A vector mask of int1 values that prevents storing values on masked-off lanes. + + Returns + ------- + BufferStore + A buffer store Stmt. + """ + from .stmt import BufferStore # pylint: disable=import-outside-toplevel + + return BufferStore(self, value, indices, predicate) + def scope(self): """Return the storage scope associated with this buffer. Returns diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 7123584f6c8e..c6c502a29632 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -1093,10 +1093,10 @@ class BufferLoad(PrimExprWithOp): The buffer to be loaded. indices : List[PrimExpr] - The buffer indices. + The buffer indices to load values from. predicate : Optional[PrimExpr] - The buffer predicate + A vector mask of int1 values that prevents loading values on masked-off lanes. span : Optional[Span] The location of this expression in the source code. diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index a6f9e3fd12d3..ddfadf1b0d4e 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -1590,6 +1590,26 @@ def vscale(): return call_intrin("int32", "tir.vscale") +def get_active_lane_mask(dtype, base, limit): + """ + Creates a mask corresponding to active and inactive vector lanes. + + Analogous to https://llvm.org/docs/LangRef.html#llvm-get-active-lane-mask-intrinsics. + + Parameters + ---------- + dtype : str + The data type of the result. + + base : PrimExpr + An expression reprsenting the base. + + limit : PrimExpr + An expression representing the limit. + """ + return call_intrin(dtype, "tir.get_active_lane_mask", base, limit) + + def vectorhigh(dtype, vec): """Get the high level half of the vector diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 31298af67f3d..0950547082a7 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -219,13 +219,13 @@ class BufferStore(Stmt): The buffer. value : PrimExpr - The value we to be stored. + The value to be stored. indices : List[PrimExpr] The indices location to be stored. predicate : Optional[PrimExpr] - The buffer predicate + A vector mask of int1 values that prevents storing values on masked-off lanes. span : Optional[Span] The location of the stmt in the source code. diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index a6569952177e..c8f725606580 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -516,7 +516,8 @@ Var EnvThread(String thread_tag) { return var; } -void BufferStore(Buffer buffer, PrimExpr value, Array indices) { +void BufferStore(Buffer buffer, PrimExpr value, Array indices, + PrimExpr predicate = PrimExpr()) { runtime::DataType buffer_dtype = buffer->dtype; int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1; bool index_scalable = indices.size() ? indices.back().dtype().is_scalable() : false; @@ -552,7 +553,7 @@ void BufferStore(Buffer buffer, PrimExpr value, Array indices) { } value = tvm::cast(lhs_dtype, value); } - AddToParent(tvm::tir::BufferStore(buffer, value, indices)); + AddToParent(tvm::tir::BufferStore(buffer, value, indices, predicate)); } void Prefetch(Buffer buffer, Array bounds) { diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index d23a3fe17aba..e5daee7d7b75 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -17,6 +17,7 @@ * under the License. */ #include // For `kAllocAlignment` +#include #include "./utils.h" @@ -273,23 +274,33 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::BufferStore store, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc buffer = d->AsDoc(store->buffer, p->Attr("buffer")); + ExprDoc value = d->AsDoc(store->value, p->Attr("value")); + + // Use .store(...) syntax when there is a predicate if (store->predicate.defined()) { + ExprDoc indices = d->AsDoc(store->indices, p->Attr("indices")); ExprDoc predicate = d->AsDoc(store->predicate, p->Attr("predicate")); - buffer = CallDoc(buffer, {}, {"pred"}, {predicate}); + return ExprStmtDoc( + buffer->Attr("store")->Call({value, indices}, {"predicate"}, {predicate})); } + return AssignDoc( /*lhs=*/buffer[BufferIndices(store->indices, p->Attr("indices"), d)], - /*rhs=*/d->AsDoc(store->value, p->Attr("value")), NullOpt); + /*rhs=*/value, NullOpt); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::BufferLoad load, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc buffer = d->AsDoc(load->buffer, p->Attr("buffer")); + + // Use .load(...) syntax when there is a predicate if (load->predicate.defined()) { + ExprDoc indices = d->AsDoc(load->indices, p->Attr("indices")); ExprDoc predicate = d->AsDoc(load->predicate, p->Attr("predicate")); - buffer = CallDoc(buffer, {}, {"pred"}, {predicate}); + return buffer->Attr("load")->Call({indices}, {"predicate"}, {predicate}); } + return buffer[BufferIndices(load->indices, p->Attr("indices"), d)]; }); diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index d8b73a8d90da..c5f6d707b95e 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -373,7 +373,9 @@ TIR_DEFINE_BUILTIN_FUNC(vscale).set_attr("TCallEffectKind", TIR_DEFINE_BUILTIN_FUNC(get_active_lane_mask) .set_num_inputs(2) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); TIR_DEFINE_BUILTIN_FUNC(assume) .set_attr("TCallEffectKind", Integer(CallEffectKind::kEmbedInfo)) diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index ffbf862a36e6..8c54e3a8afbf 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -236,19 +236,26 @@ def before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): if i_0 * 4 + i_1 < 14: B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0 - # TODO(lhutton1) Needs parser support for predicates - # @T.prim_func - # def expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): - # T.func_attr({"tir.noalias": T.bool(True)}) - # for i_0 in range(4): - # B(pred=T.get_active_lane_mask(i_0 * 4, 14))[i_0 * 4:i_0 * 4 + 4] = A(pred=T.get_active_lane_mask(i_0 * 4, 14))[i_0 * 4:i_0 * 4 + 4] + T.Broadcast(T.float32(1), 4) - # tvm.ir.assert_structural_equal(after, expected) - - # Instead, settle for a simple predicate check.. + @T.prim_func + def expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + for i_0 in range(4): + load_a = T.meta_var( + A.load( + [T.Ramp(i_0 * 4, 1, 4)], predicate=T.get_active_lane_mask("int1x4", i_0 * 4, 14) + ) + ) + add_1 = T.meta_var(load_a + T.Broadcast(T.float32(1), 4)) + B.store( + add_1, + [T.Ramp(i_0 * 4, 1, 4)], + predicate=T.get_active_lane_mask("int1x4", i_0 * 4, 14), + ) + mod = tvm.IRModule.from_expr(before) with tvm.transform.PassContext(config={"tir.enable_buffer_predication": True}): after = tvm.tir.transform.VectorizeLoop()(mod)["main"] - assert after.body.body.predicate != None + tvm.ir.assert_structural_equal(after, expected) def test_vectorize_and_predicate_some_buffer_loads_stores(): @@ -285,21 +292,25 @@ def before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): A[i_0 * 4 + i_1] = 2.0 B[i_0 * 4 + i_1] = 1.0 - # TODO(lhutton1) Needs parser support for predicates - # @T.prim_func - # def expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): - # T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) - # for i_0 in range(4): - # A(pred=T.get_active_lane_mask(i_0 * 4, 14))[i_0 * 4:i_0 * 4 + 4] = T.Broadcast(T.float32(2), 4) - # B(pred=T.get_active_lane_mask(i_0 * 4, 14))[i_0 * 4:i_0 * 4 + 4] = T.Broadcast(T.float32(1), 4) - # tvm.ir.assert_structural_equal(after, expected) + @T.prim_func + def expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + for i_0 in range(4): + A.store( + T.Broadcast(T.float32(2), 4), + [T.Ramp(i_0 * 4, 1, 4)], + predicate=T.get_active_lane_mask("int1x4", i_0 * 4, 14), + ) + B.store( + T.Broadcast(T.float32(1), 4), + [T.Ramp(i_0 * 4, 1, 4)], + predicate=T.get_active_lane_mask("int1x4", i_0 * 4, 14), + ) - # Instead, settle for a simple predicate check.. before_mod = tvm.IRModule.from_expr(before) with tvm.transform.PassContext(config={"tir.enable_buffer_predication": True}): after = tvm.tir.transform.VectorizeLoop()(before_mod)["main"] - assert after.body.body[0].predicate != None - assert after.body.body[1].predicate != None + tvm.ir.assert_structural_equal(after, expected) def test_vectorize_and_predicate_invalid_conditions(): diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index 18f4e153bff9..e258d58de606 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -889,5 +889,80 @@ def func(a_name: T.handle): assert re.match(expected_regex, script) +def test_predicated_load_store(): + from tvm.script import tir as T + + @T.prim_func + def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): + T.func_attr({"global_symbol": "func"}) + a_load = T.meta_var(A.load([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(0, 4))) + A.store(a_load, [0, T.Ramp(0, 2, 4)], predicate=T.Broadcast(0, 4)) + + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func +def func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): + A.store(A.load([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(0, 4)), [0, T.Ramp(0, 2, 4)], predicate=T.Broadcast(0, 4)) + """ + _assert_print(main, expected_output) + + +def test_predicated_buffer_load_store(): + a = tir.Var("a", "handle") + b = tir.Var("b", "handle") + buffer_map = { + a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A"), + b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"), + } + buffer_load = tir.BufferLoad( + buffer=buffer_map[b], indices=[0, tir.Ramp(0, 4, 4)], predicate=tir.Broadcast(0, 4) + ) + body = tir.BufferStore( + buffer=buffer_map[a], + value=buffer_load, + indices=[0, tir.Ramp(0, 2, 4)], + predicate=tir.Broadcast(0, 4), + ) + func = tir.PrimFunc( + params=[a, b], + ret_type=None, + buffer_map=buffer_map, + body=body, + ) + + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func(private=True) +def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): + A.store(B.load([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(0, 4)), [0, T.Ramp(0, 2, 4)], predicate=T.Broadcast(0, 4)) + """ + _assert_print(func, expected_output) + + +def test_predicated_scalable_load_store(): + from tvm.script import tir as T + + @T.prim_func + def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): + T.func_attr({"global_symbol": "func"}) + mask = T.meta_var(T.get_active_lane_mask("int1x4xvscale", 0, 13)) + a_load = T.meta_var(A.load([0, T.Ramp(0, 4, 4 * T.vscale())], predicate=mask)) + A.store(a_load, [0, T.Ramp(0, 2, 4 * T.vscale())], predicate=mask) + + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func +def func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): + A.store(\ + A.load([0, T.Ramp(0, 4, 4 * T.vscale())], predicate=T.get_active_lane_mask("int1x4xvscale", 0, 13)), \ + [0, T.Ramp(0, 2, 4 * T.vscale())], predicate=T.get_active_lane_mask("int1x4xvscale", 0, 13)\ + ) + """ + _assert_print(main, expected_output) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index cfe48d4d923b..fefacb5bd10b 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -3348,6 +3348,16 @@ def func(a: T.handle): return func +def predicated_buffer_load_store(): + @T.prim_func + def func(A: T.Buffer((4,), "float32"), B: T.Buffer((8,), "float32")): + for i_0 in range(4): + load_a = T.meta_var(A.load([T.Ramp(i_0, 1, 4)], predicate=T.Broadcast(1.0, 4))) + B.store(load_a, [T.Ramp(0, 2, 4)], predicate=T.Broadcast(1.0, 4)) + + return func + + def let_expression(): @T.prim_func def func(): @@ -4048,6 +4058,7 @@ def func(): buffer_ramp_access_as_slice_index, ramp_int64, scalable_vectors, + predicated_buffer_load_store, let_expression, void_ptr, decl_buffer, From 3a8c8e03834cf4e44855aac023b580f6d86d0c67 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Tue, 16 Jan 2024 11:03:35 +0000 Subject: [PATCH 5/6] Skip codegen tests on i386 target Change-Id: I8a893377d7361bed8840a645fbaff81ab401ccb8 --- tests/python/target/test_arm_target.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/python/target/test_arm_target.py b/tests/python/target/test_arm_target.py index 8cb559f81b18..6a66731c3430 100644 --- a/tests/python/target/test_arm_target.py +++ b/tests/python/target/test_arm_target.py @@ -383,6 +383,7 @@ def _test_accuracy(input_values, output_values, build_mod): tvm.testing.assert_allclose(output_buf.numpy(), output_values) +@tvm.testing.skip_if_32bit(reason="Skipping test for i386 due to old version of LLVM") def test_vectorize_to_sve(): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" vscale = tvm.tir.vscale() @@ -409,6 +410,7 @@ def main(A: T.Buffer((buffer_size,), "float32"), B: T.Buffer((buffer_size,), "fl _test_accuracy(np_ones, np_ones, build_mod) +@tvm.testing.skip_if_32bit(reason="Skipping test for i386 due to old version of LLVM") def test_vectorize_to_sve_with_broadcast(): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" vscale = tvm.tir.vscale() @@ -437,6 +439,7 @@ def main(A: T.Buffer((buffer_size,), "float32"), B: T.Buffer((buffer_size,), "fl _test_accuracy(np_ones, output_values, build_mod) +@tvm.testing.skip_if_32bit(reason="Skipping test for i386 due to old version of LLVM") def test_sve_full_stack(): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" vscale = tvm.tir.vscale() From c82f4c713d970322e8d4c99d23cc602740ea6110 Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Tue, 16 Jan 2024 10:50:24 +0000 Subject: [PATCH 6/6] Add support for disable_predication for te.split Plumb this argument through the te.split implementation to achieve feature parity with tir.split. Change-Id: I92707bc08e857b9fd5678153d998aeecebcce228 --- include/tvm/te/schedule.h | 14 +++++++--- python/tvm/te/schedule.py | 13 ++++++--- src/te/schedule/message_passing.cc | 3 ++- src/te/schedule/schedule_lang.cc | 30 ++++++++++++--------- tests/python/te/test_te_schedule.py | 42 +++++++++++++++++------------ 5 files changed, 64 insertions(+), 38 deletions(-) diff --git a/include/tvm/te/schedule.h b/include/tvm/te/schedule.h index 9ffcb105a7ba..c2f206f5c8c0 100644 --- a/include/tvm/te/schedule.h +++ b/include/tvm/te/schedule.h @@ -131,10 +131,11 @@ class Stage : public ObjectRef { * \param factor The split factor of the loop. * \param p_outer The result outer domain * \param p_inner The result inner domain. + * \param disable_predication Whether to not predicate the loop * \return reference to self. */ - TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer, - IterVar* p_inner); // NOLINT(*) + TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner, + bool disable_predication = false); // NOLINT(*) /*! * \brief Split the iteration with given number of parts. * @@ -142,10 +143,11 @@ class Stage : public ObjectRef { * \param nparts The number of parts in the outer domain. * \param p_outer The result outer domain. * \param p_inner The result inner domain. + * \param disable_predication Whether to not predicate the loop * \return reference to self. */ TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, - IterVar* p_inner); // NOLINT(*) + IterVar* p_inner, bool disable_predication = false); // NOLINT(*) /*! * \brief Fuse the inner outer domain to the target * \param outer The outer domain to be fused. @@ -761,6 +763,8 @@ class SplitNode : public IterVarRelationNode { PrimExpr factor; /*! \brief Number of parts, only factor or nparts can be given */ PrimExpr nparts; + /*! \brief Whether to disable the predication */ + bool disable_predication; void VisitAttrs(AttrVisitor* v) { v->Visit("parent", &parent); @@ -768,6 +772,7 @@ class SplitNode : public IterVarRelationNode { v->Visit("inner", &inner); v->Visit("factor", &factor); v->Visit("nparts", &nparts); + v->Visit("disable_predication", &disable_predication); } static constexpr const char* _type_key = "Split"; @@ -780,7 +785,8 @@ class SplitNode : public IterVarRelationNode { */ class Split : public IterVarRelation { public: - TVM_DLL Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts); + TVM_DLL Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts, + bool disable_predication); TVM_DEFINE_OBJECT_REF_METHODS(Split, IterVarRelation, SplitNode); }; diff --git a/python/tvm/te/schedule.py b/python/tvm/te/schedule.py index 936ead654dc8..180f334d1486 100644 --- a/python/tvm/te/schedule.py +++ b/python/tvm/te/schedule.py @@ -201,7 +201,7 @@ def rfactor(self, tensor, axis, factor_axis=0): class Stage(Object): """A Stage represents schedule for one operation.""" - def split(self, parent, factor=None, nparts=None): + def split(self, parent, factor=None, nparts=None, disable_predication=False): """Split the stage either by factor providing outer scope, or both Parameters @@ -215,6 +215,13 @@ def split(self, parent, factor=None, nparts=None): nparts : Expr, optional The number of outer parts. + disable_predication : bool + If enabled, don't create a predicate for guarding the loop. This + can be useful when splitting with scalable factors that the + schedule writer knows are divisible. + Warning: enabling this feature may result in incorrect code + generation if not used carefully. + Returns ------- outer : IterVar @@ -226,11 +233,11 @@ def split(self, parent, factor=None, nparts=None): if nparts is not None: if factor is not None: raise ValueError("Do not need to provide both outer and nparts") - outer, inner = _ffi_api.StageSplitByNParts(self, parent, nparts) + outer, inner = _ffi_api.StageSplitByNParts(self, parent, nparts, disable_predication) else: if factor is None: raise ValueError("Either nparts or factor need to be provided") - outer, inner = _ffi_api.StageSplitByFactor(self, parent, factor) + outer, inner = _ffi_api.StageSplitByFactor(self, parent, factor, disable_predication) return outer, inner def fuse(self, *args): diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 233663feac6d..e8f0d9332a16 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -637,7 +637,8 @@ void PassUpBoundCheck(const Stage& s, const Map& dom_map, if (outer || inner) { state[s->parent] = true; } else { - if (analyzer->CanProve(dom_map.at(s->parent)->extent == factor * step)) { + if (analyzer->CanProve(dom_map.at(s->parent)->extent == factor * step) || + s->disable_predication) { state[s->parent] = false; } else { state[s->parent] = true; diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index 44e742eee4cf..9e142b1bf76c 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -70,7 +70,7 @@ DataType MatchDataType(std::vector dtypes) { } void SplitHelper(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts, - IterVar* p_outer, IterVar* p_inner) { + IterVar* p_outer, IterVar* p_inner, bool disable_predication) { // Check if split is valid. ICHECK(parent->iter_type == kDataPar || parent->iter_type == kCommReduce || parent->iter_type == kOrdered) @@ -83,7 +83,7 @@ void SplitHelper(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr npar Array& all_vars = self->all_iter_vars; Array& leaf_vars = self->leaf_iter_vars; size_t pos = FindLeafVar(all_vars.GetArrayNode(), leaf_vars.GetArrayNode(), parent); - self->relations.push_back(Split(parent, outer, inner, factor, nparts)); + self->relations.push_back(Split(parent, outer, inner, factor, nparts, disable_predication)); // add vars to all vars all_vars.push_back(outer); all_vars.push_back(inner); @@ -226,17 +226,17 @@ Stage& Stage::set_store_predicate(PrimExpr predicate) { return *this; } -Stage& Stage::split(IterVar parent, PrimExpr factor, IterVar* p_outer, - IterVar* p_inner) { // NOLINT(*) +Stage& Stage::split(IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner, + bool disable_predication) { // NOLINT(*) With ctx(operator->()->attach_sch, __func__); - SplitHelper(operator->(), parent, factor, PrimExpr(), p_outer, p_inner); + SplitHelper(operator->(), parent, factor, PrimExpr(), p_outer, p_inner, disable_predication); return *this; } -Stage& Stage::split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, - IterVar* p_inner) { // NOLINT(*) +Stage& Stage::split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner, + bool disable_predication) { // NOLINT(*) With ctx(operator->()->attach_sch, __func__); - SplitHelper(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner); + SplitHelper(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner, disable_predication); return *this; } @@ -805,13 +805,15 @@ void ScheduleContext::ExitWithScope() { } } -Split::Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts) { +Split::Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts, + bool disable_predication) { auto n = make_object(); n->parent = parent; n->outer = outer; n->inner = inner; n->factor = factor; n->nparts = nparts; + n->disable_predication = disable_predication; data_ = std::move(n); } @@ -927,6 +929,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ", nparts="; p->Print(op->nparts); } + p->stream << ", disable_predication="; + p->stream << op->disable_predication; p->stream << ')'; }) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { @@ -973,16 +977,16 @@ TVM_REGISTER_GLOBAL("te.StageSetScope").set_body_method(&Stage::set_scope); TVM_REGISTER_GLOBAL("te.StageBind").set_body_method(&Stage::bind); TVM_REGISTER_GLOBAL("te.StageSplitByFactor") - .set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) { + .set_body_typed([](Stage stage, IterVar parent, PrimExpr factor, bool disable_predication) { IterVar outer, inner; - stage.split(parent, factor, &outer, &inner); + stage.split(parent, factor, &outer, &inner, disable_predication); return Array({outer, inner}); }); TVM_REGISTER_GLOBAL("te.StageSplitByNParts") - .set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) { + .set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts, bool disable_predication) { IterVar outer, inner; - stage.split_by_nparts(parent, nparts, &outer, &inner); + stage.split_by_nparts(parent, nparts, &outer, &inner, disable_predication); return Array({outer, inner}); }); diff --git a/tests/python/te/test_te_schedule.py b/tests/python/te/test_te_schedule.py index ed224883478e..fa04e729cf5b 100644 --- a/tests/python/te/test_te_schedule.py +++ b/tests/python/te/test_te_schedule.py @@ -19,6 +19,7 @@ import pytest import tvm from tvm import te +from tvm.driver.build_module import schedule_to_module def test_schedule_create(): @@ -354,21 +355,28 @@ def invalid_compute_at_loop(): invalid_compute_at_loop() +@pytest.mark.parametrize("split_factor", [4, 4 * tvm.tir.vscale()]) +@pytest.mark.parametrize("disable_predication", [True, False]) +def test_split_disable_predicate(disable_predication, split_factor): + A = te.placeholder((43,), name="A") + B = te.compute(A.shape, lambda i: A[i] + 2, name="C") + + sch = te.create_schedule(B.op) + (i,) = sch[B].op.axis + _, _ = sch[B].split(i, factor=split_factor, disable_predication=disable_predication) + + mod = schedule_to_module(sch, [A, B], "main") + + predicates = [] + + def _find_predicates(stmt): + if isinstance(stmt, tvm.tir.stmt.IfThenElse): + predicates.append(stmt) + + tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _find_predicates) + + assert bool(len(predicates)) != disable_predication + + if __name__ == "__main__": - test_singleton() - test_pragma() - test_tensor_intrin() - test_tensor_intrin_scalar_params() - test_rfactor() - test_schedule_create() - test_reorder() - test_tile() - test_split() - test_fuse() - test_fuse_with_split() - test_fuse_with_out_of_order_axis() - test_fuse_with_out_of_order_axis_with_reorder() - test_vectorize() - test_vectorize_commreduce() - test_legalize_invalid_attach() - test_compute_at() + tvm.testing.main()