From 70cb2263c40fcd1e0a9c9a72c24f1ec831c0650e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=B3=BD=E5=AE=87?= Date: Wed, 30 Jun 2021 02:48:02 +0800 Subject: [PATCH] Synchronize CUDA _r modifications to ROCM --- source/lib/src/rocm/prod_force.hip.cu | 10 +++++----- source/lib/src/rocm/prod_virial.hip.cu | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/source/lib/src/rocm/prod_force.hip.cu b/source/lib/src/rocm/prod_force.hip.cu index 9a0b07e282..48b12dfa50 100644 --- a/source/lib/src/rocm/prod_force.hip.cu +++ b/source/lib/src/rocm/prod_force.hip.cu @@ -80,11 +80,11 @@ __global__ void force_deriv_wrt_neighbors_r( const int nnei) { // idy -> nnei - const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; - const unsigned int idy = blockIdx.y; + const unsigned int idx = blockIdx.x; + const unsigned int idy = blockIdx.y * blockDim.x + threadIdx.x; const unsigned int idz = threadIdx.y; const int ndescrpt = nnei * 1; - if (idx >= nloc) { + if (idy >= nnei) { return; } // deriv wrt neighbors @@ -146,8 +146,8 @@ namespace deepmd { net_deriv, in_deriv, ndescrpt); const int LEN = 64; - const int nblock = (nloc + LEN -1) / LEN; - dim3 block_grid(nblock, nnei); + const int nblock = (nnei + LEN -1) / LEN; + dim3 block_grid(nloc, nblock); dim3 thread_grid(LEN, 3); hipLaunchKernelGGL(force_deriv_wrt_neighbors_r, block_grid, thread_grid, 0, 0, force, diff --git a/source/lib/src/rocm/prod_virial.hip.cu b/source/lib/src/rocm/prod_virial.hip.cu index d6ef5546e1..ff8017a687 100644 --- a/source/lib/src/rocm/prod_virial.hip.cu +++ b/source/lib/src/rocm/prod_virial.hip.cu @@ -80,12 +80,12 @@ __global__ void virial_deriv_wrt_neighbors_r( // idz = dd0 * 3 + dd1 // dd0 = idz / 3 // dd1 = idz % 3 - const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; - const unsigned int idy = blockIdx.y; + const unsigned int idx = blockIdx.x; + const unsigned int idy = blockIdx.y * blockDim.x + threadIdx.x; const unsigned int idz = threadIdx.y; const int ndescrpt = nnei * 1; - if (idx >= nloc) { + if (idy >= nnei) { return; } int j_idx = nlist[idx * nnei + idy]; @@ -154,8 +154,8 @@ void prod_virial_r_gpu_rocm( 0.0, sizeof(FPTYPE) * 9 * nall)); const int LEN = 16; - int nblock = (nloc + LEN -1) / LEN; - dim3 block_grid(nblock, nnei); + int nblock = (nnei + LEN -1) / LEN; + dim3 block_grid(nloc, nblock); dim3 thread_grid(LEN, 9); // compute virial of a frame hipLaunchKernelGGL(virial_deriv_wrt_neighbors_r, block_grid, thread_grid, 0, 0,