From eeb70699a62d5912a05245d6dd7a6bb32d86633e Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Fri, 31 May 2024 17:03:08 +0800 Subject: [PATCH] [DLight] Skip GEMV rules when more than one vector The current dlight GEMV rule require only one vector buffer, otherwise raise an error. This PR change this behavior to skip the rule. --- python/tvm/dlight/gpu/gemv.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index b8a2c6a15f13..9ad6f3f89af3 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -206,8 +206,7 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- if is_inner_reduction is None: return None elif is_inner_reduction: - self.sch_inner_reduction(sch, target, block, vector_input_buffers, epilogue) - return sch + return self.sch_inner_reduction(sch, target, block, vector_input_buffers, epilogue) elif target.kind.name == "opencl" and "android" in str(target.host): ret = self.sch_outer_reduction(sch, target, block, vector_input_buffers, epilogue) if ret is None: @@ -313,7 +312,8 @@ def apply( # load vector into shared memory, shape should be the whole vector if LOAD_V_SHARED: - assert len(vector_input_buffers) == 1 + if len(vector_input_buffers) != 1: + return None V_shared = sch.cache_read(rf, read_buffer_index=0, storage_scope="shared") sch.compute_at(V_shared, tr, preserve_unit_loops=True) l = sch.get_loops(block=V_shared)[-1]