Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 50 additions & 16 deletions source/source_hsolver/diago_cg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ void DiagoCG<T, Device>::diag_once(const ct::Tensor& prec_in,
{
phi_m.sync(psi[m]);
// copy psi_in into internal psi, m=0 has been done in Constructor
this->spsi_func_(phi_m, sphi); // sphi = S|psi(m)>
this->spsi_func_(phi_m.data<T>(), sphi.data<T>(), this->n_basis_, 1); // sphi = S|psi(m)>
this->schmit_orth(m, psi, sphi, phi_m);
this->spsi_func_(phi_m, sphi); // sphi = S|psi(m)>
this->hpsi_func_(phi_m, hphi); // hphi = H|psi(m)>
this->spsi_func_(phi_m.data<T>(), sphi.data<T>(), this->n_basis_, 1); // sphi = S|psi(m)>
this->hpsi_func_(phi_m.data<T>(), hphi.data<T>(), this->n_basis_, 1); // hphi = H|psi(m)>

eigen_pack[m] = dot_real_op()(this->n_basis_, phi_m.data<T>(), hphi.data<T>());

Expand All @@ -150,8 +150,8 @@ void DiagoCG<T, Device>::diag_once(const ct::Tensor& prec_in,
g0,
cg); // Tensor&

this->hpsi_func_(cg, pphi);
this->spsi_func_(cg, scg);
this->hpsi_func_(cg.data<T>(), pphi.data<T>(), this->n_basis_, 1);
this->spsi_func_(cg.data<T>(), scg.data<T>(), this->n_basis_, 1);

converged = this->update_psi(pphi,
cg,
Expand Down Expand Up @@ -264,7 +264,7 @@ void DiagoCG<T, Device>::orth_grad(const ct::Tensor& psi,
ct::Tensor& scg,
ct::Tensor& lagrange)
{
this->spsi_func_(grad, scg); // scg = S|grad>
this->spsi_func_(grad.data<T>(), scg.data<T>(), this->n_basis_, 1); // scg = S|grad>
ModuleBase::gemv_op<T, Device>()('C',
this->n_basis_,
m,
Expand Down Expand Up @@ -576,21 +576,47 @@ bool DiagoCG<T, Device>::test_exit_cond(const int& ntry, const int& notconv) con
}

template <typename T, typename Device>
double DiagoCG<T, Device>::diag(const Func& hpsi_func,
const Func& spsi_func,
ct::Tensor& psi,
ct::Tensor& eigen,
const std::vector<double>& ethr_band,
const ct::Tensor& prec)
double DiagoCG<T, Device>::diag(const HPsiFunc& hpsi_func,
const SPsiFunc& spsi_func,
const int ld_psi,
const int nband,
const int dim,
T* psi_in,
Real* eigenvalue_in,
const std::vector<double>& ethr_band,
const Real* prec)
{
REQUIRES_OK(ld_psi >= dim, "DiagoCG::diag: ld_psi must be >= dim");
REQUIRES_OK(static_cast<int>(ethr_band.size()) >= nband,
"DiagoCG::diag: ethr_band size must be >= nband");

auto psi = ct::TensorMap(psi_in,
ct::DataTypeToEnum<T>::value,
ct::DeviceTypeToEnum<ct_Device>::value,
ct::TensorShape({nband, ld_psi}));
auto eigen = ct::TensorMap(eigenvalue_in,
ct::DataTypeToEnum<Real>::value,
ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
ct::TensorShape({nband}));

ct::Tensor prec_tensor;
if (prec != nullptr)
{
prec_tensor = ct::TensorMap(const_cast<Real*>(prec),
ct::DataTypeToEnum<Real>::value,
ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
ct::TensorShape({dim}))
.template to_device<ct_Device>();
}

/// record the times of trying iterative diagonalization
int ntry = 0;
this->notconv_ = 0;
hpsi_func_ = hpsi_func;
spsi_func_ = spsi_func;

// create a new slice of psi to do cg diagonalization
ct::Tensor psi_temp = psi.slice({0, 0}, {int(psi.shape().dim_size(0)), int(prec.shape().dim_size(0))});
ct::Tensor psi_temp = psi.slice({0, 0}, {nband, dim});
do
{
// subspace diagonalization to get a better starting guess
Expand All @@ -601,21 +627,29 @@ double DiagoCG<T, Device>::diag(const Func& hpsi_func,
{
ct::TensorMap psi_map = ct::TensorMap(psi.data(), psi_temp);
const bool assume_S_orthogonal = true;
this->subspace_func_(psi_temp, psi_map, assume_S_orthogonal);
this->subspace_func_(psi_temp.data<T>(),
psi_map.data<T>(),
dim,
nband,
assume_S_orthogonal);
psi_temp.sync(psi_map);
}
else if (need_subspace_)
{
ct::TensorMap psi_map = ct::TensorMap(psi.data(), psi_temp);
const bool assume_S_orthogonal = false;
this->subspace_func_(psi_temp, psi_map, assume_S_orthogonal);
this->subspace_func_(psi_temp.data<T>(),
psi_map.data<T>(),
dim,
nband,
assume_S_orthogonal);
psi_temp.sync(psi_map);
}


++ntry;
avg_iter_ += 1.0;
this->diag_once(prec, psi_temp, eigen, ethr_band);
this->diag_once(prec_tensor, psi_temp, eigen, ethr_band);
} while (this->test_exit_cond(ntry, this->notconv_));

if (this->notconv_ > std::max(5, this->n_band_ / 4))
Expand Down
24 changes: 14 additions & 10 deletions source/source_hsolver/diago_cg.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ class DiagoCG final
using Real = typename GetTypeReal<T>::type;
using ct_Device = typename ct::PsiToContainer<Device>::type;
public:
using Func = std::function<void(const ct::Tensor&, ct::Tensor&)>;
using SubspaceFunc = std::function<void(const ct::Tensor&, ct::Tensor&, const bool)>;
using HPsiFunc = std::function<void(T*, T*, const int, const int)>;
using SPsiFunc = std::function<void(T*, T*, const int, const int)>;
using SubspaceFunc = std::function<void(T*, T*, const int, const int, const bool)>;
// Constructor need:
// 1. temporary mock of Hamiltonian "Hamilt_PW"
// 2. precondition pointer should point to place of precondition array.
Expand All @@ -43,12 +44,15 @@ class DiagoCG final
// refactor hpsi_info
// this is the diag() function for CG method
// returns avg_iter
double diag(const Func& hpsi_func,
const Func& spsi_func,
ct::Tensor& psi,
ct::Tensor& eigen,
const std::vector<double>& ethr_band,
const ct::Tensor& prec = {});
double diag(const HPsiFunc& hpsi_func,
const SPsiFunc& spsi_func,
const int ld_psi,
const int nband,
const int dim,
T* psi_in,
Real* eigenvalue_in,
const std::vector<double>& ethr_band,
const Real* prec = nullptr);

private:
Device * ctx_ = {};
Expand Down Expand Up @@ -77,9 +81,9 @@ class DiagoCG final

bool need_subspace_ = false;
/// A function object that performs the hPsi calculation.
Func hpsi_func_ = nullptr;
HPsiFunc hpsi_func_ = nullptr;
/// A function object that performs the sPsi calculation.
Func spsi_func_ = nullptr;
SPsiFunc spsi_func_ = nullptr;
/// A function object that performs the subspace calculation.
SubspaceFunc subspace_func_ = nullptr;

Expand Down
100 changes: 28 additions & 72 deletions source/source_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,27 +254,15 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
// wrap the subspace_func into a lambda function
// if S_orth is true, then assume psi is S-orthogonal, solve standard eigenproblem
// otherwise, solve generalized eigenproblem
auto subspace_func = [hm, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& psi_out, const bool S_orth) {
// psi_in should be a 2D tensor:
// psi_in.shape() = [nbands, nbasis]
const auto ndim = psi_in.shape().ndim();
REQUIRES_OK(ndim == 2, "dims of psi_in should be less than or equal to 2");
// Convert a Tensor object to a psi::Psi object
auto psi_in_wrapper = psi::Psi<T, Device>(psi_in.data<T>(),
1,
psi_in.shape().dim_size(0),
psi_in.shape().dim_size(1),
cur_nbasis);
auto psi_out_wrapper = psi::Psi<T, Device>(psi_out.data<T>(),
1,
psi_out.shape().dim_size(0),
psi_out.shape().dim_size(1),
cur_nbasis);
auto eigen = ct::Tensor(ct::DataTypeToEnum<Real>::value,
ct::DeviceType::CpuDevice,
ct::TensorShape({psi_in.shape().dim_size(0)}));

DiagoIterAssist<T, Device>::diag_subspace(hm, psi_in_wrapper, psi_out_wrapper, eigen.data<Real>());
auto subspace_func = [hm, cur_nbasis](T* psi_in,
T* psi_out,
const int ld_psi,
const int nband,
const bool S_orth) {
auto psi_in_wrapper = psi::Psi<T, Device>(psi_in, 1, nband, ld_psi, cur_nbasis);
auto psi_out_wrapper = psi::Psi<T, Device>(psi_out, 1, nband, ld_psi, cur_nbasis);
std::vector<Real> eigen(nband, 0.0);
DiagoIterAssist<T, Device>::diag_subspace(hm, psi_in_wrapper, psi_out_wrapper, eigen.data());
};
DiagoCG<T, Device> cg(this->basis_type,
this->calculation_type,
Expand All @@ -284,70 +272,38 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
this->diag_iter_max,
this->nproc_in_pool);

// wrap the hpsi_func and spsi_func into a lambda function
using ct_Device = typename ct::PsiToContainer<Device>::type;

// wrap the hpsi_func and spsi_func into a lambda function
auto hpsi_func = [hm, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
// psi_in should be a 2D tensor:
// psi_in.shape() = [nbands, nbasis]
const auto ndim = psi_in.shape().ndim();
REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2");
// Convert a Tensor object to a psi::Psi object
auto psi_wrapper = psi::Psi<T, Device>(psi_in.data<T>(),
1,
ndim == 1 ? 1 : psi_in.shape().dim_size(0),
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1),
cur_nbasis);
psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1);
// wrap the hpsi_func and spsi_func into lambda functions
auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
auto psi_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, cur_nbasis);
psi::Range all_bands_range(true, 0, 0, nvec - 1);
using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data<T>());
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out);
hm->ops->hPsi(info);
};
auto spsi_func = [this, hm](const ct::Tensor& psi_in, ct::Tensor& spsi_out) {
// psi_in should be a 2D tensor:
// psi_in.shape() = [nbands, nbasis]
const auto ndim = psi_in.shape().ndim();
REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2");

auto spsi_func = [this, hm](T* psi_in, T* spsi_out, const int ld_psi, const int nvec) {
if (this->use_uspp)
{
// Convert a Tensor object to a psi::Psi object
hm->sPsi(psi_in.data<T>(),
spsi_out.data<T>(),
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1),
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1),
ndim == 1 ? 1 : psi_in.shape().dim_size(0));
hm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec);
}
else
{
base_device::memory::synchronize_memory_op<T, Device, Device>()(
spsi_out.data<T>(),
psi_in.data<T>(),
static_cast<size_t>((ndim == 1 ? 1 : psi_in.shape().dim_size(0))
* (ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1))));
spsi_out,
psi_in,
static_cast<size_t>(nvec) * static_cast<size_t>(ld_psi));
}
};

