diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index f6a7d424ed7d..8f3ae9b42460 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -110,6 +110,8 @@ class DataType { } return -lanes_as_int; } + /*! \return get vscale factor or lanes depending on scalability of the vector. */ + int get_lanes_or_vscale_factor() { return is_scalable_vector() ? vscale_factor() : lanes(); } /*! \return whether type is a scalar type. */ bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; } /*! \return whether type is a scalar type. */ @@ -211,10 +213,13 @@ class DataType { /*! * \brief Construct an uint type. * \param bits The number of bits in the type. - * \param lanes The number of lanes + * \param lanes The number of lanes. + * \param is_scalable Whether the data type is scalable. * \return The constructed data type. */ - static DataType UInt(int bits, int lanes = 1) { return DataType(kDLUInt, bits, lanes); } + static DataType UInt(int bits, int lanes = 1, bool is_scalable = false) { + return DataType(kDLUInt, bits, lanes, is_scalable); + } /*! * \brief Construct an float type. * \param bits The number of bits in the type. @@ -243,10 +248,13 @@ class DataType { static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8, lanes); } /*! * \brief Construct a bool type. - * \param lanes The number of lanes + * \param lanes The number of lanes. + * \param is_scalable Whether the data type is scalable. * \return The constructed data type. */ - static DataType Bool(int lanes = 1) { return DataType::UInt(1, lanes); } + static DataType Bool(int lanes = 1, bool is_scalable = false) { + return DataType::UInt(1, lanes, is_scalable); + } /*! * \brief Construct a handle type. * \param bits The number of bits in the type. diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 6e23a84bc290..e1b1c654570a 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -1045,6 +1045,13 @@ def _has_cpu_feat(features): ) +requires_aarch64_sve = Feature( + "arm_sve", + "AArch64 SVE", + run_time_check=lambda: _has_cpu_feat("sve"), +) + + requires_x86_vnni = Feature( "x86_vnni", "x86 VNNI Extensions", diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index eae26e5cac5b..bba1488274e2 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -587,10 +587,17 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { LOG(FATAL) << "do not support " << dtype; } } - if (dtype.lanes() != 1) { + if (!dtype.is_scalar()) { #if TVM_LLVM_VERSION >= 110 - return llvm::FixedVectorType::get(etype, dtype.lanes()); + if (dtype.is_scalable_vector()) { + return llvm::VectorType::get(etype, dtype.vscale_factor(), true); + } else { + return llvm::FixedVectorType::get(etype, dtype.lanes()); + } #else + ICHECK(!dtype.is_scalable_vector()) + << "Versions of LLVM < 11 do not support scalable vectors. Please upgrade to a later " + "version."; return llvm::VectorType::get(etype, dtype.lanes()); #endif } else { @@ -749,26 +756,6 @@ std::unique_ptr CodeGenLLVM::CreateDebugInfo(llvm::Modul return debug_info; } -llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) { -#if TVM_LLVM_VERSION >= 110 - llvm::Type* type = llvm::FixedVectorType::get(value->getType(), lanes); -#else - llvm::Type* type = llvm::VectorType::get(value->getType(), lanes); -#endif - llvm::Constant* undef = llvm::UndefValue::get(type); - llvm::Constant* zero = ConstInt32(0); - value = builder_->CreateInsertElement(undef, value, zero); -#if TVM_LLVM_VERSION >= 120 - llvm::Constant* mask = llvm::ConstantVector::getSplat(llvm::ElementCount::getFixed(lanes), zero); -#elif TVM_LLVM_VERSION >= 110 - llvm::Constant* mask = - llvm::ConstantVector::getSplat(llvm::ElementCount(lanes, /*Scalable=*/false), zero); -#else - llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero); -#endif - return builder_->CreateShuffleVector(value, undef, mask); -} - llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) { int num_elems = GetVectorNumElements(vec); if (extent == num_elems && begin == 0) return vec; @@ -1693,7 +1680,8 @@ void CodeGenLLVM::BufferAccessHelper( } PrimExpr last_index = indices[indices.size() - 1]; - ICHECK_EQ(value_dtype.lanes(), last_index.dtype().lanes() * buffer_element_dtype.lanes()); + ICHECK_EQ(value_dtype.get_lanes_or_vscale_factor(), + last_index.dtype().get_lanes_or_vscale_factor() * buffer_element_dtype.lanes()); // Record index and elemtype in original form used for alias info PrimExpr last_index_origin = last_index; @@ -1736,8 +1724,6 @@ void CodeGenLLVM::BufferAccessHelper( llvm::Value* last_index_value; int subelement_i = i; if (const RampNode* ramp = last_index.as()) { - // TODO(ekalda): P4 in https://github.com/apache/tvm/issues/16455 - ICHECK(!last_index.dtype().is_scalable_vector()); PrimExpr offset = ramp->base + (ramp->stride * i); last_index_value = MakeValue(offset); } else if (last_index.dtype().lanes() > 1) { @@ -1754,8 +1740,13 @@ void CodeGenLLVM::BufferAccessHelper( all_index_values.push_back(last_index_value); TypedPointer buffer_ptr = - CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, all_index_values, - value_dtype.with_lanes(value_dtype.lanes() / last_index.dtype().lanes())); + value_dtype.is_scalable_vector() + ? CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, all_index_values, + value_dtype.with_scalable_vscale_factor(value_dtype.vscale_factor() / + last_index.dtype().lanes())) + : CreateBufferPtr( + MakeValue(buffer->data), buffer_element_dtype, all_index_values, + value_dtype.with_lanes(value_dtype.lanes() / last_index.dtype().lanes())); auto instruction = make_instruction(buffer_ptr, subelement_i, alignment, is_volatile); AddAliasInfo(instruction, buffer->data.get(), last_index_origin, buffer_element_dtype_origin); } @@ -1870,10 +1861,23 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) { - // TODO(ekalda): P4 in https://github.com/apache/tvm/issues/16455 - ICHECK(!op->dtype.is_scalable_vector()); - int lanes = op->dtype.lanes(); - return CreateBroadcast(MakeValue(op->value), lanes); + DataType dtype = op->dtype; + llvm::Value* value = MakeValue(op->value); + llvm::Type* type = DTypeToLLVMType(dtype); + llvm::Constant* undef = llvm::UndefValue::get(type); + llvm::Constant* zero = ConstInt32(0); + value = builder_->CreateInsertElement(undef, value, zero); +#if TVM_LLVM_VERSION >= 110 + llvm::ElementCount ec = + llvm::ElementCount::get(dtype.get_lanes_or_vscale_factor(), dtype.is_scalable_vector()); + llvm::Constant* mask = llvm::ConstantVector::getSplat(ec, zero); +#else + ICHECK(!dtype.is_scalable_vector()) + << "Versions of LLVM < 11 do not support scalable vectors. Please upgrade to a later " + "version."; + llvm::Constant* mask = llvm::ConstantVector::getSplat(dtype.lanes(), zero); +#endif + return builder_->CreateShuffleVector(value, undef, mask); } void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 2efac0307345..0f7aa847ecb8 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -468,7 +468,6 @@ class CodeGenLLVM : public ExprFunctor, llvm::Value* CreateAdd(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateSub(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateMul(DataType t, llvm::Value* a, llvm::Value* b); - llvm::Value* CreateBroadcast(llvm::Value* value, int lanes); virtual TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype, llvm::ArrayRef indices, DataType value_dtype); // Vector concatenation. diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index 2bd1e0608374..2d2c097be494 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -451,7 +451,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BufferStoreNode* op) { Buffer new_buffer = GetRemappedBuffer(op->buffer); auto value = this->VisitExpr(op->value); - if (new_buffer->dtype != value->dtype && value->dtype.lanes() == 1) { + if (new_buffer->dtype != value->dtype && value->dtype.is_scalar()) { value = cast(new_buffer->dtype, value); } auto indices = VisitIndices(op->indices); diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 1b611d453418..c2baad209624 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -58,7 +58,9 @@ namespace tir { CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \ << b.dtype() << "\n"; \ ObjectPtr node = make_object(); \ - node->dtype = DataType::Bool(a.dtype().lanes()); \ + DataType a_dtype = a.dtype(); \ + node->dtype = \ + DataType::Bool(a_dtype.get_lanes_or_vscale_factor(), a_dtype.is_scalable_vector()); \ node->a = std::move(a); \ node->b = std::move(b); \ node->span = std::move(span); \ @@ -393,7 +395,8 @@ Not::Not(PrimExpr a, Span span) { ICHECK(a.dtype().is_bool()); ObjectPtr node = make_object(); - node->dtype = DataType::Bool(a.dtype().lanes()); + DataType a_dtype = a.dtype(); + node->dtype = DataType::Bool(a_dtype.get_lanes_or_vscale_factor(), a_dtype.is_scalable_vector()); node->a = std::move(a); node->span = std::move(span); data_ = std::move(node); diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index e40f683e21f8..3f34f2e870fd 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -1275,6 +1275,13 @@ class VectorTypeAccessChecker : public StmtExprVisitor { auto it = info_map_.find(buffer); ICHECK(it != info_map_.end()) << "Load/Store of buffer " << buffer->name_hint << " (" << buffer << ") occurred before its declaration."; + + if (value_dtype.is_scalable_vector()) { + // Scalable types are not currently supported in storage_rewrite. Scalable buffer + // accesses are not currently checked and therefore are not rewritten. + return; + } + BufferVarInfo& var_info = it->second; if (value_dtype.element_of() == DataType::Bool()) { diff --git a/tests/cpp/tir_scalable_datatype.cc b/tests/cpp/tir_scalable_datatype.cc index 23decef69e5a..4b4764555f7b 100644 --- a/tests/cpp/tir_scalable_datatype.cc +++ b/tests/cpp/tir_scalable_datatype.cc @@ -162,6 +162,22 @@ TEST(ScalableDataType, TestScalableDataTypeInvalidLanesAccess) { tvm::InternalError); } +TEST(ScalableDataType, TestScalableBool) { + tvm::DataType scalable_type = tvm::DataType::Bool(4, true); + ASSERT_EQ(scalable_type.code(), kDLUInt); + ASSERT_EQ(scalable_type.bits(), 1); + ASSERT_EQ(scalable_type.vscale_factor(), 4); + ASSERT_TRUE(scalable_type.is_scalable_vector()); +} + +TEST(ScalableDataType, TestScalableUInt) { + tvm::DataType scalable_type = tvm::DataType::UInt(1, 4, true); + ASSERT_EQ(scalable_type.code(), kDLUInt); + ASSERT_EQ(scalable_type.bits(), 1); + ASSERT_EQ(scalable_type.vscale_factor(), 4); + ASSERT_TRUE(scalable_type.is_scalable_vector()); +} + // ----------- // Integration // ----------- diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index 4e75f916d9b2..773c113f4a42 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -492,5 +492,46 @@ def main(A: T.Buffer((5,), "int32")): assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM." +@pytest.mark.skipif( + llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM" +) +def test_scalable_buffer_load_store(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + @T.prim_func + def my_func(a: T.handle, b: T.handle): + A = T.match_buffer(a, (128,), "float32") + B = T.match_buffer(b, (128,), "float32") + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())] + + mod = tvm.build(my_func, target=target) + llvm = mod.get_source("ll") + + assert re.findall(r"load ", llvm), "No scalable load in generated LLVM." + assert re.findall(r" store ", llvm), "No scalable store in generated LLVM." + + +@pytest.mark.skipif( + llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM" +) +def test_scalable_broadcast(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + @T.prim_func + def my_func(a: T.handle): + A = T.match_buffer(a, (128,), "float32") + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(1, 4 * T.vscale()) + + mod = tvm.build(my_func, target=target) + llvm = mod.get_source("ll") + + assert re.findall( + r"shufflevector \( insertelement \(", llvm + ), "No scalable broadcast in generated LLVM." + assert re.findall(r" store ", llvm), "No scalable store in generated LLVM." + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/target/test_arm_target.py b/tests/python/target/test_arm_target.py index dc8452710a8a..158d941073c6 100644 --- a/tests/python/target/test_arm_target.py +++ b/tests/python/target/test_arm_target.py @@ -14,9 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +import subprocess +import tempfile +import re + import pytest +import numpy as np import tvm +from tvm.script import tir as T from tvm.topi.arm_cpu.conv2d_int8 import is_int8_hw_support from tvm.target import codegen @@ -61,3 +68,121 @@ def test_arm_conv2d_int8_support( with tvm.target.Target(arm_target): monkeypatch.setattr(codegen, "llvm_version_major", lambda: llvm_version) assert is_int8_hw_support(input_dtype, kernel_dtype) == is_supported + + +@pytest.fixture(scope="session") +def sve_device_vector_length(): + c_code = r""" + #include + #include + + int main() { + printf("%ld\n", svcntb() * 8); + } + """ + + with tempfile.TemporaryDirectory() as tmp_dir: + c_path = f"{tmp_dir}/vl.c" + o_path = f"{tmp_dir}/out.o" + with open(c_path, "w") as f: + f.write(c_code) + tvm.contrib.cc.create_executable(o_path, c_path, ["-march=native"]) + out = subprocess.check_output(o_path, shell=True).strip().decode() + + return int(out) + + +@tvm.testing.requires_aarch64_sve +def test_scalable_div(sve_device_vector_length): + np.random.seed(0) + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + dev = tvm.cpu(0) + + @T.prim_func + def my_func(a: T.handle): + A = T.match_buffer(a, (1,), "int32") + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + A[0] = T.Div(10000, 4 * T.vscale()) + + mod = tvm.build(my_func, target=target) + + A_nd = tvm.nd.array(np.empty((1,), dtype="int32"), device=dev) + mod(A_nd) + + ref = 10000 // (sve_device_vector_length // 32) + tvm.testing.assert_allclose(A_nd.numpy()[0], ref) + + +@tvm.testing.requires_aarch64_sve +def test_scalable_buffer_load_store(sve_device_vector_length): + np.random.seed(0) + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + num_elements = sve_device_vector_length // 32 + dev = tvm.cpu(0) + + @T.prim_func + def my_func(a: T.handle, b: T.handle): + A = T.match_buffer(a, (num_elements,), "float32") + B = T.match_buffer(b, (num_elements,), "float32") + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())] + + mod = tvm.build(my_func, target=target) + + A_np = np.random.uniform(size=(num_elements,)).astype("float32") + B_np = np.zeros((num_elements,)).astype("float32") + A_nd = tvm.nd.array(A_np, device=dev) + B_nd = tvm.nd.array(B_np, device=dev) + mod(A_nd, B_nd) + + tvm.testing.assert_allclose(B_nd.numpy(), A_np) + + +@tvm.testing.requires_aarch64_sve +def test_scalable_loop_bound(sve_device_vector_length): + np.random.seed(0) + + dtype = "float32" + num_elements = sve_device_vector_length // 32 + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + dev = tvm.cpu(0) + + @T.prim_func + def my_func(a: T.handle, b: T.handle): + A = T.match_buffer(a, (num_elements,), "float32") + B = T.match_buffer(b, (num_elements,), "float32") + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + for i in T.serial(0, 4 * T.vscale()): + B[i] = A[i] + + mod = tvm.build(my_func, target=target) + + A_np = np.random.uniform(size=(num_elements,)).astype(dtype) + B_np = np.zeros((num_elements,)).astype(dtype) + A_nd = tvm.nd.array(A_np, device=dev) + B_nd = tvm.nd.array(B_np, device=dev) + mod(A_nd, B_nd) + + tvm.testing.assert_allclose(B_nd.numpy(), A_np) + + +@tvm.testing.requires_aarch64_sve +def test_scalable_broadcast(sve_device_vector_length): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + num_elements = sve_device_vector_length // 32 + dev = tvm.cpu(0) + + @T.prim_func + def my_func(a: T.handle): + A = T.match_buffer(a, (num_elements,), "float32") + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(1, 4 * T.vscale()) + + mod = tvm.build(my_func, target=target) + + A_np = np.zeros((num_elements,)).astype("float32") + A_nd = tvm.nd.array(A_np, device=dev) + mod(A_nd) + + ref = np.ones((num_elements,)) + tvm.testing.assert_allclose(A_nd.numpy(), ref)