From fe1d82a9fe9be3d06b91b67261747311663773ae Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Tue, 10 Mar 2026 08:19:14 +0800 Subject: [PATCH 01/40] small format changes --- source/source_esolver/esolver_ks_pw.cpp | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index efb17bc0fd..5f9e56a95d 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -48,10 +48,10 @@ ESolver_KS_PW::ESolver_KS_PW() template ESolver_KS_PW::~ESolver_KS_PW() { - //**************************************************** - // do not add any codes in this deconstructor funcion - //**************************************************** - // delete Hamilt + //**************************************************** + // do not add any codes in this deconstructor funcion + //**************************************************** + // delete Hamilt this->deallocate_hamilt(); // mohan add 2025-10-12 @@ -83,7 +83,6 @@ void ESolver_KS_PW::deallocate_hamilt() template void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_para& inp) { - //! Call before_all_runners() of ESolver_KS ESolver_KS::before_all_runners(ucell, inp); //! setup and allocation for pelec, potentials, etc. @@ -105,7 +104,6 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) ModuleBase::TITLE("ESolver_KS_PW", "before_scf"); ModuleBase::timer::tick("ESolver_KS_PW", "before_scf"); - //! Call before_scf() of ESolver_KS ESolver_KS::before_scf(ucell, istep); //! Init variables (once the cell has changed) @@ -143,17 +141,15 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) template void ESolver_KS_PW::iter_init(UnitCell& ucell, const int istep, const int iter) { - // 1) Call iter_init() of ESolver_KS ESolver_KS::iter_init(ucell, istep, iter); - // 2) perform charge mixing for KSDFT using pw basis module_charge::chgmixing_ks_pw(iter, this->p_chgmix, this->dftu, PARAM.inp); - // 3) mohan move harris functional here, 2012-06-05 + // mohan move harris functional here, 2012-06-05 // use 'rho(in)' and 'v_h and v_xc'(in) this->pelec->f_en.deband_harris = this->pelec->cal_delta_eband(ucell); - // 4) update local occupations for DFT+U + // update local occupations for DFT+U // should before lambda loop in DeltaSpin pw::iter_init_dftu_pw(iter, istep, this->dftu, this->stp.psi_t, this->pelec->wg, ucell, PARAM.inp); } @@ -265,7 +261,6 @@ void ESolver_KS_PW::after_scf(UnitCell& ucell, const int istep, const this->pelec->cal_tau(*(this->psi)); } - // Call 'after_scf' of ESolver_KS ESolver_KS::after_scf(ucell, istep, conv_esolver); // Output quantities From 1a725f6bcfd284470c93aea8cf8392341436cde9 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Tue, 10 Mar 2026 08:54:30 +0800 Subject: [PATCH 02/40] refactor(esolver): extract charge density symmetrization to Symmetry_rho::symmetrize_rho - Add static method symmetrize_rho() in Symmetry_rho class - Replace 7 duplicate code blocks with single function call - Simplify code from 35 lines to 7 lines (80% reduction) - Improve code readability and maintainability Modified files: - source_estate/module_charge/symmetry_rho.h: add static method declaration - source_estate/module_charge/symmetry_rho.cpp: implement static method - source_esolver/esolver_ks_lcao.cpp: 2 calls updated - source_esolver/esolver_ks_pw.cpp: 1 call updated - source_esolver/esolver_ks_lcao_tddft.cpp: 1 call updated - source_esolver/esolver_ks_lcaopw.cpp: 1 call updated - source_esolver/esolver_of.cpp: 1 call updated - source_esolver/esolver_sdft_pw.cpp: 1 call updated This refactoring follows the ESolver cleanup principle: keep ESolver focused on high-level workflow control. --- source/source_esolver/esolver_ks_lcao.cpp | 12 ++---------- source/source_esolver/esolver_ks_lcao_tddft.cpp | 6 +----- source/source_esolver/esolver_ks_lcaopw.cpp | 6 +----- source/source_esolver/esolver_ks_pw.cpp | 6 +----- source/source_esolver/esolver_of.cpp | 6 +----- source/source_esolver/esolver_sdft_pw.cpp | 6 +----- .../source_estate/module_charge/symmetry_rho.cpp | 12 ++++++++++++ .../source_estate/module_charge/symmetry_rho.h | 16 ++++++++++++++++ 8 files changed, 35 insertions(+), 35 deletions(-) diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index f8cecf6805..be6833294f 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -203,11 +203,7 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) #endif // 16) the electron charge density should be symmetrized, - Symmetry_rho srho; - for (int is = 0; is < PARAM.inp.nspin; is++) - { - srho.begin(is, this->chr, this->pw_rho, ucell.symm); - } + Symmetry_rho::symmetrize_rho(PARAM.inp.nspin, this->chr, this->pw_rho, ucell.symm); // 17) update of RDMFT, added by jghan if (PARAM.inp.rdmft == true) @@ -435,11 +431,7 @@ void ESolver_KS_LCAO::hamilt2rho_single(UnitCell& ucell, int istep, int #endif // 5) symmetrize the charge density - Symmetry_rho srho; - for (int is = 0; is < PARAM.inp.nspin; is++) - { - srho.begin(is, this->chr, this->pw_rho, ucell.symm); - } + Symmetry_rho::symmetrize_rho(PARAM.inp.nspin, this->chr, this->pw_rho, ucell.symm); // 6) calculate delta energy this->pelec->f_en.deband = this->pelec->cal_delta_eband(ucell); diff --git a/source/source_esolver/esolver_ks_lcao_tddft.cpp b/source/source_esolver/esolver_ks_lcao_tddft.cpp index 8a0035681b..b7641a09fc 100644 --- a/source/source_esolver/esolver_ks_lcao_tddft.cpp +++ b/source/source_esolver/esolver_ks_lcao_tddft.cpp @@ -290,11 +290,7 @@ void ESolver_KS_LCAO_TDDFT::hamilt2rho_single(UnitCell& ucell, // Symmetrize the charge density only for ground state if (istep <= 1) { - Symmetry_rho srho; - for (int is = 0; is < PARAM.inp.nspin; is++) - { - srho.begin(is, this->chr, this->pw_rho, ucell.symm); - } + Symmetry_rho::symmetrize_rho(PARAM.inp.nspin, this->chr, this->pw_rho, ucell.symm); } #ifdef __EXX if (GlobalC::exx_info.info_ri.real_number) diff --git a/source/source_esolver/esolver_ks_lcaopw.cpp b/source/source_esolver/esolver_ks_lcaopw.cpp index dd37188af3..f9700f5b68 100644 --- a/source/source_esolver/esolver_ks_lcaopw.cpp +++ b/source/source_esolver/esolver_ks_lcaopw.cpp @@ -157,11 +157,7 @@ namespace ModuleESolver } #endif - Symmetry_rho srho; - for (int is = 0; is < PARAM.inp.nspin; is++) - { - srho.begin(is, this->chr, this->pw_rhod, ucell.symm); - } + Symmetry_rho::symmetrize_rho(PARAM.inp.nspin, this->chr, this->pw_rhod, ucell.symm); // deband is calculated from "output" charge density calculated // in sum_band diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 5f9e56a95d..b35ea19948 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -201,11 +201,7 @@ void ESolver_KS_PW::hamilt2rho_single(UnitCell& ucell, const int iste } // symmetrize the charge density - Symmetry_rho srho; - for (int is = 0; is < PARAM.inp.nspin; is++) - { - srho.begin(is, this->chr, this->pw_rhod, ucell.symm); - } + Symmetry_rho::symmetrize_rho(PARAM.inp.nspin, this->chr, this->pw_rhod, ucell.symm); ModuleBase::timer::tick("ESolver_KS_PW", "hamilt2rho_single"); } diff --git a/source/source_esolver/esolver_of.cpp b/source/source_esolver/esolver_of.cpp index 4a086205c4..1fb754d2b6 100644 --- a/source/source_esolver/esolver_of.cpp +++ b/source/source_esolver/esolver_of.cpp @@ -234,11 +234,7 @@ void ESolver_OF::before_opt(const int istep, UnitCell& ucell) this->pelec->init_scf(ucell, Pgrid, sf.strucFac, locpp.numeric, ucell.symm); - Symmetry_rho srho; - for (int is = 0; is < PARAM.inp.nspin; is++) - { - srho.begin(is, this->chr, this->pw_rho, ucell.symm); - } + Symmetry_rho::symmetrize_rho(PARAM.inp.nspin, this->chr, this->pw_rho, ucell.symm); for (int is = 0; is < PARAM.inp.nspin; ++is) { diff --git a/source/source_esolver/esolver_sdft_pw.cpp b/source/source_esolver/esolver_sdft_pw.cpp index 1a9057d178..798e52d26b 100644 --- a/source/source_esolver/esolver_sdft_pw.cpp +++ b/source/source_esolver/esolver_sdft_pw.cpp @@ -190,11 +190,7 @@ void ESolver_SDFT_PW::hamilt2rho_single(UnitCell& ucell, int istep, i if (PARAM.globalv.ks_run) { - Symmetry_rho srho; - for (int is = 0; is < PARAM.inp.nspin; is++) - { - srho.begin(is, this->chr, this->pw_rho, ucell.symm); - } + Symmetry_rho::symmetrize_rho(PARAM.inp.nspin, this->chr, this->pw_rho, ucell.symm); this->pelec->f_en.deband = this->pelec->cal_delta_eband(ucell); } else diff --git a/source/source_estate/module_charge/symmetry_rho.cpp b/source/source_estate/module_charge/symmetry_rho.cpp index dbd8a57af1..19b67967c7 100644 --- a/source/source_estate/module_charge/symmetry_rho.cpp +++ b/source/source_estate/module_charge/symmetry_rho.cpp @@ -10,6 +10,18 @@ Symmetry_rho::~Symmetry_rho() { } +void Symmetry_rho::symmetrize_rho(const int nspin, + const Charge& chr, + const ModulePW::PW_Basis* pw, + ModuleSymmetry::Symmetry& symm) +{ + Symmetry_rho srho; + for (int is = 0; is < nspin; is++) + { + srho.begin(is, chr, pw, symm); + } +} + void Symmetry_rho::begin(const int& spin_now, const Charge& chr, const ModulePW::PW_Basis* rho_basis, diff --git a/source/source_estate/module_charge/symmetry_rho.h b/source/source_estate/module_charge/symmetry_rho.h index 638903fd93..98d0650167 100644 --- a/source/source_estate/module_charge/symmetry_rho.h +++ b/source/source_estate/module_charge/symmetry_rho.h @@ -11,6 +11,22 @@ class Symmetry_rho Symmetry_rho(); ~Symmetry_rho(); + /** + * @brief Symmetrize charge density for all spin channels + * + * This is a static helper function that symmetrizes the charge density + * for all spin channels by calling begin() for each spin. + * + * @param nspin Number of spin channels + * @param chr Charge object containing the density + * @param pw Plane wave basis + * @param symm Symmetry object + */ + static void symmetrize_rho(const int nspin, + const Charge& chr, + const ModulePW::PW_Basis* pw, + ModuleSymmetry::Symmetry& symm); + void begin(const int& spin_now, const Charge& CHR, const ModulePW::PW_Basis* pw, From 285ee1c58b4968bb3fa945a797cd574078eed109 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Tue, 10 Mar 2026 09:19:50 +0800 Subject: [PATCH 03/40] refactor(esolver): extract DeltaSpin lambda loop to deltaspin_lcao module - Create new files deltaspin_lcao.h/cpp in module_deltaspin - Extract DeltaSpin lambda loop logic from ESolver_KS_LCAO - Simplify code from 18 lines to 1 line in hamilt2rho_single - Separate LCAO and PW implementations for DeltaSpin Modified files: - source_esolver/esolver_ks_lcao.cpp: replace inline code with function call - source_lcao/module_deltaspin/CMakeLists.txt: add new source file New files: - source_lcao/module_deltaspin/deltaspin_lcao.h: function declaration - source_lcao/module_deltaspin/deltaspin_lcao.cpp: function implementation This refactoring follows the ESolver cleanup principle: keep ESolver focused on high-level workflow control. --- source/source_esolver/esolver_ks_lcao.cpp | 20 +-------- .../module_deltaspin/CMakeLists.txt | 1 + .../module_deltaspin/deltaspin_lcao.cpp | 44 +++++++++++++++++++ .../module_deltaspin/deltaspin_lcao.h | 29 ++++++++++++ 4 files changed, 76 insertions(+), 18 deletions(-) create mode 100644 source/source_lcao/module_deltaspin/deltaspin_lcao.cpp create mode 100644 source/source_lcao/module_deltaspin/deltaspin_lcao.h diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index be6833294f..44753b41b1 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -1,6 +1,7 @@ #include "esolver_ks_lcao.h" #include "source_estate/elecstate_tools.h" #include "source_lcao/module_deltaspin/spin_constrain.h" +#include "source_lcao/module_deltaspin/deltaspin_lcao.h" #include "source_lcao/hs_matrix_k.hpp" // there may be multiple definitions if using hpp #include "source_estate/module_charge/symmetry_rho.h" #include "source_lcao/LCAO_domain.h" // need DeePKS_init @@ -388,24 +389,7 @@ void ESolver_KS_LCAO::hamilt2rho_single(UnitCell& ucell, int istep, int bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false; // 2) run the inner lambda loop to contrain atomic moments with the DeltaSpin method - bool skip_solve = false; - if (PARAM.inp.sc_mag_switch) - { - spinconstrain::SpinConstrain& sc = spinconstrain::SpinConstrain::getScInstance(); - if (!sc.mag_converged() && this->drho > 0 && this->drho < PARAM.inp.sc_scf_thr) - { - // optimize lambda to get target magnetic moments, but the lambda is not near target - sc.run_lambda_loop(iter - 1); - sc.set_mag_converged(true); - skip_solve = true; - } - else if (sc.mag_converged()) - { - // optimize lambda to get target magnetic moments, but the lambda is not near target - sc.run_lambda_loop(iter - 1); - skip_solve = true; - } - } + bool skip_solve = run_deltaspin_lambda_loop_lcao(iter - 1, this->drho, PARAM.inp); // 3) run Hsolver if (!skip_solve) diff --git a/source/source_lcao/module_deltaspin/CMakeLists.txt b/source/source_lcao/module_deltaspin/CMakeLists.txt index 02f389e5f1..6a0c1fea22 100644 --- a/source/source_lcao/module_deltaspin/CMakeLists.txt +++ b/source/source_lcao/module_deltaspin/CMakeLists.txt @@ -7,6 +7,7 @@ list(APPEND objects lambda_loop.cpp cal_mw_from_lambda.cpp template_helpers.cpp + deltaspin_lcao.cpp ) add_library( diff --git a/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp b/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp new file mode 100644 index 0000000000..9b0e2d08ab --- /dev/null +++ b/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp @@ -0,0 +1,44 @@ +#include "deltaspin_lcao.h" +#include "spin_constrain.h" + +namespace ModuleESolver +{ + +template +bool run_deltaspin_lambda_loop_lcao(const int iter, + const double drho, + const Input_para& inp) +{ + bool skip_solve = false; + + if (inp.sc_mag_switch) + { + spinconstrain::SpinConstrain& sc = spinconstrain::SpinConstrain::getScInstance(); + + if (!sc.mag_converged() && drho > 0 && drho < inp.sc_scf_thr) + { + /// optimize lambda to get target magnetic moments, but the lambda is not near target + sc.run_lambda_loop(iter - 1); + sc.set_mag_converged(true); + skip_solve = true; + } + else if (sc.mag_converged()) + { + /// optimize lambda to get target magnetic moments, but the lambda is not near target + sc.run_lambda_loop(iter - 1); + skip_solve = true; + } + } + + return skip_solve; +} + +/// Template instantiation +template bool run_deltaspin_lambda_loop_lcao(const int iter, + const double drho, + const Input_para& inp); +template bool run_deltaspin_lambda_loop_lcao>(const int iter, + const double drho, + const Input_para& inp); + +} // namespace ModuleESolver diff --git a/source/source_lcao/module_deltaspin/deltaspin_lcao.h b/source/source_lcao/module_deltaspin/deltaspin_lcao.h new file mode 100644 index 0000000000..95d3352732 --- /dev/null +++ b/source/source_lcao/module_deltaspin/deltaspin_lcao.h @@ -0,0 +1,29 @@ +#ifndef DELTASPIN_LCAO_H +#define DELTASPIN_LCAO_H + +#include "source_cell/unitcell.h" +#include "source_io/module_parameter/input_parameter.h" + +namespace ModuleESolver +{ + +/** + * @brief Run DeltaSpin lambda loop for LCAO method + * + * This function handles the lambda loop optimization for the DeltaSpin method + * in LCAO calculations. It determines whether to skip the Hamiltonian solve + * based on the convergence status of magnetic moments. + * + * @param iter Current iteration number + * @param drho Charge density convergence criterion + * @param inp Input parameters + * @return bool Whether to skip the Hamiltonian solve + */ +template +bool run_deltaspin_lambda_loop_lcao(const int iter, + const double drho, + const Input_para& inp); + +} // namespace ModuleESolver + +#endif // DELTASPIN_LCAO_H From 2a520e3f26b9b43547a839de57ecd59fdae60930 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Tue, 10 Mar 2026 09:30:56 +0800 Subject: [PATCH 04/40] refactor(esolver): complete DeltaSpin refactoring in LCAO - Add init_deltaspin_lcao() function for DeltaSpin initialization - Add cal_mi_lcao_wrapper() function for magnetic moment calculation - Refactor all DeltaSpin-related code in esolver_ks_lcao.cpp - Simplify code from 29 lines to 3 lines (90% reduction) Modified files: - source_esolver/esolver_ks_lcao.cpp: replace 3 code blocks with function calls - source_lcao/module_deltaspin/deltaspin_lcao.h: add 2 new function declarations - source_lcao/module_deltaspin/deltaspin_lcao.cpp: implement 2 new functions This completes the DeltaSpin refactoring for LCAO method: 1. init_deltaspin_lcao() - initialize DeltaSpin calculation 2. cal_mi_lcao_wrapper() - calculate magnetic moments 3. run_deltaspin_lambda_loop_lcao() - run lambda loop optimization All functions follow the ESolver cleanup principle: keep ESolver focused on high-level workflow control. --- source/source_esolver/esolver_ks_lcao.cpp | 14 +---- .../module_deltaspin/deltaspin_lcao.cpp | 59 ++++++++++++++++++- .../module_deltaspin/deltaspin_lcao.h | 37 ++++++++++++ 3 files changed, 96 insertions(+), 14 deletions(-) diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index 44753b41b1..3050a8aa9e 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -150,13 +150,7 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) this->deepks.build_overlap(ucell, orb_, pv, gd, *(two_center_bundle_.overlap_orb_alpha), PARAM.inp); // 10) prepare sc calculation - if (PARAM.inp.sc_mag_switch) - { - spinconstrain::SpinConstrain& sc = spinconstrain::SpinConstrain::getScInstance(); - sc.init_sc(PARAM.inp.sc_thr, PARAM.inp.nsc, PARAM.inp.nsc_min, PARAM.inp.alpha_trial, - PARAM.inp.sccut, PARAM.inp.sc_drop_thr, ucell, &(this->pv), - PARAM.inp.nspin, this->kv, this->p_hamilt, this->psi, this->dmat.dm, this->pelec); - } + init_deltaspin_lcao(ucell, PARAM.inp, &(this->pv), this->kv, this->p_hamilt, this->psi, this->dmat.dm, this->pelec); // 11) set xc type before the first cal of xc in pelec->init_scf, Peize Lin add 2016-12-03 this->exx_nao.before_scf(ucell, this->kv, orb_, this->p_chgmix, istep, PARAM.inp); @@ -462,11 +456,7 @@ void ESolver_KS_LCAO::iter_finish(UnitCell& ucell, const int istep, int& this->deepks.delta_e(ucell, this->kv, this->orb_, this->pv, this->gd, dm_vec, this->pelec->f_en, PARAM.inp); // 3) for delta spin - if (PARAM.inp.sc_mag_switch) - { - spinconstrain::SpinConstrain& sc = spinconstrain::SpinConstrain::getScInstance(); - sc.cal_mi_lcao(iter); - } + cal_mi_lcao_wrapper(iter); // call iter_finish() of ESolver_KS, where band gap is printed, // eig and occ are printed, magnetization is calculated, diff --git a/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp b/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp index 9b0e2d08ab..c58b4f0783 100644 --- a/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp +++ b/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp @@ -1,9 +1,44 @@ #include "deltaspin_lcao.h" #include "spin_constrain.h" +#include "source_basis/module_ao/parallel_orbitals.h" +#include "source_lcao/hamilt_lcao.h" +#include "source_estate/module_dm/density_matrix.h" +#include "source_estate/elecstate.h" namespace ModuleESolver { +template +void init_deltaspin_lcao(const UnitCell& ucell, + const Input_para& inp, + void* pv, + const K_Vectors& kv, + void* p_hamilt, + void* psi, + void* dm, + void* pelec) +{ + if (!inp.sc_mag_switch) + { + return; + } + + spinconstrain::SpinConstrain& sc = spinconstrain::SpinConstrain::getScInstance(); + sc.init_sc(inp.sc_thr, inp.nsc, inp.nsc_min, inp.alpha_trial, + inp.sccut, inp.sc_drop_thr, ucell, + static_cast(pv), + inp.nspin, kv, p_hamilt, psi, + static_cast*>(dm), + static_cast(pelec)); +} + +template +void cal_mi_lcao_wrapper(const int iter) +{ + spinconstrain::SpinConstrain& sc = spinconstrain::SpinConstrain::getScInstance(); + sc.cal_mi_lcao(iter); +} + template bool run_deltaspin_lambda_loop_lcao(const int iter, const double drho, @@ -18,14 +53,14 @@ bool run_deltaspin_lambda_loop_lcao(const int iter, if (!sc.mag_converged() && drho > 0 && drho < inp.sc_scf_thr) { /// optimize lambda to get target magnetic moments, but the lambda is not near target - sc.run_lambda_loop(iter - 1); + sc.run_lambda_loop(iter); sc.set_mag_converged(true); skip_solve = true; } else if (sc.mag_converged()) { /// optimize lambda to get target magnetic moments, but the lambda is not near target - sc.run_lambda_loop(iter - 1); + sc.run_lambda_loop(iter); skip_solve = true; } } @@ -34,6 +69,26 @@ bool run_deltaspin_lambda_loop_lcao(const int iter, } /// Template instantiation +template void init_deltaspin_lcao(const UnitCell& ucell, + const Input_para& inp, + void* pv, + const K_Vectors& kv, + void* p_hamilt, + void* psi, + void* dm, + void* pelec); +template void init_deltaspin_lcao>(const UnitCell& ucell, + const Input_para& inp, + void* pv, + const K_Vectors& kv, + void* p_hamilt, + void* psi, + void* dm, + void* pelec); + +template void cal_mi_lcao_wrapper(const int iter); +template void cal_mi_lcao_wrapper>(const int iter); + template bool run_deltaspin_lambda_loop_lcao(const int iter, const double drho, const Input_para& inp); diff --git a/source/source_lcao/module_deltaspin/deltaspin_lcao.h b/source/source_lcao/module_deltaspin/deltaspin_lcao.h index 95d3352732..f91326490b 100644 --- a/source/source_lcao/module_deltaspin/deltaspin_lcao.h +++ b/source/source_lcao/module_deltaspin/deltaspin_lcao.h @@ -2,11 +2,48 @@ #define DELTASPIN_LCAO_H #include "source_cell/unitcell.h" +#include "source_cell/klist.h" #include "source_io/module_parameter/input_parameter.h" namespace ModuleESolver { +/** + * @brief Initialize DeltaSpin for LCAO method + * + * This function initializes the DeltaSpin calculation by setting up + * the SpinConstrain object with input parameters. + * + * @param ucell Unit cell + * @param inp Input parameters + * @param pv Parallel orbitals + * @param kv K-vectors + * @param p_hamilt Pointer to Hamiltonian + * @param psi Pointer to wave functions + * @param dm Density matrix + * @param pelec Pointer to electronic state + */ +template +void init_deltaspin_lcao(const UnitCell& ucell, + const Input_para& inp, + void* pv, + const K_Vectors& kv, + void* p_hamilt, + void* psi, + void* dm, + void* pelec); + +/** + * @brief Calculate magnetic moments for DeltaSpin in LCAO method + * + * This function calculates the magnetic moments for each atom + * in the DeltaSpin method. + * + * @param iter Current iteration number + */ +template +void cal_mi_lcao_wrapper(const int iter); + /** * @brief Run DeltaSpin lambda loop for LCAO method * From 91be943b3f8de93a6b4d394d01671cb09ae4fd4a Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Tue, 10 Mar 2026 10:16:10 +0800 Subject: [PATCH 05/40] refactor(esolver): extract DFT+U code to dftu_lcao module - Create new files dftu_lcao.h/cpp in source_lcao directory - Add init_dftu_lcao() function for DFT+U initialization - Add finish_dftu_lcao() function for DFT+U finalization - Simplify code from 32 lines to 2 lines in esolver_ks_lcao.cpp - Remove conditional checks from ESolver, move them to functions Modified files: - source_esolver/esolver_ks_lcao.cpp: replace 2 code blocks with function calls - source_lcao/CMakeLists.txt: add new source file New files: - source_lcao/dftu_lcao.h: function declarations - source_lcao/dftu_lcao.cpp: function implementations This refactoring prepares for unifying old and new DFT+U implementations: - Old DFT+U: source_lcao/module_dftu/ - New DFT+U: source_lcao/module_operator_lcao/op_dftu_lcao.cpp All functions follow ESolver cleanup principle: keep ESolver focused on high-level workflow control. --- source/source_esolver/esolver_ks_lcao.cpp | 32 +------ source/source_lcao/CMakeLists.txt | 1 + source/source_lcao/dftu_lcao.cpp | 112 ++++++++++++++++++++++ source/source_lcao/dftu_lcao.h | 65 +++++++++++++ 4 files changed, 181 insertions(+), 29 deletions(-) create mode 100644 source/source_lcao/dftu_lcao.cpp create mode 100644 source/source_lcao/dftu_lcao.h diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index 3050a8aa9e..e86c7a8cb6 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -2,6 +2,7 @@ #include "source_estate/elecstate_tools.h" #include "source_lcao/module_deltaspin/spin_constrain.h" #include "source_lcao/module_deltaspin/deltaspin_lcao.h" +#include "source_lcao/dftu_lcao.h" #include "source_lcao/hs_matrix_k.hpp" // there may be multiple definitions if using hpp #include "source_estate/module_charge/symmetry_rho.h" #include "source_lcao/LCAO_domain.h" // need DeePKS_init @@ -338,15 +339,7 @@ void ESolver_KS_LCAO::iter_init(UnitCell& ucell, const int istep, const } #endif - if (PARAM.inp.dft_plus_u) - { - if (istep != 0 || iter != 1) - { - this->dftu.set_dmr(this->dmat.dm); - } - // Calculate U and J if Yukawa potential is used - this->dftu.cal_slater_UJ(ucell, this->chr.rho, this->pw_rho->nrxx); - } + init_dftu_lcao(istep, iter, PARAM.inp, &(this->dftu), this->dmat.dm, ucell, this->chr.rho, this->pw_rho->nrxx); #ifdef __MLALGO // the density matrixes of DeePKS have been updated in each iter @@ -431,26 +424,7 @@ void ESolver_KS_LCAO::iter_finish(UnitCell& ucell, const int istep, int& const std::vector>& dm_vec = this->dmat.dm->get_DMK_vector(); // 1) calculate the local occupation number matrix and energy correction in DFT+U - if (PARAM.inp.dft_plus_u) - { - // old DFT+U method calculates energy correction in esolver, - // new DFT+U method calculates energy in Hamiltonian - if (PARAM.inp.dft_plus_u == 2) - { - if (this->dftu.omc != 2) - { - dftu_cal_occup_m(iter, ucell, dm_vec, this->kv, - this->p_chgmix->get_mixing_beta(), hamilt_lcao, this->dftu); - } - this->dftu.cal_energy_correction(ucell, istep); - } - this->dftu.output(ucell); - // use the converged occupation matrix for next MD/Relax SCF calculation - if (conv_esolver) - { - this->dftu.initialed_locale = true; - } - } + finish_dftu_lcao(iter, conv_esolver, PARAM.inp, &(this->dftu), ucell, dm_vec, this->kv, this->p_chgmix->get_mixing_beta(), hamilt_lcao); // 2) for deepks, calculate delta_e, output labels during electronic steps this->deepks.delta_e(ucell, this->kv, this->orb_, this->pv, this->gd, dm_vec, this->pelec->f_en, PARAM.inp); diff --git a/source/source_lcao/CMakeLists.txt b/source/source_lcao/CMakeLists.txt index 844da5dc84..a793f5d5d0 100644 --- a/source/source_lcao/CMakeLists.txt +++ b/source/source_lcao/CMakeLists.txt @@ -23,6 +23,7 @@ if(ENABLE_LCAO) module_operator_lcao/dspin_lcao.cpp module_operator_lcao/dftu_lcao.cpp module_operator_lcao/operator_force_stress_utils.cpp + dftu_lcao.cpp pulay_fs_center2.cpp FORCE_STRESS.cpp FORCE_gamma.cpp diff --git a/source/source_lcao/dftu_lcao.cpp b/source/source_lcao/dftu_lcao.cpp new file mode 100644 index 0000000000..5a4c6c45c8 --- /dev/null +++ b/source/source_lcao/dftu_lcao.cpp @@ -0,0 +1,112 @@ +#include "dftu_lcao.h" +#include "source_lcao/module_dftu/dftu.h" +#include "source_estate/module_dm/density_matrix.h" +#include "source_lcao/hamilt_lcao.h" + +namespace ModuleESolver +{ + +template +void init_dftu_lcao(const int istep, + const int iter, + const Input_para& inp, + void* dftu, + void* dm, + const UnitCell& ucell, + double** rho, + const int nrxx) +{ + if (!inp.dft_plus_u) + { + return; + } + + auto* dftu_ptr = static_cast(dftu); + auto* dm_ptr = static_cast*>(dm); + + if (istep != 0 || iter != 1) + { + dftu_ptr->set_dmr(dm_ptr); + } + + /// Calculate U and J if Yukawa potential is used + dftu_ptr->cal_slater_UJ(ucell, rho, nrxx); +} + +template +void finish_dftu_lcao(const int iter, + const bool conv_esolver, + const Input_para& inp, + void* dftu, + const UnitCell& ucell, + const std::vector>& dm_vec, + const K_Vectors& kv, + const double mixing_beta, + void* hamilt_lcao) +{ + if (!inp.dft_plus_u) + { + return; + } + + auto* dftu_ptr = static_cast(dftu); + auto* hamilt_lcao_ptr = static_cast*>(hamilt_lcao); + + /// old DFT+U method calculates energy correction in esolver, + /// new DFT+U method calculates energy in Hamiltonian + if (inp.dft_plus_u == 2) + { + if (dftu_ptr->omc != 2) + { + dftu_cal_occup_m(iter, ucell, dm_vec, kv, mixing_beta, + static_cast*>(hamilt_lcao_ptr), *dftu_ptr); + } + dftu_ptr->cal_energy_correction(ucell, iter); + } + dftu_ptr->output(ucell); + + /// use the converged occupation matrix for next MD/Relax SCF calculation + if (conv_esolver) + { + dftu_ptr->initialed_locale = true; + } +} + +/// Template instantiation +template void init_dftu_lcao(const int istep, + const int iter, + const Input_para& inp, + void* dftu, + void* dm, + const UnitCell& ucell, + double** rho, + const int nrxx); +template void init_dftu_lcao>(const int istep, + const int iter, + const Input_para& inp, + void* dftu, + void* dm, + const UnitCell& ucell, + double** rho, + const int nrxx); + +template void finish_dftu_lcao(const int iter, + const bool conv_esolver, + const Input_para& inp, + void* dftu, + const UnitCell& ucell, + const std::vector>& dm_vec, + const K_Vectors& kv, + const double mixing_beta, + void* hamilt_lcao); +template void finish_dftu_lcao>(const int iter, + const bool conv_esolver, + const Input_para& inp, + void* dftu, + const UnitCell& ucell, + const std::vector>>& dm_vec, + const K_Vectors& kv, + const double mixing_beta, + void* hamilt_lcao); + +} // namespace ModuleESolver diff --git a/source/source_lcao/dftu_lcao.h b/source/source_lcao/dftu_lcao.h new file mode 100644 index 0000000000..5138b66256 --- /dev/null +++ b/source/source_lcao/dftu_lcao.h @@ -0,0 +1,65 @@ +#ifndef DFTU_LCAO_H +#define DFTU_LCAO_H + +#include "source_cell/unitcell.h" +#include "source_cell/klist.h" +#include "source_io/module_parameter/input_parameter.h" + +namespace ModuleESolver +{ + +/** + * @brief Initialize DFT+U for LCAO method in iter_init + * + * This function handles the DFT+U initialization during the SCF iteration. + * It sets the density matrix and calculates Slater integrals if needed. + * + * @param istep Current ionic step + * @param iter Current SCF iteration + * @param inp Input parameters + * @param dftu DFT+U object + * @param dm Density matrix + * @param ucell Unit cell + * @param rho Charge density + * @param nrxx Number of real space grid points + */ +template +void init_dftu_lcao(const int istep, + const int iter, + const Input_para& inp, + void* dftu, + void* dm, + const UnitCell& ucell, + double** rho, + const int nrxx); + +/** + * @brief Finish DFT+U calculation for LCAO method in iter_finish + * + * This function handles the DFT+U finalization during the SCF iteration. + * It calculates the occupation matrix and energy correction if needed. + * + * @param iter Current SCF iteration + * @param conv_esolver Whether ESolver has converged + * @param inp Input parameters + * @param dftu DFT+U object + * @param ucell Unit cell + * @param dm_vec Density matrix vector + * @param kv K-vectors + * @param mixing_beta Mixing beta parameter + * @param hamilt_lcao Hamiltonian LCAO object + */ +template +void finish_dftu_lcao(const int iter, + const bool conv_esolver, + const Input_para& inp, + void* dftu, + const UnitCell& ucell, + const std::vector>& dm_vec, + const K_Vectors& kv, + const double mixing_beta, + void* hamilt_lcao); + +} // namespace ModuleESolver + +#endif // DFTU_LCAO_H From 5a9648317862c698bb1a7f6fe225c3ca82f4dcfc Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Tue, 10 Mar 2026 12:30:46 +0800 Subject: [PATCH 06/40] refactor(esolver): extract diagonalization parameters setup to hsolver module - Create new files diago_params.h/cpp in source_hsolver directory - Add setup_diago_params_pw() function for PW diagonalization parameters - Simplify code from 11 lines to 1 line in esolver_ks_pw.cpp - Encapsulate diagonalization parameter setup logic Modified files: - source_esolver/esolver_ks_pw.cpp: replace inline code with function call - source_hsolver/CMakeLists.txt: add new source file New files: - source_hsolver/diago_params.h: function declaration - source_hsolver/diago_params.cpp: function implementation This refactoring follows ESolver cleanup principle: keep ESolver focused on high-level workflow control. --- source/source_esolver/esolver_ks_pw.cpp | 14 ++----- source/source_hsolver/CMakeLists.txt | 1 + source/source_hsolver/diago_params.cpp | 55 +++++++++++++++++++++++++ source/source_hsolver/diago_params.h | 29 +++++++++++++ 4 files changed, 88 insertions(+), 11 deletions(-) create mode 100644 source/source_hsolver/diago_params.cpp create mode 100644 source/source_hsolver/diago_params.h diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index b35ea19948..e255b95b46 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -6,6 +6,7 @@ #include "source_hsolver/diago_iter_assist.h" #include "source_hsolver/hsolver_pw.h" +#include "source_hsolver/diago_params.h" #include "source_hsolver/kernels/hegvd_op.h" #include "source_io/module_parameter/parameter.h" @@ -164,17 +165,8 @@ void ESolver_KS_PW::hamilt2rho_single(UnitCell& ucell, const int iste this->pelec->f_en.eband = 0.0; this->pelec->f_en.demet = 0.0; - // choose if psi should be diag in subspace - // be careful that istep start from 0 and iter start from 1 - // if (iter == 1) - hsolver::DiagoIterAssist::need_subspace = ((istep == 0 || istep == 1) && iter == 1) ? false : true; - hsolver::DiagoIterAssist::SCF_ITER = iter; - hsolver::DiagoIterAssist::PW_DIAG_THR = ethr; - - if (PARAM.inp.calculation != "nscf") - { - hsolver::DiagoIterAssist::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax; - } + // setup diagonalization parameters + hsolver::setup_diago_params_pw(istep, iter, ethr, PARAM.inp); bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false; diff --git a/source/source_hsolver/CMakeLists.txt b/source/source_hsolver/CMakeLists.txt index f4e87cdf94..b115d6d4cd 100644 --- a/source/source_hsolver/CMakeLists.txt +++ b/source/source_hsolver/CMakeLists.txt @@ -12,6 +12,7 @@ list(APPEND objects hsolver.cpp diago_pxxxgvx.cpp diag_hs_para.cpp + diago_params.cpp ) diff --git a/source/source_hsolver/diago_params.cpp b/source/source_hsolver/diago_params.cpp new file mode 100644 index 0000000000..a0c720a625 --- /dev/null +++ b/source/source_hsolver/diago_params.cpp @@ -0,0 +1,55 @@ +#include "diago_params.h" +#include "diago_iter_assist.h" + +namespace hsolver +{ + +template +void setup_diago_params_pw(const int istep, + const int iter, + const double ethr, + const Input_para& inp) +{ + /// choose if psi should be diag in subspace + /// be careful that istep start from 0 and iter start from 1 + DiagoIterAssist::need_subspace = ((istep == 0 || istep == 1) && iter == 1) ? false : true; + DiagoIterAssist::SCF_ITER = iter; + DiagoIterAssist::PW_DIAG_THR = ethr; + + if (inp.calculation != "nscf") + { + DiagoIterAssist::PW_DIAG_NMAX = inp.pw_diag_nmax; + } +} + +/// Template instantiation for CPU +template void setup_diago_params_pw, base_device::DEVICE_CPU>(const int istep, + const int iter, + const double ethr, + const Input_para& inp); +template void setup_diago_params_pw, base_device::DEVICE_CPU>(const int istep, + const int iter, + const double ethr, + const Input_para& inp); +template void setup_diago_params_pw(const int istep, + const int iter, + const double ethr, + const Input_para& inp); + +/// Template instantiation for GPU +#if ((defined __CUDA) || (defined __ROCM)) +template void setup_diago_params_pw, base_device::DEVICE_GPU>(const int istep, + const int iter, + const double ethr, + const Input_para& inp); +template void setup_diago_params_pw, base_device::DEVICE_GPU>(const int istep, + const int iter, + const double ethr, + const Input_para& inp); +template void setup_diago_params_pw(const int istep, + const int iter, + const double ethr, + const Input_para& inp); +#endif + +} // namespace hsolver diff --git a/source/source_hsolver/diago_params.h b/source/source_hsolver/diago_params.h new file mode 100644 index 0000000000..995090bebd --- /dev/null +++ b/source/source_hsolver/diago_params.h @@ -0,0 +1,29 @@ +#ifndef DIAGO_PARAMS_H +#define DIAGO_PARAMS_H + +#include "source_io/module_parameter/input_parameter.h" + +namespace hsolver +{ + +/** + * @brief Setup diagonalization parameters for PW method + * + * This function sets up the diagonalization parameters for plane wave method, + * including subspace diagonalization flag, SCF iteration number, diagonalization + * threshold, and maximum number of diagonalization steps. + * + * @param istep Current ionic step + * @param iter Current SCF iteration + * @param ethr Diagonalization threshold + * @param inp Input parameters + */ +template +void setup_diago_params_pw(const int istep, + const int iter, + const double ethr, + const Input_para& inp); + +} // namespace hsolver + +#endif // DIAGO_PARAMS_H From 4936dc1f28adc6fdbbf208825c95b4dabcbf8462 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Tue, 10 Mar 2026 14:00:59 +0800 Subject: [PATCH 07/40] fix(deltaspin): add sc_mag_switch check in cal_mi_lcao_wrapper - Add Input_para parameter to cal_mi_lcao_wrapper function - Add sc_mag_switch check to avoid calling cal_mi_lcao when DeltaSpin is disabled - Fix 'atomCounts is not set' error in non-DeltaSpin calculations - Update function call in esolver_ks_lcao.cpp This fix resolves the CI/CD failure caused by commit 2a520e3f2. The root cause was that cal_mi_lcao_wrapper was called without checking sc_mag_switch, leading to uninitialized atomCounts error. Modified files: - source_esolver/esolver_ks_lcao.cpp: update function call - source_lcao/module_deltaspin/deltaspin_lcao.h: add parameter - source_lcao/module_deltaspin/deltaspin_lcao.cpp: add check This follows the refactoring principle: preserve original condition checks when extracting code to wrapper functions. --- source/source_esolver/esolver_ks_lcao.cpp | 2 +- .../source_lcao/module_deltaspin/deltaspin_lcao.cpp | 11 ++++++++--- source/source_lcao/module_deltaspin/deltaspin_lcao.h | 3 ++- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index e86c7a8cb6..43e83aa8df 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -430,7 +430,7 @@ void ESolver_KS_LCAO::iter_finish(UnitCell& ucell, const int istep, int& this->deepks.delta_e(ucell, this->kv, this->orb_, this->pv, this->gd, dm_vec, this->pelec->f_en, PARAM.inp); // 3) for delta spin - cal_mi_lcao_wrapper(iter); + cal_mi_lcao_wrapper(iter, PARAM.inp); // call iter_finish() of ESolver_KS, where band gap is printed, // eig and occ are printed, magnetization is calculated, diff --git a/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp b/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp index c58b4f0783..96e969277c 100644 --- a/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp +++ b/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp @@ -33,8 +33,13 @@ void init_deltaspin_lcao(const UnitCell& ucell, } template -void cal_mi_lcao_wrapper(const int iter) +void cal_mi_lcao_wrapper(const int iter, const Input_para& inp) { + if (!inp.sc_mag_switch) + { + return; + } + spinconstrain::SpinConstrain& sc = spinconstrain::SpinConstrain::getScInstance(); sc.cal_mi_lcao(iter); } @@ -86,8 +91,8 @@ template void init_deltaspin_lcao>(const UnitCell& ucell, void* dm, void* pelec); -template void cal_mi_lcao_wrapper(const int iter); -template void cal_mi_lcao_wrapper>(const int iter); +template void cal_mi_lcao_wrapper(const int iter, const Input_para& inp); +template void cal_mi_lcao_wrapper>(const int iter, const Input_para& inp); template bool run_deltaspin_lambda_loop_lcao(const int iter, const double drho, diff --git a/source/source_lcao/module_deltaspin/deltaspin_lcao.h b/source/source_lcao/module_deltaspin/deltaspin_lcao.h index f91326490b..959109ece7 100644 --- a/source/source_lcao/module_deltaspin/deltaspin_lcao.h +++ b/source/source_lcao/module_deltaspin/deltaspin_lcao.h @@ -40,9 +40,10 @@ void init_deltaspin_lcao(const UnitCell& ucell, * in the DeltaSpin method. * * @param iter Current iteration number + * @param inp Input parameters */ template -void cal_mi_lcao_wrapper(const int iter); +void cal_mi_lcao_wrapper(const int iter, const Input_para& inp); /** * @brief Run DeltaSpin lambda loop for LCAO method From ea218f642d08d38a871d7fe27dc488f08fdce25d Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Tue, 10 Mar 2026 18:03:31 +0800 Subject: [PATCH 08/40] fix(deltaspin): add #ifdef __LCAO for conditional compilation - Add #ifdef __LCAO conditional compilation in init_deltaspin_lcao and cal_mi_lcao_wrapper - Fix parameter order in init_sc call for LCAO and non-LCAO builds - Fix undefined reference to cal_mi_lcao in non-LCAO build This fix resolves CI/CD compilation errors in both build_5pt (with __LCAO) and build_1p (without __LCAO) environments. The The root cause was 1. init_sc has different parameter order in LCAO vs non-LCAO builds - LCAO: psi, dm, pelec - non-LCAO: psi, pelec 2. cal_mi_lcao is only defined in LCAO build Modified files: - source_hsolver/diago_params.h: add setup_diago_params_sdft declaration - source_lcao/module_deltaspin/deltaspin_lcao.cpp: add conditional compilation This follows the refactoring principle: handle conditional compilation properly when code has different implementations for different build configurations. --- source/source_hsolver/diago_params.h | 18 ++++++++++++++++++ .../module_deltaspin/deltaspin_lcao.cpp | 10 ++++++++++ 2 files changed, 28 insertions(+) diff --git a/source/source_hsolver/diago_params.h b/source/source_hsolver/diago_params.h index 995090bebd..5d46b01046 100644 --- a/source/source_hsolver/diago_params.h +++ b/source/source_hsolver/diago_params.h @@ -24,6 +24,24 @@ void setup_diago_params_pw(const int istep, const double ethr, const Input_para& inp); +/** + * @brief Setup diagonalization parameters for SDFT method + * + * This function sets up the diagonalization parameters for stochastic DFT method, + * including subspace diagonalization flag, diagonalization threshold, and + * maximum number of diagonalization steps. + * + * @param istep Current ionic step + * @param iter Current SCF iteration + * @param ethr Diagonalization threshold + * @param inp Input parameters + */ +template +void setup_diago_params_sdft(const int istep, + const int iter, + const double ethr, + const Input_para& inp); + } // namespace hsolver #endif // DIAGO_PARAMS_H diff --git a/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp b/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp index 96e969277c..6a7effb6d0 100644 --- a/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp +++ b/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp @@ -24,12 +24,20 @@ void init_deltaspin_lcao(const UnitCell& ucell, } spinconstrain::SpinConstrain& sc = spinconstrain::SpinConstrain::getScInstance(); +#ifdef __LCAO sc.init_sc(inp.sc_thr, inp.nsc, inp.nsc_min, inp.alpha_trial, inp.sccut, inp.sc_drop_thr, ucell, static_cast(pv), inp.nspin, kv, p_hamilt, psi, static_cast*>(dm), static_cast(pelec)); +#else + sc.init_sc(inp.sc_thr, inp.nsc, inp.nsc_min, inp.alpha_trial, + inp.sccut, inp.sc_drop_thr, ucell, + static_cast(pv), + inp.nspin, kv, p_hamilt, psi, + static_cast(pelec)); +#endif } template @@ -40,8 +48,10 @@ void cal_mi_lcao_wrapper(const int iter, const Input_para& inp) return; } +#ifdef __LCAO spinconstrain::SpinConstrain& sc = spinconstrain::SpinConstrain::getScInstance(); sc.cal_mi_lcao(iter); +#endif } template From c365a3afd18cced4dfdc424c120fe0a0e3bfb4a6 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Tue, 10 Mar 2026 18:19:57 +0800 Subject: [PATCH 09/40] refactor(esolver): extract SDFT diagonalization parameters setup - Add setup_diago_params_sdft() function for SDFT diagonalization parameters - Simplify code from 11 lines to 1 line in esolver_sdft_pw.cpp - Encapsulate diagonalization parameter setup logic for SDFT Modified files: - source_esolver/esolver_sdft_pw.cpp: replace inline code with function call - source_hsolver/diago_params.cpp: add setup_diago_params_sdft implementation This refactoring follows ESolver cleanup principle: keep ESolver focused on high-level workflow control. Note: SDFT has different parameter setup logic compared to PW: - Different need_subspace condition - No SCF_ITER setting - Always set PW_DIAG_NMAX (no nscf check) --- source/source_esolver/esolver_sdft_pw.cpp | 16 ++----- source/source_hsolver/diago_params.cpp | 51 +++++++++++++++++++++++ 2 files changed, 55 insertions(+), 12 deletions(-) diff --git a/source/source_esolver/esolver_sdft_pw.cpp b/source/source_esolver/esolver_sdft_pw.cpp index 798e52d26b..26118eed21 100644 --- a/source/source_esolver/esolver_sdft_pw.cpp +++ b/source/source_esolver/esolver_sdft_pw.cpp @@ -8,6 +8,7 @@ #include "source_pw/module_stodft/sto_forces.h" #include "source_pw/module_stodft/sto_stress_pw.h" #include "source_hsolver/diago_iter_assist.h" +#include "source_hsolver/diago_params.h" #include "source_io/module_parameter/parameter.h" #include @@ -142,20 +143,11 @@ void ESolver_SDFT_PW::hamilt2rho_single(UnitCell& ucell, int istep, i // reset energy this->pelec->f_en.eband = 0.0; this->pelec->f_en.demet = 0.0; - // choose if psi should be diag in subspace - // be careful that istep start from 0 and iter start from 1 - if (istep == 0 && iter == 1 || PARAM.inp.calculation == "nscf") - { - hsolver::DiagoIterAssist::need_subspace = false; - } - else - { - hsolver::DiagoIterAssist::need_subspace = true; - } + + // setup diagonalization parameters for SDFT + hsolver::setup_diago_params_sdft(istep, iter, ethr, PARAM.inp); bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false; - hsolver::DiagoIterAssist::PW_DIAG_THR = ethr; - hsolver::DiagoIterAssist::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax; // hsolver only exists in this function hsolver::HSolverPW_SDFT hsolver_pw_sdft_obj(&this->kv, diff --git a/source/source_hsolver/diago_params.cpp b/source/source_hsolver/diago_params.cpp index a0c720a625..28e1040a97 100644 --- a/source/source_hsolver/diago_params.cpp +++ b/source/source_hsolver/diago_params.cpp @@ -22,6 +22,27 @@ void setup_diago_params_pw(const int istep, } } +template +void setup_diago_params_sdft(const int istep, + const int iter, + const double ethr, + const Input_para& inp) +{ + /// choose if psi should be diag in subspace + /// be careful that istep start from 0 and iter start from 1 + if (istep == 0 && iter == 1 || inp.calculation == "nscf") + { + DiagoIterAssist::need_subspace = false; + } + else + { + DiagoIterAssist::need_subspace = true; + } + + DiagoIterAssist::PW_DIAG_THR = ethr; + DiagoIterAssist::PW_DIAG_NMAX = inp.pw_diag_nmax; +} + /// Template instantiation for CPU template void setup_diago_params_pw, base_device::DEVICE_CPU>(const int istep, const int iter, @@ -52,4 +73,34 @@ template void setup_diago_params_pw(const int i const Input_para& inp); #endif +/// Template instantiation for SDFT CPU +template void setup_diago_params_sdft, base_device::DEVICE_CPU>(const int istep, + const int iter, + const double ethr, + const Input_para& inp); +template void setup_diago_params_sdft, base_device::DEVICE_CPU>(const int istep, + const int iter, + const double ethr, + const Input_para& inp); +template void setup_diago_params_sdft(const int istep, + const int iter, + const double ethr, + const Input_para& inp); + +/// Template instantiation for SDFT GPU +#if ((defined __CUDA) || (defined __ROCM)) +template void setup_diago_params_sdft, base_device::DEVICE_GPU>(const int istep, + const int iter, + const double ethr, + const Input_para& inp); +template void setup_diago_params_sdft, base_device::DEVICE_GPU>(const int istep, + const int iter, + const double ethr, + const Input_para& inp); +template void setup_diago_params_sdft(const int istep, + const int iter, + const double ethr, + const Input_para& inp); +#endif + } // namespace hsolver From cbd0ce77936542fed08b3c568fde9250fce50a7d Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Wed, 11 Mar 2026 13:18:58 +0800 Subject: [PATCH 10/40] refactor(hamilt): introduce HamiltBase non-template base class - Create HamiltBase as a non-template base class for Hamilt - Modify Hamilt to inherit from HamiltBase - Change ESolver_KS::p_hamilt type from Hamilt* to HamiltBase* - Add static_cast where needed when passing p_hamilt to functions expecting Hamilt* This is the first step towards removing template parameters from ESolver. Modified files: - source/source_esolver/esolver_ks.h - source/source_esolver/esolver_ks_lcaopw.cpp - source/source_esolver/esolver_ks_pw.cpp - source/source_esolver/esolver_sdft_pw.cpp - source/source_hamilt/hamilt.h New files: - source/source_hamilt/hamilt_base.h --- source/source_esolver/esolver_ks.h | 5 +- source/source_esolver/esolver_ks_lcaopw.cpp | 2 +- source/source_esolver/esolver_ks_pw.cpp | 2 +- source/source_esolver/esolver_sdft_pw.cpp | 4 +- source/source_hamilt/hamilt.h | 13 ++++-- source/source_hamilt/hamilt_base.h | 52 +++++++++++++++++++++ 6 files changed, 69 insertions(+), 9 deletions(-) create mode 100644 source/source_hamilt/hamilt_base.h diff --git a/source/source_esolver/esolver_ks.h b/source/source_esolver/esolver_ks.h index 787b58ba74..1913aa4101 100644 --- a/source/source_esolver/esolver_ks.h +++ b/source/source_esolver/esolver_ks.h @@ -7,6 +7,7 @@ #include "source_estate/module_charge/charge_mixing.h" // use charge mixing #include "source_psi/psi.h" // use electronic wave functions #include "source_hamilt/hamilt.h" // use Hamiltonian +#include "source_hamilt/hamilt_base.h" // use Hamiltonian base class #include "source_lcao/module_dftu/dftu.h" // mohan add 20251107 namespace ModuleESolver @@ -47,8 +48,8 @@ class ESolver_KS : public ESolver_FP //! Something to do after SCF iterations when SCF is converged or comes to the max iter step. virtual void after_scf(UnitCell& ucell, const int istep, const bool conv_esolver) override; - //! Hamiltonian - hamilt::Hamilt* p_hamilt = nullptr; + //! Hamiltonian (base class pointer, actual type determined at runtime) + hamilt::HamiltBase* p_hamilt = nullptr; //! PW for wave functions, only used in KSDFT, not in OFDFT ModulePW::PW_Basis_K* pw_wfc = nullptr; diff --git a/source/source_esolver/esolver_ks_lcaopw.cpp b/source/source_esolver/esolver_ks_lcaopw.cpp index f9700f5b68..db00b6265d 100644 --- a/source/source_esolver/esolver_ks_lcaopw.cpp +++ b/source/source_esolver/esolver_ks_lcaopw.cpp @@ -146,7 +146,7 @@ namespace ModuleESolver bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false; hsolver::HSolverLIP hsolver_lip_obj(this->pw_wfc); - hsolver_lip_obj.solve(this->p_hamilt, this->stp.psi_t[0], this->pelec, + hsolver_lip_obj.solve(static_cast*>(this->p_hamilt), this->stp.psi_t[0], this->pelec, *this->psi_local, skip_charge,ucell.tpiba,ucell.nat); // add exx diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index e255b95b46..a032c2c976 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -188,7 +188,7 @@ void ESolver_KS_PW::hamilt2rho_single(UnitCell& ucell, const int iste hsolver::DiagoIterAssist::need_subspace, PARAM.inp.use_k_continuity); - hsolver_pw_obj.solve(this->p_hamilt, this->stp.psi_t[0], this->pelec, this->pelec->ekb.c, + hsolver_pw_obj.solve(static_cast*>(this->p_hamilt), this->stp.psi_t[0], this->pelec, this->pelec->ekb.c, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL, skip_charge, ucell.tpiba, ucell.nat); } diff --git a/source/source_esolver/esolver_sdft_pw.cpp b/source/source_esolver/esolver_sdft_pw.cpp index 26118eed21..597825cd6d 100644 --- a/source/source_esolver/esolver_sdft_pw.cpp +++ b/source/source_esolver/esolver_sdft_pw.cpp @@ -167,7 +167,7 @@ void ESolver_SDFT_PW::hamilt2rho_single(UnitCell& ucell, int istep, i hsolver::DiagoIterAssist::need_subspace); hsolver_pw_sdft_obj.solve(ucell, - this->p_hamilt, + static_cast*>(this->p_hamilt), this->stp.psi_t[0], this->stp.psi_cpu[0], this->pelec, @@ -291,7 +291,7 @@ void ESolver_SDFT_PW::after_all_runners(UnitCell& ucell) this->pw_wfc, this->stp.psi_t, &this->ppcell, - this->p_hamilt, + static_cast, Device>*>(this->p_hamilt), this->stoche, &stowf); sto_elecond.decide_nche(PARAM.inp.cond_dt, 1e-8, this->nche_sto, PARAM.inp.emin_sto, PARAM.inp.emax_sto); diff --git a/source/source_hamilt/hamilt.h b/source/source_hamilt/hamilt.h index 6d732d7a82..3d554c0fe6 100644 --- a/source/source_hamilt/hamilt.h +++ b/source/source_hamilt/hamilt.h @@ -7,21 +7,28 @@ #include "matrixblock.h" #include "source_psi/psi.h" #include "operator.h" +#include "hamilt_base.h" namespace hamilt { template -class Hamilt +class Hamilt : public HamiltBase { public: virtual ~Hamilt(){}; /// for target K point, update consequence of hPsi() and matrix() - virtual void updateHk(const int ik){return;} + void updateHk(const int ik) override { return; } /// refresh status of Hamiltonian, for example, refresh H(R) and S(R) in LCAO case - virtual void refresh(bool yes = true){return;} + void refresh(bool yes = true) override { return; } + + /// get the class name + std::string get_classname() const override { return classname; } + + /// get the operator chain + void* get_ops() override { return static_cast(ops); } /// core function: for solving eigenvalues of Hamiltonian with iterative method virtual void hPsi( diff --git a/source/source_hamilt/hamilt_base.h b/source/source_hamilt/hamilt_base.h new file mode 100644 index 0000000000..06325bf050 --- /dev/null +++ b/source/source_hamilt/hamilt_base.h @@ -0,0 +1,52 @@ +#ifndef HAMILT_BASE_H +#define HAMILT_BASE_H + +#include + +namespace hamilt +{ + +/** + * @brief Base class for Hamiltonian + * + * This is a non-template base class for Hamilt. + * It provides a common interface for all Hamiltonian types, + * allowing ESolver to manage Hamiltonian without template parameters. + */ +class HamiltBase +{ + public: + virtual ~HamiltBase() {} + + /** + * @brief Update Hamiltonian for a specific k-point + * + * @param ik k-point index + */ + virtual void updateHk(const int ik) { return; } + + /** + * @brief Refresh the status of Hamiltonian + * + * @param yes whether to refresh + */ + virtual void refresh(bool yes = true) { return; } + + /** + * @brief Get the class name + * + * @return class name + */ + virtual std::string get_classname() const { return "none"; } + + /** + * @brief Get the operator chain (as void* to avoid template) + * + * @return pointer to operator chain + */ + virtual void* get_ops() { return nullptr; } +}; + +} // namespace hamilt + +#endif // HAMILT_BASE_H From 6e0f43ca072682621867dee28d5ee8d4e58e943d Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Wed, 11 Mar 2026 16:02:42 +0800 Subject: [PATCH 11/40] refactor(esolver): add static_cast for p_hamilt in esolver files - Add static_cast*> when passing p_hamilt to functions expecting Hamilt* type - Split long cast statements into multiple lines for better readability - Files modified: - esolver_ks_pw.cpp: setup_pot, stp.init calls - esolver_ks_lcao.cpp: init_chg_hr, hsolver_lcao_obj.solve calls - esolver_ks_lcao_tddft.cpp: solve_psi, cal_edm_tddft, matrix calls - esolver_gets.cpp: ops access, output_SR call This follows the HamiltBase refactoring strategy where p_hamilt is stored as HamiltBase* and cast to Hamilt* when needed. --- source/source_esolver/esolver_gets.cpp | 12 ++++++++---- source/source_esolver/esolver_ks_lcao.cpp | 4 ++-- source/source_esolver/esolver_ks_lcao_tddft.cpp | 12 ++++++------ source/source_esolver/esolver_ks_pw.cpp | 4 ++-- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/source/source_esolver/esolver_gets.cpp b/source/source_esolver/esolver_gets.cpp index 7eff0f537a..e03e7b8bdc 100644 --- a/source/source_esolver/esolver_gets.cpp +++ b/source/source_esolver/esolver_gets.cpp @@ -108,8 +108,9 @@ void ESolver_GetS::runner(UnitCell& ucell, const int istep) this->kv, *(two_center_bundle_.overlap_orb), orb_.cutoffs()); - dynamic_cast, std::complex>*>(this->p_hamilt->ops) - ->contributeHR(); + auto* hamilt_ptr = static_cast>*>(this->p_hamilt); + auto* ops_ptr = dynamic_cast, std::complex>*>(hamilt_ptr->ops); + ops_ptr->contributeHR(); } else { @@ -119,13 +120,16 @@ void ESolver_GetS::runner(UnitCell& ucell, const int istep) this->kv, *(two_center_bundle_.overlap_orb), orb_.cutoffs()); - dynamic_cast, double>*>(this->p_hamilt->ops)->contributeHR(); + auto* hamilt_ptr = static_cast>*>(this->p_hamilt); + auto* ops_ptr = dynamic_cast, double>*>(hamilt_ptr->ops); + ops_ptr->contributeHR(); } } const std::string fn = PARAM.globalv.global_out_dir + "sr_nao.csr"; - ModuleIO::output_SR(pv, gd, this->p_hamilt, fn); + auto* hamilt_ptr = static_cast>*>(this->p_hamilt); + ModuleIO::output_SR(pv, gd, hamilt_ptr, fn); if (PARAM.inp.out_mat_r) { diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index 43e83aa8df..bad1cec7b6 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -179,7 +179,7 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) { //! 13.1.2) init charge density from Hamiltonian matrix file LCAO_domain::init_chg_hr(PARAM.globalv.global_readin_dir, PARAM.inp.nspin, - this->p_hamilt, ucell, &(this->pv), this->psi[0], this->pelec, *this->dmat.dm, + static_cast*>(this->p_hamilt), ucell, &(this->pv), this->psi[0], this->pelec, *this->dmat.dm, this->chr, PARAM.inp.ks_solver); } } @@ -382,7 +382,7 @@ void ESolver_KS_LCAO::hamilt2rho_single(UnitCell& ucell, int istep, int if (!skip_solve) { hsolver::HSolverLCAO hsolver_lcao_obj(&(this->pv), PARAM.inp.ks_solver); - hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, *this->dmat.dm, + hsolver_lcao_obj.solve(static_cast*>(this->p_hamilt), this->psi[0], this->pelec, *this->dmat.dm, this->chr, PARAM.inp.nspin, skip_charge); } diff --git a/source/source_esolver/esolver_ks_lcao_tddft.cpp b/source/source_esolver/esolver_ks_lcao_tddft.cpp index b7641a09fc..05dc8c9233 100644 --- a/source/source_esolver/esolver_ks_lcao_tddft.cpp +++ b/source/source_esolver/esolver_ks_lcao_tddft.cpp @@ -235,7 +235,7 @@ void ESolver_KS_LCAO_TDDFT::hamilt2rho_single(UnitCell& ucell, PARAM.inp.nbands, PARAM.globalv.nlocal, this->kv.get_nks(), - this->p_hamilt, + static_cast>*>(this->p_hamilt), this->pv, this->psi, this->psi_laststep, @@ -255,7 +255,7 @@ void ESolver_KS_LCAO_TDDFT::hamilt2rho_single(UnitCell& ucell, PARAM.inp.nbands, PARAM.globalv.nlocal, this->kv.get_nks(), - this->p_hamilt, + static_cast>*>(this->p_hamilt), this->pv, this->psi, this->psi_laststep, @@ -277,7 +277,7 @@ void ESolver_KS_LCAO_TDDFT::hamilt2rho_single(UnitCell& ucell, { bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false; hsolver::HSolverLCAO> hsolver_lcao_obj(&this->pv, PARAM.inp.ks_solver); - hsolver_lcao_obj.solve(this->p_hamilt, + hsolver_lcao_obj.solve(static_cast>*>(this->p_hamilt), this->psi[0], this->pelec, *this->dmat.dm, @@ -342,11 +342,11 @@ void ESolver_KS_LCAO_TDDFT::iter_finish(UnitCell& ucell, { if (use_tensor && use_lapack) { - elecstate::cal_edm_tddft_tensor_lapack(this->pv, this->dmat, this->kv, this->p_hamilt); + elecstate::cal_edm_tddft_tensor_lapack(this->pv, this->dmat, this->kv, static_cast>*>(this->p_hamilt)); } else { - elecstate::cal_edm_tddft(this->pv, this->dmat, this->kv, this->p_hamilt); + elecstate::cal_edm_tddft(this->pv, this->dmat, this->kv, static_cast>*>(this->p_hamilt)); } } } @@ -416,7 +416,7 @@ void ESolver_KS_LCAO_TDDFT::store_h_s_psi(UnitCell& ucell, this->p_hamilt->updateHk(ik); hamilt::MatrixBlock> h_mat; hamilt::MatrixBlock> s_mat; - this->p_hamilt->matrix(h_mat, s_mat); + static_cast>*>(this->p_hamilt)->matrix(h_mat, s_mat); // Store H and S matrices to Hk_laststep and Sk_laststep if (use_tensor && use_lapack) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index a032c2c976..b2976733bf 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -128,10 +128,10 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) // init DFT+U is done in "before_all_runners" in LCAO basis. This should be refactored, mohan note 2025-11-06 pw::setup_pot(istep, ucell, this->kv, this->sf, this->pelec, this->Pgrid, this->chr, this->locpp, this->ppcell, this->dftu, this->vsep_cell, - this->stp.psi_t, this->p_hamilt, this->pw_wfc, this->pw_rhod, PARAM.inp); + this->stp.psi_t, static_cast*>(this->p_hamilt), this->pw_wfc, this->pw_rhod, PARAM.inp); // setup psi (electronic wave functions) - this->stp.init(this->p_hamilt); + this->stp.init(static_cast*>(this->p_hamilt)); //! Setup EXX helper for Hamiltonian and psi exx_helper.before_scf(this->p_hamilt, this->stp.psi_t, PARAM.inp); From 14c1b8a78e01fbd7b07f4670f1bf5a3b6dcfb032 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Thu, 12 Mar 2026 06:25:32 +0800 Subject: [PATCH 12/40] refactor(esolver): remove psi member from ESolver_KS base class Move psi::Psi* psi from ESolver_KS base class to derived classes to eliminate template parameter dependency and improve code organization. Changes: 1. ESolver_KS base class: - Remove psi::Psi* psi member variable - Remove Setup_Psi::deallocate_psi() call in destructor - Remove unnecessary includes: psi.h and setup_psi.h 2. ESolver_KS_LCAO: - Add psi::Psi* psi member variable - Add Setup_Psi::deallocate_psi() in destructor - Add include: setup_psi.h 3. ESolver_KS_LCAO_TDDFT: - Improve psi_laststep deallocation with nullptr check - psi member inherited from ESolver_KS_LCAO 4. ESolver_KS_PW: - Use stp.psi_cpu directly instead of base class psi - Remove unnecessary memory allocation in after_scf() 5. pw_others.cpp (BUG FIX): - Fix gen_bessel: use *(this->stp.psi_cpu) instead of this->psi[0] - Previous code accessed uninitialized base class psi (nullptr) - This was a latent bug that could cause crashes Benefits: - Eliminates template parameter T dependency in ESolver_KS base class - Clearer memory management: each derived class manages its own psi - Reduces compilation dependencies - Fixes potential memory access bug in pw_others.cpp Tested: Compiled successfully in build_5pt and build_1p --- source/source_esolver/esolver_ks.cpp | 3 --- source/source_esolver/esolver_ks.h | 4 ---- source/source_esolver/esolver_ks_lcao.cpp | 2 ++ source/source_esolver/esolver_ks_lcao.h | 3 +++ source/source_esolver/esolver_ks_lcao_tddft.cpp | 6 +++++- source/source_esolver/esolver_ks_pw.cpp | 7 ++----- source/source_esolver/pw_others.cpp | 2 +- 7 files changed, 13 insertions(+), 14 deletions(-) diff --git a/source/source_esolver/esolver_ks.cpp b/source/source_esolver/esolver_ks.cpp index fc99b8a572..93fb116aca 100644 --- a/source/source_esolver/esolver_ks.cpp +++ b/source/source_esolver/esolver_ks.cpp @@ -15,7 +15,6 @@ #include "source_io/module_output/output_log.h" // use write_head #include "source_estate/elecstate_print.h" // print_etot #include "source_io/module_output/print_info.h" // print_parameters -#include "source_psi/setup_psi.h" // mohan add 20251009 #include "source_lcao/module_dftu/dftu.h" // mohan add 2025-11-07 namespace ModuleESolver @@ -31,8 +30,6 @@ ESolver_KS::~ESolver_KS() //**************************************************** // do not add any codes in this deconstructor funcion //**************************************************** - Setup_Psi::deallocate_psi(this->psi); - delete this->p_hamilt; delete this->p_chgmix; this->ppcell.release_memory(); diff --git a/source/source_esolver/esolver_ks.h b/source/source_esolver/esolver_ks.h index 1913aa4101..eee36fbd88 100644 --- a/source/source_esolver/esolver_ks.h +++ b/source/source_esolver/esolver_ks.h @@ -5,7 +5,6 @@ #include "source_basis/module_pw/pw_basis_k.h" // use plane wave #include "source_cell/klist.h" // use k-points in Brillouin zone #include "source_estate/module_charge/charge_mixing.h" // use charge mixing -#include "source_psi/psi.h" // use electronic wave functions #include "source_hamilt/hamilt.h" // use Hamiltonian #include "source_hamilt/hamilt_base.h" // use Hamiltonian base class #include "source_lcao/module_dftu/dftu.h" // mohan add 20251107 @@ -60,9 +59,6 @@ class ESolver_KS : public ESolver_FP //! nonlocal pseudopotentials pseudopot_cell_vnl ppcell; - //! Electronic wavefunctions - psi::Psi* psi = nullptr; - //! DFT+U method, mohan add 2025-11-07 Plus_U dftu; diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index bad1cec7b6..d418f762c7 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -22,6 +22,7 @@ #include "source_io/module_output/print_info.h" #include "source_lcao/rho_tau_lcao.h" // mohan add 20251024 #include "source_lcao/LCAO_set.h" // mohan add 20251111 +#include "source_psi/setup_psi.h" // use Setup_Psi for deallocate_psi namespace ModuleESolver { @@ -40,6 +41,7 @@ ESolver_KS_LCAO::~ESolver_KS_LCAO() //**************************************************** // do not add any codes in this deconstructor funcion //**************************************************** + Setup_Psi::deallocate_psi(this->psi); } template diff --git a/source/source_esolver/esolver_ks_lcao.h b/source/source_esolver/esolver_ks_lcao.h index 4191306788..0e013ec9ae 100644 --- a/source/source_esolver/esolver_ks_lcao.h +++ b/source/source_esolver/esolver_ks_lcao.h @@ -57,6 +57,9 @@ class ESolver_KS_LCAO : public ESolver_KS virtual void others(UnitCell& ucell, const int istep) override; + //! Electronic wave functions (moved from base class) + psi::Psi* psi = nullptr; + //! Store information about Adjacent Atoms Record_adj RA; diff --git a/source/source_esolver/esolver_ks_lcao_tddft.cpp b/source/source_esolver/esolver_ks_lcao_tddft.cpp index 05dc8c9233..2e463acdcd 100644 --- a/source/source_esolver/esolver_ks_lcao_tddft.cpp +++ b/source/source_esolver/esolver_ks_lcao_tddft.cpp @@ -40,7 +40,11 @@ ESolver_KS_LCAO_TDDFT::~ESolver_KS_LCAO_TDDFT() //************************************************* // Do not add any code in this destructor function //************************************************* - delete psi_laststep; + if (psi_laststep != nullptr) + { + delete psi_laststep; + psi_laststep = nullptr; + } if (td_p != nullptr) { diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index b2976733bf..7415552c90 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -240,13 +240,10 @@ void ESolver_KS_PW::after_scf(UnitCell& ucell, const int istep, const ModuleBase::TITLE("ESolver_KS_PW", "after_scf"); ModuleBase::timer::tick("ESolver_KS_PW", "after_scf"); - // Since ESolver_KS::psi is hidden by ESolver_KS_PW::psi, - // we need to copy the data from ESolver_KS::psi to ESolver_KS_PW::psi. - // sunliang 2025-04-10 + // Calculate kinetic energy density tau for ELF if needed if (PARAM.inp.out_elf[0] > 0) { - this->ESolver_KS::psi = new psi::Psi(this->stp.psi_cpu[0]); - this->pelec->cal_tau(*(this->psi)); + this->pelec->cal_tau(*(this->stp.psi_cpu)); } ESolver_KS::after_scf(ucell, istep, conv_esolver); diff --git a/source/source_esolver/pw_others.cpp b/source/source_esolver/pw_others.cpp index fc42df14bd..49f7465b46 100644 --- a/source/source_esolver/pw_others.cpp +++ b/source/source_esolver/pw_others.cpp @@ -32,7 +32,7 @@ void ESolver_KS_PW::others(UnitCell& ucell, const int istep) { Numerical_Descriptor nc; nc.output_descriptor(ucell, - this->psi[0], + *(this->stp.psi_cpu), PARAM.inp.bessel_descriptor_lmax, PARAM.inp.bessel_descriptor_rcut, PARAM.inp.bessel_descriptor_tolerence, From 0c2fa0f7b6d57ef2a737be317cad01f38526de97 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Thu, 12 Mar 2026 06:36:33 +0800 Subject: [PATCH 13/40] refactor(esolver): remove template parameters from ESolver_KS base class This is a major milestone in ESolver refactoring! ESolver_KS no longer needs template parameters because: - All member variables are non-template types - All member functions do not use T or Device parameters - Template parameters were only needed for derived classes Changes: 1. ESolver_KS base class: - Remove template declaration - Remove all template declarations from member functions - Remove template instantiation code at end of file - Fix Tab indentation to spaces for better readability 2. Derived classes: - ESolver_KS_PW: public ESolver_KS (was ESolver_KS) - ESolver_KS_LCAO: public ESolver_KS (was ESolver_KS) - ESolver_GetS: public ESolver_KS (was ESolver_KS>) - Update base class calls: ESolver_KS:: (was ESolver_KS::) Code reduction: - esolver_ks.h: 78 -> 77 lines (-1 line) - esolver_ks.cpp: 346 -> 317 lines (-29 lines) - Total ESolver code: 424 -> 394 lines (-30 lines) - Overall: 8 files changed, 50 insertions(+), 80 deletions(-), net -30 lines Benefits: - Simpler base class without template complexity - Faster compilation (no template instantiation needed) - Clearer inheritance hierarchy - Easier to extract common code in future refactoring - Sets foundation for further ESolver template removal Tested: Compiled successfully in build_5pt --- source/source_esolver/esolver_gets.h | 2 +- source/source_esolver/esolver_ks.cpp | 97 ++++++++--------------- source/source_esolver/esolver_ks.h | 1 - source/source_esolver/esolver_ks_lcao.cpp | 12 +-- source/source_esolver/esolver_ks_lcao.h | 2 +- source/source_esolver/esolver_ks_pw.cpp | 12 +-- source/source_esolver/esolver_ks_pw.h | 2 +- source/source_esolver/esolver_sdft_pw.cpp | 2 +- 8 files changed, 50 insertions(+), 80 deletions(-) diff --git a/source/source_esolver/esolver_gets.h b/source/source_esolver/esolver_gets.h index 564fd55035..7a7fb1d34b 100644 --- a/source/source_esolver/esolver_gets.h +++ b/source/source_esolver/esolver_gets.h @@ -10,7 +10,7 @@ namespace ModuleESolver { -class ESolver_GetS : public ESolver_KS> +class ESolver_GetS : public ESolver_KS { public: ESolver_GetS(); diff --git a/source/source_esolver/esolver_ks.cpp b/source/source_esolver/esolver_ks.cpp index 93fb116aca..cc94510a66 100644 --- a/source/source_esolver/esolver_ks.cpp +++ b/source/source_esolver/esolver_ks.cpp @@ -20,27 +20,24 @@ namespace ModuleESolver { -template -ESolver_KS::ESolver_KS(){} +ESolver_KS::ESolver_KS() {} -template -ESolver_KS::~ESolver_KS() +ESolver_KS::~ESolver_KS() { - //**************************************************** - // do not add any codes in this deconstructor funcion - //**************************************************** + //**************************************************** + // do not add any codes in this deconstructor funcion + //**************************************************** delete this->p_hamilt; delete this->p_chgmix; this->ppcell.release_memory(); - + // mohan add 2025-10-18, should be put int clean() function pw::teardown_pwwfc(this->pw_wfc); } -template -void ESolver_KS::before_all_runners(UnitCell& ucell, const Input_para& inp) +void ESolver_KS::before_all_runners(UnitCell& ucell, const Input_para& inp) { ModuleBase::TITLE("ESolver_KS", "before_all_runners"); @@ -78,12 +75,10 @@ void ESolver_KS::before_all_runners(UnitCell& ucell, const Input_para } -template -void ESolver_KS::hamilt2rho_single(UnitCell& ucell, const int istep, const int iter, const double ethr) +void ESolver_KS::hamilt2rho_single(UnitCell& ucell, const int istep, const int iter, const double ethr) {} -template -void ESolver_KS::hamilt2rho(UnitCell& ucell, const int istep, const int iter, const double ethr) +void ESolver_KS::hamilt2rho(UnitCell& ucell, const int istep, const int iter, const double ethr) { // 1) use Hamiltonian to obtain charge density this->hamilt2rho_single(ucell, istep, iter, diag_ethr); @@ -123,8 +118,7 @@ void ESolver_KS::hamilt2rho(UnitCell& ucell, const int istep, const i } } -template -void ESolver_KS::runner(UnitCell& ucell, const int istep) +void ESolver_KS::runner(UnitCell& ucell, const int istep) { ModuleBase::TITLE("ESolver_KS", "runner"); ModuleBase::timer::tick(this->classname, "runner"); @@ -139,14 +133,14 @@ void ESolver_KS::runner(UnitCell& ucell, const int istep) this->diag_ethr = PARAM.inp.pw_diag_thr; this->scf_nmax_flag = false; // mohan add 2025-09-21 for (int iter = 1; iter <= this->maxniter; ++iter) - { - if(iter == this->maxniter) - { - this->scf_nmax_flag=true; - } + { + if(iter == this->maxniter) + { + this->scf_nmax_flag=true; + } - // 3) initialization of SCF iterations - this->iter_init(ucell, istep, iter); + // 3) initialization of SCF iterations + this->iter_init(ucell, istep, iter); // 4) use Hamiltonian to obtain charge density this->hamilt2rho(ucell, istep, iter, diag_ethr); @@ -166,22 +160,20 @@ void ESolver_KS::runner(UnitCell& ucell, const int istep) } } // end scf iterations - // 7) after scf + // 7) after scf this->after_scf(ucell, istep, conv_esolver); ModuleBase::timer::tick(this->classname, "runner"); return; }; -template -void ESolver_KS::before_scf(UnitCell& ucell, const int istep) +void ESolver_KS::before_scf(UnitCell& ucell, const int istep) { ModuleBase::TITLE("ESolver_KS", "before_scf"); ESolver_FP::before_scf(ucell, istep); } -template -void ESolver_KS::iter_init(UnitCell& ucell, const int istep, const int iter) +void ESolver_KS::iter_init(UnitCell& ucell, const int istep, const int iter) { if(PARAM.inp.esolver_type != "tddft") { @@ -207,8 +199,7 @@ void ESolver_KS::iter_init(UnitCell& ucell, const int istep, const in this->chr.save_rho_before_sum_band(); } -template -void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& iter, bool &conv_esolver) +void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& iter, bool &conv_esolver) { // 1.1) print out band gap @@ -224,25 +215,25 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i // 1.2) print out eigenvalues and occupations if (PARAM.inp.out_band[0]) { - if (iter % PARAM.inp.out_freq_elec == 0 || iter == PARAM.inp.scf_nmax || conv_esolver) - { - ModuleIO::write_eig_iter(this->pelec->ekb,this->pelec->wg,*this->pelec->klist); - } + if (iter % PARAM.inp.out_freq_elec == 0 || iter == PARAM.inp.scf_nmax || conv_esolver) + { + ModuleIO::write_eig_iter(this->pelec->ekb,this->pelec->wg,*this->pelec->klist); + } } // 2.1) compute magnetization, only for spin==2 ucell.magnet.compute_mag(ucell.omega, this->chr.nrxx, this->chr.nxyz, this->chr.rho, this->pelec->nelec_spin.data()); - // 2.2) charge mixing + // 2.2) charge mixing // SCF will continue if U is not converged for uramping calculation - bool converged_u = true; - // to avoid unnecessary dependence on dft+u, refactor is needed + bool converged_u = true; + // to avoid unnecessary dependence on dft+u, refactor is needed #ifdef __LCAO - if (PARAM.inp.dft_plus_u) - { - converged_u = this->dftu.u_converged(); - } + if (PARAM.inp.dft_plus_u) + { + converged_u = this->dftu.u_converged(); + } #endif module_charge::chgmixing_ks(iter, ucell, this->pelec, this->chr, this->p_chgmix, @@ -293,8 +284,7 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i } //! Something to do after SCF iterations when SCF is converged or comes to the max iter step. -template -void ESolver_KS::after_scf(UnitCell& ucell, const int istep, const bool conv_esolver) +void ESolver_KS::after_scf(UnitCell& ucell, const int istep, const bool conv_esolver) { ModuleBase::TITLE("ESolver_KS", "after_scf"); @@ -318,29 +308,10 @@ void ESolver_KS::after_scf(UnitCell& ucell, const int istep, const bo } -template -void ESolver_KS::after_all_runners(UnitCell& ucell) +void ESolver_KS::after_all_runners(UnitCell& ucell) { // 1) write Etot information ESolver_FP::after_all_runners(ucell); } -//------------------------------------------------------------------------------ -//! the 16th-20th functions of ESolver_KS -//! mohan add 2024-05-12 -//------------------------------------------------------------------------------ -//! This is for mixed-precision pw/LCAO basis sets. -template class ESolver_KS, base_device::DEVICE_CPU>; -template class ESolver_KS, base_device::DEVICE_CPU>; - -//! This is for GPU codes. -#if ((defined __CUDA) || (defined __ROCM)) -template class ESolver_KS, base_device::DEVICE_GPU>; -template class ESolver_KS, base_device::DEVICE_GPU>; -#endif - -//! This is for LCAO basis set. -#ifdef __LCAO -template class ESolver_KS; -#endif } // namespace ModuleESolver diff --git a/source/source_esolver/esolver_ks.h b/source/source_esolver/esolver_ks.h index eee36fbd88..b6affc7b0c 100644 --- a/source/source_esolver/esolver_ks.h +++ b/source/source_esolver/esolver_ks.h @@ -12,7 +12,6 @@ namespace ModuleESolver { -template class ESolver_KS : public ESolver_FP { public: diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index d418f762c7..0558942e91 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -51,7 +51,7 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa ModuleBase::timer::tick("ESolver_KS_LCAO", "before_all_runners"); // 1) before_all_runners in ESolver_KS - ESolver_KS::before_all_runners(ucell, inp); + ESolver_KS::before_all_runners(ucell, inp); // 2) autoset nbands in ElecState before init_basis (for Psi 2d division) if (this->pelec == nullptr) @@ -107,7 +107,7 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) ModuleBase::timer::tick("ESolver_KS_LCAO", "before_scf"); //! 1) call before_scf() of ESolver_KS. - ESolver_KS::before_scf(ucell, istep); + ESolver_KS::before_scf(ucell, istep); //! 2) find search radius double search_radius = atom_arrange::set_sr_NL(GlobalV::ofs_running, @@ -271,7 +271,7 @@ void ESolver_KS_LCAO::after_all_runners(UnitCell& ucell) ModuleBase::TITLE("ESolver_KS_LCAO", "after_all_runners"); ModuleBase::timer::tick("ESolver_KS_LCAO", "after_all_runners"); - ESolver_KS::after_all_runners(ucell); + ESolver_KS::after_all_runners(ucell); auto* hamilt_lcao = dynamic_cast*>(this->p_hamilt); if(!hamilt_lcao) @@ -303,7 +303,7 @@ void ESolver_KS_LCAO::iter_init(UnitCell& ucell, const int istep, const ModuleBase::TITLE("ESolver_KS_LCAO", "iter_init"); // call iter_init() of ESolver_KS - ESolver_KS::iter_init(ucell, istep, iter); + ESolver_KS::iter_init(ucell, istep, iter); module_charge::chgmixing_ks_lcao(iter, this->p_chgmix, this->dftu, this->dmat.dm->get_DMR_pointer(1)->get_nnr(), PARAM.inp); @@ -438,7 +438,7 @@ void ESolver_KS_LCAO::iter_finish(UnitCell& ucell, const int istep, int& // eig and occ are printed, magnetization is calculated, // charge mixing is performed, potential is updated, // HF and kS energies are computed, meta-GGA, Jason and restart - ESolver_KS::iter_finish(ucell, istep, iter, conv_esolver); + ESolver_KS::iter_finish(ucell, istep, iter, conv_esolver); // mix density matrix if mixing_restart + mixing_dmr + not first // mixing_restart at every iter except the last iter @@ -476,7 +476,7 @@ void ESolver_KS_LCAO::after_scf(UnitCell& ucell, const int istep, const } //! 1) call after_scf() of ESolver_KS - ESolver_KS::after_scf(ucell, istep, conv_esolver); + ESolver_KS::after_scf(ucell, istep, conv_esolver); //! 2) output of lcao every few ionic steps ModuleIO::ctrl_scf_lcao(ucell, diff --git a/source/source_esolver/esolver_ks_lcao.h b/source/source_esolver/esolver_ks_lcao.h index 0e013ec9ae..143f7089ba 100644 --- a/source/source_esolver/esolver_ks_lcao.h +++ b/source/source_esolver/esolver_ks_lcao.h @@ -28,7 +28,7 @@ namespace ModuleESolver { template -class ESolver_KS_LCAO : public ESolver_KS +class ESolver_KS_LCAO : public ESolver_KS { public: ESolver_KS_LCAO(); diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 7415552c90..74507e09c7 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -84,7 +84,7 @@ void ESolver_KS_PW::deallocate_hamilt() template void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_para& inp) { - ESolver_KS::before_all_runners(ucell, inp); + ESolver_KS::before_all_runners(ucell, inp); //! setup and allocation for pelec, potentials, etc. elecstate::setup_estate_pw(ucell, this->kv, this->sf, this->pelec, this->chr, @@ -105,7 +105,7 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) ModuleBase::TITLE("ESolver_KS_PW", "before_scf"); ModuleBase::timer::tick("ESolver_KS_PW", "before_scf"); - ESolver_KS::before_scf(ucell, istep); + ESolver_KS::before_scf(ucell, istep); //! Init variables (once the cell has changed) pw::update_cell_pw(ucell, this->ppcell, this->kv, this->pw_wfc, PARAM.inp); @@ -142,7 +142,7 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) template void ESolver_KS_PW::iter_init(UnitCell& ucell, const int istep, const int iter) { - ESolver_KS::iter_init(ucell, istep, iter); + ESolver_KS::iter_init(ucell, istep, iter); module_charge::chgmixing_ks_pw(iter, this->p_chgmix, this->dftu, PARAM.inp); @@ -212,7 +212,7 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int this->pelec->f_en.deband = this->pelec->cal_delta_eband(ucell); // Call iter_finish() of ESolver_KS - ESolver_KS::iter_finish(ucell, istep, iter, conv_esolver); + ESolver_KS::iter_finish(ucell, istep, iter, conv_esolver); // D in USPP needs vloc, thus needs update when veff updated // calculate the effective coefficient matrix for non-local @@ -246,7 +246,7 @@ void ESolver_KS_PW::after_scf(UnitCell& ucell, const int istep, const this->pelec->cal_tau(*(this->stp.psi_cpu)); } - ESolver_KS::after_scf(ucell, istep, conv_esolver); + ESolver_KS::after_scf(ucell, istep, conv_esolver); // Output quantities ModuleIO::ctrl_scf_pw(istep, ucell, this->pelec, this->chr, this->kv, this->pw_wfc, @@ -300,7 +300,7 @@ void ESolver_KS_PW::cal_stress(UnitCell& ucell, ModuleBase::matrix& s template void ESolver_KS_PW::after_all_runners(UnitCell& ucell) { - ESolver_KS::after_all_runners(ucell); + ESolver_KS::after_all_runners(ucell); ModuleIO::ctrl_runner_pw(ucell, this->pelec, this->pw_wfc, this->pw_rho, this->pw_rhod, this->chr, this->kv, this->stp, diff --git a/source/source_esolver/esolver_ks_pw.h b/source/source_esolver/esolver_ks_pw.h index 01e1027d79..6a6be52b73 100644 --- a/source/source_esolver/esolver_ks_pw.h +++ b/source/source_esolver/esolver_ks_pw.h @@ -13,7 +13,7 @@ namespace ModuleESolver { template -class ESolver_KS_PW : public ESolver_KS +class ESolver_KS_PW : public ESolver_KS { private: using Real = typename GetTypeReal::type; diff --git a/source/source_esolver/esolver_sdft_pw.cpp b/source/source_esolver/esolver_sdft_pw.cpp index 597825cd6d..f7f9a29983 100644 --- a/source/source_esolver/esolver_sdft_pw.cpp +++ b/source/source_esolver/esolver_sdft_pw.cpp @@ -119,7 +119,7 @@ template void ESolver_SDFT_PW::iter_finish(UnitCell& ucell, const int istep, int& iter, bool& conv_esolver) { // call iter_finish() of ESolver_KS - ESolver_KS::iter_finish(ucell, istep, iter, conv_esolver); + ESolver_KS::iter_finish(ucell, istep, iter, conv_esolver); } template From 4b51e36da270685a31a6d8062a3fd499d86a2f3b Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Thu, 12 Mar 2026 08:48:04 +0800 Subject: [PATCH 14/40] refactor(device): remove explicit template parameter from get_device_type calls - Move get_device_type implementation to header file using std::is_same - Add DEVICE_DSP support - Remove template specialization declarations and definitions - Update all call sites to use automatic template parameter deduction - The compiler now deduces Device type from the ctx parameter --- source/source_base/math_chebyshev.cpp | 14 +++++++------- .../source_base/module_device/device_helpers.cpp | 13 ------------- .../source_base/module_device/device_helpers.h | 16 ++++++++-------- .../module_device/test/device_test.cpp | 4 ++-- source/source_esolver/esolver_ks_pw.cpp | 2 +- source/source_hsolver/diago_dav_subspace.cpp | 2 +- source/source_hsolver/diago_david.cpp | 2 +- source/source_hsolver/diago_iter_assist.cpp | 8 ++++---- source/source_hsolver/test/hsolver_pw_sup.h | 2 +- source/source_pw/module_pwdft/forces.cpp | 6 +++--- source/source_pw/module_pwdft/forces_cc.cpp | 4 ++-- source/source_pw/module_pwdft/forces_scc.cpp | 2 +- source/source_pw/module_pwdft/fs_kin_tools.cpp | 2 +- .../source_pw/module_pwdft/fs_nonlocal_tools.cpp | 2 +- source/source_pw/module_pwdft/nonlocal_maths.hpp | 4 ++-- .../source_pw/module_pwdft/onsite_proj_tools.cpp | 4 ++-- .../source_pw/module_pwdft/onsite_projector.cpp | 2 +- source/source_pw/module_pwdft/op_pw_ekin.cpp | 4 ++-- source/source_pw/module_pwdft/stress_cc.cpp | 2 +- source/source_pw/module_pwdft/stress_loc.cpp | 2 +- .../module_pwdft/structure_factor_k.cpp | 2 +- source/source_pw/module_stodft/sto_forces.cpp | 2 +- source/source_pw/module_stodft/sto_wf.cpp | 10 +++++----- 23 files changed, 49 insertions(+), 62 deletions(-) diff --git a/source/source_base/math_chebyshev.cpp b/source/source_base/math_chebyshev.cpp index 8a84686ea5..b7e59a89f9 100644 --- a/source/source_base/math_chebyshev.cpp +++ b/source/source_base/math_chebyshev.cpp @@ -61,7 +61,7 @@ Chebyshev::Chebyshev(const int norder_in) : fftw(2 * EXTEND * nord } coefr_cpu = new REAL[norder]; coefc_cpu = new std::complex[norder]; - if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) { resmem_var_op()(this->coef_real, norder); resmem_complex_op()(this->coef_complex, norder); @@ -82,7 +82,7 @@ template Chebyshev::~Chebyshev() { delete[] polytrace; - if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) { delmem_var_op()(this->coef_real); delmem_complex_op()(this->coef_complex); @@ -209,7 +209,7 @@ void Chebyshev::calcoef_real(std::function fun) } } - if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) { syncmem_var_h2d_op()(coef_real, coefr_cpu, norder); } @@ -299,7 +299,7 @@ void Chebyshev::calcoef_complex(std::function(s coefc_cpu[i].imag(imag(coefc_cpu[i]) + real(pcoef[i]) / norder2 * 2 / 3); } } - if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) { syncmem_complex_h2d_op()(coef_complex, coefc_cpu, norder); } @@ -390,7 +390,7 @@ void Chebyshev::calcoef_pair(std::function fun1, std:: } } - if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) { syncmem_complex_h2d_op()(coef_complex, coefc_cpu, norder); } @@ -684,7 +684,7 @@ bool Chebyshev::checkconverge( funA(arrayn_1, arrayn, 1); REAL sum1, sum2; REAL t; - if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) { sum1 = this->ddot_real(arrayn_1, arrayn_1, N); sum2 = this->ddot_real(arrayn_1, arrayn, N); @@ -714,7 +714,7 @@ bool Chebyshev::checkconverge( for (int ior = 2; ior < norder; ++ior) { funA(arrayn, arraynp1, 1); - if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) { sum1 = this->ddot_real(arrayn, arrayn, N); sum2 = this->ddot_real(arrayn, arraynp1, N); diff --git a/source/source_base/module_device/device_helpers.cpp b/source/source_base/module_device/device_helpers.cpp index 1c53020718..0b5d5a1693 100644 --- a/source/source_base/module_device/device_helpers.cpp +++ b/source/source_base/module_device/device_helpers.cpp @@ -3,19 +3,6 @@ namespace base_device { -// Device type specializations -template <> -AbacusDevice_t get_device_type(const DEVICE_CPU* dev) -{ - return CpuDevice; -} - -template <> -AbacusDevice_t get_device_type(const DEVICE_GPU* dev) -{ - return GpuDevice; -} - // Precision specializations template <> std::string get_current_precision(const float* var) diff --git a/source/source_base/module_device/device_helpers.h b/source/source_base/module_device/device_helpers.h index 60eddd888d..6aa71938de 100644 --- a/source/source_base/module_device/device_helpers.h +++ b/source/source_base/module_device/device_helpers.h @@ -13,6 +13,7 @@ #include "types.h" #include #include +#include namespace base_device { @@ -24,14 +25,13 @@ namespace base_device * @return AbacusDevice_t enum value */ template -AbacusDevice_t get_device_type(const Device* dev); - -// Template specialization declarations -template <> -AbacusDevice_t get_device_type(const DEVICE_CPU* dev); - -template <> -AbacusDevice_t get_device_type(const DEVICE_GPU* dev); +AbacusDevice_t get_device_type(const Device* dev) +{ + if (std::is_same::value) return CpuDevice; + else if (std::is_same::value) return GpuDevice; + else if (std::is_same::value) return DspDevice; + else return UnKnown; +} /** * @brief Get the precision string for a given numeric type. diff --git a/source/source_base/module_device/test/device_test.cpp b/source/source_base/module_device/test/device_test.cpp index 02d485c8ef..faf083c721 100644 --- a/source/source_base/module_device/test/device_test.cpp +++ b/source/source_base/module_device/test/device_test.cpp @@ -20,14 +20,14 @@ class TestModulePsiDevice : public ::testing::Test TEST_F(TestModulePsiDevice, get_device_type_cpu) { - base_device::AbacusDevice_t device = base_device::get_device_type(cpu_ctx); + base_device::AbacusDevice_t device = base_device::get_device_type(cpu_ctx); EXPECT_EQ(device, base_device::CpuDevice); } #if __UT_USE_CUDA || __UT_USE_ROCM TEST_F(TestModulePsiDevice, get_device_type_gpu) { - base_device::AbacusDevice_t device = base_device::get_device_type(gpu_ctx); + base_device::AbacusDevice_t device = base_device::get_device_type(gpu_ctx); EXPECT_EQ(device, base_device::GpuDevice); } #endif // __UT_USE_CUDA || __UT_USE_ROCM diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 74507e09c7..a876039457 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -43,7 +43,7 @@ ESolver_KS_PW::ESolver_KS_PW() { this->classname = "ESolver_KS_PW"; this->basisname = "PW"; - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); } template diff --git a/source/source_hsolver/diago_dav_subspace.cpp b/source/source_hsolver/diago_dav_subspace.cpp index 27c6a5b348..4ff93d03e9 100644 --- a/source/source_hsolver/diago_dav_subspace.cpp +++ b/source/source_hsolver/diago_dav_subspace.cpp @@ -36,7 +36,7 @@ Diago_DavSubspace::Diago_DavSubspace(const std::vector& precond diag_thr(diag_thr_in), iter_nmax(diag_nmax_in), diag_comm(diag_comm_in), diag_subspace(diag_subspace_in), diago_subspace_bs(diago_subspace_bs_in) { - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->one = &one_; this->zero = &zero_; diff --git a/source/source_hsolver/diago_david.cpp b/source/source_hsolver/diago_david.cpp index ef4ba67cf3..49d5d0d953 100644 --- a/source/source_hsolver/diago_david.cpp +++ b/source/source_hsolver/diago_david.cpp @@ -20,7 +20,7 @@ DiagoDavid::DiagoDavid(const Real* precondition_in, const diag_comm_info& diag_comm_in) : nband(nband_in), dim(dim_in), nbase_x(david_ndim_in * nband_in), david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in) { - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->precondition = precondition_in; this->one = &one_; diff --git a/source/source_hsolver/diago_iter_assist.cpp b/source/source_hsolver/diago_iter_assist.cpp index fb87ad2350..8c5673c37a 100644 --- a/source/source_hsolver/diago_iter_assist.cpp +++ b/source/source_hsolver/diago_iter_assist.cpp @@ -400,14 +400,14 @@ void DiagoIterAssist::diag_heevx(const int matrix_size, // (const Device *d, const int matrix_size, const int lda, const T *A, const int num_eigenpairs, Real *eigenvalues, T *eigenvectors); heevx_op()(ctx, matrix_size, ldh, h, num_eigenpairs, eigenvalues, v); - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { #if ((defined __CUDA) || (defined __ROCM)) // eigenvalues to e, from device to host syncmem_var_d2h_op()(e, eigenvalues, num_eigenpairs); #endif } - else if (base_device::get_device_type(ctx) == base_device::CpuDevice) + else if (base_device::get_device_type(ctx) == base_device::CpuDevice) { // eigenvalues to e syncmem_var_op()(e, eigenvalues, num_eigenpairs); @@ -436,14 +436,14 @@ void DiagoIterAssist::diag_hegvd(const int nstart, hegvd_op()(ctx, nstart, ldh, hcc, scc, eigenvalues, vcc); - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { #if ((defined __CUDA) || (defined __ROCM)) // set eigenvalues in GPU to e in CPU syncmem_var_d2h_op()(e, eigenvalues, nbands); #endif } - else if (base_device::get_device_type(ctx) == base_device::CpuDevice) + else if (base_device::get_device_type(ctx) == base_device::CpuDevice) { // set eigenvalues in CPU to e in CPU syncmem_var_op()(e, eigenvalues, nbands); diff --git a/source/source_hsolver/test/hsolver_pw_sup.h b/source/source_hsolver/test/hsolver_pw_sup.h index a5aab01735..5f5108c627 100644 --- a/source/source_hsolver/test/hsolver_pw_sup.h +++ b/source/source_hsolver/test/hsolver_pw_sup.h @@ -126,7 +126,7 @@ DiagoDavid::DiagoDavid(const Real* precondition_in, const bool use_paw_in, const diag_comm_info& diag_comm_in) : nband(nband_in), dim(dim_in), nbase_x(david_ndim_in * nband_in), david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in) { - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->precondition = precondition_in; test_david = 2; diff --git a/source/source_pw/module_pwdft/forces.cpp b/source/source_pw/module_pwdft/forces.cpp index a6894c49ca..3e58b737ba 100644 --- a/source/source_pw/module_pwdft/forces.cpp +++ b/source/source_pw/module_pwdft/forces.cpp @@ -38,7 +38,7 @@ void Forces::cal_force(UnitCell& ucell, { ModuleBase::timer::tick("Forces", "cal_force"); ModuleBase::TITLE("Forces", "init"); - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); const ModuleBase::matrix& wg = elec.wg; const ModuleBase::matrix& ekb = elec.ekb; const Charge* const chr = elec.charge; @@ -331,7 +331,7 @@ void Forces::cal_force_loc(const UnitCell& ucell, { ModuleBase::TITLE("Forces", "cal_force_loc"); ModuleBase::timer::tick("Forces", "cal_force_loc"); - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); std::complex* aux = new std::complex[rho_basis->nmaxgr]; // now, in all pools , the charge are the same, // so, the force calculated by each pool is equal. @@ -478,7 +478,7 @@ void Forces::cal_force_ew(const UnitCell& ucell, { ModuleBase::TITLE("Forces", "cal_force_ew"); ModuleBase::timer::tick("Forces", "cal_force_ew"); - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); double fact = 2.0; std::vector> aux(rho_basis->npw); diff --git a/source/source_pw/module_pwdft/forces_cc.cpp b/source/source_pw/module_pwdft/forces_cc.cpp index 7788ed2af5..917e00b83c 100644 --- a/source/source_pw/module_pwdft/forces_cc.cpp +++ b/source/source_pw/module_pwdft/forces_cc.cpp @@ -116,7 +116,7 @@ void Forces::cal_force_cc(ModuleBase::matrix& forcecc, double *force_d = nullptr; double *rhocgigg_vec_d = nullptr; std::complex* psiv_d = nullptr; - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); for (int ig = 0; ig < rho_basis->npw; ig++) @@ -258,7 +258,7 @@ void Forces::deriv_drhoc double gx = 0, rhocg1 = 0; //double *aux = new double[mesh]; std::vector aux(mesh); - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); // the modulus of g for a given shell // the fourier transform // auxiliary memory for integration diff --git a/source/source_pw/module_pwdft/forces_scc.cpp b/source/source_pw/module_pwdft/forces_scc.cpp index 7134232416..5e00b87dca 100644 --- a/source/source_pw/module_pwdft/forces_scc.cpp +++ b/source/source_pw/module_pwdft/forces_scc.cpp @@ -152,7 +152,7 @@ void Forces::deriv_drhoc_scc(const bool& numeric, int igl0 = 0; double gx = 0; double rhocg1 = 0; - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); /// the modulus of g for a given shell /// the fourier transform /// auxiliary memory for integration diff --git a/source/source_pw/module_pwdft/fs_kin_tools.cpp b/source/source_pw/module_pwdft/fs_kin_tools.cpp index 853ae34abd..0c04b26f2a 100644 --- a/source/source_pw/module_pwdft/fs_kin_tools.cpp +++ b/source/source_pw/module_pwdft/fs_kin_tools.cpp @@ -10,7 +10,7 @@ FS_Kin_tools::FS_Kin_tools(const UnitCell& ucell_in, const ModuleBase::matrix& wg) : ucell_(ucell_in), nksbands_(wg.nc), wg(wg.c), wk(p_kv->wk.data()) { - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->wfc_basis_ = wfc_basis_in; const int npwk_max = this->wfc_basis_->npwk_max; const int nks = this->wfc_basis_->nks; diff --git a/source/source_pw/module_pwdft/fs_nonlocal_tools.cpp b/source/source_pw/module_pwdft/fs_nonlocal_tools.cpp index 934d3c476d..64f4015daf 100644 --- a/source/source_pw/module_pwdft/fs_nonlocal_tools.cpp +++ b/source/source_pw/module_pwdft/fs_nonlocal_tools.cpp @@ -26,7 +26,7 @@ FS_Nonlocal_tools::FS_Nonlocal_tools(const pseudopot_cell_vnl* n : nlpp_(nlpp_in), ucell_(ucell_in), kv_(kv_in), wfc_basis_(wfc_basis_in), sf_(sf_in) { // get the device context - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->nkb = nlpp_->nkb; this->max_npw = wfc_basis_->npwk_max; this->ntype = ucell_->ntype; diff --git a/source/source_pw/module_pwdft/nonlocal_maths.hpp b/source/source_pw/module_pwdft/nonlocal_maths.hpp index 3e09675bcb..3a8b133cb8 100644 --- a/source/source_pw/module_pwdft/nonlocal_maths.hpp +++ b/source/source_pw/module_pwdft/nonlocal_maths.hpp @@ -18,14 +18,14 @@ class Nonlocal_maths public: Nonlocal_maths(const pseudopot_cell_vnl* nlpp_in, const UnitCell* ucell_in) { - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->nhtol_ = nlpp_in->nhtol; this->lmax_ = nlpp_in->lmaxkb; this->ucell_ = ucell_in; } Nonlocal_maths(const ModuleBase::matrix& nhtol, const int lmax, const UnitCell* ucell_in) { - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->nhtol_ = nhtol; this->lmax_ = lmax; this->ucell_ = ucell_in; diff --git a/source/source_pw/module_pwdft/onsite_proj_tools.cpp b/source/source_pw/module_pwdft/onsite_proj_tools.cpp index 509a65a6ab..aa6ed0f83f 100644 --- a/source/source_pw/module_pwdft/onsite_proj_tools.cpp +++ b/source/source_pw/module_pwdft/onsite_proj_tools.cpp @@ -24,7 +24,7 @@ Onsite_Proj_tools::Onsite_Proj_tools(const pseudopot_cell_vnl* n : nlpp_(nlpp_in), ucell_(ucell_in), psi_(psi_in), kv_(kv_in), wfc_basis_(wfc_basis_in), sf_(sf_in) { // get the device context - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); // seems kvec_c never used... this->kvec_c = this->wfc_basis_->template get_kvec_c_data(); @@ -126,7 +126,7 @@ Onsite_Proj_tools::Onsite_Proj_tools( wfc_basis_ = wfc_basis_in; sf_ = sf_in; - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->kvec_c = this->wfc_basis_->template get_kvec_c_data(); // skip deeq, qq_nt diff --git a/source/source_pw/module_pwdft/onsite_projector.cpp b/source/source_pw/module_pwdft/onsite_projector.cpp index f9b8ad7cbc..4353700d65 100644 --- a/source/source_pw/module_pwdft/onsite_projector.cpp +++ b/source/source_pw/module_pwdft/onsite_projector.cpp @@ -104,7 +104,7 @@ void projectors::OnsiteProjector::init(const std::string& orbital_dir const ModuleBase::matrix& wg, const ModuleBase::matrix& ekb) { - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); if(!this->initialed) { diff --git a/source/source_pw/module_pwdft/op_pw_ekin.cpp b/source/source_pw/module_pwdft/op_pw_ekin.cpp index 05d28266fd..9c62204050 100644 --- a/source/source_pw/module_pwdft/op_pw_ekin.cpp +++ b/source/source_pw/module_pwdft/op_pw_ekin.cpp @@ -18,7 +18,7 @@ Ekinetic>::Ekinetic( this->gk2 = gk2_in; this->gk2_row = gk2_row; this->gk2_col = gk2_col; - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); if( this->tpiba2 < 1e-10 || this->gk2 == nullptr) { ModuleBase::WARNING_QUIT("EkineticPW", "Constuctor of Operator::EkineticPW is failed, please check your code!"); @@ -67,7 +67,7 @@ hamilt::Ekinetic>::Ekinetic(const Ekineticgk2 = ekinetic->get_gk2(); this->gk2_row = ekinetic->get_gk2_row(); this->gk2_col = ekinetic->get_gk2_col(); - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); if( this->tpiba2 < 1e-10 || this->gk2 == nullptr) { ModuleBase::WARNING_QUIT("EkineticPW", "Copy Constuctor of Operator::EkineticPW is failed, please check your code!"); } diff --git a/source/source_pw/module_pwdft/stress_cc.cpp b/source/source_pw/module_pwdft/stress_cc.cpp index 211d5a4bda..0607371074 100644 --- a/source/source_pw/module_pwdft/stress_cc.cpp +++ b/source/source_pw/module_pwdft/stress_cc.cpp @@ -230,7 +230,7 @@ void Stress_Func::deriv_drhoc double gx = 0.0; double rhocg1 = 0.0; std::vector aux(mesh); - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); // the modulus of g for a given shell // the fourier transform diff --git a/source/source_pw/module_pwdft/stress_loc.cpp b/source/source_pw/module_pwdft/stress_loc.cpp index 0b932afcdb..f6cac47604 100644 --- a/source/source_pw/module_pwdft/stress_loc.cpp +++ b/source/source_pw/module_pwdft/stress_loc.cpp @@ -189,7 +189,7 @@ const UnitCell& ucell_in int igl0 = 0; - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); std::vector gx_arr(rho_basis->ngg+1); double* gx_arr_d = nullptr; diff --git a/source/source_pw/module_pwdft/structure_factor_k.cpp b/source/source_pw/module_pwdft/structure_factor_k.cpp index 52dc326545..3ca4980c58 100644 --- a/source/source_pw/module_pwdft/structure_factor_k.cpp +++ b/source/source_pw/module_pwdft/structure_factor_k.cpp @@ -57,7 +57,7 @@ void Structure_Factor::get_sk(Device* ctx, ModuleBase::timer::tick("Structure_Factor", "get_sk"); base_device::DEVICE_CPU* cpu_ctx = {}; - base_device::AbacusDevice_t device = base_device::get_device_type(ctx); + base_device::AbacusDevice_t device = base_device::get_device_type(ctx); using cal_sk_op = hamilt::cal_sk_op; using resmem_int_op = base_device::memory::resize_memory_op; using delmem_int_op = base_device::memory::delete_memory_op; diff --git a/source/source_pw/module_stodft/sto_forces.cpp b/source/source_pw/module_stodft/sto_forces.cpp index 4e57ae98c7..e349d3de2c 100644 --- a/source/source_pw/module_stodft/sto_forces.cpp +++ b/source/source_pw/module_stodft/sto_forces.cpp @@ -31,7 +31,7 @@ void Sto_Forces::cal_stoforce(ModuleBase::matrix& force, { ModuleBase::timer::tick("Sto_Forces", "cal_force"); ModuleBase::TITLE("Sto_Forces", "init"); - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); const ModuleBase::matrix& wg = elec.wg; const Charge* chr = elec.charge; force.create(this->nat, 3); diff --git a/source/source_pw/module_stodft/sto_wf.cpp b/source/source_pw/module_stodft/sto_wf.cpp index 2ba8db2908..a0204e1f87 100644 --- a/source/source_pw/module_stodft/sto_wf.cpp +++ b/source/source_pw/module_stodft/sto_wf.cpp @@ -19,7 +19,7 @@ Stochastic_WF::~Stochastic_WF() { delete chi0_cpu; Device* ctx = {}; - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { delete chi0; } @@ -119,7 +119,7 @@ void Stochastic_WF::allocate_chi0() // allocate chi0 Device* ctx = {}; - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk, true); } @@ -248,7 +248,7 @@ void Stochastic_WF::init_com_orbitals() delete[] totnpw; // allocate chi0 Device* ctx = {}; - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk, true); } @@ -280,7 +280,7 @@ void Stochastic_WF::init_com_orbitals() // allocate chi0 Device* ctx = {}; - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk, true); } @@ -370,7 +370,7 @@ template void Stochastic_WF::sync_chi0() { Device* ctx = {}; - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { syncmem_h2d_op()(this->chi0->get_pointer(), this->chi0_cpu->get_pointer(), From f74b4b08e17cd0c1c575d3e52a78875ce30e380c Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Thu, 12 Mar 2026 08:56:39 +0800 Subject: [PATCH 15/40] refactor(esolver): remove device member variable from ESolver_KS_PW - Modify copy_d2h to accept ctx parameter and call get_device_type internally - Remove device parameter from ctrl_scf_pw function - Remove device member variable from ESolver_KS_PW class - Simplify function interfaces by using automatic template deduction --- source/source_esolver/esolver_ks_pw.cpp | 3 +-- source/source_esolver/esolver_ks_pw.h | 3 --- source/source_io/module_ctrl/ctrl_output_pw.cpp | 7 +------ source/source_io/module_ctrl/ctrl_output_pw.h | 1 - source/source_psi/setup_psi_pw.cpp | 4 ++-- source/source_psi/setup_psi_pw.h | 2 +- 6 files changed, 5 insertions(+), 15 deletions(-) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index a876039457..6506432352 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -43,7 +43,6 @@ ESolver_KS_PW::ESolver_KS_PW() { this->classname = "ESolver_KS_PW"; this->basisname = "PW"; - this->device = base_device::get_device_type(this->ctx); } template @@ -251,7 +250,7 @@ void ESolver_KS_PW::after_scf(UnitCell& ucell, const int istep, const // Output quantities ModuleIO::ctrl_scf_pw(istep, ucell, this->pelec, this->chr, this->kv, this->pw_wfc, this->pw_rho, this->pw_rhod, this->pw_big, this->stp, - this->ctx, this->device, this->Pgrid, PARAM.inp); + this->ctx, this->Pgrid, PARAM.inp); ModuleBase::timer::tick("ESolver_KS_PW", "after_scf"); } diff --git a/source/source_esolver/esolver_ks_pw.h b/source/source_esolver/esolver_ks_pw.h index 6a6be52b73..323e2df5a2 100644 --- a/source/source_esolver/esolver_ks_pw.h +++ b/source/source_esolver/esolver_ks_pw.h @@ -60,9 +60,6 @@ class ESolver_KS_PW : public ESolver_KS // for get_pchg and get_wf, use ctx as input of fft Device* ctx = {}; - // for device to host data transformation - base_device::AbacusDevice_t device = {}; - }; } // namespace ModuleESolver #endif diff --git a/source/source_io/module_ctrl/ctrl_output_pw.cpp b/source/source_io/module_ctrl/ctrl_output_pw.cpp index 2f0c157a82..8e2f0c3918 100644 --- a/source/source_io/module_ctrl/ctrl_output_pw.cpp +++ b/source/source_io/module_ctrl/ctrl_output_pw.cpp @@ -92,7 +92,6 @@ void ModuleIO::ctrl_scf_pw(const int istep, const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw &stp, const Device* ctx, - const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp) { @@ -100,7 +99,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, ModuleBase::timer::tick("ModuleIO", "ctrl_scf_pw"); // Transfer data from device (GPU) to host (CPU) in pw basis - stp.copy_d2h(device); + stp.copy_d2h(ctx); //---------------------------------------------------------- //! 4) Compute density of states (DOS) @@ -386,7 +385,6 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CPU const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw, base_device::DEVICE_CPU> &stp, const base_device::DEVICE_CPU* ctx, - const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -403,7 +401,6 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CP const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw, base_device::DEVICE_CPU> &stp, const base_device::DEVICE_CPU* ctx, - const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -421,7 +418,6 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GPU const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw, base_device::DEVICE_GPU> &stp, const base_device::DEVICE_GPU* ctx, - const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -438,7 +434,6 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GP const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw, base_device::DEVICE_GPU> &stp, const base_device::DEVICE_GPU* ctx, - const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp); #endif diff --git a/source/source_io/module_ctrl/ctrl_output_pw.h b/source/source_io/module_ctrl/ctrl_output_pw.h index 798629c55e..3ac7a2ab9c 100644 --- a/source/source_io/module_ctrl/ctrl_output_pw.h +++ b/source/source_io/module_ctrl/ctrl_output_pw.h @@ -31,7 +31,6 @@ void ctrl_scf_pw(const int istep, const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw &stp, const Device* ctx, - const base_device::AbacusDevice_t &device, // mohan add 2025-10-15 const Parallel_Grid ¶_grid, const Input_para& inp); diff --git a/source/source_psi/setup_psi_pw.cpp b/source/source_psi/setup_psi_pw.cpp index 14e564c4fb..c7428cfd7d 100644 --- a/source/source_psi/setup_psi_pw.cpp +++ b/source/source_psi/setup_psi_pw.cpp @@ -62,9 +62,9 @@ void Setup_Psi_pw::init(hamilt::Hamilt* p_hamilt) // Transfer data from GPU to CPU in pw basis template -void Setup_Psi_pw::copy_d2h(const base_device::AbacusDevice_t &device) +void Setup_Psi_pw::copy_d2h(const Device* ctx) { - if (device == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { castmem_2d_d2h_op()(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), this->psi_t[0].get_pointer() - this->psi_t[0].get_psi_bias(), diff --git a/source/source_psi/setup_psi_pw.h b/source/source_psi/setup_psi_pw.h index 13bf593f37..1e79664e2b 100644 --- a/source/source_psi/setup_psi_pw.h +++ b/source/source_psi/setup_psi_pw.h @@ -59,7 +59,7 @@ class Setup_Psi_pw void update_psi_d(); // Transfer data from device to host in pw basis - void copy_d2h(const base_device::AbacusDevice_t &device); + void copy_d2h(const Device* ctx); void clean(); From dc9450ae713a43e0bd6cb48e5437b124cef5ef51 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Thu, 12 Mar 2026 09:03:01 +0800 Subject: [PATCH 16/40] style(esolver): explicitly initialize ctx to nullptr in constructor --- source/source_esolver/esolver_ks_pw.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 6506432352..0ea4c31d1d 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -43,6 +43,7 @@ ESolver_KS_PW::ESolver_KS_PW() { this->classname = "ESolver_KS_PW"; this->basisname = "PW"; + this->ctx = nullptr; } template From 3b108eb60472867871aee91ae911fe95e16a5800 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Thu, 12 Mar 2026 14:16:15 +0800 Subject: [PATCH 17/40] feat(device): add runtime device type support to DeviceContext - Add device_type_ member variable to DeviceContext class - Add set_device_type() and get_device_type() methods - Add is_cpu(), is_gpu(), is_dsp() convenience methods - Add get_device_type(const DeviceContext*) overload for runtime device type query - Maintain backward compatibility with existing template-based get_device_type --- source/source_base/module_device/device.h | 41 +++++++++++++++++++ .../module_device/device_helpers.h | 12 +++++- 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/source/source_base/module_device/device.h b/source/source_base/module_device/device.h index afa55d5d6e..395d8c470d 100644 --- a/source/source_base/module_device/device.h +++ b/source/source_base/module_device/device.h @@ -145,6 +145,36 @@ class DeviceContext { */ int get_local_rank() const { return local_rank_; } + /** + * @brief Set the device type (CpuDevice, GpuDevice, or DspDevice) + * @param type The device type + */ + void set_device_type(AbacusDevice_t type) { device_type_ = type; } + + /** + * @brief Get the device type + * @return AbacusDevice_t The device type + */ + AbacusDevice_t get_device_type() const { return device_type_; } + + /** + * @brief Check if the device is CPU + * @return true if the device is CPU + */ + bool is_cpu() const { return device_type_ == CpuDevice; } + + /** + * @brief Check if the device is GPU + * @return true if the device is GPU + */ + bool is_gpu() const { return device_type_ == GpuDevice; } + + /** + * @brief Check if the device is DSP + * @return true if the device is DSP + */ + bool is_dsp() const { return device_type_ == DspDevice; } + // Disable copy and assignment DeviceContext(const DeviceContext&) = delete; DeviceContext& operator=(const DeviceContext&) = delete; @@ -158,10 +188,21 @@ class DeviceContext { int device_id_ = -1; int device_count_ = 0; int local_rank_ = 0; + AbacusDevice_t device_type_ = CpuDevice; std::mutex init_mutex_; }; +/** + * @brief Get the device type enum from DeviceContext (runtime version). + * @param ctx Pointer to DeviceContext + * @return AbacusDevice_t enum value + */ +inline AbacusDevice_t get_device_type(const DeviceContext* ctx) +{ + return ctx->get_device_type(); +} + } // end of namespace base_device #endif // MODULE_DEVICE_H_ diff --git a/source/source_base/module_device/device_helpers.h b/source/source_base/module_device/device_helpers.h index 6aa71938de..2870eea2d7 100644 --- a/source/source_base/module_device/device_helpers.h +++ b/source/source_base/module_device/device_helpers.h @@ -18,8 +18,18 @@ namespace base_device { +// Forward declaration +class DeviceContext; + +/** + * @brief Get the device type enum from DeviceContext (runtime version). + * @param ctx Pointer to DeviceContext + * @return AbacusDevice_t enum value + */ +inline AbacusDevice_t get_device_type(const DeviceContext* ctx); + /** - * @brief Get the device type enum for a given device type. + * @brief Get the device type enum for a given device type (compile-time version). * @tparam Device The device type (DEVICE_CPU or DEVICE_GPU) * @param dev Pointer to device (used for template deduction) * @return AbacusDevice_t enum value From 77d081f7f3b3c25b10f4ab493035e1c21d965be9 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Thu, 12 Mar 2026 16:23:35 +0800 Subject: [PATCH 18/40] feat(device): add runtime device context overloads for gradual migration - Add copy_d2h(const DeviceContext*) overload to Setup_Psi_pw - Add ctrl_scf_pw(..., const DeviceContext*, ...) overload - Add ctrl_runner_pw(..., const DeviceContext*, ...) overload - Keep original functions for backward compatibility - Replace tabs with spaces in modified files --- source/source_io/module_ctrl/ctrl_output_pw.h | 59 +++++++++++++++---- source/source_psi/setup_psi_pw.cpp | 35 ++++++++--- source/source_psi/setup_psi_pw.h | 15 +++-- 3 files changed, 81 insertions(+), 28 deletions(-) diff --git a/source/source_io/module_ctrl/ctrl_output_pw.h b/source/source_io/module_ctrl/ctrl_output_pw.h index 3ac7a2ab9c..262fc782a8 100644 --- a/source/source_io/module_ctrl/ctrl_output_pw.h +++ b/source/source_io/module_ctrl/ctrl_output_pw.h @@ -11,11 +11,11 @@ namespace ModuleIO // print out information in 'iter_finish' in ESolver_KS_PW void ctrl_iter_pw(const int istep, - const int iter, - const double &conv_esolver, - psi::Psi, base_device::DEVICE_CPU>* psi, - const K_Vectors &kv, - const ModulePW::PW_Basis_K *pw_wfc, + const int iter, + const double &conv_esolver, + psi::Psi, base_device::DEVICE_CPU>* psi, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, const Input_para& inp); // print out information in 'after_scf' in ESolver_KS_PW @@ -24,32 +24,65 @@ void ctrl_scf_pw(const int istep, UnitCell& ucell, elecstate::ElecState* pelec, const Charge &chr, - const K_Vectors &kv, - const ModulePW::PW_Basis_K *pw_wfc, - const ModulePW::PW_Basis *pw_rho, - const ModulePW::PW_Basis *pw_rhod, - const ModulePW::PW_Basis_Big *pw_big, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, + const ModulePW::PW_Basis *pw_rho, + const ModulePW::PW_Basis *pw_rhod, + const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw &stp, const Device* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); +// print out information in 'after_scf' in ESolver_KS_PW (runtime version) +template +void ctrl_scf_pw(const int istep, + UnitCell& ucell, + elecstate::ElecState* pelec, + const Charge &chr, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, + const ModulePW::PW_Basis *pw_rho, + const ModulePW::PW_Basis *pw_rhod, + const ModulePW::PW_Basis_Big *pw_big, + Setup_Psi_pw &stp, + const base_device::DeviceContext* ctx, + const Parallel_Grid ¶_grid, + const Input_para& inp); + // print out information in 'after_all_runners' in ESolver_KS_PW template void ctrl_runner_pw(UnitCell& ucell, - elecstate::ElecState* pelec, + elecstate::ElecState* pelec, ModulePW::PW_Basis_K* pw_wfc, ModulePW::PW_Basis* pw_rho, ModulePW::PW_Basis* pw_rhod, - Charge &chr, + Charge &chr, K_Vectors &kv, Setup_Psi_pw &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, - surchem &solvent, + surchem &solvent, const Device* ctx, Parallel_Grid ¶_grid, const Input_para& inp); +// print out information in 'after_all_runners' in ESolver_KS_PW (runtime version) +template +void ctrl_runner_pw(UnitCell& ucell, + elecstate::ElecState* pelec, + ModulePW::PW_Basis_K* pw_wfc, + ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, + Charge &chr, + K_Vectors &kv, + Setup_Psi_pw &stp, + Structure_Factor &sf, + pseudopot_cell_vnl &ppcell, + surchem &solvent, + const base_device::DeviceContext* ctx, + Parallel_Grid ¶_grid, + const Input_para& inp); + } #endif diff --git a/source/source_psi/setup_psi_pw.cpp b/source/source_psi/setup_psi_pw.cpp index c7428cfd7d..11b02b8512 100644 --- a/source/source_psi/setup_psi_pw.cpp +++ b/source/source_psi/setup_psi_pw.cpp @@ -9,12 +9,12 @@ Setup_Psi_pw::~Setup_Psi_pw(){} template void Setup_Psi_pw::before_runner( - const UnitCell &ucell, - const K_Vectors &kv, - const Structure_Factor &sf, - const ModulePW::PW_Basis_K &pw_wfc, - const pseudopot_cell_vnl &ppcell, - const Input_para &inp) + const UnitCell &ucell, + const K_Vectors &kv, + const Structure_Factor &sf, + const ModulePW::PW_Basis_K &pw_wfc, + const pseudopot_cell_vnl &ppcell, + const Input_para &inp) { //! Allocate and initialize psi this->p_psi_init = new psi::PSIPrepare(inp.init_wfc, @@ -70,10 +70,27 @@ void Setup_Psi_pw::copy_d2h(const Device* ctx) this->psi_t[0].get_pointer() - this->psi_t[0].get_psi_bias(), this->psi_cpu[0].size()); } - else - { + else + { + // do nothing + } + return; +} + +// Transfer data from GPU to CPU in pw basis (runtime version) +template +void Setup_Psi_pw::copy_d2h(const base_device::DeviceContext* ctx) +{ + if (base_device::get_device_type(ctx) == base_device::GpuDevice) + { + castmem_2d_d2h_op()(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), + this->psi_t[0].get_pointer() - this->psi_t[0].get_psi_bias(), + this->psi_cpu[0].size()); + } + else + { // do nothing - } + } return; } diff --git a/source/source_psi/setup_psi_pw.h b/source/source_psi/setup_psi_pw.h index 1e79664e2b..6e7a42467d 100644 --- a/source/source_psi/setup_psi_pw.h +++ b/source/source_psi/setup_psi_pw.h @@ -47,12 +47,12 @@ class Setup_Psi_pw //------------ void before_runner( - const UnitCell &ucell, - const K_Vectors &kv, - const Structure_Factor &sf, - const ModulePW::PW_Basis_K &pw_wfc, - const pseudopot_cell_vnl &ppcell, - const Input_para &inp); + const UnitCell &ucell, + const K_Vectors &kv, + const Structure_Factor &sf, + const ModulePW::PW_Basis_K &pw_wfc, + const pseudopot_cell_vnl &ppcell, + const Input_para &inp); void init(hamilt::Hamilt* p_hamilt); @@ -60,6 +60,9 @@ class Setup_Psi_pw // Transfer data from device to host in pw basis void copy_d2h(const Device* ctx); + + // Transfer data from device to host in pw basis (runtime version) + void copy_d2h(const base_device::DeviceContext* ctx); void clean(); From d27ebef6d3e199de759f16984dfc0a3eec53ecfe Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Fri, 13 Mar 2026 13:16:03 +0800 Subject: [PATCH 19/40] refactor(esolver): remove ctx member variable from ESolver_KS_PW - Remove Device* ctx member variable from ESolver_KS_PW class - Remove ctx parameter from ctrl_scf_pw and ctrl_runner_pw functions - Add local ctx variable inside ctrl_scf_pw and ctrl_runner_pw functions - Update all template instantiations to match new function signatures This refactoring simplifies the code by moving the ctx variable from a class member to a local variable within the functions that need it. The ctx variable is only used for template parameter deduction in copy_d2h and get_pchg_pw/get_wf_pw functions, so it doesn't need to be stored as a member variable. --- source/source_esolver/esolver_ks_pw.cpp | 5 +-- source/source_esolver/esolver_ks_pw.h | 3 -- .../source_io/module_ctrl/ctrl_output_pw.cpp | 30 ++++++-------- source/source_io/module_ctrl/ctrl_output_pw.h | 41 ++----------------- 4 files changed, 18 insertions(+), 61 deletions(-) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 0ea4c31d1d..169f1aa1df 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -43,7 +43,6 @@ ESolver_KS_PW::ESolver_KS_PW() { this->classname = "ESolver_KS_PW"; this->basisname = "PW"; - this->ctx = nullptr; } template @@ -251,7 +250,7 @@ void ESolver_KS_PW::after_scf(UnitCell& ucell, const int istep, const // Output quantities ModuleIO::ctrl_scf_pw(istep, ucell, this->pelec, this->chr, this->kv, this->pw_wfc, this->pw_rho, this->pw_rhod, this->pw_big, this->stp, - this->ctx, this->Pgrid, PARAM.inp); + this->Pgrid, PARAM.inp); ModuleBase::timer::tick("ESolver_KS_PW", "after_scf"); } @@ -304,7 +303,7 @@ void ESolver_KS_PW::after_all_runners(UnitCell& ucell) ModuleIO::ctrl_runner_pw(ucell, this->pelec, this->pw_wfc, this->pw_rho, this->pw_rhod, this->chr, this->kv, this->stp, - this->sf, this->ppcell, this->solvent, this->ctx, this->Pgrid, PARAM.inp); + this->sf, this->ppcell, this->solvent, this->Pgrid, PARAM.inp); elecstate::teardown_estate_pw(this->pelec, this->vsep_cell); diff --git a/source/source_esolver/esolver_ks_pw.h b/source/source_esolver/esolver_ks_pw.h index 323e2df5a2..5ab756647d 100644 --- a/source/source_esolver/esolver_ks_pw.h +++ b/source/source_esolver/esolver_ks_pw.h @@ -57,9 +57,6 @@ class ESolver_KS_PW : public ESolver_KS // DFT-1/2 method VSep* vsep_cell = nullptr; - // for get_pchg and get_wf, use ctx as input of fft - Device* ctx = {}; - }; } // namespace ModuleESolver #endif diff --git a/source/source_io/module_ctrl/ctrl_output_pw.cpp b/source/source_io/module_ctrl/ctrl_output_pw.cpp index 8e2f0c3918..625fd34809 100644 --- a/source/source_io/module_ctrl/ctrl_output_pw.cpp +++ b/source/source_io/module_ctrl/ctrl_output_pw.cpp @@ -91,13 +91,15 @@ void ModuleIO::ctrl_scf_pw(const int istep, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw &stp, - const Device* ctx, const Parallel_Grid ¶_grid, const Input_para& inp) { ModuleBase::TITLE("ModuleIO", "ctrl_scf_pw"); ModuleBase::timer::tick("ModuleIO", "ctrl_scf_pw"); + // Create local ctx for device type deduction + Device* ctx = nullptr; + // Transfer data from device (GPU) to host (CPU) in pw basis stp.copy_d2h(ctx); @@ -255,13 +257,15 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, - const Device* ctx, Parallel_Grid ¶_grid, const Input_para& inp) { ModuleBase::TITLE("ModuleIO", "ctrl_runner_pw"); ModuleBase::timer::tick("ModuleIO", "ctrl_runner_pw"); + // Create local ctx for device type deduction + Device* ctx = nullptr; + //---------------------------------------------------------- //! 1) Compute LDOS //---------------------------------------------------------- @@ -384,7 +388,6 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CPU const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw, base_device::DEVICE_CPU> &stp, - const base_device::DEVICE_CPU* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -400,7 +403,6 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CP const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw, base_device::DEVICE_CPU> &stp, - const base_device::DEVICE_CPU* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -417,13 +419,12 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GPU const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw, base_device::DEVICE_GPU> &stp, - const base_device::DEVICE_GPU* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); // complex + GPU template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GPU>( - const int nstep, + const int nstep, UnitCell& ucell, elecstate::ElecState* pelec, const Charge &chr, @@ -433,7 +434,6 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GP const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw, base_device::DEVICE_GPU> &stp, - const base_device::DEVICE_GPU* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); #endif @@ -444,14 +444,13 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_ elecstate::ElecState* pelec, ModulePW::PW_Basis_K* pw_wfc, ModulePW::PW_Basis* pw_rho, - ModulePW::PW_Basis* pw_rhod, + ModulePW::PW_Basis* pw_rhod, Charge &chr, - K_Vectors &kv, + K_Vectors &kv, Setup_Psi_pw, base_device::DEVICE_CPU> &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, - const base_device::DEVICE_CPU* ctx, Parallel_Grid ¶_grid, const Input_para& inp); @@ -461,14 +460,13 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE elecstate::ElecState* pelec, ModulePW::PW_Basis_K* pw_wfc, ModulePW::PW_Basis* pw_rho, - ModulePW::PW_Basis* pw_rhod, + ModulePW::PW_Basis* pw_rhod, Charge &chr, - K_Vectors &kv, + K_Vectors &kv, Setup_Psi_pw, base_device::DEVICE_CPU> &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, - const base_device::DEVICE_CPU* ctx, Parallel_Grid ¶_grid, const Input_para& inp); @@ -481,12 +479,11 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_ ModulePW::PW_Basis* pw_rho, ModulePW::PW_Basis* pw_rhod, Charge &chr, - K_Vectors &kv, + K_Vectors &kv, Setup_Psi_pw, base_device::DEVICE_GPU> &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, - const base_device::DEVICE_GPU* ctx, Parallel_Grid ¶_grid, const Input_para& inp); @@ -498,12 +495,11 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE ModulePW::PW_Basis* pw_rho, ModulePW::PW_Basis* pw_rhod, Charge &chr, - K_Vectors &kv, + K_Vectors &kv, Setup_Psi_pw, base_device::DEVICE_GPU> &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, - const base_device::DEVICE_GPU* ctx, Parallel_Grid ¶_grid, const Input_para& inp); #endif diff --git a/source/source_io/module_ctrl/ctrl_output_pw.h b/source/source_io/module_ctrl/ctrl_output_pw.h index 262fc782a8..2cbaa73309 100644 --- a/source/source_io/module_ctrl/ctrl_output_pw.h +++ b/source/source_io/module_ctrl/ctrl_output_pw.h @@ -30,57 +30,22 @@ void ctrl_scf_pw(const int istep, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw &stp, - const Device* ctx, - const Parallel_Grid ¶_grid, - const Input_para& inp); - -// print out information in 'after_scf' in ESolver_KS_PW (runtime version) -template -void ctrl_scf_pw(const int istep, - UnitCell& ucell, - elecstate::ElecState* pelec, - const Charge &chr, - const K_Vectors &kv, - const ModulePW::PW_Basis_K *pw_wfc, - const ModulePW::PW_Basis *pw_rho, - const ModulePW::PW_Basis *pw_rhod, - const ModulePW::PW_Basis_Big *pw_big, - Setup_Psi_pw &stp, - const base_device::DeviceContext* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); // print out information in 'after_all_runners' in ESolver_KS_PW template void ctrl_runner_pw(UnitCell& ucell, - elecstate::ElecState* pelec, - ModulePW::PW_Basis_K* pw_wfc, - ModulePW::PW_Basis* pw_rho, - ModulePW::PW_Basis* pw_rhod, - Charge &chr, - K_Vectors &kv, - Setup_Psi_pw &stp, - Structure_Factor &sf, - pseudopot_cell_vnl &ppcell, - surchem &solvent, - const Device* ctx, - Parallel_Grid ¶_grid, - const Input_para& inp); - -// print out information in 'after_all_runners' in ESolver_KS_PW (runtime version) -template -void ctrl_runner_pw(UnitCell& ucell, - elecstate::ElecState* pelec, + elecstate::ElecState* pelec, ModulePW::PW_Basis_K* pw_wfc, ModulePW::PW_Basis* pw_rho, ModulePW::PW_Basis* pw_rhod, - Charge &chr, + Charge &chr, K_Vectors &kv, Setup_Psi_pw &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, - surchem &solvent, - const base_device::DeviceContext* ctx, + surchem &solvent, Parallel_Grid ¶_grid, const Input_para& inp); From 4ee54b4103a4d76ae3187396b458d35f13ff6236 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Fri, 13 Mar 2026 16:10:05 +0800 Subject: [PATCH 20/40] refactor(psi): add runtime type information to Setup_Psi_pw - Add runtime type information (device_type_ and precision_type_) to Setup_Psi_pw - Add accessor functions for basic information (get_nbands, get_nk, get_nbasis, size) - Add accessor functions for runtime type information - Add get_psi_t() function for backward compatibility This is the first step of a gradual refactoring to remove template parameters from Setup_Psi_pw in the future. The current changes are backward compatible and do not affect existing functionality. --- source/source_psi/setup_psi_pw.cpp | 17 ++++++++++++++++ source/source_psi/setup_psi_pw.h | 32 ++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/source/source_psi/setup_psi_pw.cpp b/source/source_psi/setup_psi_pw.cpp index 11b02b8512..e87bc7ccde 100644 --- a/source/source_psi/setup_psi_pw.cpp +++ b/source/source_psi/setup_psi_pw.cpp @@ -26,6 +26,23 @@ void Setup_Psi_pw::before_runner( this->p_psi_init->prepare_init(inp.pw_seed); + //! Set runtime type information + if (std::is_same::value) { + precision_type_ = PrecisionType::Float; + } else if (std::is_same::value) { + precision_type_ = PrecisionType::Double; + } else if (std::is_same>::value) { + precision_type_ = PrecisionType::ComplexFloat; + } else { + precision_type_ = PrecisionType::ComplexDouble; + } + + if (std::is_same::value) { + device_type_ = base_device::GpuDevice; + } else { + device_type_ = base_device::CpuDevice; + } + //! If GPU or single precision, allocate a new psi (psi_t). //! otherwise, transform psi_cpu to psi_t this->psi_t = inp.device == "gpu" || inp.precision == "single" diff --git a/source/source_psi/setup_psi_pw.h b/source/source_psi/setup_psi_pw.h index 6e7a42467d..e38765cbd3 100644 --- a/source/source_psi/setup_psi_pw.h +++ b/source/source_psi/setup_psi_pw.h @@ -42,6 +42,20 @@ class Setup_Psi_pw bool already_initpsi = false; + //------------ + // runtime type information + //------------ + base_device::AbacusDevice_t device_type_ = base_device::CpuDevice; + + // Precision type: 0 = float, 1 = double, 2 = complex, 3 = complex + enum class PrecisionType { + Float = 0, + Double = 1, + ComplexFloat = 2, + ComplexDouble = 3 + }; + PrecisionType precision_type_ = PrecisionType::ComplexDouble; + //------------ // functions //------------ @@ -66,6 +80,24 @@ class Setup_Psi_pw void clean(); + //------------ + // accessor functions + //------------ + + // Get basic information (no type conversion needed, use psi_cpu) + int get_nbands() const { return this->psi_cpu->get_nbands(); } + int get_nk() const { return this->psi_cpu->get_nk(); } + int get_nbasis() const { return this->psi_cpu->get_nbasis(); } + size_t size() const { return this->psi_cpu->size(); } + + // Get runtime type information + base_device::AbacusDevice_t get_device_type() const { return device_type_; } + PrecisionType get_precision_type() const { return precision_type_; } + + // Get psi_t pointer (template version, for backward compatibility) + psi::Psi* get_psi_t() { return psi_t; } + const psi::Psi* get_psi_t() const { return psi_t; } + private: using castmem_2d_d2h_op From e38151044feb002ec5409b6e3ae71bab8793b258 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Fri, 13 Mar 2026 16:44:29 +0800 Subject: [PATCH 21/40] refactor(esolver): use get_psi_t() accessor instead of direct psi_t access - Replace all direct access to stp.psi_t with stp.get_psi_t() - Replace stp.psi_t->get_nbands() with stp.get_nbands() - This is the second step of gradual refactoring to prepare for removing template parameters Modified files: - source/source_esolver/esolver_ks_pw.cpp - source/source_esolver/esolver_sdft_pw.cpp - source/source_esolver/esolver_ks_lcaopw.cpp - source/source_io/module_ctrl/ctrl_output_pw.cpp --- source/source_esolver/esolver_ks_lcaopw.cpp | 6 +++--- source/source_esolver/esolver_ks_pw.cpp | 12 ++++++------ source/source_esolver/esolver_sdft_pw.cpp | 8 ++++---- source/source_io/module_ctrl/ctrl_output_pw.cpp | 10 +++++----- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/source/source_esolver/esolver_ks_lcaopw.cpp b/source/source_esolver/esolver_ks_lcaopw.cpp index db00b6265d..b30110046b 100644 --- a/source/source_esolver/esolver_ks_lcaopw.cpp +++ b/source/source_esolver/esolver_ks_lcaopw.cpp @@ -105,7 +105,7 @@ namespace ModuleESolver ucell.symm, &this->kv, this->psi_local, - this->stp.psi_t, + this->stp.get_psi_t(), this->pw_wfc, this->pw_rho, this->sf, @@ -146,7 +146,7 @@ namespace ModuleESolver bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false; hsolver::HSolverLIP hsolver_lip_obj(this->pw_wfc); - hsolver_lip_obj.solve(static_cast*>(this->p_hamilt), this->stp.psi_t[0], this->pelec, + hsolver_lip_obj.solve(static_cast*>(this->p_hamilt), *this->stp.get_psi_t(), this->pelec, *this->psi_local, skip_charge,ucell.tpiba,ucell.nat); // add exx @@ -240,7 +240,7 @@ namespace ModuleESolver ModuleIO::write_Vxc(PARAM.inp.nspin, PARAM.globalv.nlocal, GlobalV::DRANK, - *this->stp.psi_t, + *this->stp.get_psi_t(), ucell, this->sf, this->solvent, diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 169f1aa1df..b9d852d88c 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -127,13 +127,13 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) // init DFT+U is done in "before_all_runners" in LCAO basis. This should be refactored, mohan note 2025-11-06 pw::setup_pot(istep, ucell, this->kv, this->sf, this->pelec, this->Pgrid, this->chr, this->locpp, this->ppcell, this->dftu, this->vsep_cell, - this->stp.psi_t, static_cast*>(this->p_hamilt), this->pw_wfc, this->pw_rhod, PARAM.inp); + this->stp.get_psi_t(), static_cast*>(this->p_hamilt), this->pw_wfc, this->pw_rhod, PARAM.inp); // setup psi (electronic wave functions) this->stp.init(static_cast*>(this->p_hamilt)); //! Setup EXX helper for Hamiltonian and psi - exx_helper.before_scf(this->p_hamilt, this->stp.psi_t, PARAM.inp); + exx_helper.before_scf(this->p_hamilt, this->stp.get_psi_t(), PARAM.inp); ModuleBase::timer::tick("ESolver_KS_PW", "before_scf"); } @@ -151,7 +151,7 @@ void ESolver_KS_PW::iter_init(UnitCell& ucell, const int istep, const // update local occupations for DFT+U // should before lambda loop in DeltaSpin - pw::iter_init_dftu_pw(iter, istep, this->dftu, this->stp.psi_t, this->pelec->wg, ucell, PARAM.inp); + pw::iter_init_dftu_pw(iter, istep, this->dftu, this->stp.get_psi_t(), this->pelec->wg, ucell, PARAM.inp); } // Temporary, it should be replaced by hsolver later. @@ -187,7 +187,7 @@ void ESolver_KS_PW::hamilt2rho_single(UnitCell& ucell, const int iste hsolver::DiagoIterAssist::need_subspace, PARAM.inp.use_k_continuity); - hsolver_pw_obj.solve(static_cast*>(this->p_hamilt), this->stp.psi_t[0], this->pelec, this->pelec->ekb.c, + hsolver_pw_obj.solve(static_cast*>(this->p_hamilt), *this->stp.get_psi_t(), this->pelec, this->pelec->ekb.c, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL, skip_charge, ucell.tpiba, ucell.nat); } @@ -204,7 +204,7 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int // Related to EXX if (GlobalC::exx_info.info_global.cal_exx && !exx_helper.op_exx->first_iter) { - this->pelec->set_exx(exx_helper.cal_exx_energy(this->stp.psi_t)); + this->pelec->set_exx(exx_helper.cal_exx_energy(this->stp.get_psi_t())); } // deband is calculated from "output" charge density @@ -223,7 +223,7 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int } // Handle EXX-related operations after SCF iteration - exx_helper.iter_finish(this->pelec, &this->chr, this->stp.psi_t, ucell, PARAM.inp, conv_esolver, iter); + exx_helper.iter_finish(this->pelec, &this->chr, this->stp.get_psi_t(), ucell, PARAM.inp, conv_esolver, iter); // check if oscillate for delta_spin method pw::check_deltaspin_oscillation(iter, this->drho, this->p_chgmix, PARAM.inp); diff --git a/source/source_esolver/esolver_sdft_pw.cpp b/source/source_esolver/esolver_sdft_pw.cpp index f7f9a29983..86658eb645 100644 --- a/source/source_esolver/esolver_sdft_pw.cpp +++ b/source/source_esolver/esolver_sdft_pw.cpp @@ -168,7 +168,7 @@ void ESolver_SDFT_PW::hamilt2rho_single(UnitCell& ucell, int istep, i hsolver_pw_sdft_obj.solve(ucell, static_cast*>(this->p_hamilt), - this->stp.psi_t[0], + *this->stp.get_psi_t(), this->stp.psi_cpu[0], this->pelec, this->pw_wfc, @@ -221,7 +221,7 @@ void ESolver_SDFT_PW::cal_force(UnitCell& ucell, ModuleBase::matrix& this->locpp, this->ppcell, ucell, - *this->stp.psi_t, + *this->stp.get_psi_t(), this->stowf); } @@ -236,7 +236,7 @@ void ESolver_SDFT_PW::cal_stress(UnitCell& ucell, ModuleBase::matrix& &this->sf, &this->kv, this->pw_wfc, - *this->stp.psi_t, + *this->stp.get_psi_t(), this->stowf, &this->chr, &this->locpp, @@ -289,7 +289,7 @@ void ESolver_SDFT_PW::after_all_runners(UnitCell& ucell) &this->kv, this->pelec, this->pw_wfc, - this->stp.psi_t, + this->stp.get_psi_t(), &this->ppcell, static_cast, Device>*>(this->p_hamilt), this->stoche, diff --git a/source/source_io/module_ctrl/ctrl_output_pw.cpp b/source/source_io/module_ctrl/ctrl_output_pw.cpp index 625fd34809..396388d824 100644 --- a/source/source_io/module_ctrl/ctrl_output_pw.cpp +++ b/source/source_io/module_ctrl/ctrl_output_pw.cpp @@ -166,7 +166,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, // update psi_d stp.update_psi_d(); - const int nbands = stp.psi_t->get_nbands(); + const int nbands = stp.get_nbands(); const int ngmc = chr.ngmc; ModuleIO::get_pchg_pw(inp.out_pchg, @@ -237,7 +237,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, if (inp.onsite_radius > 0) { // float type has not been implemented auto* onsite_p = projectors::OnsiteProjector::get_instance(); - onsite_p->cal_occupations(reinterpret_cast, Device>*>(stp.psi_t), + onsite_p->cal_occupations(reinterpret_cast, Device>*>(stp.get_psi_t()), pelec->wg); } @@ -307,7 +307,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, ModuleIO::get_wf_pw(inp.out_wfc_norm, inp.out_wfc_re_im, - stp.psi_t->get_nbands(), + stp.get_nbands(), inp.nspin, pw_rhod->nxyz, &ucell, @@ -327,7 +327,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, if (inp.cal_cond) { using Real = typename GetTypeReal::type; - EleCond elec_cond(&ucell, &kv, pelec, pw_wfc, stp.psi_t, &ppcell); + EleCond elec_cond(&ucell, &kv, pelec, pw_wfc, stp.get_psi_t(), &ppcell); elec_cond.KG(inp.cond_smear, inp.cond_fwhm, inp.cond_wcut, @@ -364,7 +364,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, pw_rho); write_mlkedf_desc.generateTrainData_KS(PARAM.globalv.global_mlkedf_descriptor_dir, - stp.psi_t, + stp.get_psi_t(), pelec, pw_wfc, pw_rho, From 9dff99118f15b0f35aab498513a8c28d92bcd364 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Fri, 13 Mar 2026 17:21:29 +0800 Subject: [PATCH 22/40] refactor(psi): change psi_t from template pointer to void* - Change psi_t from psi::Psi* to void* - Add static_cast in get_psi_t() function for type conversion - Update all functions that use psi_t to use get_psi_t() or static_cast - This is the third step of gradual refactoring to remove template parameters Modified functions: - before_runner: use if-else instead of ternary operator for void* assignment - update_psi_d: use get_psi_t() to access psi_t - init: use get_psi_t() to access psi_t - copy_d2h: use get_psi_t() to access psi_t - clean: use get_psi_t() to delete psi_t --- source/source_psi/setup_psi_pw.cpp | 20 ++++++++++++-------- source/source_psi/setup_psi_pw.h | 7 ++++--- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/source/source_psi/setup_psi_pw.cpp b/source/source_psi/setup_psi_pw.cpp index e87bc7ccde..b411327b26 100644 --- a/source/source_psi/setup_psi_pw.cpp +++ b/source/source_psi/setup_psi_pw.cpp @@ -45,9 +45,11 @@ void Setup_Psi_pw::before_runner( //! If GPU or single precision, allocate a new psi (psi_t). //! otherwise, transform psi_cpu to psi_t - this->psi_t = inp.device == "gpu" || inp.precision == "single" - ? new psi::Psi(this->psi_cpu[0]) - : reinterpret_cast*>(this->psi_cpu); + if (inp.device == "gpu" || inp.precision == "single") { + this->psi_t = static_cast(new psi::Psi(this->psi_cpu[0])); + } else { + this->psi_t = static_cast(reinterpret_cast*>(this->psi_cpu)); + } } @@ -61,7 +63,7 @@ void Setup_Psi_pw::update_psi_d() // Refresh this->psi_d this->psi_d = PARAM.inp.precision == "single" - ? new psi::Psi, Device>(this->psi_t[0]) + ? new psi::Psi, Device>(*this->get_psi_t()) : reinterpret_cast, Device>*>(this->psi_t); } @@ -71,7 +73,7 @@ void Setup_Psi_pw::init(hamilt::Hamilt* p_hamilt) //! Initialize wave functions if (!this->already_initpsi) { - this->p_psi_init->initialize_psi(this->psi_cpu, this->psi_t, p_hamilt, GlobalV::ofs_running); + this->p_psi_init->initialize_psi(this->psi_cpu, this->get_psi_t(), p_hamilt, GlobalV::ofs_running); this->already_initpsi = true; } } @@ -83,8 +85,9 @@ void Setup_Psi_pw::copy_d2h(const Device* ctx) { if (base_device::get_device_type(ctx) == base_device::GpuDevice) { + auto* psi_t = this->get_psi_t(); castmem_2d_d2h_op()(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), - this->psi_t[0].get_pointer() - this->psi_t[0].get_psi_bias(), + psi_t->get_pointer() - psi_t->get_psi_bias(), this->psi_cpu[0].size()); } else @@ -100,8 +103,9 @@ void Setup_Psi_pw::copy_d2h(const base_device::DeviceContext* ctx) { if (base_device::get_device_type(ctx) == base_device::GpuDevice) { + auto* psi_t = this->get_psi_t(); castmem_2d_d2h_op()(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), - this->psi_t[0].get_pointer() - this->psi_t[0].get_psi_bias(), + psi_t->get_pointer() - psi_t->get_psi_bias(), this->psi_cpu[0].size()); } else @@ -118,7 +122,7 @@ void Setup_Psi_pw::clean() { if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") { - delete this->psi_t; + delete this->get_psi_t(); } if (PARAM.inp.precision == "single") { diff --git a/source/source_psi/setup_psi_pw.h b/source/source_psi/setup_psi_pw.h index e38765cbd3..b56fadcc07 100644 --- a/source/source_psi/setup_psi_pw.h +++ b/source/source_psi/setup_psi_pw.h @@ -32,7 +32,8 @@ class Setup_Psi_pw // originally, this term is kspw_psi // if CPU, kspw_psi = psi, otherwise, kspw_psi has a new copy - psi::Psi* psi_t = nullptr; + // psi::Psi* psi_t = nullptr; // 原来的模板版本 + void* psi_t = nullptr; // 使用 void* 存储指针,运行时类型信息记录实际类型 // originally, this term is __kspw_psi psi::Psi, Device>* psi_d = nullptr; @@ -95,8 +96,8 @@ class Setup_Psi_pw PrecisionType get_precision_type() const { return precision_type_; } // Get psi_t pointer (template version, for backward compatibility) - psi::Psi* get_psi_t() { return psi_t; } - const psi::Psi* get_psi_t() const { return psi_t; } + psi::Psi* get_psi_t() { return static_cast*>(psi_t); } + const psi::Psi* get_psi_t() const { return static_cast*>(psi_t); } private: From 7e5fcd6bbbfa2690421bfeea9c5d7fa9f4b7c515 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Fri, 13 Mar 2026 17:30:01 +0800 Subject: [PATCH 23/40] style: replace Chinese comments with English in setup_psi_pw.h MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace '原来的模板版本' with 'Original template version' - Replace '使用 void* 存储指针,运行时类型信息记录实际类型' with 'Use void* to store pointer, runtime type information records actual type' - Follow ABACUS code style guidelines for English-only comments --- source/source_psi/setup_psi_pw.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/source_psi/setup_psi_pw.h b/source/source_psi/setup_psi_pw.h index b56fadcc07..5fa14b25ee 100644 --- a/source/source_psi/setup_psi_pw.h +++ b/source/source_psi/setup_psi_pw.h @@ -32,8 +32,8 @@ class Setup_Psi_pw // originally, this term is kspw_psi // if CPU, kspw_psi = psi, otherwise, kspw_psi has a new copy - // psi::Psi* psi_t = nullptr; // 原来的模板版本 - void* psi_t = nullptr; // 使用 void* 存储指针,运行时类型信息记录实际类型 + // psi::Psi* psi_t = nullptr; // Original template version + void* psi_t = nullptr; // Use void* to store pointer, runtime type information records actual type // originally, this term is __kspw_psi psi::Psi, Device>* psi_d = nullptr; From 627dd0de251e0db46552484f8ba37e10a7b96114 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Fri, 13 Mar 2026 17:49:03 +0800 Subject: [PATCH 24/40] refactor(psi): change psi_d from template pointer to void* - Change psi_d from psi::Psi, Device>* to void* - Add get_psi_d() accessor function for type conversion - Update all functions that use psi_d to use get_psi_d() - This is part of step 1 in phase 4 of gradual refactoring Modified files: - source/source_psi/setup_psi_pw.h - source/source_psi/setup_psi_pw.cpp - source/source_esolver/esolver_ks_pw.cpp - source/source_io/module_ctrl/ctrl_output_pw.cpp --- source/source_esolver/esolver_ks_pw.cpp | 4 ++-- source/source_io/module_ctrl/ctrl_output_pw.cpp | 4 ++-- source/source_psi/setup_psi_pw.cpp | 12 +++++++----- source/source_psi/setup_psi_pw.h | 11 ++++++++++- 4 files changed, 21 insertions(+), 10 deletions(-) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index b9d852d88c..9d5581d499 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -272,7 +272,7 @@ void ESolver_KS_PW::cal_force(UnitCell& ucell, ModuleBase::matrix& fo // Calculate forces ff.cal_force(ucell, force, *this->pelec, this->pw_rhod, &ucell.symm, &this->sf, this->solvent, &this->dftu, &this->locpp, &this->ppcell, - &this->kv, this->pw_wfc, this->stp.psi_d); + &this->kv, this->pw_wfc, this->stp.get_psi_d()); } template @@ -284,7 +284,7 @@ void ESolver_KS_PW::cal_stress(UnitCell& ucell, ModuleBase::matrix& s this->stp.update_psi_d(); ss.cal_stress(stress, ucell, this->dftu, this->locpp, this->ppcell, this->pw_rhod, - &ucell.symm, &this->sf, &this->kv, this->pw_wfc, this->stp.psi_d); + &ucell.symm, &this->sf, &this->kv, this->pw_wfc, this->stp.get_psi_d()); // external stress double unit_transform = 0.0; diff --git a/source/source_io/module_ctrl/ctrl_output_pw.cpp b/source/source_io/module_ctrl/ctrl_output_pw.cpp index 396388d824..4905828431 100644 --- a/source/source_io/module_ctrl/ctrl_output_pw.cpp +++ b/source/source_io/module_ctrl/ctrl_output_pw.cpp @@ -175,7 +175,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, pw_rhod->nxyz, ngmc, &ucell, - stp.psi_d, + stp.get_psi_d(), pw_rhod, pw_wfc, ctx, @@ -311,7 +311,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, inp.nspin, pw_rhod->nxyz, &ucell, - stp.psi_d, + stp.get_psi_d(), pw_wfc, ctx, para_grid, diff --git a/source/source_psi/setup_psi_pw.cpp b/source/source_psi/setup_psi_pw.cpp index b411327b26..a255203085 100644 --- a/source/source_psi/setup_psi_pw.cpp +++ b/source/source_psi/setup_psi_pw.cpp @@ -58,13 +58,15 @@ void Setup_Psi_pw::update_psi_d() { if (this->psi_d != nullptr && PARAM.inp.precision == "single") { - delete reinterpret_cast, Device>*>(this->psi_d); + delete this->get_psi_d(); } // Refresh this->psi_d - this->psi_d = PARAM.inp.precision == "single" - ? new psi::Psi, Device>(*this->get_psi_t()) - : reinterpret_cast, Device>*>(this->psi_t); + if (PARAM.inp.precision == "single") { + this->psi_d = static_cast(new psi::Psi, Device>(*this->get_psi_t())); + } else { + this->psi_d = static_cast(reinterpret_cast, Device>*>(this->psi_t)); + } } template @@ -126,7 +128,7 @@ void Setup_Psi_pw::clean() } if (PARAM.inp.precision == "single") { - delete this->psi_d; + delete this->get_psi_d(); } delete this->psi_cpu; diff --git a/source/source_psi/setup_psi_pw.h b/source/source_psi/setup_psi_pw.h index 5fa14b25ee..5ab03f8613 100644 --- a/source/source_psi/setup_psi_pw.h +++ b/source/source_psi/setup_psi_pw.h @@ -36,7 +36,8 @@ class Setup_Psi_pw void* psi_t = nullptr; // Use void* to store pointer, runtime type information records actual type // originally, this term is __kspw_psi - psi::Psi, Device>* psi_d = nullptr; + // psi::Psi, Device>* psi_d = nullptr; // Original template version + void* psi_d = nullptr; // Use void* to store pointer, runtime type information records actual type // psi_initializer controller psi::PSIPrepare* p_psi_init = nullptr; @@ -98,6 +99,14 @@ class Setup_Psi_pw // Get psi_t pointer (template version, for backward compatibility) psi::Psi* get_psi_t() { return static_cast*>(psi_t); } const psi::Psi* get_psi_t() const { return static_cast*>(psi_t); } + + // Get psi_d pointer (template version, for backward compatibility) + psi::Psi, Device>* get_psi_d() { + return static_cast, Device>*>(psi_d); + } + const psi::Psi, Device>* get_psi_d() const { + return static_cast, Device>*>(psi_d); + } private: From c65a0fcd94298b561121379032f10a0c1f633aa3 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Fri, 13 Mar 2026 20:38:15 +0800 Subject: [PATCH 25/40] refactor(psi): introduce PSIPrepareBase base class for template removal This is the first step towards removing template parameters from Setup_Psi_pw. Changes: 1. Create PSIPrepareBase base class - Non-template base class for PSIPrepare - Similar approach to HamiltBase for Hamilt 2. Modify PSIPrepare to inherit from PSIPrepareBase - Add #include "source_psi/psi_prepare_base.h" - Change class declaration to inherit from PSIPrepareBase 3. Update Setup_Psi_pw to use PSIPrepareBase* - Change p_psi_init from PSIPrepare* to PSIPrepareBase* - Add static_cast when calling PSIPrepare methods 4. Update all PSIPrepare usage in ESolver files - esolver_ks_pw.cpp: add static_cast before prepare_init call - esolver_ks_lcaopw.cpp: add static_cast before method calls Modified files: - source/source_psi/psi_prepare_base.h (new) - source/source_psi/psi_prepare.h - source/source_psi/setup_psi_pw.h - source/source_psi/setup_psi_pw.cpp - source/source_esolver/esolver_ks_pw.cpp - source/source_esolver/esolver_ks_lcaopw.cpp Benefits: - Eliminates p_psi_init template dependency from Setup_Psi_pw - Paves the way for removing template parameters from Setup_Psi_pw - Maintains type safety through static_cast - Follows the same pattern as HamiltBase refactoring Tested: Compiled successfully in build_5pt and build_1p --- source/source_esolver/esolver_ks_lcaopw.cpp | 6 ++++-- source/source_esolver/esolver_ks_pw.cpp | 3 ++- source/source_psi/psi_prepare.h | 3 ++- source/source_psi/psi_prepare_base.h | 23 +++++++++++++++++++++ source/source_psi/setup_psi_pw.cpp | 6 ++++-- source/source_psi/setup_psi_pw.h | 2 +- 6 files changed, 36 insertions(+), 7 deletions(-) create mode 100644 source/source_psi/psi_prepare_base.h diff --git a/source/source_esolver/esolver_ks_lcaopw.cpp b/source/source_esolver/esolver_ks_lcaopw.cpp index b30110046b..42b2e41b66 100644 --- a/source/source_esolver/esolver_ks_lcaopw.cpp +++ b/source/source_esolver/esolver_ks_lcaopw.cpp @@ -81,16 +81,18 @@ namespace ModuleESolver void ESolver_KS_LIP::before_scf(UnitCell& ucell, const int istep) { ESolver_KS_PW::before_scf(ucell, istep); - this->stp.p_psi_init->initialize_lcao_in_pw(this->psi_local, GlobalV::ofs_running); + auto* p_psi_init = static_cast*>(this->stp.p_psi_init); + p_psi_init->initialize_lcao_in_pw(this->psi_local, GlobalV::ofs_running); } template void ESolver_KS_LIP::before_all_runners(UnitCell& ucell, const Input_para& inp) { ESolver_KS_PW::before_all_runners(ucell, inp); + auto* p_psi_init = static_cast*>(this->stp.p_psi_init); delete this->psi_local; this->psi_local = new psi::Psi(this->stp.psi_cpu->get_nk(), - this->stp.p_psi_init->psi_initer->nbands_start(), + p_psi_init->psi_initer->nbands_start(), this->stp.psi_cpu->get_nbasis(), this->kv.ngk, true); diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 9d5581d499..71f9c977f6 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -111,7 +111,8 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) if (ucell.cell_parameter_updated) { - this->stp.p_psi_init->prepare_init(PARAM.inp.pw_seed); + auto* p_psi_init = static_cast*>(this->stp.p_psi_init); + p_psi_init->prepare_init(PARAM.inp.pw_seed); } //! Init Hamiltonian (cell changed) diff --git a/source/source_psi/psi_prepare.h b/source/source_psi/psi_prepare.h index c6406a1063..4b35f54521 100644 --- a/source/source_psi/psi_prepare.h +++ b/source/source_psi/psi_prepare.h @@ -2,13 +2,14 @@ #define PSI_PREPARE_H #include "source_hamilt/hamilt.h" #include "source_psi/psi_initializer.h" +#include "source_psi/psi_prepare_base.h" namespace psi { // This class is used to prepare the wavefunction template -class PSIPrepare +class PSIPrepare : public PSIPrepareBase { public: PSIPrepare(const std::string& init_wfc_in, diff --git a/source/source_psi/psi_prepare_base.h b/source/source_psi/psi_prepare_base.h new file mode 100644 index 0000000000..13b8716d8e --- /dev/null +++ b/source/source_psi/psi_prepare_base.h @@ -0,0 +1,23 @@ +#ifndef PSI_PREPARE_BASE_H +#define PSI_PREPARE_BASE_H + +namespace psi +{ + +/** + * @brief Base class for PSIPrepare without template parameters. + * + * This class provides a non-template base class for PSIPrepare, + * allowing Setup_Psi_pw to store a base class pointer instead of a template pointer. + * This is part of the gradual refactoring to remove template parameters from Setup_Psi_pw. + */ +class PSIPrepareBase +{ + public: + PSIPrepareBase() = default; + virtual ~PSIPrepareBase() = default; +}; + +} // namespace psi + +#endif diff --git a/source/source_psi/setup_psi_pw.cpp b/source/source_psi/setup_psi_pw.cpp index a255203085..55faad5f9f 100644 --- a/source/source_psi/setup_psi_pw.cpp +++ b/source/source_psi/setup_psi_pw.cpp @@ -24,7 +24,8 @@ void Setup_Psi_pw::before_runner( //! Allocate memory for cpu version of psi allocate_psi(this->psi_cpu, kv.get_nks(), kv.ngk, PARAM.globalv.nbands_l, pw_wfc.npwk_max); - this->p_psi_init->prepare_init(inp.pw_seed); + auto* p_psi_init = static_cast*>(this->p_psi_init); + p_psi_init->prepare_init(inp.pw_seed); //! Set runtime type information if (std::is_same::value) { @@ -75,7 +76,8 @@ void Setup_Psi_pw::init(hamilt::Hamilt* p_hamilt) //! Initialize wave functions if (!this->already_initpsi) { - this->p_psi_init->initialize_psi(this->psi_cpu, this->get_psi_t(), p_hamilt, GlobalV::ofs_running); + auto* p_psi_init = static_cast*>(this->p_psi_init); + p_psi_init->initialize_psi(this->psi_cpu, this->get_psi_t(), p_hamilt, GlobalV::ofs_running); this->already_initpsi = true; } } diff --git a/source/source_psi/setup_psi_pw.h b/source/source_psi/setup_psi_pw.h index 5ab03f8613..a959d6a871 100644 --- a/source/source_psi/setup_psi_pw.h +++ b/source/source_psi/setup_psi_pw.h @@ -40,7 +40,7 @@ class Setup_Psi_pw void* psi_d = nullptr; // Use void* to store pointer, runtime type information records actual type // psi_initializer controller - psi::PSIPrepare* p_psi_init = nullptr; + psi::PSIPrepareBase* p_psi_init = nullptr; bool already_initpsi = false; From 1e74e1738b9397a24488ab0b43e8e13d2cd28692 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Fri, 13 Mar 2026 20:44:32 +0800 Subject: [PATCH 26/40] refactor(psi): change init() parameter from template to HamiltBase* This is the second step towards removing template parameters from Setup_Psi_pw. Changes: 1. Modify init() function signature - Change parameter from hamilt::Hamilt* to hamilt::HamiltBase* - Eliminates template dependency in function signature 2. Update init() implementation - Add static_cast*> inside function - Maintain type safety through explicit cast 3. Update call site in esolver_ks_pw.cpp - Remove static_cast from call site - Directly pass p_hamilt (which is already HamiltBase*) Modified files: - source/source_psi/setup_psi_pw.h - source/source_psi/setup_psi_pw.cpp - source/source_esolver/esolver_ks_pw.cpp Benefits: - init() function no longer depends on template parameters in signature - Simplifies call sites (no cast needed) - Follows the same pattern as p_hamilt storage in ESolver_KS - One step closer to removing template parameters from Setup_Psi_pw Tested: Compiled successfully in build_5pt and build_1p --- fix_get_device_type.sh | 20 +++++++++++ fix_get_device_type_v2.sh | 45 +++++++++++++++++++++++++ source/source_esolver/esolver_ks_pw.cpp | 2 +- source/source_psi/setup_psi_pw.cpp | 5 +-- source/source_psi/setup_psi_pw.h | 2 +- 5 files changed, 70 insertions(+), 4 deletions(-) create mode 100755 fix_get_device_type.sh create mode 100755 fix_get_device_type_v2.sh diff --git a/fix_get_device_type.sh b/fix_get_device_type.sh new file mode 100755 index 0000000000..c63551e20b --- /dev/null +++ b/fix_get_device_type.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# 脚本用于修复 get_device_type 函数调用,移除不必要的 ctx 参数 + +# 搜索所有调用 get_device_type 的文件 +files=$(grep -r "get_device_type" --include="*.cpp" --include="*.hpp" --include="*.h" source/) + +# 遍历每个文件 +while IFS= read -r line; do + # 提取文件路径 + file=$(echo "$line" | cut -d: -f1) + + # 打印正在处理的文件 + echo "Processing: $file" + + # 替换 get_device_type(ctx) 为 get_device_type() + sed -i 's/get_device_type<\([^>]*\)>\(\s*\)(\s*\([^)]*\)\s*)/get_device_type<\1>\2()/g' "$file" +done <<< "$files" + +echo "Fix completed!" diff --git a/fix_get_device_type_v2.sh b/fix_get_device_type_v2.sh new file mode 100755 index 0000000000..907af13a29 --- /dev/null +++ b/fix_get_device_type_v2.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +# 脚本用于修复 get_device_type 函数调用,去掉模板参数 + +# 定义要修改的文件列表 +files=( + "source/source_pw/module_stodft/sto_forces.cpp" + "source/source_pw/module_stodft/sto_wf.cpp" + "source/source_pw/module_pwdft/fs_nonlocal_tools.cpp" + "source/source_pw/module_pwdft/onsite_projector.cpp" + "source/source_pw/module_pwdft/onsite_proj_tools.cpp" + "source/source_pw/module_pwdft/nonlocal_maths.hpp" + "source/source_pw/module_pwdft/op_pw_ekin.cpp" + "source/source_pw/module_pwdft/stress_loc.cpp" + "source/source_pw/module_pwdft/stress_cc.cpp" + "source/source_hsolver/test/hsolver_pw_sup.h" + "source/source_pw/module_pwdft/fs_kin_tools.cpp" + "source/source_pw/module_pwdft/forces_scc.cpp" + "source/source_pw/module_pwdft/forces_cc.cpp" + "source/source_pw/module_pwdft/forces.cpp" + "source/source_hsolver/diago_iter_assist.cpp" + "source/source_hsolver/diago_david.cpp" + "source/source_hsolver/diago_dav_subspace.cpp" + "source/source_esolver/esolver_ks_pw.cpp" + "source/source_base/math_chebyshev.cpp" + "source/source_pw/module_pwdft/structure_factor_k.cpp" + "source/source_base/module_device/test/device_test.cpp" +) + +# 遍历每个文件 +for file in "${files[@]}"; do + if [ -f "$file" ]; then + echo "Processing: $file" + # 替换 get_device_type( 为 get_device_type( + sed -i 's/get_device_type(/get_device_type(/g' "$file" + # 替换 get_device_type( 为 get_device_type( + sed -i 's/get_device_type(/get_device_type(/g' "$file" + # 替换 get_device_type( 为 get_device_type( + sed -i 's/get_device_type(/get_device_type(/g' "$file" + else + echo "File not found: $file" + fi +done + +echo "Fix completed!" diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 71f9c977f6..713b5b46c5 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -131,7 +131,7 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) this->stp.get_psi_t(), static_cast*>(this->p_hamilt), this->pw_wfc, this->pw_rhod, PARAM.inp); // setup psi (electronic wave functions) - this->stp.init(static_cast*>(this->p_hamilt)); + this->stp.init(this->p_hamilt); //! Setup EXX helper for Hamiltonian and psi exx_helper.before_scf(this->p_hamilt, this->stp.get_psi_t(), PARAM.inp); diff --git a/source/source_psi/setup_psi_pw.cpp b/source/source_psi/setup_psi_pw.cpp index 55faad5f9f..c3da4c0c56 100644 --- a/source/source_psi/setup_psi_pw.cpp +++ b/source/source_psi/setup_psi_pw.cpp @@ -71,13 +71,14 @@ void Setup_Psi_pw::update_psi_d() } template -void Setup_Psi_pw::init(hamilt::Hamilt* p_hamilt) +void Setup_Psi_pw::init(hamilt::HamiltBase* p_hamilt) { //! Initialize wave functions if (!this->already_initpsi) { auto* p_psi_init = static_cast*>(this->p_psi_init); - p_psi_init->initialize_psi(this->psi_cpu, this->get_psi_t(), p_hamilt, GlobalV::ofs_running); + auto* hamilt = static_cast*>(p_hamilt); + p_psi_init->initialize_psi(this->psi_cpu, this->get_psi_t(), hamilt, GlobalV::ofs_running); this->already_initpsi = true; } } diff --git a/source/source_psi/setup_psi_pw.h b/source/source_psi/setup_psi_pw.h index a959d6a871..f1dacfb4c1 100644 --- a/source/source_psi/setup_psi_pw.h +++ b/source/source_psi/setup_psi_pw.h @@ -70,7 +70,7 @@ class Setup_Psi_pw const pseudopot_cell_vnl &ppcell, const Input_para &inp); - void init(hamilt::Hamilt* p_hamilt); + void init(hamilt::HamiltBase* p_hamilt); void update_psi_d(); From b00a5ffa80ef192e757b4b2816026eeb8f91bba9 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Fri, 13 Mar 2026 21:05:36 +0800 Subject: [PATCH 27/40] refactor(psi): remove template version of copy_d2h function This is the third step towards removing template parameters from Setup_Psi_pw. Changes: 1. Remove template version copy_d2h(const Device* ctx) - Delete the template version from setup_psi_pw.h - Delete the implementation from setup_psi_pw.cpp - Keep only the runtime version copy_d2h(const base_device::DeviceContext* ctx) 2. Update call site in ctrl_output_pw.cpp - Create Device* ctx = nullptr for template parameter deduction - Use DeviceContext::instance() for runtime device context - Call copy_d2h with DeviceContext* pointer Modified files: - source/source_psi/setup_psi_pw.h - source/source_psi/setup_psi_pw.cpp - source/source_io/module_ctrl/ctrl_output_pw.cpp Benefits: - Eliminates copy_d2h function's template dependency - All member functions now use runtime device context - One step closer to removing template parameters from Setup_Psi_pw - Maintains backward compatibility with existing code Technical details: - get_pchg_pw and get_wf_pw still need Device* ctx for template deduction - DeviceContext is used for actual device type information - This follows the gradual migration pattern used in ESolver refactoring Tested: Compiled successfully in build_5pt and build_1p --- .../source_io/module_ctrl/ctrl_output_pw.cpp | 4 +++- source/source_psi/setup_psi_pw.cpp | 18 ------------------ source/source_psi/setup_psi_pw.h | 3 --- 3 files changed, 3 insertions(+), 22 deletions(-) diff --git a/source/source_io/module_ctrl/ctrl_output_pw.cpp b/source/source_io/module_ctrl/ctrl_output_pw.cpp index 4905828431..ed69182a79 100644 --- a/source/source_io/module_ctrl/ctrl_output_pw.cpp +++ b/source/source_io/module_ctrl/ctrl_output_pw.cpp @@ -101,7 +101,9 @@ void ModuleIO::ctrl_scf_pw(const int istep, Device* ctx = nullptr; // Transfer data from device (GPU) to host (CPU) in pw basis - stp.copy_d2h(ctx); + base_device::DeviceContext* device_ctx = &base_device::DeviceContext::instance(); + device_ctx->set_device_type(stp.get_device_type()); + stp.copy_d2h(device_ctx); //---------------------------------------------------------- //! 4) Compute density of states (DOS) diff --git a/source/source_psi/setup_psi_pw.cpp b/source/source_psi/setup_psi_pw.cpp index c3da4c0c56..53a94da3d2 100644 --- a/source/source_psi/setup_psi_pw.cpp +++ b/source/source_psi/setup_psi_pw.cpp @@ -84,24 +84,6 @@ void Setup_Psi_pw::init(hamilt::HamiltBase* p_hamilt) } -// Transfer data from GPU to CPU in pw basis -template -void Setup_Psi_pw::copy_d2h(const Device* ctx) -{ - if (base_device::get_device_type(ctx) == base_device::GpuDevice) - { - auto* psi_t = this->get_psi_t(); - castmem_2d_d2h_op()(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), - psi_t->get_pointer() - psi_t->get_psi_bias(), - this->psi_cpu[0].size()); - } - else - { - // do nothing - } - return; -} - // Transfer data from GPU to CPU in pw basis (runtime version) template void Setup_Psi_pw::copy_d2h(const base_device::DeviceContext* ctx) diff --git a/source/source_psi/setup_psi_pw.h b/source/source_psi/setup_psi_pw.h index f1dacfb4c1..a327ff03a9 100644 --- a/source/source_psi/setup_psi_pw.h +++ b/source/source_psi/setup_psi_pw.h @@ -74,9 +74,6 @@ class Setup_Psi_pw void update_psi_d(); - // Transfer data from device to host in pw basis - void copy_d2h(const Device* ctx); - // Transfer data from device to host in pw basis (runtime version) void copy_d2h(const base_device::DeviceContext* ctx); From 69996bf639b53d009eb2db0221a49344b730b4d4 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Fri, 13 Mar 2026 22:01:07 +0800 Subject: [PATCH 28/40] refactor(psi): remove castmem_2d_d2h_op template type alias dependency This is the fourth step towards removing template parameters from Setup_Psi_pw. Changes: 1. Remove castmem_2d_d2h_op type alias from setup_psi_pw.h - The type alias depended on template parameters T and Device - Replaced with overloaded member functions 2. Add castmem_d2h_impl() overloaded functions - One overload for std::complex source - One overload for std::complex source - Each uses the appropriate cast_memory_op internally 3. Update copy_d2h() to use the new overloaded functions - Calls castmem_d2h_impl() instead of castmem_2d_d2h_op() - Compiler selects the correct overload based on T type Modified files: - source/source_psi/setup_psi_pw.h - source/source_psi/setup_psi_pw.cpp Benefits: - All member variables now independent of template parameters - castmem_d2h_impl encapsulates the type-dependent logic - One step closer to removing template parameters from Setup_Psi_pw Tested: Compiled successfully in build_5pt and build_1p --- source/source_psi/setup_psi_pw.cpp | 18 +++++++++++++++--- source/source_psi/setup_psi_pw.h | 4 ++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/source/source_psi/setup_psi_pw.cpp b/source/source_psi/setup_psi_pw.cpp index 53a94da3d2..fb7c9d4edf 100644 --- a/source/source_psi/setup_psi_pw.cpp +++ b/source/source_psi/setup_psi_pw.cpp @@ -91,9 +91,9 @@ void Setup_Psi_pw::copy_d2h(const base_device::DeviceContext* ctx) if (base_device::get_device_type(ctx) == base_device::GpuDevice) { auto* psi_t = this->get_psi_t(); - castmem_2d_d2h_op()(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), - psi_t->get_pointer() - psi_t->get_psi_bias(), - this->psi_cpu[0].size()); + this->castmem_d2h_impl(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), + psi_t->get_pointer() - psi_t->get_psi_bias(), + this->psi_cpu[0].size()); } else { @@ -102,6 +102,18 @@ void Setup_Psi_pw::copy_d2h(const base_device::DeviceContext* ctx) return; } +template +void Setup_Psi_pw::castmem_d2h_impl(std::complex* dst, const std::complex* src, const size_t size) +{ + base_device::memory::cast_memory_op, std::complex, base_device::DEVICE_CPU, Device>()(dst, src, size); +} + +template +void Setup_Psi_pw::castmem_d2h_impl(std::complex* dst, const std::complex* src, const size_t size) +{ + base_device::memory::cast_memory_op, std::complex, base_device::DEVICE_CPU, Device>()(dst, src, size); +} + template diff --git a/source/source_psi/setup_psi_pw.h b/source/source_psi/setup_psi_pw.h index a327ff03a9..29461e41a1 100644 --- a/source/source_psi/setup_psi_pw.h +++ b/source/source_psi/setup_psi_pw.h @@ -107,8 +107,8 @@ class Setup_Psi_pw private: - using castmem_2d_d2h_op - = base_device::memory::cast_memory_op, T, base_device::DEVICE_CPU, Device>; + void castmem_d2h_impl(std::complex* dst, const std::complex* src, const size_t size); + void castmem_d2h_impl(std::complex* dst, const std::complex* src, const size_t size); }; From e5e78ea980bcd705c6346381a33f7968c03262c1 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Sat, 14 Mar 2026 09:16:21 +0800 Subject: [PATCH 29/40] delete useless files --- fix_get_device_type.sh | 20 ----------------- fix_get_device_type_v2.sh | 45 --------------------------------------- 2 files changed, 65 deletions(-) delete mode 100755 fix_get_device_type.sh delete mode 100755 fix_get_device_type_v2.sh diff --git a/fix_get_device_type.sh b/fix_get_device_type.sh deleted file mode 100755 index c63551e20b..0000000000 --- a/fix_get_device_type.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -# 脚本用于修复 get_device_type 函数调用,移除不必要的 ctx 参数 - -# 搜索所有调用 get_device_type 的文件 -files=$(grep -r "get_device_type" --include="*.cpp" --include="*.hpp" --include="*.h" source/) - -# 遍历每个文件 -while IFS= read -r line; do - # 提取文件路径 - file=$(echo "$line" | cut -d: -f1) - - # 打印正在处理的文件 - echo "Processing: $file" - - # 替换 get_device_type(ctx) 为 get_device_type() - sed -i 's/get_device_type<\([^>]*\)>\(\s*\)(\s*\([^)]*\)\s*)/get_device_type<\1>\2()/g' "$file" -done <<< "$files" - -echo "Fix completed!" diff --git a/fix_get_device_type_v2.sh b/fix_get_device_type_v2.sh deleted file mode 100755 index 907af13a29..0000000000 --- a/fix_get_device_type_v2.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/bash - -# 脚本用于修复 get_device_type 函数调用,去掉模板参数 - -# 定义要修改的文件列表 -files=( - "source/source_pw/module_stodft/sto_forces.cpp" - "source/source_pw/module_stodft/sto_wf.cpp" - "source/source_pw/module_pwdft/fs_nonlocal_tools.cpp" - "source/source_pw/module_pwdft/onsite_projector.cpp" - "source/source_pw/module_pwdft/onsite_proj_tools.cpp" - "source/source_pw/module_pwdft/nonlocal_maths.hpp" - "source/source_pw/module_pwdft/op_pw_ekin.cpp" - "source/source_pw/module_pwdft/stress_loc.cpp" - "source/source_pw/module_pwdft/stress_cc.cpp" - "source/source_hsolver/test/hsolver_pw_sup.h" - "source/source_pw/module_pwdft/fs_kin_tools.cpp" - "source/source_pw/module_pwdft/forces_scc.cpp" - "source/source_pw/module_pwdft/forces_cc.cpp" - "source/source_pw/module_pwdft/forces.cpp" - "source/source_hsolver/diago_iter_assist.cpp" - "source/source_hsolver/diago_david.cpp" - "source/source_hsolver/diago_dav_subspace.cpp" - "source/source_esolver/esolver_ks_pw.cpp" - "source/source_base/math_chebyshev.cpp" - "source/source_pw/module_pwdft/structure_factor_k.cpp" - "source/source_base/module_device/test/device_test.cpp" -) - -# 遍历每个文件 -for file in "${files[@]}"; do - if [ -f "$file" ]; then - echo "Processing: $file" - # 替换 get_device_type( 为 get_device_type( - sed -i 's/get_device_type(/get_device_type(/g' "$file" - # 替换 get_device_type( 为 get_device_type( - sed -i 's/get_device_type(/get_device_type(/g' "$file" - # 替换 get_device_type( 为 get_device_type( - sed -i 's/get_device_type(/get_device_type(/g' "$file" - else - echo "File not found: $file" - fi -done - -echo "Fix completed!" From f007b20ca015adf223bd29219797ee0e28835e8b Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Sat, 14 Mar 2026 11:46:22 +0800 Subject: [PATCH 30/40] refactor(psi): remove template parameters from Setup_Psi_pw class - Remove template parameters from Setup_Psi_pw class - Convert member functions to template functions - Update all call sites to explicitly specify template parameters - This is a major refactoring step to enable runtime polymorphism Modified files: - source/source_psi/setup_psi_pw.h - source/source_psi/setup_psi_pw.cpp - source/source_esolver/esolver_ks_pw.h - source/source_esolver/esolver_ks_pw.cpp - source/source_esolver/esolver_sdft_pw.cpp - source/source_esolver/esolver_ks_lcaopw.cpp - source/source_io/module_ctrl/ctrl_output_pw.h - source/source_io/module_ctrl/ctrl_output_pw.cpp Key changes: 1. Setup_Psi_pw class no longer has template parameters 2. Member functions like get_psi_t(), get_psi_d(), before_runner(), init(), update_psi_d(), clean() are now template functions 3. All call sites now use stp.template get_psi_t() instead of stp.get_psi_t() 4. Removed template instantiation statements --- source/source_esolver/esolver_ks_lcaopw.cpp | 6 +-- source/source_esolver/esolver_ks_pw.cpp | 26 ++++++------ source/source_esolver/esolver_ks_pw.h | 2 +- source/source_esolver/esolver_sdft_pw.cpp | 8 ++-- .../source_io/module_ctrl/ctrl_output_pw.cpp | 36 ++++++++--------- source/source_io/module_ctrl/ctrl_output_pw.h | 6 +-- source/source_psi/setup_psi_pw.cpp | 40 +++++++------------ source/source_psi/setup_psi_pw.h | 17 +++++++- 8 files changed, 72 insertions(+), 69 deletions(-) diff --git a/source/source_esolver/esolver_ks_lcaopw.cpp b/source/source_esolver/esolver_ks_lcaopw.cpp index 42b2e41b66..7726c4ff72 100644 --- a/source/source_esolver/esolver_ks_lcaopw.cpp +++ b/source/source_esolver/esolver_ks_lcaopw.cpp @@ -107,7 +107,7 @@ namespace ModuleESolver ucell.symm, &this->kv, this->psi_local, - this->stp.get_psi_t(), + this->stp.template get_psi_t(), this->pw_wfc, this->pw_rho, this->sf, @@ -148,7 +148,7 @@ namespace ModuleESolver bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false; hsolver::HSolverLIP hsolver_lip_obj(this->pw_wfc); - hsolver_lip_obj.solve(static_cast*>(this->p_hamilt), *this->stp.get_psi_t(), this->pelec, + hsolver_lip_obj.solve(static_cast*>(this->p_hamilt), *this->stp.template get_psi_t(), this->pelec, *this->psi_local, skip_charge,ucell.tpiba,ucell.nat); // add exx @@ -242,7 +242,7 @@ namespace ModuleESolver ModuleIO::write_Vxc(PARAM.inp.nspin, PARAM.globalv.nlocal, GlobalV::DRANK, - *this->stp.get_psi_t(), + *this->stp.template get_psi_t(), ucell, this->sf, this->solvent, diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 713b5b46c5..79837db644 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -55,7 +55,7 @@ ESolver_KS_PW::~ESolver_KS_PW() this->deallocate_hamilt(); // mohan add 2025-10-12 - this->stp.clean(); + this->stp.clean(); } template @@ -90,7 +90,7 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p this->locpp, this->ppcell, this->vsep_cell, this->pw_wfc, this->pw_rho, this->pw_rhod, this->pw_big, this->solvent, inp); - this->stp.before_runner(ucell, this->kv, this->sf, *this->pw_wfc, this->ppcell, PARAM.inp); + this->stp.before_runner(ucell, this->kv, this->sf, *this->pw_wfc, this->ppcell, PARAM.inp); ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT BASIS"); @@ -128,13 +128,13 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) // init DFT+U is done in "before_all_runners" in LCAO basis. This should be refactored, mohan note 2025-11-06 pw::setup_pot(istep, ucell, this->kv, this->sf, this->pelec, this->Pgrid, this->chr, this->locpp, this->ppcell, this->dftu, this->vsep_cell, - this->stp.get_psi_t(), static_cast*>(this->p_hamilt), this->pw_wfc, this->pw_rhod, PARAM.inp); + this->stp.template get_psi_t(), static_cast*>(this->p_hamilt), this->pw_wfc, this->pw_rhod, PARAM.inp); // setup psi (electronic wave functions) - this->stp.init(this->p_hamilt); + this->stp.init(static_cast*>(this->p_hamilt)); //! Setup EXX helper for Hamiltonian and psi - exx_helper.before_scf(this->p_hamilt, this->stp.get_psi_t(), PARAM.inp); + exx_helper.before_scf(this->p_hamilt, this->stp.template get_psi_t(), PARAM.inp); ModuleBase::timer::tick("ESolver_KS_PW", "before_scf"); } @@ -152,7 +152,7 @@ void ESolver_KS_PW::iter_init(UnitCell& ucell, const int istep, const // update local occupations for DFT+U // should before lambda loop in DeltaSpin - pw::iter_init_dftu_pw(iter, istep, this->dftu, this->stp.get_psi_t(), this->pelec->wg, ucell, PARAM.inp); + pw::iter_init_dftu_pw(iter, istep, this->dftu, this->stp.template get_psi_t(), this->pelec->wg, ucell, PARAM.inp); } // Temporary, it should be replaced by hsolver later. @@ -188,7 +188,7 @@ void ESolver_KS_PW::hamilt2rho_single(UnitCell& ucell, const int iste hsolver::DiagoIterAssist::need_subspace, PARAM.inp.use_k_continuity); - hsolver_pw_obj.solve(static_cast*>(this->p_hamilt), *this->stp.get_psi_t(), this->pelec, this->pelec->ekb.c, + hsolver_pw_obj.solve(static_cast*>(this->p_hamilt), *this->stp.template get_psi_t(), this->pelec, this->pelec->ekb.c, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL, skip_charge, ucell.tpiba, ucell.nat); } @@ -205,7 +205,7 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int // Related to EXX if (GlobalC::exx_info.info_global.cal_exx && !exx_helper.op_exx->first_iter) { - this->pelec->set_exx(exx_helper.cal_exx_energy(this->stp.get_psi_t())); + this->pelec->set_exx(exx_helper.cal_exx_energy(this->stp.template get_psi_t())); } // deband is calculated from "output" charge density @@ -224,7 +224,7 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int } // Handle EXX-related operations after SCF iteration - exx_helper.iter_finish(this->pelec, &this->chr, this->stp.get_psi_t(), ucell, PARAM.inp, conv_esolver, iter); + exx_helper.iter_finish(this->pelec, &this->chr, this->stp.template get_psi_t(), ucell, PARAM.inp, conv_esolver, iter); // check if oscillate for delta_spin method pw::check_deltaspin_oscillation(iter, this->drho, this->p_chgmix, PARAM.inp); @@ -268,12 +268,12 @@ void ESolver_KS_PW::cal_force(UnitCell& ucell, ModuleBase::matrix& fo Forces ff(ucell.nat); // mohan add 2025-10-12 - this->stp.update_psi_d(); + this->stp.update_psi_d(); // Calculate forces ff.cal_force(ucell, force, *this->pelec, this->pw_rhod, &ucell.symm, &this->sf, this->solvent, &this->dftu, &this->locpp, &this->ppcell, - &this->kv, this->pw_wfc, this->stp.get_psi_d()); + &this->kv, this->pw_wfc, this->stp.template get_psi_d()); } template @@ -282,10 +282,10 @@ void ESolver_KS_PW::cal_stress(UnitCell& ucell, ModuleBase::matrix& s Stress_PW ss(this->pelec); // mohan add 2025-10-12 - this->stp.update_psi_d(); + this->stp.update_psi_d(); ss.cal_stress(stress, ucell, this->dftu, this->locpp, this->ppcell, this->pw_rhod, - &ucell.symm, &this->sf, &this->kv, this->pw_wfc, this->stp.get_psi_d()); + &ucell.symm, &this->sf, &this->kv, this->pw_wfc, this->stp.template get_psi_d()); // external stress double unit_transform = 0.0; diff --git a/source/source_esolver/esolver_ks_pw.h b/source/source_esolver/esolver_ks_pw.h index 5ab756647d..16d6530e2b 100644 --- a/source/source_esolver/esolver_ks_pw.h +++ b/source/source_esolver/esolver_ks_pw.h @@ -52,7 +52,7 @@ class ESolver_KS_PW : public ESolver_KS virtual void deallocate_hamilt(); // Electronic wave function psi - Setup_Psi_pw stp; + Setup_Psi_pw stp; // DFT-1/2 method VSep* vsep_cell = nullptr; diff --git a/source/source_esolver/esolver_sdft_pw.cpp b/source/source_esolver/esolver_sdft_pw.cpp index 86658eb645..c3bbdad210 100644 --- a/source/source_esolver/esolver_sdft_pw.cpp +++ b/source/source_esolver/esolver_sdft_pw.cpp @@ -168,7 +168,7 @@ void ESolver_SDFT_PW::hamilt2rho_single(UnitCell& ucell, int istep, i hsolver_pw_sdft_obj.solve(ucell, static_cast*>(this->p_hamilt), - *this->stp.get_psi_t(), + *this->stp.template get_psi_t(), this->stp.psi_cpu[0], this->pelec, this->pw_wfc, @@ -221,7 +221,7 @@ void ESolver_SDFT_PW::cal_force(UnitCell& ucell, ModuleBase::matrix& this->locpp, this->ppcell, ucell, - *this->stp.get_psi_t(), + *this->stp.template get_psi_t(), this->stowf); } @@ -236,7 +236,7 @@ void ESolver_SDFT_PW::cal_stress(UnitCell& ucell, ModuleBase::matrix& &this->sf, &this->kv, this->pw_wfc, - *this->stp.get_psi_t(), + *this->stp.template get_psi_t(), this->stowf, &this->chr, &this->locpp, @@ -289,7 +289,7 @@ void ESolver_SDFT_PW::after_all_runners(UnitCell& ucell) &this->kv, this->pelec, this->pw_wfc, - this->stp.get_psi_t(), + this->stp.template get_psi_t(), &this->ppcell, static_cast, Device>*>(this->p_hamilt), this->stoche, diff --git a/source/source_io/module_ctrl/ctrl_output_pw.cpp b/source/source_io/module_ctrl/ctrl_output_pw.cpp index ed69182a79..6bbf1e4329 100644 --- a/source/source_io/module_ctrl/ctrl_output_pw.cpp +++ b/source/source_io/module_ctrl/ctrl_output_pw.cpp @@ -90,7 +90,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - Setup_Psi_pw &stp, + Setup_Psi_pw &stp, const Parallel_Grid ¶_grid, const Input_para& inp) { @@ -103,7 +103,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, // Transfer data from device (GPU) to host (CPU) in pw basis base_device::DeviceContext* device_ctx = &base_device::DeviceContext::instance(); device_ctx->set_device_type(stp.get_device_type()); - stp.copy_d2h(device_ctx); + stp.template copy_d2h(device_ctx); //---------------------------------------------------------- //! 4) Compute density of states (DOS) @@ -166,7 +166,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, if (inp.out_pchg.size() > 0) { // update psi_d - stp.update_psi_d(); + stp.template update_psi_d(); const int nbands = stp.get_nbands(); const int ngmc = chr.ngmc; @@ -177,7 +177,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, pw_rhod->nxyz, ngmc, &ucell, - stp.get_psi_d(), + stp.template get_psi_d(), pw_rhod, pw_wfc, ctx, @@ -239,7 +239,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, if (inp.onsite_radius > 0) { // float type has not been implemented auto* onsite_p = projectors::OnsiteProjector::get_instance(); - onsite_p->cal_occupations(reinterpret_cast, Device>*>(stp.get_psi_t()), + onsite_p->cal_occupations(reinterpret_cast, Device>*>(stp.template get_psi_t()), pelec->wg); } @@ -255,7 +255,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - Setup_Psi_pw &stp, + Setup_Psi_pw &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, @@ -305,7 +305,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, //---------------------------------------------------------- if (inp.out_wfc_norm.size() > 0 || inp.out_wfc_re_im.size() > 0) { - stp.update_psi_d(); + stp.template update_psi_d(); ModuleIO::get_wf_pw(inp.out_wfc_norm, inp.out_wfc_re_im, @@ -313,7 +313,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, inp.nspin, pw_rhod->nxyz, &ucell, - stp.get_psi_d(), + stp.template get_psi_d(), pw_wfc, ctx, para_grid, @@ -329,7 +329,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, if (inp.cal_cond) { using Real = typename GetTypeReal::type; - EleCond elec_cond(&ucell, &kv, pelec, pw_wfc, stp.get_psi_t(), &ppcell); + EleCond elec_cond(&ucell, &kv, pelec, pw_wfc, stp.template get_psi_t(), &ppcell); elec_cond.KG(inp.cond_smear, inp.cond_fwhm, inp.cond_wcut, @@ -366,7 +366,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, pw_rho); write_mlkedf_desc.generateTrainData_KS(PARAM.globalv.global_mlkedf_descriptor_dir, - stp.get_psi_t(), + stp.template get_psi_t(), pelec, pw_wfc, pw_rho, @@ -389,7 +389,7 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CPU const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - Setup_Psi_pw, base_device::DEVICE_CPU> &stp, + Setup_Psi_pw &stp, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -404,7 +404,7 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CP const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - Setup_Psi_pw, base_device::DEVICE_CPU> &stp, + Setup_Psi_pw &stp, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -420,7 +420,7 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GPU const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - Setup_Psi_pw, base_device::DEVICE_GPU> &stp, + Setup_Psi_pw &stp, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -435,7 +435,7 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GP const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - Setup_Psi_pw, base_device::DEVICE_GPU> &stp, + Setup_Psi_pw &stp, const Parallel_Grid ¶_grid, const Input_para& inp); #endif @@ -449,7 +449,7 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_ ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - Setup_Psi_pw, base_device::DEVICE_CPU> &stp, + Setup_Psi_pw &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, @@ -465,7 +465,7 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - Setup_Psi_pw, base_device::DEVICE_CPU> &stp, + Setup_Psi_pw &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, @@ -482,7 +482,7 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_ ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - Setup_Psi_pw, base_device::DEVICE_GPU> &stp, + Setup_Psi_pw &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, @@ -498,7 +498,7 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - Setup_Psi_pw, base_device::DEVICE_GPU> &stp, + Setup_Psi_pw &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, diff --git a/source/source_io/module_ctrl/ctrl_output_pw.h b/source/source_io/module_ctrl/ctrl_output_pw.h index 554e5d8adb..00b7509990 100644 --- a/source/source_io/module_ctrl/ctrl_output_pw.h +++ b/source/source_io/module_ctrl/ctrl_output_pw.h @@ -29,7 +29,7 @@ void ctrl_scf_pw(const int istep, const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - Setup_Psi_pw &stp, + Setup_Psi_pw &stp, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -42,7 +42,7 @@ void ctrl_runner_pw(UnitCell& ucell, ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - Setup_Psi_pw &stp, + Setup_Psi_pw &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, @@ -58,7 +58,7 @@ void ctrl_runner_pw(UnitCell& ucell, ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - Setup_Psi_pw &stp, + Setup_Psi_pw &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, diff --git a/source/source_psi/setup_psi_pw.cpp b/source/source_psi/setup_psi_pw.cpp index fb7c9d4edf..be5b5c9caf 100644 --- a/source/source_psi/setup_psi_pw.cpp +++ b/source/source_psi/setup_psi_pw.cpp @@ -1,14 +1,12 @@ #include "source_psi/setup_psi_pw.h" #include "source_io/module_parameter/parameter.h" // use parameter -template -Setup_Psi_pw::Setup_Psi_pw(){} +Setup_Psi_pw::Setup_Psi_pw(){} -template -Setup_Psi_pw::~Setup_Psi_pw(){} +Setup_Psi_pw::~Setup_Psi_pw(){} template -void Setup_Psi_pw::before_runner( +void Setup_Psi_pw::before_runner( const UnitCell &ucell, const K_Vectors &kv, const Structure_Factor &sf, @@ -55,30 +53,29 @@ void Setup_Psi_pw::before_runner( template -void Setup_Psi_pw::update_psi_d() +void Setup_Psi_pw::update_psi_d() { if (this->psi_d != nullptr && PARAM.inp.precision == "single") { - delete this->get_psi_d(); + delete this->get_psi_d(); } // Refresh this->psi_d if (PARAM.inp.precision == "single") { - this->psi_d = static_cast(new psi::Psi, Device>(*this->get_psi_t())); + this->psi_d = static_cast(new psi::Psi, Device>(*this->get_psi_t())); } else { this->psi_d = static_cast(reinterpret_cast, Device>*>(this->psi_t)); } } template -void Setup_Psi_pw::init(hamilt::HamiltBase* p_hamilt) +void Setup_Psi_pw::init(hamilt::Hamilt* p_hamilt) { //! Initialize wave functions if (!this->already_initpsi) { auto* p_psi_init = static_cast*>(this->p_psi_init); - auto* hamilt = static_cast*>(p_hamilt); - p_psi_init->initialize_psi(this->psi_cpu, this->get_psi_t(), hamilt, GlobalV::ofs_running); + p_psi_init->initialize_psi(this->psi_cpu, this->get_psi_t(), p_hamilt, GlobalV::ofs_running); this->already_initpsi = true; } } @@ -86,11 +83,11 @@ void Setup_Psi_pw::init(hamilt::HamiltBase* p_hamilt) // Transfer data from GPU to CPU in pw basis (runtime version) template -void Setup_Psi_pw::copy_d2h(const base_device::DeviceContext* ctx) +void Setup_Psi_pw::copy_d2h(const base_device::DeviceContext* ctx) { if (base_device::get_device_type(ctx) == base_device::GpuDevice) { - auto* psi_t = this->get_psi_t(); + auto* psi_t = this->get_psi_t(); this->castmem_d2h_impl(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), psi_t->get_pointer() - psi_t->get_psi_bias(), this->psi_cpu[0].size()); @@ -103,13 +100,13 @@ void Setup_Psi_pw::copy_d2h(const base_device::DeviceContext* ctx) } template -void Setup_Psi_pw::castmem_d2h_impl(std::complex* dst, const std::complex* src, const size_t size) +void Setup_Psi_pw::castmem_d2h_impl(std::complex* dst, const std::complex* src, const size_t size) { base_device::memory::cast_memory_op, std::complex, base_device::DEVICE_CPU, Device>()(dst, src, size); } template -void Setup_Psi_pw::castmem_d2h_impl(std::complex* dst, const std::complex* src, const size_t size) +void Setup_Psi_pw::castmem_d2h_impl(std::complex* dst, const std::complex* src, const size_t size) { base_device::memory::cast_memory_op, std::complex, base_device::DEVICE_CPU, Device>()(dst, src, size); } @@ -117,24 +114,17 @@ void Setup_Psi_pw::castmem_d2h_impl(std::complex* dst, const template -void Setup_Psi_pw::clean() +void Setup_Psi_pw::clean() { if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") { - delete this->get_psi_t(); + delete this->get_psi_t(); } if (PARAM.inp.precision == "single") { - delete this->get_psi_d(); + delete this->get_psi_d(); } delete this->psi_cpu; delete this->p_psi_init; } - -template class Setup_Psi_pw, base_device::DEVICE_CPU>; -template class Setup_Psi_pw, base_device::DEVICE_CPU>; -#if ((defined __CUDA) || (defined __ROCM)) -template class Setup_Psi_pw, base_device::DEVICE_GPU>; -template class Setup_Psi_pw, base_device::DEVICE_GPU>; -#endif diff --git a/source/source_psi/setup_psi_pw.h b/source/source_psi/setup_psi_pw.h index 29461e41a1..8c7bd59295 100644 --- a/source/source_psi/setup_psi_pw.h +++ b/source/source_psi/setup_psi_pw.h @@ -11,7 +11,6 @@ #include "source_base/module_device/device.h" #include "source_hamilt/hamilt.h" -template class Setup_Psi_pw { public: @@ -62,6 +61,7 @@ class Setup_Psi_pw // functions //------------ + template void before_runner( const UnitCell &ucell, const K_Vectors &kv, @@ -70,13 +70,17 @@ class Setup_Psi_pw const pseudopot_cell_vnl &ppcell, const Input_para &inp); - void init(hamilt::HamiltBase* p_hamilt); + template + void init(hamilt::Hamilt* p_hamilt); + template void update_psi_d(); // Transfer data from device to host in pw basis (runtime version) + template void copy_d2h(const base_device::DeviceContext* ctx); + template void clean(); //------------ @@ -94,20 +98,29 @@ class Setup_Psi_pw PrecisionType get_precision_type() const { return precision_type_; } // Get psi_t pointer (template version, for backward compatibility) + template psi::Psi* get_psi_t() { return static_cast*>(psi_t); } + + template const psi::Psi* get_psi_t() const { return static_cast*>(psi_t); } // Get psi_d pointer (template version, for backward compatibility) + template psi::Psi, Device>* get_psi_d() { return static_cast, Device>*>(psi_d); } + + template const psi::Psi, Device>* get_psi_d() const { return static_cast, Device>*>(psi_d); } private: + template void castmem_d2h_impl(std::complex* dst, const std::complex* src, const size_t size); + + template void castmem_d2h_impl(std::complex* dst, const std::complex* src, const size_t size); }; From 1ee804f19a2219623a1e807dd674cee68d777104 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Sat, 14 Mar 2026 13:21:30 +0800 Subject: [PATCH 31/40] fix(psi): add explicit template instantiation for Setup_Psi_pw class Fix undefined reference linker errors for template member functions: - update_psi_d() - copy_d2h() - clean() - before_runner() - init() - castmem_d2h_impl() Changes: 1. Add explicit template instantiation for CPU version: - std::complex, DEVICE_CPU - std::complex, DEVICE_CPU 2. Add explicit template instantiation for GPU version (conditional): - std::complex, DEVICE_GPU - std::complex, DEVICE_GPU - Wrapped with #if ((defined __CUDA) || (defined __ROCM)) 3. Fix template argument deduction error: - Add explicit template parameters when calling castmem_d2h_impl() - The template parameter T is not used in function parameters, so it cannot be deduced Root cause: Template functions defined in .cpp files require explicit instantiation for each type combination used by other compilation units. Modified files: - source/source_psi/setup_psi_pw.cpp (+91 lines) --- source/source_psi/setup_psi_pw.cpp | 94 +++++++++++++++++++++++++++++- 1 file changed, 91 insertions(+), 3 deletions(-) diff --git a/source/source_psi/setup_psi_pw.cpp b/source/source_psi/setup_psi_pw.cpp index be5b5c9caf..a7492c707e 100644 --- a/source/source_psi/setup_psi_pw.cpp +++ b/source/source_psi/setup_psi_pw.cpp @@ -88,9 +88,9 @@ void Setup_Psi_pw::copy_d2h(const base_device::DeviceContext* ctx) if (base_device::get_device_type(ctx) == base_device::GpuDevice) { auto* psi_t = this->get_psi_t(); - this->castmem_d2h_impl(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), - psi_t->get_pointer() - psi_t->get_psi_bias(), - this->psi_cpu[0].size()); + this->castmem_d2h_impl(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), + psi_t->get_pointer() - psi_t->get_psi_bias(), + this->psi_cpu[0].size()); } else { @@ -128,3 +128,91 @@ void Setup_Psi_pw::clean() delete this->psi_cpu; delete this->p_psi_init; } + +template class psi::PSIPrepare, base_device::DEVICE_CPU>; +template class psi::PSIPrepare, base_device::DEVICE_CPU>; + +template void Setup_Psi_pw::before_runner, base_device::DEVICE_CPU>( + const UnitCell&, const K_Vectors&, const Structure_Factor&, + const ModulePW::PW_Basis_K&, const pseudopot_cell_vnl&, const Input_para&); + +template void Setup_Psi_pw::before_runner, base_device::DEVICE_CPU>( + const UnitCell&, const K_Vectors&, const Structure_Factor&, + const ModulePW::PW_Basis_K&, const pseudopot_cell_vnl&, const Input_para&); + +template void Setup_Psi_pw::init, base_device::DEVICE_CPU>( + hamilt::Hamilt, base_device::DEVICE_CPU>*); + +template void Setup_Psi_pw::init, base_device::DEVICE_CPU>( + hamilt::Hamilt, base_device::DEVICE_CPU>*); + +template void Setup_Psi_pw::update_psi_d, base_device::DEVICE_CPU>(); + +template void Setup_Psi_pw::update_psi_d, base_device::DEVICE_CPU>(); + +template void Setup_Psi_pw::copy_d2h, base_device::DEVICE_CPU>( + const base_device::DeviceContext*); + +template void Setup_Psi_pw::copy_d2h, base_device::DEVICE_CPU>( + const base_device::DeviceContext*); + +template void Setup_Psi_pw::clean, base_device::DEVICE_CPU>(); + +template void Setup_Psi_pw::clean, base_device::DEVICE_CPU>(); + +template void Setup_Psi_pw::castmem_d2h_impl, base_device::DEVICE_CPU>( + std::complex*, const std::complex*, const size_t); + +template void Setup_Psi_pw::castmem_d2h_impl, base_device::DEVICE_CPU>( + std::complex*, const std::complex*, const size_t); + +template void Setup_Psi_pw::castmem_d2h_impl, base_device::DEVICE_CPU>( + std::complex*, const std::complex*, const size_t); + +template void Setup_Psi_pw::castmem_d2h_impl, base_device::DEVICE_CPU>( + std::complex*, const std::complex*, const size_t); + +#if ((defined __CUDA) || (defined __ROCM)) +template class psi::PSIPrepare, base_device::DEVICE_GPU>; +template class psi::PSIPrepare, base_device::DEVICE_GPU>; + +template void Setup_Psi_pw::before_runner, base_device::DEVICE_GPU>( + const UnitCell&, const K_Vectors&, const Structure_Factor&, + const ModulePW::PW_Basis_K&, const pseudopot_cell_vnl&, const Input_para&); + +template void Setup_Psi_pw::before_runner, base_device::DEVICE_GPU>( + const UnitCell&, const K_Vectors&, const Structure_Factor&, + const ModulePW::PW_Basis_K&, const pseudopot_cell_vnl&, const Input_para&); + +template void Setup_Psi_pw::init, base_device::DEVICE_GPU>( + hamilt::Hamilt, base_device::DEVICE_GPU>*); + +template void Setup_Psi_pw::init, base_device::DEVICE_GPU>( + hamilt::Hamilt, base_device::DEVICE_GPU>*); + +template void Setup_Psi_pw::update_psi_d, base_device::DEVICE_GPU>(); + +template void Setup_Psi_pw::update_psi_d, base_device::DEVICE_GPU>(); + +template void Setup_Psi_pw::copy_d2h, base_device::DEVICE_GPU>( + const base_device::DeviceContext*); + +template void Setup_Psi_pw::copy_d2h, base_device::DEVICE_GPU>( + const base_device::DeviceContext*); + +template void Setup_Psi_pw::clean, base_device::DEVICE_GPU>(); + +template void Setup_Psi_pw::clean, base_device::DEVICE_GPU>(); + +template void Setup_Psi_pw::castmem_d2h_impl, base_device::DEVICE_GPU>( + std::complex*, const std::complex*, const size_t); + +template void Setup_Psi_pw::castmem_d2h_impl, base_device::DEVICE_GPU>( + std::complex*, const std::complex*, const size_t); + +template void Setup_Psi_pw::castmem_d2h_impl, base_device::DEVICE_GPU>( + std::complex*, const std::complex*, const size_t); + +template void Setup_Psi_pw::castmem_d2h_impl, base_device::DEVICE_GPU>( + std::complex*, const std::complex*, const size_t); +#endif From 3372256f4b752a1336853c7447e2a91bef529df3 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Sat, 14 Mar 2026 14:01:52 +0800 Subject: [PATCH 32/40] refactor(psi): convert before_runner to non-template function Convert Setup_Psi_pw::before_runner from template function to non-template function with runtime type dispatch. Changes: 1. setup_psi_pw.h: - Remove template parameters from before_runner declaration - Add private template function before_runner_impl for internal use 2. setup_psi_pw.cpp: - Rename original before_runner to before_runner_impl (template) - Add new non-template before_runner that dispatches based on: - inp.device (gpu or cpu) - inp.precision (single or double) - Update template instantiation from before_runner to before_runner_impl 3. esolver_ks_pw.cpp: - Update call from stp.before_runner(...) to stp.before_runner(...) Benefits: - Caller no longer needs to specify template parameters - Type is determined at runtime from input parameters - Simpler API for ESolver --- source/source_esolver/esolver_ks_pw.cpp | 2 +- source/source_psi/setup_psi_pw.cpp | 48 +++++++++++++++++++------ source/source_psi/setup_psi_pw.h | 10 +++++- 3 files changed, 48 insertions(+), 12 deletions(-) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 79837db644..cb8f658a85 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -90,7 +90,7 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p this->locpp, this->ppcell, this->vsep_cell, this->pw_wfc, this->pw_rho, this->pw_rhod, this->pw_big, this->solvent, inp); - this->stp.before_runner(ucell, this->kv, this->sf, *this->pw_wfc, this->ppcell, PARAM.inp); + this->stp.before_runner(ucell, this->kv, this->sf, *this->pw_wfc, this->ppcell, PARAM.inp); ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT BASIS"); diff --git a/source/source_psi/setup_psi_pw.cpp b/source/source_psi/setup_psi_pw.cpp index a7492c707e..3f8579da3e 100644 --- a/source/source_psi/setup_psi_pw.cpp +++ b/source/source_psi/setup_psi_pw.cpp @@ -6,7 +6,7 @@ Setup_Psi_pw::Setup_Psi_pw(){} Setup_Psi_pw::~Setup_Psi_pw(){} template -void Setup_Psi_pw::before_runner( +void Setup_Psi_pw::before_runner_impl( const UnitCell &ucell, const K_Vectors &kv, const Structure_Factor &sf, @@ -14,18 +14,15 @@ void Setup_Psi_pw::before_runner( const pseudopot_cell_vnl &ppcell, const Input_para &inp) { - //! Allocate and initialize psi this->p_psi_init = new psi::PSIPrepare(inp.init_wfc, inp.ks_solver, inp.basis_type, GlobalV::MY_RANK, ucell, sf, kv, ppcell, pw_wfc); - //! Allocate memory for cpu version of psi allocate_psi(this->psi_cpu, kv.get_nks(), kv.ngk, PARAM.globalv.nbands_l, pw_wfc.npwk_max); auto* p_psi_init = static_cast*>(this->p_psi_init); p_psi_init->prepare_init(inp.pw_seed); - //! Set runtime type information if (std::is_same::value) { precision_type_ = PrecisionType::Float; } else if (std::is_same::value) { @@ -42,8 +39,6 @@ void Setup_Psi_pw::before_runner( device_type_ = base_device::CpuDevice; } - //! If GPU or single precision, allocate a new psi (psi_t). - //! otherwise, transform psi_cpu to psi_t if (inp.device == "gpu" || inp.precision == "single") { this->psi_t = static_cast(new psi::Psi(this->psi_cpu[0])); } else { @@ -51,6 +46,39 @@ void Setup_Psi_pw::before_runner( } } +void Setup_Psi_pw::before_runner( + const UnitCell &ucell, + const K_Vectors &kv, + const Structure_Factor &sf, + const ModulePW::PW_Basis_K &pw_wfc, + const pseudopot_cell_vnl &ppcell, + const Input_para &inp) +{ + const bool is_gpu = (inp.device == "gpu"); + const bool is_single = (inp.precision == "single"); + +#if ((defined __CUDA) || (defined __ROCM)) + if (is_gpu) { + if (is_single) { + before_runner_impl, base_device::DEVICE_GPU>( + ucell, kv, sf, pw_wfc, ppcell, inp); + } else { + before_runner_impl, base_device::DEVICE_GPU>( + ucell, kv, sf, pw_wfc, ppcell, inp); + } + } else +#endif + { + if (is_single) { + before_runner_impl, base_device::DEVICE_CPU>( + ucell, kv, sf, pw_wfc, ppcell, inp); + } else { + before_runner_impl, base_device::DEVICE_CPU>( + ucell, kv, sf, pw_wfc, ppcell, inp); + } + } +} + template void Setup_Psi_pw::update_psi_d() @@ -132,11 +160,11 @@ void Setup_Psi_pw::clean() template class psi::PSIPrepare, base_device::DEVICE_CPU>; template class psi::PSIPrepare, base_device::DEVICE_CPU>; -template void Setup_Psi_pw::before_runner, base_device::DEVICE_CPU>( +template void Setup_Psi_pw::before_runner_impl, base_device::DEVICE_CPU>( const UnitCell&, const K_Vectors&, const Structure_Factor&, const ModulePW::PW_Basis_K&, const pseudopot_cell_vnl&, const Input_para&); -template void Setup_Psi_pw::before_runner, base_device::DEVICE_CPU>( +template void Setup_Psi_pw::before_runner_impl, base_device::DEVICE_CPU>( const UnitCell&, const K_Vectors&, const Structure_Factor&, const ModulePW::PW_Basis_K&, const pseudopot_cell_vnl&, const Input_para&); @@ -176,11 +204,11 @@ template void Setup_Psi_pw::castmem_d2h_impl, base_device:: template class psi::PSIPrepare, base_device::DEVICE_GPU>; template class psi::PSIPrepare, base_device::DEVICE_GPU>; -template void Setup_Psi_pw::before_runner, base_device::DEVICE_GPU>( +template void Setup_Psi_pw::before_runner_impl, base_device::DEVICE_GPU>( const UnitCell&, const K_Vectors&, const Structure_Factor&, const ModulePW::PW_Basis_K&, const pseudopot_cell_vnl&, const Input_para&); -template void Setup_Psi_pw::before_runner, base_device::DEVICE_GPU>( +template void Setup_Psi_pw::before_runner_impl, base_device::DEVICE_GPU>( const UnitCell&, const K_Vectors&, const Structure_Factor&, const ModulePW::PW_Basis_K&, const pseudopot_cell_vnl&, const Input_para&); diff --git a/source/source_psi/setup_psi_pw.h b/source/source_psi/setup_psi_pw.h index 8c7bd59295..f2a191a8b7 100644 --- a/source/source_psi/setup_psi_pw.h +++ b/source/source_psi/setup_psi_pw.h @@ -61,7 +61,6 @@ class Setup_Psi_pw // functions //------------ - template void before_runner( const UnitCell &ucell, const K_Vectors &kv, @@ -117,6 +116,15 @@ class Setup_Psi_pw private: + template + void before_runner_impl( + const UnitCell &ucell, + const K_Vectors &kv, + const Structure_Factor &sf, + const ModulePW::PW_Basis_K &pw_wfc, + const pseudopot_cell_vnl &ppcell, + const Input_para &inp); + template void castmem_d2h_impl(std::complex* dst, const std::complex* src, const size_t size); From 749c03960be6770e465d6594df3f49c1816cf14d Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Sat, 14 Mar 2026 14:29:13 +0800 Subject: [PATCH 33/40] refactor(psi): convert init to non-template function Convert Setup_Psi_pw::init from template function to non-template function with runtime type dispatch based on device_type_ and precision_type_ member variables. Changes: 1. setup_psi_pw.h: - Change init parameter from Hamilt* to HamiltBase* - Add private template function init_impl for internal use 2. setup_psi_pw.cpp: - Rename original init to init_impl (template) - Add new non-template init that dispatches based on: - device_type_ (GpuDevice or CpuDevice) - precision_type_ (ComplexFloat or ComplexDouble) - Update template instantiation from init to init_impl 3. esolver_ks_pw.cpp: - Update call from stp.init(...) to stp.init(...) Design principle: before_runner sets device_type_ and precision_type_, subsequent functions use these member variables for runtime dispatch. --- source/source_esolver/esolver_ks_pw.cpp | 2 +- source/source_psi/setup_psi_pw.cpp | 48 +++++++++++++++++++++---- source/source_psi/setup_psi_pw.h | 6 ++-- 3 files changed, 47 insertions(+), 9 deletions(-) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index cb8f658a85..34ae33ddc3 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -131,7 +131,7 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) this->stp.template get_psi_t(), static_cast*>(this->p_hamilt), this->pw_wfc, this->pw_rhod, PARAM.inp); // setup psi (electronic wave functions) - this->stp.init(static_cast*>(this->p_hamilt)); + this->stp.init(this->p_hamilt); //! Setup EXX helper for Hamiltonian and psi exx_helper.before_scf(this->p_hamilt, this->stp.template get_psi_t(), PARAM.inp); diff --git a/source/source_psi/setup_psi_pw.cpp b/source/source_psi/setup_psi_pw.cpp index 3f8579da3e..b6a5d38299 100644 --- a/source/source_psi/setup_psi_pw.cpp +++ b/source/source_psi/setup_psi_pw.cpp @@ -97,9 +97,8 @@ void Setup_Psi_pw::update_psi_d() } template -void Setup_Psi_pw::init(hamilt::Hamilt* p_hamilt) +void Setup_Psi_pw::init_impl(hamilt::Hamilt* p_hamilt) { - //! Initialize wave functions if (!this->already_initpsi) { auto* p_psi_init = static_cast*>(this->p_psi_init); @@ -108,6 +107,43 @@ void Setup_Psi_pw::init(hamilt::Hamilt* p_hamilt) } } +void Setup_Psi_pw::init(hamilt::HamiltBase* p_hamilt) +{ + if (this->already_initpsi) + { + return; + } + +#if ((defined __CUDA) || (defined __ROCM)) + if (this->device_type_ == base_device::GpuDevice) + { + if (this->precision_type_ == PrecisionType::ComplexFloat) + { + init_impl, base_device::DEVICE_GPU>( + static_cast, base_device::DEVICE_GPU>*>(p_hamilt)); + } + else + { + init_impl, base_device::DEVICE_GPU>( + static_cast, base_device::DEVICE_GPU>*>(p_hamilt)); + } + } + else +#endif + { + if (this->precision_type_ == PrecisionType::ComplexFloat) + { + init_impl, base_device::DEVICE_CPU>( + static_cast, base_device::DEVICE_CPU>*>(p_hamilt)); + } + else + { + init_impl, base_device::DEVICE_CPU>( + static_cast, base_device::DEVICE_CPU>*>(p_hamilt)); + } + } +} + // Transfer data from GPU to CPU in pw basis (runtime version) template @@ -168,10 +204,10 @@ template void Setup_Psi_pw::before_runner_impl, base_device const UnitCell&, const K_Vectors&, const Structure_Factor&, const ModulePW::PW_Basis_K&, const pseudopot_cell_vnl&, const Input_para&); -template void Setup_Psi_pw::init, base_device::DEVICE_CPU>( +template void Setup_Psi_pw::init_impl, base_device::DEVICE_CPU>( hamilt::Hamilt, base_device::DEVICE_CPU>*); -template void Setup_Psi_pw::init, base_device::DEVICE_CPU>( +template void Setup_Psi_pw::init_impl, base_device::DEVICE_CPU>( hamilt::Hamilt, base_device::DEVICE_CPU>*); template void Setup_Psi_pw::update_psi_d, base_device::DEVICE_CPU>(); @@ -212,10 +248,10 @@ template void Setup_Psi_pw::before_runner_impl, base_device const UnitCell&, const K_Vectors&, const Structure_Factor&, const ModulePW::PW_Basis_K&, const pseudopot_cell_vnl&, const Input_para&); -template void Setup_Psi_pw::init, base_device::DEVICE_GPU>( +template void Setup_Psi_pw::init_impl, base_device::DEVICE_GPU>( hamilt::Hamilt, base_device::DEVICE_GPU>*); -template void Setup_Psi_pw::init, base_device::DEVICE_GPU>( +template void Setup_Psi_pw::init_impl, base_device::DEVICE_GPU>( hamilt::Hamilt, base_device::DEVICE_GPU>*); template void Setup_Psi_pw::update_psi_d, base_device::DEVICE_GPU>(); diff --git a/source/source_psi/setup_psi_pw.h b/source/source_psi/setup_psi_pw.h index f2a191a8b7..26b4397438 100644 --- a/source/source_psi/setup_psi_pw.h +++ b/source/source_psi/setup_psi_pw.h @@ -69,8 +69,7 @@ class Setup_Psi_pw const pseudopot_cell_vnl &ppcell, const Input_para &inp); - template - void init(hamilt::Hamilt* p_hamilt); + void init(hamilt::HamiltBase* p_hamilt); template void update_psi_d(); @@ -125,6 +124,9 @@ class Setup_Psi_pw const pseudopot_cell_vnl &ppcell, const Input_para &inp); + template + void init_impl(hamilt::Hamilt* p_hamilt); + template void castmem_d2h_impl(std::complex* dst, const std::complex* src, const size_t size); From caad403ab23656601a59af232111321145d1e9e3 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Sat, 14 Mar 2026 14:56:46 +0800 Subject: [PATCH 34/40] refactor(psi): move private member variables to private section Move the following member variables from public to private: - psi_t: accessible via get_psi_t() - psi_d: accessible via get_psi_d() - already_initpsi: internal use only - device_type_: accessible via get_device_type() - precision_type_: accessible via get_precision_type() Keep the following in public: - psi_cpu: directly accessed by 14 external locations - p_psi_init: directly accessed by 3 external locations - PrecisionType enum: used as return type of get_precision_type() This improves encapsulation while maintaining backward compatibility through accessor functions. --- source/source_psi/setup_psi_pw.h | 62 ++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 27 deletions(-) diff --git a/source/source_psi/setup_psi_pw.h b/source/source_psi/setup_psi_pw.h index 26b4397438..b55b34831a 100644 --- a/source/source_psi/setup_psi_pw.h +++ b/source/source_psi/setup_psi_pw.h @@ -18,45 +18,30 @@ class Setup_Psi_pw Setup_Psi_pw(); ~Setup_Psi_pw(); + //------------ + // public types + //------------ + + // Precision type: 0 = float, 1 = double, 2 = complex, 3 = complex + enum class PrecisionType { + Float = 0, + Double = 1, + ComplexFloat = 2, + ComplexDouble = 3 + }; + //------------ // variables // psi_cpu, complex on cpu - // psi_t, complex on cpu/gpu - // psi_d, complex on cpu/gpu //------------ // originally, this term is psi // for PW, we have psi_cpu psi::Psi, base_device::DEVICE_CPU>* psi_cpu = nullptr; - // originally, this term is kspw_psi - // if CPU, kspw_psi = psi, otherwise, kspw_psi has a new copy - // psi::Psi* psi_t = nullptr; // Original template version - void* psi_t = nullptr; // Use void* to store pointer, runtime type information records actual type - - // originally, this term is __kspw_psi - // psi::Psi, Device>* psi_d = nullptr; // Original template version - void* psi_d = nullptr; // Use void* to store pointer, runtime type information records actual type - // psi_initializer controller psi::PSIPrepareBase* p_psi_init = nullptr; - bool already_initpsi = false; - - //------------ - // runtime type information - //------------ - base_device::AbacusDevice_t device_type_ = base_device::CpuDevice; - - // Precision type: 0 = float, 1 = double, 2 = complex, 3 = complex - enum class PrecisionType { - Float = 0, - Double = 1, - ComplexFloat = 2, - ComplexDouble = 3 - }; - PrecisionType precision_type_ = PrecisionType::ComplexDouble; - //------------ // functions //------------ @@ -115,6 +100,29 @@ class Setup_Psi_pw private: + //------------ + // private variables + //------------ + + // originally, this term is kspw_psi + // if CPU, kspw_psi = psi, otherwise, kspw_psi has a new copy + void* psi_t = nullptr; // Use void* to store pointer, runtime type information records actual type + + // originally, this term is __kspw_psi + void* psi_d = nullptr; // Use void* to store pointer, runtime type information records actual type + + bool already_initpsi = false; + + //------------ + // runtime type information + //------------ + base_device::AbacusDevice_t device_type_ = base_device::CpuDevice; + PrecisionType precision_type_ = PrecisionType::ComplexDouble; + + //------------ + // private functions + //------------ + template void before_runner_impl( const UnitCell &ucell, From 78ffa5c495cb77d1dc243c458994f945a1b41ba8 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Sat, 14 Mar 2026 15:14:23 +0800 Subject: [PATCH 35/40] refactor(psi): convert clean to non-template function Convert Setup_Psi_pw::clean from template function to non-template function with runtime type dispatch based on device_type_ and precision_type_ member variables. Changes: 1. setup_psi_pw.h: - Remove template parameters from clean declaration - Add private template function clean_impl for internal use 2. setup_psi_pw.cpp: - Rename original clean to clean_impl (template) - Add new non-template clean that dispatches based on: - device_type_ (GpuDevice or CpuDevice) - precision_type_ (ComplexFloat or ComplexDouble) - Replace PARAM.inp.device/precision checks with member variables - Update template instantiation from clean to clean_impl 3. esolver_ks_pw.cpp: - Update call from stp.clean() to stp.clean() --- source/source_esolver/esolver_ks_pw.cpp | 2 +- source/source_psi/setup_psi_pw.cpp | 44 ++++++++++++++++++++----- source/source_psi/setup_psi_pw.h | 4 ++- 3 files changed, 39 insertions(+), 11 deletions(-) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 34ae33ddc3..36086bc149 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -55,7 +55,7 @@ ESolver_KS_PW::~ESolver_KS_PW() this->deallocate_hamilt(); // mohan add 2025-10-12 - this->stp.clean(); + this->stp.clean(); } template diff --git a/source/source_psi/setup_psi_pw.cpp b/source/source_psi/setup_psi_pw.cpp index b6a5d38299..c91ade9069 100644 --- a/source/source_psi/setup_psi_pw.cpp +++ b/source/source_psi/setup_psi_pw.cpp @@ -175,16 +175,14 @@ void Setup_Psi_pw::castmem_d2h_impl(std::complex* dst, const std::comple base_device::memory::cast_memory_op, std::complex, base_device::DEVICE_CPU, Device>()(dst, src, size); } - - template -void Setup_Psi_pw::clean() +void Setup_Psi_pw::clean_impl() { - if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") + if (this->device_type_ == base_device::GpuDevice || this->precision_type_ == PrecisionType::ComplexFloat) { delete this->get_psi_t(); } - if (PARAM.inp.precision == "single") + if (this->precision_type_ == PrecisionType::ComplexFloat) { delete this->get_psi_d(); } @@ -193,6 +191,34 @@ void Setup_Psi_pw::clean() delete this->p_psi_init; } +void Setup_Psi_pw::clean() +{ +#if ((defined __CUDA) || (defined __ROCM)) + if (this->device_type_ == base_device::GpuDevice) + { + if (this->precision_type_ == PrecisionType::ComplexFloat) + { + clean_impl, base_device::DEVICE_GPU>(); + } + else + { + clean_impl, base_device::DEVICE_GPU>(); + } + } + else +#endif + { + if (this->precision_type_ == PrecisionType::ComplexFloat) + { + clean_impl, base_device::DEVICE_CPU>(); + } + else + { + clean_impl, base_device::DEVICE_CPU>(); + } + } +} + template class psi::PSIPrepare, base_device::DEVICE_CPU>; template class psi::PSIPrepare, base_device::DEVICE_CPU>; @@ -220,9 +246,9 @@ template void Setup_Psi_pw::copy_d2h, base_device::DEVICE_CP template void Setup_Psi_pw::copy_d2h, base_device::DEVICE_CPU>( const base_device::DeviceContext*); -template void Setup_Psi_pw::clean, base_device::DEVICE_CPU>(); +template void Setup_Psi_pw::clean_impl, base_device::DEVICE_CPU>(); -template void Setup_Psi_pw::clean, base_device::DEVICE_CPU>(); +template void Setup_Psi_pw::clean_impl, base_device::DEVICE_CPU>(); template void Setup_Psi_pw::castmem_d2h_impl, base_device::DEVICE_CPU>( std::complex*, const std::complex*, const size_t); @@ -264,9 +290,9 @@ template void Setup_Psi_pw::copy_d2h, base_device::DEVICE_GP template void Setup_Psi_pw::copy_d2h, base_device::DEVICE_GPU>( const base_device::DeviceContext*); -template void Setup_Psi_pw::clean, base_device::DEVICE_GPU>(); +template void Setup_Psi_pw::clean_impl, base_device::DEVICE_GPU>(); -template void Setup_Psi_pw::clean, base_device::DEVICE_GPU>(); +template void Setup_Psi_pw::clean_impl, base_device::DEVICE_GPU>(); template void Setup_Psi_pw::castmem_d2h_impl, base_device::DEVICE_GPU>( std::complex*, const std::complex*, const size_t); diff --git a/source/source_psi/setup_psi_pw.h b/source/source_psi/setup_psi_pw.h index b55b34831a..d25d590bba 100644 --- a/source/source_psi/setup_psi_pw.h +++ b/source/source_psi/setup_psi_pw.h @@ -63,7 +63,6 @@ class Setup_Psi_pw template void copy_d2h(const base_device::DeviceContext* ctx); - template void clean(); //------------ @@ -135,6 +134,9 @@ class Setup_Psi_pw template void init_impl(hamilt::Hamilt* p_hamilt); + template + void clean_impl(); + template void castmem_d2h_impl(std::complex* dst, const std::complex* src, const size_t size); From a03ab9e0dfbb810f9729769a2f05d094a58b4801 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Sat, 14 Mar 2026 15:27:12 +0800 Subject: [PATCH 36/40] refactor(psi): convert copy_d2h to non-template function Convert Setup_Psi_pw::copy_d2h from template function to non-template function with runtime type dispatch based on device_type_ and precision_type_ member variables. Changes: 1. setup_psi_pw.h: - Remove template parameters and DeviceContext parameter from copy_d2h - Add private template function copy_d2h_impl for internal use 2. setup_psi_pw.cpp: - Rename original copy_d2h to copy_d2h_impl (template) - Add new non-template copy_d2h that: - Returns early if device_type_ is not GpuDevice - Dispatches based on precision_type_ (ComplexFloat or ComplexDouble) - Update template instantiation from copy_d2h to copy_d2h_impl 3. ctrl_output_pw.cpp: - Simplify call from stp.template copy_d2h(device_ctx) to stp.copy_d2h() - Remove DeviceContext setup code for copy_d2h --- .../source_io/module_ctrl/ctrl_output_pw.cpp | 4 +- source/source_psi/setup_psi_pw.cpp | 41 ++++++++++--------- source/source_psi/setup_psi_pw.h | 8 ++-- 3 files changed, 28 insertions(+), 25 deletions(-) diff --git a/source/source_io/module_ctrl/ctrl_output_pw.cpp b/source/source_io/module_ctrl/ctrl_output_pw.cpp index 6bbf1e4329..92df73310a 100644 --- a/source/source_io/module_ctrl/ctrl_output_pw.cpp +++ b/source/source_io/module_ctrl/ctrl_output_pw.cpp @@ -101,9 +101,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, Device* ctx = nullptr; // Transfer data from device (GPU) to host (CPU) in pw basis - base_device::DeviceContext* device_ctx = &base_device::DeviceContext::instance(); - device_ctx->set_device_type(stp.get_device_type()); - stp.template copy_d2h(device_ctx); + stp.copy_d2h(); //---------------------------------------------------------- //! 4) Compute density of states (DOS) diff --git a/source/source_psi/setup_psi_pw.cpp b/source/source_psi/setup_psi_pw.cpp index c91ade9069..ae8e7d9f18 100644 --- a/source/source_psi/setup_psi_pw.cpp +++ b/source/source_psi/setup_psi_pw.cpp @@ -145,22 +145,33 @@ void Setup_Psi_pw::init(hamilt::HamiltBase* p_hamilt) } -// Transfer data from GPU to CPU in pw basis (runtime version) +// Transfer data from GPU to CPU in pw basis template -void Setup_Psi_pw::copy_d2h(const base_device::DeviceContext* ctx) +void Setup_Psi_pw::copy_d2h_impl() { - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + auto* psi_t = this->get_psi_t(); + this->castmem_d2h_impl(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), + psi_t->get_pointer() - psi_t->get_psi_bias(), + this->psi_cpu[0].size()); +} + +void Setup_Psi_pw::copy_d2h() +{ + if (this->device_type_ != base_device::GpuDevice) + { + return; + } + +#if ((defined __CUDA) || (defined __ROCM)) + if (this->precision_type_ == PrecisionType::ComplexFloat) { - auto* psi_t = this->get_psi_t(); - this->castmem_d2h_impl(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), - psi_t->get_pointer() - psi_t->get_psi_bias(), - this->psi_cpu[0].size()); + copy_d2h_impl, base_device::DEVICE_GPU>(); } else { - // do nothing + copy_d2h_impl, base_device::DEVICE_GPU>(); } - return; +#endif } template @@ -240,12 +251,6 @@ template void Setup_Psi_pw::update_psi_d, base_device::DEVIC template void Setup_Psi_pw::update_psi_d, base_device::DEVICE_CPU>(); -template void Setup_Psi_pw::copy_d2h, base_device::DEVICE_CPU>( - const base_device::DeviceContext*); - -template void Setup_Psi_pw::copy_d2h, base_device::DEVICE_CPU>( - const base_device::DeviceContext*); - template void Setup_Psi_pw::clean_impl, base_device::DEVICE_CPU>(); template void Setup_Psi_pw::clean_impl, base_device::DEVICE_CPU>(); @@ -284,11 +289,9 @@ template void Setup_Psi_pw::update_psi_d, base_device::DEVIC template void Setup_Psi_pw::update_psi_d, base_device::DEVICE_GPU>(); -template void Setup_Psi_pw::copy_d2h, base_device::DEVICE_GPU>( - const base_device::DeviceContext*); +template void Setup_Psi_pw::copy_d2h_impl, base_device::DEVICE_GPU>(); -template void Setup_Psi_pw::copy_d2h, base_device::DEVICE_GPU>( - const base_device::DeviceContext*); +template void Setup_Psi_pw::copy_d2h_impl, base_device::DEVICE_GPU>(); template void Setup_Psi_pw::clean_impl, base_device::DEVICE_GPU>(); diff --git a/source/source_psi/setup_psi_pw.h b/source/source_psi/setup_psi_pw.h index d25d590bba..2f8ac674f4 100644 --- a/source/source_psi/setup_psi_pw.h +++ b/source/source_psi/setup_psi_pw.h @@ -59,9 +59,8 @@ class Setup_Psi_pw template void update_psi_d(); - // Transfer data from device to host in pw basis (runtime version) - template - void copy_d2h(const base_device::DeviceContext* ctx); + // Transfer data from device to host in pw basis + void copy_d2h(); void clean(); @@ -137,6 +136,9 @@ class Setup_Psi_pw template void clean_impl(); + template + void copy_d2h_impl(); + template void castmem_d2h_impl(std::complex* dst, const std::complex* src, const size_t size); From a4add5aec29068715b0b03c647ec4cb955824e44 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Sat, 14 Mar 2026 15:44:20 +0800 Subject: [PATCH 37/40] refactor(psi): convert update_psi_d to non-template function Convert Setup_Psi_pw::update_psi_d from template function to non-template function with runtime type dispatch based on device_type_ and precision_type_ member variables. Changes: 1. setup_psi_pw.h: - Remove template parameters from update_psi_d declaration - Add private template function update_psi_d_impl for internal use 2. setup_psi_pw.cpp: - Rename original update_psi_d to update_psi_d_impl (template) - Add new non-template update_psi_d that dispatches based on: - device_type_ (GpuDevice or CpuDevice) - precision_type_ (ComplexFloat or ComplexDouble) - Replace PARAM.inp.precision checks with precision_type_ member variable - Update template instantiation from update_psi_d to update_psi_d_impl 3. esolver_ks_pw.cpp: - Update calls from stp.update_psi_d() to stp.update_psi_d() 4. ctrl_output_pw.cpp: - Update calls from stp.template update_psi_d() to stp.update_psi_d() This completes the refactoring of all Setup_Psi_pw member functions to use runtime type dispatch instead of template parameters. --- source/source_esolver/esolver_ks_pw.cpp | 4 +- .../source_io/module_ctrl/ctrl_output_pw.cpp | 4 +- source/source_psi/setup_psi_pw.cpp | 42 +++++++++++++++---- source/source_psi/setup_psi_pw.h | 4 +- 4 files changed, 42 insertions(+), 12 deletions(-) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 36086bc149..04eb629962 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -268,7 +268,7 @@ void ESolver_KS_PW::cal_force(UnitCell& ucell, ModuleBase::matrix& fo Forces ff(ucell.nat); // mohan add 2025-10-12 - this->stp.update_psi_d(); + this->stp.update_psi_d(); // Calculate forces ff.cal_force(ucell, force, *this->pelec, this->pw_rhod, &ucell.symm, @@ -282,7 +282,7 @@ void ESolver_KS_PW::cal_stress(UnitCell& ucell, ModuleBase::matrix& s Stress_PW ss(this->pelec); // mohan add 2025-10-12 - this->stp.update_psi_d(); + this->stp.update_psi_d(); ss.cal_stress(stress, ucell, this->dftu, this->locpp, this->ppcell, this->pw_rhod, &ucell.symm, &this->sf, &this->kv, this->pw_wfc, this->stp.template get_psi_d()); diff --git a/source/source_io/module_ctrl/ctrl_output_pw.cpp b/source/source_io/module_ctrl/ctrl_output_pw.cpp index 92df73310a..1020e7aef6 100644 --- a/source/source_io/module_ctrl/ctrl_output_pw.cpp +++ b/source/source_io/module_ctrl/ctrl_output_pw.cpp @@ -164,7 +164,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, if (inp.out_pchg.size() > 0) { // update psi_d - stp.template update_psi_d(); + stp.update_psi_d(); const int nbands = stp.get_nbands(); const int ngmc = chr.ngmc; @@ -303,7 +303,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, //---------------------------------------------------------- if (inp.out_wfc_norm.size() > 0 || inp.out_wfc_re_im.size() > 0) { - stp.template update_psi_d(); + stp.update_psi_d(); ModuleIO::get_wf_pw(inp.out_wfc_norm, inp.out_wfc_re_im, diff --git a/source/source_psi/setup_psi_pw.cpp b/source/source_psi/setup_psi_pw.cpp index ae8e7d9f18..f5bc240292 100644 --- a/source/source_psi/setup_psi_pw.cpp +++ b/source/source_psi/setup_psi_pw.cpp @@ -81,21 +81,49 @@ void Setup_Psi_pw::before_runner( template -void Setup_Psi_pw::update_psi_d() +void Setup_Psi_pw::update_psi_d_impl() { - if (this->psi_d != nullptr && PARAM.inp.precision == "single") + if (this->psi_d != nullptr && this->precision_type_ == PrecisionType::ComplexFloat) { delete this->get_psi_d(); } // Refresh this->psi_d - if (PARAM.inp.precision == "single") { + if (this->precision_type_ == PrecisionType::ComplexFloat) { this->psi_d = static_cast(new psi::Psi, Device>(*this->get_psi_t())); } else { this->psi_d = static_cast(reinterpret_cast, Device>*>(this->psi_t)); } } +void Setup_Psi_pw::update_psi_d() +{ +#if ((defined __CUDA) || (defined __ROCM)) + if (this->device_type_ == base_device::GpuDevice) + { + if (this->precision_type_ == PrecisionType::ComplexFloat) + { + update_psi_d_impl, base_device::DEVICE_GPU>(); + } + else + { + update_psi_d_impl, base_device::DEVICE_GPU>(); + } + } + else +#endif + { + if (this->precision_type_ == PrecisionType::ComplexFloat) + { + update_psi_d_impl, base_device::DEVICE_CPU>(); + } + else + { + update_psi_d_impl, base_device::DEVICE_CPU>(); + } + } +} + template void Setup_Psi_pw::init_impl(hamilt::Hamilt* p_hamilt) { @@ -247,9 +275,9 @@ template void Setup_Psi_pw::init_impl, base_device::DEVICE_C template void Setup_Psi_pw::init_impl, base_device::DEVICE_CPU>( hamilt::Hamilt, base_device::DEVICE_CPU>*); -template void Setup_Psi_pw::update_psi_d, base_device::DEVICE_CPU>(); +template void Setup_Psi_pw::update_psi_d_impl, base_device::DEVICE_CPU>(); -template void Setup_Psi_pw::update_psi_d, base_device::DEVICE_CPU>(); +template void Setup_Psi_pw::update_psi_d_impl, base_device::DEVICE_CPU>(); template void Setup_Psi_pw::clean_impl, base_device::DEVICE_CPU>(); @@ -285,9 +313,9 @@ template void Setup_Psi_pw::init_impl, base_device::DEVICE_G template void Setup_Psi_pw::init_impl, base_device::DEVICE_GPU>( hamilt::Hamilt, base_device::DEVICE_GPU>*); -template void Setup_Psi_pw::update_psi_d, base_device::DEVICE_GPU>(); +template void Setup_Psi_pw::update_psi_d_impl, base_device::DEVICE_GPU>(); -template void Setup_Psi_pw::update_psi_d, base_device::DEVICE_GPU>(); +template void Setup_Psi_pw::update_psi_d_impl, base_device::DEVICE_GPU>(); template void Setup_Psi_pw::copy_d2h_impl, base_device::DEVICE_GPU>(); diff --git a/source/source_psi/setup_psi_pw.h b/source/source_psi/setup_psi_pw.h index 2f8ac674f4..88e9d42bf1 100644 --- a/source/source_psi/setup_psi_pw.h +++ b/source/source_psi/setup_psi_pw.h @@ -56,7 +56,6 @@ class Setup_Psi_pw void init(hamilt::HamiltBase* p_hamilt); - template void update_psi_d(); // Transfer data from device to host in pw basis @@ -133,6 +132,9 @@ class Setup_Psi_pw template void init_impl(hamilt::Hamilt* p_hamilt); + template + void update_psi_d_impl(); + template void clean_impl(); From d9ffbb3cc862c09c86215dcfaaba40047665d4fa Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Sat, 14 Mar 2026 16:17:01 +0800 Subject: [PATCH 38/40] refactor(exx): convert Exx_Helper to runtime polymorphism Convert Exx_Helper from template member to runtime polymorphic pointer using a base class pattern similar to Setup_Psi_pw. Changes: 1. exx_helper_base.h (new): - Create pure virtual base class Exx_HelperBase - Use void* for template-dependent parameters 2. exx_helper.h: - Exx_Helper now inherits from Exx_HelperBase - All public methods marked as override 3. exx_helper.cpp: - Update function signatures to use void* parameters - Add static_cast for type conversion 4. esolver_ks_pw.h: - Change Exx_Helper exx_helper to Exx_HelperBase* exx_helper 5. esolver_ks_pw.cpp: - Create concrete Exx_Helper instance based on inp.device and inp.precision - Delete exx_helper in destructor Benefits: - ESolver_KS_PW no longer requires Exx_Helper template parameters - Type determined at runtime from input parameters - Consistent with Setup_Psi_pw refactoring pattern --- source/source_esolver/esolver_ks_pw.cpp | 46 +++++++++++++-- source/source_esolver/esolver_ks_pw.h | 4 +- source/source_pw/module_pwdft/exx_helper.cpp | 16 +++--- source/source_pw/module_pwdft/exx_helper.h | 56 ++++++------------- .../source_pw/module_pwdft/exx_helper_base.h | 41 ++++++++++++++ 5 files changed, 111 insertions(+), 52 deletions(-) create mode 100644 source/source_pw/module_pwdft/exx_helper_base.h diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 04eb629962..9c1f8cd8f9 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -54,6 +54,13 @@ ESolver_KS_PW::~ESolver_KS_PW() // delete Hamilt this->deallocate_hamilt(); + // delete exx_helper + if (this->exx_helper != nullptr) + { + delete this->exx_helper; + this->exx_helper = nullptr; + } + // mohan add 2025-10-12 this->stp.clean(); } @@ -94,8 +101,37 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT BASIS"); + //! Create exx_helper based on device and precision + const bool is_gpu = (inp.device == "gpu"); + const bool is_single = (inp.precision == "single"); + +#if ((defined __CUDA) || (defined __ROCM)) + if (is_gpu) + { + if (is_single) + { + this->exx_helper = new Exx_Helper, base_device::DEVICE_GPU>(); + } + else + { + this->exx_helper = new Exx_Helper, base_device::DEVICE_GPU>(); + } + } + else +#endif + { + if (is_single) + { + this->exx_helper = new Exx_Helper, base_device::DEVICE_CPU>(); + } + else + { + this->exx_helper = new Exx_Helper, base_device::DEVICE_CPU>(); + } + } + //! Initialize exx pw - this->exx_helper.init(ucell, inp, this->pelec->wg); + this->exx_helper->init(ucell, inp, this->pelec->wg); } template @@ -134,7 +170,7 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) this->stp.init(this->p_hamilt); //! Setup EXX helper for Hamiltonian and psi - exx_helper.before_scf(this->p_hamilt, this->stp.template get_psi_t(), PARAM.inp); + exx_helper->before_scf(this->p_hamilt, this->stp.template get_psi_t(), PARAM.inp); ModuleBase::timer::tick("ESolver_KS_PW", "before_scf"); } @@ -203,9 +239,9 @@ template void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int& iter, bool& conv_esolver) { // Related to EXX - if (GlobalC::exx_info.info_global.cal_exx && !exx_helper.op_exx->first_iter) + if (GlobalC::exx_info.info_global.cal_exx && !exx_helper->get_op_first_iter()) { - this->pelec->set_exx(exx_helper.cal_exx_energy(this->stp.template get_psi_t())); + this->pelec->set_exx(exx_helper->cal_exx_energy(this->stp.template get_psi_t())); } // deband is calculated from "output" charge density @@ -224,7 +260,7 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int } // Handle EXX-related operations after SCF iteration - exx_helper.iter_finish(this->pelec, &this->chr, this->stp.template get_psi_t(), ucell, PARAM.inp, conv_esolver, iter); + exx_helper->iter_finish(this->pelec, &this->chr, this->stp.template get_psi_t(), ucell, PARAM.inp, conv_esolver, iter); // check if oscillate for delta_spin method pw::check_deltaspin_oscillation(iter, this->drho, this->p_chgmix, PARAM.inp); diff --git a/source/source_esolver/esolver_ks_pw.h b/source/source_esolver/esolver_ks_pw.h index 16d6530e2b..f527101d68 100644 --- a/source/source_esolver/esolver_ks_pw.h +++ b/source/source_esolver/esolver_ks_pw.h @@ -3,7 +3,7 @@ #include "./esolver_ks.h" #include "source_psi/setup_psi_pw.h" // mohan add 20251012 #include "source_pw/module_pwdft/vsep_pw.h" -#include "source_pw/module_pwdft/exx_helper.h" +#include "source_pw/module_pwdft/exx_helper_base.h" #include "source_pw/module_pwdft/op_pw_vel.h" #include @@ -33,7 +33,7 @@ class ESolver_KS_PW : public ESolver_KS void after_all_runners(UnitCell& ucell) override; - Exx_Helper exx_helper; + Exx_HelperBase* exx_helper = nullptr; protected: virtual void before_scf(UnitCell& ucell, const int istep) override; diff --git a/source/source_pw/module_pwdft/exx_helper.cpp b/source/source_pw/module_pwdft/exx_helper.cpp index 2af13d3173..fa80a7ff10 100644 --- a/source/source_pw/module_pwdft/exx_helper.cpp +++ b/source/source_pw/module_pwdft/exx_helper.cpp @@ -32,7 +32,7 @@ void Exx_Helper::init(const UnitCell& ucell, const Input_para& inp, c } template -void Exx_Helper::before_scf(void* p_hamilt, psi::Psi* psi, const Input_para& inp) +void Exx_Helper::before_scf(void* p_hamilt, void* psi, const Input_para& inp) { /// Return if not a valid calculation type if (inp.calculation != "scf" && inp.calculation != "relax" @@ -56,7 +56,7 @@ void Exx_Helper::before_scf(void* p_hamilt, psi::Psi* psi, } template -bool Exx_Helper::iter_finish(void* p_elec, Charge* p_charge, psi::Psi* psi, +bool Exx_Helper::iter_finish(void* p_elec, Charge* p_charge, void* psi, UnitCell& ucell, const Input_para& inp, bool& conv_esolver, int& iter) { @@ -118,9 +118,9 @@ bool Exx_Helper::iter_finish(void* p_elec, Charge* p_charge, psi::Psi } template -double Exx_Helper::cal_exx_energy(psi::Psi *psi_) +double Exx_Helper::cal_exx_energy(void* psi_) { - return op_exx->cal_exx_energy(psi_); + return op_exx->cal_exx_energy(static_cast*>(psi_)); } @@ -156,11 +156,13 @@ bool Exx_Helper::exx_after_converge(int &iter, bool ene_conv) } template -void Exx_Helper::set_psi(psi::Psi *psi_) +void Exx_Helper::set_psi(void* psi_) { - if (psi_ == nullptr) + auto* psi = static_cast*>(psi_); + if (psi == nullptr) return; - op_exx->set_psi(*psi_); + this->psi = psi; + op_exx->set_psi(*psi); if (PARAM.inp.exxace && GlobalC::exx_info.info_global.separate_loop) { op_exx->construct_ace(); diff --git a/source/source_pw/module_pwdft/exx_helper.h b/source/source_pw/module_pwdft/exx_helper.h index 53056fdbbe..96f0a37a7f 100644 --- a/source/source_pw/module_pwdft/exx_helper.h +++ b/source/source_pw/module_pwdft/exx_helper.h @@ -2,6 +2,7 @@ #include "source_base/matrix.h" #include "source_pw/module_pwdft/op_pw_exx.h" #include "source_io/module_parameter/input_parameter.h" +#include "source_pw/module_pwdft/exx_helper_base.h" #ifndef EXX_HELPER_H #define EXX_HELPER_H @@ -9,64 +10,43 @@ class Charge; template -struct Exx_Helper +struct Exx_Helper : public Exx_HelperBase { using Real = typename GetTypeReal::type; using OperatorEXX = hamilt::OperatorEXXPW; public: Exx_Helper() = default; + virtual ~Exx_Helper() = default; OperatorEXX *op_exx = nullptr; - void init(const UnitCell& ucell, const Input_para& inp, const ModuleBase::matrix& wg); + void init(const UnitCell& ucell, const Input_para& inp, const ModuleBase::matrix& wg) override; - /** - * @brief Setup EXX helper before SCF iteration. - * - * This function sets up the EXX helper for the Hamiltonian and psi - * before each SCF iteration. It checks if the calculation type and - * EXX settings are appropriate. - * - * @param p_hamilt Pointer to the Hamiltonian object (void* to avoid circular dependency). - * @param psi Pointer to the wave function object. - * @param inp The input parameters. - */ - void before_scf(void* p_hamilt, psi::Psi* psi, const Input_para& inp); + void before_scf(void* p_hamilt, void* psi, const Input_para& inp) override; - /** - * @brief Handle EXX-related operations after SCF iteration. - * - * This function handles EXX convergence checking and potential update - * after each SCF iteration. It is called in iter_finish. - * - * @param p_elec Pointer to the ElecState object (void* to avoid circular dependency). - * @param p_charge Pointer to the Charge object. - * @param psi Pointer to the wave function object. - * @param ucell The unit cell (non-const reference for update_pot). - * @param inp The input parameters. - * @param conv_esolver Whether SCF is converged (may be modified). - * @param iter The current iteration number (may be modified). - * @return true if EXX processing was done, false otherwise. - */ - bool iter_finish(void* p_elec, Charge* p_charge, psi::Psi* psi, + bool iter_finish(void* p_elec, Charge* p_charge, void* psi, UnitCell& ucell, const Input_para& inp, - bool& conv_esolver, int& iter); + bool& conv_esolver, int& iter) override; - void set_firstiter(bool flag = true) { first_iter = flag; } - void set_wg(const ModuleBase::matrix *wg_) { wg = wg_; } - void set_psi(psi::Psi *psi_); - void iter_inc() { exx_iter++; } + void set_firstiter(bool flag = true) override { first_iter = flag; } + void set_wg(const ModuleBase::matrix *wg_) override { wg = wg_; } + void set_psi(void* psi_) override; + void iter_inc() override { exx_iter++; } - void set_op() + void set_op() override { op_exx->first_iter = first_iter; set_psi(psi); op_exx->set_wg(wg); } - bool exx_after_converge(int &iter, bool ene_conv); + bool exx_after_converge(int &iter, bool ene_conv) override; - double cal_exx_energy(psi::Psi *psi_); + double cal_exx_energy(void* psi_) override; + + bool get_op_first_iter() const override { return op_exx ? op_exx->first_iter : false; } + void set_op_first_iter(bool flag) override { if (op_exx) op_exx->first_iter = flag; } + void set_op_exx(void* op) override { op_exx = reinterpret_cast(op); } private: bool first_iter = false; diff --git a/source/source_pw/module_pwdft/exx_helper_base.h b/source/source_pw/module_pwdft/exx_helper_base.h new file mode 100644 index 0000000000..60eda82b35 --- /dev/null +++ b/source/source_pw/module_pwdft/exx_helper_base.h @@ -0,0 +1,41 @@ +#ifndef EXX_HELPER_BASE_H +#define EXX_HELPER_BASE_H + +#include "source_base/matrix.h" + +class Charge; +class UnitCell; +struct Input_para; + +class Exx_HelperBase +{ + public: + Exx_HelperBase() = default; + virtual ~Exx_HelperBase() = default; + + virtual void init(const UnitCell& ucell, const Input_para& inp, const ModuleBase::matrix& wg) = 0; + + virtual void before_scf(void* p_hamilt, void* psi, const Input_para& inp) = 0; + + virtual bool iter_finish(void* p_elec, Charge* p_charge, void* psi, + UnitCell& ucell, const Input_para& inp, + bool& conv_esolver, int& iter) = 0; + + virtual void set_firstiter(bool flag = true) = 0; + virtual void set_wg(const ModuleBase::matrix* wg) = 0; + virtual void set_psi(void* psi) = 0; + virtual void iter_inc() = 0; + + virtual void set_op() = 0; + + virtual bool exx_after_converge(int& iter, bool ene_conv) = 0; + + virtual double cal_exx_energy(void* psi) = 0; + + virtual bool get_op_first_iter() const = 0; + virtual void set_op_first_iter(bool flag) = 0; + + virtual void set_op_exx(void* op) = 0; +}; + +#endif // EXX_HELPER_BASE_H From 0c5170afdb30d03f5c766e38c5e2ebdd6298edbd Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Sat, 14 Mar 2026 16:49:22 +0800 Subject: [PATCH 39/40] style: use static_cast instead of reinterpret_cast in deallocate_hamilt Replace reinterpret_cast with static_cast when deleting HamiltPW pointer. Since HamiltPW inherits from HamiltBase, static_cast is safer and more appropriate for downcasting in inheritance hierarchies. --- source/source_esolver/esolver_ks_pw.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 9c1f8cd8f9..4dabc88771 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -82,7 +82,7 @@ void ESolver_KS_PW::deallocate_hamilt() { if (this->p_hamilt != nullptr) { - delete reinterpret_cast*>(this->p_hamilt); + delete static_cast*>(this->p_hamilt); this->p_hamilt = nullptr; } } From 0416ccade0809440b498b4c1486e52e3ad3dcab3 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Sat, 14 Mar 2026 17:12:26 +0800 Subject: [PATCH 40/40] refactor(estate): convert setup_estate_pw to non-template function Convert setup_estate_pw and teardown_estate_pw from template functions to non-template functions with runtime type dispatch based on inp.device and inp.precision. Changes: 1. setup_estate_pw.h: - Remove template parameters from setup_estate_pw and teardown_estate_pw - Add template implementation functions setup_estate_pw_impl and teardown_estate_pw_impl 2. setup_estate_pw.cpp: - Add non-template setup_estate_pw that dispatches based on: - inp.device (gpu or cpu) - inp.precision (single or double) - Rename original implementations to *_impl - Update template instantiation 3. esolver_ks_pw.cpp: - Update calls from setup_estate_pw(...) to setup_estate_pw(...) - Update calls from teardown_estate_pw(...) to teardown_estate_pw(...) This simplifies the calling code in ESolver_KS_PW by removing template parameters while maintaining the same functionality through runtime dispatch. --- source/source_esolver/esolver_ks_pw.cpp | 4 +- source/source_estate/setup_estate_pw.cpp | 257 +++++++++++++---------- source/source_estate/setup_estate_pw.h | 52 +++-- 3 files changed, 180 insertions(+), 133 deletions(-) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 4dabc88771..4d605b9738 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -93,7 +93,7 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p ESolver_KS::before_all_runners(ucell, inp); //! setup and allocation for pelec, potentials, etc. - elecstate::setup_estate_pw(ucell, this->kv, this->sf, this->pelec, this->chr, + elecstate::setup_estate_pw(ucell, this->kv, this->sf, this->pelec, this->chr, this->locpp, this->ppcell, this->vsep_cell, this->pw_wfc, this->pw_rho, this->pw_rhod, this->pw_big, this->solvent, inp); @@ -342,7 +342,7 @@ void ESolver_KS_PW::after_all_runners(UnitCell& ucell) this->pw_rho, this->pw_rhod, this->chr, this->kv, this->stp, this->sf, this->ppcell, this->solvent, this->Pgrid, PARAM.inp); - elecstate::teardown_estate_pw(this->pelec, this->vsep_cell); + elecstate::teardown_estate_pw(this->pelec, this->vsep_cell); } diff --git a/source/source_estate/setup_estate_pw.cpp b/source/source_estate/setup_estate_pw.cpp index 5daa7be1b5..3ca02d79e6 100644 --- a/source/source_estate/setup_estate_pw.cpp +++ b/source/source_estate/setup_estate_pw.cpp @@ -1,32 +1,87 @@ #include "source_estate/setup_estate_pw.h" -#include "source_estate/elecstate_pw.h" // init of pelec -#include "source_estate/elecstate_pw_sdft.h" // init of pelec for sdft -#include "source_estate/elecstate_tools.h" // occupations +#include "source_estate/elecstate_pw.h" +#include "source_estate/elecstate_pw_sdft.h" +#include "source_estate/elecstate_tools.h" -template -void elecstate::setup_estate_pw(UnitCell& ucell, // unitcell - K_Vectors &kv, // kpoints - Structure_Factor &sf, // structure factors - elecstate::ElecState* &pelec, // pointer of electrons - Charge &chr, // charge density - pseudopot_cell_vl &locpp, // local pseudopotentials - pseudopot_cell_vnl &ppcell, // non-local pseudopotentials - VSep* &vsep_cell, // U-1/2 method - ModulePW::PW_Basis_K* pw_wfc, // pw for wfc - ModulePW::PW_Basis* pw_rho, // pw for rho - ModulePW::PW_Basis* pw_rhod, // pw for rhod - ModulePW::PW_Basis_Big* pw_big, // pw for big grid - surchem &solvent, // solvent - const Input_para& inp) // input parameters +namespace elecstate +{ + +void setup_estate_pw( + UnitCell& ucell, + K_Vectors& kv, + Structure_Factor& sf, + elecstate::ElecState*& pelec, + Charge& chr, + pseudopot_cell_vl& locpp, + pseudopot_cell_vnl& ppcell, + VSep*& vsep_cell, + ModulePW::PW_Basis_K* pw_wfc, + ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, + ModulePW::PW_Basis_Big* pw_big, + surchem& solvent, + const Input_para& inp) { ModuleBase::TITLE("elecstate", "setup_estate_pw"); - //! Initialize ElecState, set pelec pointer + const bool is_gpu = (inp.device == "gpu"); + const bool is_single = (inp.precision == "single"); + +#if ((defined __CUDA) || (defined __ROCM)) + if (is_gpu) + { + if (is_single) + { + setup_estate_pw_impl, base_device::DEVICE_GPU>( + ucell, kv, sf, pelec, chr, locpp, ppcell, vsep_cell, + pw_wfc, pw_rho, pw_rhod, pw_big, solvent, inp); + } + else + { + setup_estate_pw_impl, base_device::DEVICE_GPU>( + ucell, kv, sf, pelec, chr, locpp, ppcell, vsep_cell, + pw_wfc, pw_rho, pw_rhod, pw_big, solvent, inp); + } + } + else +#endif + { + if (is_single) + { + setup_estate_pw_impl, base_device::DEVICE_CPU>( + ucell, kv, sf, pelec, chr, locpp, ppcell, vsep_cell, + pw_wfc, pw_rho, pw_rhod, pw_big, solvent, inp); + } + else + { + setup_estate_pw_impl, base_device::DEVICE_CPU>( + ucell, kv, sf, pelec, chr, locpp, ppcell, vsep_cell, + pw_wfc, pw_rho, pw_rhod, pw_big, solvent, inp); + } + } +} + +template +void setup_estate_pw_impl( + UnitCell& ucell, + K_Vectors& kv, + Structure_Factor& sf, + elecstate::ElecState*& pelec, + Charge& chr, + pseudopot_cell_vl& locpp, + pseudopot_cell_vnl& ppcell, + VSep*& vsep_cell, + ModulePW::PW_Basis_K* pw_wfc, + ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, + ModulePW::PW_Basis_Big* pw_big, + surchem& solvent, + const Input_para& inp) +{ if (pelec == nullptr) { if (inp.esolver_type == "sdft") { - //! SDFT only supports double precision currently pelec = new elecstate::ElecStatePW_SDFT, Device>(pw_wfc, &chr, &kv, &ucell, &ppcell, pw_rho, pw_big); } @@ -37,14 +92,12 @@ void elecstate::setup_estate_pw(UnitCell& ucell, // unitcell } } - //! Initialize DFT-1/2 if (PARAM.inp.dfthalf_type > 0) { vsep_cell = new VSep; vsep_cell->init_vsep(*pw_rhod, ucell.sep_cell); } - //! Initialize the potential. if (pelec->pot == nullptr) { pelec->pot = new elecstate::Potential(pw_rhod, @@ -52,16 +105,13 @@ void elecstate::setup_estate_pw(UnitCell& ucell, // unitcell &solvent, &(pelec->f_en.etxc), &(pelec->f_en.vtxc), vsep_cell); } - //! Initalize local pseudopotential locpp.init_vloc(ucell, pw_rhod); ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "LOCAL POTENTIAL"); - //! Initalize non-local pseudopotential ppcell.init(ucell, &sf, pw_wfc); ppcell.init_vnl(ucell, pw_rhod); ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "NON-LOCAL POTENTIAL"); - //! Setup occupations if (inp.ocp) { elecstate::fixed_weights(inp.ocp_kb, @@ -71,116 +121,95 @@ void elecstate::setup_estate_pw(UnitCell& ucell, // unitcell pelec->wg, pelec->skip_weights); } - - return; } +void teardown_estate_pw(elecstate::ElecState*& pelec, VSep*& vsep_cell) +{ + ModuleBase::TITLE("elecstate", "teardown_estate_pw"); + + if (vsep_cell != nullptr) + { + delete vsep_cell; + } + + if (pelec != nullptr) + { + delete pelec; + pelec = nullptr; + } +} template -void elecstate::teardown_estate_pw(elecstate::ElecState* &pelec, VSep* &vsep_cell) +void teardown_estate_pw_impl(elecstate::ElecState*& pelec, VSep*& vsep_cell) { - ModuleBase::TITLE("elecstate", "teardown_estate_pw"); + ModuleBase::TITLE("elecstate", "teardown_estate_pw_impl"); if (vsep_cell != nullptr) { delete vsep_cell; } - // mohan update 20251005 to increase the security level if (pelec != nullptr) { - auto* pw_elec = dynamic_cast*>(pelec); - if (pw_elec) - { - delete pw_elec; - pelec = nullptr; - } - else - { - ModuleBase::WARNING_QUIT("elecstate::teardown_estate_pw", "Invalid ElecState type"); + auto* pw_elec = dynamic_cast*>(pelec); + if (pw_elec) + { + delete pw_elec; + pelec = nullptr; + } + else + { + ModuleBase::WARNING_QUIT("elecstate::teardown_estate_pw_impl", "Invalid ElecState type"); } } } +template void setup_estate_pw_impl, base_device::DEVICE_CPU>( + UnitCell& ucell, K_Vectors& kv, Structure_Factor& sf, + elecstate::ElecState*& pelec, Charge& chr, + pseudopot_cell_vl& locpp, pseudopot_cell_vnl& ppcell, VSep*& vsep_cell, + ModulePW::PW_Basis_K* pw_wfc, ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, ModulePW::PW_Basis_Big* pw_big, + surchem& solvent, const Input_para& inp); -template void elecstate::setup_estate_pw, base_device::DEVICE_CPU>( - UnitCell& ucell, // unitcell - K_Vectors &kv, // kpoints - Structure_Factor &sf, // structure factors - elecstate::ElecState* &pelec, // pointer of electrons - Charge &chr, // charge density - pseudopot_cell_vl &locpp, // local pseudopotentials - pseudopot_cell_vnl &ppcell, // non-local pseudopotentials - VSep* &vsep_cell, // U-1/2 method - ModulePW::PW_Basis_K *pw_wfc, // pw for wfc - ModulePW::PW_Basis *pw_rho, // pw for rho - ModulePW::PW_Basis *pw_rhod, // pw for rhod - ModulePW::PW_Basis_Big* pw_big, // pw for big grid - surchem &solvent, // solvent - const Input_para& inp); // input parameters - -template void elecstate::setup_estate_pw, base_device::DEVICE_CPU>( - UnitCell& ucell, // unitcell - K_Vectors &kv, // kpoints - Structure_Factor &sf, // structure factors - elecstate::ElecState* &pelec, // pointer of electrons - Charge &chr, // charge density - pseudopot_cell_vl &locpp, // local pseudopotentials - pseudopot_cell_vnl &ppcell, // non-local pseudopotentials - VSep* &vsep_cell, // U-1/2 method - ModulePW::PW_Basis_K *pw_wfc, // pw for wfc - ModulePW::PW_Basis *pw_rho, // pw for rho - ModulePW::PW_Basis *pw_rhod, // pw for rhod - ModulePW::PW_Basis_Big* pw_big, // pw for big grid - surchem &solvent, // solvent - const Input_para& inp); // input parameters - - -template void elecstate::teardown_estate_pw, base_device::DEVICE_CPU>( - elecstate::ElecState* &pelec, VSep* &vsep_cell); - -template void elecstate::teardown_estate_pw, base_device::DEVICE_CPU>( - elecstate::ElecState* &pelec, VSep* &vsep_cell); +template void setup_estate_pw_impl, base_device::DEVICE_CPU>( + UnitCell& ucell, K_Vectors& kv, Structure_Factor& sf, + elecstate::ElecState*& pelec, Charge& chr, + pseudopot_cell_vl& locpp, pseudopot_cell_vnl& ppcell, VSep*& vsep_cell, + ModulePW::PW_Basis_K* pw_wfc, ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, ModulePW::PW_Basis_Big* pw_big, + surchem& solvent, const Input_para& inp); +template void teardown_estate_pw_impl, base_device::DEVICE_CPU>( + elecstate::ElecState*& pelec, VSep*& vsep_cell); + +template void teardown_estate_pw_impl, base_device::DEVICE_CPU>( + elecstate::ElecState*& pelec, VSep*& vsep_cell); #if ((defined __CUDA) || (defined __ROCM)) -template void elecstate::setup_estate_pw, base_device::DEVICE_GPU>( - UnitCell& ucell, // unitcell - K_Vectors &kv, // kpoints - Structure_Factor &sf, // structure factors - elecstate::ElecState* &pelec, // pointer of electrons - Charge &chr, // charge density - pseudopot_cell_vl &locpp, // local pseudopotentials - pseudopot_cell_vnl &ppcell, // non-local pseudopotentials - VSep* &vsep_cell, // U-1/2 method - ModulePW::PW_Basis_K *pw_wfc, // pw for wfc - ModulePW::PW_Basis *pw_rho, // pw for rho - ModulePW::PW_Basis *pw_rhod, // pw for rhod - ModulePW::PW_Basis_Big* pw_big, // pw for big grid - surchem &solvent, // solvent - const Input_para& inp); // input parameters - -template void elecstate::setup_estate_pw, base_device::DEVICE_GPU>( - UnitCell& ucell, // unitcell - K_Vectors &kv, // kpoints - Structure_Factor &sf, // structure factors - elecstate::ElecState* &pelec, // pointer of electrons - Charge &chr, // charge density - pseudopot_cell_vl &locpp, // local pseudopotentials - pseudopot_cell_vnl &ppcell, // non-local pseudopotentials - VSep* &vsep_cell, // U-1/2 method - ModulePW::PW_Basis_K *pw_wfc, // pw for wfc - ModulePW::PW_Basis *pw_rho, // pw for rho - ModulePW::PW_Basis *pw_rhod, // pw for rhod - ModulePW::PW_Basis_Big* pw_big, // pw for big grid - surchem &solvent, // solvent - const Input_para& inp); // input parameters - -template void elecstate::teardown_estate_pw, base_device::DEVICE_GPU>( - elecstate::ElecState* &pelec, VSep* &vsep_cell); - -template void elecstate::teardown_estate_pw, base_device::DEVICE_GPU>( - elecstate::ElecState* &pelec, VSep* &vsep_cell); +template void setup_estate_pw_impl, base_device::DEVICE_GPU>( + UnitCell& ucell, K_Vectors& kv, Structure_Factor& sf, + elecstate::ElecState*& pelec, Charge& chr, + pseudopot_cell_vl& locpp, pseudopot_cell_vnl& ppcell, VSep*& vsep_cell, + ModulePW::PW_Basis_K* pw_wfc, ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, ModulePW::PW_Basis_Big* pw_big, + surchem& solvent, const Input_para& inp); + +template void setup_estate_pw_impl, base_device::DEVICE_GPU>( + UnitCell& ucell, K_Vectors& kv, Structure_Factor& sf, + elecstate::ElecState*& pelec, Charge& chr, + pseudopot_cell_vl& locpp, pseudopot_cell_vnl& ppcell, VSep*& vsep_cell, + ModulePW::PW_Basis_K* pw_wfc, ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, ModulePW::PW_Basis_Big* pw_big, + surchem& solvent, const Input_para& inp); + +template void teardown_estate_pw_impl, base_device::DEVICE_GPU>( + elecstate::ElecState*& pelec, VSep*& vsep_cell); + +template void teardown_estate_pw_impl, base_device::DEVICE_GPU>( + elecstate::ElecState*& pelec, VSep*& vsep_cell); #endif + +} diff --git a/source/source_estate/setup_estate_pw.h b/source/source_estate/setup_estate_pw.h index 81a2261f6d..cd1a388a74 100644 --- a/source/source_estate/setup_estate_pw.h +++ b/source/source_estate/setup_estate_pw.h @@ -1,7 +1,7 @@ #ifndef SETUP_ESTATE_PW_H #define SETUP_ESTATE_PW_H -#include "source_base/module_device/device.h" // use Device +#include "source_base/module_device/device.h" #include "source_cell/unitcell.h" #include "source_cell/klist.h" #include "source_pw/module_pwdft/structure_factor.h" @@ -12,26 +12,44 @@ namespace elecstate { +void setup_estate_pw( + UnitCell& ucell, + K_Vectors& kv, + Structure_Factor& sf, + elecstate::ElecState*& pelec, + Charge& chr, + pseudopot_cell_vl& locpp, + pseudopot_cell_vnl& ppcell, + VSep*& vsep_cell, + ModulePW::PW_Basis_K* pw_wfc, + ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, + ModulePW::PW_Basis_Big* pw_big, + surchem& solvent, + const Input_para& inp); + +void teardown_estate_pw(elecstate::ElecState*& pelec, VSep*& vsep_cell); + template -void setup_estate_pw(UnitCell& ucell, // unitcell - K_Vectors &kv, // kpoints - Structure_Factor &sf, // structure factors - elecstate::ElecState* &pelec, // pointer of electrons - Charge &chr, // charge density - pseudopot_cell_vl &locpp, // local pseudopotentials - pseudopot_cell_vnl &ppcell, // non-local pseudopotentials - VSep* &vsep_cell, // U-1/2 method - ModulePW::PW_Basis_K* pw_wfc, // pw for wfc - ModulePW::PW_Basis* pw_rho, // pw for rho - ModulePW::PW_Basis* pw_rhod, // pw for rhod - ModulePW::PW_Basis_Big* pw_big, // pw for big grid - surchem &solvent, // solvent - const Input_para& inp); // input parameters +void setup_estate_pw_impl( + UnitCell& ucell, + K_Vectors& kv, + Structure_Factor& sf, + elecstate::ElecState*& pelec, + Charge& chr, + pseudopot_cell_vl& locpp, + pseudopot_cell_vnl& ppcell, + VSep*& vsep_cell, + ModulePW::PW_Basis_K* pw_wfc, + ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, + ModulePW::PW_Basis_Big* pw_big, + surchem& solvent, + const Input_para& inp); template -void teardown_estate_pw(elecstate::ElecState* &pelec, VSep* &vsep_cell); +void teardown_estate_pw_impl(elecstate::ElecState*& pelec, VSep*& vsep_cell); } - #endif