auto psi_tensor = ct::TensorMap(psi.get_pointer(),
ct::DataTypeToEnum<T>::value,
ct::DeviceTypeToEnum<ct_Device>::value,
ct::TensorShape({psi.get_nbands(), psi.get_nbasis()}));

auto eigen_tensor = ct::TensorMap(eigenvalue,
ct::DataTypeToEnum<Real>::value,
ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
ct::TensorShape({psi.get_nbands()}));

auto prec_tensor = ct::TensorMap(pre_condition.data(),
ct::DataTypeToEnum<Real>::value,
ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
ct::TensorShape({static_cast<int>(pre_condition.size())}))
.to_device<ct_Device>()
.slice({0}, {psi.get_current_ngk()});

DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, this->ethr_band, prec_tensor)
cg.diag(hpsi_func,
spsi_func,
psi.get_nbasis(),
psi.get_nbands(),
psi.get_current_ngk(),
psi.get_pointer(),
eigenvalue,
this->ethr_band,
pre_condition.data())
);
// TODO: Double check tensormap's potential problem
// ct::TensorMap(psi.get_pointer(), psi_tensor, {psi.get_nbands(), psi.get_nbasis()}).sync(psi_tensor);
Expand Down
69 changes: 34 additions & 35 deletions source/source_hsolver/test/diago_cg_float_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,19 @@ class DiagoCGPrepare
// New interface of cg method
/**************************************************************/
// warp the subspace_func into a lambda function
auto subspace_func = [ha](const ct::Tensor& psi_in, ct::Tensor& psi_out, const bool S_orth) { /*do nothing*/ };
auto subspace_func = [ha](std::complex<float>* psi_in,
std::complex<float>* psi_out,
const int ld_psi,
const int nband,
const bool S_orth) {
auto psi_in_wrapper = psi::Psi<std::complex<float>>(psi_in, 1, nband, ld_psi, true);
auto psi_out_wrapper = psi::Psi<std::complex<float>>(psi_out, 1, nband, ld_psi, true);
std::vector<float> eigen(nband, 0.0f);
hsolver::DiagoIterAssist<std::complex<float>>::diag_subspace(ha,
psi_in_wrapper,
psi_out_wrapper,
eigen.data());
};
hsolver::DiagoCG<std::complex<float>> cg(
PARAM.input.basis_type,
PARAM.input.calculation,
Expand All @@ -156,46 +168,33 @@ class DiagoCGPrepare
float start, end;
start = MPI_Wtime();

auto hpsi_func = [ha](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
const auto ndim = psi_in.shape().ndim();
REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2");
auto psi_wrapper = psi::Psi<std::complex<float>>(
psi_in.data<std::complex<float>>(), 1,
ndim == 1 ? 1 : psi_in.shape().dim_size(0),
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), true);
psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1);
auto hpsi_func = [ha](std::complex<float>* psi_in,
std::complex<float>* hpsi_out,
const int ld_psi,
const int nvec) {
auto psi_wrapper = psi::Psi<std::complex<float>>(psi_in, 1, nvec, ld_psi, true);
psi::Range all_bands_range(true, 0, 0, nvec - 1);
using hpsi_info = typename hamilt::Operator<std::complex<float>>::hpsi_info;
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data<std::complex<float>>());
hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out);
ha->ops->hPsi(info);
};
auto spsi_func = [ha](const ct::Tensor& psi_in, ct::Tensor& spsi_out) {
const auto ndim = psi_in.shape().ndim();
REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2");
ha->sPsi(psi_in.data<std::complex<float>>(), spsi_out.data<std::complex<float>>(),
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1),
ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1),
ndim == 1 ? 1 : psi_in.shape().dim_size(0));
auto spsi_func = [ha](std::complex<float>* psi_in,
std::complex<float>* spsi_out,
const int ld_psi,
const int nvec) {
ha->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec);
};
auto psi_tensor = ct::TensorMap(
psi_local.get_pointer(),
ct::DataType::DT_COMPLEX,
ct::DeviceType::CpuDevice,
ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_current_ngk()});
auto eigen_tensor = ct::TensorMap(
en,
ct::DataType::DT_FLOAT,
ct::DeviceType::CpuDevice,
ct::TensorShape({psi_local.get_nbands()}));
auto prec_tensor = ct::TensorMap(
precondition_local,
ct::DataType::DT_FLOAT,
ct::DeviceType::CpuDevice,
ct::TensorShape({static_cast<int>(psi_local.get_current_ngk())})).slice({0}, {psi_local.get_current_ngk()});

std::vector<double> ethr_band(nband, 1e-5);
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor);
// TODO: Double check tensormap's potential problem
ct::TensorMap(psi_local.get_pointer(), psi_tensor, {psi_local.get_nbands(), psi_local.get_nbasis()}).sync(psi_tensor);
cg.diag(hpsi_func,
spsi_func,
psi_local.get_nbasis(),
psi_local.get_nbands(),
psi_local.get_current_ngk(),
psi_local.get_pointer(),
en,
ethr_band,
precondition_local);
/**************************************************************/

end = MPI_Wtime();
Expand Down
Loading
Loading