diff --git a/monai/_extensions/gmm/gmm_cuda.cu b/monai/_extensions/gmm/gmm_cuda.cu index 2cf70a9920..0c808d3165 100644 --- a/monai/_extensions/gmm/gmm_cuda.cu +++ b/monai/_extensions/gmm/gmm_cuda.cu @@ -21,6 +21,13 @@ limitations under the License. #define EPSILON 1e-5 #define BLOCK_SIZE 32 #define TILE(SIZE, STRIDE) ((((SIZE)-1) / (STRIDE)) + 1) +#ifdef __HIP_PLATFORM_AMD__ +#define __SHFL_DOWN(a, b) __shfl_down(a, b) +#define __SHFL_XOR(a, b) __shfl_xor(a, b) +#else +#define __SHFL_DOWN(a, b) __shfl_down_sync(0xffffffff, a, b) +#define __SHFL_XOR(a, b) __shfl_xor_sync(0xffffffff, a, b) +#endif template __global__ void CovarianceReductionKernel( @@ -82,13 +89,11 @@ __global__ void CovarianceReductionKernel( for (int i = 0; i < MATRIX_COMPONENT_COUNT; i++) { float matrix_component = matrix[i]; - - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16); - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8); - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4); - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2); - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1); - + matrix_component += __SHFL_DOWN(matrix_component, 16); + matrix_component += __SHFL_DOWN(matrix_component, 8); + matrix_component += __SHFL_DOWN(matrix_component, 4); + matrix_component += __SHFL_DOWN(matrix_component, 2); + matrix_component += __SHFL_DOWN(matrix_component, 1); if (lane_index == 0) { s_matrix_component[warp_index] = matrix_component; } @@ -97,23 +102,21 @@ __global__ void CovarianceReductionKernel( if (warp_index == 0) { matrix_component = s_matrix_component[lane_index]; - if (warp_count >= 32) { - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16); + matrix_component += __SHFL_DOWN(matrix_component, 16); } if (warp_count >= 16) { - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8); + matrix_component += __SHFL_DOWN(matrix_component, 8); } if (warp_count >= 8) { - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4); + matrix_component += __SHFL_DOWN(matrix_component, 4); } if (warp_count >= 4) { - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2); + matrix_component += __SHFL_DOWN(matrix_component, 2); } if (warp_count >= 2) { - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1); + matrix_component += __SHFL_DOWN(matrix_component, 1); } - if (lane_index == 0) { g_batch_matrices[matrix_offset + i] = matrix_component; } @@ -156,13 +159,11 @@ __global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_g matrix_component += g_batch_matrices[(matrix_offset + matrix_index) * GMM_COMPONENT_COUNT + index]; } } - - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16); - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8); - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4); - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2); - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1); - + matrix_component += __SHFL_DOWN(matrix_component, 16); + matrix_component += __SHFL_DOWN(matrix_component, 8); + matrix_component += __SHFL_DOWN(matrix_component, 4); + matrix_component += __SHFL_DOWN(matrix_component, 2); + matrix_component += __SHFL_DOWN(matrix_component, 1); if (lane_index == 0) { s_matrix_component[warp_index] = matrix_component; } @@ -171,23 +172,21 @@ __global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_g if (warp_index == 0) { matrix_component = s_matrix_component[lane_index]; - if (warp_count >= 32) { - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16); + matrix_component += __SHFL_DOWN(matrix_component, 16); } if (warp_count >= 16) { - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8); + matrix_component += __SHFL_DOWN(matrix_component, 8); } if (warp_count >= 8) { - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4); + matrix_component += __SHFL_DOWN(matrix_component, 4); } if (warp_count >= 4) { - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2); + matrix_component += __SHFL_DOWN(matrix_component, 2); } if (warp_count >= 2) { - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1); + matrix_component += __SHFL_DOWN(matrix_component, 1); } - if (lane_index == 0) { float constant = i == 0 ? 0.0f : s_gmm[i] * s_gmm[j]; @@ -261,13 +260,11 @@ __global__ void GMMFindSplit(GMMSplit_t* gmmSplit, int gmmK, float* gmm) { } float max_value = eigenvalue; - - max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 16)); - max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 8)); - max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 4)); - max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 2)); - max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 1)); - + max_value = max(max_value, __SHFL_XOR(max_value, 16)); + max_value = max(max_value, __SHFL_XOR(max_value, 8)); + max_value = max(max_value, __SHFL_XOR(max_value, 4)); + max_value = max(max_value, __SHFL_XOR(max_value, 2)); + max_value = max(max_value, __SHFL_XOR(max_value, 1)); if (max_value == eigenvalue) { GMMSplit_t split; @@ -347,12 +344,11 @@ __global__ void GMMcommonTerm(float* g_gmm) { float gmm_n = threadIdx.x < MIXTURE_SIZE ? g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT] : 0.0f; float sum = gmm_n; - - sum += __shfl_xor_sync(0xffffffff, sum, 1); - sum += __shfl_xor_sync(0xffffffff, sum, 2); - sum += __shfl_xor_sync(0xffffffff, sum, 4); - sum += __shfl_xor_sync(0xffffffff, sum, 8); - sum += __shfl_xor_sync(0xffffffff, sum, 16); + sum += __SHFL_XOR(sum, 1); + sum += __SHFL_XOR(sum, 2); + sum += __SHFL_XOR(sum, 4); + sum += __SHFL_XOR(sum, 8); + sum += __SHFL_XOR(sum, 16); if (threadIdx.x < MIXTURE_SIZE) { float det = g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT + MATRIX_COMPONENT_COUNT] + EPSILON; @@ -446,13 +442,14 @@ void GMMInitialize( for (unsigned int k = MIXTURE_COUNT; k < gmm_N; k += MIXTURE_COUNT) { for (unsigned int i = 0; i < k; ++i) { CovarianceReductionKernel - <<<{block_count, 1, batch_count}, BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count); + <<>>(i, image, alpha, block_gmm_scratch, element_count); } - CovarianceFinalizationKernel<<<{k, 1, batch_count}, BLOCK>>>(block_gmm_scratch, gmm, block_count); + CovarianceFinalizationKernel<<>>(block_gmm_scratch, gmm, block_count); - GMMFindSplit<<<{1, 1, batch_count}, dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(gmm_split_scratch, k / MIXTURE_COUNT, gmm); - GMMDoSplit<<<{TILE(element_count, BLOCK_SIZE * DO_SPLIT_DEGENERACY), 1, batch_count}, BLOCK_SIZE>>>( + GMMFindSplit<<>>( + gmm_split_scratch, k / MIXTURE_COUNT, gmm); + GMMDoSplit<<>>( gmm_split_scratch, (k / MIXTURE_COUNT) << 4, image, alpha, element_count); } } @@ -472,12 +469,13 @@ void GMMUpdate( for (unsigned int i = 0; i < gmm_N; ++i) { CovarianceReductionKernel - <<<{block_count, 1, batch_count}, BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count); + <<>>(i, image, alpha, block_gmm_scratch, element_count); } - CovarianceFinalizationKernel<<<{gmm_N, 1, batch_count}, BLOCK>>>(block_gmm_scratch, gmm, block_count); + CovarianceFinalizationKernel + <<>>(block_gmm_scratch, gmm, block_count); - GMMcommonTerm<<<{1, 1, batch_count}, dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(gmm); + GMMcommonTerm<<>>(gmm); } void GMMDataTerm(