Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions source/lib/src/rocm/prod_force.hip.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,23 @@ __global__ void force_deriv_wrt_neighbors_a(
const int nnei)
{
// idy -> nnei
const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
const unsigned int idy = blockIdx.y;
const unsigned int idx = blockIdx.x;
const unsigned int idy = blockIdx.y * blockDim.x + threadIdx.x;
const unsigned int idz = threadIdx.y;
const unsigned int idw = threadIdx.z;
const int ndescrpt = nnei * 4;
if (idx >= nloc) {
if (idy >= nnei) {
return;
}
// deriv wrt neighbors
int j_idx = nlist[idx * nnei + idy];
if (j_idx < 0) {
return;
}
atomicAdd(
force + j_idx * 3 + idz,
net_deriv[idx * ndescrpt + idy * 4 + idw] * in_deriv[idx * ndescrpt * 3 + (idy * 4 + idw) * 3 + idz]);
FPTYPE force_tmp = 0.f;
for (int idw = 0; idw < 4; ++idw) {
force_tmp += net_deriv[idx * ndescrpt + idy * 4 + idw] * in_deriv[idx * ndescrpt * 3 + (idy * 4 + idw) * 3 + idz];
}
atomicAdd(force + j_idx * 3 + idz, force_tmp);
}

template<typename FPTYPE>
Expand Down Expand Up @@ -117,9 +118,9 @@ namespace deepmd {
net_deriv, in_deriv, ndescrpt);

const int LEN = 64;
const int nblock = (nloc + LEN -1) / LEN;
dim3 block_grid(nblock, nnei);
dim3 thread_grid(LEN, 3, 4);
const int nblock = (nnei + LEN - 1) / LEN;
dim3 block_grid(nloc, nblock);
dim3 thread_grid(LEN, 3);
hipLaunchKernelGGL(force_deriv_wrt_neighbors_a, block_grid, thread_grid, 0, 0,
force,
net_deriv, in_deriv, nlist, nloc, nnei);
Expand Down
24 changes: 11 additions & 13 deletions source/lib/src/rocm/prod_virial.hip.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,22 @@ __global__ void virial_deriv_wrt_neighbors_a(
// idz = dd0 * 3 + dd1
// dd0 = idz / 3
// dd1 = idz % 3
const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
const unsigned int idy = blockIdx.y;
const unsigned int idx = blockIdx.x;
const unsigned int idy = blockIdx.y * blockDim.x + threadIdx.x;
const unsigned int idz = threadIdx.y;
const unsigned int idw = threadIdx.z;
const int ndescrpt = nnei * 4;
if (idx >= nloc) {
if (idy >= nnei) {
return;
}
int j_idx = nlist[idx * nnei + idy];
if (j_idx < 0) {
return;
}
// atomicAdd(
// virial + idz,
// net_deriv[idx * ndescrpt + idy * 4 + idw] * rij[idx * nnei * 3 + idy * 3 + idz / 3] * in_deriv[idx * ndescrpt * 3 + (idy * 4 + idw) * 3 + idz % 3]);
atomicAdd(
atom_virial + j_idx * 9 + idz,
net_deriv[idx * ndescrpt + idy * 4 + idw] * rij[idx * nnei * 3 + idy * 3 + idz % 3] * in_deriv[idx * ndescrpt * 3 + (idy * 4 + idw) * 3 + idz / 3]);
FPTYPE virial_tmp = 0.f;
for (int idw = 0; idw < 4; ++idw) {
virial_tmp += net_deriv[idx * ndescrpt + idy * 4 + idw] * rij[idx * nnei * 3 + idy * 3 + idz % 3] * in_deriv[idx * ndescrpt * 3 + (idy * 4 + idw) * 3 + idz / 3];
}
atomicAdd(atom_virial + j_idx * 9 + idz, virial_tmp);
}

template<typename FPTYPE>
Expand Down Expand Up @@ -123,9 +121,9 @@ void prod_virial_a_gpu_rocm(
0.0, sizeof(FPTYPE) * 9 * nall));

const int LEN = 16;
int nblock = (nloc + LEN -1) / LEN;
dim3 block_grid(nblock, nnei);
dim3 thread_grid(LEN, 9, 4);
int nblock = (nnei + LEN -1) / LEN;
dim3 block_grid(nloc, nblock);
dim3 thread_grid(LEN, 9);
// compute virial of a frame
hipLaunchKernelGGL(virial_deriv_wrt_neighbors_a, block_grid, thread_grid, 0, 0,
virial, atom_virial,
Expand Down