diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index cbef6235c098..c2de31965afa 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -706,7 +706,7 @@ def apply( if LOAD_V_SHARED is False: LOAD_V_TILE = 1 - if not isinstance(len_r, int): + if not isinstance(len_r, int) or len_r < LOAD_V_TILE * TR * SCALE_PACK * DEC_PACK: return None if isinstance(len_s, int) and len_s > 32000: