From a431bc7474c11547626b6744f361f288728809b0 Mon Sep 17 00:00:00 2001 From: Meera Nakrani Date: Wed, 28 Jul 2021 12:35:10 +0000 Subject: [PATCH] Adding initial SVE support to TVM Prototype containing initial VLA and predication implementation --- include/tvm/runtime/data_type.h | 53 +- include/tvm/te/schedule.h | 7 + include/tvm/tir/expr.h | 6 +- include/tvm/tir/op.h | 3 +- include/tvm/tir/stmt.h | 11 +- include/tvm/tir/transform.h | 9 + include/tvm/tir/var.h | 8 +- python/tvm/autotvm/task/space.py | 2 +- python/tvm/te/hybrid/calls.py | 1 + python/tvm/te/schedule.py | 12 +- python/tvm/tir/ir_builder.py | 2 + python/tvm/tir/stmt.py | 1 + python/tvm/tir/transform/transform.py | 17 + src/autotvm/feature_visitor.cc | 3 + src/autotvm/feature_visitor.h | 1 + src/driver/driver_api.cc | 1 + src/printer/tir_text_printer.cc | 30 +- src/target/intrin_rule.cc | 4 +- src/target/llvm/codegen_aarch64.cc | 298 +++++++ src/target/llvm/codegen_llvm.cc | 21 +- src/target/llvm/codegen_llvm.h | 2 +- src/te/operation/op_utils.cc | 3 + src/te/schedule/schedule_lang.cc | 11 + src/tir/ir/expr.cc | 81 +- src/tir/ir/expr_functor.cc | 4 +- src/tir/ir/stmt.cc | 17 +- src/tir/op/op.cc | 9 +- src/tir/transforms/lower_intrin.cc | 2 +- src/tir/transforms/vectorize_loop_scalable.cc | 613 +++++++++++++ .../unittest/test_target_codegen_arm.py | 87 ++ .../unittest/test_target_codegen_llvm.py | 822 ------------------ .../test_tir_transform_vectorize_scalable.py | 177 ++++ 32 files changed, 1444 insertions(+), 874 deletions(-) create mode 100644 src/target/llvm/codegen_aarch64.cc create mode 100644 src/tir/transforms/vectorize_loop_scalable.cc delete mode 100644 tests/python/unittest/test_target_codegen_llvm.py create mode 100644 tests/python/unittest/test_tir_transform_vectorize_scalable.py diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 3b767547357b..f03ea5d0a546 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -69,14 +69,26 @@ class DataType { * \param bits The number of bits in the type. * \param lanes The number of lanes. */ - DataType(int code, int bits, int lanes) { + DataType(int code, int bits, int lanes, bool is_scalable = false) { data_.code = static_cast(code); data_.bits = static_cast(bits); data_.lanes = static_cast(lanes); + is_scalable_ = is_scalable; if (code == kBFloat) { ICHECK_EQ(bits, 16); } } + // DataType(int code, int bits) { + // data_.code = static_cast(code); + // data_.bits = static_cast(bits); + // is_scalable_ = true; + // std::cout<(8); // minimal lanes + // + //// if (code == kBFloat) { + //// ICHECK_EQ(bits, 16); + //// } + // } /*! \return The type code. */ int code() const { return static_cast(data_.code); } /*! \return number of bits in the data. */ @@ -107,6 +119,13 @@ class DataType { bool is_vector_bool() const { return is_vector() && bits() == 1; } /*! \return whether type is a Void type. */ bool is_void() const { return code() == DataType::kHandle && bits() == 0 && lanes() == 0; } + bool is_scalable() const { return is_scalable_; } + + DataType with_scalable_lanes() const { + int min_num_lanes = 128 / bits(); + ICHECK(min_num_lanes != 0); + return DataType(data_.code, data_.bits, min_num_lanes, true); + } /*! * \brief Create a new data type by change lanes to a specified value. * \param lanes The target number of lanes. @@ -131,7 +150,7 @@ class DataType { */ bool operator==(const DataType& other) const { return data_.code == other.data_.code && data_.bits == other.data_.bits && - data_.lanes == other.data_.lanes; + data_.lanes == other.data_.lanes; // && is_scalable_ == other.is_scalable_; } /*! * \brief NotEqual comparator. @@ -151,21 +170,27 @@ class DataType { * \param lanes The number of lanes. * \return The constructed data type. */ - static DataType Int(int bits, int lanes = 1) { return DataType(kDLInt, bits, lanes); } + static DataType Int(int bits, int lanes = 1, bool is_scalable = false) { + return DataType(kDLInt, bits, lanes, is_scalable); + } /*! * \brief Construct an uint type. * \param bits The number of bits in the type. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType UInt(int bits, int lanes = 1) { return DataType(kDLUInt, bits, lanes); } + static DataType UInt(int bits, int lanes = 1, bool is_scalable = false) { + return DataType(kDLUInt, bits, lanes, is_scalable); + } /*! * \brief Construct an float type. * \param bits The number of bits in the type. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType Float(int bits, int lanes = 1) { return DataType(kDLFloat, bits, lanes); } + static DataType Float(int bits, int lanes = 1, bool is_scalable = false) { + return DataType(kDLFloat, bits, lanes, is_scalable); + } /*! * \brief Construct an bfloat type. * \param bits The number of bits in the type. @@ -178,7 +203,9 @@ class DataType { * \param lanes The number of lanes * \return The constructed data type. */ - static DataType Bool(int lanes = 1) { return DataType::UInt(1, lanes); } + static DataType Bool(int lanes = 1, bool is_scalable = false) { + return DataType::UInt(1, lanes, is_scalable); + } /*! * \brief Construct a handle type. * \param bits The number of bits in the type. @@ -204,6 +231,7 @@ class DataType { } private: + bool is_scalable_{false}; DLDataType data_; }; @@ -285,6 +313,8 @@ inline DLDataType String2DLDataType(std::string s); */ inline std::string DLDataType2String(DLDataType t); +inline std::string VLADataType2String(DLDataType t); + // implementation details inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) { switch (static_cast(type_code)) { @@ -336,6 +366,17 @@ inline std::string DLDataType2String(DLDataType t) { return os.str(); } +inline std::string VLADataType2String(DataType t) { + if (t.bits() == 0) return ""; + std::ostringstream os; + os << t.operator DLDataType(); + // auto const str_to_parse = os.str(); + // auto pos = str_to_parse.find("x"); + // auto stem= str_to_parse.substr(0, pos); + os << "xVL"; + return os.str(); +} + inline DLDataType String2DLDataType(std::string s) { DLDataType t; // handle void type diff --git a/include/tvm/te/schedule.h b/include/tvm/te/schedule.h index 6f26d07dc8a5..6803b5c8c870 100644 --- a/include/tvm/te/schedule.h +++ b/include/tvm/te/schedule.h @@ -196,6 +196,13 @@ class Stage : public ObjectRef { * \return reference to self. */ TVM_DLL Stage& vectorize(IterVar var); // NOLINT(*) + + /*! + * \brief Vectorize iteration with a scalable vector length(VL). + * \param var The axis to be vectorized. + * \return reference to self. + */ + TVM_DLL Stage& vectorize_scalable(IterVar var); // NOLINT(*) /*! * \brief Replace computation of the current stage by tensor intrinsic f. * \param var The axis marks beginning of tensorization. diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 8ea48dd592d5..f87f154d35d2 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -768,6 +768,7 @@ class RampNode : public PrimExprNode { PrimExpr stride; /*! \brief Total number of lanes. */ int lanes; + bool is_scalable; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); @@ -779,7 +780,7 @@ class RampNode : public PrimExprNode { bool SEqualReduce(const RampNode* other, SEqualReducer equal) const { return equal(dtype, other->dtype) && equal(base, other->base) && equal(stride, other->stride) && - equal(lanes, other->lanes); + equal(lanes, other->lanes) && equal(is_scalable, other->is_scalable); } void SHashReduce(SHashReducer hash_reduce) const { @@ -787,6 +788,7 @@ class RampNode : public PrimExprNode { hash_reduce(base); hash_reduce(stride); hash_reduce(lanes); + hash_reduce(is_scalable); } static constexpr const char* _type_key = "tir.Ramp"; @@ -800,6 +802,7 @@ class RampNode : public PrimExprNode { class Ramp : public PrimExpr { public: TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span = Span()); + TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes, bool is_scalable, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Ramp, PrimExpr, RampNode); }; @@ -839,6 +842,7 @@ class BroadcastNode : public PrimExprNode { class Broadcast : public PrimExpr { public: TVM_DLL Broadcast(PrimExpr value, int lanes, Span span = Span()); + TVM_DLL Broadcast(PrimExpr value, int lanes, bool is_scalable, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Broadcast, PrimExpr, BroadcastNode); }; diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 9cf7d0a3cd1f..c7e47fc47adb 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -1110,7 +1110,8 @@ inline PrimExpr make_const(DataType t, ValueType value, Span span) { if (t.lanes() == 1) { return MakeConstScalar(t, value, span); } else { - return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span); + return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), t.is_scalable(), + span); } } diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 0da8e55be023..5563acc8bc39 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -783,7 +783,9 @@ enum class ForKind : int { * the loop is simply removed and the loop variable is * mapped to the corresponding context thread. */ - kThreadBinding = 4 + kThreadBinding = 4, + /*! \brief Loop is vectorized but the vector length (VL) is unknown. */ + kVectorizedScalable = 5 }; /*! @@ -822,6 +824,8 @@ class ForNode : public StmtNode { * and can be ignored in most passes. */ Map annotations; + bool is_vla; + int stride; void VisitAttrs(AttrVisitor* v) { v->Visit("loop_var", &loop_var); @@ -862,7 +866,8 @@ class For : public Stmt { public: TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, Optional thread_binding = NullOpt, - Map annotations = Map(), Span span = Span()); + Map annotations = Map(), + Span span = Span(), bool is_vla = false, int stride = 1); TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode); @@ -1369,6 +1374,8 @@ inline const char* ForKind2String(ForKind t) { return "parallel"; case ForKind::kVectorized: return "vectorized"; + case ForKind::kVectorizedScalable: + return "vectorized_scalable"; case ForKind::kUnrolled: return "unroll"; case ForKind::kThreadBinding: diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index d1308fe0059e..b8465be53135 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -121,6 +121,15 @@ TVM_DLL Pass LoopPartition(); */ TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true); +/*! + * \brief Lower vectorization loops. + * + * \param enable_vectorize Whether vectorization is enabled. + * + * \return The pass. + */ +TVM_DLL Pass VectorizeLoopScalable(bool enable_vectorize = true); + /*! * \brief Inject virtual thread loops. * diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 65c5c12a701b..0d73fcaa8e2c 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -229,7 +229,11 @@ enum IterVarType : int { /*! * \brief Marks boundary of tensorization intrinsic. */ - kTensorized = 8 + kTensorized = 8, + /*! + * \brief The loop is vectorized with a scalable vector length + */ + kVectorizedScalable = 9 }; /*! @@ -324,6 +328,8 @@ inline const char* IterVarType2String(IterVarType t) { return "Parallelized"; case kTensorized: return "Tensorized"; + case kVectorizedScalable: + return "VectorizedScalable"; } return "Unknown"; } diff --git a/python/tvm/autotvm/task/space.py b/python/tvm/autotvm/task/space.py index afbfb4c03988..37905480e5cf 100644 --- a/python/tvm/autotvm/task/space.py +++ b/python/tvm/autotvm/task/space.py @@ -602,7 +602,7 @@ def apply( elif ann == "vec": if vec_size and axis_lens[i] not in vec_size: cfg.raise_error("Wrong size of lanes in vectorization") - sch[op].vectorize(axes[i]) + sch[op].vectorize_scalable(axes[i]) elif ann == "blockIdx.x": sch[op].bind(axes[i], thread_axis("blockIdx.x")) elif ann == "blockIdx.y": diff --git a/python/tvm/te/hybrid/calls.py b/python/tvm/te/hybrid/calls.py index 462066106a9d..29c82381242e 100644 --- a/python/tvm/te/hybrid/calls.py +++ b/python/tvm/te/hybrid/calls.py @@ -34,6 +34,7 @@ "unroll": ForKind.UNROLLED, "parallel": ForKind.PARALLEL, "vectorize": ForKind.VECTORIZED, + "vectorize_scalable": ForKind.VECTORIZED_SCALABLE, "const_range": (ForKind.UNROLLED,), } diff --git a/python/tvm/te/schedule.py b/python/tvm/te/schedule.py index 7bd7dceb03e5..5aceb63c9cf6 100644 --- a/python/tvm/te/schedule.py +++ b/python/tvm/te/schedule.py @@ -379,10 +379,20 @@ def vectorize(self, var): Parameters ---------- var : IterVar - The iteration to be vectorize + The iteration to be vectorized """ _ffi_api.StageVectorize(self, var) + def vectorize_scalable(self, var): + """Vectorize the iteration. + + Parameters + ---------- + var : IterVar + The iteration to be vectorized + """ + _ffi_api.StageVectorizeScalable(self, var) + def tensorize(self, var, tensor_intrin): """Tensorize the computation enclosed by var with tensor_intrin diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 978c630b17ad..53e2355daccb 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -257,6 +257,8 @@ def _exit_cb(): kind_id = _stmt.ForKind.VECTORIZED elif kind == "unroll": kind_id = _stmt.ForKind.UNROLLED + elif kind == "vectorize_scalable": + kind_id = _stmt.ForKind.VECTORIZED_SCALABLE else: raise ValueError("Unknown kind") self.emit(_stmt.For(loop_var, begin, extent, kind_id, self._pop_seq())) diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index d57077f08b52..12eff9098851 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -106,6 +106,7 @@ class ForKind(IntEnum): VECTORIZED = 2 UNROLLED = 3 THREAD_BINDING = 4 + VECTORIZED_SCALABLE = 5 @tvm._ffi.register_object("tir.For") diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 537499a27fa9..c96f759163f6 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -169,6 +169,23 @@ def VectorizeLoop(enable_vectorize: bool = True): return _ffi_api.VectorizeLoop(enable_vectorize) # type: ignore +def VectorizeLoopScalable(enable_vectorize=True): + """Lower vectorization loops. + + Parameters + ---------- + enable_vectorize : bool + Whether vectorization is enabled. + Will lower to scalar loop when it is turned off. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.VectorizeLoopScalable(enable_vectorize) + + def InjectVirtualThread(): """Inject virtual thread loops. diff --git a/src/autotvm/feature_visitor.cc b/src/autotvm/feature_visitor.cc index 59cac9cc9827..dc8134e3179e 100644 --- a/src/autotvm/feature_visitor.cc +++ b/src/autotvm/feature_visitor.cc @@ -44,6 +44,9 @@ void FeatureVisitor::VisitStmt_(const ForNode* op) { case ForKind::kVectorized: ann = kVectorized; break; + case ForKind::kVectorizedScalable: + ann = kVectorizedScalable; + break; case ForKind::kSerial: ann = kSerial; break; diff --git a/src/autotvm/feature_visitor.h b/src/autotvm/feature_visitor.h index 8180839b0668..a4c519aa4750 100644 --- a/src/autotvm/feature_visitor.h +++ b/src/autotvm/feature_visitor.h @@ -49,6 +49,7 @@ enum AnnotationType { kThreadZ, kUnrolled, kVectorized, + kVectorizedScalable, kParallel, kSerial, kVirtualThread, diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 2008fe5e47b8..15ea90cd8af7 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -237,6 +237,7 @@ Array CreatePassList(bool disable_loop_partition, bool for pass_list.push_back(tir::transform::LoopPartition()); } + pass_list.push_back(tir::transform::VectorizeLoopScalable(!disable_vectorize)); pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize)); pass_list.push_back(tir::transform::InjectVirtualThread()); pass_list.push_back(tir::transform::InjectDoubleBuffer()); diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index f232994480f8..d31283d1144b 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -351,13 +351,22 @@ Doc TIRTextPrinter::VisitExpr_(const LoadNode* op) { Doc TIRTextPrinter::VisitExpr_(const RampNode* op) { Doc doc; - doc << "ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " << op->lanes << ")"; + if (op->is_scalable) { + doc << "ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " << op->lanes << "xVL" + << ")"; + } else { + doc << "ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " << op->lanes << ")"; + } return doc; } Doc TIRTextPrinter::VisitExpr_(const BroadcastNode* op) { Doc doc; - doc << "broadcast(" << Print(op->value) << ", " << op->lanes << ")"; + if (op->dtype.is_scalable()) { + doc << "broadcast(" << Print(op->value) << ", " << op->lanes << "xVL)"; + } else { + doc << "broadcast(" << Print(op->value) << ", " << op->lanes << ")"; + } return doc; } @@ -489,8 +498,14 @@ Doc TIRTextPrinter::VisitStmt_(const EvaluateNode* op) { Doc TIRTextPrinter::VisitStmt_(const ForNode* op) { Doc doc; - doc << "for (" << Print(op->loop_var) << ", " << Print(op->min) << ", " - << Print(op->min + op->extent) << ")"; + if (op->is_vla) { + doc << "for (" << Print(op->loop_var) << ", " << Print(op->min) << ", " + << Print(op->min + op->extent) << ", " << Print(op->loop_var) << "+=" << op->stride + << "xVL)"; + } else { + doc << "for (" << Print(op->loop_var) << ", " << Print(op->min) << ", " + << Print(op->min + op->extent) << ", " << Print(op->loop_var) << "++)"; + } if (op->kind != ForKind::kSerial) { doc << " " << Doc::StrLiteral(ForKind2String(op->kind)); } @@ -628,7 +643,12 @@ Doc TIRTextPrinter::VisitType_(const TupleTypeNode* node) { } Doc TIRTextPrinter::PrintDType(DataType dtype) { - return Doc::Text(runtime::DLDataType2String(dtype)); + if (dtype.is_scalable()) { + return Doc::Text(runtime::VLADataType2String(dtype)); + + } else { + return Doc::Text(runtime::DLDataType2String(dtype)); + } } template diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index e697d9b60273..117e60f67371 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -197,8 +197,8 @@ TVM_REGISTER_OP("tir.q_multiply_shift") ICHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32); ICHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32); - DataType hp_dtype = DataType::Int(64, x.dtype().lanes()); - DataType lp_dtype = DataType::Int(32, x.dtype().lanes()); + DataType hp_dtype = DataType::Int(64, x.dtype().lanes(), x.dtype().is_scalable()); + DataType lp_dtype = DataType::Int(32, x.dtype().lanes(), x.dtype().is_scalable()); // 1) Calculating the integer multiplier and integer shift PrimExpr zero = make_const(s.dtype(), 0); diff --git a/src/target/llvm/codegen_aarch64.cc b/src/target/llvm/codegen_aarch64.cc new file mode 100644 index 000000000000..e91924bba6ab --- /dev/null +++ b/src/target/llvm/codegen_aarch64.cc @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file codegen_arm.cc + * \brief ARM specific code generator + */ +#ifdef TVM_LLVM_VERSION + +#include + +#include "codegen_cpu.h" + +namespace tvm { +namespace codegen { + +// ------------------------- +// Utility functions to remove +void print_LLVM_type(llvm::Type* type) { + std::string type_str; + llvm::raw_string_ostream rso(type_str); + type->print(rso); + std::cout << rso.str() << std::endl; + ; +} + +void print_LLVM_val(llvm::Value* val) { + std::string type_str; + llvm::raw_string_ostream rso(type_str); + val->print(rso); + std::cout << rso.str() << std::endl; + ; +} +// ------------------------- + +// AArch64 code generation +class CodeGenAArch64 final : public CodeGenCPU { + public: + void InitTarget(llvm::TargetMachine* tm) final { + // set native vector bits. + native_vector_bits_ = 16 * 8; + CodeGenCPU::InitTarget(tm); + } + llvm::Value* VisitExpr_(const LoadNode* op); + void VisitStmt_(const ForNode* op); + void VisitStmt_(const StoreNode* op); + + private: + // SVE LLVM intrinsics + llvm::Value* sve_stride(int min_lanes); + llvm::Value* sve_whilelt(llvm::Value* a, llvm::Value* b, int min_lanes); + llvm::Value* sve_store(llvm::Value* ptr, llvm::Value* val, DataType t); + llvm::Value* sve_load(llvm::Value* ptr, DataType t); + void CreateSVEFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride, const Var& loop_var, + const Stmt& body, int min_lanes); + + // Predicate + llvm::Value* mask_; +}; + +llvm::Value* CodeGenAArch64::sve_stride(int min_lanes) { + llvm::Intrinsic::ID cnt_id; + + switch (min_lanes) { + case 16: + cnt_id = llvm::Function::lookupIntrinsicID("llvm.aarch64.sve.cntb"); + break; + case 8: // half + cnt_id = llvm::Function::lookupIntrinsicID("llvm.aarch64.sve.cnth"); + break; + case 4: // float + cnt_id = llvm::Function::lookupIntrinsicID("llvm.aarch64.sve.cntw"); + break; + default: // double + cnt_id = llvm::Function::lookupIntrinsicID("llvm.aarch64.sve.cntd"); + } + + // All pattern + int all_pattern = 31; + + llvm::Value* in_param = llvm::ConstantInt::get(*ctx_, llvm::APInt(32, all_pattern)); + std::vector arg_value{in_param}; + std::vector arg_type{builder_->getInt32Ty()}; + llvm::Type* return_type = builder_->getInt64Ty(); + llvm::Function* func_cnt = GetIntrinsicDecl(cnt_id, return_type, arg_type); + llvm::Value* vec_stride = builder_->CreateCall(func_cnt, arg_value); + llvm::Value* vec_stride_int32 = + builder_->CreateTruncOrBitCast(vec_stride, builder_->getInt32Ty()); + return vec_stride_int32; +} + +llvm::Value* CodeGenAArch64::sve_whilelt(llvm::Value* a, llvm::Value* b, int min_lanes) { + llvm::Intrinsic::ID whilelt_id = llvm::Function::lookupIntrinsicID("llvm.aarch64.sve.whilelt"); + std::vector arg_value{a, b}; + std::vector arg_type{builder_->getInt32Ty(), builder_->getInt32Ty()}; + + // Needs to be a vector type + llvm::Type* bool_type = llvm::Type::getIntNTy(*ctx_, 1); + llvm::Type* return_type = llvm::ScalableVectorType::get(bool_type, min_lanes); + + llvm::Function* func_whilelt = GetIntrinsicDecl(whilelt_id, return_type, arg_type); + llvm::Value* whilelt = builder_->CreateCall(func_whilelt, arg_value); + return whilelt; +} + +llvm::Value* CodeGenAArch64::sve_store(llvm::Value* ptr, llvm::Value* val, DataType t) { + // Get the intrinsic + llvm::Intrinsic::ID st1_id = llvm::Function::lookupIntrinsicID("llvm.aarch64.sve.st1"); + std::vector arg_value{val, mask_, ptr}; + + // Get the pointer type + llvm::PointerType* ptr_type = llvm::dyn_cast(ptr->getType()); + ICHECK(ptr_type != nullptr); + + // Input types + llvm::Type* mask_type = mask_->getType(); + llvm::Type* scalar_type = ptr_type->getElementType(); + llvm::Type* store_type = llvm::ScalableVectorType::get(scalar_type, t.lanes()); + std::vector arg_type{store_type, mask_type, ptr_type}; + + // Return type (void) + llvm::Type* return_type = llvm::Type::getVoidTy(*ctx_); + llvm::Function* func_st1 = GetIntrinsicDecl(st1_id, return_type, arg_type); + + // Create the call + llvm::Value* st1 = builder_->CreateCall(func_st1, arg_value); + return st1; +} + +llvm::Value* CodeGenAArch64::sve_load(llvm::Value* ptr, DataType t) { + llvm::Intrinsic::ID ld1_id = llvm::Function::lookupIntrinsicID("llvm.aarch64.sve.ld1"); + std::vector arg_value{mask_, ptr}; + llvm::Type* ptr_type = ptr->getType(); + llvm::Type* mask_type = mask_->getType(); + + std::vector arg_type{mask_type, ptr_type}; + llvm::PointerType* ptype = llvm::dyn_cast(ptr_type); + ICHECK(ptype != nullptr); + + llvm::Type* scalar_type = ptype->getElementType(); + llvm::Type* return_type = llvm::ScalableVectorType::get(scalar_type, t.lanes()); + llvm::Function* func_ld1 = GetIntrinsicDecl(ld1_id, return_type, arg_type); + + llvm::Value* ld1 = builder_->CreateCall(func_ld1, arg_value); + return ld1; +} + +void CodeGenAArch64::CreateSVEFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride, + const Var& loop_var, const Stmt& body, int min_lanes) { + using llvm::BasicBlock; + BasicBlock* for_begin = builder_->GetInsertBlock(); + BasicBlock* for_body = BasicBlock::Create(*ctx_, "for_body", function_); + BasicBlock* for_end = BasicBlock::Create(*ctx_, "for_end", function_); + + // for_begin block + builder_->SetInsertPoint(for_begin); + llvm::Value* vec_stride = sve_stride(min_lanes); + builder_->CreateBr(for_body); + + // for_body + builder_->SetInsertPoint(for_body); + llvm::PHINode* loop_value = builder_->CreatePHI(begin->getType(), 2); + mask_ = sve_whilelt(loop_value, end, min_lanes); + loop_value->addIncoming(begin, for_begin); + ICHECK(!var_map_.count(loop_var.get())); + var_map_[loop_var.get()] = loop_value; + + this->VisitStmt(body); + var_map_.erase(loop_var.get()); + llvm::Value* loop_next = CreateAdd(loop_var.dtype(), loop_value, vec_stride); + loop_value->addIncoming(loop_next, builder_->GetInsertBlock()); + builder_->CreateCondBr(CreateLT(loop_var.dtype(), loop_value, end), for_body, for_end, + md_very_likely_branch_); + builder_->SetInsertPoint(for_end); + function_->print(llvm::errs()); +} + +llvm::Value* CodeGenAArch64::VisitExpr_(const LoadNode* op) { + DataType t = op->dtype; + if (!t.is_scalable()) return CodeGenCPU::VisitExpr_(op); + llvm::Value* buffer = MakeValue(op->buffer_var); + + // scalable vector load + const RampNode* ramp = op->index.as(); + ICHECK(ramp); + // TODO(giuseros): use gather to address a load-with-stride-greater-than-1 + ICHECK(is_one(ramp->stride)); + + int alignment, native_bits; + GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits); + ICHECK_EQ(ramp->lanes, t.lanes()); + llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); + + llvm::Value* load = sve_load(ptr, t); + return load; +} + +void CodeGenAArch64::VisitStmt_(const StoreNode* op) { + ICHECK(is_one(op->predicate)) << op->predicate; + DataType t = op->value.dtype(); + bool is_volatile = volatile_buf_.count(op->buffer_var.get()); + llvm::Value* buffer = MakeValue(op->buffer_var); + llvm::Value* index = MakeValue(op->index); + llvm::Value* value = MakeValue(op->value); + + if (t.lanes() == 1) { + int alignment, native_bits; + GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits); + llvm::Value* ptr = CreateBufferPtr(t, buffer, index); +#if TVM_LLVM_VERSION >= 110 + llvm::StoreInst* store = + builder_->CreateAlignedStore(value, ptr, llvm::Align(alignment), is_volatile); +#else + llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile); +#endif + AddAliasInfo(store, op->buffer_var.get(), op->index); + return; + } else { + // vector store + unsigned addrspace = llvm::dyn_cast(buffer->getType())->getAddressSpace(); + if (const RampNode* ramp = op->index.as()) { + if (is_one(ramp->stride)) { + int alignment, native_bits; + GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits); + ICHECK_EQ(ramp->lanes, t.lanes()); + llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); + if (!t.is_scalable()) { + ptr = builder_->CreatePointerCast(ptr, DTypeToLLVMType(t)->getPointerTo(addrspace)); + } +#if TVM_LLVM_VERSION >= 110 + if (t.is_scalable()) { + sve_store(ptr, value, t); + return; + } + llvm::StoreInst* store = + builder_->CreateAlignedStore(value, ptr, llvm::Align(alignment), is_volatile); +#else + llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile); +#endif + AddAliasInfo(store, op->buffer_var.get(), op->index); + return; + } + } + } + ICHECK_GE(t.bits(), 8); + // scalarized store. + int basic_align = t.bits() / 8; + auto f = [&](int i, llvm::Value* index) { + llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index); +#if TVM_LLVM_VERSION >= 110 + llvm::StoreInst* store = builder_->CreateAlignedStore( + builder_->CreateExtractElement(value, i), ptr, llvm::Align(basic_align), is_volatile); +#else + llvm::StoreInst* store = builder_->CreateAlignedStore(builder_->CreateExtractElement(value, i), + ptr, basic_align, is_volatile); +#endif + AddAliasInfo(store, op->buffer_var.get(), PrimExpr()); + }; + this->Scalarize(op->index, f); +} + +void CodeGenAArch64::VisitStmt_(const ForNode* op) { + ICHECK(is_zero(op->min)); + analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + if (op->is_vla) { + CreateSVEFor(MakeValue(op->min), MakeValue(op->extent), + llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body, + op->stride); + } else { + CodeGenCPU::VisitStmt_(op); + } +} + +TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_aarch64") + .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + CodeGenLLVM* cg = new CodeGenAArch64(); + *rv = static_cast(cg); + }); + +} // namespace codegen +} // namespace tvm +#endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index b83748b784b6..79967cf9095b 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -403,6 +403,9 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { } if (dtype.lanes() != 1) { #if TVM_LLVM_VERSION >= 110 + if (dtype.is_scalable()) { + return llvm::ScalableVectorType::get(etype, dtype.lanes()); + } return llvm::FixedVectorType::get(etype, dtype.lanes()); #else return llvm::VectorType::get(etype, dtype.lanes()); @@ -558,9 +561,15 @@ std::unique_ptr CodeGenLLVM::CreateDebugInfo(llvm::Modul return debug_info; } -llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) { +llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes, bool is_scalable) { #if TVM_LLVM_VERSION >= 110 - llvm::Type* type = llvm::FixedVectorType::get(value->getType(), lanes); + llvm::Type* type; + + if (is_scalable) { + type = llvm::ScalableVectorType::get(value->getType(), lanes); + } else { + type = llvm::FixedVectorType::get(value->getType(), lanes); + } #else llvm::Type* type = llvm::VectorType::get(value->getType(), lanes); #endif @@ -568,10 +577,12 @@ llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) { llvm::Constant* zero = ConstInt32(0); value = builder_->CreateInsertElement(undef, value, zero); #if TVM_LLVM_VERSION >= 120 - llvm::Constant* mask = llvm::ConstantVector::getSplat(llvm::ElementCount::getFixed(lanes), zero); + llvm::Constant* mask = + (is_scalable ? llvm::ConstantVector::getSplat(llvm::ElementCount::getScalable(lanes), zero) + : llvm::ConstantVector::getSplat(llvm::ElementCount::getFixed(lanes), zero)); #elif TVM_LLVM_VERSION >= 110 llvm::Constant* mask = - llvm::ConstantVector::getSplat(llvm::ElementCount(lanes, /*Scalable=*/false), zero); + llvm::ConstantVector::getSplat(llvm::ElementCount(lanes, /*Scalable=*/is_scalable), zero); #else llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero); #endif @@ -1253,7 +1264,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) { - return CreateBroadcast(MakeValue(op->value), op->lanes); + return CreateBroadcast(MakeValue(op->value), op->lanes, op->dtype.is_scalable()); } void CodeGenLLVM::VisitStmt_(const StoreNode* op) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 52c5b98a0025..c93f03fb0cb5 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -280,7 +280,7 @@ class CodeGenLLVM : public ExprFunctor, llvm::Value* CreateAdd(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateSub(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateMul(DataType t, llvm::Value* a, llvm::Value* b); - llvm::Value* CreateBroadcast(llvm::Value* value, int lanes); + llvm::Value* CreateBroadcast(llvm::Value* value, int lanes, bool is_scalable = false); llvm::Value* CreateBufferPtr(DataType t, llvm::Value* buffer, llvm::Value* index); // Vector concatenation. llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent); diff --git a/src/te/operation/op_utils.cc b/src/te/operation/op_utils.cc index b3897e142545..86f615d57467 100644 --- a/src/te/operation/op_utils.cc +++ b/src/te/operation/op_utils.cc @@ -97,6 +97,9 @@ std::vector > MakeLoopNest(const Stage& stage, break; case kTensorized: break; + case kVectorizedScalable: + kind = ForKind::kVectorizedScalable; + break; default: LOG(FATAL) << "Unknown iter type" << it_attr->iter_type << " in the iter_var_attrs"; } diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index 8964c1013a53..2f2fe1457aed 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -353,6 +353,15 @@ Stage& Stage::vectorize(IterVar var) { // NOLINT(*) return *this; } +Stage& Stage::vectorize_scalable(IterVar var) { // NOLINT(*) + ICHECK(var->iter_type == kDataPar || var->iter_type == kOpaque || var->iter_type == kUnrolled || + var->iter_type == kVectorized || var->iter_type == kTensorized || + var->iter_type == kParallelized) + << "Cannot vectorize on " << IterVarType2String(var->iter_type); + SetAttrIterType(operator->(), var, kVectorizedScalable); + return *this; +} + Stage& Stage::tensorize(IterVar var, TensorIntrin f) { // NOLINT(*) UpdateIterVarAttr(operator->(), var, [f](IterVarAttrNode* n) { n->iter_type = kTensorized; @@ -867,6 +876,8 @@ TVM_REGISTER_GLOBAL("te.StageUnroll").set_body_method(&Stage::unroll); TVM_REGISTER_GLOBAL("te.StageVectorize").set_body_method(&Stage::vectorize); +TVM_REGISTER_GLOBAL("te.StageVectorizeScalable").set_body_method(&Stage::vectorize_scalable); + TVM_REGISTER_GLOBAL("te.StageTensorize").set_body_method(&Stage::tensorize); TVM_REGISTER_GLOBAL("te.StageParallel").set_body_method(&Stage::parallel); diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index afc5c36ebb92..5f01bdf11410 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -46,18 +46,18 @@ namespace tir { data_ = std::move(node); \ } -#define TVM_DEFINE_CMPOP_CONSTRUCTOR(Name) \ - Name::Name(PrimExpr a, PrimExpr b, Span span) { \ - using T = Name::ContainerType; \ - ICHECK(a.defined()) << "ValueError: a is undefined\n"; \ - ICHECK(b.defined()) << "ValueError: b is undefined\n"; \ - ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; \ - ObjectPtr node = make_object(); \ - node->dtype = DataType::Bool(a.dtype().lanes()); \ - node->a = std::move(a); \ - node->b = std::move(b); \ - node->span = std::move(span); \ - data_ = std::move(node); \ +#define TVM_DEFINE_CMPOP_CONSTRUCTOR(Name) \ + Name::Name(PrimExpr a, PrimExpr b, Span span) { \ + using T = Name::ContainerType; \ + ICHECK(a.defined()) << "ValueError: a is undefined\n"; \ + ICHECK(b.defined()) << "ValueError: b is undefined\n"; \ + ICHECK_EQ(a.dtype(), b.dtype()) << "TypeError: mismatched types\n"; \ + ObjectPtr node = make_object(); \ + node->dtype = DataType::Bool(a.dtype().lanes(), a.dtype().is_scalable()); \ + node->a = std::move(a); \ + node->b = std::move(b); \ + node->span = std::move(span); \ + data_ = std::move(node); \ } // Var @@ -188,7 +188,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Cast Cast::Cast(DataType t, PrimExpr value, Span span) { ICHECK(value.defined()); - ICHECK_EQ(t.lanes(), value.dtype().lanes()); + ICHECK((t.lanes() == value.dtype().lanes()) || (t.is_scalable() == value.dtype().is_scalable())); ObjectPtr node = make_object(); node->dtype = t; node->value = std::move(value); @@ -648,10 +648,16 @@ Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate, S // vectorized/non-vectorized arrays as needed. Ideally, these // should be changed to explicit casts in the TIR graph, rather than // being handled at the code-gen level. - ICHECK((dtype.lanes() == element_lanes * index.dtype().lanes()) || - (dtype.lanes() == index.dtype().lanes())); - ICHECK((dtype.lanes() == element_lanes * predicate.dtype().lanes()) || - (dtype.lanes() == index.dtype().lanes())); + if (!dtype.is_scalable() && !index.dtype().is_scalable()) { + ICHECK((dtype.lanes() == element_lanes * index.dtype().lanes()) || + (dtype.lanes() == index.dtype().lanes())); + } + if (!dtype.is_scalable() && !predicate.dtype().is_scalable()) { + ICHECK((dtype.lanes() == element_lanes * predicate.dtype().lanes()) || + (dtype.lanes() == index.dtype().lanes()) || + (dtype.lanes() == predicate.dtype().lanes())); + } + ObjectPtr node = make_object(); node->dtype = dtype; @@ -706,6 +712,23 @@ Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) { data_ = std::move(node); } +Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, bool is_scalable, Span span) { + ICHECK(base.defined()); + ICHECK(stride.defined()); + ICHECK(base.dtype().is_scalar()); + ICHECK(stride.dtype().is_scalar()); + ICHECK_EQ(stride.dtype(), base.dtype()); + + ObjectPtr node = make_object(); + node->dtype = base.dtype().with_lanes(lanes); + node->base = base; + node->stride = stride; + node->span = std::move(span); + node->is_scalable = is_scalable; // is_scalable means xVL + node->lanes = lanes; + data_ = std::move(node); +} + TVM_REGISTER_GLOBAL("tir.Ramp") .set_body_typed([](PrimExpr base, PrimExpr stride, int lanes, Span span) { return Ramp(base, stride, lanes, span); @@ -720,7 +743,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->Print(op->base); p->stream << ", "; p->Print(op->stride); - p->stream << ", " << op->lanes << ")"; + bool is_vla = op->is_scalable; + p->stream << ", " << op->lanes << (is_vla ? "xVL)" : ")"); }); // Broadcast @@ -737,6 +761,21 @@ Broadcast::Broadcast(PrimExpr value, int lanes, Span span) { data_ = node; } +// VLA Broadcast +Broadcast::Broadcast(PrimExpr value, int lanes, bool is_scalable, Span span) { + ICHECK(value.defined()); + ICHECK(value.dtype().is_scalar()); + ICHECK_GT(lanes, 1); + + ObjectPtr node = make_object(); + node->dtype = DataType(value.dtype().code(), value.dtype().bits(), lanes, is_scalable); + // value.dtype().with_lanes(lanes); + node->value = std::move(value); + node->lanes = lanes; + node->span = std::move(span); + data_ = node; +} + TVM_REGISTER_GLOBAL("tir.Broadcast").set_body_typed([](PrimExpr value, int lanes, Span span) { return Broadcast(value, lanes, span); }); @@ -746,7 +785,11 @@ TVM_REGISTER_NODE_TYPE(BroadcastNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); - p->stream << "x" << op->lanes << "("; + if (op->dtype.is_scalable()) { + p->stream << op->lanes << "xVL("; + } else { + p->stream << "x" << op->lanes << "("; + } p->Print(op->value); p->stream << ")"; }); diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 4c5ea5bfd2d0..ececa79de2aa 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -276,7 +276,7 @@ PrimExpr ExprMutator::VisitExpr_(const RampNode* op) { if (base.same_as(op->base) && stride.same_as(op->stride)) { return GetRef(op); } else { - return Ramp(base, stride, op->lanes); + return Ramp(base, stride, op->lanes, op->is_scalable); } } @@ -284,6 +284,8 @@ PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op) { PrimExpr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { return GetRef(op); + } else if (op->dtype.is_scalable()) { + return Broadcast(value, op->lanes, true); } else { return Broadcast(value, op->lanes); } diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index d59c94dc5753..20f89b4481e4 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -132,7 +132,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // For For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, - Optional thread_binding, Map annotations, Span span) { + Optional thread_binding, Map annotations, Span span, + bool is_vla, int stride) { ICHECK(min.defined()); ICHECK(extent.defined()); ICHECK(min.dtype().is_scalar()); @@ -149,6 +150,8 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, node->thread_binding = std::move(thread_binding); node->annotations = std::move(annotations); node->span = std::move(span); + node->is_vla = is_vla; + node->stride = stride; data_ = std::move(node); } @@ -178,6 +181,9 @@ std::ostream& operator<<(std::ostream& out, ForKind type) { // NOLINT(*) case ForKind::kThreadBinding: out << "launch_thread"; break; + case ForKind::kVectorizedScalable: + out << "vectorized_scalable"; + break; } return out; } @@ -190,6 +196,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->Print(op->min); p->stream << ", "; p->Print(op->extent); + if (op->is_vla) { + p->stream << ", " << op->loop_var << "+=" << op->stride << "xVL"; + } p->stream << ") {\n"; p->indent += 2; @@ -257,9 +266,11 @@ Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate, } ICHECK((value.dtype().lanes() == element_lanes * index.dtype().lanes()) || - (value.dtype().lanes() == index.dtype().lanes())); + (value.dtype().lanes() == index.dtype().lanes()) || + (value.dtype().is_scalable() == index.dtype().is_scalable())); ICHECK((value.dtype().lanes() == element_lanes * predicate.dtype().lanes()) || - (value.dtype().lanes() == index.dtype().lanes())); + (value.dtype().lanes() == index.dtype().lanes()) || + (value.dtype().is_scalable() == predicate.dtype().is_scalable())); ObjectPtr node = make_object(); node->buffer_var = std::move(buffer_var); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index aca6d1b50b0e..e444d4fdacd7 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -108,8 +108,13 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) lhs = tir::Broadcast(lhs, rtype.lanes()); } else if (rtype.lanes() == 1 && ltype.lanes() != 1) { rhs = tir::Broadcast(rhs, ltype.lanes()); + } else if (rtype.is_scalable() && ltype.is_scalar()) { + rhs = tir::Broadcast(rhs, ltype.lanes(), true); + } else if (rtype.is_scalar() && ltype.is_scalable()) { + rhs = tir::Broadcast(rhs, ltype.lanes(), true); } else { - ICHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << " vs " << rtype; + ICHECK(ltype.lanes() == rtype.lanes() || (ltype.is_scalable() && rtype.is_scalable())) + << "Cannot match type " << ltype << " vs " << rtype; } if (lhs.dtype() == rhs.dtype()) return; @@ -288,7 +293,7 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span) { } return tir::Broadcast(value, t.lanes(), span); } else { - ICHECK(value.dtype().lanes() == t.lanes()); + ICHECK(value.dtype().lanes() == t.lanes() || value.dtype().is_scalable() == t.is_scalable()); return tir::Cast(t, value, span); } } diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 2555002d29b0..14cd529cf106 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -256,7 +256,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { }; if (should_swap()) { - PrimExpr new_bcast = Broadcast(cast->value, bcast->lanes); + PrimExpr new_bcast = Broadcast(cast->value, bcast->lanes, bcast->dtype.is_scalable()); return Cast(bcast->dtype, new_bcast); } } diff --git a/src/tir/transforms/vectorize_loop_scalable.cc b/src/tir/transforms/vectorize_loop_scalable.cc new file mode 100644 index 000000000000..8c303313957d --- /dev/null +++ b/src/tir/transforms/vectorize_loop_scalable.cc @@ -0,0 +1,613 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file vectorize_loop_scalable.cc + */ +// Loop vectorizer as in Halide pipeline. +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace tir { + +inline PrimExpr BroadcastToVL(PrimExpr e, int min_num_lanes) { + if (e.dtype().is_scalable()) return e; + // In the VLA world a ramp is always scalable + if (e.as()) { + return e; + } + if (const BroadcastNode* op = e.as()) { + return Broadcast(op->value, min_num_lanes, true); + } + return Broadcast(e, min_num_lanes, true); +} + +// Rewrite vectorized allocation access +// This is necessary for making each vector component containing its own workspace. +// Originates from Halide's loop vectorizer +// +// s[i] = s[i * lanes + var] +// +// The same principle applies when using one thread to simulate multiple context. +// +class VecAllocAccess : public StmtExprMutator { + public: + VecAllocAccess(const VarNode* buf, Var var, int var_lanes) + : buf_(buf), var_(var), var_lanes_(var_lanes) {} + // Load + PrimExpr VisitExpr_(const LoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + // type_ = expr->dtype; + if (op->buffer_var.get() == buf_) { + return Load(op->dtype, op->buffer_var, op->index * var_lanes_ + var_, op->predicate); + } else { + return expr; + } + } + // Store + Stmt VisitStmt_(const StoreNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + if (op->buffer_var.get() == buf_) { + return Store(op->buffer_var, op->value, op->index * var_lanes_ + var_, op->predicate); + } else { + return stmt; + } + } + + private: + // buffer var + const VarNode* buf_; + // variable to be replaced + Var var_; + // the lanes. + int var_lanes_; + // The type + // DataType type_; +}; + +// We use ExprFunctor directly instead of StmtExprMutator +// This is because the transformation can change the dtype of the Expr +// The existing ExprMutator transformation rules may not be well defined. +class VectorizerVLA : public StmtMutator, public ExprFunctor { + public: + using ExprFunctor::VisitExpr; + using StmtMutator::operator(); + + VectorizerVLA(Var var, PrimExpr min, int var_lanes) + : var_(var), min_(min), var_lanes_(var_lanes) { + // ramp_ = Ramp(var_, 1); + } + + Stmt VisitStmt(const Stmt& stmt) final { + ICHECK(!need_scalarize_); + Stmt ret = StmtMutator::VisitStmt(stmt); + if (need_scalarize_) { + need_scalarize_ = false; + return Scalarize(stmt); + } else { + return ret; + } + } + + PrimExpr VisitExpr(const PrimExpr& e) final { return ExprFunctor::VisitExpr(e); } + + PrimExpr VisitExpr_(const AddNode* op) final { + return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a + b; }); + } + + PrimExpr VisitExpr_(const SubNode* op) final { + return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a - b; }); + } + + PrimExpr VisitExpr_(const MulNode* op) final { + PrimExpr a = this->VisitExpr(op->a); + PrimExpr b = this->VisitExpr(op->b); + if (a.same_as(op->a) && b.same_as(op->b)) { + return GetRef(op); + } else { + int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); + if (lanes != 1) { + const RampNode* b_ramp = b.as(); + const RampNode* a_ramp = a.as(); + + // This happens when we have a stride*i index into the tensor with stride >1 + // In this case scalarize, since for now it is not supported + if ((a_ramp && b.dtype().lanes() == 1 && analyzer_.CanProve(b > 0)) || + (b_ramp && a.dtype().lanes() == 1 && analyzer_.CanProve(a > 0))) { + need_scalarize_ = true; + return GetRef(op); + } + } + return Mul(BroadcastToVL(a, type_.lanes()), BroadcastToVL(b, type_.lanes())); + } + return BinaryVec(op); + } + PrimExpr VisitExpr_(const DivNode* op) final { return BinaryVec
(op); } + PrimExpr VisitExpr_(const ModNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const FloorDivNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const FloorModNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const MinNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const MaxNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const EQNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const NENode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const LTNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const LENode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const GTNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const GENode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const AndNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const OrNode* op) final { return BinaryVec(op); } + + PrimExpr VisitExpr_(const NotNode* op) final { + PrimExpr a = this->VisitExpr(op->a); + if (a.same_as(op->a)) { + return GetRef(op); + } else { + return !(a); + } + } + + PrimExpr VisitExpr_(const RampNode* op) final { + // This happens when the data tensor is a vector type. We scalarize in this + // case + need_scalarize_ = true; + return GetRef(op); + } + + PrimExpr VisitExpr_(const BroadcastNode* op) final { + PrimExpr value = this->VisitExpr(op->value); + if (value.dtype().lanes() != 1) { + need_scalarize_ = true; + return GetRef(op); + } + if (value.same_as(op->value)) { + return GetRef(op); + } else { + return Broadcast(op->value, op->lanes); + } + } + + PrimExpr VisitExpr_(const SelectNode* op) final { + PrimExpr cond = this->VisitExpr(op->condition); + PrimExpr t = this->VisitExpr(op->true_value); + PrimExpr f = this->VisitExpr(op->false_value); + if (cond.same_as(op->condition) && t.same_as(op->true_value) && f.same_as(op->false_value)) { + return GetRef(op); + } else { + return Select(cond, BroadcastToVL(t, type_.lanes()), BroadcastToVL(f, type_.lanes())); + } + } + PrimExpr VisitExpr_(const CastNode* op) final { + PrimExpr value = this->VisitExpr(op->value); + if (value.same_as(op->value)) { + return GetRef(op); + } else { + auto base_type = op->dtype; + auto variable_type = DataType(base_type.code(), base_type.bits(), type_.lanes(), true); + return Cast(variable_type, value); + } + } + + PrimExpr VisitExpr_(const FloatImmNode* op) final { return GetRef(op); } + + PrimExpr VisitExpr_(const IntImmNode* op) final { return GetRef(op); } + + PrimExpr VisitExpr_(const StringImmNode* op) final { return GetRef(op); } + + // Variable + PrimExpr VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + + if (var.same_as(var_)) { + return Ramp(var_, 1, type_.lanes(), true); + } + auto it = let_binding_.find(var); + if (it != let_binding_.end()) { + return it->second; + } else { + return std::move(var); + } + } + // IfThenElse expr + PrimExpr MutateIfThenElseExpr_(const CallNode* op) { + PrimExpr cond = this->VisitExpr(op->args[0]); + if (cond.dtype().is_vector()) { + need_scalarize_ = true; + return GetRef(op); + } + PrimExpr t = this->VisitExpr(op->args[1]); + PrimExpr f = this->VisitExpr(op->args[2]); + if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && f.same_as(op->args[2])) { + return GetRef(op); + } else { + t = BroadcastToVL(t, type_.lanes()); + f = BroadcastToVL(f, type_.lanes()); + return Call(op->dtype.with_scalable_lanes(), op->op, {cond, t, f}); + } + } + // Call + PrimExpr VisitExpr_(const CallNode* op) final { + // TODO (giuseros): we could remove @tir.likely since we are using + // predication. It is not trivial to do that here (would be simpler to do when we split) + // but should be doable. + if (op->op.same_as(builtin::if_then_else())) { + type_ = op->dtype.with_scalable_lanes(); + return MutateIfThenElseExpr_(op); + } + auto* op_ptr = op->op.as(); + bool vectorizable = op_ptr && op_vectorizable_.get(GetRef(op_ptr), false); + + if (!vectorizable) { + // Cannot vectorize this op + Array new_args; + for (auto arg : op->args) { + auto new_arg = this->VisitExpr(arg); + if (new_arg.dtype().is_vector()) { + need_scalarize_ = true; + return GetRef(op); + } + new_args.push_back(new_arg); + } + if (op->args.same_as(new_args)) { + return GetRef(op); + } else { + return Call(op->dtype, op->op, new_args); + } + } else { + int lane = 0; + Array new_args = MutateArray(op->args, &lane); + // normal code path. + if (op->args.same_as(new_args)) { + return GetRef(op); + } else { + auto base_type = op->dtype; + return Call(DataType(base_type.code(), base_type.bits(), type_.lanes(), true), op->op, + new_args); + } + } + } + // Load + PrimExpr VisitExpr_(const LoadNode* op) final { + DataType base_type = op->dtype; + auto load_type = DataType(base_type.code(), base_type.bits(), type_.lanes(), true); + PrimExpr index = this->VisitExpr(op->index); + PrimExpr pred = this->VisitExpr(op->predicate); + if (index.same_as(op->index) && pred.same_as(op->predicate)) { + return GetRef(op); + } else { + // int lanes = std::max(index.dtype().lanes(), pred.dtype().lanes()); + return Load(load_type, op->buffer_var, BroadcastToVL(index, type_.lanes()), + BroadcastToVL(pred, type_.lanes())); + } + } + // Let + PrimExpr VisitExpr_(const LetNode* op) final { + PrimExpr value = this->VisitExpr(op->value); + // Weaker SSA condition + // A single var can be binded in multiple lets + // but they have to bind to the same value. + // This is used to allow cases when we reuse a single let + // expression to cosntruct a nested expr. + // (let x = 1 in x + 1) * (let x = 1 in x + 1) + auto it = let_binding_.find(op->var); + if (it != let_binding_.end()) { + ICHECK(deep_equal_(it->second, value)) + << "Let cannot bind the same var to two different values"; + } + if (value.dtype().lanes() != op->value.dtype().lanes()) { + Var new_var(op->var->name_hint, value.dtype().with_scalable_lanes()); + let_binding_[op->var] = new_var; + return Let(new_var, value, this->VisitExpr(op->body)); + } else { + let_binding_[op->var] = op->var; + PrimExpr body = this->VisitExpr(op->body); + if (value.same_as(op->value) && body.same_as(op->body)) { + return GetRef(op); + } else { + return Let(op->var, value, body); + } + } + } + // Store + Stmt VisitStmt_(const StoreNode* op) final { + // type_ = op->buffer_var->dtype.with_scalable_lanes(); + DataType base_type = op->value.dtype(); + type_ = + DataType(base_type.code(), base_type.bits(), min_vector_len_bits_ / base_type.bits(), true); + + PrimExpr value = this->VisitExpr(op->value); + // type_ = value.dtype().with_scalable_lanes(); + + PrimExpr index = this->VisitExpr(op->index); + PrimExpr pred = this->VisitExpr(op->predicate); + if (value.same_as(op->value) && index.same_as(op->index)) { + return GetRef(op); + } else { + int min_lanes = type_.lanes(); + auto vla_loop_body = Store(op->buffer_var, BroadcastToVL(value, min_lanes), + BroadcastToVL(index, min_lanes), BroadcastToVL(pred, min_lanes)); + if (need_loop_) { + need_loop_ = false; + return For(var_, min_, var_lanes_, ForKind::kSerial, vla_loop_body, NullOpt, + Map(), Span(), true, type_.lanes()); + + // TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, + // Optional thread_binding = NullOpt, + // Map annotations = Map(), Span span = Span(), + // bool is_vla = false, int stride = 1); + } else { + return vla_loop_body; + } + } + } + // For + Stmt VisitStmt_(const ForNode* op) final { + // TODO(giuseros): Add a configuration parameter to enable + // For loop vectorization. For VLA it boils down to have a + // gather primitive in LLVM + return Scalarize(GetRef(op)); + } + + // IfThenElse + Stmt VisitStmt_(const IfThenElseNode* op) final { + ICHECK(!op->condition.dtype().is_vector()); + // Evaluating then_case first to get to the data type + Stmt then_case = this->VisitStmt(op->then_case); + + PrimExpr condition = this->VisitExpr(op->condition); + if (condition.dtype().is_vector()) { + return Scalarize(GetRef(op)); + } + Stmt else_case; + if (op->else_case.defined()) { + else_case = this->VisitStmt(op->else_case); + } + if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && + else_case.same_as(op->else_case)) { + return GetRef(op); + } else { + return IfThenElse(condition, then_case, else_case); + } + } + // LetStmt + Stmt VisitStmt_(const LetStmtNode* op) final { + PrimExpr value = this->VisitExpr(op->value); + ICHECK(!let_binding_.count(op->var)) << "SSA violation, a single var is binded twice"; + let_binding_[op->var] = value; + + if (value.dtype().lanes() != op->value.dtype().lanes()) { + Var new_var(op->var->name_hint, value.dtype().with_scalable_lanes()); + let_binding_[op->var] = new_var; + need_loop_ = false; + auto let_stmt = LetStmt(new_var, value, this->VisitStmt(op->body)); + return For(var_, min_, var_lanes_, ForKind::kSerial, let_stmt, NullOpt, + Map(), Span(), true, type_.lanes()); + } else { + let_binding_[op->var] = op->var; + Stmt body = this->VisitStmt(op->body); + if (value.same_as(op->value) && body.same_as(op->body)) { + return GetRef(op); + } else { + return LetStmt(op->var, value, body); + } + } + } + // Allocate + Stmt VisitStmt_(const AllocateNode* op) final { + PrimExpr condition = this->VisitExpr(op->condition); + if (condition.dtype().is_vector()) { + LOG(WARNING) << "Cannot handle vector extent in alloc "; + return Scalarize(GetRef(op)); + } + Array extents; + for (size_t i = 0; i < op->extents.size(); i++) { + PrimExpr new_ext = this->VisitExpr(op->extents[i]); + if (new_ext.dtype().is_vector()) { + LOG(WARNING) << "Cannot handle vector extent in alloc "; + return Scalarize(GetRef(op)); + } + extents.push_back(new_ext); + } + // place the vector lanes in least significant dimension. + extents.push_back(var_lanes_); + // rewrite access to buffer internally. + Stmt body = VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body); + body = this->VisitStmt(body); + return Allocate(op->buffer_var, op->dtype, extents, condition, body); + } + + // scalarize the statment + Stmt Scalarize(Stmt stmt) { + Var idx(var_->name_hint + ".s", var_->dtype); + Map values{{var_, idx}}; + stmt = Substitute(stmt, values); + return For(idx, 0, var_lanes_, ForKind::kSerial, stmt); + } + // ProducerStore + Stmt VisitStmt_(const ProducerStoreNode* op) final { + LOG(FATAL) << "ProducerProvide is cannot appear in a TIR PrimFunc"; + return Stmt(); + } + + DataType vla_type() { return type_; } + + int extent() { return var_lanes_; } + + private: + // analyzer + arith::Analyzer analyzer_; + // deep equal + ExprDeepEqual deep_equal_; + // variable to be replaced + Var var_; + // the lanes. + PrimExpr min_; + int var_lanes_; + // ramp representing the var. + PrimExpr ramp_; + DataType type_; + // flag to mark requirment of scalarization. + bool need_scalarize_{false}; + bool need_loop_{true}; + // Should be configured + int min_vector_len_bits_{128}; + int scalable_lanes_; + // Let binding + std::unordered_map let_binding_; + // vectorizable property + OpAttrMap op_vectorizable_ = Op::GetAttrMap("TVectorizable"); + + // mutate array, with given lane requirement + // when finished, p_lane updates the lane requirement. + Array MutateArray(Array arr, int* p_lanes) { + if (arr.size() == 0) return arr; + int& lanes = *p_lanes; + bool changed = false; + std::vector new_arr(arr.size()); + for (size_t i = 0; i < arr.size(); i++) { + PrimExpr old_elem = arr[i]; + PrimExpr new_elem = this->VisitExpr(old_elem); + if (!new_elem.same_as(old_elem)) changed = true; + new_arr[i] = new_elem; + lanes = std::max(lanes, new_elem.dtype().lanes()); + } + + for (size_t i = 0; i < arr.size(); ++i) { + if (new_arr[i].dtype().lanes() != lanes) { + new_arr[i] = BroadcastToVL(new_arr[i], type_.lanes()); + changed = true; + } + } + if (!changed) return arr; + return Array(new_arr); + } + template + PrimExpr BinaryVec(const T* op) { + static_assert(std::is_same::value, "constraint"); + PrimExpr a = this->VisitExpr(op->a); + PrimExpr b = this->VisitExpr(op->b); + if (a.same_as(op->a) && b.same_as(op->b)) { + return GetRef(op); + } else { + auto ba = BroadcastToVL(a, type_.lanes()); + auto bb = BroadcastToVL(b, type_.lanes()); + auto bin_op = TOp(ba, bb); + return bin_op; + } + } + template + PrimExpr AddSubVec(const T* op, FCompute fcompute) { + PrimExpr a = this->VisitExpr(op->a); + PrimExpr b = this->VisitExpr(op->b); + if (a.same_as(op->a) && b.same_as(op->b)) { + return GetRef(op); + } else { + int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); + if (lanes != 1) { + const RampNode* b_ramp = b.as(); + const RampNode* a_ramp = a.as(); + if (a.dtype().lanes() == 1 && b_ramp) { + PrimExpr new_stride = fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride); + + if (analyzer_.CanProve(new_stride != 1)) { + // TODO(giuros01): add support for gather also when stride != 1 + need_scalarize_ = true; + return GetRef(op); + } + + return Ramp(fcompute(a, b_ramp->base), new_stride, b_ramp->lanes, true); + } + if (b.dtype().lanes() == 1 && a_ramp) { + return Ramp(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes, true); + } + } + return fcompute(BroadcastToVL(a, type_.lanes()), BroadcastToVL(b, type_.lanes())); + } + } +}; + +class LoopVectorizerVLA : public StmtMutator { + public: + Stmt VisitStmt_(const ForNode* op) final { + if (op->kind == ForKind::kVectorizedScalable) { + ICHECK(is_zero(op->min)); + auto* extent_as_int = op->extent.as(); + if (!extent_as_int || extent_as_int->value < 1) { + LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent; + } + VectorizerVLA vla_vectorizer(op->loop_var, op->min, static_cast(extent_as_int->value)); + auto vla_loop_body = vla_vectorizer(op->body); + return vla_loop_body; + } else { + return StmtMutator::VisitStmt_(op); + } + } +}; + +Stmt VectorizeLoopScalable(Stmt stmt) { return LoopVectorizerVLA()(std::move(stmt)); } + +class VectorizeVLASkipper : public StmtMutator { + public: + Stmt VisitStmt_(const ForNode* op) final { + Stmt stmt = StmtMutator::VisitStmt_(op); + op = stmt.as(); + if (op->kind == ForKind::kVectorized) { + return For(op->loop_var, op->min, op->extent, ForKind::kSerial, op->body); + } else { + return stmt; + } + } +}; + +Stmt SkipVectorizeScalable(Stmt stmt) { return VectorizeVLASkipper()(std::move(stmt)); } + +namespace transform { + +// TODO(tvm-team): Make it as a target property. +Pass VectorizeLoopScalable(bool enable_vectorize) { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + if (enable_vectorize) { + n->body = LoopVectorizerVLA()(std::move(n->body)); + } else { + n->body = VectorizeVLASkipper()(std::move(n->body)); + } + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.VectorizeLoopScalable", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.VectorizeLoopScalable").set_body_typed(VectorizeLoopScalable); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_target_codegen_arm.py b/tests/python/unittest/test_target_codegen_arm.py index b5c69d6df1a6..aa8156dc8377 100644 --- a/tests/python/unittest/test_target_codegen_arm.py +++ b/tests/python/unittest/test_target_codegen_arm.py @@ -21,6 +21,93 @@ import ctypes +def test_llvm_flip_pipeline_sve(): + target = "llvm -device=arm_cpu -mtriple=aarch64-gnu-linux -mattr=v8.2a,+sve" + + def check_llvm(nn, base): + n = tvm.runtime.convert(nn) + A = te.placeholder((n + base), name="A") + C = te.compute((n,), lambda i: A(nn + base - i - 1), name="C") + s = te.create_schedule(C.op) + xo, xi = s[C].split(C.op.axis[0], factor=4) + s[C].parallel(xo) + s[C].vectorize_scalable(xi) + + # build and invoke the kernel. + f = tvm.build(s, [A, C], target) + + # ctx = remote.context(target) + # # launch the kernel. + # n = nn + # a = tvm.nd.array(np.random.uniform(size=(n + base)).astype(A.dtype), ctx) + # c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) + # f(a, c) + # tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy()[::-1][:n]) + + check_llvm(4, 0) + check_llvm(128, 8) + check_llvm(3, 0) + check_llvm(128, 1) + + +def test_llvm_vadd_pipeline_sve(): + target = "llvm -device=arm_cpu -mtriple=aarch64-gnu-linux -mattr=v8.2a,+sve" + + def check_llvm(n, lanes): + A = te.placeholder((n,), name="A", dtype="float32x%d" % lanes) + B = te.compute((n,), lambda i: A[i], name="B") + C = te.compute((n,), lambda i: B[i] + tvm.tir.const(1, A.dtype), name="C") + s = te.create_schedule(C.op) + xo, xi = s[C].split(C.op.axis[0], nparts=2) + _, xi = s[C].split(xi, factor=2) + s[C].parallel(xo) + s[C].vectorize_scalable(xi) + s[B].compute_at(s[C], xo) + xo, xi = s[B].split(B.op.axis[0], factor=2) + s[B].vectorize_scalable(xi) + # build and invoke the kernel. + f = tvm.build(s, [A, C], target) + + # ctx = remote.context(target) + # # launch the kernel. + # a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(np.random.uniform(size=(n, lanes))) + # c = tvm.nd.empty((n,), C.dtype, ctx) + # f(a, c) + # tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 1) + + check_llvm(64, 2) + check_llvm(512, 2) + + +def test_llvm_madd_pipeline_sve(): + target = "llvm -device=arm_cpu -mtriple=aarch64-gnu-linux -mattr=v8.2a,+sve" + + def check_llvm(nn, base, stride): + n = tvm.runtime.convert(nn) + A = te.placeholder((n + base, stride), name="A") + C = te.compute((n, stride), lambda i, j: A(base + i, j) + 1, name="C") + s = te.create_schedule(C.op) + xo, xi = s[C].split(C.op.axis[0], factor=4) + s[C].parallel(xo) + s[C].vectorize_scalable(xi) + # build and invoke the kernel. + f = tvm.build(s, [A, C], target) + + # ctx = remote.context(target) + # # launch the kernel. + # n = nn + # a = tvm.nd.array(np.random.uniform(size=(n + base, stride)).astype(A.dtype), ctx) + # c = tvm.nd.array(np.zeros((n, stride), dtype=C.dtype), ctx) + # f(a, c) + # tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy()[base:] + 1) + + check_llvm(64, 0, 2) + check_llvm(4, 0, 1) + + with tvm.transform.PassContext(config={"tir.noalias": False}): + check_llvm(4, 0, 3) + + def test_popcount(): target = "llvm -mtriple=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon" diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py deleted file mode 100644 index 10cbcd68f362..000000000000 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ /dev/null @@ -1,822 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import collections -import ctypes -import json -import sys - -import tvm -import tvm.testing -from tvm import te -from tvm import topi -from tvm.contrib import utils -import numpy as np -import ctypes -import math -import re -import pytest - - -@tvm.testing.requires_llvm -def test_llvm_intrin(): - ib = tvm.tir.ir_builder.create() - n = tvm.runtime.convert(4) - A = ib.pointer("float32", name="A") - args = [tvm.tir.call_intrin("handle", "tir.address_of", A[0]), 0, 3, 1] - ib.emit(tvm.tir.Evaluate(tvm.tir.Call("int32", "tir.prefetch", args))) - body = ib.get() - - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "prefetch")) - fcode = tvm.build(mod, None, "llvm") - - -@tvm.testing.requires_llvm -def test_llvm_void_intrin(): - ib = tvm.tir.ir_builder.create() - A = ib.pointer("uint8", name="A") - # Create an intrinsic that returns void. - x = tvm.tir.call_llvm_intrin("", "llvm.va_start", tvm.tir.const(1, "uint32"), A) - ib.emit(x) - body = ib.get() - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "main")) - fcode = tvm.build(mod, None, "llvm") - - -@tvm.testing.requires_llvm -def test_llvm_overloaded_intrin(): - # Name lookup for overloaded intrinsics in LLVM 4- requires a name - # that includes the overloaded types. - if tvm.target.codegen.llvm_version_major() < 5: - return - - def use_llvm_intrinsic(A, C): - ib = tvm.tir.ir_builder.create() - L = A.vload((0, 0)) - I = tvm.tir.call_llvm_pure_intrin( - "int32", "llvm.ctlz", tvm.tir.const(2, "uint32"), L, tvm.tir.const(0, "int1") - ) - S = C.vstore((0, 0), I) - ib.emit(S) - return ib.get() - - A = tvm.te.placeholder((1, 1), dtype="int32", name="A") - C = tvm.te.extern( - (1, 1), [A], lambda ins, outs: use_llvm_intrinsic(ins[0], outs[0]), name="C", dtype="int32" - ) - s = tvm.te.create_schedule(C.op) - f = tvm.build(s, [A, C], target="llvm") - - -@tvm.testing.requires_llvm -def test_llvm_lookup_intrin(): - ib = tvm.tir.ir_builder.create() - A = ib.pointer("uint8x8", name="A") - z = tvm.tir.const(0, "int32") - x = tvm.tir.call_llvm_pure_intrin( - "uint8x8", "llvm.ctpop.v8i8", tvm.tir.const(1, "uint32"), A[z] - ) - ib.emit(x) - body = ib.get() - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "main")) - fcode = tvm.build(mod, None, "llvm") - - -@tvm.testing.requires_llvm -def test_llvm_large_uintimm(): - value = (1 << 63) + 123 - other = tvm.tir.const(3, "uint64") - A = te.compute((), lambda: tvm.tir.const(value, "uint64") + other, name="A") - s = te.create_schedule(A.op) - - def check_llvm(): - f = tvm.build(s, [A], "llvm") - dev = tvm.cpu(0) - # launch the kernel. - a = tvm.nd.empty((), dtype=A.dtype, device=dev) - f(a) - assert a.numpy() == value + 3 - - check_llvm() - - -@tvm.testing.requires_llvm -def test_llvm_persist_parallel(): - n = 128 - A = te.placeholder((n,), name="A") - B = te.compute(A.shape, lambda *i: A(*i) + 1, name="B") - C = te.compute(A.shape, lambda *i: te.sqrt(B(*i)) * 2 + 2, name="C") - s = te.create_schedule(C.op) - xo, xi = s[C].split(C.op.axis[0], factor=8) - xo1, xo2 = s[C].split(xo, nparts=1) - s[B].compute_at(s[C], xo1) - s[B].parallel(s[B].op.axis[0]) - s[B].pragma(s[B].op.axis[0], "parallel_barrier_when_finish") - s[C].parallel(xi) - s[C].pragma(xo1, "parallel_launch_point") - s[C].pragma(xi, "parallel_stride_pattern") - - def check_llvm(): - # BUILD and invoke the kernel. - f = tvm.build(s, [A, C], "llvm") - dev = tvm.cpu(0) - # launch the kernel. - a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) - c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev) - f(a, c) - tvm.testing.assert_allclose(c.numpy(), np.sqrt(a.numpy() + 1) * 2 + 2, rtol=1e-5) - - check_llvm() - - -@tvm.testing.requires_llvm -def test_llvm_flip_pipeline(): - def check_llvm(nn, base): - n = tvm.runtime.convert(nn) - A = te.placeholder((n + base), name="A") - C = te.compute((n,), lambda i: A(nn + base - i - 1), name="C") - s = te.create_schedule(C.op) - xo, xi = s[C].split(C.op.axis[0], factor=4) - s[C].parallel(xo) - s[C].vectorize(xi) - # build and invoke the kernel. - f = tvm.build(s, [A, C], "llvm") - dev = tvm.cpu(0) - # launch the kernel. - n = nn - a = tvm.nd.array(np.random.uniform(size=(n + base)).astype(A.dtype), dev) - c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev) - f(a, c) - tvm.testing.assert_allclose(c.numpy(), a.numpy()[::-1][:n]) - - check_llvm(4, 0) - check_llvm(128, 8) - check_llvm(3, 0) - check_llvm(128, 1) - - -@tvm.testing.requires_llvm -def test_llvm_vadd_pipeline(): - def check_llvm(n, lanes): - A = te.placeholder((n,), name="A", dtype="float32x%d" % lanes) - B = te.compute((n,), lambda i: A[i], name="B") - C = te.compute((n,), lambda i: B[i] + tvm.tir.const(1, A.dtype), name="C") - s = te.create_schedule(C.op) - xo, xi = s[C].split(C.op.axis[0], nparts=2) - _, xi = s[C].split(xi, factor=2) - s[C].parallel(xo) - s[C].vectorize(xi) - s[B].compute_at(s[C], xo) - xo, xi = s[B].split(B.op.axis[0], factor=2) - s[B].vectorize(xi) - # build and invoke the kernel. - f = tvm.build(s, [A, C], "llvm") - dev = tvm.cpu(0) - # launch the kernel. - a = tvm.nd.empty((n,), A.dtype).copyfrom(np.random.uniform(size=(n, lanes))) - c = tvm.nd.empty((n,), C.dtype, dev) - f(a, c) - tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1) - - check_llvm(64, 2) - check_llvm(512, 2) - - -@tvm.testing.requires_llvm -def test_llvm_madd_pipeline(): - def check_llvm(nn, base, stride): - n = tvm.runtime.convert(nn) - A = te.placeholder((n + base, stride), name="A") - C = te.compute((n, stride), lambda i, j: A(base + i, j) + 1, name="C") - s = te.create_schedule(C.op) - xo, xi = s[C].split(C.op.axis[0], factor=4) - s[C].parallel(xo) - s[C].vectorize(xi) - # build and invoke the kernel. - f = tvm.build(s, [A, C], "llvm") - dev = tvm.cpu(0) - # launch the kernel. - n = nn - a = tvm.nd.array(np.random.uniform(size=(n + base, stride)).astype(A.dtype), dev) - c = tvm.nd.array(np.zeros((n, stride), dtype=C.dtype), dev) - f(a, c) - tvm.testing.assert_allclose(c.numpy(), a.numpy()[base:] + 1) - - check_llvm(64, 0, 2) - check_llvm(4, 0, 1) - - with tvm.transform.PassContext(config={"tir.noalias": False}): - check_llvm(4, 0, 3) - - -@tvm.testing.requires_llvm -def test_llvm_temp_space(): - nn = 1024 - n = tvm.runtime.convert(nn) - A = te.placeholder((n,), name="A") - B = te.compute(A.shape, lambda i: A(i) + 1, name="B") - C = te.compute(A.shape, lambda i: B(i) + 1, name="C") - s = te.create_schedule(C.op) - - def check_llvm(): - # build and invoke the kernel. - f = tvm.build(s, [A, C], "llvm") - dev = tvm.cpu(0) - # launch the kernel. - n = nn - a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) - c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev) - f(a, c) - tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1 + 1) - - check_llvm() - - -@tvm.testing.requires_llvm -def test_multiple_func(): - nn = 1024 - n = tvm.runtime.convert(nn) - A = te.placeholder((n,), name="A") - B = te.placeholder((n,), name="B") - C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C") - s = te.create_schedule(C.op) - xo, xi = s[C].split(C.op.axis[0], factor=4) - s[C].parallel(xo) - s[C].vectorize(xi) - - def check_llvm(): - # build two functions - f2 = tvm.lower(s, [A, B, C], name="fadd1") - f1 = tvm.lower(s, [A, B, C], name="fadd2") - m = tvm.build([f1, f2], "llvm") - fadd2 = m["fadd2"] - fadd1 = m["fadd1"] - - dev = tvm.cpu(0) - # launch the kernel. - n = nn - a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) - b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev) - c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev) - fadd1(a, b, c) - tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) - fadd2(a, b, c) - tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) - - check_llvm() - - -@tvm.testing.requires_llvm -def test_llvm_condition(): - def check_llvm(n, offset): - A = te.placeholder((n,), name="A") - C = te.compute((n,), lambda i: tvm.tir.if_then_else(i >= offset, A[i], 0.0), name="C") - s = te.create_schedule(C.op) - # build and invoke the kernel. - f = tvm.build(s, [A, C], "llvm") - dev = tvm.cpu(0) - # launch the kernel. - a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), dev) - c = tvm.nd.empty((n,), A.dtype, dev) - f(a, c) - c_np = a.numpy() - c_np[:offset] = 0 - tvm.testing.assert_allclose(c.numpy(), c_np) - - check_llvm(64, 8) - - -@tvm.testing.requires_llvm -def test_llvm_bool(): - def check_llvm(n): - A = te.placeholder((n,), name="A", dtype="int32") - C = te.compute((n,), lambda i: A[i].equal(1).astype("float"), name="C") - s = te.create_schedule(C.op) - # build and invoke the kernel. - f = tvm.build(s, [A, C], "llvm") - dev = tvm.cpu(0) - # launch the kernel. - a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), dev) - c = tvm.nd.empty((n,), C.dtype, dev) - f(a, c) - c_np = a.numpy() == 1 - tvm.testing.assert_allclose(c.numpy(), c_np) - - check_llvm(64) - - -@tvm.testing.requires_llvm -def test_rank_zero(): - def check_llvm(n): - A = te.placeholder((n,), name="A") - scale = te.placeholder((), name="scale") - k = te.reduce_axis((0, n), name="k") - C = te.compute((), lambda: te.sum(A[k] * scale(), axis=k), name="C") - D = te.compute((), lambda: C() + 1) - s = te.create_schedule(D.op) - # build and invoke the kernel. - f = tvm.build(s, [A, scale, D], "llvm") - dev = tvm.cpu(0) - # launch the kernel. - a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), dev) - sc = tvm.nd.array(np.random.randint(0, 2, size=()).astype(scale.dtype), dev) - d = tvm.nd.empty((), D.dtype, dev) - f(a, sc, d) - d_np = np.sum(a.numpy()) * sc.numpy() + 1 - tvm.testing.assert_allclose(d.numpy(), d_np) - - check_llvm(64) - - -@tvm.testing.requires_llvm -def test_rank_zero_bound_checkers(): - def check_llvm(n): - with tvm.transform.PassContext(config={"tir.instrument_bound_checkers": True}): - A = te.placeholder((n,), name="A") - scale = te.placeholder((), name="scale") - k = te.reduce_axis((0, n), name="k") - C = te.compute((), lambda: te.sum(A[k] * scale(), axis=k), name="C") - D = te.compute((), lambda: C() + 1) - s = te.create_schedule(D.op) - # build and invoke the kernel. - f = tvm.build(s, [A, scale, D], "llvm") - dev = tvm.cpu(0) - # launch the kernel. - a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), dev) - sc = tvm.nd.array(np.random.randint(0, 2, size=()).astype(scale.dtype), dev) - d = tvm.nd.empty((), D.dtype, dev) - f(a, sc, d) - d_np = np.sum(a.numpy()) * sc.numpy() + 1 - tvm.testing.assert_allclose(d.numpy(), d_np) - - check_llvm(64) - - -@tvm.testing.requires_llvm -def test_alignment(): - n = tvm.runtime.convert(1024) - A = te.placeholder((n,), name="A") - B = te.compute(A.shape, lambda i: A[i] * 3, name="B") - s = te.create_schedule(B.op) - bx, tx = s[B].split(B.op.axis[0], factor=8) - s[B].vectorize(tx) - f = tvm.build(s, [A, B], "llvm", name="test_alignment") - - lines = f.get_source().split("\n") - - # Check alignment on load/store. - for l in lines: - if "align" in l and "4 x float" in l: - assert "align 32" in l - - # Check parameter alignment. This looks for the definition of the - # outlined "compute_" function to see if there is an "align" attribute - # listed there. - def has_param_alignment(): - for l in lines: - if re.search(r"test_alignment_compute_\([^(]*align [0-9]", l): - return True - return False - - if tvm.target.codegen.llvm_version_major() >= 5: - assert has_param_alignment() - - # Check for assume intrinsics. This isn't 100% accurate, since it just - # checks if the llvm.assume is there, but detailed check would require - # a much more detailed analysis of the LLVM IR. - def has_call_to_assume(): - for l in lines: - if re.search(r"call.*llvm.assume", l): - return True - return False - - assert has_call_to_assume() - - -@tvm.testing.requires_llvm -def test_llvm_div(): - """Check that the semantics of div and mod is correct""" - - def check(start, end, dstart, dend, dtype, floor_div=False): - div = tvm.te.floordiv if floor_div else tvm.tir.truncdiv - mod = tvm.te.floormod if floor_div else tvm.tir.truncmod - - # A are dividends, B are divisors. Note that we add 1 to make include end in the range. - A = te.placeholder((end - start + 1,), name="A", dtype=dtype) - B = te.placeholder((dend - dstart + 1,), name="B", dtype=dtype) - # We clip values with min and max so that simplifiers know the ranges of values - - def clipa(x): - return tvm.te.min(tvm.tir.const(end, dtype), tvm.te.max(tvm.tir.const(start, dtype), x)) - - def clipb(x): - return tvm.te.min( - tvm.tir.const(dend, dtype), tvm.te.max(tvm.tir.const(dstart, dtype), x) - ) - - # If the range is just a single point, use the constant itself - if start == end: - - def clipa(x): - return tvm.tir.const(start, dtype) - - if dstart == dend: - - def clipb(x): - return tvm.tir.const(dstart, dtype) - - # D are division results and M are modulo results - [D, M] = te.compute( - (end - start + 1, dend - dstart + 1), - lambda i, j: (div(clipa(A[i]), clipb(B[j])), mod(clipa(A[i]), clipb(B[j]))), - ) - - s = te.create_schedule([D.op, M.op]) - f = tvm.build(s, [A, B, D, M], "llvm") - - # Fill input arrays with values - A_arr = tvm.nd.empty((end - start + 1,), dtype) - B_arr = tvm.nd.empty((dend - dstart + 1,), dtype) - A_arr.copyfrom(np.arange(start, end + 1, dtype=dtype)) - B_np = np.arange(dstart, dend + 1, dtype=dtype) - # If the range of the divisor contains 0, replace it with 1 to avoid division by zero - if dend >= 0 and dstart <= 0: - B_np[-dstart] = 1 - B_arr.copyfrom(B_np) - D_arr = tvm.nd.empty((end - start + 1, dend - dstart + 1), dtype) - M_arr = tvm.nd.empty((end - start + 1, dend - dstart + 1), dtype) - - # Run the function and convert the results to numpy - f(A_arr, B_arr, D_arr, M_arr) - D_arr = D_arr.numpy() - M_arr = M_arr.numpy() - - # This helper just prints additional info on failure - def _show_info(): - print("dtype: {}".format(dtype)) - print("dividend range: [{}, {}]".format(start, end)) - print("divisor range: [{}, {}]".format(dstart, dend)) - lowered = tvm.lower(s, [A, B, D, M], simple_mode=True) - print("Lowered code:") - print(lowered) - - # Check that the computed values are correct - for i in range(start, end + 1): - for j in range(dstart, dend + 1): - if j == 0: - continue - - if floor_div: - dref = i // j - mref = i % j - else: - dref = int(float(i) / j) - mref = int(math.fmod(i, j)) - - if D_arr[i - start, j - dstart] != dref: - _show_info() - raise AssertionError( - "Incorrect division result: {}({}, {}) is {} " - "but should be {}".format( - div.__name__, i, j, D_arr[i - start, j - dstart], dref - ) - ) - if M_arr[i - start, j - dstart] != mref: - _show_info() - raise AssertionError( - "Incorrect modulo result: {}({}, {}) is {} " - "but should be {}".format( - mod.__name__, i, j, M_arr[i - start, j - dstart], mref - ) - ) - - # Try different ranges to cover different cases - for start, end in [ - (-12, -12), - (-11, -1), - (-11, 0), - (0, 0), - (12, 12), - (1, 11), - (0, 11), - (-11, 11), - ]: - for dstart, dend in [ - (-11, -1), - (-11, 0), - (-4, -4), - (-2, -2), - (1, 11), - (0, 11), - (4, 4), - (2, 2), - (-11, 11), - ]: - if end < start or dend < dstart or (dend == 0 and dstart == 0): - continue - check(start, end, dstart, dend, "int32", floor_div=False) - check(start, end, dstart, dend, "int32", floor_div=True) - check(start, end, dstart, dend, "int8", floor_div=False) - check(start, end, dstart, dend, "int8", floor_div=True) - if start >= 0 and dstart >= 0: - check(start, end, dstart, dend, "uint32", floor_div=False) - check(start, end, dstart, dend, "uint32", floor_div=True) - - # Additional tests for uint8 - for dstart, dend in [(0, 11), (1, 11), (2, 2), (4, 4)]: - check(123, 133, dstart, dend, "uint8", floor_div=False) - check(123, 133, dstart, dend, "uint8", floor_div=True) - check(0, 255, dstart, dend, "uint8", floor_div=False) - check(0, 255, dstart, dend, "uint8", floor_div=True) - - -@tvm.testing.requires_llvm -def test_llvm_fp_math(): - def check_llvm_reciprocal(n): - A = te.placeholder((n,), name="A") - B = te.compute((n,), lambda i: te.div(1.0, (1e37 * A[i])), name="B") - - s = te.create_schedule(B.op) - f = tvm.build(s, [A, B], "llvm") - - a = tvm.nd.array(np.full((n,), 100, "float32")) - b = tvm.nd.empty((n,), "float32") - f(a, b) - tvm.testing.assert_allclose(b.numpy(), np.zeros((n,), "float32")) - - check_llvm_reciprocal(4) - check_llvm_reciprocal(8) - check_llvm_reciprocal(16) - - def check_llvm_sigmoid(n): - A = te.placeholder((n,), name="A") - B = te.compute((n,), lambda i: te.sigmoid(A[i]), name="B") - - s = te.create_schedule(B.op) - f = tvm.build(s, [A, B], "llvm") - - a = tvm.nd.array(np.full((n,), -1000, "float32")) - b = tvm.nd.empty((n,), "float32") - f(a, b) - tvm.testing.assert_allclose(b.numpy(), np.zeros((n,), "float32")) - - check_llvm_sigmoid(4) - check_llvm_sigmoid(8) - check_llvm_sigmoid(16) - - -@tvm.testing.requires_llvm -def test_dwarf_debug_information(): - nn = 1024 - n = tvm.runtime.convert(nn) - A = te.placeholder((n,), name="A") - B = te.placeholder((n,), name="B") - C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C") - s = te.create_schedule(C.op) - xo, xi = s[C].split(C.op.axis[0], factor=4) - s[C].parallel(xo) - s[C].vectorize(xi) - - def check_llvm_object(): - if tvm.target.codegen.llvm_version_major() < 5: - return - if tvm.target.codegen.llvm_version_major() > 6: - return - # build two functions - f2 = tvm.lower(s, [A, B, C], name="fadd1") - f1 = tvm.lower(s, [A, B, C], name="fadd2") - m = tvm.build([f1, f2], "llvm") - temp = utils.tempdir() - o_path = temp.relpath("temp.o") - m.save(o_path) - import shutil - import subprocess - import sys - - # Try the dwarfdump utility (OS X) - if shutil.which("dwarfdump"): - output = subprocess.check_output(["dwarfdump", o_path]) - assert re.search(r"""DW_AT_name\\t\("fadd1"\)""", str(output)) - assert re.search(r"""DW_AT_name\\t\("fadd2"\)""", str(output)) - - # Try gobjdump (OS X) - if shutil.which("gobjdump"): - output = subprocess.check_output(["gobjdump", "--dwarf", o_path]) - assert re.search(r"""DW_AT_name.*fadd1""", str(output)) - assert re.search(r"""DW_AT_name.*fadd2""", str(output)) - - # Try objdump (Linux) - Darwin objdump has different DWARF syntax. - if shutil.which("objdump") and sys.platform != "darwin": - output = subprocess.check_output(["objdump", "--dwarf", o_path]) - assert re.search(r"""DW_AT_name.*fadd1""", str(output)) - assert re.search(r"""DW_AT_name.*fadd2""", str(output)) - - def check_llvm_ir(): - if tvm.target.codegen.llvm_version_major() < 5: - return - if tvm.target.codegen.llvm_version_major() > 6: - return - # build two functions - f2 = tvm.lower(s, [A, B, C], name="fadd1") - f1 = tvm.lower(s, [A, B, C], name="fadd2") - m = tvm.build([f1, f2], target="llvm -mtriple=aarch64-linux-gnu") - ll = m.get_source("ll") - - # On non-Darwin OS, don't explicitly specify DWARF version. - import re - - assert not re.search(r""""Dwarf Version""" "", ll) - assert re.search(r"""llvm.dbg.value""", ll) - - # Try Darwin, require DWARF-2 - m = tvm.build([f1, f2], target="llvm -mtriple=x86_64-apple-darwin-macho") - ll = m.get_source("ll") - assert re.search(r"""i32 4, !"Dwarf Version", i32 2""", ll) - assert re.search(r"""llvm.dbg.value""", ll) - - check_llvm_object() - check_llvm_ir() - - -@tvm.testing.requires_llvm -def test_llvm_shuffle(): - a = te.placeholder((8,), "int32") - b = te.placeholder((8,), "int32") - c = te.compute((8,), lambda x: a[x] + b[7 - x]) - sch = te.create_schedule(c.op) - - def my_vectorize(): - def vectorizer(op): - store = op.body - idx = tvm.tir.Ramp(tvm.tir.const(0, "int32"), tvm.tir.const(1, "int32"), 8) - all_ones = tvm.tir.const(1, "int32x8") - value = store.value - b_idx = tvm.tir.Shuffle([idx], [tvm.tir.const(i, "int32") for i in range(7, -1, -1)]) - new_a = tvm.tir.Load("int32x8", value.a.buffer_var, idx, all_ones) - new_b = tvm.tir.Load("int32x8", value.b.buffer_var, b_idx, all_ones) - value = new_a + new_b - return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones) - - def _transform(f, *_): - return f.with_body( - tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ["tir.For"]) - ) - - return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize") - - with tvm.transform.PassContext(config={"tir.add_lower_pass": [(1, my_vectorize())]}): - ir = tvm.lower(sch, [a, b, c], simple_mode=True) - module = tvm.build(sch, [a, b, c]) - a_ = tvm.nd.array(np.arange(1, 9, dtype="int32")) - b_ = tvm.nd.array(np.arange(8, 0, -1, dtype="int32")) - c_ = tvm.nd.array(np.zeros((8,), dtype="int32")) - module(a_, b_, c_) - tvm.testing.assert_allclose(c_.numpy(), (a_.numpy() * 2).astype("int32")) - - -def np_float2np_bf16(arr): - """Convert a numpy array of float to a numpy array - of bf16 in uint16""" - orig = arr.view("= n): + A[i] = A[i] + 1 + stmt = ib.get() + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) + stmt = tvm.tir.transform.VectorizeLoopScalable()(mod)["main"].body + + assert isinstance(stmt, tvm.tir.For) + + +def test_vectorize_if_then_else(): + n = te.var("n") + x = te.var("x") + ib = tvm.tir.ir_builder.create() + A = ib.pointer("float32", name="A") + with ib.for_range(0, 4, kind="vectorize_scalable") as i: + A[i] = tvm.tir.call_intrin("float32", "tir.if_then_else", i > 0, A[i] + 1, A[i]) + stmt = ib.get() + print(stmt) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt)) + stmt = tvm.tir.transform.VectorizeLoopScalable()(mod)["main"].body + print(stmt) + assert isinstance(stmt, tvm.tir.For) + + ib = tvm.tir.ir_builder.create() + A = ib.pointer("float32", name="A") + with ib.for_range(0, n) as k: + with ib.for_range(0, 4, kind="vectorize_scalable") as i: + A[k * 4 + i] = tvm.tir.call_intrin( + "float32", "tir.if_then_else", k > 0, A[k * 4 + i], 0 + ) + stmt = ib.get() + print(stmt) + assert isinstance(stmt, tvm.tir.For) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) + stmt = tvm.tir.transform.VectorizeLoopScalable()(mod)["main"].body + print(stmt) + assert isinstance(stmt.body, tvm.tir.For) + + assert not isinstance(stmt.body.body, tvm.tir.For) + assert isinstance(stmt.body.body.value.args[2], tvm.tir.Broadcast) + + +if __name__ == "__main__": + test_vectorize_vector() + test_vectorize_with_if() + test_vectorize_loop() + test_vectorize_if_then_else() + test_vectorize_with_le_cond() + test_vectorize_with_ge_cond() + test_vectorize_let()