From 53b33b7c817b83419436a1e093a19ec111f56ce4 Mon Sep 17 00:00:00 2001 From: Chen Nuo <49788094+Cstandardlib@users.noreply.github.com> Date: Mon, 9 Mar 2026 00:16:12 +0800 Subject: [PATCH] Use a unified hpsi_func/spsi_func for all ks_solvers --- source/source_hsolver/hsolver_pw.cpp | 118 +++++++-------------------- 1 file changed, 29 insertions(+), 89 deletions(-) diff --git a/source/source_hsolver/hsolver_pw.cpp b/source/source_hsolver/hsolver_pw.cpp index 0d74f72162..204b9b53ed 100644 --- a/source/source_hsolver/hsolver_pw.cpp +++ b/source/source_hsolver/hsolver_pw.cpp @@ -249,6 +249,18 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, const int cur_nbasis = psi.get_current_nbas(); + // Shared matrix-blockvector operators used by all iterative solvers. + auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { + auto psi_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, cur_nbasis); + psi::Range bands_range(true, 0, 0, nvec - 1); + using hpsi_info = typename hamilt::Operator::hpsi_info; + hpsi_info info(&psi_wrapper, bands_range, hpsi_out); + hm->ops->hPsi(info); + }; + auto spsi_func = [hm](const T* psi_in, T* spsi_out, const int ld_psi, const int nvec) { + hm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec); + }; + if (this->method == "cg") { // wrap the subspace_func into a lambda function @@ -272,28 +284,6 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, this->diag_iter_max, this->nproc_in_pool); - // wrap the hpsi_func and spsi_func into lambda functions - auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { - auto psi_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, cur_nbasis); - psi::Range all_bands_range(true, 0, 0, nvec - 1); - using hpsi_info = typename hamilt::Operator::hpsi_info; - hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out); - hm->ops->hPsi(info); - }; - auto spsi_func = [this, hm](T* psi_in, T* spsi_out, const int ld_psi, const int nvec) { - if (this->use_uspp) - { - hm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec); - } - else - { - base_device::memory::synchronize_memory_op()( - spsi_out, - psi_in, - static_cast(nvec) * static_cast(ld_psi)); - } - }; - DiagoIterAssist::avg_iter += static_cast( cg.diag(hpsi_func, spsi_func, @@ -313,42 +303,14 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, const int nband_l = psi.get_nbands(); const int nbasis = psi.get_nbasis(); const int ndim = psi.get_current_ngk(); - // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec - auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { - - // Convert "pointer data stucture" to a psi::Psi object - auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, cur_nbasis); - - psi::Range bands_range(true, 0, 0, nvec - 1); - - using hpsi_info = typename hamilt::Operator::hpsi_info; - hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out); - hm->ops->hPsi(info); - }; DiagoBPCG bpcg(pre_condition.data()); bpcg.init_iter(PARAM.inp.nbands, nband_l, nbasis, ndim); bpcg.diag(hpsi_func, psi.get_pointer(), eigenvalue, this->ethr_band); } else if (this->method == "dav_subspace") { - // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec - auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { - - // Convert "pointer data stucture" to a psi::Psi object - auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, cur_nbasis); - - psi::Range bands_range(true, 0, 0, nvec - 1); - - using hpsi_info = typename hamilt::Operator::hpsi_info; - hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out); - hm->ops->hPsi(info); - }; bool scf = this->calculation_type == "nscf" ? false : true; - auto spsi_func = [hm](T* psi_in, T* spsi_out, const int ld_psi, const int nvec) { - hm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec); - }; - Diago_DavSubspace dav_subspace(pre_condition, psi.get_nbands(), psi.get_k_first() ? psi.get_current_ngk() @@ -361,8 +323,13 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, PARAM.inp.nb2d); DiagoIterAssist::avg_iter += static_cast( - dav_subspace - .diag(hpsi_func, spsi_func, psi.get_pointer(), psi.get_nbasis(), eigenvalue, this->ethr_band, scf)); + dav_subspace.diag(hpsi_func, + spsi_func, + psi.get_pointer(), + psi.get_nbasis(), + eigenvalue, + this->ethr_band, + scf)); } else if (this->method == "dav") { @@ -383,45 +350,18 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, const int nband = psi.get_nbands(); /// number of eigenpairs sought const int ld_psi = psi.get_nbasis(); /// leading dimension of psi - // Davidson matrix-blockvector functions - /// wrap hpsi into lambda function, Matrix \times blockvector - // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec - auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { - - // Convert pointer of psi_in to a psi::Psi object - auto psi_iter_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, cur_nbasis); - - psi::Range bands_range(true, 0, 0, nvec - 1); - - using hpsi_info = typename hamilt::Operator::hpsi_info; - hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out); - hm->ops->hPsi(info); - }; - - /// wrap spsi into lambda function, Matrix \times blockvector - /// spsi(X, SX, ld, nvec) - /// ld is leading dimension of psi and spsi - auto spsi_func = [hm](const T* psi_in, - T* spsi_out, - const int ld_psi, // Leading dimension of psi and spsi. - const int nvec // Number of vectors(bands) - ) { - // sPsi determines S=I or not by PARAM.globalv.use_uspp inside - // sPsi(psi, spsi, nrow, npw, nbands) - hm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec); - }; - DiagoDavid david(pre_condition.data(), nband, dim, PARAM.inp.pw_diag_ndim, this->use_paw, comm_info); // do diag and add davidson iteration counts up to avg_iter - DiagoIterAssist::avg_iter += static_cast(david.diag(hpsi_func, - spsi_func, - ld_psi, - psi.get_pointer(), - eigenvalue, - this->ethr_band, - david_maxiter, - ntry_max, - notconv_max)); + DiagoIterAssist::avg_iter += static_cast( + david.diag(hpsi_func, + spsi_func, + ld_psi, + psi.get_pointer(), + eigenvalue, + this->ethr_band, + david_maxiter, + ntry_max, + notconv_max)); } ModuleBase::timer::tick("HSolverPW", "solve_psik"); return;