diff --git a/python/tvm/tir/tensor_intrin/hexagon.py b/python/tvm/tir/tensor_intrin/hexagon.py index 49c12c3e9dce..5e5749055bb0 100644 --- a/python/tvm/tir/tensor_intrin/hexagon.py +++ b/python/tvm/tir/tensor_intrin/hexagon.py @@ -104,6 +104,48 @@ def dot_product_32x4_u8i8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> Non return dot_product_32x4_u8i8i32_desc, dot_product_32x4_u8i8i32_vrmpy +def generate_dot_product_32x2_i16i16i32(mem_scope="global"): + @T.prim_func + def dot_product_32x2_i16i16i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (2,), "int16", offset_factor=1, scope=mem_scope) + B = T.match_buffer(b, (32, 2), "int16", offset_factor=1, scope=mem_scope) + C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) + with T.block("root"): + T.reads(C[0:32], A[0:2], B[0:32, 0:2]) + T.writes(C[0:32]) + for i in T.serial(0, 32): + for k in T.serial(0, 2): + with T.block("update"): + vi, vk = T.axis.remap("SR", [i, k]) + C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") + + @T.prim_func + def dot_product_32x2_i16i16i32_vdmpy(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (2,), "int16", offset_factor=1, scope=mem_scope) + B = T.match_buffer(b, (32, 2), "int16", offset_factor=1, scope=mem_scope) + C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) + with T.block("root"): + T.reads(C[0:32], A[0:2], B[0:32, 0:2]) + T.writes(C[0:32]) + + A_i16x2 = A.vload([0], "int16x2") + A_i32 = T.reinterpret(A_i16x2, dtype="int32") + + B_i16x64 = B.vload([0, 0], dtype="int16x64") + B_i32x32 = T.reinterpret(B_i16x64, dtype="int32x32") + + C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vdmpyhvsat.acc.128B"), + T.uint32(3), + C[T.ramp(T.int32(0), 1, 32)], + T.Broadcast(A_i32, 32), + B_i32x32, + dtype="int32x32", + ) + + return dot_product_32x2_i16i16i32_desc, dot_product_32x2_i16i16i32_vdmpy + + VRMPY_u8u8i32_INTRIN = "dot_32x4_u8u8i32_vrmpy" TensorIntrin.register(VRMPY_u8u8i32_INTRIN, *generate_dot_product_32x4_u8u8i32()) @@ -112,6 +154,10 @@ def dot_product_32x4_u8i8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> Non TensorIntrin.register(VRMPY_u8i8i32_INTRIN, *generate_dot_product_32x4_u8i8i32()) +VDMPY_i16i16i32_INTRIN = "dot_product_32x2_i16i16i32_vdmpy" + +TensorIntrin.register(VDMPY_i16i16i32_INTRIN, *generate_dot_product_32x2_i16i16i32()) + VRMPY_u8u8i32_VTCM_INTRIN = "dot_32x4_u8u8i32_vtcm_vrmpy" TensorIntrin.register(VRMPY_u8u8i32_VTCM_INTRIN, *generate_dot_product_32x4_u8u8i32("global.vtcm")) diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py index 21cc39b71402..fc0bdc146c88 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize.py +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -30,7 +30,7 @@ ) from tvm.tir.tensor_intrin.rocm import AMDGPU_SDOT4_INTRIN from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN -from tvm.tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN +from tvm.tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN, VDMPY_i16i16i32_INTRIN # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks @@ -540,33 +540,31 @@ def test_tensorize_with_annotation(): verify_trace_roundtrip(sch=s, mod=func) -def get_matmul_packed(m, n, k, lhs_type, int32_lanes, rhs_dtype="int8"): +def get_matmul_packed(m, n, k, lhs_type, rhs_dtype="int8"): X = te.placeholder((m, k), name="X", dtype=lhs_type) - packed_W = te.placeholder((n // int32_lanes, k // 4, int32_lanes, 4), name="packedW", dtype=rhs_dtype) + W = te.placeholder((n, k), name="W", dtype=rhs_dtype) ak = te.reduce_axis((0, k), name="k") matmul = te.compute( (m, n), lambda i, j: te.sum( - X[i, ak].astype("int32") - * packed_W[ - tvm.tir.indexdiv(j, int32_lanes), tvm.tir.indexdiv(ak, 4), j % int32_lanes, ak % 4 - ].astype("int32"), + X[i, ak].astype("int32") * W[j, ak].astype("int32"), axis=ak, ), name="compute", ) - return te.create_prim_func([X, packed_W, matmul]) + return te.create_prim_func([X, W, matmul]) def test_tensorize_vnni(): m, n, k = 128, 128, 128 - func = get_matmul_packed(m, n, k, "uint8", 16) + func = get_matmul_packed(m, n, k, "uint8") sch = tir.Schedule(func, debug_mask="all") block = sch.get_block("compute") + sch.transform_layout(block, "W", lambda i, j: [i//16, j//4, i%16, j%4]) _, j, k = sch.get_loops(block) _, ji = sch.split(j, factors=[None, 16]) @@ -582,11 +580,12 @@ def test_tensorize_vnni(): def test_tensorize_arm_dot(): m, n, k = 128, 128, 128 - func = get_matmul_packed(m, n, k, "int8", 4) + func = get_matmul_packed(m, n, k, "int8") for intrin in [ARM_DOT_4x4_i8_SDOT_INTRIN, ARM_DOT_4x4_i8_NEON_INTRIN]: sch = tir.Schedule(func, debug_mask="all") block = sch.get_block("compute") + sch.transform_layout(block, "W", lambda i, j: [i//4, j//4, i%4, j%4]) _, j, k = sch.get_loops(block) _, ji = sch.split(j, factors=[None, 4]) @@ -602,10 +601,11 @@ def test_tensorize_arm_dot(): def test_tensorize_vrmpy(): m, n, k = 128, 128, 128 - func = get_matmul_packed(m, n, k, "uint8", 32, "uint8") + func = get_matmul_packed(m, n, k, "uint8", "uint8") sch = tir.Schedule(func, debug_mask="all") block = sch.get_block("compute") + sch.transform_layout(block, "W", lambda i, j: [i//32, j//4, i%32, j%4]) _, j, k = sch.get_loops(block) _, ji = sch.split(j, factors=[None, 32]) @@ -618,6 +618,26 @@ def test_tensorize_vrmpy(): verify_trace_roundtrip(sch=sch, mod=func) +def test_tensorize_vdmpy(): + m, n, k = 128, 128, 128 + + func = get_matmul_packed(m, n, k, "int16", "int16") + + sch = tir.Schedule(func, debug_mask="all") + block = sch.get_block("compute") + sch.transform_layout(block, "W", lambda i, j: [i//32, j//2, i%32, j%2]) + _, j, k = sch.get_loops(block) + + _, ji = sch.split(j, factors=[None, 32]) + ko, ki = sch.split(k, factors=[None, 2]) + sch.reorder(ko, ji, ki) + + sch.decompose_reduction(block, ko) + sch.tensorize(ji, VDMPY_i16i16i32_INTRIN) + + verify_trace_roundtrip(sch=sch, mod=func) + + def test_tensorize_dpa4(): m, n, k = 128, 128, 128