Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions source/source_esolver/esolver_gets.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ void ESolver_GetS::runner(UnitCell& ucell, const int istep)
this->kv,
*(two_center_bundle_.overlap_orb),
orb_.cutoffs());
dynamic_cast<hamilt::OperatorLCAO<std::complex<double>, std::complex<double>>*>(this->p_hamilt->ops)
->contributeHR();
auto* hamilt_ptr = static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt);
auto* ops_ptr = dynamic_cast<hamilt::OperatorLCAO<std::complex<double>, std::complex<double>>*>(hamilt_ptr->ops);
ops_ptr->contributeHR();
}
else
{
Expand All @@ -119,13 +120,16 @@ void ESolver_GetS::runner(UnitCell& ucell, const int istep)
this->kv,
*(two_center_bundle_.overlap_orb),
orb_.cutoffs());
dynamic_cast<hamilt::OperatorLCAO<std::complex<double>, double>*>(this->p_hamilt->ops)->contributeHR();
auto* hamilt_ptr = static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt);
auto* ops_ptr = dynamic_cast<hamilt::OperatorLCAO<std::complex<double>, double>*>(hamilt_ptr->ops);
ops_ptr->contributeHR();
}
}

const std::string fn = PARAM.globalv.global_out_dir + "sr_nao.csr";

ModuleIO::output_SR(pv, gd, this->p_hamilt, fn);
auto* hamilt_ptr = static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt);
ModuleIO::output_SR(pv, gd, hamilt_ptr, fn);

