From 05c1f507ad40f3d09a48d14f30cf6ca5849228cf Mon Sep 17 00:00:00 2001 From: mohanchen Date: Sun, 12 Oct 2025 17:59:37 +0800 Subject: [PATCH 01/14] add psi interface for pw --- source/source_esolver/esolver_ks_pw.cpp | 94 ++++++++----------------- source/source_esolver/esolver_ks_pw.h | 15 +--- source/source_psi/setup_psi.cpp | 71 +++++++++++++++++++ source/source_psi/setup_psi.h | 52 ++++++++++++++ 4 files changed, 154 insertions(+), 78 deletions(-) create mode 100644 source/source_psi/setup_psi.cpp create mode 100644 source/source_psi/setup_psi.h diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 97d4f61533..d6ecd4dae2 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -49,17 +49,9 @@ ESolver_KS_PW::~ESolver_KS_PW() // delete Hamilt this->deallocate_hamilt(); - if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") - { - delete this->kspw_psi; - } - if (PARAM.inp.precision == "single") - { - delete this->__kspw_psi; - } + // mohan add 2025-10-12 + this->setup_psi.clean(); - delete this->psi; - delete this->p_psi_init; } template @@ -89,18 +81,7 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p this->locpp, this->ppcell, this->vsep_cell, this->pw_wfc, this->pw_rho, this->pw_rhod, this->pw_big, this->solvent, inp); - //! Allocate and initialize psi - this->p_psi_init = new psi::PSIInit(inp.init_wfc, - inp.ks_solver, inp.basis_type, GlobalV::MY_RANK, ucell, - this->sf, this->kv, this->ppcell, *this->pw_wfc); - - allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk, PARAM.globalv.nbands_l, this->pw_wfc->npwk_max); - - this->p_psi_init->prepare_init(inp.pw_seed); - - this->kspw_psi = inp.device == "gpu" || inp.precision == "single" - ? new psi::Psi(this->psi[0]) - : reinterpret_cast*>(this->psi); + this->setup_psi.before_runner(); ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT BASIS"); @@ -142,7 +123,7 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) this->pw_wfc->collect_local_pw(PARAM.inp.erf_ecut, PARAM.inp.erf_height, PARAM.inp.erf_sigma); - this->p_psi_init->prepare_init(PARAM.inp.pw_seed); + this->setup_psi.p_psi_init->prepare_init(PARAM.inp.pw_seed); } //! Init Hamiltonian (cell changed) @@ -156,14 +137,10 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) //! Setup potentials (local, non-local, sc, +U, DFT-1/2) pw::setup_pot(istep, ucell, this->kv, this->sf, this->pelec, this->Pgrid, this->chr, this->locpp, this->ppcell, this->vsep_cell, - this->kspw_psi, this->p_hamilt, this->pw_wfc, this->pw_rhod, PARAM.inp); + this->setup_psi.psi_t, this->p_hamilt, this->pw_wfc, this->pw_rhod, PARAM.inp); - //! Initialize wave functions - if (!this->already_initpsi) - { - this->p_psi_init->initialize_psi(this->psi, this->kspw_psi, this->p_hamilt, GlobalV::ofs_running); - this->already_initpsi = true; - } + + this->setup_psi.init(); //! Exx calculations if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax" @@ -173,7 +150,7 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) { auto hamilt_pw = reinterpret_cast*>(this->p_hamilt); hamilt_pw->set_exx_helper(exx_helper); - exx_helper.set_psi(kspw_psi); + exx_helper.set_psi(this->setup_psi.psi_t); } } @@ -202,7 +179,7 @@ void ESolver_KS_PW::iter_init(UnitCell& ucell, const int istep, const // new DFT+U method will calculate energy when evaluating the Hamiltonian if (dftu->omc != 2) { - dftu->cal_occ_pw(iter, this->kspw_psi, this->pelec->wg, ucell, PARAM.inp.mixing_beta); + dftu->cal_occ_pw(iter, this->setup_psi.psi_t, this->pelec->wg, ucell, PARAM.inp.mixing_beta); } dftu->output(ucell); } @@ -271,7 +248,7 @@ void ESolver_KS_PW::hamilt2rho_single(UnitCell& ucell, const int iste PARAM.inp.use_k_continuity); hsolver_pw_obj.solve(this->p_hamilt, - this->kspw_psi[0], + this->setup_psi.psi_t[0], this->pelec, this->pelec->ekb.c, GlobalV::RANK_IN_POOL, @@ -316,7 +293,7 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int // Related to EXX if (GlobalC::exx_info.info_global.cal_exx && !exx_helper.op_exx->first_iter) { - this->pelec->set_exx(exx_helper.cal_exx_energy(kspw_psi)); + this->pelec->set_exx(exx_helper.cal_exx_energy(this->setup_psi.psi_t)); } // deband is calculated from "output" charge density @@ -347,12 +324,12 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int double dexx = 0.0; if (PARAM.inp.exx_thr_type == "energy") { - dexx = exx_helper.cal_exx_energy(this->kspw_psi); + dexx = exx_helper.cal_exx_energy(this->setup_psi.psi_t); } - exx_helper.set_psi(this->kspw_psi); + exx_helper.set_psi(this->setup_psi.psi_t); if (PARAM.inp.exx_thr_type == "energy") { - dexx -= exx_helper.cal_exx_energy(this->kspw_psi); + dexx -= exx_helper.cal_exx_energy(this->setup_psi.psi_t); // std::cout << "dexx = " << dexx << std::endl; } bool conv_ene = std::abs(dexx) < PARAM.inp.exx_ene_thr; @@ -373,7 +350,7 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int } else { - exx_helper.set_psi(this->kspw_psi); + exx_helper.set_psi(this->setup_psi.psi_t); } } @@ -418,15 +395,15 @@ void ESolver_KS_PW::after_scf(UnitCell& ucell, const int istep, const // Transfer data from GPU to CPU in pw basis if (this->device == base_device::GpuDevice) { - castmem_2d_d2h_op()(this->psi[0].get_pointer() - this->psi[0].get_psi_bias(), - this->kspw_psi[0].get_pointer() - this->kspw_psi[0].get_psi_bias(), - this->psi[0].size()); + castmem_2d_d2h_op()(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), + this->psi_t[0].get_pointer() - this->psi_t[0].get_psi_bias(), + this->psi_cpu[0].size()); } // Output quantities ModuleIO::ctrl_scf_pw(istep, ucell, this->pelec, this->chr, this->kv, this->pw_wfc, - this->pw_rho, this->pw_rhod, this->pw_big, this->psi, this->kspw_psi, - this->__kspw_psi, this->ctx, this->Pgrid, PARAM.inp); + this->pw_rho, this->pw_rhod, this->pw_big, this->psi, this->setup_psi, + this->ctx, this->Pgrid, PARAM.inp); ModuleBase::timer::tick("ESolver_KS_PW", "after_scf"); } @@ -442,20 +419,13 @@ void ESolver_KS_PW::cal_force(UnitCell& ucell, ModuleBase::matrix& fo { Forces ff(ucell.nat); - if (this->__kspw_psi != nullptr && PARAM.inp.precision == "single") - { - delete reinterpret_cast, Device>*>(this->__kspw_psi); - } - - // Refresh __kspw_psi - this->__kspw_psi = PARAM.inp.precision == "single" - ? new psi::Psi, Device>(this->kspw_psi[0]) - : reinterpret_cast, Device>*>(this->kspw_psi); + // mohan add 2025-10-12 + this->setup_psi.update_psi_d(); // Calculate forces ff.cal_force(ucell, force, *this->pelec, this->pw_rhod, &ucell.symm, &this->sf, this->solvent, &this->locpp, &this->ppcell, - &this->kv, this->pw_wfc, this->__kspw_psi); + &this->kv, this->pw_wfc, this->setup_psi.psi_d); } template @@ -463,18 +433,11 @@ void ESolver_KS_PW::cal_stress(UnitCell& ucell, ModuleBase::matrix& s { Stress_PW ss(this->pelec); - if (this->__kspw_psi != nullptr && PARAM.inp.precision == "single") - { - delete reinterpret_cast, Device>*>(this->__kspw_psi); - } - - // Refresh __kspw_psi - this->__kspw_psi = PARAM.inp.precision == "single" - ? new psi::Psi, Device>(this->kspw_psi[0]) - : reinterpret_cast, Device>*>(this->kspw_psi); + // mohan add 2025-10-12 + this->setup_psi.update_psi_d(); ss.cal_stress(stress, ucell, this->locpp, this->ppcell, this->pw_rhod, - &ucell.symm, &this->sf, &this->kv, this->pw_wfc, this->__kspw_psi); + &ucell.symm, &this->sf, &this->kv, this->pw_wfc, this->setup_psi.psi_d); // external stress double unit_transform = 0.0; @@ -492,9 +455,8 @@ void ESolver_KS_PW::after_all_runners(UnitCell& ucell) ESolver_KS::after_all_runners(ucell); ModuleIO::ctrl_runner_pw(ucell, this->pelec, this->pw_wfc, - this->pw_rho, this->pw_rhod, this->chr, this->kv, this->psi, - this->kspw_psi, this->__kspw_psi, this->sf, - this->ppcell, this->solvent, this->ctx, this->Pgrid, PARAM.inp); + this->pw_rho, this->pw_rhod, this->chr, this->kv, this->setup_psi, + this->sf, this->ppcell, this->solvent, this->ctx, this->Pgrid, PARAM.inp); elecstate::teardown_estate_pw(this->pelec, this->vsep_cell); diff --git a/source/source_esolver/esolver_ks_pw.h b/source/source_esolver/esolver_ks_pw.h index 523ae91939..da1af69d0a 100644 --- a/source/source_esolver/esolver_ks_pw.h +++ b/source/source_esolver/esolver_ks_pw.h @@ -1,7 +1,7 @@ #ifndef ESOLVER_KS_PW_H #define ESOLVER_KS_PW_H #include "./esolver_ks.h" -#include "source_psi/psi_init.h" +#include "source_psi/setup_psi.h" // mohan add 20251012 #include "source_pw/module_pwdft/VSep_in_pw.h" #include "source_pw/module_pwdft/global.h" #include "source_pw/module_pwdft/module_exx_helper/exx_helper.h" @@ -54,11 +54,8 @@ class ESolver_KS_PW : public ESolver_KS virtual void allocate_hamilt(const UnitCell& ucell); virtual void deallocate_hamilt(); - //! hide the psi in ESolver_KS for tmp use - psi::Psi, base_device::DEVICE_CPU>* psi = nullptr; - - // psi_initializer controller - psi::PSIInit* p_psi_init = nullptr; + // Electronic wave function: wf + Setup_Psi &setup_psi; // DFT-1/2 method VSep* vsep_cell = nullptr; @@ -67,12 +64,6 @@ class ESolver_KS_PW : public ESolver_KS base_device::AbacusDevice_t device = {}; - psi::Psi* kspw_psi = nullptr; - - psi::Psi, Device>* __kspw_psi = nullptr; - - bool already_initpsi = false; - using castmem_2d_d2h_op = base_device::memory::cast_memory_op, T, base_device::DEVICE_CPU, Device>; }; diff --git a/source/source_psi/setup_psi.cpp b/source/source_psi/setup_psi.cpp new file mode 100644 index 0000000000..46bdbf3363 --- /dev/null +++ b/source/source_psi/setup_psi.cpp @@ -0,0 +1,71 @@ +#include "source_lcao/setup_deepks.h" +#include "source_lcao/LCAO_domain.h" +#include "source_io/module_parameter/parameter.h" // use parameter + +template +Setup_Psi::Setup_Psi(){} + +template +Setup_Psi::~Setup_Psi(){} + +template +Setup_Psi::before_runner() +{ + //! Allocate and initialize psi + this->p_psi_init = new psi::PSIInit(inp.init_wfc, + inp.ks_solver, inp.basis_type, GlobalV::MY_RANK, ucell, + this->sf, this->kv, this->ppcell, *this->pw_wfc); + + allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk, PARAM.globalv.nbands_l, this->pw_wfc->npwk_max); + + this->p_psi_init->prepare_init(inp.pw_seed); + + this->kspw_psi = inp.device == "gpu" || inp.precision == "single" + ? new psi::Psi(this->psi[0]) + : reinterpret_cast*>(this->psi); +} + + +template +Setup_Psi::update_psi_d() +{ + if (this->psi_d != nullptr && PARAM.inp.precision == "single") + { + delete reinterpret_cast, Device>*>(this->psi_t); + } + + // Refresh this->psi_d + this->psi_d = PARAM.inp.precision == "single" + ? new psi::Psi, Device>(this->psi_t[0]) + : reinterpret_cast, Device>*>(this->psi_t); +} + +template +void Setup_Psi::init() +{ + //! Initialize wave functions + if (!this->already_initpsi) + { + this->p_psi_init->initialize_psi(this->psi, this->kspw_psi, this->p_hamilt, GlobalV::ofs_running); + this->already_initpsi = true; + } +} + + +template +void Setup_Psi::clean() +{ + if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") + { + delete this->kspw_psi; + } + if (PARAM.inp.precision == "single") + { + delete this->__kspw_psi; + } + + delete this->psi; + delete this->p_psi_init; + + +} diff --git a/source/source_psi/setup_psi.h b/source/source_psi/setup_psi.h new file mode 100644 index 0000000000..1c8d1fe896 --- /dev/null +++ b/source/source_psi/setup_psi.h @@ -0,0 +1,52 @@ +#ifndef SETUP_PSI_H +#define SETUP_PSI_H + +#include "source_psi/psi_init.h" + +template +class Setup_Psi +{ + public: + + Setup_Psi(); + ~Setup_Psi(); + + //------------ + // variables + // psi_cpu, complex on cpu + // psi_t, complex on cpu/gpu + // psi_d, complex on cpu/gpu + //------------ + + // originally, this term is psi + // for PW, we have psi_cpu + psi::Psi, base_device::DEVICE_CPU>* psi_cpu = nullptr; + + // originally, this term is kspw_psi + // if CPU, kspw_psi = psi, otherwise, kspw_psi has a new copy + psi::Psi* psi_t = nullptr; + + // originally, this term is __kspw_psi + psi::Psi, Device>* psi_d = nullptr; + + // psi_initializer controller + psi::PSIInit* p_psi_init = nullptr; + + bool already_initpsi = false; + + //------------ + // functions + //------------ + + void before_runner(); + + void init(); + + void update_psi_d(); + + void clean(); + +}; + + +#endif From 9e10c6b79a3454bf1dd094bbd336860fed1e2d05 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Sun, 12 Oct 2025 21:10:31 +0800 Subject: [PATCH 02/14] update setup_psi --- source/source_esolver/esolver_ks_pw.cpp | 6 +++--- source/source_io/ctrl_output_pw.cpp | 22 ++++++---------------- source/source_io/ctrl_output_pw.h | 5 ++--- source/source_psi/setup_psi.cpp | 18 ++++++++---------- 4 files changed, 19 insertions(+), 32 deletions(-) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index d6ecd4dae2..fe4c263b6a 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -371,7 +371,7 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int } // the output quantities - ModuleIO::ctrl_iter_pw(istep, iter, conv_esolver, this->psi, + ModuleIO::ctrl_iter_pw(istep, iter, conv_esolver, this->setup_psi.psi_cpu, this->kv, this->pw_wfc, PARAM.inp); } @@ -386,7 +386,7 @@ void ESolver_KS_PW::after_scf(UnitCell& ucell, const int istep, const // sunliang 2025-04-10 if (PARAM.inp.out_elf[0] > 0) { - this->ESolver_KS::psi = new psi::Psi(this->psi[0]); + this->ESolver_KS::psi = new psi::Psi(this->setup_psi.psi_cpu[0]); } // Call 'after_scf' of ESolver_KS @@ -402,7 +402,7 @@ void ESolver_KS_PW::after_scf(UnitCell& ucell, const int istep, const // Output quantities ModuleIO::ctrl_scf_pw(istep, ucell, this->pelec, this->chr, this->kv, this->pw_wfc, - this->pw_rho, this->pw_rhod, this->pw_big, this->psi, this->setup_psi, + this->pw_rho, this->pw_rhod, this->pw_big, this->setup_psi, this->ctx, this->Pgrid, PARAM.inp); ModuleBase::timer::tick("ESolver_KS_PW", "after_scf"); diff --git a/source/source_io/ctrl_output_pw.cpp b/source/source_io/ctrl_output_pw.cpp index a8c588996a..a5b706d5bb 100644 --- a/source/source_io/ctrl_output_pw.cpp +++ b/source/source_io/ctrl_output_pw.cpp @@ -90,9 +90,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi* kspw_psi, - psi::Psi, Device>* __kspw_psi, + Setup_Psi &setup_psi, const Device* ctx, const Parallel_Grid ¶_grid, const Input_para& inp) @@ -160,15 +158,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, //------------------------------------------------------------------ if (inp.out_pchg.size() > 0) { - if (__kspw_psi != nullptr && inp.precision == "single") - { - delete reinterpret_cast, Device>*>(__kspw_psi); - } - - // Refresh __kspw_psi - __kspw_psi = inp.precision == "single" - ? new psi::Psi, Device>(kspw_psi[0]) - : reinterpret_cast, Device>*>(kspw_psi); + setup_psi.update_psi_d(); const int nbands = kspw_psi->get_nbands(); const int ngmc = chr.ngmc; @@ -179,7 +169,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, pw_rhod->nxyz, ngmc, &ucell, - __kspw_psi, + setup_psi.psi, pw_rhod, pw_wfc, ctx, @@ -207,7 +197,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, inp.nnkpfile, inp.wannier_spin); wan.set_tpiba_omega(ucell.tpiba, ucell.omega); - wan.calculate(ucell, pelec->ekb, pw_wfc, pw_big, kv, psi); + wan.calculate(ucell, pelec->ekb, pw_wfc, pw_big, kv, psi_cpu); std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Wannier functions calculation"); } @@ -219,7 +209,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, { std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "Berry phase polarization"); berryphase bp; - bp.Macroscopic_polarization(ucell, pw_wfc->npwk_max, psi, pw_rho, pw_wfc, kv); + bp.Macroscopic_polarization(ucell, pw_wfc->npwk_max, psi_cpu, pw_rho, pw_wfc, kv); std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Berry phase polarization"); } @@ -241,7 +231,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, if (inp.onsite_radius > 0) { // float type has not been implemented auto* onsite_p = projectors::OnsiteProjector::get_instance(); - onsite_p->cal_occupations(reinterpret_cast, Device>*>(kspw_psi), + onsite_p->cal_occupations(reinterpret_cast, Device>*>(psi_t), pelec->wg); } diff --git a/source/source_io/ctrl_output_pw.h b/source/source_io/ctrl_output_pw.h index 87fea245b0..4e20e9b97e 100644 --- a/source/source_io/ctrl_output_pw.h +++ b/source/source_io/ctrl_output_pw.h @@ -4,6 +4,7 @@ #include "source_base/module_device/device.h" // use Device #include "source_psi/psi.h" // define psi #include "source_estate/elecstate_lcao.h" // use pelec +#include "source_psi/setup_psi.h" // use Setup_Psi class namespace ModuleIO { @@ -28,9 +29,7 @@ void ctrl_scf_pw(const int istep, const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi* kspw_psi, - psi::Psi, Device>* __kspw_psi, + Setup_Psi &setup_psi, const Device* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); diff --git a/source/source_psi/setup_psi.cpp b/source/source_psi/setup_psi.cpp index 46bdbf3363..903e5a4f1e 100644 --- a/source/source_psi/setup_psi.cpp +++ b/source/source_psi/setup_psi.cpp @@ -16,13 +16,13 @@ Setup_Psi::before_runner() inp.ks_solver, inp.basis_type, GlobalV::MY_RANK, ucell, this->sf, this->kv, this->ppcell, *this->pw_wfc); - allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk, PARAM.globalv.nbands_l, this->pw_wfc->npwk_max); + allocate_psi(this->psi_cpu, this->kv.get_nks(), this->kv.ngk, PARAM.globalv.nbands_l, this->pw_wfc->npwk_max); this->p_psi_init->prepare_init(inp.pw_seed); - this->kspw_psi = inp.device == "gpu" || inp.precision == "single" - ? new psi::Psi(this->psi[0]) - : reinterpret_cast*>(this->psi); + this->psi_t = inp.device == "gpu" || inp.precision == "single" + ? new psi::Psi(this->psi_cpu[0]) + : reinterpret_cast*>(this->psi_cpu); } @@ -46,7 +46,7 @@ void Setup_Psi::init() //! Initialize wave functions if (!this->already_initpsi) { - this->p_psi_init->initialize_psi(this->psi, this->kspw_psi, this->p_hamilt, GlobalV::ofs_running); + this->p_psi_init->initialize_psi(this->psi_cpu, this->psi_t, p_hamilt, GlobalV::ofs_running); this->already_initpsi = true; } } @@ -57,15 +57,13 @@ void Setup_Psi::clean() { if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") { - delete this->kspw_psi; + delete this->psi_t; } if (PARAM.inp.precision == "single") { - delete this->__kspw_psi; + delete this->psi_d; } - delete this->psi; + delete this->psi_cpu; delete this->p_psi_init; - - } From ec5359899f76e03be14f5172af3ec6aa09b4f834 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Sun, 12 Oct 2025 21:29:35 +0800 Subject: [PATCH 03/14] update pis --- source/Makefile.Objects | 3 +- source/source_esolver/esolver_ks_pw.cpp | 49 +++++++++++-------------- source/source_esolver/esolver_ks_pw.h | 2 +- source/source_psi/CMakeLists.txt | 3 +- source/source_psi/setup_psi.cpp | 17 +++++++++ source/source_psi/setup_psi.h | 2 + 6 files changed, 46 insertions(+), 30 deletions(-) diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 925298c983..210445adee 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -460,7 +460,8 @@ OBJS_ORBITAL=ORB_atomic.o\ OBJS_PSI=psi.o\ -OBJS_PSI_INITIALIZER=psi_initializer.o\ +OBJS_PSI_INITIALIZER=setup_psi.o\ + psi_initializer.o\ psi_initializer_random.o\ psi_initializer_file.o\ psi_initializer_atomic.o\ diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index fe4c263b6a..51f5909aa2 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -50,7 +50,7 @@ ESolver_KS_PW::~ESolver_KS_PW() this->deallocate_hamilt(); // mohan add 2025-10-12 - this->setup_psi.clean(); + this->stp.clean(); } @@ -81,7 +81,7 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p this->locpp, this->ppcell, this->vsep_cell, this->pw_wfc, this->pw_rho, this->pw_rhod, this->pw_big, this->solvent, inp); - this->setup_psi.before_runner(); + this->stp.before_runner(); ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT BASIS"); @@ -123,7 +123,7 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) this->pw_wfc->collect_local_pw(PARAM.inp.erf_ecut, PARAM.inp.erf_height, PARAM.inp.erf_sigma); - this->setup_psi.p_psi_init->prepare_init(PARAM.inp.pw_seed); + this->stp.p_psi_init->prepare_init(PARAM.inp.pw_seed); } //! Init Hamiltonian (cell changed) @@ -137,10 +137,10 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) //! Setup potentials (local, non-local, sc, +U, DFT-1/2) pw::setup_pot(istep, ucell, this->kv, this->sf, this->pelec, this->Pgrid, this->chr, this->locpp, this->ppcell, this->vsep_cell, - this->setup_psi.psi_t, this->p_hamilt, this->pw_wfc, this->pw_rhod, PARAM.inp); + this->stp.psi_t, this->p_hamilt, this->pw_wfc, this->pw_rhod, PARAM.inp); - this->setup_psi.init(); + this->stp.init(); //! Exx calculations if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax" @@ -150,7 +150,7 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) { auto hamilt_pw = reinterpret_cast*>(this->p_hamilt); hamilt_pw->set_exx_helper(exx_helper); - exx_helper.set_psi(this->setup_psi.psi_t); + exx_helper.set_psi(this->stp.psi_t); } } @@ -179,7 +179,7 @@ void ESolver_KS_PW::iter_init(UnitCell& ucell, const int istep, const // new DFT+U method will calculate energy when evaluating the Hamiltonian if (dftu->omc != 2) { - dftu->cal_occ_pw(iter, this->setup_psi.psi_t, this->pelec->wg, ucell, PARAM.inp.mixing_beta); + dftu->cal_occ_pw(iter, this->stp.psi_t, this->pelec->wg, ucell, PARAM.inp.mixing_beta); } dftu->output(ucell); } @@ -248,7 +248,7 @@ void ESolver_KS_PW::hamilt2rho_single(UnitCell& ucell, const int iste PARAM.inp.use_k_continuity); hsolver_pw_obj.solve(this->p_hamilt, - this->setup_psi.psi_t[0], + this->stp.psi_t[0], this->pelec, this->pelec->ekb.c, GlobalV::RANK_IN_POOL, @@ -293,7 +293,7 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int // Related to EXX if (GlobalC::exx_info.info_global.cal_exx && !exx_helper.op_exx->first_iter) { - this->pelec->set_exx(exx_helper.cal_exx_energy(this->setup_psi.psi_t)); + this->pelec->set_exx(exx_helper.cal_exx_energy(this->stp.psi_t)); } // deband is calculated from "output" charge density @@ -324,12 +324,12 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int double dexx = 0.0; if (PARAM.inp.exx_thr_type == "energy") { - dexx = exx_helper.cal_exx_energy(this->setup_psi.psi_t); + dexx = exx_helper.cal_exx_energy(this->stp.psi_t); } - exx_helper.set_psi(this->setup_psi.psi_t); + exx_helper.set_psi(this->stp.psi_t); if (PARAM.inp.exx_thr_type == "energy") { - dexx -= exx_helper.cal_exx_energy(this->setup_psi.psi_t); + dexx -= exx_helper.cal_exx_energy(this->stp.psi_t); // std::cout << "dexx = " << dexx << std::endl; } bool conv_ene = std::abs(dexx) < PARAM.inp.exx_ene_thr; @@ -350,7 +350,7 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int } else { - exx_helper.set_psi(this->setup_psi.psi_t); + exx_helper.set_psi(this->stp.psi_t); } } @@ -371,7 +371,7 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int } // the output quantities - ModuleIO::ctrl_iter_pw(istep, iter, conv_esolver, this->setup_psi.psi_cpu, + ModuleIO::ctrl_iter_pw(istep, iter, conv_esolver, this->stp.psi_cpu, this->kv, this->pw_wfc, PARAM.inp); } @@ -386,23 +386,18 @@ void ESolver_KS_PW::after_scf(UnitCell& ucell, const int istep, const // sunliang 2025-04-10 if (PARAM.inp.out_elf[0] > 0) { - this->ESolver_KS::psi = new psi::Psi(this->setup_psi.psi_cpu[0]); + this->ESolver_KS::psi = new psi::Psi(this->stp.psi_cpu[0]); } // Call 'after_scf' of ESolver_KS ESolver_KS::after_scf(ucell, istep, conv_esolver); // Transfer data from GPU to CPU in pw basis - if (this->device == base_device::GpuDevice) - { - castmem_2d_d2h_op()(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), - this->psi_t[0].get_pointer() - this->psi_t[0].get_psi_bias(), - this->psi_cpu[0].size()); - } + this->stp.copy_g2c(); // Output quantities ModuleIO::ctrl_scf_pw(istep, ucell, this->pelec, this->chr, this->kv, this->pw_wfc, - this->pw_rho, this->pw_rhod, this->pw_big, this->setup_psi, + this->pw_rho, this->pw_rhod, this->pw_big, this->stp, this->ctx, this->Pgrid, PARAM.inp); ModuleBase::timer::tick("ESolver_KS_PW", "after_scf"); @@ -420,12 +415,12 @@ void ESolver_KS_PW::cal_force(UnitCell& ucell, ModuleBase::matrix& fo Forces ff(ucell.nat); // mohan add 2025-10-12 - this->setup_psi.update_psi_d(); + this->stp.update_psi_d(); // Calculate forces ff.cal_force(ucell, force, *this->pelec, this->pw_rhod, &ucell.symm, &this->sf, this->solvent, &this->locpp, &this->ppcell, - &this->kv, this->pw_wfc, this->setup_psi.psi_d); + &this->kv, this->pw_wfc, this->stp.psi_d); } template @@ -434,10 +429,10 @@ void ESolver_KS_PW::cal_stress(UnitCell& ucell, ModuleBase::matrix& s Stress_PW ss(this->pelec); // mohan add 2025-10-12 - this->setup_psi.update_psi_d(); + this->stp.update_psi_d(); ss.cal_stress(stress, ucell, this->locpp, this->ppcell, this->pw_rhod, - &ucell.symm, &this->sf, &this->kv, this->pw_wfc, this->setup_psi.psi_d); + &ucell.symm, &this->sf, &this->kv, this->pw_wfc, this->stp.psi_d); // external stress double unit_transform = 0.0; @@ -455,7 +450,7 @@ void ESolver_KS_PW::after_all_runners(UnitCell& ucell) ESolver_KS::after_all_runners(ucell); ModuleIO::ctrl_runner_pw(ucell, this->pelec, this->pw_wfc, - this->pw_rho, this->pw_rhod, this->chr, this->kv, this->setup_psi, + this->pw_rho, this->pw_rhod, this->chr, this->kv, this->stp, this->sf, this->ppcell, this->solvent, this->ctx, this->Pgrid, PARAM.inp); elecstate::teardown_estate_pw(this->pelec, this->vsep_cell); diff --git a/source/source_esolver/esolver_ks_pw.h b/source/source_esolver/esolver_ks_pw.h index da1af69d0a..358fc68e57 100644 --- a/source/source_esolver/esolver_ks_pw.h +++ b/source/source_esolver/esolver_ks_pw.h @@ -55,7 +55,7 @@ class ESolver_KS_PW : public ESolver_KS virtual void deallocate_hamilt(); // Electronic wave function: wf - Setup_Psi &setup_psi; + Setup_Psi &stp; // DFT-1/2 method VSep* vsep_cell = nullptr; diff --git a/source/source_psi/CMakeLists.txt b/source/source_psi/CMakeLists.txt index a1037885f3..0155d3a4c0 100644 --- a/source/source_psi/CMakeLists.txt +++ b/source/source_psi/CMakeLists.txt @@ -13,6 +13,7 @@ add_library( add_library( psi_initializer OBJECT + setup_psi.cpp psi_initializer.cpp psi_initializer_random.cpp psi_initializer_file.cpp @@ -32,4 +33,4 @@ if (BUILD_TESTING) if(ENABLE_MPI) add_subdirectory(test) endif() -endif() \ No newline at end of file +endif() diff --git a/source/source_psi/setup_psi.cpp b/source/source_psi/setup_psi.cpp index 903e5a4f1e..0d8f5437a4 100644 --- a/source/source_psi/setup_psi.cpp +++ b/source/source_psi/setup_psi.cpp @@ -16,10 +16,13 @@ Setup_Psi::before_runner() inp.ks_solver, inp.basis_type, GlobalV::MY_RANK, ucell, this->sf, this->kv, this->ppcell, *this->pw_wfc); + //! Allocate memory for cpu version of psi allocate_psi(this->psi_cpu, this->kv.get_nks(), this->kv.ngk, PARAM.globalv.nbands_l, this->pw_wfc->npwk_max); this->p_psi_init->prepare_init(inp.pw_seed); + //! If GPU or single precision, allocate a new psi (psi_t). + //! otherwise, transform psi_cpu to psi_t this->psi_t = inp.device == "gpu" || inp.precision == "single" ? new psi::Psi(this->psi_cpu[0]) : reinterpret_cast*>(this->psi_cpu); @@ -52,6 +55,20 @@ void Setup_Psi::init() } +template +void Setup_Psi::copy_g22() +{ + // Transfer data from GPU to CPU in pw basis + if (this->device == base_device::GpuDevice) + { + castmem_2d_d2h_op()(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), + this->psi_t[0].get_pointer() - this->psi_t[0].get_psi_bias(), + this->psi_cpu[0].size()); + } +} + + + template void Setup_Psi::clean() { diff --git a/source/source_psi/setup_psi.h b/source/source_psi/setup_psi.h index 1c8d1fe896..cda2083d6d 100644 --- a/source/source_psi/setup_psi.h +++ b/source/source_psi/setup_psi.h @@ -44,6 +44,8 @@ class Setup_Psi void update_psi_d(); + void copy_g22(); + void clean(); }; From b44e8ec5ba548dd7f864860ba1e89c627c9549e6 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Sun, 12 Oct 2025 21:37:35 +0800 Subject: [PATCH 04/14] keep updating psi --- source/source_esolver/esolver_ks_pw.h | 2 +- source/source_esolver/esolver_sdft_pw.cpp | 2 +- source/source_io/ctrl_output_pw.cpp | 43 +++++++---------------- source/source_io/ctrl_output_pw.h | 6 ++-- 4 files changed, 17 insertions(+), 36 deletions(-) diff --git a/source/source_esolver/esolver_ks_pw.h b/source/source_esolver/esolver_ks_pw.h index 358fc68e57..831e5bbf43 100644 --- a/source/source_esolver/esolver_ks_pw.h +++ b/source/source_esolver/esolver_ks_pw.h @@ -55,7 +55,7 @@ class ESolver_KS_PW : public ESolver_KS virtual void deallocate_hamilt(); // Electronic wave function: wf - Setup_Psi &stp; + Setup_Psi stp; // DFT-1/2 method VSep* vsep_cell = nullptr; diff --git a/source/source_esolver/esolver_sdft_pw.cpp b/source/source_esolver/esolver_sdft_pw.cpp index 170147ba06..dde15a8aaf 100644 --- a/source/source_esolver/esolver_sdft_pw.cpp +++ b/source/source_esolver/esolver_sdft_pw.cpp @@ -233,7 +233,7 @@ void ESolver_SDFT_PW::cal_force(UnitCell& ucell, ModuleBase::matrix& this->locpp, this->ppcell, ucell, - *this->kspw_psi, + *this->stp.psi_t, this->stowf); } diff --git a/source/source_io/ctrl_output_pw.cpp b/source/source_io/ctrl_output_pw.cpp index a5b706d5bb..a55a3dd260 100644 --- a/source/source_io/ctrl_output_pw.cpp +++ b/source/source_io/ctrl_output_pw.cpp @@ -90,7 +90,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - Setup_Psi &setup_psi, + Setup_Psi &stp, const Device* ctx, const Parallel_Grid ¶_grid, const Input_para& inp) @@ -158,7 +158,8 @@ void ModuleIO::ctrl_scf_pw(const int istep, //------------------------------------------------------------------ if (inp.out_pchg.size() > 0) { - setup_psi.update_psi_d(); + // update psi_d + stp.update_psi_d(); const int nbands = kspw_psi->get_nbands(); const int ngmc = chr.ngmc; @@ -169,7 +170,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, pw_rhod->nxyz, ngmc, &ucell, - setup_psi.psi, + stp.psi, pw_rhod, pw_wfc, ctx, @@ -247,9 +248,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi* kspw_psi, - psi::Psi, Device>* __kspw_psi, + Setup_Psi &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, @@ -389,9 +388,7 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CPU const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi, base_device::DEVICE_CPU>* kspw_psi, // T and Device - psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, // Device + Setup_Psi, base_device::DEVICE_CPU> &stp, const base_device::DEVICE_CPU* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -407,9 +404,7 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CP const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi, base_device::DEVICE_CPU>* kspw_psi, // T and Device - psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, // Device + Setup_Psi, base_device::DEVICE_CPU> &stp, const base_device::DEVICE_CPU* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -426,9 +421,7 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GPU const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi, base_device::DEVICE_GPU>* kspw_psi, // T and Device - psi::Psi, base_device::DEVICE_GPU>* __kspw_psi, // Device + Setup_Psi, base_device::DEVICE_GPU> &stp, const base_device::DEVICE_GPU* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -444,9 +437,7 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GP const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi, base_device::DEVICE_GPU>* kspw_psi, // T and Device - psi::Psi, base_device::DEVICE_GPU>* __kspw_psi, // Device + Setup_Psi, base_device::DEVICE_GPU> &stp, const base_device::DEVICE_GPU* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -461,9 +452,7 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_ ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi, base_device::DEVICE_CPU>* kspw_psi, // T and Device - psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, // Device + Setup_Psi, base_device::DEVICE_CPU> &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, @@ -480,9 +469,7 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi, base_device::DEVICE_CPU>* kspw_psi, // T and Device - psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, // Device + Setup_Psi, base_device::DEVICE_CPU> &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, @@ -500,9 +487,7 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_ ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi, base_device::DEVICE_GPU>* kspw_psi, // T and Device - psi::Psi, base_device::DEVICE_GPU>* __kspw_psi, // Device + Setup_Psi, base_device::DEVICE_GPU> &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, @@ -519,9 +504,7 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi, base_device::DEVICE_GPU>* kspw_psi, // T and Device - psi::Psi, base_device::DEVICE_GPU>* __kspw_psi, // Device + Setup_Psi, base_device::DEVICE_GPU> &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, diff --git a/source/source_io/ctrl_output_pw.h b/source/source_io/ctrl_output_pw.h index 4e20e9b97e..dbb8c9f0b4 100644 --- a/source/source_io/ctrl_output_pw.h +++ b/source/source_io/ctrl_output_pw.h @@ -29,7 +29,7 @@ void ctrl_scf_pw(const int istep, const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - Setup_Psi &setup_psi, + Setup_Psi &stp, const Device* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -43,9 +43,7 @@ void ctrl_runner_pw(UnitCell& ucell, ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi* kspw_psi, - psi::Psi, Device>* __kspw_psi, + Setup_Psi &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, From 6f47c383b7e50b400f97218c74b1778feb48724b Mon Sep 17 00:00:00 2001 From: mohanchen Date: Mon, 13 Oct 2025 14:11:29 +0800 Subject: [PATCH 05/14] fix things --- source/source_esolver/esolver_ks_lcaopw.cpp | 13 +++++---- source/source_esolver/esolver_sdft_pw.cpp | 6 ++-- source/source_io/ctrl_output_pw.cpp | 32 ++++++++------------- source/source_psi/setup_psi.cpp | 22 +++++++++----- source/source_psi/setup_psi.h | 3 +- 5 files changed, 39 insertions(+), 37 deletions(-) diff --git a/source/source_esolver/esolver_ks_lcaopw.cpp b/source/source_esolver/esolver_ks_lcaopw.cpp index 8597fa6847..5858a332ba 100644 --- a/source/source_esolver/esolver_ks_lcaopw.cpp +++ b/source/source_esolver/esolver_ks_lcaopw.cpp @@ -81,7 +81,7 @@ namespace ModuleESolver void ESolver_KS_LIP::before_scf(UnitCell& ucell, const int istep) { ESolver_KS_PW::before_scf(ucell, istep); - this->p_psi_init->initialize_lcao_in_pw(this->psi_local, GlobalV::ofs_running); + this->stp.p_psi_init->initialize_lcao_in_pw(this->psi_local, GlobalV::ofs_running); } template @@ -89,9 +89,9 @@ namespace ModuleESolver { ESolver_KS_PW::before_all_runners(ucell, inp); delete this->psi_local; - this->psi_local = new psi::Psi(this->psi->get_nk(), - this->p_psi_init->psi_initer->nbands_start(), - this->psi->get_nbasis(), + this->psi_local = new psi::Psi(this->stp.psi_cpu->get_nk(), + this->stp.p_psi_init->psi_initer->nbands_start(), + this->stp.psi_cpu->get_nbasis(), this->kv.ngk, true); #ifdef __EXX @@ -147,7 +147,8 @@ namespace ModuleESolver bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false; hsolver::HSolverLIP hsolver_lip_obj(this->pw_wfc); - hsolver_lip_obj.solve(this->p_hamilt, this->kspw_psi[0], this->pelec, *this->psi_local, skip_charge,ucell.tpiba,ucell.nat); + hsolver_lip_obj.solve(this->p_hamilt, this->stp.psi_t[0], this->pelec, + *this->psi_local, skip_charge,ucell.tpiba,ucell.nat); // add exx #ifdef __EXX @@ -244,7 +245,7 @@ namespace ModuleESolver ModuleIO::write_Vxc(PARAM.inp.nspin, PARAM.globalv.nlocal, GlobalV::DRANK, - *this->kspw_psi, + *this->stp.psi_t, ucell, this->sf, this->solvent, diff --git a/source/source_esolver/esolver_sdft_pw.cpp b/source/source_esolver/esolver_sdft_pw.cpp index dde15a8aaf..c1d6f1f11e 100644 --- a/source/source_esolver/esolver_sdft_pw.cpp +++ b/source/source_esolver/esolver_sdft_pw.cpp @@ -176,7 +176,7 @@ void ESolver_SDFT_PW::hamilt2rho_single(UnitCell& ucell, int istep, i hsolver_pw_sdft_obj.solve(ucell, this->p_hamilt, - this->kspw_psi[0], + this->stp.psi_t[0], this->psi[0], this->pelec, this->pw_wfc, @@ -248,7 +248,7 @@ void ESolver_SDFT_PW::cal_stress(UnitCell& ucell, ModuleBase::matrix& &this->sf, &this->kv, this->pw_wfc, - *this->kspw_psi, + *this->stp.psi_t, this->stowf, &this->chr, &this->locpp, @@ -301,7 +301,7 @@ void ESolver_SDFT_PW::after_all_runners(UnitCell& ucell) &this->kv, this->pelec, this->pw_wfc, - this->kspw_psi, + this->stp.psi_t, &this->ppcell, this->p_hamilt, this->stoche, diff --git a/source/source_io/ctrl_output_pw.cpp b/source/source_io/ctrl_output_pw.cpp index a55a3dd260..d71b7bb369 100644 --- a/source/source_io/ctrl_output_pw.cpp +++ b/source/source_io/ctrl_output_pw.cpp @@ -161,7 +161,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, // update psi_d stp.update_psi_d(); - const int nbands = kspw_psi->get_nbands(); + const int nbands = stp.psi_t->get_nbands(); const int ngmc = chr.ngmc; ModuleIO::get_pchg_pw(inp.out_pchg, @@ -170,7 +170,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, pw_rhod->nxyz, ngmc, &ucell, - stp.psi, + stp.psi_d, pw_rhod, pw_wfc, ctx, @@ -198,7 +198,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, inp.nnkpfile, inp.wannier_spin); wan.set_tpiba_omega(ucell.tpiba, ucell.omega); - wan.calculate(ucell, pelec->ekb, pw_wfc, pw_big, kv, psi_cpu); + wan.calculate(ucell, pelec->ekb, pw_wfc, pw_big, kv, stp.psi_cpu); std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Wannier functions calculation"); } @@ -210,7 +210,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, { std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "Berry phase polarization"); berryphase bp; - bp.Macroscopic_polarization(ucell, pw_wfc->npwk_max, psi_cpu, pw_rho, pw_wfc, kv); + bp.Macroscopic_polarization(ucell, pw_wfc->npwk_max, stp.psi_cpu, pw_rho, pw_wfc, kv); std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Berry phase polarization"); } @@ -232,7 +232,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, if (inp.onsite_radius > 0) { // float type has not been implemented auto* onsite_p = projectors::OnsiteProjector::get_instance(); - onsite_p->cal_occupations(reinterpret_cast, Device>*>(psi_t), + onsite_p->cal_occupations(reinterpret_cast, Device>*>(stp.psi_t), pelec->wg); } @@ -265,7 +265,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, if (inp.out_ldos[0]) { ModuleIO::cal_ldos_pw(reinterpret_cast>*>(pelec), - psi[0], para_grid, ucell); + stp.psi_cpu[0], para_grid, ucell); } //---------------------------------------------------------- @@ -285,7 +285,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, << " a.u." << std::endl; } Numerical_Basis numerical_basis; - numerical_basis.output_overlap(psi[0], sf, kv, pw_wfc, ucell, i); + numerical_basis.output_overlap(stp.psi_cpu[0], sf, kv, pw_wfc, ucell, i); } ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "BASIS OVERLAP (Q and S) GENERATION."); } @@ -296,23 +296,15 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, //---------------------------------------------------------- if (inp.out_wfc_norm.size() > 0 || inp.out_wfc_re_im.size() > 0) { - if (__kspw_psi != nullptr && inp.precision == "single") - { - delete reinterpret_cast, Device>*>(__kspw_psi); - } - - // Refresh __kspw_psi - __kspw_psi = inp.precision == "single" - ? new psi::Psi, Device>(kspw_psi[0]) - : reinterpret_cast, Device>*>(kspw_psi); + stp.update_psi_d(); ModuleIO::get_wf_pw(inp.out_wfc_norm, inp.out_wfc_re_im, - kspw_psi->get_nbands(), + stp.psi_t->get_nbands(), inp.nspin, pw_rhod->nxyz, &ucell, - __kspw_psi, + stp.psi_d, pw_wfc, ctx, para_grid, @@ -328,7 +320,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, if (inp.cal_cond) { using Real = typename GetTypeReal::type; - EleCond elec_cond(&ucell, &kv, pelec, pw_wfc, kspw_psi, &ppcell); + EleCond elec_cond(&ucell, &kv, pelec, pw_wfc, stp.psi_t, &ppcell); elec_cond.KG(inp.cond_smear, inp.cond_fwhm, inp.cond_wcut, @@ -365,7 +357,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, pw_rho); write_mlkedf_desc.generateTrainData_KS(PARAM.globalv.global_mlkedf_descriptor_dir, - kspw_psi, + stp.psi_t, pelec, pw_wfc, pw_rho, diff --git a/source/source_psi/setup_psi.cpp b/source/source_psi/setup_psi.cpp index 0d8f5437a4..6031b10c47 100644 --- a/source/source_psi/setup_psi.cpp +++ b/source/source_psi/setup_psi.cpp @@ -1,3 +1,4 @@ +#include "source_psi/setup_psi.h" #include "source_lcao/setup_deepks.h" #include "source_lcao/LCAO_domain.h" #include "source_io/module_parameter/parameter.h" // use parameter @@ -9,7 +10,7 @@ template Setup_Psi::~Setup_Psi(){} template -Setup_Psi::before_runner() +void Setup_Psi::before_runner() { //! Allocate and initialize psi this->p_psi_init = new psi::PSIInit(inp.init_wfc, @@ -30,11 +31,11 @@ Setup_Psi::before_runner() template -Setup_Psi::update_psi_d() +void Setup_Psi::update_psi_d() { if (this->psi_d != nullptr && PARAM.inp.precision == "single") { - delete reinterpret_cast, Device>*>(this->psi_t); + delete reinterpret_cast, Device>*>(this->psi_d); } // Refresh this->psi_d @@ -44,7 +45,7 @@ Setup_Psi::update_psi_d() } template -void Setup_Psi::init() +void Setup_Psi::init() { //! Initialize wave functions if (!this->already_initpsi) @@ -55,10 +56,10 @@ void Setup_Psi::init() } +// Transfer data from GPU to CPU in pw basis template -void Setup_Psi::copy_g22() +void Setup_Psi::copy_g2c() { - // Transfer data from GPU to CPU in pw basis if (this->device == base_device::GpuDevice) { castmem_2d_d2h_op()(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), @@ -70,7 +71,7 @@ void Setup_Psi::copy_g22() template -void Setup_Psi::clean() +void Setup_Psi::clean() { if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") { @@ -84,3 +85,10 @@ void Setup_Psi::clean() delete this->psi_cpu; delete this->p_psi_init; } + +template class Setup_Psi, base_device::DEVICE_CPU>; +template class Setup_Psi, base_device::DEVICE_CPU>; +#if ((defined __CUDA) || (defined __ROCM)) +template class Setup_Psi, base_device::DEVICE_GPU>; +template class Setup_Psi, base_device::DEVICE_GPU>; +#endif diff --git a/source/source_psi/setup_psi.h b/source/source_psi/setup_psi.h index cda2083d6d..7684097ec8 100644 --- a/source/source_psi/setup_psi.h +++ b/source/source_psi/setup_psi.h @@ -44,7 +44,8 @@ class Setup_Psi void update_psi_d(); - void copy_g22(); + // Transfer data from GPU to CPU in pw basis + void copy_g2c(); void clean(); From 4311f4b5c7f9db11281860dda451f2b52f450c6d Mon Sep 17 00:00:00 2001 From: mohanchen Date: Mon, 13 Oct 2025 14:24:07 +0800 Subject: [PATCH 06/14] update psi fix bugs --- source/source_esolver/esolver_ks_pw.cpp | 2 +- source/source_psi/setup_psi.cpp | 11 ++++++++--- source/source_psi/setup_psi.h | 12 +++++++++++- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 51f5909aa2..a5b2f887a2 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -81,7 +81,7 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p this->locpp, this->ppcell, this->vsep_cell, this->pw_wfc, this->pw_rho, this->pw_rhod, this->pw_big, this->solvent, inp); - this->stp.before_runner(); + this->stp.before_runner(this->kv, this->sf, this->pw_wfc, this->ppcell, PARAM.inp); ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT BASIS"); diff --git a/source/source_psi/setup_psi.cpp b/source/source_psi/setup_psi.cpp index 6031b10c47..167cbc8ef5 100644 --- a/source/source_psi/setup_psi.cpp +++ b/source/source_psi/setup_psi.cpp @@ -10,15 +10,20 @@ template Setup_Psi::~Setup_Psi(){} template -void Setup_Psi::before_runner() +void Setup_Psi::before_runner( + const K_Vectors& kv, + const Structure_Factor& sf, + const ModulePW::PW_Basis_K* pw_wfc, + const pseudopot_cell_vnl &ppcell, + const Input_para& inp) { //! Allocate and initialize psi this->p_psi_init = new psi::PSIInit(inp.init_wfc, inp.ks_solver, inp.basis_type, GlobalV::MY_RANK, ucell, - this->sf, this->kv, this->ppcell, *this->pw_wfc); + sf, kv, ppcell, pw_wfc); //! Allocate memory for cpu version of psi - allocate_psi(this->psi_cpu, this->kv.get_nks(), this->kv.ngk, PARAM.globalv.nbands_l, this->pw_wfc->npwk_max); + allocate_psi(this->psi_cpu, kv.get_nks(), kv.ngk, PARAM.globalv.nbands_l, pw_wfc->npwk_max); this->p_psi_init->prepare_init(inp.pw_seed); diff --git a/source/source_psi/setup_psi.h b/source/source_psi/setup_psi.h index 7684097ec8..a70f1de81e 100644 --- a/source/source_psi/setup_psi.h +++ b/source/source_psi/setup_psi.h @@ -2,6 +2,11 @@ #define SETUP_PSI_H #include "source_psi/psi_init.h" +#include "source_cell/klist.h" +#include "source_pw/module_pwdft/structure_factor.h" +#include "source_basis/module_pw/pw_basis_k.h" +#include "source_pw/module_pwdft/VNL_in_pw.h" +#include "source_io/module_parameter/input_parameter.h" template class Setup_Psi @@ -38,7 +43,12 @@ class Setup_Psi // functions //------------ - void before_runner(); + void before_runner( + const K_Vectors& kv, + const Structure_Factor& sf, + const ModulePW::PW_Basis_K* pw_wfc, + const pseudopot_cell_vnl &ppcell, + const Input_para& inp); void init(); From 35ee47ca2686f94c8509e947c5463808182f7b9a Mon Sep 17 00:00:00 2001 From: mohanchen Date: Mon, 13 Oct 2025 17:06:09 +0800 Subject: [PATCH 07/14] update psi --- source/source_esolver/esolver_ks_pw.cpp | 2 +- source/source_psi/setup_psi.cpp | 1 + source/source_psi/setup_psi.h | 2 ++ 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index a5b2f887a2..ae8fefc5bf 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -81,7 +81,7 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p this->locpp, this->ppcell, this->vsep_cell, this->pw_wfc, this->pw_rho, this->pw_rhod, this->pw_big, this->solvent, inp); - this->stp.before_runner(this->kv, this->sf, this->pw_wfc, this->ppcell, PARAM.inp); + this->stp.before_runner(ucell, this->kv, this->sf, this->pw_wfc, this->ppcell, PARAM.inp); ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT BASIS"); diff --git a/source/source_psi/setup_psi.cpp b/source/source_psi/setup_psi.cpp index 167cbc8ef5..477ede15f1 100644 --- a/source/source_psi/setup_psi.cpp +++ b/source/source_psi/setup_psi.cpp @@ -11,6 +11,7 @@ Setup_Psi::~Setup_Psi(){} template void Setup_Psi::before_runner( + const UnitCell& ucell, const K_Vectors& kv, const Structure_Factor& sf, const ModulePW::PW_Basis_K* pw_wfc, diff --git a/source/source_psi/setup_psi.h b/source/source_psi/setup_psi.h index a70f1de81e..871df54fe7 100644 --- a/source/source_psi/setup_psi.h +++ b/source/source_psi/setup_psi.h @@ -2,6 +2,7 @@ #define SETUP_PSI_H #include "source_psi/psi_init.h" +#include "source_cell/unitcell.h" #include "source_cell/klist.h" #include "source_pw/module_pwdft/structure_factor.h" #include "source_basis/module_pw/pw_basis_k.h" @@ -44,6 +45,7 @@ class Setup_Psi //------------ void before_runner( + const UnitCell& ucell, const K_Vectors& kv, const Structure_Factor& sf, const ModulePW::PW_Basis_K* pw_wfc, From 046f97a9a6a4e3bc04f278ce45f237904f839aa7 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Mon, 13 Oct 2025 22:57:53 +0800 Subject: [PATCH 08/14] update psi --- source/source_esolver/esolver_ks_pw.cpp | 6 +++--- source/source_esolver/esolver_ks_pw.h | 2 -- source/source_psi/setup_psi.cpp | 18 +++++++++--------- source/source_psi/setup_psi.h | 19 ++++++++++++------- 4 files changed, 24 insertions(+), 21 deletions(-) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index ae8fefc5bf..9248d2a631 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -81,7 +81,7 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p this->locpp, this->ppcell, this->vsep_cell, this->pw_wfc, this->pw_rho, this->pw_rhod, this->pw_big, this->solvent, inp); - this->stp.before_runner(ucell, this->kv, this->sf, this->pw_wfc, this->ppcell, PARAM.inp); + this->stp.before_runner(ucell, this->kv, this->sf, *this->pw_wfc, this->ppcell, PARAM.inp); ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT BASIS"); @@ -140,7 +140,7 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) this->stp.psi_t, this->p_hamilt, this->pw_wfc, this->pw_rhod, PARAM.inp); - this->stp.init(); + this->stp.init(this->p_hamilt); //! Exx calculations if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax" @@ -393,7 +393,7 @@ void ESolver_KS_PW::after_scf(UnitCell& ucell, const int istep, const ESolver_KS::after_scf(ucell, istep, conv_esolver); // Transfer data from GPU to CPU in pw basis - this->stp.copy_g2c(); + this->stp.copy_g2c(this->device); // Output quantities ModuleIO::ctrl_scf_pw(istep, ucell, this->pelec, this->chr, this->kv, this->pw_wfc, diff --git a/source/source_esolver/esolver_ks_pw.h b/source/source_esolver/esolver_ks_pw.h index 831e5bbf43..af3a418095 100644 --- a/source/source_esolver/esolver_ks_pw.h +++ b/source/source_esolver/esolver_ks_pw.h @@ -64,8 +64,6 @@ class ESolver_KS_PW : public ESolver_KS base_device::AbacusDevice_t device = {}; - using castmem_2d_d2h_op - = base_device::memory::cast_memory_op, T, base_device::DEVICE_CPU, Device>; }; } // namespace ModuleESolver #endif diff --git a/source/source_psi/setup_psi.cpp b/source/source_psi/setup_psi.cpp index 477ede15f1..a6fa267481 100644 --- a/source/source_psi/setup_psi.cpp +++ b/source/source_psi/setup_psi.cpp @@ -11,12 +11,12 @@ Setup_Psi::~Setup_Psi(){} template void Setup_Psi::before_runner( - const UnitCell& ucell, - const K_Vectors& kv, - const Structure_Factor& sf, - const ModulePW::PW_Basis_K* pw_wfc, + const UnitCell &ucell, + const K_Vectors &kv, + const Structure_Factor &sf, + const ModulePW::PW_Basis_K &pw_wfc, const pseudopot_cell_vnl &ppcell, - const Input_para& inp) + const Input_para &inp) { //! Allocate and initialize psi this->p_psi_init = new psi::PSIInit(inp.init_wfc, @@ -24,7 +24,7 @@ void Setup_Psi::before_runner( sf, kv, ppcell, pw_wfc); //! Allocate memory for cpu version of psi - allocate_psi(this->psi_cpu, kv.get_nks(), kv.ngk, PARAM.globalv.nbands_l, pw_wfc->npwk_max); + allocate_psi(this->psi_cpu, kv.get_nks(), kv.ngk, PARAM.globalv.nbands_l, pw_wfc.npwk_max); this->p_psi_init->prepare_init(inp.pw_seed); @@ -51,7 +51,7 @@ void Setup_Psi::update_psi_d() } template -void Setup_Psi::init() +void Setup_Psi::init(hamilt::Hamilt* p_hamilt) { //! Initialize wave functions if (!this->already_initpsi) @@ -64,9 +64,9 @@ void Setup_Psi::init() // Transfer data from GPU to CPU in pw basis template -void Setup_Psi::copy_g2c() +void Setup_Psi::copy_g2c(base_device::AbacusDevice_t &device) { - if (this->device == base_device::GpuDevice) + if (device == base_device::GpuDevice) { castmem_2d_d2h_op()(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), this->psi_t[0].get_pointer() - this->psi_t[0].get_psi_bias(), diff --git a/source/source_psi/setup_psi.h b/source/source_psi/setup_psi.h index 871df54fe7..d3eb64391f 100644 --- a/source/source_psi/setup_psi.h +++ b/source/source_psi/setup_psi.h @@ -8,6 +8,8 @@ #include "source_basis/module_pw/pw_basis_k.h" #include "source_pw/module_pwdft/VNL_in_pw.h" #include "source_io/module_parameter/input_parameter.h" +#include "source_base/module_device/device.h" +#include "source_hamilt/hamilt.h" template class Setup_Psi @@ -45,22 +47,25 @@ class Setup_Psi //------------ void before_runner( - const UnitCell& ucell, - const K_Vectors& kv, - const Structure_Factor& sf, - const ModulePW::PW_Basis_K* pw_wfc, + const UnitCell &ucell, + const K_Vectors &kv, + const Structure_Factor &sf, + const ModulePW::PW_Basis_K &pw_wfc, const pseudopot_cell_vnl &ppcell, - const Input_para& inp); + const Input_para &inp); - void init(); + void init(hamilt::Hamilt* p_hamilt); void update_psi_d(); // Transfer data from GPU to CPU in pw basis - void copy_g2c(); + void copy_g2c(base_device::AbacusDevice_t &device); void clean(); + using castmem_2d_d2h_op + = base_device::memory::cast_memory_op, T, base_device::DEVICE_CPU, Device>; + }; From b1ecf18653f95d036458c6321ed6b7c0aa39f28e Mon Sep 17 00:00:00 2001 From: mohanchen Date: Mon, 13 Oct 2025 23:04:58 +0800 Subject: [PATCH 09/14] update, now can be compiled successfully --- source/source_esolver/esolver_ks_pw.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/source_esolver/esolver_ks_pw.h b/source/source_esolver/esolver_ks_pw.h index af3a418095..22d805b579 100644 --- a/source/source_esolver/esolver_ks_pw.h +++ b/source/source_esolver/esolver_ks_pw.h @@ -54,7 +54,7 @@ class ESolver_KS_PW : public ESolver_KS virtual void allocate_hamilt(const UnitCell& ucell); virtual void deallocate_hamilt(); - // Electronic wave function: wf + // Electronic wave function psi Setup_Psi stp; // DFT-1/2 method From 01045969538640ffab360277529b4a05788614d8 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Tue, 14 Oct 2025 08:32:31 +0800 Subject: [PATCH 10/14] fix bug --- source/source_esolver/esolver_ks_lcaopw.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/source/source_esolver/esolver_ks_lcaopw.cpp b/source/source_esolver/esolver_ks_lcaopw.cpp index 5858a332ba..7a6d78e339 100644 --- a/source/source_esolver/esolver_ks_lcaopw.cpp +++ b/source/source_esolver/esolver_ks_lcaopw.cpp @@ -105,13 +105,12 @@ namespace ModuleESolver ucell.symm, &this->kv, this->psi_local, - this->kspw_psi, + this->stp.psi_t, this->pw_wfc, this->pw_rho, this->sf, &ucell, this->pelec)); - // this->exx_lip.init(GlobalC::exx_info.info_lip, cell.symm, &this->kv, this->p_psi_init, this->kspw_psi, this->pw_wfc, this->pw_rho, this->sf, &cell, this->pelec); } } #endif From f06944c8b690d505d47c24f6cdd051ca37dc69d7 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Tue, 14 Oct 2025 09:16:29 +0800 Subject: [PATCH 11/14] update cmake in module_psi --- source/Makefile.Objects | 4 ++-- source/source_psi/CMakeLists.txt | 3 ++- source/source_psi/setup_psi.h | 2 ++ 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 210445adee..08ad8751f8 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -460,8 +460,7 @@ OBJS_ORBITAL=ORB_atomic.o\ OBJS_PSI=psi.o\ -OBJS_PSI_INITIALIZER=setup_psi.o\ - psi_initializer.o\ +OBJS_PSI_INITIALIZER=psi_initializer.o\ psi_initializer_random.o\ psi_initializer_file.o\ psi_initializer_atomic.o\ @@ -765,6 +764,7 @@ OBJS_SRCPW=H_Ewald_pw.o\ of_stress_pw.o\ symmetry_rho.o\ symmetry_rhog.o\ + setup_psi.o\ psi_init.o\ elecond.o\ sto_tool.o\ diff --git a/source/source_psi/CMakeLists.txt b/source/source_psi/CMakeLists.txt index 0155d3a4c0..f871d2feee 100644 --- a/source/source_psi/CMakeLists.txt +++ b/source/source_psi/CMakeLists.txt @@ -7,13 +7,13 @@ add_library( add_library( psi_overall_init OBJECT + setup_psi.cpp psi_init.cpp ) add_library( psi_initializer OBJECT - setup_psi.cpp psi_initializer.cpp psi_initializer_random.cpp psi_initializer_file.cpp @@ -23,6 +23,7 @@ add_library( psi_initializer_nao_random.cpp ) + if(ENABLE_COVERAGE) add_coverage(psi) add_coverage(psi_initializer) diff --git a/source/source_psi/setup_psi.h b/source/source_psi/setup_psi.h index d3eb64391f..1c5355731c 100644 --- a/source/source_psi/setup_psi.h +++ b/source/source_psi/setup_psi.h @@ -63,6 +63,8 @@ class Setup_Psi void clean(); + private: + using castmem_2d_d2h_op = base_device::memory::cast_memory_op, T, base_device::DEVICE_CPU, Device>; From 6833fe162c418afaf68643d433822ed010f67fd1 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Tue, 14 Oct 2025 09:38:22 +0800 Subject: [PATCH 12/14] update esolver_ks --- source/source_esolver/esolver_ks.cpp | 25 ++++--------------------- 1 file changed, 4 insertions(+), 21 deletions(-) diff --git a/source/source_esolver/esolver_ks.cpp b/source/source_esolver/esolver_ks.cpp index 4768f6afd1..a4e8126ad0 100644 --- a/source/source_esolver/esolver_ks.cpp +++ b/source/source_esolver/esolver_ks.cpp @@ -334,10 +334,7 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i //---------------------------------------------------------------- // 2) compute magnetization, only for LSDA(spin==2) //---------------------------------------------------------------- - ucell.magnet.compute_mag(ucell.omega, - this->chr.nrxx, - this->chr.nxyz, - this->chr.rho, + ucell.magnet.compute_mag(ucell.omega, this->chr.nrxx, this->chr.nxyz, this->chr.rho, this->pelec->nelec_spin.data()); //---------------------------------------------------------------- @@ -434,20 +431,15 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i MPI_Bcast(this->chr.rho[0], this->pw_rhod->nrxx, MPI_DOUBLE, 0, BP_WORLD); #endif - //---------------------------------------------------------------- // 4) Update potentials (should be done every SF iter) - //---------------------------------------------------------------- - // Hamilt should be used after it is constructed. - // this->phamilt->update(conv_esolver); this->update_pot(ucell, istep, iter, conv_esolver); - //---------------------------------------------------------------- // 5) calculate energies - //---------------------------------------------------------------- // 1 means Harris-Foulkes functional // 2 means Kohn-Sham functional this->pelec->cal_energies(1); this->pelec->cal_energies(2); + if (iter == 1) { this->pelec->f_en.etot_old = this->pelec->f_en.etot; @@ -456,7 +448,6 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i this->pelec->f_en.etot_old = this->pelec->f_en.etot; - //---------------------------------------------------------------- // 6) time and meta-GGA //---------------------------------------------------------------- @@ -481,21 +472,15 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i #ifdef __RAPIDJSON - //---------------------------------------------------------------- // 7) add Json of scf mag - //---------------------------------------------------------------- - Json::add_output_scf_mag(ucell.magnet.tot_mag, - ucell.magnet.abs_mag, + Json::add_output_scf_mag(ucell.magnet.tot_mag, ucell.magnet.abs_mag, this->pelec->f_en.etot * ModuleBase::Ry_to_eV, this->pelec->f_en.etot_delta * ModuleBase::Ry_to_eV, - drho, - duration); + drho, duration); #endif //__RAPIDJSON - //---------------------------------------------------------------- // 7) SCF restart information - //---------------------------------------------------------------- if (PARAM.inp.mixing_restart > 0 && iter == this->p_chgmix->mixing_restart_step - 1 && iter != PARAM.inp.scf_nmax) @@ -504,9 +489,7 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i std::cout << " SCF restart after this step!" << std::endl; } - //---------------------------------------------------------------- // 8) Iter finish - //---------------------------------------------------------------- ESolver_FP::iter_finish(ucell, istep, iter, conv_esolver); } From 151d87e12de73a5bd9487013a6776e045beb7212 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Tue, 14 Oct 2025 09:44:49 +0800 Subject: [PATCH 13/14] fix error introduced by esolver_ks's psi, which should be esolver_ks_pw's new psi (stp.psi) --- source/source_esolver/esolver_sdft_pw.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/source_esolver/esolver_sdft_pw.cpp b/source/source_esolver/esolver_sdft_pw.cpp index c1d6f1f11e..1a9057d178 100644 --- a/source/source_esolver/esolver_sdft_pw.cpp +++ b/source/source_esolver/esolver_sdft_pw.cpp @@ -177,7 +177,7 @@ void ESolver_SDFT_PW::hamilt2rho_single(UnitCell& ucell, int istep, i hsolver_pw_sdft_obj.solve(ucell, this->p_hamilt, this->stp.psi_t[0], - this->psi[0], + this->stp.psi_cpu[0], this->pelec, this->pw_wfc, this->stowf, @@ -279,7 +279,7 @@ void ESolver_SDFT_PW::after_all_runners(UnitCell& ucell) this->pw_wfc, &this->kv, this->pelec, - reinterpret_cast>*>(this->psi), + reinterpret_cast>*>(this->stp.psi_cpu), reinterpret_cast>*>(this->p_hamilt), this->stoche, reinterpret_cast, base_device::DEVICE_CPU>*>(&stowf)); From e6492cb61a3b6dce5c3a43cb6e240f871205d05a Mon Sep 17 00:00:00 2001 From: mohanchen Date: Wed, 15 Oct 2025 08:34:49 +0800 Subject: [PATCH 14/14] change function name --- source/source_esolver/esolver_ks_pw.cpp | 5 +---- source/source_esolver/esolver_ks_pw.h | 2 ++ source/source_io/ctrl_output_pw.cpp | 8 ++++++++ source/source_io/ctrl_output_pw.h | 1 + source/source_psi/setup_psi.cpp | 7 ++++++- source/source_psi/setup_psi.h | 4 ++-- 6 files changed, 20 insertions(+), 7 deletions(-) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 9248d2a631..d2e2612070 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -392,13 +392,10 @@ void ESolver_KS_PW::after_scf(UnitCell& ucell, const int istep, const // Call 'after_scf' of ESolver_KS ESolver_KS::after_scf(ucell, istep, conv_esolver); - // Transfer data from GPU to CPU in pw basis - this->stp.copy_g2c(this->device); - // Output quantities ModuleIO::ctrl_scf_pw(istep, ucell, this->pelec, this->chr, this->kv, this->pw_wfc, this->pw_rho, this->pw_rhod, this->pw_big, this->stp, - this->ctx, this->Pgrid, PARAM.inp); + this->ctx, this->device, this->Pgrid, PARAM.inp); ModuleBase::timer::tick("ESolver_KS_PW", "after_scf"); } diff --git a/source/source_esolver/esolver_ks_pw.h b/source/source_esolver/esolver_ks_pw.h index 22d805b579..acdd7083ee 100644 --- a/source/source_esolver/esolver_ks_pw.h +++ b/source/source_esolver/esolver_ks_pw.h @@ -60,8 +60,10 @@ class ESolver_KS_PW : public ESolver_KS // DFT-1/2 method VSep* vsep_cell = nullptr; + // for get_pchg and get_wf, use ctx as input of fft Device* ctx = {}; + // for device to host data transformation base_device::AbacusDevice_t device = {}; }; diff --git a/source/source_io/ctrl_output_pw.cpp b/source/source_io/ctrl_output_pw.cpp index d71b7bb369..a67d949ade 100644 --- a/source/source_io/ctrl_output_pw.cpp +++ b/source/source_io/ctrl_output_pw.cpp @@ -92,12 +92,16 @@ void ModuleIO::ctrl_scf_pw(const int istep, const ModulePW::PW_Basis_Big *pw_big, Setup_Psi &stp, const Device* ctx, + const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp) { ModuleBase::TITLE("ModuleIO", "ctrl_scf_pw"); ModuleBase::timer::tick("ModuleIO", "ctrl_scf_pw"); + // Transfer data from device (GPU) to host (CPU) in pw basis + stp.copy_d2h(device); + //---------------------------------------------------------- //! 4) Compute density of states (DOS) //---------------------------------------------------------- @@ -382,6 +386,7 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CPU const ModulePW::PW_Basis_Big *pw_big, Setup_Psi, base_device::DEVICE_CPU> &stp, const base_device::DEVICE_CPU* ctx, + const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -398,6 +403,7 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CP const ModulePW::PW_Basis_Big *pw_big, Setup_Psi, base_device::DEVICE_CPU> &stp, const base_device::DEVICE_CPU* ctx, + const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -415,6 +421,7 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GPU const ModulePW::PW_Basis_Big *pw_big, Setup_Psi, base_device::DEVICE_GPU> &stp, const base_device::DEVICE_GPU* ctx, + const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -431,6 +438,7 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GP const ModulePW::PW_Basis_Big *pw_big, Setup_Psi, base_device::DEVICE_GPU> &stp, const base_device::DEVICE_GPU* ctx, + const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp); #endif diff --git a/source/source_io/ctrl_output_pw.h b/source/source_io/ctrl_output_pw.h index dbb8c9f0b4..b10870c9d1 100644 --- a/source/source_io/ctrl_output_pw.h +++ b/source/source_io/ctrl_output_pw.h @@ -31,6 +31,7 @@ void ctrl_scf_pw(const int istep, const ModulePW::PW_Basis_Big *pw_big, Setup_Psi &stp, const Device* ctx, + const base_device::AbacusDevice_t &device, // mohan add 2025-10-15 const Parallel_Grid ¶_grid, const Input_para& inp); diff --git a/source/source_psi/setup_psi.cpp b/source/source_psi/setup_psi.cpp index a6fa267481..84ef806bd9 100644 --- a/source/source_psi/setup_psi.cpp +++ b/source/source_psi/setup_psi.cpp @@ -64,7 +64,7 @@ void Setup_Psi::init(hamilt::Hamilt* p_hamilt) // Transfer data from GPU to CPU in pw basis template -void Setup_Psi::copy_g2c(base_device::AbacusDevice_t &device) +void Setup_Psi::copy_d2h(const base_device::AbacusDevice_t &device) { if (device == base_device::GpuDevice) { @@ -72,6 +72,11 @@ void Setup_Psi::copy_g2c(base_device::AbacusDevice_t &device) this->psi_t[0].get_pointer() - this->psi_t[0].get_psi_bias(), this->psi_cpu[0].size()); } + else + { + // do nothing + } + return; } diff --git a/source/source_psi/setup_psi.h b/source/source_psi/setup_psi.h index 1c5355731c..40a68a5020 100644 --- a/source/source_psi/setup_psi.h +++ b/source/source_psi/setup_psi.h @@ -58,8 +58,8 @@ class Setup_Psi void update_psi_d(); - // Transfer data from GPU to CPU in pw basis - void copy_g2c(base_device::AbacusDevice_t &device); + // Transfer data from device to host in pw basis + void copy_d2h(const base_device::AbacusDevice_t &device); void clean();