From 692e1191f63da8b0bac0c1d5a4875c10bf2c6aa1 Mon Sep 17 00:00:00 2001 From: Haozhi Han Date: Fri, 26 Jan 2024 22:22:05 +0800 Subject: [PATCH] Update hsolver_pw.cpp when use_uspp==false, overlap matrix should be E. --- source/module_hsolver/hsolver_pw.cpp | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index 16fa5f335b..ae784d2009 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -624,17 +624,31 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, psi::P hm->ops->hPsi(info); ModuleBase::timer::tick("DiagoCG_New", "hpsi_func"); }; - auto spsi_func = [hm](const ct::Tensor& psi_in, ct::Tensor& spsi_out) { + auto spsi_func = [this, hm](const ct::Tensor& psi_in, ct::Tensor& spsi_out) { ModuleBase::timer::tick("DiagoCG_New", "spsi_func"); // psi_in should be a 2D tensor: // psi_in.shape() = [nbands, nbasis] const auto ndim = psi_in.shape().ndim(); REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2"); - // Convert a Tensor object to a psi::Psi object - hm->sPsi(psi_in.data(), spsi_out.data(), + + if (GlobalV::use_uspp) + { + // Convert a Tensor object to a psi::Psi object + hm->sPsi(psi_in.data(), spsi_out.data(), ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), ndim == 1 ? 1 : psi_in.shape().dim_size(0)); + } else + { + psi::memory::synchronize_memory_op()( + this->ctx, + this->ctx, + spsi_out.data(), + psi_in.data(), + static_cast((ndim == 1 ? 1 : psi_in.shape().dim_size(0)) + * (ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1)))); + } + ModuleBase::timer::tick("DiagoCG_New", "spsi_func"); }; auto psi_tensor = ct::TensorMap( @@ -776,4 +790,4 @@ template class HSolverPW, psi::DEVICE_GPU>; template class HSolverPW, psi::DEVICE_GPU>; #endif -} // namespace hsolver \ No newline at end of file +} // namespace hsolver