From b5860de58b9d4118fe6a4d7f5d0d7d0af3688cf4 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Wed, 29 May 2024 07:33:04 +0000 Subject: [PATCH] [TOPI] Fix SME conv2d schedule import and intrin argument Fixes a merge conflict between #16981 and #17003. Change-Id: Ifcc983ef0b8c00250568a048fd682933adfdcde4 --- python/tvm/topi/arm_cpu/conv2d.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/arm_cpu/conv2d.py b/python/tvm/topi/arm_cpu/conv2d.py index 58c909301ede..d0fe251e7e23 100644 --- a/python/tvm/topi/arm_cpu/conv2d.py +++ b/python/tvm/topi/arm_cpu/conv2d.py @@ -729,7 +729,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule): # pylint: disable=import-outside-toplevel from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes from tvm.tir.tensor_intrin.arm_cpu import ( - ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE, + ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE, ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA, ARM_SME_INIT, get_sme_gemm_interleaved_mopa_2svlx2svl_intrin, @@ -743,7 +743,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule): ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True) sch.parallel(b) sch.reorder(b, ko, mo, ki, mi) - sch.tensorize(ki, ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE) + sch.tensorize(ki, ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE) # Split and reorder the loops of the GeMM for tensorization b, m, n, k = sch.get_loops(gemm_block) @@ -760,7 +760,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule): sme_gemm_interleaved_intrin_name = ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}" tvm.tir.TensorIntrin.register( sme_gemm_interleaved_intrin_name, - *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded), + *get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded, dtype), override=True, ) sch.tensorize(mi, sme_gemm_interleaved_intrin_name)