From 8226c815768cdc9b9844383ab44d3cd2d900d674 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 10 Feb 2025 23:59:17 -0500 Subject: [PATCH] [CUDA] Remove htanh from unsupported math ops for CUDA 12.8 This PR removes htanh from the list of unsupported CUDA half operators, as it is started to be supported since CUDA 12.8. Specifically, we added a CUDA version check in the generated CUDA code, so that when the CUDA version is older than 12.8, htanh will still be treated as an unsupported operator and fall back to the packed operation. While for newer CUDA versions, we directly use the function that is defined in `cuda_fp16.h`. --- src/target/source/literal/cuda_half_t.h | 4 ++++ .../test_meta_schedule_mma_m16n8k8_auto_tensorization.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index c5ecda07a4d3..abdf22df2616 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -317,7 +317,9 @@ static inline __device__ __host__ half HALF_MATH_NAME(half x) { \ #if defined(__CUDA_ARCH__) #if (__CUDA_ARCH__ >= 530) CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf) +#if ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 8))) CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf) +#endif CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf) CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf) CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf) @@ -358,7 +360,9 @@ static inline __device__ __host__ nv_bfloat16 HALF_MATH_NAME(nv_bfloat16 x) { } CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf) +#if ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 8))) CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf) +#endif CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf) CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf) CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf) diff --git a/tests/python/meta_schedule/test_meta_schedule_mma_m16n8k8_auto_tensorization.py b/tests/python/meta_schedule/test_meta_schedule_mma_m16n8k8_auto_tensorization.py index 68f26bd3ee6c..ea8fee672461 100644 --- a/tests/python/meta_schedule/test_meta_schedule_mma_m16n8k8_auto_tensorization.py +++ b/tests/python/meta_schedule/test_meta_schedule_mma_m16n8k8_auto_tensorization.py @@ -717,7 +717,9 @@ class TVM_ALIGNED(2) half { #if defined(__CUDA_ARCH__) #if (__CUDA_ARCH__ >= 530) CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf) +#if ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 8))) CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf) +#endif CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf) CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf) CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf)