From c018ec4596dc3c081a6c840b2523c3bdb3ed4d05 Mon Sep 17 00:00:00 2001 From: Lesheng Jin Date: Wed, 17 Apr 2024 14:27:23 -0400 Subject: [PATCH 1/3] [Bugfix] rocm shared memory issue on MI250 --- python/tvm/dlight/gpu/gemv.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index 644f4e6dfa7a..ed0d894e2f52 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -470,7 +470,8 @@ def apply( elif target.kind.name == "rocm": VEC_C = 4 LOAD_V_SHARED = True - LOAD_V_VEC = 8 + # TODO: for MI250, set LOAD_V_VEC = 4 for now + LOAD_V_VEC = 4 UNROLL = 256 if isinstance(len_S, int): if len_S > len_R: From c730bfb0182acf0953186299515967ac137c2d6d Mon Sep 17 00:00:00 2001 From: Lesheng Jin Date: Wed, 17 Apr 2024 14:49:33 -0400 Subject: [PATCH 2/3] set LOAD_V_SHARED=False --- python/tvm/dlight/gpu/gemv.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index ed0d894e2f52..bbbd8fb86f0b 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -469,9 +469,9 @@ def apply( TS, TR = 2, 64 elif target.kind.name == "rocm": VEC_C = 4 - LOAD_V_SHARED = True - # TODO: for MI250, set LOAD_V_VEC = 4 for now - LOAD_V_VEC = 4 + # TODO: set LOAD_V_SHARED = False for now + LOAD_V_SHARED = False + LOAD_V_VEC = 8 UNROLL = 256 if isinstance(len_S, int): if len_S > len_R: From 9cd7a94ffdd5a4347e0231eeeaea85fbd164145a Mon Sep 17 00:00:00 2001 From: Lesheng Jin Date: Wed, 17 Apr 2024 14:54:45 -0400 Subject: [PATCH 3/3] upd the comment --- python/tvm/dlight/gpu/gemv.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index bbbd8fb86f0b..ed32ea77858f 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -470,6 +470,8 @@ def apply( elif target.kind.name == "rocm": VEC_C = 4 # TODO: set LOAD_V_SHARED = False for now + # rocm might have some issues when load/store of shared do not belong to same data type + # and only works for certain vector lens, our commonly useful vector lens are in 4 LOAD_V_SHARED = False LOAD_V_VEC = 8 UNROLL = 256