diff --git a/source/source_esolver/esolver_ks_lcaopw.cpp b/source/source_esolver/esolver_ks_lcaopw.cpp index 42b2e41b66..7726c4ff72 100644 --- a/source/source_esolver/esolver_ks_lcaopw.cpp +++ b/source/source_esolver/esolver_ks_lcaopw.cpp @@ -107,7 +107,7 @@ namespace ModuleESolver ucell.symm, &this->kv, this->psi_local, - this->stp.get_psi_t(), + this->stp.template get_psi_t(), this->pw_wfc, this->pw_rho, this->sf, @@ -148,7 +148,7 @@ namespace ModuleESolver bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false; hsolver::HSolverLIP hsolver_lip_obj(this->pw_wfc); - hsolver_lip_obj.solve(static_cast*>(this->p_hamilt), *this->stp.get_psi_t(), this->pelec, + hsolver_lip_obj.solve(static_cast*>(this->p_hamilt), *this->stp.template get_psi_t(), this->pelec, *this->psi_local, skip_charge,ucell.tpiba,ucell.nat); // add exx @@ -242,7 +242,7 @@ namespace ModuleESolver ModuleIO::write_Vxc(PARAM.inp.nspin, PARAM.globalv.nlocal, GlobalV::DRANK, - *this->stp.get_psi_t(), + *this->stp.template get_psi_t(), ucell, this->sf, this->solvent, diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 713b5b46c5..4d605b9738 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -54,6 +54,13 @@ ESolver_KS_PW::~ESolver_KS_PW() // delete Hamilt this->deallocate_hamilt(); + // delete exx_helper + if (this->exx_helper != nullptr) + { + delete this->exx_helper; + this->exx_helper = nullptr; + } + // mohan add 2025-10-12 this->stp.clean(); } @@ -75,7 +82,7 @@ void ESolver_KS_PW::deallocate_hamilt() { if (this->p_hamilt != nullptr) { - delete reinterpret_cast*>(this->p_hamilt); + delete static_cast*>(this->p_hamilt); this->p_hamilt = nullptr; } } @@ -86,7 +93,7 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p ESolver_KS::before_all_runners(ucell, inp); //! setup and allocation for pelec, potentials, etc. - elecstate::setup_estate_pw(ucell, this->kv, this->sf, this->pelec, this->chr, + elecstate::setup_estate_pw(ucell, this->kv, this->sf, this->pelec, this->chr, this->locpp, this->ppcell, this->vsep_cell, this->pw_wfc, this->pw_rho, this->pw_rhod, this->pw_big, this->solvent, inp); @@ -94,8 +101,37 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT BASIS"); + //! Create exx_helper based on device and precision + const bool is_gpu = (inp.device == "gpu"); + const bool is_single = (inp.precision == "single"); + +#if ((defined __CUDA) || (defined __ROCM)) + if (is_gpu) + { + if (is_single) + { + this->exx_helper = new Exx_Helper, base_device::DEVICE_GPU>(); + } + else + { + this->exx_helper = new Exx_Helper, base_device::DEVICE_GPU>(); + } + } + else +#endif + { + if (is_single) + { + this->exx_helper = new Exx_Helper, base_device::DEVICE_CPU>(); + } + else + { + this->exx_helper = new Exx_Helper, base_device::DEVICE_CPU>(); + } + } + //! Initialize exx pw - this->exx_helper.init(ucell, inp, this->pelec->wg); + this->exx_helper->init(ucell, inp, this->pelec->wg); } template @@ -128,13 +164,13 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) // init DFT+U is done in "before_all_runners" in LCAO basis. This should be refactored, mohan note 2025-11-06 pw::setup_pot(istep, ucell, this->kv, this->sf, this->pelec, this->Pgrid, this->chr, this->locpp, this->ppcell, this->dftu, this->vsep_cell, - this->stp.get_psi_t(), static_cast*>(this->p_hamilt), this->pw_wfc, this->pw_rhod, PARAM.inp); + this->stp.template get_psi_t(), static_cast*>(this->p_hamilt), this->pw_wfc, this->pw_rhod, PARAM.inp); // setup psi (electronic wave functions) this->stp.init(this->p_hamilt); //! Setup EXX helper for Hamiltonian and psi - exx_helper.before_scf(this->p_hamilt, this->stp.get_psi_t(), PARAM.inp); + exx_helper->before_scf(this->p_hamilt, this->stp.template get_psi_t(), PARAM.inp); ModuleBase::timer::tick("ESolver_KS_PW", "before_scf"); } @@ -152,7 +188,7 @@ void ESolver_KS_PW::iter_init(UnitCell& ucell, const int istep, const // update local occupations for DFT+U // should before lambda loop in DeltaSpin - pw::iter_init_dftu_pw(iter, istep, this->dftu, this->stp.get_psi_t(), this->pelec->wg, ucell, PARAM.inp); + pw::iter_init_dftu_pw(iter, istep, this->dftu, this->stp.template get_psi_t(), this->pelec->wg, ucell, PARAM.inp); } // Temporary, it should be replaced by hsolver later. @@ -188,7 +224,7 @@ void ESolver_KS_PW::hamilt2rho_single(UnitCell& ucell, const int iste hsolver::DiagoIterAssist::need_subspace, PARAM.inp.use_k_continuity); - hsolver_pw_obj.solve(static_cast*>(this->p_hamilt), *this->stp.get_psi_t(), this->pelec, this->pelec->ekb.c, + hsolver_pw_obj.solve(static_cast*>(this->p_hamilt), *this->stp.template get_psi_t(), this->pelec, this->pelec->ekb.c, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL, skip_charge, ucell.tpiba, ucell.nat); } @@ -203,9 +239,9 @@ template void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int& iter, bool& conv_esolver) { // Related to EXX - if (GlobalC::exx_info.info_global.cal_exx && !exx_helper.op_exx->first_iter) + if (GlobalC::exx_info.info_global.cal_exx && !exx_helper->get_op_first_iter()) { - this->pelec->set_exx(exx_helper.cal_exx_energy(this->stp.get_psi_t())); + this->pelec->set_exx(exx_helper->cal_exx_energy(this->stp.template get_psi_t())); } // deband is calculated from "output" charge density @@ -224,7 +260,7 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int } // Handle EXX-related operations after SCF iteration - exx_helper.iter_finish(this->pelec, &this->chr, this->stp.get_psi_t(), ucell, PARAM.inp, conv_esolver, iter); + exx_helper->iter_finish(this->pelec, &this->chr, this->stp.template get_psi_t(), ucell, PARAM.inp, conv_esolver, iter); // check if oscillate for delta_spin method pw::check_deltaspin_oscillation(iter, this->drho, this->p_chgmix, PARAM.inp); @@ -273,7 +309,7 @@ void ESolver_KS_PW::cal_force(UnitCell& ucell, ModuleBase::matrix& fo // Calculate forces ff.cal_force(ucell, force, *this->pelec, this->pw_rhod, &ucell.symm, &this->sf, this->solvent, &this->dftu, &this->locpp, &this->ppcell, - &this->kv, this->pw_wfc, this->stp.get_psi_d()); + &this->kv, this->pw_wfc, this->stp.template get_psi_d()); } template @@ -285,7 +321,7 @@ void ESolver_KS_PW::cal_stress(UnitCell& ucell, ModuleBase::matrix& s this->stp.update_psi_d(); ss.cal_stress(stress, ucell, this->dftu, this->locpp, this->ppcell, this->pw_rhod, - &ucell.symm, &this->sf, &this->kv, this->pw_wfc, this->stp.get_psi_d()); + &ucell.symm, &this->sf, &this->kv, this->pw_wfc, this->stp.template get_psi_d()); // external stress double unit_transform = 0.0; @@ -306,7 +342,7 @@ void ESolver_KS_PW::after_all_runners(UnitCell& ucell) this->pw_rho, this->pw_rhod, this->chr, this->kv, this->stp, this->sf, this->ppcell, this->solvent, this->Pgrid, PARAM.inp); - elecstate::teardown_estate_pw(this->pelec, this->vsep_cell); + elecstate::teardown_estate_pw(this->pelec, this->vsep_cell); } diff --git a/source/source_esolver/esolver_ks_pw.h b/source/source_esolver/esolver_ks_pw.h index 5ab756647d..f527101d68 100644 --- a/source/source_esolver/esolver_ks_pw.h +++ b/source/source_esolver/esolver_ks_pw.h @@ -3,7 +3,7 @@ #include "./esolver_ks.h" #include "source_psi/setup_psi_pw.h" // mohan add 20251012 #include "source_pw/module_pwdft/vsep_pw.h" -#include "source_pw/module_pwdft/exx_helper.h" +#include "source_pw/module_pwdft/exx_helper_base.h" #include "source_pw/module_pwdft/op_pw_vel.h" #include @@ -33,7 +33,7 @@ class ESolver_KS_PW : public ESolver_KS void after_all_runners(UnitCell& ucell) override; - Exx_Helper exx_helper; + Exx_HelperBase* exx_helper = nullptr; protected: virtual void before_scf(UnitCell& ucell, const int istep) override; @@ -52,7 +52,7 @@ class ESolver_KS_PW : public ESolver_KS virtual void deallocate_hamilt(); // Electronic wave function psi - Setup_Psi_pw stp; + Setup_Psi_pw stp; // DFT-1/2 method VSep* vsep_cell = nullptr; diff --git a/source/source_esolver/esolver_sdft_pw.cpp b/source/source_esolver/esolver_sdft_pw.cpp index 86658eb645..c3bbdad210 100644 --- a/source/source_esolver/esolver_sdft_pw.cpp +++ b/source/source_esolver/esolver_sdft_pw.cpp @@ -168,7 +168,7 @@ void ESolver_SDFT_PW::hamilt2rho_single(UnitCell& ucell, int istep, i hsolver_pw_sdft_obj.solve(ucell, static_cast*>(this->p_hamilt), - *this->stp.get_psi_t(), + *this->stp.template get_psi_t(), this->stp.psi_cpu[0], this->pelec, this->pw_wfc, @@ -221,7 +221,7 @@ void ESolver_SDFT_PW::cal_force(UnitCell& ucell, ModuleBase::matrix& this->locpp, this->ppcell, ucell, - *this->stp.get_psi_t(), + *this->stp.template get_psi_t(), this->stowf); } @@ -236,7 +236,7 @@ void ESolver_SDFT_PW::cal_stress(UnitCell& ucell, ModuleBase::matrix& &this->sf, &this->kv, this->pw_wfc, - *this->stp.get_psi_t(), + *this->stp.template get_psi_t(), this->stowf, &this->chr, &this->locpp, @@ -289,7 +289,7 @@ void ESolver_SDFT_PW::after_all_runners(UnitCell& ucell) &this->kv, this->pelec, this->pw_wfc, - this->stp.get_psi_t(), + this->stp.template get_psi_t(), &this->ppcell, static_cast, Device>*>(this->p_hamilt), this->stoche, diff --git a/source/source_estate/setup_estate_pw.cpp b/source/source_estate/setup_estate_pw.cpp index 5daa7be1b5..3ca02d79e6 100644 --- a/source/source_estate/setup_estate_pw.cpp +++ b/source/source_estate/setup_estate_pw.cpp @@ -1,32 +1,87 @@ #include "source_estate/setup_estate_pw.h" -#include "source_estate/elecstate_pw.h" // init of pelec -#include "source_estate/elecstate_pw_sdft.h" // init of pelec for sdft -#include "source_estate/elecstate_tools.h" // occupations +#include "source_estate/elecstate_pw.h" +#include "source_estate/elecstate_pw_sdft.h" +#include "source_estate/elecstate_tools.h" -template -void elecstate::setup_estate_pw(UnitCell& ucell, // unitcell - K_Vectors &kv, // kpoints - Structure_Factor &sf, // structure factors - elecstate::ElecState* &pelec, // pointer of electrons - Charge &chr, // charge density - pseudopot_cell_vl &locpp, // local pseudopotentials - pseudopot_cell_vnl &ppcell, // non-local pseudopotentials - VSep* &vsep_cell, // U-1/2 method - ModulePW::PW_Basis_K* pw_wfc, // pw for wfc - ModulePW::PW_Basis* pw_rho, // pw for rho - ModulePW::PW_Basis* pw_rhod, // pw for rhod - ModulePW::PW_Basis_Big* pw_big, // pw for big grid - surchem &solvent, // solvent - const Input_para& inp) // input parameters +namespace elecstate +{ + +void setup_estate_pw( + UnitCell& ucell, + K_Vectors& kv, + Structure_Factor& sf, + elecstate::ElecState*& pelec, + Charge& chr, + pseudopot_cell_vl& locpp, + pseudopot_cell_vnl& ppcell, + VSep*& vsep_cell, + ModulePW::PW_Basis_K* pw_wfc, + ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, + ModulePW::PW_Basis_Big* pw_big, + surchem& solvent, + const Input_para& inp) { ModuleBase::TITLE("elecstate", "setup_estate_pw"); - //! Initialize ElecState, set pelec pointer + const bool is_gpu = (inp.device == "gpu"); + const bool is_single = (inp.precision == "single"); + +#if ((defined __CUDA) || (defined __ROCM)) + if (is_gpu) + { + if (is_single) + { + setup_estate_pw_impl, base_device::DEVICE_GPU>( + ucell, kv, sf, pelec, chr, locpp, ppcell, vsep_cell, + pw_wfc, pw_rho, pw_rhod, pw_big, solvent, inp); + } + else + { + setup_estate_pw_impl, base_device::DEVICE_GPU>( + ucell, kv, sf, pelec, chr, locpp, ppcell, vsep_cell, + pw_wfc, pw_rho, pw_rhod, pw_big, solvent, inp); + } + } + else +#endif + { + if (is_single) + { + setup_estate_pw_impl, base_device::DEVICE_CPU>( + ucell, kv, sf, pelec, chr, locpp, ppcell, vsep_cell, + pw_wfc, pw_rho, pw_rhod, pw_big, solvent, inp); + } + else + { + setup_estate_pw_impl, base_device::DEVICE_CPU>( + ucell, kv, sf, pelec, chr, locpp, ppcell, vsep_cell, + pw_wfc, pw_rho, pw_rhod, pw_big, solvent, inp); + } + } +} + +template +void setup_estate_pw_impl( + UnitCell& ucell, + K_Vectors& kv, + Structure_Factor& sf, + elecstate::ElecState*& pelec, + Charge& chr, + pseudopot_cell_vl& locpp, + pseudopot_cell_vnl& ppcell, + VSep*& vsep_cell, + ModulePW::PW_Basis_K* pw_wfc, + ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, + ModulePW::PW_Basis_Big* pw_big, + surchem& solvent, + const Input_para& inp) +{ if (pelec == nullptr) { if (inp.esolver_type == "sdft") { - //! SDFT only supports double precision currently pelec = new elecstate::ElecStatePW_SDFT, Device>(pw_wfc, &chr, &kv, &ucell, &ppcell, pw_rho, pw_big); } @@ -37,14 +92,12 @@ void elecstate::setup_estate_pw(UnitCell& ucell, // unitcell } } - //! Initialize DFT-1/2 if (PARAM.inp.dfthalf_type > 0) { vsep_cell = new VSep; vsep_cell->init_vsep(*pw_rhod, ucell.sep_cell); } - //! Initialize the potential. if (pelec->pot == nullptr) { pelec->pot = new elecstate::Potential(pw_rhod, @@ -52,16 +105,13 @@ void elecstate::setup_estate_pw(UnitCell& ucell, // unitcell &solvent, &(pelec->f_en.etxc), &(pelec->f_en.vtxc), vsep_cell); } - //! Initalize local pseudopotential locpp.init_vloc(ucell, pw_rhod); ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "LOCAL POTENTIAL"); - //! Initalize non-local pseudopotential ppcell.init(ucell, &sf, pw_wfc); ppcell.init_vnl(ucell, pw_rhod); ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "NON-LOCAL POTENTIAL"); - //! Setup occupations if (inp.ocp) { elecstate::fixed_weights(inp.ocp_kb, @@ -71,116 +121,95 @@ void elecstate::setup_estate_pw(UnitCell& ucell, // unitcell pelec->wg, pelec->skip_weights); } - - return; } +void teardown_estate_pw(elecstate::ElecState*& pelec, VSep*& vsep_cell) +{ + ModuleBase::TITLE("elecstate", "teardown_estate_pw"); + + if (vsep_cell != nullptr) + { + delete vsep_cell; + } + + if (pelec != nullptr) + { + delete pelec; + pelec = nullptr; + } +} template -void elecstate::teardown_estate_pw(elecstate::ElecState* &pelec, VSep* &vsep_cell) +void teardown_estate_pw_impl(elecstate::ElecState*& pelec, VSep*& vsep_cell) { - ModuleBase::TITLE("elecstate", "teardown_estate_pw"); + ModuleBase::TITLE("elecstate", "teardown_estate_pw_impl"); if (vsep_cell != nullptr) { delete vsep_cell; } - // mohan update 20251005 to increase the security level if (pelec != nullptr) { - auto* pw_elec = dynamic_cast*>(pelec); - if (pw_elec) - { - delete pw_elec; - pelec = nullptr; - } - else - { - ModuleBase::WARNING_QUIT("elecstate::teardown_estate_pw", "Invalid ElecState type"); + auto* pw_elec = dynamic_cast*>(pelec); + if (pw_elec) + { + delete pw_elec; + pelec = nullptr; + } + else + { + ModuleBase::WARNING_QUIT("elecstate::teardown_estate_pw_impl", "Invalid ElecState type"); } } } +template void setup_estate_pw_impl, base_device::DEVICE_CPU>( + UnitCell& ucell, K_Vectors& kv, Structure_Factor& sf, + elecstate::ElecState*& pelec, Charge& chr, + pseudopot_cell_vl& locpp, pseudopot_cell_vnl& ppcell, VSep*& vsep_cell, + ModulePW::PW_Basis_K* pw_wfc, ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, ModulePW::PW_Basis_Big* pw_big, + surchem& solvent, const Input_para& inp); -template void elecstate::setup_estate_pw, base_device::DEVICE_CPU>( - UnitCell& ucell, // unitcell - K_Vectors &kv, // kpoints - Structure_Factor &sf, // structure factors - elecstate::ElecState* &pelec, // pointer of electrons - Charge &chr, // charge density - pseudopot_cell_vl &locpp, // local pseudopotentials - pseudopot_cell_vnl &ppcell, // non-local pseudopotentials - VSep* &vsep_cell, // U-1/2 method - ModulePW::PW_Basis_K *pw_wfc, // pw for wfc - ModulePW::PW_Basis *pw_rho, // pw for rho - ModulePW::PW_Basis *pw_rhod, // pw for rhod - ModulePW::PW_Basis_Big* pw_big, // pw for big grid - surchem &solvent, // solvent - const Input_para& inp); // input parameters - -template void elecstate::setup_estate_pw, base_device::DEVICE_CPU>( - UnitCell& ucell, // unitcell - K_Vectors &kv, // kpoints - Structure_Factor &sf, // structure factors - elecstate::ElecState* &pelec, // pointer of electrons - Charge &chr, // charge density - pseudopot_cell_vl &locpp, // local pseudopotentials - pseudopot_cell_vnl &ppcell, // non-local pseudopotentials - VSep* &vsep_cell, // U-1/2 method - ModulePW::PW_Basis_K *pw_wfc, // pw for wfc - ModulePW::PW_Basis *pw_rho, // pw for rho - ModulePW::PW_Basis *pw_rhod, // pw for rhod - ModulePW::PW_Basis_Big* pw_big, // pw for big grid - surchem &solvent, // solvent - const Input_para& inp); // input parameters - - -template void elecstate::teardown_estate_pw, base_device::DEVICE_CPU>( - elecstate::ElecState* &pelec, VSep* &vsep_cell); - -template void elecstate::teardown_estate_pw, base_device::DEVICE_CPU>( - elecstate::ElecState* &pelec, VSep* &vsep_cell); +template void setup_estate_pw_impl, base_device::DEVICE_CPU>( + UnitCell& ucell, K_Vectors& kv, Structure_Factor& sf, + elecstate::ElecState*& pelec, Charge& chr, + pseudopot_cell_vl& locpp, pseudopot_cell_vnl& ppcell, VSep*& vsep_cell, + ModulePW::PW_Basis_K* pw_wfc, ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, ModulePW::PW_Basis_Big* pw_big, + surchem& solvent, const Input_para& inp); +template void teardown_estate_pw_impl, base_device::DEVICE_CPU>( + elecstate::ElecState*& pelec, VSep*& vsep_cell); + +template void teardown_estate_pw_impl, base_device::DEVICE_CPU>( + elecstate::ElecState*& pelec, VSep*& vsep_cell); #if ((defined __CUDA) || (defined __ROCM)) -template void elecstate::setup_estate_pw, base_device::DEVICE_GPU>( - UnitCell& ucell, // unitcell - K_Vectors &kv, // kpoints - Structure_Factor &sf, // structure factors - elecstate::ElecState* &pelec, // pointer of electrons - Charge &chr, // charge density - pseudopot_cell_vl &locpp, // local pseudopotentials - pseudopot_cell_vnl &ppcell, // non-local pseudopotentials - VSep* &vsep_cell, // U-1/2 method - ModulePW::PW_Basis_K *pw_wfc, // pw for wfc - ModulePW::PW_Basis *pw_rho, // pw for rho - ModulePW::PW_Basis *pw_rhod, // pw for rhod - ModulePW::PW_Basis_Big* pw_big, // pw for big grid - surchem &solvent, // solvent - const Input_para& inp); // input parameters - -template void elecstate::setup_estate_pw, base_device::DEVICE_GPU>( - UnitCell& ucell, // unitcell - K_Vectors &kv, // kpoints - Structure_Factor &sf, // structure factors - elecstate::ElecState* &pelec, // pointer of electrons - Charge &chr, // charge density - pseudopot_cell_vl &locpp, // local pseudopotentials - pseudopot_cell_vnl &ppcell, // non-local pseudopotentials - VSep* &vsep_cell, // U-1/2 method - ModulePW::PW_Basis_K *pw_wfc, // pw for wfc - ModulePW::PW_Basis *pw_rho, // pw for rho - ModulePW::PW_Basis *pw_rhod, // pw for rhod - ModulePW::PW_Basis_Big* pw_big, // pw for big grid - surchem &solvent, // solvent - const Input_para& inp); // input parameters - -template void elecstate::teardown_estate_pw, base_device::DEVICE_GPU>( - elecstate::ElecState* &pelec, VSep* &vsep_cell); - -template void elecstate::teardown_estate_pw, base_device::DEVICE_GPU>( - elecstate::ElecState* &pelec, VSep* &vsep_cell); +template void setup_estate_pw_impl, base_device::DEVICE_GPU>( + UnitCell& ucell, K_Vectors& kv, Structure_Factor& sf, + elecstate::ElecState*& pelec, Charge& chr, + pseudopot_cell_vl& locpp, pseudopot_cell_vnl& ppcell, VSep*& vsep_cell, + ModulePW::PW_Basis_K* pw_wfc, ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, ModulePW::PW_Basis_Big* pw_big, + surchem& solvent, const Input_para& inp); + +template void setup_estate_pw_impl, base_device::DEVICE_GPU>( + UnitCell& ucell, K_Vectors& kv, Structure_Factor& sf, + elecstate::ElecState*& pelec, Charge& chr, + pseudopot_cell_vl& locpp, pseudopot_cell_vnl& ppcell, VSep*& vsep_cell, + ModulePW::PW_Basis_K* pw_wfc, ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, ModulePW::PW_Basis_Big* pw_big, + surchem& solvent, const Input_para& inp); + +template void teardown_estate_pw_impl, base_device::DEVICE_GPU>( + elecstate::ElecState*& pelec, VSep*& vsep_cell); + +template void teardown_estate_pw_impl, base_device::DEVICE_GPU>( + elecstate::ElecState*& pelec, VSep*& vsep_cell); #endif + +} diff --git a/source/source_estate/setup_estate_pw.h b/source/source_estate/setup_estate_pw.h index 81a2261f6d..cd1a388a74 100644 --- a/source/source_estate/setup_estate_pw.h +++ b/source/source_estate/setup_estate_pw.h @@ -1,7 +1,7 @@ #ifndef SETUP_ESTATE_PW_H #define SETUP_ESTATE_PW_H -#include "source_base/module_device/device.h" // use Device +#include "source_base/module_device/device.h" #include "source_cell/unitcell.h" #include "source_cell/klist.h" #include "source_pw/module_pwdft/structure_factor.h" @@ -12,26 +12,44 @@ namespace elecstate { +void setup_estate_pw( + UnitCell& ucell, + K_Vectors& kv, + Structure_Factor& sf, + elecstate::ElecState*& pelec, + Charge& chr, + pseudopot_cell_vl& locpp, + pseudopot_cell_vnl& ppcell, + VSep*& vsep_cell, + ModulePW::PW_Basis_K* pw_wfc, + ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, + ModulePW::PW_Basis_Big* pw_big, + surchem& solvent, + const Input_para& inp); + +void teardown_estate_pw(elecstate::ElecState*& pelec, VSep*& vsep_cell); + template -void setup_estate_pw(UnitCell& ucell, // unitcell - K_Vectors &kv, // kpoints - Structure_Factor &sf, // structure factors - elecstate::ElecState* &pelec, // pointer of electrons - Charge &chr, // charge density - pseudopot_cell_vl &locpp, // local pseudopotentials - pseudopot_cell_vnl &ppcell, // non-local pseudopotentials - VSep* &vsep_cell, // U-1/2 method - ModulePW::PW_Basis_K* pw_wfc, // pw for wfc - ModulePW::PW_Basis* pw_rho, // pw for rho - ModulePW::PW_Basis* pw_rhod, // pw for rhod - ModulePW::PW_Basis_Big* pw_big, // pw for big grid - surchem &solvent, // solvent - const Input_para& inp); // input parameters +void setup_estate_pw_impl( + UnitCell& ucell, + K_Vectors& kv, + Structure_Factor& sf, + elecstate::ElecState*& pelec, + Charge& chr, + pseudopot_cell_vl& locpp, + pseudopot_cell_vnl& ppcell, + VSep*& vsep_cell, + ModulePW::PW_Basis_K* pw_wfc, + ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, + ModulePW::PW_Basis_Big* pw_big, + surchem& solvent, + const Input_para& inp); template -void teardown_estate_pw(elecstate::ElecState* &pelec, VSep* &vsep_cell); +void teardown_estate_pw_impl(elecstate::ElecState*& pelec, VSep*& vsep_cell); } - #endif diff --git a/source/source_io/module_ctrl/ctrl_output_pw.cpp b/source/source_io/module_ctrl/ctrl_output_pw.cpp index ed69182a79..1020e7aef6 100644 --- a/source/source_io/module_ctrl/ctrl_output_pw.cpp +++ b/source/source_io/module_ctrl/ctrl_output_pw.cpp @@ -90,7 +90,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - Setup_Psi_pw &stp, + Setup_Psi_pw &stp, const Parallel_Grid ¶_grid, const Input_para& inp) { @@ -101,9 +101,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, Device* ctx = nullptr; // Transfer data from device (GPU) to host (CPU) in pw basis - base_device::DeviceContext* device_ctx = &base_device::DeviceContext::instance(); - device_ctx->set_device_type(stp.get_device_type()); - stp.copy_d2h(device_ctx); + stp.copy_d2h(); //---------------------------------------------------------- //! 4) Compute density of states (DOS) @@ -177,7 +175,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, pw_rhod->nxyz, ngmc, &ucell, - stp.get_psi_d(), + stp.template get_psi_d(), pw_rhod, pw_wfc, ctx, @@ -239,7 +237,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, if (inp.onsite_radius > 0) { // float type has not been implemented auto* onsite_p = projectors::OnsiteProjector::get_instance(); - onsite_p->cal_occupations(reinterpret_cast, Device>*>(stp.get_psi_t()), + onsite_p->cal_occupations(reinterpret_cast, Device>*>(stp.template get_psi_t()), pelec->wg); } @@ -255,7 +253,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - Setup_Psi_pw &stp, + Setup_Psi_pw &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, @@ -313,7 +311,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, inp.nspin, pw_rhod->nxyz, &ucell, - stp.get_psi_d(), + stp.template get_psi_d(), pw_wfc, ctx, para_grid, @@ -329,7 +327,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, if (inp.cal_cond) { using Real = typename GetTypeReal::type; - EleCond elec_cond(&ucell, &kv, pelec, pw_wfc, stp.get_psi_t(), &ppcell); + EleCond elec_cond(&ucell, &kv, pelec, pw_wfc, stp.template get_psi_t(), &ppcell); elec_cond.KG(inp.cond_smear, inp.cond_fwhm, inp.cond_wcut, @@ -366,7 +364,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, pw_rho); write_mlkedf_desc.generateTrainData_KS(PARAM.globalv.global_mlkedf_descriptor_dir, - stp.get_psi_t(), + stp.template get_psi_t(), pelec, pw_wfc, pw_rho, @@ -389,7 +387,7 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CPU const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - Setup_Psi_pw, base_device::DEVICE_CPU> &stp, + Setup_Psi_pw &stp, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -404,7 +402,7 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CP const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - Setup_Psi_pw, base_device::DEVICE_CPU> &stp, + Setup_Psi_pw &stp, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -420,7 +418,7 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GPU const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - Setup_Psi_pw, base_device::DEVICE_GPU> &stp, + Setup_Psi_pw &stp, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -435,7 +433,7 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GP const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - Setup_Psi_pw, base_device::DEVICE_GPU> &stp, + Setup_Psi_pw &stp, const Parallel_Grid ¶_grid, const Input_para& inp); #endif @@ -449,7 +447,7 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_ ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - Setup_Psi_pw, base_device::DEVICE_CPU> &stp, + Setup_Psi_pw &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, @@ -465,7 +463,7 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - Setup_Psi_pw, base_device::DEVICE_CPU> &stp, + Setup_Psi_pw &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, @@ -482,7 +480,7 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_ ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - Setup_Psi_pw, base_device::DEVICE_GPU> &stp, + Setup_Psi_pw &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, @@ -498,7 +496,7 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - Setup_Psi_pw, base_device::DEVICE_GPU> &stp, + Setup_Psi_pw &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, diff --git a/source/source_io/module_ctrl/ctrl_output_pw.h b/source/source_io/module_ctrl/ctrl_output_pw.h index 554e5d8adb..00b7509990 100644 --- a/source/source_io/module_ctrl/ctrl_output_pw.h +++ b/source/source_io/module_ctrl/ctrl_output_pw.h @@ -29,7 +29,7 @@ void ctrl_scf_pw(const int istep, const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - Setup_Psi_pw &stp, + Setup_Psi_pw &stp, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -42,7 +42,7 @@ void ctrl_runner_pw(UnitCell& ucell, ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - Setup_Psi_pw &stp, + Setup_Psi_pw &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, @@ -58,7 +58,7 @@ void ctrl_runner_pw(UnitCell& ucell, ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - Setup_Psi_pw &stp, + Setup_Psi_pw &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, diff --git a/source/source_psi/setup_psi_pw.cpp b/source/source_psi/setup_psi_pw.cpp index fb7c9d4edf..f5bc240292 100644 --- a/source/source_psi/setup_psi_pw.cpp +++ b/source/source_psi/setup_psi_pw.cpp @@ -1,14 +1,12 @@ #include "source_psi/setup_psi_pw.h" #include "source_io/module_parameter/parameter.h" // use parameter -template -Setup_Psi_pw::Setup_Psi_pw(){} +Setup_Psi_pw::Setup_Psi_pw(){} -template -Setup_Psi_pw::~Setup_Psi_pw(){} +Setup_Psi_pw::~Setup_Psi_pw(){} template -void Setup_Psi_pw::before_runner( +void Setup_Psi_pw::before_runner_impl( const UnitCell &ucell, const K_Vectors &kv, const Structure_Factor &sf, @@ -16,18 +14,15 @@ void Setup_Psi_pw::before_runner( const pseudopot_cell_vnl &ppcell, const Input_para &inp) { - //! Allocate and initialize psi this->p_psi_init = new psi::PSIPrepare(inp.init_wfc, inp.ks_solver, inp.basis_type, GlobalV::MY_RANK, ucell, sf, kv, ppcell, pw_wfc); - //! Allocate memory for cpu version of psi allocate_psi(this->psi_cpu, kv.get_nks(), kv.ngk, PARAM.globalv.nbands_l, pw_wfc.npwk_max); auto* p_psi_init = static_cast*>(this->p_psi_init); p_psi_init->prepare_init(inp.pw_seed); - //! Set runtime type information if (std::is_same::value) { precision_type_ = PrecisionType::Float; } else if (std::is_same::value) { @@ -44,8 +39,6 @@ void Setup_Psi_pw::before_runner( device_type_ = base_device::CpuDevice; } - //! If GPU or single precision, allocate a new psi (psi_t). - //! otherwise, transform psi_cpu to psi_t if (inp.device == "gpu" || inp.precision == "single") { this->psi_t = static_cast(new psi::Psi(this->psi_cpu[0])); } else { @@ -53,88 +46,294 @@ void Setup_Psi_pw::before_runner( } } +void Setup_Psi_pw::before_runner( + const UnitCell &ucell, + const K_Vectors &kv, + const Structure_Factor &sf, + const ModulePW::PW_Basis_K &pw_wfc, + const pseudopot_cell_vnl &ppcell, + const Input_para &inp) +{ + const bool is_gpu = (inp.device == "gpu"); + const bool is_single = (inp.precision == "single"); + +#if ((defined __CUDA) || (defined __ROCM)) + if (is_gpu) { + if (is_single) { + before_runner_impl, base_device::DEVICE_GPU>( + ucell, kv, sf, pw_wfc, ppcell, inp); + } else { + before_runner_impl, base_device::DEVICE_GPU>( + ucell, kv, sf, pw_wfc, ppcell, inp); + } + } else +#endif + { + if (is_single) { + before_runner_impl, base_device::DEVICE_CPU>( + ucell, kv, sf, pw_wfc, ppcell, inp); + } else { + before_runner_impl, base_device::DEVICE_CPU>( + ucell, kv, sf, pw_wfc, ppcell, inp); + } + } +} + template -void Setup_Psi_pw::update_psi_d() +void Setup_Psi_pw::update_psi_d_impl() { - if (this->psi_d != nullptr && PARAM.inp.precision == "single") + if (this->psi_d != nullptr && this->precision_type_ == PrecisionType::ComplexFloat) { - delete this->get_psi_d(); + delete this->get_psi_d(); } // Refresh this->psi_d - if (PARAM.inp.precision == "single") { - this->psi_d = static_cast(new psi::Psi, Device>(*this->get_psi_t())); + if (this->precision_type_ == PrecisionType::ComplexFloat) { + this->psi_d = static_cast(new psi::Psi, Device>(*this->get_psi_t())); } else { this->psi_d = static_cast(reinterpret_cast, Device>*>(this->psi_t)); } } +void Setup_Psi_pw::update_psi_d() +{ +#if ((defined __CUDA) || (defined __ROCM)) + if (this->device_type_ == base_device::GpuDevice) + { + if (this->precision_type_ == PrecisionType::ComplexFloat) + { + update_psi_d_impl, base_device::DEVICE_GPU>(); + } + else + { + update_psi_d_impl, base_device::DEVICE_GPU>(); + } + } + else +#endif + { + if (this->precision_type_ == PrecisionType::ComplexFloat) + { + update_psi_d_impl, base_device::DEVICE_CPU>(); + } + else + { + update_psi_d_impl, base_device::DEVICE_CPU>(); + } + } +} + template -void Setup_Psi_pw::init(hamilt::HamiltBase* p_hamilt) +void Setup_Psi_pw::init_impl(hamilt::Hamilt* p_hamilt) { - //! Initialize wave functions if (!this->already_initpsi) { auto* p_psi_init = static_cast*>(this->p_psi_init); - auto* hamilt = static_cast*>(p_hamilt); - p_psi_init->initialize_psi(this->psi_cpu, this->get_psi_t(), hamilt, GlobalV::ofs_running); + p_psi_init->initialize_psi(this->psi_cpu, this->get_psi_t(), p_hamilt, GlobalV::ofs_running); this->already_initpsi = true; } } +void Setup_Psi_pw::init(hamilt::HamiltBase* p_hamilt) +{ + if (this->already_initpsi) + { + return; + } -// Transfer data from GPU to CPU in pw basis (runtime version) +#if ((defined __CUDA) || (defined __ROCM)) + if (this->device_type_ == base_device::GpuDevice) + { + if (this->precision_type_ == PrecisionType::ComplexFloat) + { + init_impl, base_device::DEVICE_GPU>( + static_cast, base_device::DEVICE_GPU>*>(p_hamilt)); + } + else + { + init_impl, base_device::DEVICE_GPU>( + static_cast, base_device::DEVICE_GPU>*>(p_hamilt)); + } + } + else +#endif + { + if (this->precision_type_ == PrecisionType::ComplexFloat) + { + init_impl, base_device::DEVICE_CPU>( + static_cast, base_device::DEVICE_CPU>*>(p_hamilt)); + } + else + { + init_impl, base_device::DEVICE_CPU>( + static_cast, base_device::DEVICE_CPU>*>(p_hamilt)); + } + } +} + + +// Transfer data from GPU to CPU in pw basis template -void Setup_Psi_pw::copy_d2h(const base_device::DeviceContext* ctx) +void Setup_Psi_pw::copy_d2h_impl() { - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + auto* psi_t = this->get_psi_t(); + this->castmem_d2h_impl(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), + psi_t->get_pointer() - psi_t->get_psi_bias(), + this->psi_cpu[0].size()); +} + +void Setup_Psi_pw::copy_d2h() +{ + if (this->device_type_ != base_device::GpuDevice) + { + return; + } + +#if ((defined __CUDA) || (defined __ROCM)) + if (this->precision_type_ == PrecisionType::ComplexFloat) { - auto* psi_t = this->get_psi_t(); - this->castmem_d2h_impl(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), - psi_t->get_pointer() - psi_t->get_psi_bias(), - this->psi_cpu[0].size()); + copy_d2h_impl, base_device::DEVICE_GPU>(); } else { - // do nothing + copy_d2h_impl, base_device::DEVICE_GPU>(); } - return; +#endif } template -void Setup_Psi_pw::castmem_d2h_impl(std::complex* dst, const std::complex* src, const size_t size) +void Setup_Psi_pw::castmem_d2h_impl(std::complex* dst, const std::complex* src, const size_t size) { base_device::memory::cast_memory_op, std::complex, base_device::DEVICE_CPU, Device>()(dst, src, size); } template -void Setup_Psi_pw::castmem_d2h_impl(std::complex* dst, const std::complex* src, const size_t size) +void Setup_Psi_pw::castmem_d2h_impl(std::complex* dst, const std::complex* src, const size_t size) { base_device::memory::cast_memory_op, std::complex, base_device::DEVICE_CPU, Device>()(dst, src, size); } - - template -void Setup_Psi_pw::clean() +void Setup_Psi_pw::clean_impl() { - if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") + if (this->device_type_ == base_device::GpuDevice || this->precision_type_ == PrecisionType::ComplexFloat) { - delete this->get_psi_t(); + delete this->get_psi_t(); } - if (PARAM.inp.precision == "single") + if (this->precision_type_ == PrecisionType::ComplexFloat) { - delete this->get_psi_d(); + delete this->get_psi_d(); } delete this->psi_cpu; delete this->p_psi_init; } -template class Setup_Psi_pw, base_device::DEVICE_CPU>; -template class Setup_Psi_pw, base_device::DEVICE_CPU>; +void Setup_Psi_pw::clean() +{ #if ((defined __CUDA) || (defined __ROCM)) -template class Setup_Psi_pw, base_device::DEVICE_GPU>; -template class Setup_Psi_pw, base_device::DEVICE_GPU>; + if (this->device_type_ == base_device::GpuDevice) + { + if (this->precision_type_ == PrecisionType::ComplexFloat) + { + clean_impl, base_device::DEVICE_GPU>(); + } + else + { + clean_impl, base_device::DEVICE_GPU>(); + } + } + else +#endif + { + if (this->precision_type_ == PrecisionType::ComplexFloat) + { + clean_impl, base_device::DEVICE_CPU>(); + } + else + { + clean_impl, base_device::DEVICE_CPU>(); + } + } +} + +template class psi::PSIPrepare, base_device::DEVICE_CPU>; +template class psi::PSIPrepare, base_device::DEVICE_CPU>; + +template void Setup_Psi_pw::before_runner_impl, base_device::DEVICE_CPU>( + const UnitCell&, const K_Vectors&, const Structure_Factor&, + const ModulePW::PW_Basis_K&, const pseudopot_cell_vnl&, const Input_para&); + +template void Setup_Psi_pw::before_runner_impl, base_device::DEVICE_CPU>( + const UnitCell&, const K_Vectors&, const Structure_Factor&, + const ModulePW::PW_Basis_K&, const pseudopot_cell_vnl&, const Input_para&); + +template void Setup_Psi_pw::init_impl, base_device::DEVICE_CPU>( + hamilt::Hamilt, base_device::DEVICE_CPU>*); + +template void Setup_Psi_pw::init_impl, base_device::DEVICE_CPU>( + hamilt::Hamilt, base_device::DEVICE_CPU>*); + +template void Setup_Psi_pw::update_psi_d_impl, base_device::DEVICE_CPU>(); + +template void Setup_Psi_pw::update_psi_d_impl, base_device::DEVICE_CPU>(); + +template void Setup_Psi_pw::clean_impl, base_device::DEVICE_CPU>(); + +template void Setup_Psi_pw::clean_impl, base_device::DEVICE_CPU>(); + +template void Setup_Psi_pw::castmem_d2h_impl, base_device::DEVICE_CPU>( + std::complex*, const std::complex*, const size_t); + +template void Setup_Psi_pw::castmem_d2h_impl, base_device::DEVICE_CPU>( + std::complex*, const std::complex*, const size_t); + +template void Setup_Psi_pw::castmem_d2h_impl, base_device::DEVICE_CPU>( + std::complex*, const std::complex*, const size_t); + +template void Setup_Psi_pw::castmem_d2h_impl, base_device::DEVICE_CPU>( + std::complex*, const std::complex*, const size_t); + +#if ((defined __CUDA) || (defined __ROCM)) +template class psi::PSIPrepare, base_device::DEVICE_GPU>; +template class psi::PSIPrepare, base_device::DEVICE_GPU>; + +template void Setup_Psi_pw::before_runner_impl, base_device::DEVICE_GPU>( + const UnitCell&, const K_Vectors&, const Structure_Factor&, + const ModulePW::PW_Basis_K&, const pseudopot_cell_vnl&, const Input_para&); + +template void Setup_Psi_pw::before_runner_impl, base_device::DEVICE_GPU>( + const UnitCell&, const K_Vectors&, const Structure_Factor&, + const ModulePW::PW_Basis_K&, const pseudopot_cell_vnl&, const Input_para&); + +template void Setup_Psi_pw::init_impl, base_device::DEVICE_GPU>( + hamilt::Hamilt, base_device::DEVICE_GPU>*); + +template void Setup_Psi_pw::init_impl, base_device::DEVICE_GPU>( + hamilt::Hamilt, base_device::DEVICE_GPU>*); + +template void Setup_Psi_pw::update_psi_d_impl, base_device::DEVICE_GPU>(); + +template void Setup_Psi_pw::update_psi_d_impl, base_device::DEVICE_GPU>(); + +template void Setup_Psi_pw::copy_d2h_impl, base_device::DEVICE_GPU>(); + +template void Setup_Psi_pw::copy_d2h_impl, base_device::DEVICE_GPU>(); + +template void Setup_Psi_pw::clean_impl, base_device::DEVICE_GPU>(); + +template void Setup_Psi_pw::clean_impl, base_device::DEVICE_GPU>(); + +template void Setup_Psi_pw::castmem_d2h_impl, base_device::DEVICE_GPU>( + std::complex*, const std::complex*, const size_t); + +template void Setup_Psi_pw::castmem_d2h_impl, base_device::DEVICE_GPU>( + std::complex*, const std::complex*, const size_t); + +template void Setup_Psi_pw::castmem_d2h_impl, base_device::DEVICE_GPU>( + std::complex*, const std::complex*, const size_t); + +template void Setup_Psi_pw::castmem_d2h_impl, base_device::DEVICE_GPU>( + std::complex*, const std::complex*, const size_t); #endif diff --git a/source/source_psi/setup_psi_pw.h b/source/source_psi/setup_psi_pw.h index 29461e41a1..88e9d42bf1 100644 --- a/source/source_psi/setup_psi_pw.h +++ b/source/source_psi/setup_psi_pw.h @@ -11,7 +11,6 @@ #include "source_base/module_device/device.h" #include "source_hamilt/hamilt.h" -template class Setup_Psi_pw { public: @@ -19,45 +18,30 @@ class Setup_Psi_pw Setup_Psi_pw(); ~Setup_Psi_pw(); + //------------ + // public types + //------------ + + // Precision type: 0 = float, 1 = double, 2 = complex, 3 = complex + enum class PrecisionType { + Float = 0, + Double = 1, + ComplexFloat = 2, + ComplexDouble = 3 + }; + //------------ // variables // psi_cpu, complex on cpu - // psi_t, complex on cpu/gpu - // psi_d, complex on cpu/gpu //------------ // originally, this term is psi // for PW, we have psi_cpu psi::Psi, base_device::DEVICE_CPU>* psi_cpu = nullptr; - // originally, this term is kspw_psi - // if CPU, kspw_psi = psi, otherwise, kspw_psi has a new copy - // psi::Psi* psi_t = nullptr; // Original template version - void* psi_t = nullptr; // Use void* to store pointer, runtime type information records actual type - - // originally, this term is __kspw_psi - // psi::Psi, Device>* psi_d = nullptr; // Original template version - void* psi_d = nullptr; // Use void* to store pointer, runtime type information records actual type - // psi_initializer controller psi::PSIPrepareBase* p_psi_init = nullptr; - bool already_initpsi = false; - - //------------ - // runtime type information - //------------ - base_device::AbacusDevice_t device_type_ = base_device::CpuDevice; - - // Precision type: 0 = float, 1 = double, 2 = complex, 3 = complex - enum class PrecisionType { - Float = 0, - Double = 1, - ComplexFloat = 2, - ComplexDouble = 3 - }; - PrecisionType precision_type_ = PrecisionType::ComplexDouble; - //------------ // functions //------------ @@ -74,8 +58,8 @@ class Setup_Psi_pw void update_psi_d(); - // Transfer data from device to host in pw basis (runtime version) - void copy_d2h(const base_device::DeviceContext* ctx); + // Transfer data from device to host in pw basis + void copy_d2h(); void clean(); @@ -94,20 +78,73 @@ class Setup_Psi_pw PrecisionType get_precision_type() const { return precision_type_; } // Get psi_t pointer (template version, for backward compatibility) + template psi::Psi* get_psi_t() { return static_cast*>(psi_t); } + + template const psi::Psi* get_psi_t() const { return static_cast*>(psi_t); } // Get psi_d pointer (template version, for backward compatibility) + template psi::Psi, Device>* get_psi_d() { return static_cast, Device>*>(psi_d); } + + template const psi::Psi, Device>* get_psi_d() const { return static_cast, Device>*>(psi_d); } private: + //------------ + // private variables + //------------ + + // originally, this term is kspw_psi + // if CPU, kspw_psi = psi, otherwise, kspw_psi has a new copy + void* psi_t = nullptr; // Use void* to store pointer, runtime type information records actual type + + // originally, this term is __kspw_psi + void* psi_d = nullptr; // Use void* to store pointer, runtime type information records actual type + + bool already_initpsi = false; + + //------------ + // runtime type information + //------------ + base_device::AbacusDevice_t device_type_ = base_device::CpuDevice; + PrecisionType precision_type_ = PrecisionType::ComplexDouble; + + //------------ + // private functions + //------------ + + template + void before_runner_impl( + const UnitCell &ucell, + const K_Vectors &kv, + const Structure_Factor &sf, + const ModulePW::PW_Basis_K &pw_wfc, + const pseudopot_cell_vnl &ppcell, + const Input_para &inp); + + template + void init_impl(hamilt::Hamilt* p_hamilt); + + template + void update_psi_d_impl(); + + template + void clean_impl(); + + template + void copy_d2h_impl(); + + template void castmem_d2h_impl(std::complex* dst, const std::complex* src, const size_t size); + + template void castmem_d2h_impl(std::complex* dst, const std::complex* src, const size_t size); }; diff --git a/source/source_pw/module_pwdft/exx_helper.cpp b/source/source_pw/module_pwdft/exx_helper.cpp index 2af13d3173..fa80a7ff10 100644 --- a/source/source_pw/module_pwdft/exx_helper.cpp +++ b/source/source_pw/module_pwdft/exx_helper.cpp @@ -32,7 +32,7 @@ void Exx_Helper::init(const UnitCell& ucell, const Input_para& inp, c } template -void Exx_Helper::before_scf(void* p_hamilt, psi::Psi* psi, const Input_para& inp) +void Exx_Helper::before_scf(void* p_hamilt, void* psi, const Input_para& inp) { /// Return if not a valid calculation type if (inp.calculation != "scf" && inp.calculation != "relax" @@ -56,7 +56,7 @@ void Exx_Helper::before_scf(void* p_hamilt, psi::Psi* psi, } template -bool Exx_Helper::iter_finish(void* p_elec, Charge* p_charge, psi::Psi* psi, +bool Exx_Helper::iter_finish(void* p_elec, Charge* p_charge, void* psi, UnitCell& ucell, const Input_para& inp, bool& conv_esolver, int& iter) { @@ -118,9 +118,9 @@ bool Exx_Helper::iter_finish(void* p_elec, Charge* p_charge, psi::Psi } template -double Exx_Helper::cal_exx_energy(psi::Psi *psi_) +double Exx_Helper::cal_exx_energy(void* psi_) { - return op_exx->cal_exx_energy(psi_); + return op_exx->cal_exx_energy(static_cast*>(psi_)); } @@ -156,11 +156,13 @@ bool Exx_Helper::exx_after_converge(int &iter, bool ene_conv) } template -void Exx_Helper::set_psi(psi::Psi *psi_) +void Exx_Helper::set_psi(void* psi_) { - if (psi_ == nullptr) + auto* psi = static_cast*>(psi_); + if (psi == nullptr) return; - op_exx->set_psi(*psi_); + this->psi = psi; + op_exx->set_psi(*psi); if (PARAM.inp.exxace && GlobalC::exx_info.info_global.separate_loop) { op_exx->construct_ace(); diff --git a/source/source_pw/module_pwdft/exx_helper.h b/source/source_pw/module_pwdft/exx_helper.h index 53056fdbbe..96f0a37a7f 100644 --- a/source/source_pw/module_pwdft/exx_helper.h +++ b/source/source_pw/module_pwdft/exx_helper.h @@ -2,6 +2,7 @@ #include "source_base/matrix.h" #include "source_pw/module_pwdft/op_pw_exx.h" #include "source_io/module_parameter/input_parameter.h" +#include "source_pw/module_pwdft/exx_helper_base.h" #ifndef EXX_HELPER_H #define EXX_HELPER_H @@ -9,64 +10,43 @@ class Charge; template -struct Exx_Helper +struct Exx_Helper : public Exx_HelperBase { using Real = typename GetTypeReal::type; using OperatorEXX = hamilt::OperatorEXXPW; public: Exx_Helper() = default; + virtual ~Exx_Helper() = default; OperatorEXX *op_exx = nullptr; - void init(const UnitCell& ucell, const Input_para& inp, const ModuleBase::matrix& wg); + void init(const UnitCell& ucell, const Input_para& inp, const ModuleBase::matrix& wg) override; - /** - * @brief Setup EXX helper before SCF iteration. - * - * This function sets up the EXX helper for the Hamiltonian and psi - * before each SCF iteration. It checks if the calculation type and - * EXX settings are appropriate. - * - * @param p_hamilt Pointer to the Hamiltonian object (void* to avoid circular dependency). - * @param psi Pointer to the wave function object. - * @param inp The input parameters. - */ - void before_scf(void* p_hamilt, psi::Psi* psi, const Input_para& inp); + void before_scf(void* p_hamilt, void* psi, const Input_para& inp) override; - /** - * @brief Handle EXX-related operations after SCF iteration. - * - * This function handles EXX convergence checking and potential update - * after each SCF iteration. It is called in iter_finish. - * - * @param p_elec Pointer to the ElecState object (void* to avoid circular dependency). - * @param p_charge Pointer to the Charge object. - * @param psi Pointer to the wave function object. - * @param ucell The unit cell (non-const reference for update_pot). - * @param inp The input parameters. - * @param conv_esolver Whether SCF is converged (may be modified). - * @param iter The current iteration number (may be modified). - * @return true if EXX processing was done, false otherwise. - */ - bool iter_finish(void* p_elec, Charge* p_charge, psi::Psi* psi, + bool iter_finish(void* p_elec, Charge* p_charge, void* psi, UnitCell& ucell, const Input_para& inp, - bool& conv_esolver, int& iter); + bool& conv_esolver, int& iter) override; - void set_firstiter(bool flag = true) { first_iter = flag; } - void set_wg(const ModuleBase::matrix *wg_) { wg = wg_; } - void set_psi(psi::Psi *psi_); - void iter_inc() { exx_iter++; } + void set_firstiter(bool flag = true) override { first_iter = flag; } + void set_wg(const ModuleBase::matrix *wg_) override { wg = wg_; } + void set_psi(void* psi_) override; + void iter_inc() override { exx_iter++; } - void set_op() + void set_op() override { op_exx->first_iter = first_iter; set_psi(psi); op_exx->set_wg(wg); } - bool exx_after_converge(int &iter, bool ene_conv); + bool exx_after_converge(int &iter, bool ene_conv) override; - double cal_exx_energy(psi::Psi *psi_); + double cal_exx_energy(void* psi_) override; + + bool get_op_first_iter() const override { return op_exx ? op_exx->first_iter : false; } + void set_op_first_iter(bool flag) override { if (op_exx) op_exx->first_iter = flag; } + void set_op_exx(void* op) override { op_exx = reinterpret_cast(op); } private: bool first_iter = false; diff --git a/source/source_pw/module_pwdft/exx_helper_base.h b/source/source_pw/module_pwdft/exx_helper_base.h new file mode 100644 index 0000000000..60eda82b35 --- /dev/null +++ b/source/source_pw/module_pwdft/exx_helper_base.h @@ -0,0 +1,41 @@ +#ifndef EXX_HELPER_BASE_H +#define EXX_HELPER_BASE_H + +#include "source_base/matrix.h" + +class Charge; +class UnitCell; +struct Input_para; + +class Exx_HelperBase +{ + public: + Exx_HelperBase() = default; + virtual ~Exx_HelperBase() = default; + + virtual void init(const UnitCell& ucell, const Input_para& inp, const ModuleBase::matrix& wg) = 0; + + virtual void before_scf(void* p_hamilt, void* psi, const Input_para& inp) = 0; + + virtual bool iter_finish(void* p_elec, Charge* p_charge, void* psi, + UnitCell& ucell, const Input_para& inp, + bool& conv_esolver, int& iter) = 0; + + virtual void set_firstiter(bool flag = true) = 0; + virtual void set_wg(const ModuleBase::matrix* wg) = 0; + virtual void set_psi(void* psi) = 0; + virtual void iter_inc() = 0; + + virtual void set_op() = 0; + + virtual bool exx_after_converge(int& iter, bool ene_conv) = 0; + + virtual double cal_exx_energy(void* psi) = 0; + + virtual bool get_op_first_iter() const = 0; + virtual void set_op_first_iter(bool flag) = 0; + + virtual void set_op_exx(void* op) = 0; +}; + +#endif // EXX_HELPER_BASE_H