diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 5efa5f3b9085..f6a7d424ed7d 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -27,6 +27,7 @@ #include #include +#include #include #include @@ -110,7 +111,7 @@ class DataType { return -lanes_as_int; } /*! \return whether type is a scalar type. */ - bool is_scalar() const { return lanes() == 1; } + bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; } /*! \return whether type is a scalar type. */ bool is_bool() const { return code() == DataType::kUInt && bits() == 1; } /*! \return whether type is a float type. */ @@ -389,9 +390,12 @@ inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*) os << "custom[" << GetCustomTypeName(t.code) << "]"; } if (t.code == kTVMOpaqueHandle) return os; + int16_t lanes = static_cast(t.lanes); os << static_cast(t.bits); - if (t.lanes != 1) { - os << 'x' << static_cast(t.lanes); + if (lanes > 1) { + os << 'x' << lanes; + } else if (lanes < -1) { + os << "xvscalex" << -lanes; } return os; } @@ -456,9 +460,14 @@ inline DLDataType String2DLDataType(std::string s) { char* xdelim; // emulate sscanf("%ux%u", bits, lanes) uint8_t bits = static_cast(strtoul(scan, &xdelim, 10)); if (bits != 0) t.bits = bits; + int scalable_multiplier = 1; + if (strncmp(xdelim, "xvscale", 7) == 0) { + scalable_multiplier = -1; + xdelim += 7; + } char* endpt = xdelim; if (*xdelim == 'x') { - t.lanes = static_cast(strtoul(xdelim + 1, &endpt, 10)); + t.lanes = static_cast(scalable_multiplier * strtoul(xdelim + 1, &endpt, 10)); } ICHECK(endpt == s.c_str() + s.length()) << "unknown type " << s; return t; diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 54e4d8f205a1..06f2d4c7e6b6 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -135,7 +135,11 @@ def __init__(self, type_str): arr = type_str.split("x") head = arr[0] - self.lanes = int(arr[1]) if len(arr) > 1 else 1 + if len(arr) == 3: + assert arr[1] == "vscale", f"Invalid data type. Expected 'vscale' but got '{arr[1]}'" + self.lanes = ctypes.c_uint16(-int(arr[2])) + elif len(arr) > 1: + self.lanes = ctypes.c_uint16(int(arr[1])) bits = 32 if head.startswith("int"): @@ -188,8 +192,11 @@ def __repr__(self): type_name = "custom[%s]" % tvm.runtime._ffi_api._datatype_get_type_name(self.type_code) x = "%s%d" % (type_name, self.bits) - if self.lanes != 1: + lanes_as_int = ctypes.c_int16(self.lanes).value + if lanes_as_int > 1: x += "x%d" % self.lanes + elif lanes_as_int < -1: + x += "xvscalex%d" % -lanes_as_int return x def __eq__(self, other): diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index b329d25b5471..c46a8c2643f5 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -342,7 +342,7 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span) { using tir::FloatImmNode; if (value.dtype() == t) return value; // const fold IntImm as they are used in index computations - if (t.lanes() == 1) { + if (t.is_scalar()) { if (const IntImmNode* op = value.as()) { return make_const(t, op->value, op->span); } else if (const FloatImmNode* op = value.as()) { diff --git a/tests/cpp/tir_scalable_datatype.cc b/tests/cpp/tir_scalable_datatype.cc index daa4dfe72912..23decef69e5a 100644 --- a/tests/cpp/tir_scalable_datatype.cc +++ b/tests/cpp/tir_scalable_datatype.cc @@ -24,12 +24,14 @@ #include #include +#include "../../src/script/printer/utils.h" + using ::testing::HasSubstr; // --------- // Data Type // --------- -TEST(TIR, TestCreateScalableType) { +TEST(ScalableDataType, TestCreateScalableType) { tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true); ASSERT_EQ(scalable_type.code(), kDLInt); ASSERT_EQ(scalable_type.bits(), 32); @@ -38,7 +40,7 @@ TEST(TIR, TestCreateScalableType) { ASSERT_TRUE(scalable_type.is_scalable_or_fixed_length_vector()); } -TEST(TIR, TestScalableWithBits) { +TEST(ScalableDataType, TestScalableWithBits) { tvm::DataType scalable_type = tvm::DataType(kDLInt, 1, 8, true); scalable_type = scalable_type.with_bits(32); ASSERT_EQ(scalable_type.bits(), 32); @@ -46,7 +48,7 @@ TEST(TIR, TestScalableWithBits) { ASSERT_TRUE(scalable_type.is_scalable_or_fixed_length_vector()); } -TEST(TIR, TestScalableWithVscaleFactor) { +TEST(ScalableDataType, TestScalableWithVscaleFactor) { tvm::DataType type = tvm::DataType(kDLInt, 32, 1); tvm::DataType scalable_type = type.with_scalable_vscale_factor(4); ASSERT_EQ(scalable_type.vscale_factor(), 4); @@ -54,18 +56,54 @@ TEST(TIR, TestScalableWithVscaleFactor) { ASSERT_TRUE(scalable_type.is_scalable_or_fixed_length_vector()); } -TEST(TIR, TestAssignScalableDataType) { +TEST(ScalableDataType, TestAssignScalableDataType) { tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 2, true); tvm::DataType scalable_type_copy = scalable_type; ASSERT_TRUE(scalable_type_copy.is_scalable_vector()); ASSERT_TRUE(scalable_type_copy.is_scalable_or_fixed_length_vector()); } -TEST(TIR, TestScalableDataTypeAndNonScalableDataTypeInequality) { +TEST(ScalableDataType, TestScalableDataTypeEquality) { + ASSERT_TRUE(tvm::DataType(kDLInt, 32, 4, true) == tvm::DataType(kDLInt, 32, 4, true)); +} + +TEST(ScalableDataType, TestScalableDataTypeAndNonScalableDataTypeInequality) { ASSERT_FALSE(tvm::DataType(kDLInt, 32, 4, true) == tvm::DataType(kDLInt, 32, 4)); } -TEST(TIR, TestGetScalableVectorBytesError) { +TEST(ScalableDataType, TestIsScalar) { + ASSERT_FALSE(tvm::DataType(kDLInt, 32, 4, true).is_scalar()); + ASSERT_TRUE(tvm::DataType(kDLInt, 32, 1, false).is_scalar()); + ASSERT_FALSE(tvm::DataType(kDLInt, 32, 4, false).is_scalar()); + ASSERT_FALSE(tvm::DataType(kDLOpaqueHandle, 1, 0, false).is_scalar()); +} + +TEST(ScalableDataType, TestScalableDataTypeToString) { + tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true); + EXPECT_EQ(tvm::runtime::DLDataType2String(scalable_type), "int32xvscalex4"); +} + +TEST(ScalableDataType, TestStringToScalableDataType) { + std::string scalable_type_str = "int32xvscalex4"; + EXPECT_EQ(tvm::DataType(tvm::runtime::String2DLDataType(scalable_type_str)), + tvm::DataType(kDLInt, 32, 4, true)); +} + +TEST(ScalableDataType, TestInvalidStringToScalableDataType) { + std::string scalable_type_str = "int32x4xvscale"; + EXPECT_THROW( + { + try { + tvm::runtime::String2DLDataType(scalable_type_str); + } catch (const tvm::InternalError& e) { + EXPECT_THAT(e.what(), HasSubstr("unknown type int32x4xvscale")); + throw; + } + }, + tvm::InternalError); +} + +TEST(ScalableDataType, TestGetScalableVectorBytes) { tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true); EXPECT_THROW( { @@ -80,7 +118,7 @@ TEST(TIR, TestGetScalableVectorBytesError) { tvm::InternalError); } -TEST(TIR, TestScalableDataTypeInvalidLanesError) { +TEST(ScalableDataType, TestScalableDataTypeInvalidLanesError) { EXPECT_THROW( { try { @@ -93,7 +131,7 @@ TEST(TIR, TestScalableDataTypeInvalidLanesError) { tvm::InternalError); } -TEST(TIR, TestScalableDataTypeInvalidVscaleFactorAccess) { +TEST(ScalableDataType, TestScalableDataTypeInvalidVscaleFactorAccess) { tvm::DataType fixed_length_type = tvm::DataType(kDLFloat, 32, 4); ASSERT_TRUE(fixed_length_type.is_fixed_length_vector()); ASSERT_TRUE(fixed_length_type.is_scalable_or_fixed_length_vector()); @@ -109,7 +147,7 @@ TEST(TIR, TestScalableDataTypeInvalidVscaleFactorAccess) { tvm::InternalError); } -TEST(TIR, TestScalableDataTypeInvalidLanesAccess) { +TEST(ScalableDataType, TestScalableDataTypeInvalidLanesAccess) { tvm::DataType scalable_type = tvm::DataType(kDLFloat, 32, 4, true); EXPECT_THROW( { @@ -123,3 +161,23 @@ TEST(TIR, TestScalableDataTypeInvalidLanesAccess) { }, tvm::InternalError); } + +// ----------- +// Integration +// ----------- +#if TVM_LLVM_VERSION >= 130 +TEST(ScalableDataType, TestScalableIntrinCall) { + tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true); + tvm::tir::Call call = tvm::tir::Call( + scalable_type, tvm::tir::builtin::call_llvm_intrin(), + {tvm::IntImm(tvm::DataType::Int(32), ::llvm::Intrinsic::experimental_stepvector)}); + ASSERT_EQ(call->dtype, scalable_type); + ASSERT_EQ(call->Script(), + "T.call_llvm_intrin(\"int32xvscalex4\", \"llvm.experimental.stepvector\")"); +} +#endif + +TEST(ScalableDataType, TestTIRScriptScalableDtype2Str) { + tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true); + ASSERT_EQ(tvm::script::printer::DType2Str(scalable_type), "int32xvscalex4"); +} diff --git a/tests/python/tir-base/test_tir_nodes.py b/tests/python/tir-base/test_tir_nodes.py index 5b55c432b055..f3498f8ec753 100644 --- a/tests/python/tir-base/test_tir_nodes.py +++ b/tests/python/tir-base/test_tir_nodes.py @@ -439,21 +439,15 @@ def test_broadcast_to_scalable_vec(): assert broadcast.lanes.b == 4 -@pytest.mark.xfail( - reason="Support for scalable data type string will be added in P3 of https://github.com/apache/tvm/issues/16455" -) def test_buffer_load_scalable_vec(): buf = tvm.tir.decl_buffer((24,), "float32") index = tvm.tir.expr.Ramp(1, 1, 8 * tvm.tir.vscale()) load = tvm.tir.BufferLoad(buf, [index]) assert isinstance(load, tvm.tir.BufferLoad) - assert load.dtype == "float32x8xvscale" + assert load.dtype == "float32xvscalex8" -@pytest.mark.xfail( - reason="Support for scalable data type string will be added in P3 of https://github.com/apache/tvm/issues/16455" -) def test_buffer_store_scalable_vec(): b = tvm.tir.decl_buffer((24,), "int32") value = tvm.tir.expr.Broadcast(1, 4 * tvm.tir.vscale()) @@ -461,15 +455,12 @@ def test_buffer_store_scalable_vec(): store = tvm.tir.BufferStore(b, value, [index]) assert isinstance(store, tvm.tir.BufferStore) - assert store.value.dtype == "int32x4xvscale" + assert store.value.dtype == "int32xvscalex4" -@pytest.mark.xfail( - reason="Support for scalable data type string will be added in P3 of https://github.com/apache/tvm/issues/16455" -) def test_scalable_vec_cast(): b = tvm.tir.decl_buffer((24,), "float32") - value = tvm.tir.expr.Broadcast(1, 12 * tvm.tir.vscale()).astype("float32x12xvscale") + value = tvm.tir.expr.Broadcast(1, 12 * tvm.tir.vscale()).astype("float32xvscalex12") index = tvm.tir.expr.Ramp(0, 1, 12 * tvm.tir.vscale()) store = tvm.tir.BufferStore(b, value, [index]) diff --git a/tests/python/tir-base/test_tir_scalable_datatype.py b/tests/python/tir-base/test_tir_scalable_datatype.py new file mode 100644 index 000000000000..41a367e6e543 --- /dev/null +++ b/tests/python/tir-base/test_tir_scalable_datatype.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import tvm +from tvm import tir +from tvm.script import tir as T +from tvm.target.codegen import llvm_version_major + +""" +Tests for scalable data types. +""" + + +def test_create_scalable_data_type_python_api(): + dtype = tvm.DataType("float32xvscalex4") + assert str(dtype) == "float32xvscalex4" + + +@pytest.mark.skipif(llvm_version_major() < 13, reason="Stepvector intrinsic was added in LLVM 13.") +def test_create_scalable_tir_intrin(): + intrin = tir.call_llvm_intrin("int32xvscalex4", "llvm.experimental.stepvector") + assert intrin.dtype == "int32xvscalex4" + assert str(intrin) == 'T.call_llvm_intrin("int32xvscalex4", "llvm.experimental.stepvector")' + + +@pytest.mark.skipif(llvm_version_major() < 13, reason="Stepvector intrinsic was added in LLVM 13.") +def test_tvm_script_create_scalable_tir_intrin(): + @T.prim_func + def my_func(): + T.call_llvm_intrin("int32xvscalex4", "llvm.experimental.stepvector") + + assert ( + 'T.call_llvm_intrin("int32xvscalex4", "llvm.experimental.stepvector")' in my_func.script() + ) + + +def test_invalid_data_type(): + err_msg = "Invalid data type. Expected 'vscale' but got '4'" + with pytest.raises(AssertionError, match=err_msg): + tvm.DataType("float32x4xvscale") + + +if __name__ == "__main__": + tvm.testing.main()