Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 46 additions & 48 deletions monai/_extensions/gmm/gmm_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ limitations under the License.
#define EPSILON 1e-5
#define BLOCK_SIZE 32
#define TILE(SIZE, STRIDE) ((((SIZE)-1) / (STRIDE)) + 1)
#ifdef __HIP_PLATFORM_AMD__
#define __SHFL_DOWN(a, b) __shfl_down(a, b)
#define __SHFL_XOR(a, b) __shfl_xor(a, b)
#else
#define __SHFL_DOWN(a, b) __shfl_down_sync(0xffffffff, a, b)
#define __SHFL_XOR(a, b) __shfl_xor_sync(0xffffffff, a, b)
#endif

template <int warp_count, int load_count>
__global__ void CovarianceReductionKernel(
Expand Down Expand Up @@ -82,13 +89,11 @@ __global__ void CovarianceReductionKernel(

for (int i = 0; i < MATRIX_COMPONENT_COUNT; i++) {
float matrix_component = matrix[i];

matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1);

matrix_component += __SHFL_DOWN(matrix_component, 16);
matrix_component += __SHFL_DOWN(matrix_component, 8);
matrix_component += __SHFL_DOWN(matrix_component, 4);
matrix_component += __SHFL_DOWN(matrix_component, 2);
matrix_component += __SHFL_DOWN(matrix_component, 1);
if (lane_index == 0) {
s_matrix_component[warp_index] = matrix_component;
}
Expand All @@ -97,23 +102,21 @@ __global__ void CovarianceReductionKernel(

if (warp_index == 0) {
matrix_component = s_matrix_component[lane_index];

if (warp_count >= 32) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16);
matrix_component += __SHFL_DOWN(matrix_component, 16);
}
if (warp_count >= 16) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8);
matrix_component += __SHFL_DOWN(matrix_component, 8);
}
if (warp_count >= 8) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4);
matrix_component += __SHFL_DOWN(matrix_component, 4);
}
if (warp_count >= 4) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2);
matrix_component += __SHFL_DOWN(matrix_component, 2);
}
if (warp_count >= 2) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1);
matrix_component += __SHFL_DOWN(matrix_component, 1);
}

if (lane_index == 0) {
g_batch_matrices[matrix_offset + i] = matrix_component;
}
Expand Down Expand Up @@ -156,13 +159,11 @@ __global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_g
matrix_component += g_batch_matrices[(matrix_offset + matrix_index) * GMM_COMPONENT_COUNT + index];
}
}

matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1);

matrix_component += __SHFL_DOWN(matrix_component, 16);
matrix_component += __SHFL_DOWN(matrix_component, 8);
matrix_component += __SHFL_DOWN(matrix_component, 4);
matrix_component += __SHFL_DOWN(matrix_component, 2);
matrix_component += __SHFL_DOWN(matrix_component, 1);
if (lane_index == 0) {
s_matrix_component[warp_index] = matrix_component;
}
Expand All @@ -171,23 +172,21 @@ __global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_g

if (warp_index == 0) {
matrix_component = s_matrix_component[lane_index];

if (warp_count >= 32) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16);
matrix_component += __SHFL_DOWN(matrix_component, 16);
}
if (warp_count >= 16) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8);
matrix_component += __SHFL_DOWN(matrix_component, 8);
}
if (warp_count >= 8) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4);
matrix_component += __SHFL_DOWN(matrix_component, 4);
}
if (warp_count >= 4) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2);
matrix_component += __SHFL_DOWN(matrix_component, 2);
}
if (warp_count >= 2) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1);
matrix_component += __SHFL_DOWN(matrix_component, 1);
}

if (lane_index == 0) {
float constant = i == 0 ? 0.0f : s_gmm[i] * s_gmm[j];

Expand Down Expand Up @@ -261,13 +260,11 @@ __global__ void GMMFindSplit(GMMSplit_t* gmmSplit, int gmmK, float* gmm) {
}

float max_value = eigenvalue;

max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 16));
max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 8));
max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 4));
max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 2));
max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));

