From 9dc59bd3bbcfceabe251309e4f6b1362e50390c9 Mon Sep 17 00:00:00 2001 From: "Xia, Yu" Date: Tue, 29 Jun 2021 11:43:49 +0800 Subject: [PATCH] speedup ROCm kernels which use atomicAdd --- source/lib/src/rocm/prod_force.hip.cu | 21 +++++++++++---------- source/lib/src/rocm/prod_virial.hip.cu | 24 +++++++++++------------- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/source/lib/src/rocm/prod_force.hip.cu b/source/lib/src/rocm/prod_force.hip.cu index 3c56b8155a..9a0b07e282 100644 --- a/source/lib/src/rocm/prod_force.hip.cu +++ b/source/lib/src/rocm/prod_force.hip.cu @@ -51,12 +51,11 @@ __global__ void force_deriv_wrt_neighbors_a( const int nnei) { // idy -> nnei - const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; - const unsigned int idy = blockIdx.y; + const unsigned int idx = blockIdx.x; + const unsigned int idy = blockIdx.y * blockDim.x + threadIdx.x; const unsigned int idz = threadIdx.y; - const unsigned int idw = threadIdx.z; const int ndescrpt = nnei * 4; - if (idx >= nloc) { + if (idy >= nnei) { return; } // deriv wrt neighbors @@ -64,9 +63,11 @@ __global__ void force_deriv_wrt_neighbors_a( if (j_idx < 0) { return; } - atomicAdd( - force + j_idx * 3 + idz, - net_deriv[idx * ndescrpt + idy * 4 + idw] * in_deriv[idx * ndescrpt * 3 + (idy * 4 + idw) * 3 + idz]); + FPTYPE force_tmp = 0.f; + for (int idw = 0; idw < 4; ++idw) { + force_tmp += net_deriv[idx * ndescrpt + idy * 4 + idw] * in_deriv[idx * ndescrpt * 3 + (idy * 4 + idw) * 3 + idz]; + } + atomicAdd(force + j_idx * 3 + idz, force_tmp); } template @@ -117,9 +118,9 @@ namespace deepmd { net_deriv, in_deriv, ndescrpt); const int LEN = 64; - const int nblock = (nloc + LEN -1) / LEN; - dim3 block_grid(nblock, nnei); - dim3 thread_grid(LEN, 3, 4); + const int nblock = (nnei + LEN - 1) / LEN; + dim3 block_grid(nloc, nblock); + dim3 thread_grid(LEN, 3); hipLaunchKernelGGL(force_deriv_wrt_neighbors_a, block_grid, thread_grid, 0, 0, force, net_deriv, in_deriv, nlist, nloc, nnei); diff --git a/source/lib/src/rocm/prod_virial.hip.cu b/source/lib/src/rocm/prod_virial.hip.cu index a285a1789b..d6ef5546e1 100644 --- a/source/lib/src/rocm/prod_virial.hip.cu +++ b/source/lib/src/rocm/prod_virial.hip.cu @@ -46,24 +46,22 @@ __global__ void virial_deriv_wrt_neighbors_a( // idz = dd0 * 3 + dd1 // dd0 = idz / 3 // dd1 = idz % 3 - const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; - const unsigned int idy = blockIdx.y; + const unsigned int idx = blockIdx.x; + const unsigned int idy = blockIdx.y * blockDim.x + threadIdx.x; const unsigned int idz = threadIdx.y; - const unsigned int idw = threadIdx.z; const int ndescrpt = nnei * 4; - if (idx >= nloc) { + if (idy >= nnei) { return; } int j_idx = nlist[idx * nnei + idy]; if (j_idx < 0) { return; } - // atomicAdd( - // virial + idz, - // net_deriv[idx * ndescrpt + idy * 4 + idw] * rij[idx * nnei * 3 + idy * 3 + idz / 3] * in_deriv[idx * ndescrpt * 3 + (idy * 4 + idw) * 3 + idz % 3]); - atomicAdd( - atom_virial + j_idx * 9 + idz, - net_deriv[idx * ndescrpt + idy * 4 + idw] * rij[idx * nnei * 3 + idy * 3 + idz % 3] * in_deriv[idx * ndescrpt * 3 + (idy * 4 + idw) * 3 + idz / 3]); + FPTYPE virial_tmp = 0.f; + for (int idw = 0; idw < 4; ++idw) { + virial_tmp += net_deriv[idx * ndescrpt + idy * 4 + idw] * rij[idx * nnei * 3 + idy * 3 + idz % 3] * in_deriv[idx * ndescrpt * 3 + (idy * 4 + idw) * 3 + idz / 3]; + } + atomicAdd(atom_virial + j_idx * 9 + idz, virial_tmp); } template @@ -123,9 +121,9 @@ void prod_virial_a_gpu_rocm( 0.0, sizeof(FPTYPE) * 9 * nall)); const int LEN = 16; - int nblock = (nloc + LEN -1) / LEN; - dim3 block_grid(nblock, nnei); - dim3 thread_grid(LEN, 9, 4); + int nblock = (nnei + LEN -1) / LEN; + dim3 block_grid(nloc, nblock); + dim3 thread_grid(LEN, 9); // compute virial of a frame hipLaunchKernelGGL(virial_deriv_wrt_neighbors_a, block_grid, thread_grid, 0, 0, virial, atom_virial,