From 6492856cc3298829272379ed5bba3546e15eec95 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 31 May 2021 23:45:17 +0800 Subject: [PATCH 1/2] fix bug in op definition prod_virial_grad_multi_device.cc. This fixes 679 --- source/op/prod_virial_grad_multi_device.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/op/prod_virial_grad_multi_device.cc b/source/op/prod_virial_grad_multi_device.cc index ac74d1d141..e8c117f50e 100644 --- a/source/op/prod_virial_grad_multi_device.cc +++ b/source/op/prod_virial_grad_multi_device.cc @@ -117,7 +117,7 @@ class ProdVirialSeAGradOp : public OpKernel // loop over frames for (int kk = 0; kk < nframes; ++kk){ FPTYPE * grad_net = p_grad_net + kk * nloc * ndescrpt; - const FPTYPE * grad = p_grad + kk * nloc * 3; + const FPTYPE * grad = p_grad + kk * 9; const FPTYPE * in_deriv = p_in_deriv + kk * nloc * ndescrpt * 3; const FPTYPE * rij = p_rij + kk * nloc * nnei * 3; const int * nlist = p_nlist + kk * nloc * nnei; @@ -272,4 +272,4 @@ REGISTER_KERNEL_BUILDER( ProdVirialSeRGradOp); REGISTER_GPU(float); REGISTER_GPU(double); -#endif // GOOGLE_CUDA \ No newline at end of file +#endif // GOOGLE_CUDA From e3c6fd5652759f0fe5245f0c5921b0a8ffe21c8e Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 1 Jun 2021 07:37:01 +0800 Subject: [PATCH 2/2] fix bug in op definition prod_virial_grad_multi_device.cc. --- source/op/prod_virial_grad_multi_device.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/op/prod_virial_grad_multi_device.cc b/source/op/prod_virial_grad_multi_device.cc index e8c117f50e..d8c8898c37 100644 --- a/source/op/prod_virial_grad_multi_device.cc +++ b/source/op/prod_virial_grad_multi_device.cc @@ -229,7 +229,7 @@ class ProdVirialSeRGradOp : public OpKernel // loop over frames for (int kk = 0; kk < nframes; ++kk){ FPTYPE * grad_net = p_grad_net + kk * nloc * ndescrpt; - const FPTYPE * grad = p_grad + kk * nloc * 3; + const FPTYPE * grad = p_grad + kk * 9; const FPTYPE * in_deriv = p_in_deriv + kk * nloc * ndescrpt * 3; const FPTYPE * rij = p_rij + kk * nloc * nnei * 3; const int * nlist = p_nlist + kk * nloc * nnei;