diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index b8a2c6a15f13..9ad6f3f89af3 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -206,8 +206,7 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- if is_inner_reduction is None: return None elif is_inner_reduction: - self.sch_inner_reduction(sch, target, block, vector_input_buffers, epilogue) - return sch + return self.sch_inner_reduction(sch, target, block, vector_input_buffers, epilogue) elif target.kind.name == "opencl" and "android" in str(target.host): ret = self.sch_outer_reduction(sch, target, block, vector_input_buffers, epilogue) if ret is None: @@ -313,7 +312,8 @@ def apply( # load vector into shared memory, shape should be the whole vector if LOAD_V_SHARED: - assert len(vector_input_buffers) == 1 + if len(vector_input_buffers) != 1: + return None V_shared = sch.cache_read(rf, read_buffer_index=0, storage_scope="shared") sch.compute_at(V_shared, tr, preserve_unit_loops=True) l = sch.get_loops(block=V_shared)[-1]