From 6a9aafd83989809a6eb1e4f7f4df00b216f93cd3 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 15 Oct 2023 01:56:23 -0400 Subject: [PATCH 1/2] fix se_a compression for just enough sel and symmetrical coordinates In this case, when `em_x` is equal to the last neighbor, `em` may not. This patch checks whether em is zero, which means sel is larger than the actual neighbor number. Signed-off-by: Jinzhe Zeng --- source/lib/src/gpu/tabulate.cu | 12 +++++++++--- source/lib/src/tabulate.cc | 7 ++++--- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/source/lib/src/gpu/tabulate.cu b/source/lib/src/gpu/tabulate.cu index a22742ae19..71ea17ced5 100644 --- a/source/lib/src/gpu/tabulate.cu +++ b/source/lib/src/gpu/tabulate.cu @@ -200,7 +200,9 @@ __global__ void tabulate_fusion_se_a_fifth_order_polynomial( FPTYPE var[6]; for (int ii = 0; ii < nnei; ii++) { FPTYPE xx = em_x[block_idx * nnei + ii]; - if (xx == ago && is_sorted) { + if (xx == ago && em[block_idx * nnei * 4 + ii * 4 + 1] == 0. && + em[block_idx * nnei * 4 + ii * 4 + 2] == 0. && + em[block_idx * nnei * 4 + ii * 4 + 3] == 0. && is_sorted) { unloop = true; breakpoint = ii; } @@ -286,7 +288,9 @@ __global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial( FPTYPE ago = GpuShuffleSync(0xffffffff, em_x[block_idx * nnei + nnei - 1], 0); for (int ii = warp_idx; ii < nnei; ii += KTILE) { FPTYPE xx = em_x[block_idx * nnei + ii]; - if (ago == xx && is_sorted) { + if (ago == xx && em[block_idx * nnei * 4 + ii * 4 + 1] == 0. && + em[block_idx * nnei * 4 + ii * 4 + 2] == 0. && + em[block_idx * nnei * 4 + ii * 4 + 3] == 0. && is_sorted) { unloop = true; breakpoint = ii; } @@ -393,7 +397,9 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial( for (int ii = 0; ii < nnei; ii++) { FPTYPE xx = em_x[block_idx * nnei + ii]; FPTYPE dz_xx = dz_dy_dem_x[block_idx * nnei + ii]; - if (xx == ago && is_sorted) { + if (xx == ago && em[block_idx * nnei * 4 + ii * 4 + 1] == 0. && + em[block_idx * nnei * 4 + ii * 4 + 2] == 0. && + em[block_idx * nnei * 4 + ii * 4 + 3] == 0. && is_sorted) { unloop = true; breakpoint = ii; } diff --git a/source/lib/src/tabulate.cc b/source/lib/src/tabulate.cc index 3e2a1bec62..f0b97d085c 100644 --- a/source/lib/src/tabulate.cc +++ b/source/lib/src/tabulate.cc @@ -108,7 +108,7 @@ void deepmd::tabulate_fusion_se_a_cpu(FPTYPE* out, ll[2] = em[ii * nnei * 4 + jj * 4 + 2]; ll[3] = em[ii * nnei * 4 + jj * 4 + 3]; FPTYPE xx = em_x[ii * nnei + jj]; - if (ago == xx && is_sorted) { + if (ago == xx && ll[1] == 0. && ll[2] == 0. && ll[3] == 0. && is_sorted) { unloop = true; } int table_idx = 0; @@ -195,7 +195,8 @@ void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x, ll[2] = em[ii * nnei * 4 + jj * 4 + 2]; ll[3] = em[ii * nnei * 4 + jj * 4 + 3]; FPTYPE xx = em_x[ii * nnei + jj]; - if (ago == xx && is_sorted) { + if (ago == xx && &&ll[1] == 0. && ll[2] == 0. && ll[3] == 0. && + is_sorted) { unloop = true; } int table_idx = 0; @@ -298,7 +299,7 @@ void deepmd::tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy, hh[3] = dz_dy_dem[ii * nnei * 4 + jj * 4 + 3]; FPTYPE xx = em_x[ii * nnei + jj]; FPTYPE dz_xx = dz_dy_dem_x[ii * nnei + jj]; - if (ago == xx && is_sorted) { + if (ago == xx && ll[1] == 0. && ll[2] == 0. && ll[3] == 0. && is_sorted) { unloop = true; } int table_idx = 0; From 91ee7148307f8fbd49f7db0b14fc11e7cbe15b3f Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 15 Oct 2023 01:57:37 -0400 Subject: [PATCH 2/2] fix typo Signed-off-by: Jinzhe Zeng --- source/lib/src/tabulate.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/source/lib/src/tabulate.cc b/source/lib/src/tabulate.cc index f0b97d085c..9352980351 100644 --- a/source/lib/src/tabulate.cc +++ b/source/lib/src/tabulate.cc @@ -195,8 +195,7 @@ void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x, ll[2] = em[ii * nnei * 4 + jj * 4 + 2]; ll[3] = em[ii * nnei * 4 + jj * 4 + 3]; FPTYPE xx = em_x[ii * nnei + jj]; - if (ago == xx && &&ll[1] == 0. && ll[2] == 0. && ll[3] == 0. && - is_sorted) { + if (ago == xx && ll[1] == 0. && ll[2] == 0. && ll[3] == 0. && is_sorted) { unloop = true; } int table_idx = 0;