diff --git a/source/source_esolver/esolver_gets.cpp b/source/source_esolver/esolver_gets.cpp index 7eff0f537a..e03e7b8bdc 100644 --- a/source/source_esolver/esolver_gets.cpp +++ b/source/source_esolver/esolver_gets.cpp @@ -108,8 +108,9 @@ void ESolver_GetS::runner(UnitCell& ucell, const int istep) this->kv, *(two_center_bundle_.overlap_orb), orb_.cutoffs()); - dynamic_cast, std::complex>*>(this->p_hamilt->ops) - ->contributeHR(); + auto* hamilt_ptr = static_cast>*>(this->p_hamilt); + auto* ops_ptr = dynamic_cast, std::complex>*>(hamilt_ptr->ops); + ops_ptr->contributeHR(); } else { @@ -119,13 +120,16 @@ void ESolver_GetS::runner(UnitCell& ucell, const int istep) this->kv, *(two_center_bundle_.overlap_orb), orb_.cutoffs()); - dynamic_cast, double>*>(this->p_hamilt->ops)->contributeHR(); + auto* hamilt_ptr = static_cast>*>(this->p_hamilt); + auto* ops_ptr = dynamic_cast, double>*>(hamilt_ptr->ops); + ops_ptr->contributeHR(); } } const std::string fn = PARAM.globalv.global_out_dir + "sr_nao.csr"; - ModuleIO::output_SR(pv, gd, this->p_hamilt, fn); + auto* hamilt_ptr = static_cast>*>(this->p_hamilt); + ModuleIO::output_SR(pv, gd, hamilt_ptr, fn); if (PARAM.inp.out_mat_r) { diff --git a/source/source_esolver/esolver_ks.h b/source/source_esolver/esolver_ks.h index 787b58ba74..1913aa4101 100644 --- a/source/source_esolver/esolver_ks.h +++ b/source/source_esolver/esolver_ks.h @@ -7,6 +7,7 @@ #include "source_estate/module_charge/charge_mixing.h" // use charge mixing #include "source_psi/psi.h" // use electronic wave functions #include "source_hamilt/hamilt.h" // use Hamiltonian +#include "source_hamilt/hamilt_base.h" // use Hamiltonian base class #include "source_lcao/module_dftu/dftu.h" // mohan add 20251107 namespace ModuleESolver @@ -47,8 +48,8 @@ class ESolver_KS : public ESolver_FP //! Something to do after SCF iterations when SCF is converged or comes to the max iter step. virtual void after_scf(UnitCell& ucell, const int istep, const bool conv_esolver) override; - //! Hamiltonian - hamilt::Hamilt* p_hamilt = nullptr; + //! Hamiltonian (base class pointer, actual type determined at runtime) + hamilt::HamiltBase* p_hamilt = nullptr; //! PW for wave functions, only used in KSDFT, not in OFDFT ModulePW::PW_Basis_K* pw_wfc = nullptr; diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index 43e83aa8df..bad1cec7b6 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -179,7 +179,7 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) { //! 13.1.2) init charge density from Hamiltonian matrix file LCAO_domain::init_chg_hr(PARAM.globalv.global_readin_dir, PARAM.inp.nspin, - this->p_hamilt, ucell, &(this->pv), this->psi[0], this->pelec, *this->dmat.dm, + static_cast*>(this->p_hamilt), ucell, &(this->pv), this->psi[0], this->pelec, *this->dmat.dm, this->chr, PARAM.inp.ks_solver); } } @@ -382,7 +382,7 @@ void ESolver_KS_LCAO::hamilt2rho_single(UnitCell& ucell, int istep, int if (!skip_solve) { hsolver::HSolverLCAO hsolver_lcao_obj(&(this->pv), PARAM.inp.ks_solver); - hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, *this->dmat.dm, + hsolver_lcao_obj.solve(static_cast*>(this->p_hamilt), this->psi[0], this->pelec, *this->dmat.dm, this->chr, PARAM.inp.nspin, skip_charge); } diff --git a/source/source_esolver/esolver_ks_lcao_tddft.cpp b/source/source_esolver/esolver_ks_lcao_tddft.cpp index b7641a09fc..05dc8c9233 100644 --- a/source/source_esolver/esolver_ks_lcao_tddft.cpp +++ b/source/source_esolver/esolver_ks_lcao_tddft.cpp @@ -235,7 +235,7 @@ void ESolver_KS_LCAO_TDDFT::hamilt2rho_single(UnitCell& ucell, PARAM.inp.nbands, PARAM.globalv.nlocal, this->kv.get_nks(), - this->p_hamilt, + static_cast>*>(this->p_hamilt), this->pv, this->psi, this->psi_laststep, @@ -255,7 +255,7 @@ void ESolver_KS_LCAO_TDDFT::hamilt2rho_single(UnitCell& ucell, PARAM.inp.nbands, PARAM.globalv.nlocal, this->kv.get_nks(), - this->p_hamilt, + static_cast>*>(this->p_hamilt), this->pv, this->psi, this->psi_laststep, @@ -277,7 +277,7 @@ void ESolver_KS_LCAO_TDDFT::hamilt2rho_single(UnitCell& ucell, { bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false; hsolver::HSolverLCAO> hsolver_lcao_obj(&this->pv, PARAM.inp.ks_solver); - hsolver_lcao_obj.solve(this->p_hamilt, + hsolver_lcao_obj.solve(static_cast>*>(this->p_hamilt), this->psi[0], this->pelec, *this->dmat.dm, @@ -342,11 +342,11 @@ void ESolver_KS_LCAO_TDDFT::iter_finish(UnitCell& ucell, { if (use_tensor && use_lapack) { - elecstate::cal_edm_tddft_tensor_lapack(this->pv, this->dmat, this->kv, this->p_hamilt); + elecstate::cal_edm_tddft_tensor_lapack(this->pv, this->dmat, this->kv, static_cast>*>(this->p_hamilt)); } else { - elecstate::cal_edm_tddft(this->pv, this->dmat, this->kv, this->p_hamilt); + elecstate::cal_edm_tddft(this->pv, this->dmat, this->kv, static_cast>*>(this->p_hamilt)); } } } @@ -416,7 +416,7 @@ void ESolver_KS_LCAO_TDDFT::store_h_s_psi(UnitCell& ucell, this->p_hamilt->updateHk(ik); hamilt::MatrixBlock> h_mat; hamilt::MatrixBlock> s_mat; - this->p_hamilt->matrix(h_mat, s_mat); + static_cast>*>(this->p_hamilt)->matrix(h_mat, s_mat); // Store H and S matrices to Hk_laststep and Sk_laststep if (use_tensor && use_lapack) diff --git a/source/source_esolver/esolver_ks_lcaopw.cpp b/source/source_esolver/esolver_ks_lcaopw.cpp index f9700f5b68..db00b6265d 100644 --- a/source/source_esolver/esolver_ks_lcaopw.cpp +++ b/source/source_esolver/esolver_ks_lcaopw.cpp @@ -146,7 +146,7 @@ namespace ModuleESolver bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false; hsolver::HSolverLIP hsolver_lip_obj(this->pw_wfc); - hsolver_lip_obj.solve(this->p_hamilt, this->stp.psi_t[0], this->pelec, + hsolver_lip_obj.solve(static_cast*>(this->p_hamilt), this->stp.psi_t[0], this->pelec, *this->psi_local, skip_charge,ucell.tpiba,ucell.nat); // add exx diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index e255b95b46..b2976733bf 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -128,10 +128,10 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) // init DFT+U is done in "before_all_runners" in LCAO basis. This should be refactored, mohan note 2025-11-06 pw::setup_pot(istep, ucell, this->kv, this->sf, this->pelec, this->Pgrid, this->chr, this->locpp, this->ppcell, this->dftu, this->vsep_cell, - this->stp.psi_t, this->p_hamilt, this->pw_wfc, this->pw_rhod, PARAM.inp); + this->stp.psi_t, static_cast*>(this->p_hamilt), this->pw_wfc, this->pw_rhod, PARAM.inp); // setup psi (electronic wave functions) - this->stp.init(this->p_hamilt); + this->stp.init(static_cast*>(this->p_hamilt)); //! Setup EXX helper for Hamiltonian and psi exx_helper.before_scf(this->p_hamilt, this->stp.psi_t, PARAM.inp); @@ -188,7 +188,7 @@ void ESolver_KS_PW::hamilt2rho_single(UnitCell& ucell, const int iste hsolver::DiagoIterAssist::need_subspace, PARAM.inp.use_k_continuity); - hsolver_pw_obj.solve(this->p_hamilt, this->stp.psi_t[0], this->pelec, this->pelec->ekb.c, + hsolver_pw_obj.solve(static_cast*>(this->p_hamilt), this->stp.psi_t[0], this->pelec, this->pelec->ekb.c, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL, skip_charge, ucell.tpiba, ucell.nat); } diff --git a/source/source_esolver/esolver_sdft_pw.cpp b/source/source_esolver/esolver_sdft_pw.cpp index 26118eed21..597825cd6d 100644 --- a/source/source_esolver/esolver_sdft_pw.cpp +++ b/source/source_esolver/esolver_sdft_pw.cpp @@ -167,7 +167,7 @@ void ESolver_SDFT_PW::hamilt2rho_single(UnitCell& ucell, int istep, i hsolver::DiagoIterAssist::need_subspace); hsolver_pw_sdft_obj.solve(ucell, - this->p_hamilt, + static_cast*>(this->p_hamilt), this->stp.psi_t[0], this->stp.psi_cpu[0], this->pelec, @@ -291,7 +291,7 @@ void ESolver_SDFT_PW::after_all_runners(UnitCell& ucell) this->pw_wfc, this->stp.psi_t, &this->ppcell, - this->p_hamilt, + static_cast, Device>*>(this->p_hamilt), this->stoche, &stowf); sto_elecond.decide_nche(PARAM.inp.cond_dt, 1e-8, this->nche_sto, PARAM.inp.emin_sto, PARAM.inp.emax_sto); diff --git a/source/source_hamilt/hamilt.h b/source/source_hamilt/hamilt.h index 6d732d7a82..3d554c0fe6 100644 --- a/source/source_hamilt/hamilt.h +++ b/source/source_hamilt/hamilt.h @@ -7,21 +7,28 @@ #include "matrixblock.h" #include "source_psi/psi.h" #include "operator.h" +#include "hamilt_base.h" namespace hamilt { template -class Hamilt +class Hamilt : public HamiltBase { public: virtual ~Hamilt(){}; /// for target K point, update consequence of hPsi() and matrix() - virtual void updateHk(const int ik){return;} + void updateHk(const int ik) override { return; } /// refresh status of Hamiltonian, for example, refresh H(R) and S(R) in LCAO case - virtual void refresh(bool yes = true){return;} + void refresh(bool yes = true) override { return; } + + /// get the class name + std::string get_classname() const override { return classname; } + + /// get the operator chain + void* get_ops() override { return static_cast(ops); } /// core function: for solving eigenvalues of Hamiltonian with iterative method virtual void hPsi( diff --git a/source/source_hamilt/hamilt_base.h b/source/source_hamilt/hamilt_base.h new file mode 100644 index 0000000000..06325bf050 --- /dev/null +++ b/source/source_hamilt/hamilt_base.h @@ -0,0 +1,52 @@ +#ifndef HAMILT_BASE_H +#define HAMILT_BASE_H + +#include + +namespace hamilt +{ + +/** + * @brief Base class for Hamiltonian + * + * This is a non-template base class for Hamilt. + * It provides a common interface for all Hamiltonian types, + * allowing ESolver to manage Hamiltonian without template parameters. + */ +class HamiltBase +{ + public: + virtual ~HamiltBase() {} + + /** + * @brief Update Hamiltonian for a specific k-point + * + * @param ik k-point index + */ + virtual void updateHk(const int ik) { return; } + + /** + * @brief Refresh the status of Hamiltonian + * + * @param yes whether to refresh + */ + virtual void refresh(bool yes = true) { return; } + + /** + * @brief Get the class name + * + * @return class name + */ + virtual std::string get_classname() const { return "none"; } + + /** + * @brief Get the operator chain (as void* to avoid template) + * + * @return pointer to operator chain + */ + virtual void* get_ops() { return nullptr; } +}; + +} // namespace hamilt + +#endif // HAMILT_BASE_H