diff --git a/source/lib/src/gpu/tabulate.cu b/source/lib/src/gpu/tabulate.cu index a22742ae19..71ea17ced5 100644 --- a/source/lib/src/gpu/tabulate.cu +++ b/source/lib/src/gpu/tabulate.cu @@ -200,7 +200,9 @@ __global__ void tabulate_fusion_se_a_fifth_order_polynomial( FPTYPE var[6]; for (int ii = 0; ii < nnei; ii++) { FPTYPE xx = em_x[block_idx * nnei + ii]; - if (xx == ago && is_sorted) { + if (xx == ago && em[block_idx * nnei * 4 + ii * 4 + 1] == 0. && + em[block_idx * nnei * 4 + ii * 4 + 2] == 0. && + em[block_idx * nnei * 4 + ii * 4 + 3] == 0. && is_sorted) { unloop = true; breakpoint = ii; } @@ -286,7 +288,9 @@ __global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial( FPTYPE ago = GpuShuffleSync(0xffffffff, em_x[block_idx * nnei + nnei - 1], 0); for (int ii = warp_idx; ii < nnei; ii += KTILE) { FPTYPE xx = em_x[block_idx * nnei + ii]; - if (ago == xx && is_sorted) { + if (ago == xx && em[block_idx * nnei * 4 + ii * 4 + 1] == 0. && + em[block_idx * nnei * 4 + ii * 4 + 2] == 0. && + em[block_idx * nnei * 4 + ii * 4 + 3] == 0. && is_sorted) { unloop = true; breakpoint = ii; } @@ -393,7 +397,9 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial( for (int ii = 0; ii < nnei; ii++) { FPTYPE xx = em_x[block_idx * nnei + ii]; FPTYPE dz_xx = dz_dy_dem_x[block_idx * nnei + ii]; - if (xx == ago && is_sorted) { + if (xx == ago && em[block_idx * nnei * 4 + ii * 4 + 1] == 0. && + em[block_idx * nnei * 4 + ii * 4 + 2] == 0. && + em[block_idx * nnei * 4 + ii * 4 + 3] == 0. && is_sorted) { unloop = true; breakpoint = ii; } diff --git a/source/lib/src/tabulate.cc b/source/lib/src/tabulate.cc index 3e2a1bec62..9352980351 100644 --- a/source/lib/src/tabulate.cc +++ b/source/lib/src/tabulate.cc @@ -108,7 +108,7 @@ void deepmd::tabulate_fusion_se_a_cpu(FPTYPE* out, ll[2] = em[ii * nnei * 4 + jj * 4 + 2]; ll[3] = em[ii * nnei * 4 + jj * 4 + 3]; FPTYPE xx = em_x[ii * nnei + jj]; - if (ago == xx && is_sorted) { + if (ago == xx && ll[1] == 0. && ll[2] == 0. && ll[3] == 0. && is_sorted) { unloop = true; } int table_idx = 0; @@ -195,7 +195,7 @@ void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x, ll[2] = em[ii * nnei * 4 + jj * 4 + 2]; ll[3] = em[ii * nnei * 4 + jj * 4 + 3]; FPTYPE xx = em_x[ii * nnei + jj]; - if (ago == xx && is_sorted) { + if (ago == xx && ll[1] == 0. && ll[2] == 0. && ll[3] == 0. && is_sorted) { unloop = true; } int table_idx = 0; @@ -298,7 +298,7 @@ void deepmd::tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy, hh[3] = dz_dy_dem[ii * nnei * 4 + jj * 4 + 3]; FPTYPE xx = em_x[ii * nnei + jj]; FPTYPE dz_xx = dz_dy_dem_x[ii * nnei + jj]; - if (ago == xx && is_sorted) { + if (ago == xx && ll[1] == 0. && ll[2] == 0. && ll[3] == 0. && is_sorted) { unloop = true; } int table_idx = 0;