Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions source/lib/src/rocm/prod_force.hip.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ __global__ void force_deriv_wrt_neighbors_r(
const int nnei)
{
// idy -> nnei
const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
const unsigned int idy = blockIdx.y;
const unsigned int idx = blockIdx.x;
const unsigned int idy = blockIdx.y * blockDim.x + threadIdx.x;
const unsigned int idz = threadIdx.y;
const int ndescrpt = nnei * 1;
if (idx >= nloc) {
if (idy >= nnei) {
return;
}
// deriv wrt neighbors
Expand Down Expand Up @@ -146,8 +146,8 @@ namespace deepmd {
net_deriv, in_deriv, ndescrpt);

const int LEN = 64;
const int nblock = (nloc + LEN -1) / LEN;
dim3 block_grid(nblock, nnei);
const int nblock = (nnei + LEN -1) / LEN;
dim3 block_grid(nloc, nblock);
dim3 thread_grid(LEN, 3);
hipLaunchKernelGGL(force_deriv_wrt_neighbors_r, block_grid, thread_grid, 0, 0,
force,
Expand Down
10 changes: 5 additions & 5 deletions source/lib/src/rocm/prod_virial.hip.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ __global__ void virial_deriv_wrt_neighbors_r(
// idz = dd0 * 3 + dd1
// dd0 = idz / 3
// dd1 = idz % 3
const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
const unsigned int idy = blockIdx.y;
const unsigned int idx = blockIdx.x;
const unsigned int idy = blockIdx.y * blockDim.x + threadIdx.x;
const unsigned int idz = threadIdx.y;
const int ndescrpt = nnei * 1;

if (idx >= nloc) {
if (idy >= nnei) {
return;
}
int j_idx = nlist[idx * nnei + idy];
Expand Down Expand Up @@ -154,8 +154,8 @@ void prod_virial_r_gpu_rocm(
0.0, sizeof(FPTYPE) * 9 * nall));

const int LEN = 16;
int nblock = (nloc + LEN -1) / LEN;
dim3 block_grid(nblock, nnei);
int nblock = (nnei + LEN -1) / LEN;
dim3 block_grid(nloc, nblock);
dim3 thread_grid(LEN, 9);
// compute virial of a frame
hipLaunchKernelGGL(virial_deriv_wrt_neighbors_r, block_grid, thread_grid, 0, 0,
Expand Down