diff --git a/python/tvm/dlight/gpu/general_reduction.py b/python/tvm/dlight/gpu/general_reduction.py index ef6bb1db91e1..404b73a6f0cc 100644 --- a/python/tvm/dlight/gpu/general_reduction.py +++ b/python/tvm/dlight/gpu/general_reduction.py @@ -40,6 +40,9 @@ def apply( # pylint: disable=too-many-locals if target.kind.name == "cuda": len_tx = 256 unroll_depth = 256 + elif target.kind.name == "opencl": + len_tx = 256 + unroll_depth = 64 else: len_tx = 64 unroll_depth = 64 diff --git a/python/tvm/dlight/gpu/rmsnorm.py b/python/tvm/dlight/gpu/rmsnorm.py index f8b2bb4a172d..4047721c9aa8 100644 --- a/python/tvm/dlight/gpu/rmsnorm.py +++ b/python/tvm/dlight/gpu/rmsnorm.py @@ -82,6 +82,8 @@ def apply( # pylint: disable=too-many-locals,missing-docstring ) -> tir.Schedule: if target.kind.name == "cuda": num_tx = 512 + elif target.kind.name == "opencl": + num_tx = 256 else: num_tx = 64 diff --git a/python/tvm/dlight/gpu/transpose.py b/python/tvm/dlight/gpu/transpose.py index d4496756a2d0..3bef3d61e536 100644 --- a/python/tvm/dlight/gpu/transpose.py +++ b/python/tvm/dlight/gpu/transpose.py @@ -57,6 +57,10 @@ def apply( # pylint: disable=too-many-locals len_tx = 16 len_ty = 8 unroll_depth = 256 + elif target.kind.name == "opencl": + len_tx = 16 + len_ty = 8 + unroll_depth = 64 else: len_tx = 8 len_ty = 4 diff --git a/python/tvm/dlight/gpu/utils.py b/python/tvm/dlight/gpu/utils.py index 4f2df5cfa0c9..e27a6969ad88 100644 --- a/python/tvm/dlight/gpu/utils.py +++ b/python/tvm/dlight/gpu/utils.py @@ -55,6 +55,8 @@ def suggest_threads_per_block( threads = 256 elif target.kind.name == "metal": threads = 256 + elif target.kind.name == "opencl": + threads = 256 else: threads = 64 results: List[Optional[int]] = []