From 5cc636749593b647d464cc7e82ae5ec9679929d6 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Wed, 8 Oct 2025 10:50:33 +0800 Subject: [PATCH 01/16] add setup_dm in module_dm --- source/Makefile.Objects | 1 + source/source_esolver/esolver_ks_lcao.cpp | 99 ++++++--------------- source/source_estate/CMakeLists.txt | 1 + source/source_estate/module_dm/setup_dm.cpp | 65 ++++++++++++++ source/source_estate/module_dm/setup_dm.h | 22 +++++ 5 files changed, 116 insertions(+), 72 deletions(-) create mode 100644 source/source_estate/module_dm/setup_dm.cpp create mode 100644 source/source_estate/module_dm/setup_dm.h diff --git a/source/Makefile.Objects b/source/Makefile.Objects index d43b91ee3d..e9e86b02c1 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -253,6 +253,7 @@ OBJS_ELECSTAT=elecstate.o\ OBJS_ELECSTAT_LCAO=elecstate_lcao.o\ elecstate_lcao_cal_tau.o\ + setup_dm.o\ density_matrix.o\ density_matrix_io.o\ cal_dm_psi.o\ diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index 456aafee7a..0a7fb82a53 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -1,6 +1,5 @@ #include "esolver_ks_lcao.h" -//#include "source_base/formatter.h" #include "source_base/global_variable.h" #include "source_base/tool_title.h" #include "source_estate/elecstate_tools.h" @@ -8,19 +7,9 @@ #include "source_estate/module_dm/cal_dm_psi.h" #include "source_lcao/module_deltaspin/spin_constrain.h" #include "source_lcao/module_dftu/dftu.h" -//#include "source_io/berryphase.h" #include "source_io/cube_io.h" -//#include "source_io/io_npz.h" -//#include "source_io/output_dmk.h" #include "source_io/output_log.h" -//#include "source_io/output_mat_sparse.h" -//#include "source_io/output_mulliken.h" -//#include "source_io/output_sk.h" #include "source_io/read_wfc_nao.h" -//#include "source_io/to_qo.h" -//#include "source_io/to_wannier90_lcao.h" -//#include "source_io/to_wannier90_lcao_in_pw.h" -//#include "source_io/write_HS.h" #include "source_io/write_elecstat_pot.h" #include "source_io/module_parameter/parameter.h" @@ -60,10 +49,10 @@ #include "source_lcao/module_gint/temp_gint/gint_info.h" #include "source_estate/module_charge/chgmixing.h" // use charge mixing, mohan add 20251006 +#include "source_estate/module_dm/setup_dm.h" // setup dm from electronic wave functions #include "source_io/ctrl_runner_lcao.h" // use ctrl_runner_lcao() #include "source_io/ctrl_iter_lcao.h" // use ctrl_iter_lcao() - namespace ModuleESolver { @@ -164,12 +153,8 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa if (inp.init_wfc == "file" && inp.esolver_type != "tddft") { if (!ModuleIO::read_wfc_nao(PARAM.globalv.global_readin_dir, - this->pv, - *(this->psi), - this->pelec, - this->pelec->klist->ik2iktot, - this->pelec->klist->get_nkstot(), - inp.nspin)) + this->pv, *(this->psi), this->pelec, this->pelec->klist->ik2iktot, + this->pelec->klist->get_nkstot(), inp.nspin)) { ModuleBase::WARNING_QUIT("ESolver_KS_LCAO", "read electronic wave functions failed"); } @@ -409,60 +394,35 @@ void ESolver_KS_LCAO::iter_init(UnitCell& ucell, const int istep, const // call iter_init() of ESolver_KS ESolver_KS::iter_init(ucell, istep, iter); - elecstate::DensityMatrix* dm - = dynamic_cast*>(this->pelec)->get_DM(); + // cast pointers + + auto* estate = dynamic_cast*>(this->pelec); + + if(!estate) + { + ModuleBase::WARNING_QUIT("ESolver_KS_LCAO::iter_init","pelec does not exist"); + } + + elecstate::DensityMatrix* dm = estate->get_DM(); module_charge::chgmixing_ks_lcao(iter, this->p_chgmix, dm->get_DMR_pointer(1)->get_nnr(), PARAM.inp); // mohan update 2012-06-05 - this->pelec->f_en.deband_harris = this->pelec->cal_delta_eband(ucell); + estate->f_en.deband_harris = estate->cal_delta_eband(ucell); - // first need to calculate the weight according to - // electrons number. if (istep == 0 && PARAM.inp.init_wfc == "file") - { - int exx_two_level_step = 0; + { + int exx_two_level_step = 0; #ifdef __EXX - if (GlobalC::exx_info.info_global.cal_exx) - { - // the following steps are only needed in the first outer exx loop - exx_two_level_step - = GlobalC::exx_info.info_ri.real_number ? this->exd->two_level_step : this->exc->two_level_step; - } + if (GlobalC::exx_info.info_global.cal_exx) + { + // the following steps are only needed in the first outer exx loop + exx_two_level_step + = GlobalC::exx_info.info_ri.real_number ? this->exd->two_level_step : this->exc->two_level_step; + } #endif - if (iter == 1 && exx_two_level_step == 0) - { - std::cout << " WAVEFUN -> CHARGE " << std::endl; - - // calculate the density matrix using read in wave functions - // and then calculate the charge density on grid. - - this->pelec->skip_weights = true; - elecstate::calculate_weights(this->pelec->ekb, - this->pelec->wg, - this->pelec->klist, - this->pelec->eferm, - this->pelec->f_en, - this->pelec->nelec_spin, - this->pelec->skip_weights); - - auto _pelec = dynamic_cast*>(this->pelec); - elecstate::calEBand(_pelec->ekb, _pelec->wg, _pelec->f_en); - elecstate::cal_dm_psi(_pelec->DM->get_paraV_pointer(), _pelec->wg, *this->psi, *(_pelec->DM)); - _pelec->DM->cal_DMR(); - - this->pelec->psiToRho(*this->psi); - this->pelec->skip_weights = false; - - elecstate::cal_ux(ucell); - - //! update the potentials by using new electron charge density - this->pelec->pot->update_from_charge(&this->chr, &ucell); - - //! compute the correction energy for metals - this->pelec->f_en.descf = this->pelec->cal_delta_escf(); - } - } + elecstate::setup_dm(ucell, estate, this->psi, this->chr, iter, exx_two_level_step); + } #ifdef __EXX // calculate exact-exchange @@ -523,7 +483,7 @@ void ESolver_KS_LCAO::hamilt2rho_single(UnitCell& ucell, int istep, int { ModuleBase::TITLE("ESolver_KS_LCAO", "hamilt2rho_single"); - // i1) reset energy + // 1) reset energy this->pelec->f_en.eband = 0.0; this->pelec->f_en.demet = 0.0; bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false; @@ -618,7 +578,6 @@ void ESolver_KS_LCAO::iter_finish(UnitCell& ucell, const int istep, int& const std::vector>& dm_vec = estate->get_DM()->get_DMK_vector(); - // 1) calculate the local occupation number matrix and energy correction in DFT+U if (PARAM.inp.dft_plus_u) { @@ -628,12 +587,8 @@ void ESolver_KS_LCAO::iter_finish(UnitCell& ucell, const int istep, int& { if (GlobalC::dftu.omc != 2) { - ModuleDFTU::dftu_cal_occup_m(iter, - ucell, - dm_vec, - this->kv, - this->p_chgmix->get_mixing_beta(), - hamilt_lcao); + ModuleDFTU::dftu_cal_occup_m(iter, ucell, dm_vec, this->kv, + this->p_chgmix->get_mixing_beta(), hamilt_lcao); } GlobalC::dftu.cal_energy_correction(ucell, istep); } diff --git a/source/source_estate/CMakeLists.txt b/source/source_estate/CMakeLists.txt index f1bd096009..a8380ade4e 100644 --- a/source/source_estate/CMakeLists.txt +++ b/source/source_estate/CMakeLists.txt @@ -46,6 +46,7 @@ if(ENABLE_LCAO) elecstate_lcao.cpp elecstate_lcao_cal_tau.cpp module_pot/H_TDDFT_pw.cpp + module_dm/setup_dm.cpp module_dm/density_matrix.cpp module_dm/density_matrix_io.cpp module_dm/cal_dm_psi.cpp diff --git a/source/source_estate/module_dm/setup_dm.cpp b/source/source_estate/module_dm/setup_dm.cpp new file mode 100644 index 0000000000..beb611b0dc --- /dev/null +++ b/source/source_estate/module_dm/setup_dm.cpp @@ -0,0 +1,65 @@ +#include "source_estate/module_dm/setup_dm.h" +#include "source_estate/module_dm/cal_dm_psi.h" +#include "source_estate/elecstate_tools.h" +#include "source_estate/cal_ux.h" + +template +void elecstate::setup_dm(UnitCell& ucell, + elecstate::ElecStateLCAO* pelec, + psi::Psi* psi, + Charge &chr, + const int iter, + const int exx_two_level_step) +{ + ModuleBase::TITLE("elecstate", "setup_dm"); + + if (iter == 1 && exx_two_level_step == 0) + { + std::cout << " WAVEFUN -> CHARGE " << std::endl; + + // calculate the density matrix using read in wave functions + // and then calculate the charge density on grid. + + pelec->skip_weights = true; + elecstate::calculate_weights(pelec->ekb, + pelec->wg, + pelec->klist, + pelec->eferm, + pelec->f_en, + pelec->nelec_spin, + pelec->skip_weights); + + elecstate::calEBand(pelec->ekb, pelec->wg, pelec->f_en); + elecstate::cal_dm_psi(pelec->DM->get_paraV_pointer(), pelec->wg, *psi, *(pelec->DM)); + pelec->DM->cal_DMR(); + + pelec->psiToRho(*psi); + pelec->skip_weights = false; + + elecstate::cal_ux(ucell); + + //! update the potentials by using new electron charge density + pelec->pot->update_from_charge(&chr, &ucell); + + //! compute the correction energy for metals + pelec->f_en.descf = pelec->cal_delta_escf(); + } + + return; +} + + +template void elecstate::setup_dm(UnitCell& ucell, + elecstate::ElecStateLCAO* pelec, + psi::Psi* psi, + Charge &chr, + const int iter, + const int exx_two_level_step); + +template void elecstate::setup_dm>(UnitCell& ucell, + elecstate::ElecStateLCAO>* pelec, + psi::Psi>* psi, + Charge &chr, + const int iter, + const int exx_two_level_step); + diff --git a/source/source_estate/module_dm/setup_dm.h b/source/source_estate/module_dm/setup_dm.h new file mode 100644 index 0000000000..73a9f8b00f --- /dev/null +++ b/source/source_estate/module_dm/setup_dm.h @@ -0,0 +1,22 @@ +#ifndef SETUP_DM_H +#define SETUP_DM_H + +#include "source_cell/unitcell.h" // use unitcell +#include "source_estate/elecstate_lcao.h"// use ElecStateLCAO +#include "source_psi/psi.h" // use electronic wave functions +#include "source_estate/module_charge/charge.h" // use charge + +namespace elecstate +{ + +template +void setup_dm(UnitCell& ucell, + ElecStateLCAO* pelec, + psi::Psi* psi, + Charge &chr, + const int iter, + const int exx_two_level_step); + +} + +#endif From 74b3500ed1ca9e4b13b8e23cad54194af6715a38 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Wed, 8 Oct 2025 11:47:44 +0800 Subject: [PATCH 02/16] update before_all_runners in esolver --- source/source_esolver/esolver_fp.cpp | 14 +-- source/source_esolver/esolver_ks.cpp | 62 ++++------- source/source_esolver/esolver_ks_lcao.cpp | 62 +++-------- source/source_esolver/lcao_before_scf.cpp | 120 +++++++--------------- 4 files changed, 78 insertions(+), 180 deletions(-) diff --git a/source/source_esolver/esolver_fp.cpp b/source/source_esolver/esolver_fp.cpp index f6aff301bd..2084f589dc 100644 --- a/source/source_esolver/esolver_fp.cpp +++ b/source/source_esolver/esolver_fp.cpp @@ -48,21 +48,21 @@ void ESolver_FP::before_all_runners(UnitCell& ucell, const Input_para& inp) this->pw_rho, this->pw_rhod, this->pw_big, this->classname, inp); - // setup the structure factors + // setup structure factors this->sf.set(this->pw_rhod, inp.nbspline); + // write geometry file ModuleIO::CifParser::write(PARAM.globalv.global_out_dir + "STRU.cif", ucell, "# Generated by ABACUS ModuleIO::CifParser", "data_?"); - //! initialize the charge extrapolation method if necessary + // init charge extrapolation this->CE.Init_CE(inp.nspin, ucell.nat, this->pw_rhod->nrxx, inp.chg_extrap); return; } -//! Something to do after SCF iterations when SCF is converged or comes to the max iter step. void ESolver_FP::after_scf(UnitCell& ucell, const int istep, const bool conv_esolver) { ModuleBase::TITLE("ESolver_FP", "after_scf"); @@ -125,12 +125,8 @@ void ESolver_FP::before_scf(UnitCell& ucell, const int istep) if (ucell.ionic_position_updated) { this->CE.update_all_dis(ucell); - this->CE.extrapolate_charge(&this->Pgrid, - ucell, - &this->chr, - &this->sf, - GlobalV::ofs_running, - GlobalV::ofs_warning); + this->CE.extrapolate_charge(&this->Pgrid, ucell, &this->chr, &this->sf, + GlobalV::ofs_running, GlobalV::ofs_warning); } //! calculate D2 or D3 vdW diff --git a/source/source_esolver/esolver_ks.cpp b/source/source_esolver/esolver_ks.cpp index a00abc02e6..9e1299907b 100644 --- a/source/source_esolver/esolver_ks.cpp +++ b/source/source_esolver/esolver_ks.cpp @@ -51,9 +51,11 @@ template void ESolver_KS::before_all_runners(UnitCell& ucell, const Input_para& inp) { ModuleBase::TITLE("ESolver_KS", "before_all_runners"); - //! 1) initialize "before_all_runniers" in ESolver_FP + + //! 1) init "before_all_runniers" in ESolver_FP ESolver_FP::before_all_runners(ucell, inp); + //! 2) setup some parameters classname = "ESolver_KS"; basisname = ""; @@ -65,8 +67,8 @@ void ESolver_KS::before_all_runners(UnitCell& ucell, const Input_para std::string fft_device = inp.device; - // Fast Fourier Transform - // LCAO basis doesn't support GPU acceleration on FFT currently + //! 3) setup pw_wfc + // currently LCAO doesn't support GPU acceleration of FFT if(inp.basis_type == "lcao") { fft_device = "cpu"; @@ -82,71 +84,49 @@ void ESolver_KS::before_all_runners(UnitCell& ucell, const Input_para pw_wfc = new ModulePW::PW_Basis_K_Big(fft_device, fft_precision); ModulePW::PW_Basis_K_Big* tmp = static_cast(pw_wfc); - // should not use INPUT here, mohan 2024-05-12 tmp->setbxyz(inp.bx, inp.by, inp.bz); - ///---------------------------------------------------------- - /// charge mixing - ///---------------------------------------------------------- + //! 4) setup charge mixing p_chgmix = new Charge_Mixing(); p_chgmix->set_rhopw(this->pw_rho, this->pw_rhod); // cell_factor this->ppcell.cell_factor = inp.cell_factor; + ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "SETUP UNITCELL"); - //! 3) it has been established that - // xc_func is same for all elements, therefore - // only the first one if used + //! 5) setup Exc for the first element '0', because all elements have same exc XC_Functional::set_xc_type(ucell.atoms[0].ncpp.xc_func); - ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "SETUP UNITCELL"); - - //! 4) setup the charge mixing parameters - p_chgmix->set_mixing(inp.mixing_mode, - inp.mixing_beta, - inp.mixing_ndim, - inp.mixing_gg0, - inp.mixing_tau, - inp.mixing_beta_mag, - inp.mixing_gg0_mag, - inp.mixing_gg0_min, - inp.mixing_angle, - inp.mixing_dmr, - ucell.omega, - ucell.tpiba); + //! 6) setup the charge mixing parameters + p_chgmix->set_mixing(inp.mixing_mode, inp.mixing_beta, inp.mixing_ndim, + inp.mixing_gg0, inp.mixing_tau, inp.mixing_beta_mag, inp.mixing_gg0_mag, + inp.mixing_gg0_min, inp.mixing_angle, inp.mixing_dmr, ucell.omega, ucell.tpiba); p_chgmix->init_mixing(); - //! 5) ESolver depends on the Symmetry module - // symmetry analysis should be performed every time the cell is changed + //! 7) symmetry analysis should be performed every time the cell is changed if (ModuleSymmetry::Symmetry::symm_flag == 1) { ucell.symm.analy_sys(ucell.lat, ucell.st, ucell.atoms, GlobalV::ofs_running); ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "SYMMETRY"); } - //! 6) Setup the k points according to symmetry. + //! 8) Setup the k points according to symmetry. this->kv.set(ucell,ucell.symm, inp.kpoint_file, inp.nspin, ucell.G, ucell.latvec, GlobalV::ofs_running); ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT K-POINTS"); - //! 7) print information + //! 9) print information ModuleIO::setup_parameters(ucell, this->kv); - //! 8) setup plane wave for electronic wave functions + //! 10) setup plane wave for electronic wave functions ModuleESolver::pw_setup(inp, ucell, *this->pw_rho, this->kv, *this->pw_wfc); - //! 9) initialize the real-space uniform grid for FFT and parallel - //! distribution of plane waves - Pgrid.init(this->pw_rhod->nx, - this->pw_rhod->ny, - this->pw_rhod->nz, - this->pw_rhod->nplane, - this->pw_rhod->nrxx, - pw_big->nbz, - pw_big->bz); - - //! 10) calculate the structure factor + //! 11) parallel of FFT grid + Pgrid.init(this->pw_rhod->nx, this->pw_rhod->ny, this->pw_rhod->nz, + this->pw_rhod->nplane, this->pw_rhod->nrxx, pw_big->nbz, pw_big->bz); + + //! 12) calculate the structure factor this->sf.setup_structure_factor(&ucell, Pgrid, this->pw_rhod); } diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index 0a7fb82a53..9b86df0d45 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -1,53 +1,24 @@ #include "esolver_ks_lcao.h" - -#include "source_base/global_variable.h" -#include "source_base/tool_title.h" #include "source_estate/elecstate_tools.h" - -#include "source_estate/module_dm/cal_dm_psi.h" #include "source_lcao/module_deltaspin/spin_constrain.h" -#include "source_lcao/module_dftu/dftu.h" -#include "source_io/cube_io.h" -#include "source_io/output_log.h" #include "source_io/read_wfc_nao.h" -#include "source_io/write_elecstat_pot.h" -#include "source_io/module_parameter/parameter.h" - -// be careful of hpp, there may be multiple definitions of functions, 20250302, mohan -#include "source_lcao/hs_matrix_k.hpp" - -#include "source_base/global_function.h" -#include "source_cell/module_neighbor/sltk_grid_driver.h" +#include "source_lcao/hs_matrix_k.hpp" // there may be multiple definitions if using hpp #include "source_estate/cal_ux.h" #include "source_estate/module_charge/symmetry_rho.h" -#include "source_estate/occupy.h" #include "source_lcao/LCAO_domain.h" // need DeePKS_init #include "source_lcao/module_dftu/dftu.h" -#include "source_pw/module_pwdft/global.h" -#include "source_io/print_info.h" - -#include - #ifdef __MLALGO #include "source_lcao/module_deepks/LCAO_deepks.h" #include "source_lcao/module_deepks/LCAO_deepks_interface.h" #endif -//-----force& stress------------------- #include "source_lcao/FORCE_STRESS.h" - -//-----HSolver ElecState Hamilt-------- #include "source_estate/elecstate_lcao.h" #include "source_lcao/hamilt_lcao.h" #include "source_hsolver/hsolver_lcao.h" - #ifdef __EXX #include "../source_lcao/module_ri/exx_opt_orb.h" #endif - -// test RDMFT #include "source_lcao/module_rdmft/rdmft.h" -#include "source_lcao/module_gint/temp_gint/gint_info.h" - #include "source_estate/module_charge/chgmixing.h" // use charge mixing, mohan add 20251006 #include "source_estate/module_dm/setup_dm.h" // setup dm from electronic wave functions #include "source_io/ctrl_runner_lcao.h" // use ctrl_runner_lcao() @@ -96,8 +67,7 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa // 1) before_all_runners in ESolver_KS ESolver_KS::before_all_runners(ucell, inp); - // 2) init ElecState - // autoset nbands in ElecState before init_basis (for Psi 2d division) + // 2) autoset nbands in ElecState before init_basis (for Psi 2d division) if (this->pelec == nullptr) { // TK stands for double and std::complex? @@ -105,23 +75,23 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa this->kv.get_nks(), &(this->GG), &(this->GK), this->pw_rho, this->pw_big); } - // 3) read the LCAO orbitals/projectors and construct the interpolation tables. + // 3) read LCAO orbitals/projectors and construct the interpolation tables. LCAO_domain::init_basis_lcao(this->pv, inp.onsite_radius, inp.lcao_ecut, inp.lcao_dk, inp.lcao_dr, inp.lcao_rmax, ucell, two_center_bundle_, orb_); // 4) setup EXX calculations if (PARAM.inp.calculation == "gen_opt_abfs") { - #ifdef __EXX +#ifdef __EXX Exx_Opt_Orb exx_opt_orb; exx_opt_orb.generate_matrix(GlobalC::exx_info.info_opt_abfs, this->kv, ucell, this->orb_); - #else +#else ModuleBase::WARNING_QUIT("ESolver_KS_LCAO::before_all_runners", "calculation=gen_opt_abfs must compile __EXX"); - #endif +#endif return; } - // 5) initialize electronic wave function psi + // 5) init electronic wave function psi if (this->psi == nullptr) { int nsk = 0; @@ -160,11 +130,10 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa } } - // 7) initialize the density matrix - // DMK are allocated here, but DMR is constructed in before_scf() + // 7) init DMK, but DMR is constructed in before_scf() dynamic_cast*>(this->pelec)->init_DM(&this->kv, &(this->pv), inp.nspin); - // 8) initialize exact exchange calculations + // 8) init exact exchange calculations #ifdef __EXX if (inp.calculation == "scf" || inp.calculation == "relax" || inp.calculation == "cell-relax" || inp.calculation == "md") @@ -198,15 +167,15 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa dftu->init(ucell, &this->pv, this->kv.get_nks(), &orb_); } - // 10) initialize local pseudopotentials + // 10) init local pseudopotentials this->locpp.init_vloc(ucell, this->pw_rho); ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "LOCAL POTENTIAL"); - // 11) inititlize the charge density + // 11) init charge density this->chr.allocate(inp.nspin); this->pelec->omega = ucell.omega; - // 12) initialize the potential + // 12) init potentials if (this->pelec->pot == nullptr) { this->pelec->pot = new elecstate::Potential(this->pw_rhod, this->pw_rho, @@ -214,7 +183,7 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa &(this->pelec->f_en.etxc), &(this->pelec->f_en.vtxc)); } - // 13) initialize deepks + // 13) init deepks #ifdef __MLALGO LCAO_domain::DeePKS_init(ucell, pv, this->kv.get_nks(), orb_, this->ld, GlobalV::ofs_running); if (inp.deepks_scf) @@ -228,8 +197,7 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa } #endif - // 14) set occupations - // tddft does not need to set occupations in the first scf + // 14) set occupations, tddft does not need to set occupations in the first scf if (inp.ocp && inp.esolver_type != "tddft") { elecstate::fixed_weights(inp.ocp_kb, inp.nbands, inp.nelec, @@ -257,7 +225,7 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa } } - // 16) initialize rdmft, added by jghan + // 16) init rdmft, added by jghan if (inp.rdmft == true) { rdmft_solver.init(this->GG, this->GK, this->pv, ucell, diff --git a/source/source_esolver/lcao_before_scf.cpp b/source/source_esolver/lcao_before_scf.cpp index 237cb41175..501db574f2 100644 --- a/source/source_esolver/lcao_before_scf.cpp +++ b/source/source_esolver/lcao_before_scf.cpp @@ -2,35 +2,18 @@ #include "source_esolver/esolver_ks_lcao.h" #include "source_lcao/hamilt_lcao.h" #include "source_lcao/module_dftu/dftu.h" -#include "source_pw/module_pwdft/global.h" -// -#include "source_base/timer.h" #include "source_cell/module_neighbor/sltk_atom_arrange.h" #include "source_cell/module_neighbor/sltk_grid_driver.h" -#include "source_io/berryphase.h" -#include "source_io/get_pchg_lcao.h" -#include "source_io/get_wf_lcao.h" -#include "source_io/io_npz.h" -#include "source_io/to_wannier90_lcao.h" -#include "source_io/to_wannier90_lcao_in_pw.h" -#include "source_io/write_HS_R.h" #include "source_io/module_parameter/parameter.h" #include "source_estate/elecstate_tools.h" #ifdef __MLALGO #include "source_lcao/module_deepks/LCAO_deepks.h" #endif -#include "source_base/formatter.h" #include "source_estate/elecstate_lcao.h" -#include "source_estate/module_dm/cal_dm_psi.h" #include "source_lcao/LCAO_domain.h" #include "source_lcao/module_operator_lcao/op_exx_lcao.h" #include "source_lcao/module_operator_lcao/operator_lcao.h" #include "source_lcao/module_deltaspin/spin_constrain.h" -#include "source_io/cube_io.h" -#include "source_io/write_elecstat_pot.h" -#ifdef __EXX -#include "source_io/restart_exx_csr.h" -#endif namespace ModuleESolver { @@ -44,20 +27,27 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) //! 1) call before_scf() of ESolver_KS. ESolver_KS::before_scf(ucell, istep); + auto* estate = dynamic_cast*>(this->pelec); + auto* hamilt_lcao = dynamic_cast*>(this->p_hamilt); + + if(!estate) + { + ModuleBase::WARNING_QUIT("ESolver_KS_LCAO::before_scf","pelec does not exist"); + } + + if(!hamilt_lcao) + { + ModuleBase::WARNING_QUIT("ESolver_KS_LCAO::before_scf","p_hamilt does not exist"); + } + //! 2) find search radius double search_radius = atom_arrange::set_sr_NL(GlobalV::ofs_running, - PARAM.inp.out_level, - orb_.get_rcutmax_Phi(), - ucell.infoNL.get_rcutmax_Beta(), - PARAM.globalv.gamma_only_local); + PARAM.inp.out_level, orb_.get_rcutmax_Phi(), ucell.infoNL.get_rcutmax_Beta(), + PARAM.globalv.gamma_only_local); //! 3) use search_radius to search adj atoms - atom_arrange::search(PARAM.globalv.search_pbc, - GlobalV::ofs_running, - this->gd, - ucell, - search_radius, - PARAM.inp.test_atom_input); + atom_arrange::search(PARAM.globalv.search_pbc, GlobalV::ofs_running, + this->gd, ucell, search_radius, PARAM.inp.test_atom_input); //! 4) initialize NAO basis set #ifdef __OLD_GINT @@ -104,6 +94,7 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) //! 6) prepare grid integral #else + // here new is a unique pointer, which will be deleted automatically gint_info_.reset( new ModuleGint::GintInfo( this->pw_big->nbx, @@ -138,19 +129,13 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) } if (this->p_hamilt == nullptr) { - elecstate::DensityMatrix* DM = dynamic_cast*>(this->pelec)->get_DM(); + elecstate::DensityMatrix* DM = estate->get_DM(); this->p_hamilt = new hamilt::HamiltLCAO( PARAM.globalv.gamma_only_local ? &(this->GG) : nullptr, PARAM.globalv.gamma_only_local ? nullptr : &(this->GK), - ucell, - this->gd, - &this->pv, - this->pelec->pot, - this->kv, - two_center_bundle_, - orb_, - DM + ucell, this->gd, &this->pv, this->pelec->pot, this->kv, + two_center_bundle_, orb_, DM #ifdef __MLALGO , &this->ld @@ -165,9 +150,6 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) ); } - - - #ifdef __MLALGO // 9) for each ionic step, the overlap must be rebuilt // since it depends on ionic positions @@ -177,23 +159,13 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) // allocate , phialpha is different every ion step, so it is allocated here DeePKS_domain::allocate_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, pv, this->ld.phialpha); // build and save at beginning - DeePKS_domain::build_phialpha(PARAM.inp.cal_force, - ucell, - orb_, - this->gd, - pv, - *(two_center_bundle_.overlap_orb_alpha), - this->ld.phialpha); + DeePKS_domain::build_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, + pv, *(two_center_bundle_.overlap_orb_alpha), this->ld.phialpha); if (PARAM.inp.deepks_out_unittest) { - DeePKS_domain::check_phialpha(PARAM.inp.cal_force, - ucell, - orb_, - this->gd, - pv, - this->ld.phialpha, - GlobalV::MY_RANK); + DeePKS_domain::check_phialpha(PARAM.inp.cal_force, ucell, orb_, + this->gd, pv, this->ld.phialpha, GlobalV::MY_RANK); } } #endif @@ -202,23 +174,12 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) if (PARAM.inp.sc_mag_switch) { spinconstrain::SpinConstrain& sc = spinconstrain::SpinConstrain::getScInstance(); - sc.init_sc(PARAM.inp.sc_thr, - PARAM.inp.nsc, - PARAM.inp.nsc_min, - PARAM.inp.alpha_trial, - PARAM.inp.sccut, - PARAM.inp.sc_drop_thr, - ucell, - &(this->pv), - PARAM.inp.nspin, - this->kv, - this->p_hamilt, - this->psi, - this->pelec); + sc.init_sc(PARAM.inp.sc_thr, PARAM.inp.nsc, PARAM.inp.nsc_min, PARAM.inp.alpha_trial, + PARAM.inp.sccut, PARAM.inp.sc_drop_thr, ucell, &(this->pv), + PARAM.inp.nspin, this->kv, this->p_hamilt, this->psi, this->pelec); } - // 11) set xc type before the first cal of xc in pelec->init_scf - // Peize Lin add 2016-12-03 + // 11) set xc type before the first cal of xc in pelec->init_scf, Peize Lin add 2016-12-03 #ifdef __EXX if (PARAM.inp.calculation != "nscf") { @@ -238,42 +199,35 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) // 13) initalize DMR // DMR should be same size with Hamiltonian(R) - dynamic_cast*>(this->pelec) - ->get_DM() - ->init_DMR(*(dynamic_cast*>(this->p_hamilt)->getHR())); + estate->get_DM()->init_DMR(*hamilt_lcao->getHR()); #ifdef __MLALGO - // initialize DMR of DeePKS + // 14) initialize DMR of DeePKS this->ld.init_DMR(ucell, orb_, this->pv, this->gd); #endif - // 14) two cases are considered: + // 15) two cases are considered: // 1. DMK in DensityMatrix is not empty (istep > 0), then DMR is initialized by DMK // 2. DMK in DensityMatrix is empty (istep == 0), then DMR is initialized by zeros if (istep > 0) { - dynamic_cast*>(this->pelec)->get_DM()->cal_DMR(); + estate->get_DM()->cal_DMR(); } - // 15) the electron charge density should be symmetrized, - // here is the initialization + // 16) the electron charge density should be symmetrized, Symmetry_rho srho; for (int is = 0; is < PARAM.inp.nspin; is++) { srho.begin(is, this->chr, this->pw_rho, ucell.symm); } - // 16) why we need to set this sentence? mohan add 2025-03-10 + // 17) why we need to set this sentence? mohan add 2025-03-10 this->p_hamilt->non_first_scf = istep; - // 17) update of RDMFT, added by jghan + // 18) update of RDMFT, added by jghan if (PARAM.inp.rdmft == true) { - // necessary operation of these parameters have be done with p_esolver->Init() in source/source_main/driver_run.cpp - rdmft_solver.update_ion(ucell, - *(this->pw_rho), - this->locpp.vloc, - this->sf.strucFac); + rdmft_solver.update_ion(ucell, *(this->pw_rho), this->locpp.vloc, this->sf.strucFac); } ModuleBase::timer::tick("ESolver_KS_LCAO", "before_scf"); From 15fe01cb8c24d989ef81c711cd6445e6f2a90175 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Wed, 8 Oct 2025 13:34:09 +0800 Subject: [PATCH 03/16] change setup_dm back --- source/source_esolver/esolver_ks_lcao.cpp | 42 ++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index 9b86df0d45..3b015e6e32 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -24,6 +24,9 @@ #include "source_io/ctrl_runner_lcao.h" // use ctrl_runner_lcao() #include "source_io/ctrl_iter_lcao.h" // use ctrl_iter_lcao() +// tmp +#include "source_estate/module_dm/cal_dm_psi.h" + namespace ModuleESolver { @@ -389,12 +392,49 @@ void ESolver_KS_LCAO::iter_init(UnitCell& ucell, const int istep, const = GlobalC::exx_info.info_ri.real_number ? this->exd->two_level_step : this->exc->two_level_step; } #endif - elecstate::setup_dm(ucell, estate, this->psi, this->chr, iter, exx_two_level_step); +// elecstate::setup_dm(ucell, estate, this->psi, this->chr, iter, exx_two_level_step); + + + if (iter == 1 && exx_two_level_step == 0) + { + std::cout << " WAVEFUN -> CHARGE " << std::endl; + + // calculate the density matrix using read in wave functions + // and then calculate the charge density on grid. + + estate->skip_weights = true; + elecstate::calculate_weights(estate->ekb, + estate->wg, + estate->klist, + estate->eferm, + estate->f_en, + estate->nelec_spin, + estate->skip_weights); + + elecstate::calEBand(estate->ekb, estate->wg, estate->f_en); + elecstate::cal_dm_psi(estate->DM->get_paraV_pointer(), estate->wg, *this->psi, *(estate->DM)); + estate->DM->cal_DMR(); + + estate->psiToRho(*this->psi); + estate->skip_weights = false; + + elecstate::cal_ux(ucell); + + //! update the potentials by using new electron charge density + estate->pot->update_from_charge(&this->chr, &ucell); + + //! compute the correction energy for metals + estate->f_en.descf = estate->cal_delta_escf(); + } + + + } #ifdef __EXX // calculate exact-exchange if (PARAM.inp.calculation != "nscf") +q { if (GlobalC::exx_info.info_ri.real_number) { From 087a9af3067d9e22b3180392b8e7e872b5fa723f Mon Sep 17 00:00:00 2001 From: mohanchen Date: Wed, 8 Oct 2025 13:55:30 +0800 Subject: [PATCH 04/16] add setup_exx --- source/source_esolver/esolver_ks_lcao.cpp | 49 ++-------------------- source/source_esolver/esolver_ks_lcao.h | 36 ++++------------ source/source_lcao/FORCE_STRESS.cpp | 21 ++++------ source/source_lcao/FORCE_STRESS.h | 7 ++-- source/source_lcao/module_ri/setup_exx.cpp | 49 ++++++++++++++++++++++ source/source_lcao/module_ri/setup_exx.h | 35 ++++++++++++++++ 6 files changed, 108 insertions(+), 89 deletions(-) create mode 100644 source/source_lcao/module_ri/setup_exx.cpp create mode 100644 source/source_lcao/module_ri/setup_exx.h diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index 3b015e6e32..2f725016b2 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -35,22 +35,7 @@ ESolver_KS_LCAO::ESolver_KS_LCAO() { this->classname = "ESolver_KS_LCAO"; this->basisname = "LCAO"; - -#ifdef __EXX - // 1. currently this initialization must be put in constructor rather than `before_all_runners()` - // because the latter is not reused by ESolver_LCAO_TDDFT, - // which cause the failure of the subsequent procedure reused by ESolver_LCAO_TDDFT - // 2. always construct but only initialize when if(cal_exx) is true - // because some members like two_level_step are used outside if(cal_exx) - if (GlobalC::exx_info.info_ri.real_number) - { - this->exd = std::make_shared>(GlobalC::exx_info.info_ri); - } - else - { - this->exc = std::make_shared>>(GlobalC::exx_info.info_ri); - } -#endif + this->exx_nao.init(); // mohan add 20251008 } template @@ -137,31 +122,7 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa dynamic_cast*>(this->pelec)->init_DM(&this->kv, &(this->pv), inp.nspin); // 8) init exact exchange calculations -#ifdef __EXX - if (inp.calculation == "scf" || inp.calculation == "relax" || inp.calculation == "cell-relax" - || inp.calculation == "md") - { - if (GlobalC::exx_info.info_global.cal_exx) - { - if (inp.init_wfc != "file") - { // if init_wfc==file, directly enter the EXX loop - XC_Functional::set_xc_first_loop(ucell); - } - - // initialize 2-center radial tables for EXX-LRI - if (GlobalC::exx_info.info_ri.real_number) - { - this->exd->init(MPI_COMM_WORLD, ucell, this->kv, orb_); - this->exd->exx_before_all_runners(this->kv, ucell, this->pv); - } - else - { - this->exc->init(MPI_COMM_WORLD, ucell, this->kv, orb_); - this->exc->exx_before_all_runners(this->kv, ucell, this->pv); - } - } - } -#endif + exx_nao.before_runner(); // 9) initialize DFT+U if (inp.dft_plus_u) @@ -280,10 +241,7 @@ void ESolver_KS_LCAO::cal_force(UnitCell& ucell, ModuleBase::matrix& for this->ld, "tot", #endif -#ifdef __EXX - *this->exd, - *this->exc, -#endif + this->exx_nao, &ucell.symm); // delete RA after cal_force @@ -434,7 +392,6 @@ void ESolver_KS_LCAO::iter_init(UnitCell& ucell, const int istep, const #ifdef __EXX // calculate exact-exchange if (PARAM.inp.calculation != "nscf") -q { if (GlobalC::exx_info.info_ri.real_number) { diff --git a/source/source_esolver/esolver_ks_lcao.h b/source/source_esolver/esolver_ks_lcao.h index 8040ccbbb9..b6ab648e49 100644 --- a/source/source_esolver/esolver_ks_lcao.h +++ b/source/source_esolver/esolver_ks_lcao.h @@ -2,32 +2,17 @@ #define ESOLVER_KS_LCAO_H #include "esolver_ks.h" - -// for adjacent atoms -#include "source_lcao/record_adj.h" - -// for NAO basis -#include "source_basis/module_nao/two_center_bundle.h" - -// for grid integration -#include "source_lcao/module_gint/gint_gamma.h" -#include "source_lcao/module_gint/gint_k.h" -#include "source_lcao/module_gint/temp_gint/gint.h" +#include "source_lcao/record_adj.h" // adjacent atoms +#include "source_basis/module_nao/two_center_bundle.h" // nao basis +#include "source_lcao/module_gint/gint_gamma.h" // gint for gamma-only k-points +#include "source_lcao/module_gint/gint_k.h" // gint for multi k-points +#include "source_lcao/module_gint/temp_gint/gint.h" // gint #include "source_lcao/module_gint/temp_gint/gint_info.h" - -// for DeePKS #ifdef __MLALGO -#include "source_lcao/module_deepks/LCAO_deepks.h" +#include "source_lcao/module_deepks/LCAO_deepks.h" // deepks #endif - -// for EXX -#ifdef __EXX -#include "source_lcao/module_ri/Exx_LRI_interface.h" -#include "source_lcao/module_ri/Mix_DMk_2D.h" -#endif - -// for RDMFT -#include "source_lcao/module_rdmft/rdmft.h" +#include "source_lcao/module_ri/setup_exx.h" // for exx, mohan add 20251008 +#include "source_lcao/module_rdmft/rdmft.h" // rdmft #include @@ -118,10 +103,7 @@ class ESolver_KS_LCAO : public ESolver_KS LCAO_Deepks ld; #endif -#ifdef __EXX - std::shared_ptr> exd = nullptr; - std::shared_ptr>> exc = nullptr; -#endif + Exx_NAO exx_nao; friend class LR::ESolver_LR; friend class LR::ESolver_LR, double>; diff --git a/source/source_lcao/FORCE_STRESS.cpp b/source/source_lcao/FORCE_STRESS.cpp index 03d91e319d..dbfed4aa25 100644 --- a/source/source_lcao/FORCE_STRESS.cpp +++ b/source/source_lcao/FORCE_STRESS.cpp @@ -55,10 +55,7 @@ void Force_Stress_LCAO::getForceStress(UnitCell& ucell, LCAO_Deepks& ld, const std::string& dpks_out_type, #endif -#ifdef __EXX - Exx_LRI_Interface& exd, - Exx_LRI_Interface>& exc, -#endif + Exx_NAO &exx_nao, ModuleSymmetry::Symmetry* symm) { ModuleBase::TITLE("Force_Stress_LCAO", "getForceStress"); @@ -378,26 +375,26 @@ void Force_Stress_LCAO::getForceStress(UnitCell& ucell, { if (GlobalC::exx_info.info_ri.real_number) { - exd.cal_exx_force(ucell.nat); - force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exd.get_force(); + exx_nao.exd.cal_exx_force(ucell.nat); + force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_nao.exd.get_force(); } else { - exc.cal_exx_force(ucell.nat); - force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exc.get_force(); + exx_nao.exc.cal_exx_force(ucell.nat); + force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_nao.exc.get_force(); } } if (isstress) { if (GlobalC::exx_info.info_ri.real_number) { - exd.cal_exx_stress(ucell.omega, ucell.lat0); - stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exd.get_stress(); + exx_nao.exd.cal_exx_stress(ucell.omega, ucell.lat0); + stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_nao.exd.get_stress(); } else { - exc.cal_exx_stress(ucell.omega, ucell.lat0); - stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exc.get_stress(); + exx_nao.exc.cal_exx_stress(ucell.omega, ucell.lat0); + stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_nao.exc.get_stress(); } } } diff --git a/source/source_lcao/FORCE_STRESS.h b/source/source_lcao/FORCE_STRESS.h index 3c6eb65745..22e17965c1 100644 --- a/source/source_lcao/FORCE_STRESS.h +++ b/source/source_lcao/FORCE_STRESS.h @@ -16,6 +16,8 @@ #include "force_stress_arrays.h" #include "source_lcao/module_gint/gint_gamma.h" #include "source_lcao/module_gint/gint_k.h" +#include "source_lcao/module_ri/setup_exx.h" // for exx, mohan add 20251008 + template class Force_Stress_LCAO @@ -53,10 +55,7 @@ class Force_Stress_LCAO LCAO_Deepks& ld, const std::string& dpks_out_type, #endif -#ifdef __EXX - Exx_LRI_Interface& exd, - Exx_LRI_Interface>& exc, -#endif + Exx_NAO &exx_nao, ModuleSymmetry::Symmetry* symm); private: diff --git a/source/source_lcao/module_ri/setup_exx.cpp b/source/source_lcao/module_ri/setup_exx.cpp new file mode 100644 index 0000000000..d01b3d73cc --- /dev/null +++ b/source/source_lcao/module_ri/setup_exx.cpp @@ -0,0 +1,49 @@ +#include "source_lcao/module_ri/setup_exx.h" + +void Exx_NAO::init0() +{ +#ifdef __EXX + // 1. currently this initialization must be put in constructor rather than `before_all_runners()` + // because the latter is not reused by ESolver_LCAO_TDDFT, + // which cause the failure of the subsequent procedure reused by ESolver_LCAO_TDDFT + // 2. always construct but only initialize when if(cal_exx) is true + // because some members like two_level_step are used outside if(cal_exx) + if (GlobalC::exx_info.info_ri.real_number) + { + this->exd = std::make_shared>(GlobalC::exx_info.info_ri); + } + else + { + this->exc = std::make_shared>>(GlobalC::exx_info.info_ri); + } +#endif +} + +void Exx_NAO::before_runner() +{ +#ifdef __EXX + if (inp.calculation == "scf" || inp.calculation == "relax" || inp.calculation == "cell-relax" + || inp.calculation == "md") + { + if (GlobalC::exx_info.info_global.cal_exx) + { + if (inp.init_wfc != "file") + { // if init_wfc==file, directly enter the EXX loop + XC_Functional::set_xc_first_loop(ucell); + } + + // initialize 2-center radial tables for EXX-LRI + if (GlobalC::exx_info.info_ri.real_number) + { + this->exd->init(MPI_COMM_WORLD, ucell, this->kv, orb_); + this->exd->exx_before_all_runners(this->kv, ucell, this->pv); + } + else + { + this->exc->init(MPI_COMM_WORLD, ucell, this->kv, orb_); + this->exc->exx_before_all_runners(this->kv, ucell, this->pv); + } + } + } +#endif +} diff --git a/source/source_lcao/module_ri/setup_exx.h b/source/source_lcao/module_ri/setup_exx.h new file mode 100644 index 0000000000..6be0b81f52 --- /dev/null +++ b/source/source_lcao/module_ri/setup_exx.h @@ -0,0 +1,35 @@ +#ifndef SETUP_EXX_H +#define SETUP_EXX_H + +/* +#include "source_cell/unitcell.h" // use unitcell +#include "source_estate/elecstate_lcao.h"// use ElecStateLCAO +#include "source_psi/psi.h" // use electronic wave functions +#include "source_estate/module_charge/charge.h" // use charge +*/ + +// for EXX +#ifdef __EXX +#include "source_lcao/module_ri/Exx_LRI_interface.h" +#include "source_lcao/module_ri/Mix_DMk_2D.h" +#endif + +class Exx_NAO +{ + public: + +#ifdef __EXX + std::shared_ptr> exd = nullptr; + std::shared_ptr>> exc = nullptr; +#endif + + void init0(); + + void before_runner(); + + + +}; + + +#endif From 0654086af2b48b8040274d40aff5c44c8de29b8d Mon Sep 17 00:00:00 2001 From: mohanchen Date: Wed, 8 Oct 2025 13:57:00 +0800 Subject: [PATCH 05/16] fix bug --- source/source_esolver/esolver_ks_lcao.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index 3b015e6e32..e834c805c0 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -434,7 +434,6 @@ void ESolver_KS_LCAO::iter_init(UnitCell& ucell, const int istep, const #ifdef __EXX // calculate exact-exchange if (PARAM.inp.calculation != "nscf") -q { if (GlobalC::exx_info.info_ri.real_number) { From 8adda0c74f4e1255ab67b44aea642f0f141ec44b Mon Sep 17 00:00:00 2001 From: mohanchen Date: Wed, 8 Oct 2025 14:08:59 +0800 Subject: [PATCH 06/16] add setup_exx in normal LCAO codes --- source/Makefile.Objects | 15 ++++++++------- source/source_esolver/esolver_double_xc.cpp | 7 ++----- source/source_esolver/esolver_ks_lcao.cpp | 2 +- source/source_esolver/esolver_ks_lcao.h | 2 +- source/source_lcao/CMakeLists.txt | 1 + source/source_lcao/FORCE_STRESS.h | 2 +- source/source_lcao/module_ri/CMakeLists.txt | 2 +- source/source_lcao/{module_ri => }/setup_exx.cpp | 4 ++-- source/source_lcao/{module_ri => }/setup_exx.h | 6 +++--- 9 files changed, 20 insertions(+), 21 deletions(-) rename source/source_lcao/{module_ri => }/setup_exx.cpp (96%) rename source/source_lcao/{module_ri => }/setup_exx.h (90%) diff --git a/source/Makefile.Objects b/source/Makefile.Objects index e9e86b02c1..6475fa82d6 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -669,6 +669,7 @@ OBJS_LCAO=evolve_elec.o\ LCAO_allocate.o\ LCAO_set_mat2d.o\ LCAO_init_basis.o\ + setup_exx.o\ center2_orb.o\ center2_orb-orb11.o\ center2_orb-orb21.o\ @@ -677,13 +678,13 @@ OBJS_LCAO=evolve_elec.o\ wavefunc_in_pw.o\ OBJS_MODULE_RI=conv_coulomb_pot_k.o\ - exx_abfs-abfs_index.o \ - exx_abfs-jle.o \ - exx_abfs-io.o \ - exx_abfs-construct_orbs.o \ - ABFs_Construct-PCA.o \ - exx_opt_orb.o \ - exx_opt_orb-print.o \ + exx_abfs-abfs_index.o\ + exx_abfs-jle.o\ + exx_abfs-io.o\ + exx_abfs-construct_orbs.o\ + ABFs_Construct-PCA.o\ + exx_opt_orb-print.o\ + exx_opt_orb.o\ Matrix_Orbs11.o\ Matrix_Orbs21.o\ Matrix_Orbs22.o\ diff --git a/source/source_esolver/esolver_double_xc.cpp b/source/source_esolver/esolver_double_xc.cpp index 6c8eb1cbb2..48568c243f 100644 --- a/source/source_esolver/esolver_double_xc.cpp +++ b/source/source_esolver/esolver_double_xc.cpp @@ -415,11 +415,8 @@ void ESolver_DoubleXC::cal_force(UnitCell& ucell, ModuleBase::matrix& fo this->ld, "base", #endif -#ifdef __EXX - *this->exd, - *this->exc, -#endif - &ucell.symm); + this->exx_nao, + &ucell.symm); // restore to original xc XC_Functional::set_xc_type(ucell.atoms[0].ncpp.xc_func); diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index 2f725016b2..9cadb28c48 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -122,7 +122,7 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa dynamic_cast*>(this->pelec)->init_DM(&this->kv, &(this->pv), inp.nspin); // 8) init exact exchange calculations - exx_nao.before_runner(); + this->exx_nao.before_runner(); // 9) initialize DFT+U if (inp.dft_plus_u) diff --git a/source/source_esolver/esolver_ks_lcao.h b/source/source_esolver/esolver_ks_lcao.h index b6ab648e49..4f8e1747ff 100644 --- a/source/source_esolver/esolver_ks_lcao.h +++ b/source/source_esolver/esolver_ks_lcao.h @@ -11,7 +11,7 @@ #ifdef __MLALGO #include "source_lcao/module_deepks/LCAO_deepks.h" // deepks #endif -#include "source_lcao/module_ri/setup_exx.h" // for exx, mohan add 20251008 +#include "source_lcao/setup_exx.h" // for exx, mohan add 20251008 #include "source_lcao/module_rdmft/rdmft.h" // rdmft #include diff --git a/source/source_lcao/CMakeLists.txt b/source/source_lcao/CMakeLists.txt index da170bbb73..d3df201219 100644 --- a/source/source_lcao/CMakeLists.txt +++ b/source/source_lcao/CMakeLists.txt @@ -41,6 +41,7 @@ if(ENABLE_LCAO) LCAO_allocate.cpp LCAO_set_mat2d.cpp LCAO_init_basis.cpp + setup_exx.cpp record_adj.cpp center2_orb.cpp center2_orb-orb11.cpp diff --git a/source/source_lcao/FORCE_STRESS.h b/source/source_lcao/FORCE_STRESS.h index 22e17965c1..c98d36326d 100644 --- a/source/source_lcao/FORCE_STRESS.h +++ b/source/source_lcao/FORCE_STRESS.h @@ -16,7 +16,7 @@ #include "force_stress_arrays.h" #include "source_lcao/module_gint/gint_gamma.h" #include "source_lcao/module_gint/gint_k.h" -#include "source_lcao/module_ri/setup_exx.h" // for exx, mohan add 20251008 +#include "source_lcao/setup_exx.h" // for exx, mohan add 20251008 template diff --git a/source/source_lcao/module_ri/CMakeLists.txt b/source/source_lcao/module_ri/CMakeLists.txt index a201cab1a3..65ab0d12f9 100644 --- a/source/source_lcao/module_ri/CMakeLists.txt +++ b/source/source_lcao/module_ri/CMakeLists.txt @@ -38,4 +38,4 @@ if (ENABLE_LIBRI) if(ENABLE_COVERAGE) add_coverage(ri) endif() -endif() \ No newline at end of file +endif() diff --git a/source/source_lcao/module_ri/setup_exx.cpp b/source/source_lcao/setup_exx.cpp similarity index 96% rename from source/source_lcao/module_ri/setup_exx.cpp rename to source/source_lcao/setup_exx.cpp index d01b3d73cc..a458d2fc36 100644 --- a/source/source_lcao/module_ri/setup_exx.cpp +++ b/source/source_lcao/setup_exx.cpp @@ -1,6 +1,6 @@ -#include "source_lcao/module_ri/setup_exx.h" +#include "source_lcao/setup_exx.h" -void Exx_NAO::init0() +void Exx_NAO::init() { #ifdef __EXX // 1. currently this initialization must be put in constructor rather than `before_all_runners()` diff --git a/source/source_lcao/module_ri/setup_exx.h b/source/source_lcao/setup_exx.h similarity index 90% rename from source/source_lcao/module_ri/setup_exx.h rename to source/source_lcao/setup_exx.h index 6be0b81f52..8abe75ffad 100644 --- a/source/source_lcao/module_ri/setup_exx.h +++ b/source/source_lcao/setup_exx.h @@ -1,5 +1,5 @@ -#ifndef SETUP_EXX_H -#define SETUP_EXX_H +#ifndef SETUP_EXX_NAO_H +#define SETUP_EXX_NAO_H /* #include "source_cell/unitcell.h" // use unitcell @@ -23,7 +23,7 @@ class Exx_NAO std::shared_ptr>> exc = nullptr; #endif - void init0(); + void init(); void before_runner(); From 4070dd0b522048c28a487e3d84306aa2123e429c Mon Sep 17 00:00:00 2001 From: mohanchen Date: Wed, 8 Oct 2025 15:32:45 +0800 Subject: [PATCH 07/16] fix bug about p_hamilt and fix setup_exx --- examples/relax/lcao_output/INPUT | 1 - source/source_esolver/esolver_double_xc.cpp | 6 ++-- source/source_esolver/esolver_ks_lcao.cpp | 20 ++++++------- source/source_esolver/esolver_ks_lcao.h | 2 +- source/source_esolver/lcao_after_scf.cpp | 4 +-- source/source_esolver/lcao_before_scf.cpp | 21 +++++++------- source/source_esolver/lcao_others.cpp | 6 ++-- source/source_lcao/FORCE.h | 4 +-- source/source_lcao/FORCE_STRESS.cpp | 18 ++++++------ source/source_lcao/FORCE_STRESS.h | 2 +- .../module_lr/esolver_lrtd_lcao.cpp | 8 ++--- source/source_lcao/setup_exx.cpp | 29 +++++++++++++++---- source/source_lcao/setup_exx.h | 21 +++++++++----- source/source_lcao/spar_dh.h | 4 +-- 14 files changed, 84 insertions(+), 62 deletions(-) diff --git a/examples/relax/lcao_output/INPUT b/examples/relax/lcao_output/INPUT index 87261f91cb..54175c33fd 100644 --- a/examples/relax/lcao_output/INPUT +++ b/examples/relax/lcao_output/INPUT @@ -40,4 +40,3 @@ out_band 0 out_stru 0 out_app_flag 0 -out_interval 1 diff --git a/source/source_esolver/esolver_double_xc.cpp b/source/source_esolver/esolver_double_xc.cpp index 48568c243f..bf68e20833 100644 --- a/source/source_esolver/esolver_double_xc.cpp +++ b/source/source_esolver/esolver_double_xc.cpp @@ -162,9 +162,9 @@ void ESolver_DoubleXC::before_scf(UnitCell& ucell, const int istep) #ifdef __EXX , istep, - GlobalC::exx_info.info_ri.real_number ? &this->exd->two_level_step : &this->exc->two_level_step, - GlobalC::exx_info.info_ri.real_number ? &this->exd->get_Hexxs() : nullptr, - GlobalC::exx_info.info_ri.real_number ? nullptr : &this->exc->get_Hexxs() + GlobalC::exx_info.info_ri.real_number ? &this->exx_nao.exd->two_level_step : &this->exx_nao.exc->two_level_step, + GlobalC::exx_info.info_ri.real_number ? &this->exx_nao.exd->get_Hexxs() : nullptr, + GlobalC::exx_info.info_ri.real_number ? nullptr : &this->exx_nao.exc->get_Hexxs() #endif ); } diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index 9cadb28c48..97c5057935 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -122,7 +122,7 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa dynamic_cast*>(this->pelec)->init_DM(&this->kv, &(this->pv), inp.nspin); // 8) init exact exchange calculations - this->exx_nao.before_runner(); + this->exx_nao.before_runner(ucell, this->kv, this->orb_, this->pv, PARAM.inp); // 9) initialize DFT+U if (inp.dft_plus_u) @@ -307,8 +307,8 @@ void ESolver_KS_LCAO::after_all_runners(UnitCell& ucell) this->orb_, this->pw_rho, this->pw_rhod, this->sf, this->locpp.vloc, #ifdef __EXX - this->exd, - this->exc, + this->exx_nao.exd, + this->exx_nao.exc, #endif this->solvent); @@ -347,7 +347,7 @@ void ESolver_KS_LCAO::iter_init(UnitCell& ucell, const int istep, const { // the following steps are only needed in the first outer exx loop exx_two_level_step - = GlobalC::exx_info.info_ri.real_number ? this->exd->two_level_step : this->exc->two_level_step; + = GlobalC::exx_info.info_ri.real_number ? this->exx_nao.exd->two_level_step : this->exx_nao.exc->two_level_step; } #endif // elecstate::setup_dm(ucell, estate, this->psi, this->chr, iter, exx_two_level_step); @@ -395,11 +395,11 @@ void ESolver_KS_LCAO::iter_init(UnitCell& ucell, const int istep, const { if (GlobalC::exx_info.info_ri.real_number) { - this->exd->exx_eachiterinit(istep, ucell, *dm, this->kv, iter); + this->exx_nao.exd->exx_eachiterinit(istep, ucell, *dm, this->kv, iter); } else { - this->exc->exx_eachiterinit(istep, ucell, *dm, this->kv, iter); + this->exx_nao.exc->exx_eachiterinit(istep, ucell, *dm, this->kv, iter); } } #endif @@ -486,11 +486,11 @@ void ESolver_KS_LCAO::hamilt2rho_single(UnitCell& ucell, int istep, int { if (GlobalC::exx_info.info_ri.real_number) { - this->exd->exx_hamilt2rho(*this->pelec, this->pv, iter); + this->exx_nao.exd->exx_hamilt2rho(*this->pelec, this->pv, iter); } else { - this->exc->exx_hamilt2rho(*this->pelec, this->pv, iter); + this->exx_nao.exc->exx_hamilt2rho(*this->pelec, this->pv, iter); } } #endif @@ -609,8 +609,8 @@ void ESolver_KS_LCAO::iter_finish(UnitCell& ucell, const int istep, int& this->ld, #endif #ifdef __EXX - *this->exd, - *this->exc, + *this->exx_nao.exd, + *this->exx_nao.exc, #endif iter, istep, conv_esolver, this->scf_ene_thr); diff --git a/source/source_esolver/esolver_ks_lcao.h b/source/source_esolver/esolver_ks_lcao.h index 4f8e1747ff..3c531b04b2 100644 --- a/source/source_esolver/esolver_ks_lcao.h +++ b/source/source_esolver/esolver_ks_lcao.h @@ -103,7 +103,7 @@ class ESolver_KS_LCAO : public ESolver_KS LCAO_Deepks ld; #endif - Exx_NAO exx_nao; + Exx_NAO exx_nao; friend class LR::ESolver_LR; friend class LR::ESolver_LR, double>; diff --git a/source/source_esolver/lcao_after_scf.cpp b/source/source_esolver/lcao_after_scf.cpp index bf819f8a75..a64d4586c4 100644 --- a/source/source_esolver/lcao_after_scf.cpp +++ b/source/source_esolver/lcao_after_scf.cpp @@ -65,8 +65,8 @@ void ESolver_KS_LCAO::after_scf(UnitCell& ucell, const int istep, const this->ld, #endif #ifdef __EXX - *this->exd, - *this->exc, + *this->exx_nao.exd, + *this->exx_nao.exc, #endif istep); } diff --git a/source/source_esolver/lcao_before_scf.cpp b/source/source_esolver/lcao_before_scf.cpp index 501db574f2..77505899a6 100644 --- a/source/source_esolver/lcao_before_scf.cpp +++ b/source/source_esolver/lcao_before_scf.cpp @@ -28,17 +28,11 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) ESolver_KS::before_scf(ucell, istep); auto* estate = dynamic_cast*>(this->pelec); - auto* hamilt_lcao = dynamic_cast*>(this->p_hamilt); - if(!estate) { ModuleBase::WARNING_QUIT("ESolver_KS_LCAO::before_scf","pelec does not exist"); } - if(!hamilt_lcao) - { - ModuleBase::WARNING_QUIT("ESolver_KS_LCAO::before_scf","p_hamilt does not exist"); - } //! 2) find search radius double search_radius = atom_arrange::set_sr_NL(GlobalV::ofs_running, @@ -143,9 +137,9 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) #ifdef __EXX , istep, - GlobalC::exx_info.info_ri.real_number ? &this->exd->two_level_step : &this->exc->two_level_step, - GlobalC::exx_info.info_ri.real_number ? &this->exd->get_Hexxs() : nullptr, - GlobalC::exx_info.info_ri.real_number ? nullptr : &this->exc->get_Hexxs() + GlobalC::exx_info.info_ri.real_number ? &this->exx_nao.exd->two_level_step : &this->exx_nao.exc->two_level_step, + GlobalC::exx_info.info_ri.real_number ? &this->exx_nao.exd->get_Hexxs() : nullptr, + GlobalC::exx_info.info_ri.real_number ? nullptr : &this->exx_nao.exc->get_Hexxs() #endif ); } @@ -185,11 +179,11 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) { if (GlobalC::exx_info.info_ri.real_number) { - this->exd->exx_beforescf(istep, this->kv, *this->p_chgmix, ucell, orb_); + this->exx_nao.exd->exx_beforescf(istep, this->kv, *this->p_chgmix, ucell, orb_); } else { - this->exc->exx_beforescf(istep, this->kv, *this->p_chgmix, ucell, orb_); + this->exx_nao.exc->exx_beforescf(istep, this->kv, *this->p_chgmix, ucell, orb_); } } #endif @@ -199,6 +193,11 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) // 13) initalize DMR // DMR should be same size with Hamiltonian(R) + auto* hamilt_lcao = dynamic_cast*>(this->p_hamilt); + if(!hamilt_lcao) + { + ModuleBase::WARNING_QUIT("ESolver_KS_LCAO::before_scf","p_hamilt does not exist"); + } estate->get_DM()->init_DMR(*hamilt_lcao->getHR()); #ifdef __MLALGO diff --git a/source/source_esolver/lcao_others.cpp b/source/source_esolver/lcao_others.cpp index 2cc85ca253..cc0aa32bcf 100644 --- a/source/source_esolver/lcao_others.cpp +++ b/source/source_esolver/lcao_others.cpp @@ -238,9 +238,9 @@ void ESolver_KS_LCAO::others(UnitCell& ucell, const int istep) #ifdef __EXX , istep, - GlobalC::exx_info.info_ri.real_number ? &this->exd->two_level_step : &this->exc->two_level_step, - GlobalC::exx_info.info_ri.real_number ? &this->exd->get_Hexxs() : nullptr, - GlobalC::exx_info.info_ri.real_number ? nullptr : &this->exc->get_Hexxs() + GlobalC::exx_info.info_ri.real_number ? &this->exx_nao.exd->two_level_step : &this->exx_nao.exc->two_level_step, + GlobalC::exx_info.info_ri.real_number ? &this->exx_nao.exd->get_Hexxs() : nullptr, + GlobalC::exx_info.info_ri.real_number ? nullptr : &this->exx_nao.exc->get_Hexxs() #endif ); } diff --git a/source/source_lcao/FORCE.h b/source/source_lcao/FORCE.h index 0382bdde71..c5649246d8 100644 --- a/source/source_lcao/FORCE.h +++ b/source/source_lcao/FORCE.h @@ -1,5 +1,5 @@ -#ifndef W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_MODULE_HAMILT_LCAO_HAMILT_LCAODFT_FORCE_H -#define W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_MODULE_HAMILT_LCAO_HAMILT_LCAODFT_FORCE_H +#ifndef LCAO_FORCE_H +#define LCAO_FORCE_H #include "source_base/global_function.h" #include "source_base/global_variable.h" diff --git a/source/source_lcao/FORCE_STRESS.cpp b/source/source_lcao/FORCE_STRESS.cpp index dbfed4aa25..3188d26310 100644 --- a/source/source_lcao/FORCE_STRESS.cpp +++ b/source/source_lcao/FORCE_STRESS.cpp @@ -55,7 +55,7 @@ void Force_Stress_LCAO::getForceStress(UnitCell& ucell, LCAO_Deepks& ld, const std::string& dpks_out_type, #endif - Exx_NAO &exx_nao, + Exx_NAO &exx_nao, ModuleSymmetry::Symmetry* symm) { ModuleBase::TITLE("Force_Stress_LCAO", "getForceStress"); @@ -375,26 +375,26 @@ void Force_Stress_LCAO::getForceStress(UnitCell& ucell, { if (GlobalC::exx_info.info_ri.real_number) { - exx_nao.exd.cal_exx_force(ucell.nat); - force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_nao.exd.get_force(); + exx_nao.exd->cal_exx_force(ucell.nat); + force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_nao.exd->get_force(); } else { - exx_nao.exc.cal_exx_force(ucell.nat); - force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_nao.exc.get_force(); + exx_nao.exc->cal_exx_force(ucell.nat); + force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_nao.exc->get_force(); } } if (isstress) { if (GlobalC::exx_info.info_ri.real_number) { - exx_nao.exd.cal_exx_stress(ucell.omega, ucell.lat0); - stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_nao.exd.get_stress(); + exx_nao.exd->cal_exx_stress(ucell.omega, ucell.lat0); + stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_nao.exd->get_stress(); } else { - exx_nao.exc.cal_exx_stress(ucell.omega, ucell.lat0); - stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_nao.exc.get_stress(); + exx_nao.exc->cal_exx_stress(ucell.omega, ucell.lat0); + stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_nao.exc->get_stress(); } } } diff --git a/source/source_lcao/FORCE_STRESS.h b/source/source_lcao/FORCE_STRESS.h index c98d36326d..5febf0a1aa 100644 --- a/source/source_lcao/FORCE_STRESS.h +++ b/source/source_lcao/FORCE_STRESS.h @@ -55,7 +55,7 @@ class Force_Stress_LCAO LCAO_Deepks& ld, const std::string& dpks_out_type, #endif - Exx_NAO &exx_nao, + Exx_NAO &exx_nao, ModuleSymmetry::Symmetry* symm); private: diff --git a/source/source_lcao/module_lr/esolver_lrtd_lcao.cpp b/source/source_lcao/module_lr/esolver_lrtd_lcao.cpp index 6fb0d256eb..6608acb150 100644 --- a/source/source_lcao/module_lr/esolver_lrtd_lcao.cpp +++ b/source/source_lcao/module_lr/esolver_lrtd_lcao.cpp @@ -274,10 +274,10 @@ LR::ESolver_LR::ESolver_LR(ModuleESolver::ESolver_KS_LCAO&& ks_sol { // if the same kernel is calculated in the esolver_ks, move it std::string dft_functional = LR_Util::tolower(input.dft_functional); - if (ks_sol.exd && std::is_same::value && xc_kernel == dft_functional) { - this->move_exx_lri(ks_sol.exd->exx_ptr); - } else if (ks_sol.exc && std::is_same>::value && xc_kernel == dft_functional) { - this->move_exx_lri(ks_sol.exc->exx_ptr); + if (ks_sol.exx_nao.exd && std::is_same::value && xc_kernel == dft_functional) { + this->move_exx_lri(ks_sol.exx_nao.exd->exx_ptr); + } else if (ks_sol.exx_nao.exc && std::is_same>::value && xc_kernel == dft_functional) { + this->move_exx_lri(ks_sol.exx_nao.exc->exx_ptr); } else // construct C, V from scratch { // set ccp_type according to the xc_kernel diff --git a/source/source_lcao/setup_exx.cpp b/source/source_lcao/setup_exx.cpp index a458d2fc36..aeabfaac36 100644 --- a/source/source_lcao/setup_exx.cpp +++ b/source/source_lcao/setup_exx.cpp @@ -1,6 +1,14 @@ #include "source_lcao/setup_exx.h" -void Exx_NAO::init() +template +Exx_NAO::Exx_NAO(){} + +template +Exx_NAO::~Exx_NAO(){} + + +template +void Exx_NAO::init() { #ifdef __EXX // 1. currently this initialization must be put in constructor rather than `before_all_runners()` @@ -19,7 +27,13 @@ void Exx_NAO::init() #endif } -void Exx_NAO::before_runner() +template +void Exx_NAO::before_runner( + UnitCell& ucell, // unitcell + K_Vectors &kv, // k points + const LCAO_Orbitals &orb, // orbital info + const Parallel_Orbitals &pv, // parallel orbitals + const Input_para& inp) { #ifdef __EXX if (inp.calculation == "scf" || inp.calculation == "relax" || inp.calculation == "cell-relax" @@ -35,15 +49,18 @@ void Exx_NAO::before_runner() // initialize 2-center radial tables for EXX-LRI if (GlobalC::exx_info.info_ri.real_number) { - this->exd->init(MPI_COMM_WORLD, ucell, this->kv, orb_); - this->exd->exx_before_all_runners(this->kv, ucell, this->pv); + this->exd->init(MPI_COMM_WORLD, ucell, kv, orb); + this->exd->exx_before_all_runners(kv, ucell, pv); } else { - this->exc->init(MPI_COMM_WORLD, ucell, this->kv, orb_); - this->exc->exx_before_all_runners(this->kv, ucell, this->pv); + this->exc->init(MPI_COMM_WORLD, ucell, kv, orb); + this->exc->exx_before_all_runners(kv, ucell, pv); } } } #endif } + +template class Exx_NAO; +template class Exx_NAO>; diff --git a/source/source_lcao/setup_exx.h b/source/source_lcao/setup_exx.h index 8abe75ffad..bc94be9296 100644 --- a/source/source_lcao/setup_exx.h +++ b/source/source_lcao/setup_exx.h @@ -1,12 +1,11 @@ #ifndef SETUP_EXX_NAO_H #define SETUP_EXX_NAO_H -/* #include "source_cell/unitcell.h" // use unitcell -#include "source_estate/elecstate_lcao.h"// use ElecStateLCAO -#include "source_psi/psi.h" // use electronic wave functions -#include "source_estate/module_charge/charge.h" // use charge -*/ +#include "source_cell/klist.h" // k points +#include "source_io/input_conv.h" // inp +#include "source_basis/module_ao/parallel_orbitals.h" // parallel orbitals +#include "source_basis/module_ao/ORB_read.h" // orb // for EXX #ifdef __EXX @@ -14,10 +13,14 @@ #include "source_lcao/module_ri/Mix_DMk_2D.h" #endif +template class Exx_NAO { public: + Exx_NAO(); + ~Exx_NAO(); + #ifdef __EXX std::shared_ptr> exd = nullptr; std::shared_ptr>> exc = nullptr; @@ -25,8 +28,12 @@ class Exx_NAO void init(); - void before_runner(); - + void before_runner( + UnitCell& ucell, // unitcell + K_Vectors &kv, // k points + const LCAO_Orbitals &orb, // orbital info + const Parallel_Orbitals &pv, // parallel orbitals + const Input_para& inp); }; diff --git a/source/source_lcao/spar_dh.h b/source/source_lcao/spar_dh.h index c972524466..a71ebe4ec2 100644 --- a/source/source_lcao/spar_dh.h +++ b/source/source_lcao/spar_dh.h @@ -1,5 +1,5 @@ -#ifndef W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_MODULE_HAMILT_LCAO_HAMILT_LCAODFT_SPAR_DH_H -#define W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_MODULE_HAMILT_LCAO_HAMILT_LCAODFT_SPAR_DH_H +#ifndef SPAR_DH_H +#define SPAR_DH_H #include "source_cell/module_neighbor/sltk_atom_arrange.h" #include "source_cell/module_neighbor/sltk_grid_driver.h" From c2c5951cdc4591a4cb62530be9b9f9a347361181 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Wed, 8 Oct 2025 15:37:03 +0800 Subject: [PATCH 08/16] fix bug --- source/source_lcao/setup_exx.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/source_lcao/setup_exx.h b/source/source_lcao/setup_exx.h index bc94be9296..34b4a670bb 100644 --- a/source/source_lcao/setup_exx.h +++ b/source/source_lcao/setup_exx.h @@ -3,7 +3,7 @@ #include "source_cell/unitcell.h" // use unitcell #include "source_cell/klist.h" // k points -#include "source_io/input_conv.h" // inp +#include "source_io/module_parameter/input_parameter.h" // Input_para #include "source_basis/module_ao/parallel_orbitals.h" // parallel orbitals #include "source_basis/module_ao/ORB_read.h" // orb From c7c89766613cdfd288f111bcf668bbc8a3d72f51 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Wed, 8 Oct 2025 16:03:26 +0800 Subject: [PATCH 09/16] fix bug --- source/source_esolver/esolver_ks_lcao.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index 97c5057935..32e6ce55db 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -350,9 +350,9 @@ void ESolver_KS_LCAO::iter_init(UnitCell& ucell, const int istep, const = GlobalC::exx_info.info_ri.real_number ? this->exx_nao.exd->two_level_step : this->exx_nao.exc->two_level_step; } #endif -// elecstate::setup_dm(ucell, estate, this->psi, this->chr, iter, exx_two_level_step); - + elecstate::setup_dm(ucell, estate, this->psi, this->chr, iter, exx_two_level_step); +/* if (iter == 1 && exx_two_level_step == 0) { std::cout << " WAVEFUN -> CHARGE " << std::endl; @@ -385,6 +385,7 @@ void ESolver_KS_LCAO::iter_init(UnitCell& ucell, const int istep, const estate->f_en.descf = estate->cal_delta_escf(); } +*/ } From 5b32de92720030a4df8b467bafc9cc62f909593f Mon Sep 17 00:00:00 2001 From: mohanchen Date: Wed, 8 Oct 2025 16:25:49 +0800 Subject: [PATCH 10/16] eliminate exd and exc from ctrl_iter_lcao input parameters --- source/source_esolver/esolver_ks_lcao.cpp | 6 +----- source/source_io/ctrl_iter_lcao.cpp | 25 +++++++---------------- source/source_io/ctrl_iter_lcao.h | 9 ++------ source/source_lcao/setup_exx.h | 1 - 4 files changed, 10 insertions(+), 31 deletions(-) diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index 32e6ce55db..42371c4a5e 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -609,11 +609,7 @@ void ESolver_KS_LCAO::iter_finish(UnitCell& ucell, const int istep, int& #ifdef __MLALGO this->ld, #endif -#ifdef __EXX - *this->exx_nao.exd, - *this->exx_nao.exc, -#endif - iter, istep, conv_esolver, this->scf_ene_thr); + exx_nao, iter, istep, conv_esolver, this->scf_ene_thr); } diff --git a/source/source_io/ctrl_iter_lcao.cpp b/source/source_io/ctrl_iter_lcao.cpp index d14ba00f63..cb0dd3cca0 100644 --- a/source/source_io/ctrl_iter_lcao.cpp +++ b/source/source_io/ctrl_iter_lcao.cpp @@ -24,10 +24,7 @@ void ctrl_iter_lcao(UnitCell& ucell, // unit cell * #ifdef __MLALGO LCAO_Deepks& ld, #endif -#ifdef __EXX - Exx_LRI_Interface& exd, - Exx_LRI_Interface>& exc, -#endif + Exx_NAO &exx_nao, int &iter, const int istep, bool &conv_esolver, @@ -53,14 +50,15 @@ void ctrl_iter_lcao(UnitCell& ucell, // unit cell * if (GlobalC::exx_info.info_global.cal_exx) { GlobalC::exx_info.info_ri.real_number ? - exd.exx_iter_finish(kv, ucell, *p_hamilt, *pelec, + exx_nao.exd->exx_iter_finish(kv, ucell, *p_hamilt, *pelec, *p_chgmix, scf_ene_thr, iter, istep, conv_esolver) : - exc.exx_iter_finish(kv, ucell, *p_hamilt, *pelec, + exx_nao.exc->exx_iter_finish(kv, ucell, *p_hamilt, *pelec, *p_chgmix, scf_ene_thr, iter, istep, conv_esolver); } } #endif + // for deepks, output labels during electronic steps (after conv_esolver is renewed) #ifdef __MLALGO if (inp.deepks_out_labels >0 && inp.deepks_out_freq_elec) @@ -96,10 +94,7 @@ template void ctrl_iter_lcao(UnitCell& ucell, // unit cell * #ifdef __MLALGO LCAO_Deepks& ld, #endif -#ifdef __EXX - Exx_LRI_Interface& exd, - Exx_LRI_Interface>& exc, -#endif + Exx_NAO &exx_nao, int &iter, const int istep, bool &conv_esolver, @@ -120,10 +115,7 @@ template void ctrl_iter_lcao, double>(UnitCell& ucell, // u #ifdef __MLALGO LCAO_Deepks>& ld, #endif -#ifdef __EXX - Exx_LRI_Interface, double>& exd, - Exx_LRI_Interface, std::complex>& exc, -#endif + Exx_NAO> &exx_nao, int &iter, const int istep, bool &conv_esolver, @@ -144,10 +136,7 @@ template void ctrl_iter_lcao, std::complex>(UnitCel #ifdef __MLALGO LCAO_Deepks>& ld, #endif -#ifdef __EXX - Exx_LRI_Interface, double>& exd, - Exx_LRI_Interface, std::complex>& exc, -#endif + Exx_NAO> &exx_nao, int &iter, const int istep, bool &conv_esolver, diff --git a/source/source_io/ctrl_iter_lcao.h b/source/source_io/ctrl_iter_lcao.h index 7b6dc32b47..94c3c4364e 100644 --- a/source/source_io/ctrl_iter_lcao.h +++ b/source/source_io/ctrl_iter_lcao.h @@ -8,9 +8,7 @@ #include "source_estate/module_charge/charge.h" // use charge #include "source_estate/module_charge/charge_mixing.h" // use charge mixing #include "source_lcao/hamilt_lcao.h" // use hamilt::HamiltLCAO -#ifdef __EXX -#include "source_lcao/module_ri/Exx_LRI_interface.h" // use EXX codes -#endif +#include "source_lcao/setup_exx.h" // mohan add 20251008 namespace ModuleIO { @@ -30,10 +28,7 @@ void ctrl_iter_lcao(UnitCell& ucell, // unit cell * #ifdef __MLALGO LCAO_Deepks& ld, #endif -#ifdef __EXX - Exx_LRI_Interface& exd, - Exx_LRI_Interface>& exc, -#endif + Exx_NAO &exx_nao, int &iter, const int istep, bool &conv_esolver, diff --git a/source/source_lcao/setup_exx.h b/source/source_lcao/setup_exx.h index 34b4a670bb..33154ba8e6 100644 --- a/source/source_lcao/setup_exx.h +++ b/source/source_lcao/setup_exx.h @@ -35,7 +35,6 @@ class Exx_NAO const Parallel_Orbitals &pv, // parallel orbitals const Input_para& inp); - }; From e25b9238fd7a23d3abeebee76162ba0acb77fae9 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Wed, 8 Oct 2025 16:54:52 +0800 Subject: [PATCH 11/16] add setup_deepks --- source/Makefile.Objects | 1 + source/source_esolver/esolver_ks_lcao.cpp | 16 +--------- source/source_esolver/esolver_ks_lcao.h | 10 +++--- source/source_lcao/CMakeLists.txt | 1 + source/source_lcao/setup_deepks.cpp | 31 +++++++++++++++++++ source/source_lcao/setup_deepks.h | 37 +++++++++++++++++++++++ 6 files changed, 75 insertions(+), 21 deletions(-) create mode 100644 source/source_lcao/setup_deepks.cpp create mode 100644 source/source_lcao/setup_deepks.h diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 6475fa82d6..3895a65b9f 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -670,6 +670,7 @@ OBJS_LCAO=evolve_elec.o\ LCAO_set_mat2d.o\ LCAO_init_basis.o\ setup_exx.o\ + setup_deepks.o\ center2_orb.o\ center2_orb-orb11.o\ center2_orb-orb21.o\ diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index 42371c4a5e..1f7ca369e3 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -24,9 +24,6 @@ #include "source_io/ctrl_runner_lcao.h" // use ctrl_runner_lcao() #include "source_io/ctrl_iter_lcao.h" // use ctrl_iter_lcao() -// tmp -#include "source_estate/module_dm/cal_dm_psi.h" - namespace ModuleESolver { @@ -148,18 +145,7 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa } // 13) init deepks -#ifdef __MLALGO - LCAO_domain::DeePKS_init(ucell, pv, this->kv.get_nks(), orb_, this->ld, GlobalV::ofs_running); - if (inp.deepks_scf) - { - // load the DeePKS model from deep neural network - DeePKS_domain::load_model(inp.deepks_model, ld.model_deepks); - // read pdm from file for NSCF or SCF-restart, do it only once in whole calculation - DeePKS_domain::read_pdm((inp.init_chg == "file"), inp.deepks_equiv, - ld.init_pdm, ucell.nat, orb_.Alpha[0].getTotal_nchi() * ucell.nat, - ld.lmaxd, ld.inl2l, *orb_.Alpha, ld.pdm); - } -#endif + this->deepks.before_runner(ucell, this->kv, this->orb_, this->pv, PARAM.inp); // 14) set occupations, tddft does not need to set occupations in the first scf if (inp.ocp && inp.esolver_type != "tddft") diff --git a/source/source_esolver/esolver_ks_lcao.h b/source/source_esolver/esolver_ks_lcao.h index 3c531b04b2..b428e8fc3f 100644 --- a/source/source_esolver/esolver_ks_lcao.h +++ b/source/source_esolver/esolver_ks_lcao.h @@ -8,9 +8,7 @@ #include "source_lcao/module_gint/gint_k.h" // gint for multi k-points #include "source_lcao/module_gint/temp_gint/gint.h" // gint #include "source_lcao/module_gint/temp_gint/gint_info.h" -#ifdef __MLALGO -#include "source_lcao/module_deepks/LCAO_deepks.h" // deepks -#endif +#include "source_lcao/setup_deepks.h" // for deepks, mohan add 20251008 #include "source_lcao/setup_exx.h" // for exx, mohan add 20251008 #include "source_lcao/module_rdmft/rdmft.h" // rdmft @@ -99,10 +97,10 @@ class ESolver_KS_LCAO : public ESolver_KS ModuleBase::matrix scs; bool have_force = false; -#ifdef __MLALGO - LCAO_Deepks ld; -#endif + // deepks method, mohan add 2025-10-08 + DeePKS deepks; + // exact-exchange energy, mohan add 2025-10-08 Exx_NAO exx_nao; friend class LR::ESolver_LR; diff --git a/source/source_lcao/CMakeLists.txt b/source/source_lcao/CMakeLists.txt index d3df201219..30ea5a3118 100644 --- a/source/source_lcao/CMakeLists.txt +++ b/source/source_lcao/CMakeLists.txt @@ -42,6 +42,7 @@ if(ENABLE_LCAO) LCAO_set_mat2d.cpp LCAO_init_basis.cpp setup_exx.cpp + setup_deepks.cpp record_adj.cpp center2_orb.cpp center2_orb-orb11.cpp diff --git a/source/source_lcao/setup_deepks.cpp b/source/source_lcao/setup_deepks.cpp new file mode 100644 index 0000000000..08c1904bd5 --- /dev/null +++ b/source/source_lcao/setup_deepks.cpp @@ -0,0 +1,31 @@ +#include "source_lcao/setup_deepks.h" + +template +DeePKS::DeePKS(){} + +template +DeePKS::~DeePKS(){} + +template +void DeePKS::before_runner(UnitCell& ucell, // unitcell + K_Vectors &kv, // k points + const LCAO_Orbitals &orb, // orbital info + const Parallel_Orbitals &pv, // parallel orbitals + const Input_para& inp) +{ +#ifdef __MLALGO + LCAO_domain::DeePKS_init(ucell, pv, kv.get_nks(), orb_, this->ld, GlobalV::ofs_running); + if (inp.deepks_scf) + { + // load the DeePKS model from deep neural network + DeePKS_domain::load_model(inp.deepks_model, ld.model_deepks); + // read pdm from file for NSCF or SCF-restart, do it only once in whole calculation + DeePKS_domain::read_pdm((inp.init_chg == "file"), inp.deepks_equiv, + ld.init_pdm, ucell.nat, orb_.Alpha[0].getTotal_nchi() * ucell.nat, + ld.lmaxd, ld.inl2l, *orb_.Alpha, ld.pdm); + } +#endif +} + +template class DeePKS; +template class DeePKS>; diff --git a/source/source_lcao/setup_deepks.h b/source/source_lcao/setup_deepks.h new file mode 100644 index 0000000000..fffd1c3d8e --- /dev/null +++ b/source/source_lcao/setup_deepks.h @@ -0,0 +1,37 @@ +#ifndef SETUP_DEEPKS_H +#define SETUP_DEEPKS_H + +#include "source_cell/unitcell.h" // use unitcell +#include "source_cell/klist.h" // k points +#include "source_io/module_parameter/input_parameter.h" // Input_para +#include "source_basis/module_ao/parallel_orbitals.h" // parallel orbitals +#include "source_basis/module_ao/ORB_read.h" // orb + +#ifdef __MLALGO +#include "source_lcao/module_deepks/LCAO_deepks.h" // deepks +#endif + + +template +class DeePKS +{ + public: + + DeePKS(); + ~DeePKS(); + +#ifdef __MLALGO + LCAO_Deepks ld; +#endif + + void before_runner( + UnitCell& ucell, // unitcell + K_Vectors &kv, // k points + const LCAO_Orbitals &orb, // orbital info + const Parallel_Orbitals &pv, // parallel orbitals + const Input_para& inp); + +}; + + +#endif From 3c36487a8995dcc2c8f83d769c71df66b8e14d48 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Wed, 8 Oct 2025 17:06:04 +0800 Subject: [PATCH 12/16] update the input variables of ctrl_scf_lcao --- source/source_esolver/lcao_after_scf.cpp | 11 ++---- source/source_io/ctrl_scf_lcao.cpp | 44 +++++++----------------- source/source_io/ctrl_scf_lcao.h | 15 +++----- 3 files changed, 20 insertions(+), 50 deletions(-) diff --git a/source/source_esolver/lcao_after_scf.cpp b/source/source_esolver/lcao_after_scf.cpp index a64d4586c4..4f01542466 100644 --- a/source/source_esolver/lcao_after_scf.cpp +++ b/source/source_esolver/lcao_after_scf.cpp @@ -61,14 +61,9 @@ void ESolver_KS_LCAO::after_scf(UnitCell& ucell, const int istep, const this->orb_, this->pw_wfc, this->pw_rho, this->GridT, this->pw_big, this->sf, this->rdmft_solver, -#ifdef __MLALGO - this->ld, -#endif -#ifdef __EXX - *this->exx_nao.exd, - *this->exx_nao.exc, -#endif - istep); + this->deepks, + this->exx_nao, + istep); } //------------------------------------------------------------------ diff --git a/source/source_io/ctrl_scf_lcao.cpp b/source/source_io/ctrl_scf_lcao.cpp index c981e631e3..c11f954c10 100644 --- a/source/source_io/ctrl_scf_lcao.cpp +++ b/source/source_io/ctrl_scf_lcao.cpp @@ -50,14 +50,9 @@ void ctrl_scf_lcao(UnitCell& ucell, Grid_Technique >, // for berryphase const ModulePW::PW_Basis_Big* pw_big, // for Wannier90 const Structure_Factor& sf, // for Wannier90 - rdmft::RDMFT &rdmft_solver, // for RDMFT -#ifdef __MLALGO - LCAO_Deepks& ld, -#endif -#ifdef __EXX - Exx_LRI_Interface& exd, - Exx_LRI_Interface>& exc, -#endif + rdmft::RDMFT &rdmft_solver, // for RDMFT + DeePKS &deepks, + Exx_NAO &exx_nao, const int istep) { ModuleBase::TITLE("ModuleIO", "ctrl_scf_lcao"); @@ -162,7 +157,7 @@ void ctrl_scf_lcao(UnitCell& ucell, #ifdef __MLALGO // need control parameter hamilt::HamiltLCAO* p_ham_deepks = p_hamilt; - std::shared_ptr> ld_shared_ptr(&ld, [](LCAO_Deepks*) {}); + std::shared_ptr> ld_shared_ptr(&deepks.ld, [](LCAO_Deepks*) {}); LCAO_Deepks_Interface deepks_interface(ld_shared_ptr); deepks_interface.out_deepks_labels(pelec->f_en.etot, @@ -364,11 +359,11 @@ void ctrl_scf_lcao(UnitCell& ucell, + "HexxR" + std::to_string(GlobalV::MY_RANK); if (GlobalC::exx_info.info_ri.real_number) { - ModuleIO::write_Hexxs_csr(file_name_exx, ucell, exd.get_Hexxs()); + ModuleIO::write_Hexxs_csr(file_name_exx, ucell, exx_nao.exd->get_Hexxs()); } else { - ModuleIO::write_Hexxs_csr(file_name_exx, ucell, exc.get_Hexxs()); + ModuleIO::write_Hexxs_csr(file_name_exx, ucell, exx_nao.exc->get_Hexxs()); } } } @@ -458,13 +453,8 @@ template void ModuleIO::ctrl_scf_lcao(UnitCell& ucell, const ModulePW::PW_Basis_Big* pw_big, // for Wannier90 const Structure_Factor& sf, // for Wannier90 rdmft::RDMFT &rdmft_solver, // for RDMFT -#ifdef __MLALGO - LCAO_Deepks& ld, -#endif -#ifdef __EXX - Exx_LRI_Interface& exd, - Exx_LRI_Interface>& exc, -#endif + DeePKS &deepks, + Exx_NAO &exx_nao, const int istep); // For multiple k-points @@ -485,13 +475,8 @@ template void ModuleIO::ctrl_scf_lcao, double>(UnitCell& uc const ModulePW::PW_Basis_Big* pw_big, // for Wannier90 const Structure_Factor& sf, // for Wannier90 rdmft::RDMFT, double> &rdmft_solver, // for RDMFT -#ifdef __MLALGO - LCAO_Deepks>& ld, -#endif -#ifdef __EXX - Exx_LRI_Interface, double>& exd, - Exx_LRI_Interface, std::complex>& exc, -#endif + DeePKS> &deepks, + Exx_NAO> &exx_nao, const int istep); template void ModuleIO::ctrl_scf_lcao, std::complex>(UnitCell& ucell, @@ -511,12 +496,7 @@ template void ModuleIO::ctrl_scf_lcao, std::complex const ModulePW::PW_Basis_Big* pw_big, // for Wannier90 const Structure_Factor& sf, // for Wannier90 rdmft::RDMFT, std::complex> &rdmft_solver, // for RDMFT -#ifdef __MLALGO - LCAO_Deepks>& ld, -#endif -#ifdef __EXX - Exx_LRI_Interface, double>& exd, - Exx_LRI_Interface, std::complex>& exc, -#endif + DeePKS> &deepks, + Exx_NAO> &exx_nao, const int istep); diff --git a/source/source_io/ctrl_scf_lcao.h b/source/source_io/ctrl_scf_lcao.h index 8e3dcf4d29..107f00595d 100644 --- a/source/source_io/ctrl_scf_lcao.h +++ b/source/source_io/ctrl_scf_lcao.h @@ -13,9 +13,9 @@ #include "source_basis/module_pw/pw_basis_k.h" // use ModulePW::PW_Basis_K and ModulePW::PW_Basis #include "source_pw/module_pwdft/structure_factor.h" // use Structure_Factor #include "source_lcao/module_rdmft/rdmft.h" // use RDMFT codes -#ifdef __EXX -#include "source_lcao/module_ri/Exx_LRI_interface.h" // use EXX codes -#endif + +#include "source_lcao/setup_deepks.h" // for deepks, mohan add 20251008 +#include "source_lcao/setup_exx.h" // for exx, mohan add 20251008 namespace ModuleIO { @@ -38,13 +38,8 @@ namespace ModuleIO const ModulePW::PW_Basis_Big* pw_big, // for Wannier90 const Structure_Factor& sf, // for Wannier90 rdmft::RDMFT &rdmft_solver, // for RDMFT -#ifdef __MLALGO - LCAO_Deepks& ld, -#endif -#ifdef __EXX - Exx_LRI_Interface& exd, - Exx_LRI_Interface>& exc, -#endif + DeePKS &deepks, + Exx_NAO &exx_nao, const int istep); } #endif From 72e702363643436411293da01dafbb82aa0630f3 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Wed, 8 Oct 2025 17:21:37 +0800 Subject: [PATCH 13/16] reconstruct the constructor of p_hamilt --- source/source_esolver/esolver_double_xc.cpp | 7 ++----- source/source_esolver/lcao_before_scf.cpp | 6 +----- source/source_esolver/lcao_others.cpp | 7 ++----- source/source_lcao/hamilt_lcao.cpp | 17 +++++++---------- source/source_lcao/hamilt_lcao.h | 11 +++-------- 5 files changed, 15 insertions(+), 33 deletions(-) diff --git a/source/source_esolver/esolver_double_xc.cpp b/source/source_esolver/esolver_double_xc.cpp index bf68e20833..f6e22baf42 100644 --- a/source/source_esolver/esolver_double_xc.cpp +++ b/source/source_esolver/esolver_double_xc.cpp @@ -154,11 +154,8 @@ void ESolver_DoubleXC::before_scf(UnitCell& ucell, const int istep) this->kv, this->two_center_bundle_, this->orb_, - DM -#ifdef __MLALGO - , - &this->ld -#endif + DM, + this->deepks #ifdef __EXX , istep, diff --git a/source/source_esolver/lcao_before_scf.cpp b/source/source_esolver/lcao_before_scf.cpp index 77505899a6..f8239a1d26 100644 --- a/source/source_esolver/lcao_before_scf.cpp +++ b/source/source_esolver/lcao_before_scf.cpp @@ -129,11 +129,7 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) PARAM.globalv.gamma_only_local ? &(this->GG) : nullptr, PARAM.globalv.gamma_only_local ? nullptr : &(this->GK), ucell, this->gd, &this->pv, this->pelec->pot, this->kv, - two_center_bundle_, orb_, DM -#ifdef __MLALGO - , - &this->ld -#endif + two_center_bundle_, orb_, DM, this->deepks #ifdef __EXX , istep, diff --git a/source/source_esolver/lcao_others.cpp b/source/source_esolver/lcao_others.cpp index cc0aa32bcf..4d4199ef9b 100644 --- a/source/source_esolver/lcao_others.cpp +++ b/source/source_esolver/lcao_others.cpp @@ -230,11 +230,8 @@ void ESolver_KS_LCAO::others(UnitCell& ucell, const int istep) this->kv, two_center_bundle_, orb_, - DM -#ifdef __MLALGO - , - &this->ld -#endif + DM, + this->deepks #ifdef __EXX , istep, diff --git a/source/source_lcao/hamilt_lcao.cpp b/source/source_lcao/hamilt_lcao.cpp index c0b60e8331..c335e7bc4d 100644 --- a/source/source_lcao/hamilt_lcao.cpp +++ b/source/source_lcao/hamilt_lcao.cpp @@ -79,11 +79,8 @@ HamiltLCAO::HamiltLCAO(Gint_Gamma* GG_in, const K_Vectors& kv_in, const TwoCenterBundle& two_center_bundle, const LCAO_Orbitals& orb, - elecstate::DensityMatrix* DM_in -#ifdef __MLALGO - , - LCAO_Deepks* ld_in -#endif + elecstate::DensityMatrix* DM_in, + DeePKS &deepks #ifdef __EXX , const int istep, @@ -214,7 +211,7 @@ HamiltLCAO::HamiltLCAO(Gint_Gamma* GG_in, &orb, this->kv->get_nks(), DM_in, - ld_in); + deepks.ld); this->getOperator()->add(deepks); this->V_delta_R = dynamic_cast>*>(deepks)->get_V_delta_R(); } @@ -330,7 +327,7 @@ HamiltLCAO::HamiltLCAO(Gint_Gamma* GG_in, #ifdef __MLALGO if (PARAM.inp.deepks_scf) { - Operator* deepks = new DeePKS>(this->hsk, + Operator* deepks_op = new DeePKS>(this->hsk, this->kv->kvec_d, hR, &ucell, @@ -339,9 +336,9 @@ HamiltLCAO::HamiltLCAO(Gint_Gamma* GG_in, &orb, this->kv->get_nks(), DM_in, - ld_in); - this->getOperator()->add(deepks); - this->V_delta_R = dynamic_cast>*>(deepks)->get_V_delta_R(); + deepks.ld); + this->getOperator()->add(deepks_op); + this->V_delta_R = dynamic_cast>*>(deepks_op)->get_V_delta_R(); } #endif // TDDFT_velocity_gauge diff --git a/source/source_lcao/hamilt_lcao.h b/source/source_lcao/hamilt_lcao.h index 16c5c34e0a..cde0269629 100644 --- a/source/source_lcao/hamilt_lcao.h +++ b/source/source_lcao/hamilt_lcao.h @@ -14,9 +14,7 @@ #include -#ifdef __MLALGO -#include "source_lcao/module_deepks/LCAO_deepks.h" -#endif +#include "source_lcao/setup_deepks.h" // mohan add 20251008 #ifdef __EXX #include "source_lcao/module_ri/Exx_LRI.h" @@ -50,11 +48,8 @@ class HamiltLCAO : public Hamilt const K_Vectors& kv_in, const TwoCenterBundle& two_center_bundle, const LCAO_Orbitals& orb, - elecstate::DensityMatrix* DM_in -#ifdef __MLALGO - , - LCAO_Deepks* ld_in -#endif + elecstate::DensityMatrix* DM_in, + DeePKS &deepks #ifdef __EXX , const int istep, From e87c8b9abbdf65533e10412a60f8250317cdb008 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Wed, 8 Oct 2025 17:29:41 +0800 Subject: [PATCH 14/16] update ld --- source/source_esolver/esolver_double_xc.cpp | 4 ++-- source/source_esolver/esolver_ks_lcao.cpp | 6 +++--- source/source_esolver/lcao_before_scf.cpp | 8 ++++---- source/source_esolver/lcao_others.cpp | 6 +++--- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/source/source_esolver/esolver_double_xc.cpp b/source/source_esolver/esolver_double_xc.cpp index f6e22baf42..c6947b8687 100644 --- a/source/source_esolver/esolver_double_xc.cpp +++ b/source/source_esolver/esolver_double_xc.cpp @@ -235,7 +235,7 @@ void ESolver_DoubleXC::iter_finish(UnitCell& ucell, const int istep, int #ifdef __MLALGO // ---------- output tot and precalc ---------- hamilt::HamiltLCAO* p_ham_deepks = dynamic_cast*>(this->p_hamilt); - std::shared_ptr> ld_shared_ptr(&this->ld, [](LCAO_Deepks*) {}); + std::shared_ptr> ld_shared_ptr(&this->deepks.ld, [](LCAO_Deepks*) {}); LCAO_Deepks_Interface deepks_interface(ld_shared_ptr); deepks_interface.out_deepks_labels(this->pelec->f_en.etot, @@ -409,7 +409,7 @@ void ESolver_DoubleXC::cal_force(UnitCell& ucell, ModuleBase::matrix& fo this->pw_rho, this->solvent, #ifdef __MLALGO - this->ld, + this->deepks.ld, "base", #endif this->exx_nao, diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index 1f7ca369e3..ebf57154a3 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -224,7 +224,7 @@ void ESolver_KS_LCAO::cal_force(UnitCell& ucell, ModuleBase::matrix& for this->pw_rho, this->solvent, #ifdef __MLALGO - this->ld, + this->deepks.ld, "tot", #endif this->exx_nao, @@ -413,7 +413,7 @@ void ESolver_KS_LCAO::iter_init(UnitCell& ucell, const int istep, const // if (iter == 1 && istep == 0) // { // // initialize DMR - // this->ld.init_DMR(ucell, orb_, this->pv, this->gd); + // this->deepks.ld.init_DMR(ucell, orb_, this->pv, this->gd); // } #endif @@ -593,7 +593,7 @@ void ESolver_KS_LCAO::iter_finish(UnitCell& ucell, const int istep, int& this->pv, this->gd, this->psi, this->chr, this->p_chgmix, hamilt_lcao, this->orb_, #ifdef __MLALGO - this->ld, + this->deepks.ld, #endif exx_nao, iter, istep, conv_esolver, this->scf_ene_thr); diff --git a/source/source_esolver/lcao_before_scf.cpp b/source/source_esolver/lcao_before_scf.cpp index f8239a1d26..e40203a581 100644 --- a/source/source_esolver/lcao_before_scf.cpp +++ b/source/source_esolver/lcao_before_scf.cpp @@ -147,15 +147,15 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) { const Parallel_Orbitals* pv = &this->pv; // allocate , phialpha is different every ion step, so it is allocated here - DeePKS_domain::allocate_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, pv, this->ld.phialpha); + DeePKS_domain::allocate_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, pv, this->deepks.ld.phialpha); // build and save at beginning DeePKS_domain::build_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, - pv, *(two_center_bundle_.overlap_orb_alpha), this->ld.phialpha); + pv, *(two_center_bundle_.overlap_orb_alpha), this->deepks.ld.phialpha); if (PARAM.inp.deepks_out_unittest) { DeePKS_domain::check_phialpha(PARAM.inp.cal_force, ucell, orb_, - this->gd, pv, this->ld.phialpha, GlobalV::MY_RANK); + this->gd, pv, this->deepks.ld.phialpha, GlobalV::MY_RANK); } } #endif @@ -198,7 +198,7 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) #ifdef __MLALGO // 14) initialize DMR of DeePKS - this->ld.init_DMR(ucell, orb_, this->pv, this->gd); + this->deepks.ld.init_DMR(ucell, orb_, this->pv, this->gd); #endif // 15) two cases are considered: diff --git a/source/source_esolver/lcao_others.cpp b/source/source_esolver/lcao_others.cpp index 4d4199ef9b..8237c3a700 100644 --- a/source/source_esolver/lcao_others.cpp +++ b/source/source_esolver/lcao_others.cpp @@ -249,7 +249,7 @@ void ESolver_KS_LCAO::others(UnitCell& ucell, const int istep) { const Parallel_Orbitals* pv = &this->pv; // allocate , phialpha is different every ion step, so it is allocated here - DeePKS_domain::allocate_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, pv, this->ld.phialpha); + DeePKS_domain::allocate_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, pv, this->deepks.ld.phialpha); // build and save at beginning DeePKS_domain::build_phialpha(PARAM.inp.cal_force, ucell, @@ -257,7 +257,7 @@ void ESolver_KS_LCAO::others(UnitCell& ucell, const int istep) this->gd, pv, *(two_center_bundle_.overlap_orb_alpha), - this->ld.phialpha); + this->deepks.ld.phialpha); if (PARAM.inp.deepks_out_unittest) { @@ -266,7 +266,7 @@ void ESolver_KS_LCAO::others(UnitCell& ucell, const int istep) orb_, this->gd, pv, - this->ld.phialpha, + this->deepks.ld.phialpha, GlobalV::MY_RANK); } } From 605d2ac6d4b07a26cc6f270c8ae775a2052f45d9 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Wed, 8 Oct 2025 21:42:54 +0800 Subject: [PATCH 15/16] fix a few bugs --- source/source_esolver/esolver_ks_lcao.cpp | 2 +- source/source_esolver/esolver_ks_lcao.h | 2 +- source/source_lcao/hamilt_lcao.cpp | 14 +++++++------- source/source_lcao/hamilt_lcao.h | 2 +- source/source_lcao/setup_deepks.cpp | 21 +++++++++++---------- source/source_lcao/setup_deepks.h | 11 +++++------ 6 files changed, 26 insertions(+), 26 deletions(-) diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index ebf57154a3..0703a33c74 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -145,7 +145,7 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa } // 13) init deepks - this->deepks.before_runner(ucell, this->kv, this->orb_, this->pv, PARAM.inp); + this->deepks.before_runner(ucell, kv.get_nks(), this->orb_, this->pv, PARAM.inp); // 14) set occupations, tddft does not need to set occupations in the first scf if (inp.ocp && inp.esolver_type != "tddft") diff --git a/source/source_esolver/esolver_ks_lcao.h b/source/source_esolver/esolver_ks_lcao.h index b428e8fc3f..13359f7c8e 100644 --- a/source/source_esolver/esolver_ks_lcao.h +++ b/source/source_esolver/esolver_ks_lcao.h @@ -98,7 +98,7 @@ class ESolver_KS_LCAO : public ESolver_KS bool have_force = false; // deepks method, mohan add 2025-10-08 - DeePKS deepks; + Setup_DeePKS deepks; // exact-exchange energy, mohan add 2025-10-08 Exx_NAO exx_nao; diff --git a/source/source_lcao/hamilt_lcao.cpp b/source/source_lcao/hamilt_lcao.cpp index c335e7bc4d..389b9812cf 100644 --- a/source/source_lcao/hamilt_lcao.cpp +++ b/source/source_lcao/hamilt_lcao.cpp @@ -1,4 +1,4 @@ -#include "hamilt_lcao.h" +#include "source_lcao/hamilt_lcao.h" #include "source_base/global_variable.h" #include "source_base/memory.h" @@ -80,7 +80,7 @@ HamiltLCAO::HamiltLCAO(Gint_Gamma* GG_in, const TwoCenterBundle& two_center_bundle, const LCAO_Orbitals& orb, elecstate::DensityMatrix* DM_in, - DeePKS &deepks + Setup_DeePKS &deepks #ifdef __EXX , const int istep, @@ -202,7 +202,7 @@ HamiltLCAO::HamiltLCAO(Gint_Gamma* GG_in, #ifdef __MLALGO if (PARAM.inp.deepks_scf) { - Operator* deepks = new DeePKS>(this->hsk, + Operator* deepks_op = new DeePKS>(this->hsk, this->kv->kvec_d, this->hR, // no explicit call yet &ucell, @@ -211,9 +211,9 @@ HamiltLCAO::HamiltLCAO(Gint_Gamma* GG_in, &orb, this->kv->get_nks(), DM_in, - deepks.ld); - this->getOperator()->add(deepks); - this->V_delta_R = dynamic_cast>*>(deepks)->get_V_delta_R(); + &deepks.ld); + this->getOperator()->add(deepks_op); + this->V_delta_R = dynamic_cast>*>(deepks_op)->get_V_delta_R(); } #endif @@ -336,7 +336,7 @@ HamiltLCAO::HamiltLCAO(Gint_Gamma* GG_in, &orb, this->kv->get_nks(), DM_in, - deepks.ld); + &deepks.ld); this->getOperator()->add(deepks_op); this->V_delta_R = dynamic_cast>*>(deepks_op)->get_V_delta_R(); } diff --git a/source/source_lcao/hamilt_lcao.h b/source/source_lcao/hamilt_lcao.h index cde0269629..3a0a41eb7e 100644 --- a/source/source_lcao/hamilt_lcao.h +++ b/source/source_lcao/hamilt_lcao.h @@ -49,7 +49,7 @@ class HamiltLCAO : public Hamilt const TwoCenterBundle& two_center_bundle, const LCAO_Orbitals& orb, elecstate::DensityMatrix* DM_in, - DeePKS &deepks + Setup_DeePKS &deepks #ifdef __EXX , const int istep, diff --git a/source/source_lcao/setup_deepks.cpp b/source/source_lcao/setup_deepks.cpp index 08c1904bd5..36987a8da6 100644 --- a/source/source_lcao/setup_deepks.cpp +++ b/source/source_lcao/setup_deepks.cpp @@ -1,31 +1,32 @@ #include "source_lcao/setup_deepks.h" +#include "source_lcao/LCAO_domain.h" template -DeePKS::DeePKS(){} +Setup_DeePKS::Setup_DeePKS(){} template -DeePKS::~DeePKS(){} +Setup_DeePKS::~Setup_DeePKS(){} template -void DeePKS::before_runner(UnitCell& ucell, // unitcell - K_Vectors &kv, // k points +void Setup_DeePKS::before_runner(const UnitCell& ucell, // unitcell + const int nks, // number of k points const LCAO_Orbitals &orb, // orbital info const Parallel_Orbitals &pv, // parallel orbitals const Input_para& inp) { #ifdef __MLALGO - LCAO_domain::DeePKS_init(ucell, pv, kv.get_nks(), orb_, this->ld, GlobalV::ofs_running); + LCAO_domain::DeePKS_init(ucell, pv, kv.get_nks(), orb, this->ld, GlobalV::ofs_running); if (inp.deepks_scf) { // load the DeePKS model from deep neural network - DeePKS_domain::load_model(inp.deepks_model, ld.model_deepks); + DeePKS_domain::load_model(inp.deepks_model, this->ld.model_deepks); // read pdm from file for NSCF or SCF-restart, do it only once in whole calculation DeePKS_domain::read_pdm((inp.init_chg == "file"), inp.deepks_equiv, - ld.init_pdm, ucell.nat, orb_.Alpha[0].getTotal_nchi() * ucell.nat, - ld.lmaxd, ld.inl2l, *orb_.Alpha, ld.pdm); + this->ld.init_pdm, ucell.nat, orb.Alpha[0].getTotal_nchi() * ucell.nat, + this->ld.lmaxd, this->ld.inl2l, *orb.Alpha, this->ld.pdm); } #endif } -template class DeePKS; -template class DeePKS>; +template class Setup_DeePKS; +template class Setup_DeePKS>; diff --git a/source/source_lcao/setup_deepks.h b/source/source_lcao/setup_deepks.h index fffd1c3d8e..47a37d9992 100644 --- a/source/source_lcao/setup_deepks.h +++ b/source/source_lcao/setup_deepks.h @@ -2,7 +2,6 @@ #define SETUP_DEEPKS_H #include "source_cell/unitcell.h" // use unitcell -#include "source_cell/klist.h" // k points #include "source_io/module_parameter/input_parameter.h" // Input_para #include "source_basis/module_ao/parallel_orbitals.h" // parallel orbitals #include "source_basis/module_ao/ORB_read.h" // orb @@ -13,20 +12,20 @@ template -class DeePKS +class Setup_DeePKS { public: - DeePKS(); - ~DeePKS(); + Setup_DeePKS(); + ~Setup_DeePKS(); #ifdef __MLALGO LCAO_Deepks ld; #endif void before_runner( - UnitCell& ucell, // unitcell - K_Vectors &kv, // k points + const UnitCell& ucell, // unitcell + const int nks, // k points const LCAO_Orbitals &orb, // orbital info const Parallel_Orbitals &pv, // parallel orbitals const Input_para& inp); From e1ebff389545e9a36cc91e82572595954b61b878 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Wed, 8 Oct 2025 22:15:55 +0800 Subject: [PATCH 16/16] DeePKS has been used in operator, fix it --- source/source_esolver/esolver_ks_lcao.cpp | 12 ++++++------ source/source_io/ctrl_scf_lcao.cpp | 8 ++++---- source/source_io/ctrl_scf_lcao.h | 2 +- source/source_lcao/setup_deepks.cpp | 6 +++--- source/source_lcao/setup_deepks.h | 4 ++-- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index 0703a33c74..c9753b3e8f 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -145,7 +145,7 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa } // 13) init deepks - this->deepks.before_runner(ucell, kv.get_nks(), this->orb_, this->pv, PARAM.inp); + this->deepks.before_runner(ucell, this->kv.get_nks(), this->orb_, this->pv, PARAM.inp); // 14) set occupations, tddft does not need to set occupations in the first scf if (inp.ocp && inp.esolver_type != "tddft") @@ -403,7 +403,7 @@ void ESolver_KS_LCAO::iter_init(UnitCell& ucell, const int istep, const #ifdef __MLALGO // the density matrixes of DeePKS have been updated in each iter - ld.set_hr_cal(true); + this->deepks.ld.set_hr_cal(true); // HR in HamiltLCAO should be recalculate if (PARAM.inp.deepks_scf) @@ -551,10 +551,10 @@ void ESolver_KS_LCAO::iter_finish(UnitCell& ucell, const int istep, int& #ifdef __MLALGO if (PARAM.inp.deepks_scf) { - ld.dpks_cal_e_delta_band(dm_vec, this->kv.get_nks()); - DeePKS_domain::update_dmr(this->kv.kvec_d, dm_vec, ucell, orb_, this->pv, this->gd, ld.dm_r); - estate->f_en.edeepks_scf = ld.E_delta - ld.e_delta_band; - estate->f_en.edeepks_delta = ld.E_delta; + this->deepks.ld.dpks_cal_e_delta_band(dm_vec, this->kv.get_nks()); + DeePKS_domain::update_dmr(this->kv.kvec_d, dm_vec, ucell, orb_, this->pv, this->gd, this->deepks.ld.dm_r); + estate->f_en.edeepks_scf = this->deepks.ld.E_delta - this->deepks.ld.e_delta_band; + estate->f_en.edeepks_delta = this->deepks.ld.E_delta; } #endif diff --git a/source/source_io/ctrl_scf_lcao.cpp b/source/source_io/ctrl_scf_lcao.cpp index c11f954c10..baf9257ca4 100644 --- a/source/source_io/ctrl_scf_lcao.cpp +++ b/source/source_io/ctrl_scf_lcao.cpp @@ -51,7 +51,7 @@ void ctrl_scf_lcao(UnitCell& ucell, const ModulePW::PW_Basis_Big* pw_big, // for Wannier90 const Structure_Factor& sf, // for Wannier90 rdmft::RDMFT &rdmft_solver, // for RDMFT - DeePKS &deepks, + Setup_DeePKS &deepks, Exx_NAO &exx_nao, const int istep) { @@ -453,7 +453,7 @@ template void ModuleIO::ctrl_scf_lcao(UnitCell& ucell, const ModulePW::PW_Basis_Big* pw_big, // for Wannier90 const Structure_Factor& sf, // for Wannier90 rdmft::RDMFT &rdmft_solver, // for RDMFT - DeePKS &deepks, + Setup_DeePKS &deepks, Exx_NAO &exx_nao, const int istep); @@ -475,7 +475,7 @@ template void ModuleIO::ctrl_scf_lcao, double>(UnitCell& uc const ModulePW::PW_Basis_Big* pw_big, // for Wannier90 const Structure_Factor& sf, // for Wannier90 rdmft::RDMFT, double> &rdmft_solver, // for RDMFT - DeePKS> &deepks, + Setup_DeePKS> &deepks, Exx_NAO> &exx_nao, const int istep); @@ -496,7 +496,7 @@ template void ModuleIO::ctrl_scf_lcao, std::complex const ModulePW::PW_Basis_Big* pw_big, // for Wannier90 const Structure_Factor& sf, // for Wannier90 rdmft::RDMFT, std::complex> &rdmft_solver, // for RDMFT - DeePKS> &deepks, + Setup_DeePKS> &deepks, Exx_NAO> &exx_nao, const int istep); diff --git a/source/source_io/ctrl_scf_lcao.h b/source/source_io/ctrl_scf_lcao.h index 107f00595d..c67ea465b0 100644 --- a/source/source_io/ctrl_scf_lcao.h +++ b/source/source_io/ctrl_scf_lcao.h @@ -38,7 +38,7 @@ namespace ModuleIO const ModulePW::PW_Basis_Big* pw_big, // for Wannier90 const Structure_Factor& sf, // for Wannier90 rdmft::RDMFT &rdmft_solver, // for RDMFT - DeePKS &deepks, + Setup_DeePKS &deepks, Exx_NAO &exx_nao, const int istep); } diff --git a/source/source_lcao/setup_deepks.cpp b/source/source_lcao/setup_deepks.cpp index 36987a8da6..8608e65095 100644 --- a/source/source_lcao/setup_deepks.cpp +++ b/source/source_lcao/setup_deepks.cpp @@ -11,11 +11,11 @@ template void Setup_DeePKS::before_runner(const UnitCell& ucell, // unitcell const int nks, // number of k points const LCAO_Orbitals &orb, // orbital info - const Parallel_Orbitals &pv, // parallel orbitals - const Input_para& inp) + Parallel_Orbitals &pv, // parallel orbitals + const Input_para &inp) { #ifdef __MLALGO - LCAO_domain::DeePKS_init(ucell, pv, kv.get_nks(), orb, this->ld, GlobalV::ofs_running); + LCAO_domain::DeePKS_init(ucell, pv, nks, orb, this->ld, GlobalV::ofs_running); if (inp.deepks_scf) { // load the DeePKS model from deep neural network diff --git a/source/source_lcao/setup_deepks.h b/source/source_lcao/setup_deepks.h index 47a37d9992..6299ad581b 100644 --- a/source/source_lcao/setup_deepks.h +++ b/source/source_lcao/setup_deepks.h @@ -27,8 +27,8 @@ class Setup_DeePKS const UnitCell& ucell, // unitcell const int nks, // k points const LCAO_Orbitals &orb, // orbital info - const Parallel_Orbitals &pv, // parallel orbitals - const Input_para& inp); + Parallel_Orbitals &pv, // parallel orbitals + const Input_para &inp); };