From 018609c69498001d3102a772ee72efb4a932939d Mon Sep 17 00:00:00 2001 From: Yaoming Mu Date: Wed, 15 Mar 2023 22:18:23 +0100 Subject: [PATCH 1/6] add Dockerfile.amd to build MONAI docker image for AMD GPU. The file is based on Dockerfile for NVIDIA GPU. --- Dockerfile.amd | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 Dockerfile.amd diff --git a/Dockerfile.amd b/Dockerfile.amd new file mode 100644 index 0000000000..d595a2f175 --- /dev/null +++ b/Dockerfile.amd @@ -0,0 +1,41 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# To build with a different base image +# please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag. +#ARG ARCH=ROCM +ARG PYTORCH_IMAGE=rocm/pytorch:rocm5.4_ubuntu20.04_py3.8_pytorch_1.12.1 +FROM ${PYTORCH_IMAGE} + +LABEL maintainer="monai.contact@gmail.com" + +WORKDIR /opt/monai + +# install full deps +COPY requirements.txt requirements-min.txt requirements-dev.txt /tmp/ +RUN cp /tmp/requirements.txt /tmp/req.bak \ + && sed -i '/cucim/d' /tmp/requirements-dev.txt \ + && awk '!/torch/' /tmp/requirements.txt > /tmp/tmp && mv /tmp/tmp /tmp/requirements.txt \ + && python -m pip install --upgrade --no-cache-dir pip \ + && python -m pip install --no-cache-dir -r /tmp/requirements-dev.txt + + +# compile ext and remove temp files +# TODO: remark for issue [revise the dockerfile #1276](https://github.com/Project-MONAI/MONAI/issues/1276) +# please specify exact files and folders to be copied -- else, basically always, the Docker build process cannot cache +# this or anything below it and always will build from at most here; one file change leads to no caching from here on... + +COPY LICENSE CHANGELOG.md CODE_OF_CONDUCT.md CONTRIBUTING.md README.md versioneer.py setup.py setup.cfg runtests.sh MANIFEST.in ./ +COPY tests ./tests +COPY monai ./monai +RUN BUILD_MONAI=1 FORCE_CUDA=1 python setup.py develop \ + && rm -rf build __pycache__ +WORKDIR /opt/monai From a8cfcf683bd90ba3da102b3ef55e42a7580ff43e Mon Sep 17 00:00:00 2001 From: Yaoming Mu Date: Wed, 15 Mar 2023 22:42:05 +0100 Subject: [PATCH 2/6] add ROCm HIP support 1. replace __shfl_down_sync(0xffffffff, ...) with __shfl_down(...) __shfl_xor_sync(0xffffffff,...) with __shfl_xor(...) for AMD ROCm HIP 2. replace dim3 integer vector initialization { , , } with dim3( , , ) for AMD ROCm HIP --- monai/_extensions/gmm/gmm_cuda.cu | 95 ++++++++++++++++++++++++------- 1 file changed, 76 insertions(+), 19 deletions(-) diff --git a/monai/_extensions/gmm/gmm_cuda.cu b/monai/_extensions/gmm/gmm_cuda.cu index 2cf70a9920..64deae8abd 100644 --- a/monai/_extensions/gmm/gmm_cuda.cu +++ b/monai/_extensions/gmm/gmm_cuda.cu @@ -82,13 +82,19 @@ __global__ void CovarianceReductionKernel( for (int i = 0; i < MATRIX_COMPONENT_COUNT; i++) { float matrix_component = matrix[i]; - +#ifdef __HIP_PLATFORM_AMD__ + matrix_component += __shfl_down(matrix_component, 16); + matrix_component += __shfl_down(matrix_component, 8); + matrix_component += __shfl_down(matrix_component, 4); + matrix_component += __shfl_down(matrix_component, 2); + matrix_component += __shfl_down(matrix_component, 1); +#else matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16); matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8); matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4); matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2); matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1); - +#endif if (lane_index == 0) { s_matrix_component[warp_index] = matrix_component; } @@ -97,7 +103,23 @@ __global__ void CovarianceReductionKernel( if (warp_index == 0) { matrix_component = s_matrix_component[lane_index]; - +#ifdef __HIP_PLATFORM_AMD__ + if (warp_count >= 32) { + matrix_component += __shfl_down(matrix_component, 16); + } + if (warp_count >= 16) { + matrix_component += __shfl_down(matrix_component, 8); + } + if (warp_count >= 8) { + matrix_component += __shfl_down(matrix_component, 4); + } + if (warp_count >= 4) { + matrix_component += __shfl_down(matrix_component, 2); + } + if (warp_count >= 2) { + matrix_component += __shfl_down(matrix_component, 1); + } +#else if (warp_count >= 32) { matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16); } @@ -113,7 +135,7 @@ __global__ void CovarianceReductionKernel( if (warp_count >= 2) { matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1); } - +#endif if (lane_index == 0) { g_batch_matrices[matrix_offset + i] = matrix_component; } @@ -156,13 +178,19 @@ __global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_g matrix_component += g_batch_matrices[(matrix_offset + matrix_index) * GMM_COMPONENT_COUNT + index]; } } - +#ifdef __HIP_PLATFORM_AMD__ + matrix_component += __shfl_down(matrix_component, 16); + matrix_component += __shfl_down(matrix_component, 8); + matrix_component += __shfl_down(matrix_component, 4); + matrix_component += __shfl_down(matrix_component, 2); + matrix_component += __shfl_down(matrix_component, 1); +#else matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16); matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8); matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4); matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2); matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1); - +#endif if (lane_index == 0) { s_matrix_component[warp_index] = matrix_component; } @@ -171,7 +199,23 @@ __global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_g if (warp_index == 0) { matrix_component = s_matrix_component[lane_index]; - +#ifdef __HIP_PLATFORM_AMD__ + if (warp_count >= 32) { + matrix_component += __shfl_down(matrix_component, 16); + } + if (warp_count >= 16) { + matrix_component += __shfl_down(matrix_component, 8); + } + if (warp_count >= 8) { + matrix_component += __shfl_down(matrix_component, 4); + } + if (warp_count >= 4) { + matrix_component += __shfl_down(matrix_component, 2); + } + if (warp_count >= 2) { + matrix_component += __shfl_down(matrix_component, 1); + } +#else if (warp_count >= 32) { matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16); } @@ -187,7 +231,7 @@ __global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_g if (warp_count >= 2) { matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1); } - +#endif if (lane_index == 0) { float constant = i == 0 ? 0.0f : s_gmm[i] * s_gmm[j]; @@ -261,13 +305,19 @@ __global__ void GMMFindSplit(GMMSplit_t* gmmSplit, int gmmK, float* gmm) { } float max_value = eigenvalue; - +#ifdef __HIP_PLATFORM_AMD__ + max_value = max(max_value, __shfl_xor(max_value, 16)); + max_value = max(max_value, __shfl_xor(max_value, 8)); + max_value = max(max_value, __shfl_xor(max_value, 4)); + max_value = max(max_value, __shfl_xor(max_value, 2)); + max_value = max(max_value, __shfl_xor(max_value, 1)); +#else max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 16)); max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 8)); max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 4)); max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 2)); max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 1)); - +#endif if (max_value == eigenvalue) { GMMSplit_t split; @@ -347,13 +397,20 @@ __global__ void GMMcommonTerm(float* g_gmm) { float gmm_n = threadIdx.x < MIXTURE_SIZE ? g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT] : 0.0f; float sum = gmm_n; - +#ifdef __HIP_PLATFORM_AMD__ + sum += __shfl_xor(sum, 1); + sum += __shfl_xor(sum, 2); + sum += __shfl_xor(sum, 4); + sum += __shfl_xor(sum, 8); + sum += __shfl_xor(sum, 16); + +#else sum += __shfl_xor_sync(0xffffffff, sum, 1); sum += __shfl_xor_sync(0xffffffff, sum, 2); sum += __shfl_xor_sync(0xffffffff, sum, 4); sum += __shfl_xor_sync(0xffffffff, sum, 8); sum += __shfl_xor_sync(0xffffffff, sum, 16); - +#endif if (threadIdx.x < MIXTURE_SIZE) { float det = g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT + MATRIX_COMPONENT_COUNT] + EPSILON; float commonTerm = det > 0.0f ? gmm_n / (sqrtf(det) * sum) : gmm_n / sum; @@ -446,13 +503,13 @@ void GMMInitialize( for (unsigned int k = MIXTURE_COUNT; k < gmm_N; k += MIXTURE_COUNT) { for (unsigned int i = 0; i < k; ++i) { CovarianceReductionKernel - <<<{block_count, 1, batch_count}, BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count); + <<>>(i, image, alpha, block_gmm_scratch, element_count); } - CovarianceFinalizationKernel<<<{k, 1, batch_count}, BLOCK>>>(block_gmm_scratch, gmm, block_count); + CovarianceFinalizationKernel<<>>(block_gmm_scratch, gmm, block_count); - GMMFindSplit<<<{1, 1, batch_count}, dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(gmm_split_scratch, k / MIXTURE_COUNT, gmm); - GMMDoSplit<<<{TILE(element_count, BLOCK_SIZE * DO_SPLIT_DEGENERACY), 1, batch_count}, BLOCK_SIZE>>>( + GMMFindSplit<<>>(gmm_split_scratch, k / MIXTURE_COUNT, gmm); + GMMDoSplit<<>>( gmm_split_scratch, (k / MIXTURE_COUNT) << 4, image, alpha, element_count); } } @@ -472,12 +529,12 @@ void GMMUpdate( for (unsigned int i = 0; i < gmm_N; ++i) { CovarianceReductionKernel - <<<{block_count, 1, batch_count}, BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count); + <<>>(i, image, alpha, block_gmm_scratch, element_count); } - CovarianceFinalizationKernel<<<{gmm_N, 1, batch_count}, BLOCK>>>(block_gmm_scratch, gmm, block_count); + CovarianceFinalizationKernel<<>>(block_gmm_scratch, gmm, block_count); - GMMcommonTerm<<<{1, 1, batch_count}, dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(gmm); + GMMcommonTerm<<>>(gmm); } void GMMDataTerm( From e83e77fcfa2917bfd82cd37c40dbd3b6364deada Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Mar 2023 20:27:44 +0000 Subject: [PATCH 3/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- Dockerfile.amd | 2 +- monai/_extensions/gmm/gmm_cuda.cu | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Dockerfile.amd b/Dockerfile.amd index d595a2f175..ef47282462 100644 --- a/Dockerfile.amd +++ b/Dockerfile.amd @@ -25,7 +25,7 @@ RUN cp /tmp/requirements.txt /tmp/req.bak \ && sed -i '/cucim/d' /tmp/requirements-dev.txt \ && awk '!/torch/' /tmp/requirements.txt > /tmp/tmp && mv /tmp/tmp /tmp/requirements.txt \ && python -m pip install --upgrade --no-cache-dir pip \ - && python -m pip install --no-cache-dir -r /tmp/requirements-dev.txt + && python -m pip install --no-cache-dir -r /tmp/requirements-dev.txt # compile ext and remove temp files diff --git a/monai/_extensions/gmm/gmm_cuda.cu b/monai/_extensions/gmm/gmm_cuda.cu index 64deae8abd..46f7dcb05c 100644 --- a/monai/_extensions/gmm/gmm_cuda.cu +++ b/monai/_extensions/gmm/gmm_cuda.cu @@ -184,7 +184,7 @@ __global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_g matrix_component += __shfl_down(matrix_component, 4); matrix_component += __shfl_down(matrix_component, 2); matrix_component += __shfl_down(matrix_component, 1); -#else +#else matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16); matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8); matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4); @@ -215,7 +215,7 @@ __global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_g if (warp_count >= 2) { matrix_component += __shfl_down(matrix_component, 1); } -#else +#else if (warp_count >= 32) { matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16); } @@ -404,7 +404,7 @@ __global__ void GMMcommonTerm(float* g_gmm) { sum += __shfl_xor(sum, 8); sum += __shfl_xor(sum, 16); -#else +#else sum += __shfl_xor_sync(0xffffffff, sum, 1); sum += __shfl_xor_sync(0xffffffff, sum, 2); sum += __shfl_xor_sync(0xffffffff, sum, 4); From 59f7827fe34c18b5f9f793005f47a13c4f8aef7d Mon Sep 17 00:00:00 2001 From: Yaoming Mu Date: Wed, 24 May 2023 10:48:57 -0500 Subject: [PATCH 4/6] DCO Remediation Commit for Yaoming Mu I, Yaoming Mu , hereby add my Signed-off-by to this commit: 018609c69498001d3102a772ee72efb4a932939d I, Yaoming Mu , hereby add my Signed-off-by to this commit: a8cfcf683bd90ba3da102b3ef55e42a7580ff43e As discussed with MONAI team, moves DOckerfile.amd from this PR. The existing file "Dockerfile" will be refactored to allow build images for different GPUs. Signed-off-by: Yaoming Mu --- Dockerfile.amd | 41 ----------------------------------------- 1 file changed, 41 deletions(-) delete mode 100644 Dockerfile.amd diff --git a/Dockerfile.amd b/Dockerfile.amd deleted file mode 100644 index ef47282462..0000000000 --- a/Dockerfile.amd +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# To build with a different base image -# please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag. -#ARG ARCH=ROCM -ARG PYTORCH_IMAGE=rocm/pytorch:rocm5.4_ubuntu20.04_py3.8_pytorch_1.12.1 -FROM ${PYTORCH_IMAGE} - -LABEL maintainer="monai.contact@gmail.com" - -WORKDIR /opt/monai - -# install full deps -COPY requirements.txt requirements-min.txt requirements-dev.txt /tmp/ -RUN cp /tmp/requirements.txt /tmp/req.bak \ - && sed -i '/cucim/d' /tmp/requirements-dev.txt \ - && awk '!/torch/' /tmp/requirements.txt > /tmp/tmp && mv /tmp/tmp /tmp/requirements.txt \ - && python -m pip install --upgrade --no-cache-dir pip \ - && python -m pip install --no-cache-dir -r /tmp/requirements-dev.txt - - -# compile ext and remove temp files -# TODO: remark for issue [revise the dockerfile #1276](https://github.com/Project-MONAI/MONAI/issues/1276) -# please specify exact files and folders to be copied -- else, basically always, the Docker build process cannot cache -# this or anything below it and always will build from at most here; one file change leads to no caching from here on... - -COPY LICENSE CHANGELOG.md CODE_OF_CONDUCT.md CONTRIBUTING.md README.md versioneer.py setup.py setup.cfg runtests.sh MANIFEST.in ./ -COPY tests ./tests -COPY monai ./monai -RUN BUILD_MONAI=1 FORCE_CUDA=1 python setup.py develop \ - && rm -rf build __pycache__ -WORKDIR /opt/monai From c218ceae69be5e42901d7f76529e8a7ee4bd6b3c Mon Sep 17 00:00:00 2001 From: Yaoming Mu Date: Wed, 24 May 2023 11:10:34 -0500 Subject: [PATCH 5/6] As suggested by ericspod( Eric Kerfoot), using macros functions to reduce the amount of codes. Currently AMD HIP does not support __shfl_xxx_sync functions as NVIDIA CUDA, but when mask is 0xffffffff, __shfl_xxx_sync would be replaced by __shfl_xxx for the codes in gmm_cuda.cu run unittest tests/test_gmm.py for NVIDIA and AMD GPUs and tests passed. Signed-off-by: Yaoming Mu --- monai/_extensions/gmm/gmm_cuda.cu | 135 ++++++++---------------------- 1 file changed, 37 insertions(+), 98 deletions(-) diff --git a/monai/_extensions/gmm/gmm_cuda.cu b/monai/_extensions/gmm/gmm_cuda.cu index 46f7dcb05c..b3e05540a3 100644 --- a/monai/_extensions/gmm/gmm_cuda.cu +++ b/monai/_extensions/gmm/gmm_cuda.cu @@ -21,6 +21,13 @@ limitations under the License. #define EPSILON 1e-5 #define BLOCK_SIZE 32 #define TILE(SIZE, STRIDE) ((((SIZE)-1) / (STRIDE)) + 1) +#ifdef __HIP_PLATFORM_AMD__ + #define __SHFL_DOWN(a, b) __shfl_down(a, b) + #define __SHFL_XOR(a, b) __shfl_xor(a, b) +#else + #define __SHFL_DOWN(a, b) __shfl_down_sync(0xffffffff,a, b) + #define __SHFL_XOR(a, b) __shfl_xor_sync(0xffffffff, a, b) +#endif template __global__ void CovarianceReductionKernel( @@ -82,19 +89,11 @@ __global__ void CovarianceReductionKernel( for (int i = 0; i < MATRIX_COMPONENT_COUNT; i++) { float matrix_component = matrix[i]; -#ifdef __HIP_PLATFORM_AMD__ - matrix_component += __shfl_down(matrix_component, 16); - matrix_component += __shfl_down(matrix_component, 8); - matrix_component += __shfl_down(matrix_component, 4); - matrix_component += __shfl_down(matrix_component, 2); - matrix_component += __shfl_down(matrix_component, 1); -#else - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16); - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8); - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4); - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2); - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1); -#endif + matrix_component += __SHFL_DOWN(matrix_component, 16); + matrix_component += __SHFL_DOWN(matrix_component, 8); + matrix_component += __SHFL_DOWN(matrix_component, 4); + matrix_component += __SHFL_DOWN(matrix_component, 2); + matrix_component += __SHFL_DOWN(matrix_component, 1); if (lane_index == 0) { s_matrix_component[warp_index] = matrix_component; } @@ -103,39 +102,21 @@ __global__ void CovarianceReductionKernel( if (warp_index == 0) { matrix_component = s_matrix_component[lane_index]; -#ifdef __HIP_PLATFORM_AMD__ if (warp_count >= 32) { - matrix_component += __shfl_down(matrix_component, 16); + matrix_component += __SHFL_DOWN(matrix_component, 16); } if (warp_count >= 16) { - matrix_component += __shfl_down(matrix_component, 8); + matrix_component += __SHFL_DOWN(matrix_component, 8); } if (warp_count >= 8) { - matrix_component += __shfl_down(matrix_component, 4); + matrix_component += __SHFL_DOWN(matrix_component, 4); } if (warp_count >= 4) { - matrix_component += __shfl_down(matrix_component, 2); + matrix_component += __SHFL_DOWN(matrix_component, 2); } if (warp_count >= 2) { - matrix_component += __shfl_down(matrix_component, 1); + matrix_component += __SHFL_DOWN(matrix_component, 1); } -#else - if (warp_count >= 32) { - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16); - } - if (warp_count >= 16) { - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8); - } - if (warp_count >= 8) { - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4); - } - if (warp_count >= 4) { - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2); - } - if (warp_count >= 2) { - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1); - } -#endif if (lane_index == 0) { g_batch_matrices[matrix_offset + i] = matrix_component; } @@ -178,19 +159,11 @@ __global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_g matrix_component += g_batch_matrices[(matrix_offset + matrix_index) * GMM_COMPONENT_COUNT + index]; } } -#ifdef __HIP_PLATFORM_AMD__ - matrix_component += __shfl_down(matrix_component, 16); - matrix_component += __shfl_down(matrix_component, 8); - matrix_component += __shfl_down(matrix_component, 4); - matrix_component += __shfl_down(matrix_component, 2); - matrix_component += __shfl_down(matrix_component, 1); -#else - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16); - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8); - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4); - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2); - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1); -#endif + matrix_component += __SHFL_DOWN(matrix_component, 16); + matrix_component += __SHFL_DOWN(matrix_component, 8); + matrix_component += __SHFL_DOWN(matrix_component, 4); + matrix_component += __SHFL_DOWN(matrix_component, 2); + matrix_component += __SHFL_DOWN(matrix_component, 1); if (lane_index == 0) { s_matrix_component[warp_index] = matrix_component; } @@ -199,39 +172,21 @@ __global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_g if (warp_index == 0) { matrix_component = s_matrix_component[lane_index]; -#ifdef __HIP_PLATFORM_AMD__ - if (warp_count >= 32) { - matrix_component += __shfl_down(matrix_component, 16); - } - if (warp_count >= 16) { - matrix_component += __shfl_down(matrix_component, 8); - } - if (warp_count >= 8) { - matrix_component += __shfl_down(matrix_component, 4); - } - if (warp_count >= 4) { - matrix_component += __shfl_down(matrix_component, 2); - } - if (warp_count >= 2) { - matrix_component += __shfl_down(matrix_component, 1); - } -#else if (warp_count >= 32) { - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16); + matrix_component += __SHFL_DOWN(matrix_component, 16); } if (warp_count >= 16) { - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8); + matrix_component += __SHFL_DOWN(matrix_component, 8); } if (warp_count >= 8) { - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4); + matrix_component += __SHFL_DOWN(matrix_component, 4); } if (warp_count >= 4) { - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2); + matrix_component += __SHFL_DOWN(matrix_component, 2); } if (warp_count >= 2) { - matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1); + matrix_component += __SHFL_DOWN(matrix_component, 1); } -#endif if (lane_index == 0) { float constant = i == 0 ? 0.0f : s_gmm[i] * s_gmm[j]; @@ -305,19 +260,11 @@ __global__ void GMMFindSplit(GMMSplit_t* gmmSplit, int gmmK, float* gmm) { } float max_value = eigenvalue; -#ifdef __HIP_PLATFORM_AMD__ - max_value = max(max_value, __shfl_xor(max_value, 16)); - max_value = max(max_value, __shfl_xor(max_value, 8)); - max_value = max(max_value, __shfl_xor(max_value, 4)); - max_value = max(max_value, __shfl_xor(max_value, 2)); - max_value = max(max_value, __shfl_xor(max_value, 1)); -#else - max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 16)); - max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 8)); - max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 4)); - max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 2)); - max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 1)); -#endif + max_value = max(max_value, __SHFL_XOR(max_value, 16)); + max_value = max(max_value, __SHFL_XOR(max_value, 8)); + max_value = max(max_value, __SHFL_XOR(max_value, 4)); + max_value = max(max_value, __SHFL_XOR(max_value, 2)); + max_value = max(max_value, __SHFL_XOR(max_value, 1)); if (max_value == eigenvalue) { GMMSplit_t split; @@ -397,20 +344,12 @@ __global__ void GMMcommonTerm(float* g_gmm) { float gmm_n = threadIdx.x < MIXTURE_SIZE ? g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT] : 0.0f; float sum = gmm_n; -#ifdef __HIP_PLATFORM_AMD__ - sum += __shfl_xor(sum, 1); - sum += __shfl_xor(sum, 2); - sum += __shfl_xor(sum, 4); - sum += __shfl_xor(sum, 8); - sum += __shfl_xor(sum, 16); + sum += __SHFL_XOR(sum, 1); + sum += __SHFL_XOR(sum, 2); + sum += __SHFL_XOR(sum, 4); + sum += __SHFL_XOR(sum, 8); + sum += __SHFL_XOR(sum, 16); -#else - sum += __shfl_xor_sync(0xffffffff, sum, 1); - sum += __shfl_xor_sync(0xffffffff, sum, 2); - sum += __shfl_xor_sync(0xffffffff, sum, 4); - sum += __shfl_xor_sync(0xffffffff, sum, 8); - sum += __shfl_xor_sync(0xffffffff, sum, 16); -#endif if (threadIdx.x < MIXTURE_SIZE) { float det = g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT + MATRIX_COMPONENT_COUNT] + EPSILON; float commonTerm = det > 0.0f ? gmm_n / (sqrtf(det) * sum) : gmm_n / sum; From 8445097456a67f39ba160042c37da00b49ad0b03 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Wed, 24 May 2023 19:43:45 +0000 Subject: [PATCH 6/6] [MONAI] code formatting Signed-off-by: monai-bot --- monai/_extensions/gmm/gmm_cuda.cu | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/monai/_extensions/gmm/gmm_cuda.cu b/monai/_extensions/gmm/gmm_cuda.cu index b3e05540a3..0c808d3165 100644 --- a/monai/_extensions/gmm/gmm_cuda.cu +++ b/monai/_extensions/gmm/gmm_cuda.cu @@ -22,11 +22,11 @@ limitations under the License. #define BLOCK_SIZE 32 #define TILE(SIZE, STRIDE) ((((SIZE)-1) / (STRIDE)) + 1) #ifdef __HIP_PLATFORM_AMD__ - #define __SHFL_DOWN(a, b) __shfl_down(a, b) - #define __SHFL_XOR(a, b) __shfl_xor(a, b) +#define __SHFL_DOWN(a, b) __shfl_down(a, b) +#define __SHFL_XOR(a, b) __shfl_xor(a, b) #else - #define __SHFL_DOWN(a, b) __shfl_down_sync(0xffffffff,a, b) - #define __SHFL_XOR(a, b) __shfl_xor_sync(0xffffffff, a, b) +#define __SHFL_DOWN(a, b) __shfl_down_sync(0xffffffff, a, b) +#define __SHFL_XOR(a, b) __shfl_xor_sync(0xffffffff, a, b) #endif template @@ -447,7 +447,8 @@ void GMMInitialize( CovarianceFinalizationKernel<<>>(block_gmm_scratch, gmm, block_count); - GMMFindSplit<<>>(gmm_split_scratch, k / MIXTURE_COUNT, gmm); + GMMFindSplit<<>>( + gmm_split_scratch, k / MIXTURE_COUNT, gmm); GMMDoSplit<<>>( gmm_split_scratch, (k / MIXTURE_COUNT) << 4, image, alpha, element_count); } @@ -471,7 +472,8 @@ void GMMUpdate( <<>>(i, image, alpha, block_gmm_scratch, element_count); } - CovarianceFinalizationKernel<<>>(block_gmm_scratch, gmm, block_count); + CovarianceFinalizationKernel + <<>>(block_gmm_scratch, gmm, block_count); GMMcommonTerm<<>>(gmm); }