Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 29 additions & 89 deletions source/source_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,18 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,

const int cur_nbasis = psi.get_current_nbas();

// Shared matrix-blockvector operators used by all iterative solvers.
auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
auto psi_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, cur_nbasis);
psi::Range bands_range(true, 0, 0, nvec - 1);
using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
hpsi_info info(&psi_wrapper, bands_range, hpsi_out);
hm->ops->hPsi(info);
};
auto spsi_func = [hm](const T* psi_in, T* spsi_out, const int ld_psi, const int nvec) {
hm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec);
};

if (this->method == "cg")
{
// wrap the subspace_func into a lambda function
Expand All @@ -272,28 +284,6 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
this->diag_iter_max,
this->nproc_in_pool);

// wrap the hpsi_func and spsi_func into lambda functions
auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
auto psi_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, cur_nbasis);
psi::Range all_bands_range(true, 0, 0, nvec - 1);
using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out);
hm->ops->hPsi(info);
};
auto spsi_func = [this, hm](T* psi_in, T* spsi_out, const int ld_psi, const int nvec) {
if (this->use_uspp)
{
hm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec);
}
else
{
base_device::memory::synchronize_memory_op<T, Device, Device>()(
spsi_out,
psi_in,
static_cast<size_t>(nvec) * static_cast<size_t>(ld_psi));
}
};

DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(
cg.diag(hpsi_func,
spsi_func,
Expand All @@ -313,42 +303,14 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
const int nband_l = psi.get_nbands();
const int nbasis = psi.get_nbasis();
const int ndim = psi.get_current_ngk();
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {

// Convert "pointer data stucture" to a psi::Psi object
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, cur_nbasis);

psi::Range bands_range(true, 0, 0, nvec - 1);

using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);
hm->ops->hPsi(info);
};
DiagoBPCG<T, Device> bpcg(pre_condition.data());
bpcg.init_iter(PARAM.inp.nbands, nband_l, nbasis, ndim);
bpcg.diag(hpsi_func, psi.get_pointer(), eigenvalue, this->ethr_band);
}
else if (this->method == "dav_subspace")
{
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {

// Convert "pointer data stucture" to a psi::Psi object
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, cur_nbasis);

psi::Range bands_range(true, 0, 0, nvec - 1);

using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);
hm->ops->hPsi(info);
};
bool scf = this->calculation_type == "nscf" ? false : true;

auto spsi_func = [hm](T* psi_in, T* spsi_out, const int ld_psi, const int nvec) {
hm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec);
};

Diago_DavSubspace<T, Device> dav_subspace(pre_condition,
psi.get_nbands(),
psi.get_k_first() ? psi.get_current_ngk()
Expand All @@ -361,8 +323,13 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
PARAM.inp.nb2d);

DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(
dav_subspace
.diag(hpsi_func, spsi_func, psi.get_pointer(), psi.get_nbasis(), eigenvalue, this->ethr_band, scf));
dav_subspace.diag(hpsi_func,
spsi_func,
psi.get_pointer(),
psi.get_nbasis(),
eigenvalue,
this->ethr_band,
scf));
}
else if (this->method == "dav")
{
Expand All @@ -383,45 +350,18 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
const int nband = psi.get_nbands(); /// number of eigenpairs sought
const int ld_psi = psi.get_nbasis(); /// leading dimension of psi

// Davidson matrix-blockvector functions
/// wrap hpsi into lambda function, Matrix \times blockvector
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {

// Convert pointer of psi_in to a psi::Psi object
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, cur_nbasis);

psi::Range bands_range(true, 0, 0, nvec - 1);

using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);
hm->ops->hPsi(info);
};

/// wrap spsi into lambda function, Matrix \times blockvector
/// spsi(X, SX, ld, nvec)
/// ld is leading dimension of psi and spsi
auto spsi_func = [hm](const T* psi_in,
T* spsi_out,
const int ld_psi, // Leading dimension of psi and spsi.
const int nvec // Number of vectors(bands)
) {
// sPsi determines S=I or not by PARAM.globalv.use_uspp inside
// sPsi(psi, spsi, nrow, npw, nbands)
hm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec);
};

DiagoDavid<T, Device> david(pre_condition.data(), nband, dim, PARAM.inp.pw_diag_ndim, this->use_paw, comm_info);
// do diag and add davidson iteration counts up to avg_iter
DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(david.diag(hpsi_func,
spsi_func,
ld_psi,
psi.get_pointer(),
eigenvalue,
this->ethr_band,
david_maxiter,
ntry_max,
notconv_max));
DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(
david.diag(hpsi_func,
spsi_func,
ld_psi,
psi.get_pointer(),
eigenvalue,
this->ethr_band,
david_maxiter,
ntry_max,
notconv_max));
}
ModuleBase::timer::tick("HSolverPW", "solve_psik");
return;
Expand Down
Loading