if (PARAM.inp.out_mat_r)
{
Expand Down
5 changes: 3 additions & 2 deletions source/source_esolver/esolver_ks.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "source_estate/module_charge/charge_mixing.h" // use charge mixing
#include "source_psi/psi.h" // use electronic wave functions
#include "source_hamilt/hamilt.h" // use Hamiltonian
#include "source_hamilt/hamilt_base.h" // use Hamiltonian base class
#include "source_lcao/module_dftu/dftu.h" // mohan add 20251107

namespace ModuleESolver
Expand Down Expand Up @@ -47,8 +48,8 @@ class ESolver_KS : public ESolver_FP
//! Something to do after SCF iterations when SCF is converged or comes to the max iter step.
virtual void after_scf(UnitCell& ucell, const int istep, const bool conv_esolver) override;

//! Hamiltonian
hamilt::Hamilt<T, Device>* p_hamilt = nullptr;
//! Hamiltonian (base class pointer, actual type determined at runtime)
hamilt::HamiltBase* p_hamilt = nullptr;

//! PW for wave functions, only used in KSDFT, not in OFDFT
ModulePW::PW_Basis_K* pw_wfc = nullptr;
Expand Down
4 changes: 2 additions & 2 deletions source/source_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ void ESolver_KS_LCAO<TK, TR>::before_scf(UnitCell& ucell, const int istep)
{
//! 13.1.2) init charge density from Hamiltonian matrix file
LCAO_domain::init_chg_hr<TK, TR>(PARAM.globalv.global_readin_dir, PARAM.inp.nspin,
this->p_hamilt, ucell, &(this->pv), this->psi[0], this->pelec, *this->dmat.dm,
static_cast<hamilt::Hamilt<TK>*>(this->p_hamilt), ucell, &(this->pv), this->psi[0], this->pelec, *this->dmat.dm,
this->chr, PARAM.inp.ks_solver);
}
}
Expand Down Expand Up @@ -382,7 +382,7 @@ void ESolver_KS_LCAO<TK, TR>::hamilt2rho_single(UnitCell& ucell, int istep, int
if (!skip_solve)
{
hsolver::HSolverLCAO<TK> hsolver_lcao_obj(&(this->pv), PARAM.inp.ks_solver);
hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, *this->dmat.dm,
hsolver_lcao_obj.solve(static_cast<hamilt::Hamilt<TK>*>(this->p_hamilt), this->psi[0], this->pelec, *this->dmat.dm,
this->chr, PARAM.inp.nspin, skip_charge);
}

Expand Down
12 changes: 6 additions & 6 deletions source/source_esolver/esolver_ks_lcao_tddft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::hamilt2rho_single(UnitCell& ucell,
PARAM.inp.nbands,
PARAM.globalv.nlocal,
this->kv.get_nks(),
this->p_hamilt,
static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt),
this->pv,
this->psi,
this->psi_laststep,
Expand All @@ -255,7 +255,7 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::hamilt2rho_single(UnitCell& ucell,
PARAM.inp.nbands,
PARAM.globalv.nlocal,
this->kv.get_nks(),
this->p_hamilt,
static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt),
this->pv,
this->psi,
this->psi_laststep,
Expand All @@ -277,7 +277,7 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::hamilt2rho_single(UnitCell& ucell,
{
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;
hsolver::HSolverLCAO<std::complex<double>> hsolver_lcao_obj(&this->pv, PARAM.inp.ks_solver);
hsolver_lcao_obj.solve(this->p_hamilt,
hsolver_lcao_obj.solve(static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt),
this->psi[0],
this->pelec,
*this->dmat.dm,
Expand Down Expand Up @@ -342,11 +342,11 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::iter_finish(UnitCell& ucell,
{
if (use_tensor && use_lapack)
{
elecstate::cal_edm_tddft_tensor_lapack<Device>(this->pv, this->dmat, this->kv, this->p_hamilt);
elecstate::cal_edm_tddft_tensor_lapack<Device>(this->pv, this->dmat, this->kv, static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt));
}
else
{
elecstate::cal_edm_tddft(this->pv, this->dmat, this->kv, this->p_hamilt);
elecstate::cal_edm_tddft(this->pv, this->dmat, this->kv, static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt));
}
}
}
Expand Down Expand Up @@ -416,7 +416,7 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::store_h_s_psi(UnitCell& ucell,
this->p_hamilt->updateHk(ik);
hamilt::MatrixBlock<std::complex<double>> h_mat;
hamilt::MatrixBlock<std::complex<double>> s_mat;
this->p_hamilt->matrix(h_mat, s_mat);
static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt)->matrix(h_mat, s_mat);

// Store H and S matrices to Hk_laststep and Sk_laststep
if (use_tensor && use_lapack)
Expand Down
2 changes: 1 addition & 1 deletion source/source_esolver/esolver_ks_lcaopw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ namespace ModuleESolver
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;

hsolver::HSolverLIP<T> hsolver_lip_obj(this->pw_wfc);
hsolver_lip_obj.solve(this->p_hamilt, this->stp.psi_t[0], this->pelec,
hsolver_lip_obj.solve(static_cast<hamilt::Hamilt<T>*>(this->p_hamilt), this->stp.psi_t[0], this->pelec,
*this->psi_local, skip_charge,ucell.tpiba,ucell.nat);

// add exx
Expand Down
6 changes: 3 additions & 3 deletions source/source_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,10 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
// init DFT+U is done in "before_all_runners" in LCAO basis. This should be refactored, mohan note 2025-11-06
pw::setup_pot(istep, ucell, this->kv, this->sf, this->pelec, this->Pgrid,
this->chr, this->locpp, this->ppcell, this->dftu, this->vsep_cell,
this->stp.psi_t, this->p_hamilt, this->pw_wfc, this->pw_rhod, PARAM.inp);
this->stp.psi_t, static_cast<hamilt::Hamilt<T, Device>*>(this->p_hamilt), this->pw_wfc, this->pw_rhod, PARAM.inp);

// setup psi (electronic wave functions)
this->stp.init(this->p_hamilt);
this->stp.init(static_cast<hamilt::Hamilt<T, Device>*>(this->p_hamilt));

//! Setup EXX helper for Hamiltonian and psi
exx_helper.before_scf(this->p_hamilt, this->stp.psi_t, PARAM.inp);
Expand Down Expand Up @@ -188,7 +188,7 @@ void ESolver_KS_PW<T, Device>::hamilt2rho_single(UnitCell& ucell, const int iste
hsolver::DiagoIterAssist<T, Device>::need_subspace,
PARAM.inp.use_k_continuity);

hsolver_pw_obj.solve(this->p_hamilt, this->stp.psi_t[0], this->pelec, this->pelec->ekb.c,
hsolver_pw_obj.solve(static_cast<hamilt::Hamilt<T, Device>*>(this->p_hamilt), this->stp.psi_t[0], this->pelec, this->pelec->ekb.c,
GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL, skip_charge, ucell.tpiba, ucell.nat);
}

Expand Down
4 changes: 2 additions & 2 deletions source/source_esolver/esolver_sdft_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ void ESolver_SDFT_PW<T, Device>::hamilt2rho_single(UnitCell& ucell, int istep, i
hsolver::DiagoIterAssist<T, Device>::need_subspace);

hsolver_pw_sdft_obj.solve(ucell,
this->p_hamilt,
static_cast<hamilt::Hamilt<T, Device>*>(this->p_hamilt),
this->stp.psi_t[0],
this->stp.psi_cpu[0],
this->pelec,
Expand Down Expand Up @@ -291,7 +291,7 @@ void ESolver_SDFT_PW<T, Device>::after_all_runners(UnitCell& ucell)
this->pw_wfc,
this->stp.psi_t,
&this->ppcell,
this->p_hamilt,
static_cast<hamilt::Hamilt<std::complex<double>, Device>*>(this->p_hamilt),
this->stoche,
&stowf);
sto_elecond.decide_nche(PARAM.inp.cond_dt, 1e-8, this->nche_sto, PARAM.inp.emin_sto, PARAM.inp.emax_sto);
Expand Down
13 changes: 10 additions & 3 deletions source/source_hamilt/hamilt.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,28 @@
#include "matrixblock.h"
#include "source_psi/psi.h"
#include "operator.h"
#include "hamilt_base.h"

namespace hamilt
{

template <typename T, typename Device = base_device::DEVICE_CPU>
class Hamilt
class Hamilt : public HamiltBase
{
public:
virtual ~Hamilt(){};

/// for target K point, update consequence of hPsi() and matrix()
virtual void updateHk(const int ik){return;}
void updateHk(const int ik) override { return; }

/// refresh status of Hamiltonian, for example, refresh H(R) and S(R) in LCAO case
virtual void refresh(bool yes = true){return;}
void refresh(bool yes = true) override { return; }

/// get the class name
std::string get_classname() const override { return classname; }

/// get the operator chain
void* get_ops() override { return static_cast<void*>(ops); }

/// core function: for solving eigenvalues of Hamiltonian with iterative method
virtual void hPsi(
Expand Down
52 changes: 52 additions & 0 deletions source/source_hamilt/hamilt_base.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#ifndef HAMILT_BASE_H
#define HAMILT_BASE_H

#include <string>

namespace hamilt
{

/**
* @brief Base class for Hamiltonian
*
* This is a non-template base class for Hamilt<T, Device>.
* It provides a common interface for all Hamiltonian types,
* allowing ESolver to manage Hamiltonian without template parameters.
*/
class HamiltBase
{
public:
virtual ~HamiltBase() {}

/**
* @brief Update Hamiltonian for a specific k-point
*
* @param ik k-point index
*/
virtual void updateHk(const int ik) { return; }

/**
* @brief Refresh the status of Hamiltonian
*
* @param yes whether to refresh
*/
virtual void refresh(bool yes = true) { return; }

/**
* @brief Get the class name
*
* @return class name
*/
virtual std::string get_classname() const { return "none"; }

/**
* @brief Get the operator chain (as void* to avoid template)
*
* @return pointer to operator chain
*/
virtual void* get_ops() { return nullptr; }
};

} // namespace hamilt

#endif // HAMILT_BASE_H
Loading