max_value = max(max_value, __SHFL_XOR(max_value, 16));
max_value = max(max_value, __SHFL_XOR(max_value, 8));
max_value = max(max_value, __SHFL_XOR(max_value, 4));
max_value = max(max_value, __SHFL_XOR(max_value, 2));
max_value = max(max_value, __SHFL_XOR(max_value, 1));
if (max_value == eigenvalue) {
GMMSplit_t split;

Expand Down Expand Up @@ -347,12 +344,11 @@ __global__ void GMMcommonTerm(float* g_gmm) {
float gmm_n = threadIdx.x < MIXTURE_SIZE ? g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT] : 0.0f;

float sum = gmm_n;

sum += __shfl_xor_sync(0xffffffff, sum, 1);
sum += __shfl_xor_sync(0xffffffff, sum, 2);
sum += __shfl_xor_sync(0xffffffff, sum, 4);
sum += __shfl_xor_sync(0xffffffff, sum, 8);
sum += __shfl_xor_sync(0xffffffff, sum, 16);
sum += __SHFL_XOR(sum, 1);
sum += __SHFL_XOR(sum, 2);
sum += __SHFL_XOR(sum, 4);
sum += __SHFL_XOR(sum, 8);
sum += __SHFL_XOR(sum, 16);

if (threadIdx.x < MIXTURE_SIZE) {
float det = g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT + MATRIX_COMPONENT_COUNT] + EPSILON;
Expand Down Expand Up @@ -446,13 +442,14 @@ void GMMInitialize(
for (unsigned int k = MIXTURE_COUNT; k < gmm_N; k += MIXTURE_COUNT) {
for (unsigned int i = 0; i < k; ++i) {
CovarianceReductionKernel<WARPS, LOAD>
<<<{block_count, 1, batch_count}, BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count);
<<<dim3(block_count, 1, batch_count), BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count);
}

CovarianceFinalizationKernel<WARPS, false><<<{k, 1, batch_count}, BLOCK>>>(block_gmm_scratch, gmm, block_count);
CovarianceFinalizationKernel<WARPS, false><<<dim3(k, 1, batch_count), BLOCK>>>(block_gmm_scratch, gmm, block_count);

GMMFindSplit<<<{1, 1, batch_count}, dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(gmm_split_scratch, k / MIXTURE_COUNT, gmm);
GMMDoSplit<<<{TILE(element_count, BLOCK_SIZE * DO_SPLIT_DEGENERACY), 1, batch_count}, BLOCK_SIZE>>>(
GMMFindSplit<<<dim3(1, 1, batch_count), dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(
gmm_split_scratch, k / MIXTURE_COUNT, gmm);
GMMDoSplit<<<dim3(TILE(element_count, BLOCK_SIZE * DO_SPLIT_DEGENERACY), 1, batch_count), BLOCK_SIZE>>>(
gmm_split_scratch, (k / MIXTURE_COUNT) << 4, image, alpha, element_count);
}
}
Expand All @@ -472,12 +469,13 @@ void GMMUpdate(

for (unsigned int i = 0; i < gmm_N; ++i) {
CovarianceReductionKernel<WARPS, LOAD>
<<<{block_count, 1, batch_count}, BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count);
<<<dim3(block_count, 1, batch_count), BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count);
}

CovarianceFinalizationKernel<WARPS, true><<<{gmm_N, 1, batch_count}, BLOCK>>>(block_gmm_scratch, gmm, block_count);
CovarianceFinalizationKernel<WARPS, true>
<<<dim3(gmm_N, 1, batch_count), BLOCK>>>(block_gmm_scratch, gmm, block_count);

GMMcommonTerm<<<{1, 1, batch_count}, dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(gmm);
GMMcommonTerm<<<dim3(1, 1, batch_count), dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(gmm);
}

void GMMDataTerm(
Expand Down