From 749493a2a28df4c688877fc05fae5de89bb87b56 Mon Sep 17 00:00:00 2001 From: LKFEIYI <38131547+LKFEIYI@users.noreply.github.com> Date: Wed, 7 Jan 2026 01:16:02 -0700 Subject: [PATCH] reduce FFT calls in Leps2 --- .../module_surchem/minimize_cg.cpp | 50 ++++++------------- source/source_hamilt/module_surchem/surchem.h | 7 +-- 2 files changed, 18 insertions(+), 39 deletions(-) diff --git a/source/source_hamilt/module_surchem/minimize_cg.cpp b/source/source_hamilt/module_surchem/minimize_cg.cpp index 451a8fc233..da41053cd5 100644 --- a/source/source_hamilt/module_surchem/minimize_cg.cpp +++ b/source/source_hamilt/module_surchem/minimize_cg.cpp @@ -26,16 +26,12 @@ void surchem::minimize_cg(const UnitCell& ucell, std::complex *gradphi_G_work = new std::complex[rho_basis->npw]; - // Removed unused phi_work allocation - // std::complex *phi_work = new std::complex[rho_basis->npw]; - // ========================================================== // PRE-ALLOCATION FOR LEPS2 (Avoids allocation inside loop) // ========================================================== ModuleBase::Vector3 *aux_grad_phi = new ModuleBase::Vector3[rho_basis->nrxx]; - std::complex *aux_grad_grad_phi_G = new std::complex[rho_basis->npw]; - double *aux_lp_real = new double[rho_basis->nrxx]; double *aux_grad_grad_phi_real = new double[rho_basis->nrxx]; + // remove aux_grad_grad_phi_G and aux_lp_real ModuleBase::GlobalFunc::ZEROS(resid, rho_basis->npw); ModuleBase::GlobalFunc::ZEROS(z, rho_basis->npw); @@ -68,7 +64,7 @@ void surchem::minimize_cg(const UnitCell& ucell, // call leps to calculate div ( epsilon * grad ) phi // Updated Leps2 call with new buffers Leps2(ucell, rho_basis, phi, d_eps, gradphi_G_work, lp, - aux_grad_phi, aux_grad_grad_phi_G, aux_lp_real, aux_grad_grad_phi_real); + aux_grad_phi, aux_grad_grad_phi_real); // the residue // r = A*phi + (chtot + N) @@ -106,7 +102,7 @@ void surchem::minimize_cg(const UnitCell& ucell, // Updated Leps2 call inside loop Leps2(ucell, rho_basis, d, d_eps, gradphi_G_work, lp, - aux_grad_phi, aux_grad_grad_phi_G, aux_lp_real, aux_grad_grad_phi_real); + aux_grad_phi, aux_grad_grad_phi_real); // calculate alpha alpha = -rinvLr / ModuleBase::GlobalFunc::ddot_real(rho_basis->npw, d, lp); @@ -161,12 +157,13 @@ void surchem::minimize_cg(const UnitCell& ucell, // Clean up auxiliary buffers delete[] aux_grad_phi; - delete[] aux_grad_grad_phi_G; - delete[] aux_lp_real; + // delete[] aux_grad_grad_phi_G; // Removed + // delete[] aux_lp_real; // Removed delete[] aux_grad_grad_phi_real; } // avoid creating large temporary matrices inside its iteration loop +// reduce the intermediate FFT related calls void surchem::Leps2(const UnitCell& ucell, const ModulePW::PW_Basis* rho_basis, std::complex* phi, @@ -174,8 +171,6 @@ void surchem::Leps2(const UnitCell& ucell, std::complex* gradphi_G_work, std::complex* lp, ModuleBase::Vector3* grad_phi_R, // size: nrxx - std::complex* aux_G, // size: npw - double* lp_real, // size: nrxx double* aux_R) // size: nrxx { @@ -190,45 +185,32 @@ void surchem::Leps2(const UnitCell& ucell, } - ModuleBase::GlobalFunc::ZEROS(lp_real, rho_basis->nrxx); + ModuleBase::GlobalFunc::ZEROS(lp, rho_basis->npw); - // 1. R -> G + // R -> G for (int ir = 0; ir < rho_basis->nrxx; ir++) aux_R[ir] = grad_phi_R[ir].x; - rho_basis->real2recip(aux_R, gradphi_G_work); // + rho_basis->real2recip(aux_R, gradphi_G_work); - for(int ig=0; ignpw; ig++) { - aux_G[ig] = ModuleBase::IMAG_UNIT * gradphi_G_work[ig] * rho_basis->gcar[ig][0]; // 0 = x - } - rho_basis->recip2real(aux_G, aux_R); - for(int ir=0; irnrxx; ir++) { - lp_real[ir] += aux_R[ir] * ucell.tpiba; + // Divergence in G space: div(F) -> i * G * F(G) + lp[ig] += ModuleBase::IMAG_UNIT * gradphi_G_work[ig] * rho_basis->gcar[ig][0]; // 0 = x } - for (int ir = 0; ir < rho_basis->nrxx; ir++) aux_R[ir] = grad_phi_R[ir].y; rho_basis->real2recip(aux_R, gradphi_G_work); for(int ig=0; ignpw; ig++) { - aux_G[ig] = ModuleBase::IMAG_UNIT * gradphi_G_work[ig] * rho_basis->gcar[ig][1]; // 1 = y - } - rho_basis->recip2real(aux_G, aux_R); - for(int ir=0; irnrxx; ir++) { - lp_real[ir] += aux_R[ir] * ucell.tpiba; + lp[ig] += ModuleBase::IMAG_UNIT * gradphi_G_work[ig] * rho_basis->gcar[ig][1]; // 1 = y } - for (int ir = 0; ir < rho_basis->nrxx; ir++) aux_R[ir] = grad_phi_R[ir].z; rho_basis->real2recip(aux_R, gradphi_G_work); for(int ig=0; ignpw; ig++) { - aux_G[ig] = ModuleBase::IMAG_UNIT * gradphi_G_work[ig] * rho_basis->gcar[ig][2]; // 2 = z - } - rho_basis->recip2real(aux_G, aux_R); - for(int ir=0; irnrxx; ir++) { - lp_real[ir] += aux_R[ir] * ucell.tpiba; + lp[ig] += ModuleBase::IMAG_UNIT * gradphi_G_work[ig] * rho_basis->gcar[ig][2]; // 2 = z } - - rho_basis->real2recip(lp_real, lp); + for(int ig=0; ignpw; ig++) { + lp[ig] *= ucell.tpiba; + } } \ No newline at end of file diff --git a/source/source_hamilt/module_surchem/surchem.h b/source/source_hamilt/module_surchem/surchem.h index 9c057aaffb..f7aa0a4b02 100644 --- a/source/source_hamilt/module_surchem/surchem.h +++ b/source/source_hamilt/module_surchem/surchem.h @@ -100,11 +100,8 @@ class surchem double* epsilon, // epsilon from shapefunc, dim=nrxx std::complex* gradphi_G_work, std::complex* lp, - // New buffers - ModuleBase::Vector3* grad_phi_R, - std::complex* grad_grad_phi_G, - double* lp_real, - double* grad_grad_phi_real); + ModuleBase::Vector3* grad_phi_R, // size: nrxx + double* aux_R); void v_correction(const UnitCell& cell, const Parallel_Grid& pgrid,