From c38e15f261b75ef28a595b71f1add3073932533c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 16 Nov 2022 17:52:54 +0900 Subject: [PATCH] [Hexagon] Fix vrmpy tensorization --- python/tvm/tir/tensor_intrin/hexagon.py | 4 --- .../unittest/test_tir_schedule_tensorize.py | 26 ++++++++++++++++--- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/hexagon.py b/python/tvm/tir/tensor_intrin/hexagon.py index 6fa9dd8f00ae..306c8cd2e14e 100644 --- a/python/tvm/tir/tensor_intrin/hexagon.py +++ b/python/tvm/tir/tensor_intrin/hexagon.py @@ -32,8 +32,6 @@ def dot_product_32x4_u8u8i32_desc( for i in T.serial(0, 32): for k in T.serial(0, 4): with T.block("update"): - with T.init(): - C[i] = T.int32(0) vi, vk = T.axis.remap("SR", [i, k]) C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") @@ -76,8 +74,6 @@ def dot_product_32x4_u8i8i32_desc( for i in T.serial(0, 32): for k in T.serial(0, 4): with T.block("update"): - with T.init(): - C[i] = T.int32(0) vi, vk = T.axis.remap("SR", [i, k]) C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py index f30e91b892c5..0129cee53254 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize.py +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -30,6 +30,7 @@ ) from tvm.tir.tensor_intrin.rocm import AMDGPU_SDOT4_INTRIN from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN +from tvm.tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks @@ -539,9 +540,9 @@ def test_tensorize_with_annotation(): verify_trace_roundtrip(sch=s, mod=func) -def get_matmul_packed(m, n, k, lhs_type, int32_lanes): +def get_matmul_packed(m, n, k, lhs_type, int32_lanes, rhs_dtype="int8"): X = te.placeholder((m, k), name="X", dtype=lhs_type) - packed_W = te.placeholder((n // int32_lanes, k // 4, int32_lanes, 4), name="packedW", dtype="int8") + packed_W = te.placeholder((n // int32_lanes, k // 4, int32_lanes, 4), name="packedW", dtype=rhs_dtype) ak = te.reduce_axis((0, k), name="k") matmul = te.compute( @@ -549,7 +550,7 @@ def get_matmul_packed(m, n, k, lhs_type, int32_lanes): lambda i, j: te.sum( X[i, ak].astype("int32") * packed_W[ - tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4 + tvm.tir.indexdiv(j, int32_lanes), tvm.tir.indexdiv(ak, 4), j % int32_lanes, ak % 4 ].astype("int32"), axis=ak, ), @@ -598,6 +599,25 @@ def test_tensorize_arm_dot(): verify_trace_roundtrip(sch=sch, mod=func) +def test_tensorize_vrmpy(): + m, n, k = 128, 128, 128 + + func = get_matmul_packed(m, n, k, "uint8", 32, "uint8") + + sch = tir.Schedule(func, debug_mask="all") + block = sch.get_block("compute") + _, j, k = sch.get_loops(block) + + _, ji = sch.split(j, factors=[None, 32]) + ko, ki = sch.split(k, factors=[None, 4]) + sch.reorder(ko, ji, ki) + + sch.decompose_reduction(block, ko) + sch.tensorize(ji, VRMPY_u8u8i32_INTRIN) + + verify_trace_roundtrip(sch=sch, mod=func) + + def test_tensorize_dpa4(): m, n, k = 128, 128, 128