From c39d04a6d619a0294a50efd20bc2d2f65c432e42 Mon Sep 17 00:00:00 2001 From: Chen Nuo <49788094+Cstandardlib@users.noreply.github.com> Date: Sat, 7 Mar 2026 01:00:01 +0800 Subject: [PATCH 1/3] Refactor cg interface from Tensor to T * --- source/source_hsolver/diago_cg.cpp | 66 +++++++++--- source/source_hsolver/diago_cg.h | 24 +++-- source/source_hsolver/hsolver_pw.cpp | 101 +++++------------- .../test/diago_cg_float_test.cpp | 70 ++++++------ .../test/diago_cg_real_test.cpp | 67 ++++++------ source/source_hsolver/test/diago_cg_test.cpp | 70 ++++++------ source/source_hsolver/test/hsolver_pw_sup.h | 37 ++++--- source/source_lcao/module_lr/hsolver_lrtd.hpp | 26 +++-- 8 files changed, 233 insertions(+), 228 deletions(-) diff --git a/source/source_hsolver/diago_cg.cpp b/source/source_hsolver/diago_cg.cpp index 564c36e74b..d6bd08450e 100644 --- a/source/source_hsolver/diago_cg.cpp +++ b/source/source_hsolver/diago_cg.cpp @@ -122,10 +122,10 @@ void DiagoCG::diag_once(const ct::Tensor& prec_in, { phi_m.sync(psi[m]); // copy psi_in into internal psi, m=0 has been done in Constructor - this->spsi_func_(phi_m, sphi); // sphi = S|psi(m)> + this->spsi_func_(phi_m.data(), sphi.data(), this->n_basis_, 1); // sphi = S|psi(m)> this->schmit_orth(m, psi, sphi, phi_m); - this->spsi_func_(phi_m, sphi); // sphi = S|psi(m)> - this->hpsi_func_(phi_m, hphi); // hphi = H|psi(m)> + this->spsi_func_(phi_m.data(), sphi.data(), this->n_basis_, 1); // sphi = S|psi(m)> + this->hpsi_func_(phi_m.data(), hphi.data(), this->n_basis_, 1); // hphi = H|psi(m)> eigen_pack[m] = dot_real_op()(this->n_basis_, phi_m.data(), hphi.data()); @@ -150,8 +150,8 @@ void DiagoCG::diag_once(const ct::Tensor& prec_in, g0, cg); // Tensor& - this->hpsi_func_(cg, pphi); - this->spsi_func_(cg, scg); + this->hpsi_func_(cg.data(), pphi.data(), this->n_basis_, 1); + this->spsi_func_(cg.data(), scg.data(), this->n_basis_, 1); converged = this->update_psi(pphi, cg, @@ -264,7 +264,7 @@ void DiagoCG::orth_grad(const ct::Tensor& psi, ct::Tensor& scg, ct::Tensor& lagrange) { - this->spsi_func_(grad, scg); // scg = S|grad> + this->spsi_func_(grad.data(), scg.data(), this->n_basis_, 1); // scg = S|grad> ModuleBase::gemv_op()('C', this->n_basis_, m, @@ -576,13 +576,39 @@ bool DiagoCG::test_exit_cond(const int& ntry, const int& notconv) con } template -double DiagoCG::diag(const Func& hpsi_func, - const Func& spsi_func, - ct::Tensor& psi, - ct::Tensor& eigen, - const std::vector& ethr_band, - const ct::Tensor& prec) +double DiagoCG::diag(const HPsiFunc& hpsi_func, + const SPsiFunc& spsi_func, + const int ld_psi, + const int nband, + const int dim, + T* psi_in, + Real* eigenvalue_in, + const std::vector& ethr_band, + const Real* prec) { + REQUIRES_OK(ld_psi >= dim, "DiagoCG::diag: ld_psi must be >= dim"); + REQUIRES_OK(static_cast(ethr_band.size()) >= nband, + "DiagoCG::diag: ethr_band size must be >= nband"); + + auto psi = ct::TensorMap(psi_in, + ct::DataTypeToEnum::value, + ct::DeviceTypeToEnum::value, + ct::TensorShape({nband, ld_psi})); + auto eigen = ct::TensorMap(eigenvalue_in, + ct::DataTypeToEnum::value, + ct::DeviceTypeToEnum::value, + ct::TensorShape({nband})); + + ct::Tensor prec_tensor; + if (prec != nullptr) + { + prec_tensor = ct::TensorMap(const_cast(prec), + ct::DataTypeToEnum::value, + ct::DeviceTypeToEnum::value, + ct::TensorShape({dim})) + .template to_device(); + } + /// record the times of trying iterative diagonalization int ntry = 0; this->notconv_ = 0; @@ -590,7 +616,7 @@ double DiagoCG::diag(const Func& hpsi_func, spsi_func_ = spsi_func; // create a new slice of psi to do cg diagonalization - ct::Tensor psi_temp = psi.slice({0, 0}, {int(psi.shape().dim_size(0)), int(prec.shape().dim_size(0))}); + ct::Tensor psi_temp = psi.slice({0, 0}, {nband, dim}); do { // subspace diagonalization to get a better starting guess @@ -601,21 +627,29 @@ double DiagoCG::diag(const Func& hpsi_func, { ct::TensorMap psi_map = ct::TensorMap(psi.data(), psi_temp); const bool assume_S_orthogonal = true; - this->subspace_func_(psi_temp, psi_map, assume_S_orthogonal); + this->subspace_func_(psi_temp.data(), + psi_map.data(), + dim, + nband, + assume_S_orthogonal); psi_temp.sync(psi_map); } else if (need_subspace_) { ct::TensorMap psi_map = ct::TensorMap(psi.data(), psi_temp); const bool assume_S_orthogonal = false; - this->subspace_func_(psi_temp, psi_map, assume_S_orthogonal); + this->subspace_func_(psi_temp.data(), + psi_map.data(), + dim, + nband, + assume_S_orthogonal); psi_temp.sync(psi_map); } ++ntry; avg_iter_ += 1.0; - this->diag_once(prec, psi_temp, eigen, ethr_band); + this->diag_once(prec_tensor, psi_temp, eigen, ethr_band); } while (this->test_exit_cond(ntry, this->notconv_)); if (this->notconv_ > std::max(5, this->n_band_ / 4)) diff --git a/source/source_hsolver/diago_cg.h b/source/source_hsolver/diago_cg.h index bf03cb5850..99d9369a0a 100644 --- a/source/source_hsolver/diago_cg.h +++ b/source/source_hsolver/diago_cg.h @@ -22,8 +22,9 @@ class DiagoCG final using Real = typename GetTypeReal::type; using ct_Device = typename ct::PsiToContainer::type; public: - using Func = std::function; - using SubspaceFunc = std::function; + using HPsiFunc = std::function; + using SPsiFunc = std::function; + using SubspaceFunc = std::function; // Constructor need: // 1. temporary mock of Hamiltonian "Hamilt_PW" // 2. precondition pointer should point to place of precondition array. @@ -43,12 +44,15 @@ class DiagoCG final // refactor hpsi_info // this is the diag() function for CG method // returns avg_iter - double diag(const Func& hpsi_func, - const Func& spsi_func, - ct::Tensor& psi, - ct::Tensor& eigen, - const std::vector& ethr_band, - const ct::Tensor& prec = {}); + double diag(const HPsiFunc& hpsi_func, + const SPsiFunc& spsi_func, + const int ld_psi, + const int nband, + const int dim, + T* psi_in, + Real* eigenvalue_in, + const std::vector& ethr_band, + const Real* prec = nullptr); private: Device * ctx_ = {}; @@ -77,9 +81,9 @@ class DiagoCG final bool need_subspace_ = false; /// A function object that performs the hPsi calculation. - Func hpsi_func_ = nullptr; + HPsiFunc hpsi_func_ = nullptr; /// A function object that performs the sPsi calculation. - Func spsi_func_ = nullptr; + SPsiFunc spsi_func_ = nullptr; /// A function object that performs the subspace calculation. SubspaceFunc subspace_func_ = nullptr; diff --git a/source/source_hsolver/hsolver_pw.cpp b/source/source_hsolver/hsolver_pw.cpp index 2ea7b92a2b..f6c4706c41 100644 --- a/source/source_hsolver/hsolver_pw.cpp +++ b/source/source_hsolver/hsolver_pw.cpp @@ -254,27 +254,16 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, // wrap the subspace_func into a lambda function // if S_orth is true, then assume psi is S-orthogonal, solve standard eigenproblem // otherwise, solve generalized eigenproblem - auto subspace_func = [hm, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& psi_out, const bool S_orth) { - // psi_in should be a 2D tensor: - // psi_in.shape() = [nbands, nbasis] - const auto ndim = psi_in.shape().ndim(); - REQUIRES_OK(ndim == 2, "dims of psi_in should be less than or equal to 2"); - // Convert a Tensor object to a psi::Psi object - auto psi_in_wrapper = psi::Psi(psi_in.data(), - 1, - psi_in.shape().dim_size(0), - psi_in.shape().dim_size(1), - cur_nbasis); - auto psi_out_wrapper = psi::Psi(psi_out.data(), - 1, - psi_out.shape().dim_size(0), - psi_out.shape().dim_size(1), - cur_nbasis); - auto eigen = ct::Tensor(ct::DataTypeToEnum::value, - ct::DeviceType::CpuDevice, - ct::TensorShape({psi_in.shape().dim_size(0)})); - - DiagoIterAssist::diag_subspace(hm, psi_in_wrapper, psi_out_wrapper, eigen.data()); + auto subspace_func = [hm, cur_nbasis](T* psi_in, + T* psi_out, + const int ld_psi, + const int nband, + const bool S_orth) { + (void)S_orth; + auto psi_in_wrapper = psi::Psi(psi_in, 1, nband, ld_psi, cur_nbasis); + auto psi_out_wrapper = psi::Psi(psi_out, 1, nband, ld_psi, cur_nbasis); + std::vector eigen(nband, 0.0); + DiagoIterAssist::diag_subspace(hm, psi_in_wrapper, psi_out_wrapper, eigen.data()); }; DiagoCG cg(this->basis_type, this->calculation_type, @@ -284,70 +273,38 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, this->diag_iter_max, this->nproc_in_pool); - // wrap the hpsi_func and spsi_func into a lambda function - using ct_Device = typename ct::PsiToContainer::type; - - // wrap the hpsi_func and spsi_func into a lambda function - auto hpsi_func = [hm, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) { - // psi_in should be a 2D tensor: - // psi_in.shape() = [nbands, nbasis] - const auto ndim = psi_in.shape().ndim(); - REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2"); - // Convert a Tensor object to a psi::Psi object - auto psi_wrapper = psi::Psi(psi_in.data(), - 1, - ndim == 1 ? 1 : psi_in.shape().dim_size(0), - ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), - cur_nbasis); - psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1); + // wrap the hpsi_func and spsi_func into lambda functions + auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) { + auto psi_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, cur_nbasis); + psi::Range all_bands_range(true, 0, 0, nvec - 1); using hpsi_info = typename hamilt::Operator::hpsi_info; - hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data()); + hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out); hm->ops->hPsi(info); }; - auto spsi_func = [this, hm](const ct::Tensor& psi_in, ct::Tensor& spsi_out) { - // psi_in should be a 2D tensor: - // psi_in.shape() = [nbands, nbasis] - const auto ndim = psi_in.shape().ndim(); - REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2"); - + auto spsi_func = [this, hm](T* psi_in, T* spsi_out, const int ld_psi, const int nvec) { if (this->use_uspp) { - // Convert a Tensor object to a psi::Psi object - hm->sPsi(psi_in.data(), - spsi_out.data(), - ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), - ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), - ndim == 1 ? 1 : psi_in.shape().dim_size(0)); + hm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec); } else { base_device::memory::synchronize_memory_op()( - spsi_out.data(), - psi_in.data(), - static_cast((ndim == 1 ? 1 : psi_in.shape().dim_size(0)) - * (ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1)))); + spsi_out, + psi_in, + static_cast(nvec) * static_cast(ld_psi)); } }; - auto psi_tensor = ct::TensorMap(psi.get_pointer(), - ct::DataTypeToEnum::value, - ct::DeviceTypeToEnum::value, - ct::TensorShape({psi.get_nbands(), psi.get_nbasis()})); - - auto eigen_tensor = ct::TensorMap(eigenvalue, - ct::DataTypeToEnum::value, - ct::DeviceTypeToEnum::value, - ct::TensorShape({psi.get_nbands()})); - - auto prec_tensor = ct::TensorMap(pre_condition.data(), - ct::DataTypeToEnum::value, - ct::DeviceTypeToEnum::value, - ct::TensorShape({static_cast(pre_condition.size())})) - .to_device() - .slice({0}, {psi.get_current_ngk()}); - DiagoIterAssist::avg_iter += static_cast( - cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, this->ethr_band, prec_tensor) + cg.diag(hpsi_func, + spsi_func, + psi.get_nbasis(), + psi.get_nbands(), + psi.get_current_ngk(), + psi.get_pointer(), + eigenvalue, + this->ethr_band, + pre_condition.data()) ); // TODO: Double check tensormap's potential problem // ct::TensorMap(psi.get_pointer(), psi_tensor, {psi.get_nbands(), psi.get_nbasis()}).sync(psi_tensor); diff --git a/source/source_hsolver/test/diago_cg_float_test.cpp b/source/source_hsolver/test/diago_cg_float_test.cpp index 4c4278b3f1..1bd7e7877f 100644 --- a/source/source_hsolver/test/diago_cg_float_test.cpp +++ b/source/source_hsolver/test/diago_cg_float_test.cpp @@ -142,7 +142,20 @@ class DiagoCGPrepare // New interface of cg method /**************************************************************/ // warp the subspace_func into a lambda function - auto subspace_func = [ha](const ct::Tensor& psi_in, ct::Tensor& psi_out, const bool S_orth) { /*do nothing*/ }; + auto subspace_func = [ha](std::complex* psi_in, + std::complex* psi_out, + const int ld_psi, + const int nband, + const bool S_orth) { + (void)S_orth; + auto psi_in_wrapper = psi::Psi>(psi_in, 1, nband, ld_psi, true); + auto psi_out_wrapper = psi::Psi>(psi_out, 1, nband, ld_psi, true); + std::vector eigen(nband, 0.0f); + hsolver::DiagoIterAssist>::diag_subspace(ha, + psi_in_wrapper, + psi_out_wrapper, + eigen.data()); + }; hsolver::DiagoCG> cg( PARAM.input.basis_type, PARAM.input.calculation, @@ -156,46 +169,33 @@ class DiagoCGPrepare float start, end; start = MPI_Wtime(); - auto hpsi_func = [ha](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) { - const auto ndim = psi_in.shape().ndim(); - REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2"); - auto psi_wrapper = psi::Psi>( - psi_in.data>(), 1, - ndim == 1 ? 1 : psi_in.shape().dim_size(0), - ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), true); - psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1); + auto hpsi_func = [ha](std::complex* psi_in, + std::complex* hpsi_out, + const int ld_psi, + const int nvec) { + auto psi_wrapper = psi::Psi>(psi_in, 1, nvec, ld_psi, true); + psi::Range all_bands_range(true, 0, 0, nvec - 1); using hpsi_info = typename hamilt::Operator>::hpsi_info; - hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data>()); + hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out); ha->ops->hPsi(info); }; - auto spsi_func = [ha](const ct::Tensor& psi_in, ct::Tensor& spsi_out) { - const auto ndim = psi_in.shape().ndim(); - REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2"); - ha->sPsi(psi_in.data>(), spsi_out.data>(), - ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), - ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), - ndim == 1 ? 1 : psi_in.shape().dim_size(0)); + auto spsi_func = [ha](std::complex* psi_in, + std::complex* spsi_out, + const int ld_psi, + const int nvec) { + ha->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec); }; - auto psi_tensor = ct::TensorMap( - psi_local.get_pointer(), - ct::DataType::DT_COMPLEX, - ct::DeviceType::CpuDevice, - ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_current_ngk()}); - auto eigen_tensor = ct::TensorMap( - en, - ct::DataType::DT_FLOAT, - ct::DeviceType::CpuDevice, - ct::TensorShape({psi_local.get_nbands()})); - auto prec_tensor = ct::TensorMap( - precondition_local, - ct::DataType::DT_FLOAT, - ct::DeviceType::CpuDevice, - ct::TensorShape({static_cast(psi_local.get_current_ngk())})).slice({0}, {psi_local.get_current_ngk()}); std::vector ethr_band(nband, 1e-5); - cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor); - // TODO: Double check tensormap's potential problem - ct::TensorMap(psi_local.get_pointer(), psi_tensor, {psi_local.get_nbands(), psi_local.get_nbasis()}).sync(psi_tensor); + cg.diag(hpsi_func, + spsi_func, + psi_local.get_nbasis(), + psi_local.get_nbands(), + psi_local.get_current_ngk(), + psi_local.get_pointer(), + en, + ethr_band, + precondition_local); /**************************************************************/ end = MPI_Wtime(); diff --git a/source/source_hsolver/test/diago_cg_real_test.cpp b/source/source_hsolver/test/diago_cg_real_test.cpp index dbcbe95ae8..924c724df5 100644 --- a/source/source_hsolver/test/diago_cg_real_test.cpp +++ b/source/source_hsolver/test/diago_cg_real_test.cpp @@ -147,7 +147,17 @@ class DiagoCGPrepare // New interface of cg method /**************************************************************/ // warp the subspace_func into a lambda function - auto subspace_func = [ha](const ct::Tensor& psi_in, ct::Tensor& psi_out, const bool S_orth) { /*do nothing*/ }; + auto subspace_func = [ha](double* psi_in, + double* psi_out, + const int ld_psi, + const int nband, + const bool S_orth) { + (void)S_orth; + auto psi_in_wrapper = psi::Psi(psi_in, 1, nband, ld_psi, true); + auto psi_out_wrapper = psi::Psi(psi_out, 1, nband, ld_psi, true); + std::vector eigen(nband, 0.0); + hsolver::DiagoIterAssist::diag_subspace(ha, psi_in_wrapper, psi_out_wrapper, eigen.data()); + }; hsolver::DiagoCG cg( PARAM.input.basis_type, PARAM.input.calculation, @@ -161,46 +171,33 @@ class DiagoCGPrepare double start, end; start = MPI_Wtime(); - auto hpsi_func = [ha](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) { - const auto ndim = psi_in.shape().ndim(); - REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2"); - auto psi_wrapper = psi::Psi( - psi_in.data(), 1, - ndim == 1 ? 1 : psi_in.shape().dim_size(0), - ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), true); - psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1); + auto hpsi_func = [ha](double* psi_in, + double* hpsi_out, + const int ld_psi, + const int nvec) { + auto psi_wrapper = psi::Psi(psi_in, 1, nvec, ld_psi, true); + psi::Range all_bands_range(true, 0, 0, nvec - 1); using hpsi_info = typename hamilt::Operator::hpsi_info; - hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data()); + hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out); ha->ops->hPsi(info); }; - auto spsi_func = [ha](const ct::Tensor& psi_in, ct::Tensor& spsi_out) { - const auto ndim = psi_in.shape().ndim(); - REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2"); - ha->sPsi(psi_in.data(), spsi_out.data(), - ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), - ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), - ndim == 1 ? 1 : psi_in.shape().dim_size(0)); + auto spsi_func = [ha](double* psi_in, + double* spsi_out, + const int ld_psi, + const int nvec) { + ha->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec); }; - auto psi_tensor = ct::TensorMap( - psi_local.get_pointer(), - ct::DataType::DT_DOUBLE, - ct::DeviceType::CpuDevice, - ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_current_ngk()}); - auto eigen_tensor = ct::TensorMap( - en, - ct::DataType::DT_DOUBLE, - ct::DeviceType::CpuDevice, - ct::TensorShape({psi_local.get_nbands()})); - auto prec_tensor = ct::TensorMap( - precondition_local, - ct::DataType::DT_DOUBLE, - ct::DeviceType::CpuDevice, - ct::TensorShape({static_cast(psi_local.get_current_ngk())})).slice({0}, {psi_local.get_current_ngk()}); std::vector ethr_band(nband, 1e-5); - cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor); - // TODO: Double check tensormap's potential problem - ct::TensorMap(psi_local.get_pointer(), psi_tensor, {psi_local.get_nbands(), psi_local.get_nbasis()}).sync(psi_tensor); + cg.diag(hpsi_func, + spsi_func, + psi_local.get_nbasis(), + psi_local.get_nbands(), + psi_local.get_current_ngk(), + psi_local.get_pointer(), + en, + ethr_band, + precondition_local); /**************************************************************/ end = MPI_Wtime(); diff --git a/source/source_hsolver/test/diago_cg_test.cpp b/source/source_hsolver/test/diago_cg_test.cpp index b3518205f0..edeefecb0e 100644 --- a/source/source_hsolver/test/diago_cg_test.cpp +++ b/source/source_hsolver/test/diago_cg_test.cpp @@ -136,7 +136,20 @@ class DiagoCGPrepare // New interface of cg method /**************************************************************/ // warp the subspace_func into a lambda function - auto subspace_func = [ha](const ct::Tensor& psi_in, ct::Tensor& psi_out, const bool S_orth) { /*do nothing*/ }; + auto subspace_func = [ha](std::complex* psi_in, + std::complex* psi_out, + const int ld_psi, + const int nband, + const bool S_orth) { + (void)S_orth; + auto psi_in_wrapper = psi::Psi>(psi_in, 1, nband, ld_psi, true); + auto psi_out_wrapper = psi::Psi>(psi_out, 1, nband, ld_psi, true); + std::vector eigen(nband, 0.0); + hsolver::DiagoIterAssist>::diag_subspace(ha, + psi_in_wrapper, + psi_out_wrapper, + eigen.data()); + }; hsolver::DiagoCG> cg( PARAM.input.basis_type, PARAM.input.calculation, @@ -150,46 +163,33 @@ class DiagoCGPrepare double start, end; start = MPI_Wtime(); - auto hpsi_func = [ha](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) { - const auto ndim = psi_in.shape().ndim(); - REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2"); - auto psi_wrapper = psi::Psi>( - psi_in.data>(), 1, - ndim == 1 ? 1 : psi_in.shape().dim_size(0), - ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), true); - psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1); + auto hpsi_func = [ha](std::complex* psi_in, + std::complex* hpsi_out, + const int ld_psi, + const int nvec) { + auto psi_wrapper = psi::Psi>(psi_in, 1, nvec, ld_psi, true); + psi::Range all_bands_range(true, 0, 0, nvec - 1); using hpsi_info = typename hamilt::Operator>::hpsi_info; - hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data>()); + hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out); ha->ops->hPsi(info); }; - auto spsi_func = [ha](const ct::Tensor& psi_in, ct::Tensor& spsi_out) { - const auto ndim = psi_in.shape().ndim(); - REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2"); - ha->sPsi(psi_in.data>(), spsi_out.data>(), - ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), - ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), - ndim == 1 ? 1 : psi_in.shape().dim_size(0)); + auto spsi_func = [ha](std::complex* psi_in, + std::complex* spsi_out, + const int ld_psi, + const int nvec) { + ha->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec); }; - auto psi_tensor = ct::TensorMap( - psi_local.get_pointer(), - ct::DataType::DT_COMPLEX_DOUBLE, - ct::DeviceType::CpuDevice, - ct::TensorShape({psi_local.get_nbands(), psi_local.get_nbasis()})).slice({0, 0}, {psi_local.get_nbands(), psi_local.get_current_ngk()}); - auto eigen_tensor = ct::TensorMap( - en, - ct::DataType::DT_DOUBLE, - ct::DeviceType::CpuDevice, - ct::TensorShape({psi_local.get_nbands()})); - auto prec_tensor = ct::TensorMap( - precondition_local, - ct::DataType::DT_DOUBLE, - ct::DeviceType::CpuDevice, - ct::TensorShape({static_cast(psi_local.get_current_ngk())})).slice({0}, {psi_local.get_current_ngk()}); std::vector ethr_band(nband, 1e-5); - cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor); - // TODO: Double check tensormap's potential problem - ct::TensorMap(psi_local.get_pointer(), psi_tensor, {psi_local.get_nbands(), psi_local.get_nbasis()}).sync(psi_tensor); + cg.diag(hpsi_func, + spsi_func, + psi_local.get_nbasis(), + psi_local.get_nbands(), + psi_local.get_current_ngk(), + psi_local.get_pointer(), + en, + ethr_band, + precondition_local); /**************************************************************/ // cg.diag(ha,psi_local,en); diff --git a/source/source_hsolver/test/hsolver_pw_sup.h b/source/source_hsolver/test/hsolver_pw_sup.h index fb3757a08b..0f76793ce9 100644 --- a/source/source_hsolver/test/hsolver_pw_sup.h +++ b/source/source_hsolver/test/hsolver_pw_sup.h @@ -92,24 +92,29 @@ DiagoCG::~DiagoCG() { } template -double DiagoCG::diag(const Func& hpsi_func, - const Func& spsi_func, - ct::Tensor& psi, - ct::Tensor& eigen, - const std::vector& ethr_band, - const ct::Tensor& prec) { - auto n_bands = psi.shape().dim_size(0); - auto n_basis = psi.shape().dim_size(1); - auto psi_pack = psi.accessor(); - auto eigen_pack = eigen.accessor(); +double DiagoCG::diag(const HPsiFunc& hpsi_func, + const SPsiFunc& spsi_func, + const int ld_psi, + const int nband, + const int dim, + T* psi_in, + Real* eigenvalue_in, + const std::vector& ethr_band, + const Real* prec) { + (void)hpsi_func; + (void)spsi_func; + (void)dim; + (void)ethr_band; + (void)prec; // do something - for (int ib = 0; ib < n_bands; ib++) { - eigen_pack[ib] = 0.0; - for (int ig = 0; ig < n_basis; ig++) { - psi_pack[ib][ig] += T(2.0, 0.0); - eigen_pack[ib] += psi_pack[ib][ig].real(); + for (int ib = 0; ib < nband; ib++) { + eigenvalue_in[ib] = 0.0; + T* psi_band = psi_in + static_cast(ib) * static_cast(ld_psi); + for (int ig = 0; ig < ld_psi; ig++) { + psi_band[ig] += T(2.0, 0.0); + eigenvalue_in[ib] += psi_band[ig].real(); } - eigen_pack[ib] /= n_basis; + eigenvalue_in[ib] /= ld_psi; } DiagoIterAssist::avg_iter += 1.0; return avg_iter_; diff --git a/source/source_lcao/module_lr/hsolver_lrtd.hpp b/source/source_lcao/module_lr/hsolver_lrtd.hpp index b481d56936..b81fd78ed1 100644 --- a/source/source_lcao/module_lr/hsolver_lrtd.hpp +++ b/source/source_lcao/module_lr/hsolver_lrtd.hpp @@ -134,20 +134,28 @@ namespace LR ////// why diago_cg depends on basis_type? // hsolver::DiagoCG cg("lcao", "nscf", true, subspace_func, diag_ethr, maxiter, GlobalV::NPROC_IN_POOL); - auto subspace_func = [](const ct::Tensor& psi_in, ct::Tensor& psi_out, const bool S_orth) {}; + auto subspace_func = [](T* psi_in, T* psi_out, const int ld_psi, const int nband, const bool S_orth) { + }; hsolver::DiagoCG cg("lcao", "nscf", false, subspace_func, diag_ethr, maxiter, GlobalV::NPROC_IN_POOL); - auto psi_tensor = ct::TensorMap(psi, ct::DataTypeToEnum::value, ct::DeviceType::CpuDevice, ct::TensorShape({ nband, dim })); - auto eigen_tensor = ct::TensorMap(eigenvalue.data(), ct::DataTypeToEnum>::value, ct::DeviceType::CpuDevice, ct::TensorShape({ nband })); - std::vector> precondition_(precondition); //since TensorMap does not support const pointer - auto precon_tensor = ct::TensorMap(precondition_.data(), ct::DataTypeToEnum>::value, ct::DeviceType::CpuDevice, ct::TensorShape({ dim })); - auto hpsi_func = [&hm](const ct::Tensor& psi_in, ct::Tensor& hpsi) {hm.hPsi(psi_in.data(), hpsi.data(), psi_in.shape().dim_size(0) /*nbasis_local*/, 1/*band-by-band*/);}; - auto spsi_func = [&hm](const ct::Tensor& psi_in, ct::Tensor& spsi) { - std::memcpy(spsi.data(), psi_in.data(), sizeof(T) * psi_in.NumElements()); + auto hpsi_func = [&hm](T* psi_in, T* hpsi, const int ld_psi, const int nvec) { + hm.hPsi(psi_in, hpsi, ld_psi, nvec); + }; + auto spsi_func = [&hm](T* psi_in, T* spsi, const int ld_psi, const int nvec) { + (void)hm; + std::memcpy(spsi, psi_in, sizeof(T) * static_cast(ld_psi) * static_cast(nvec)); }; std::vector ethr_band(nband, diag_ethr); - cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, precon_tensor); + cg.diag(hpsi_func, + spsi_func, + dim, + nband, + dim, + psi, + eigenvalue.data(), + ethr_band, + precondition.data()); } else { throw std::runtime_error("HSolverLR::solve: method not implemented"); } } From 36276ab687d520fbf769fbe6a61bbfeff5ad7f86 Mon Sep 17 00:00:00 2001 From: Chen Nuo <49788094+Cstandardlib@users.noreply.github.com> Date: Sat, 7 Mar 2026 01:28:43 +0800 Subject: [PATCH 2/3] Remove redundant code --- source/source_hsolver/hsolver_pw.cpp | 1 - source/source_hsolver/test/diago_cg_real_test.cpp | 1 - source/source_hsolver/test/diago_cg_test.cpp | 1 - source/source_hsolver/test/hsolver_pw_sup.h | 5 ----- source/source_lcao/module_lr/hsolver_lrtd.hpp | 3 +-- 5 files changed, 1 insertion(+), 10 deletions(-) diff --git a/source/source_hsolver/hsolver_pw.cpp b/source/source_hsolver/hsolver_pw.cpp index f6c4706c41..0d74f72162 100644 --- a/source/source_hsolver/hsolver_pw.cpp +++ b/source/source_hsolver/hsolver_pw.cpp @@ -259,7 +259,6 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, const int ld_psi, const int nband, const bool S_orth) { - (void)S_orth; auto psi_in_wrapper = psi::Psi(psi_in, 1, nband, ld_psi, cur_nbasis); auto psi_out_wrapper = psi::Psi(psi_out, 1, nband, ld_psi, cur_nbasis); std::vector eigen(nband, 0.0); diff --git a/source/source_hsolver/test/diago_cg_real_test.cpp b/source/source_hsolver/test/diago_cg_real_test.cpp index 924c724df5..eacb1341eb 100644 --- a/source/source_hsolver/test/diago_cg_real_test.cpp +++ b/source/source_hsolver/test/diago_cg_real_test.cpp @@ -152,7 +152,6 @@ class DiagoCGPrepare const int ld_psi, const int nband, const bool S_orth) { - (void)S_orth; auto psi_in_wrapper = psi::Psi(psi_in, 1, nband, ld_psi, true); auto psi_out_wrapper = psi::Psi(psi_out, 1, nband, ld_psi, true); std::vector eigen(nband, 0.0); diff --git a/source/source_hsolver/test/diago_cg_test.cpp b/source/source_hsolver/test/diago_cg_test.cpp index edeefecb0e..bc373a4d80 100644 --- a/source/source_hsolver/test/diago_cg_test.cpp +++ b/source/source_hsolver/test/diago_cg_test.cpp @@ -141,7 +141,6 @@ class DiagoCGPrepare const int ld_psi, const int nband, const bool S_orth) { - (void)S_orth; auto psi_in_wrapper = psi::Psi>(psi_in, 1, nband, ld_psi, true); auto psi_out_wrapper = psi::Psi>(psi_out, 1, nband, ld_psi, true); std::vector eigen(nband, 0.0); diff --git a/source/source_hsolver/test/hsolver_pw_sup.h b/source/source_hsolver/test/hsolver_pw_sup.h index 0f76793ce9..a5aab01735 100644 --- a/source/source_hsolver/test/hsolver_pw_sup.h +++ b/source/source_hsolver/test/hsolver_pw_sup.h @@ -101,11 +101,6 @@ double DiagoCG::diag(const HPsiFunc& hpsi_func, Real* eigenvalue_in, const std::vector& ethr_band, const Real* prec) { - (void)hpsi_func; - (void)spsi_func; - (void)dim; - (void)ethr_band; - (void)prec; // do something for (int ib = 0; ib < nband; ib++) { eigenvalue_in[ib] = 0.0; diff --git a/source/source_lcao/module_lr/hsolver_lrtd.hpp b/source/source_lcao/module_lr/hsolver_lrtd.hpp index b81fd78ed1..f1d3991955 100644 --- a/source/source_lcao/module_lr/hsolver_lrtd.hpp +++ b/source/source_lcao/module_lr/hsolver_lrtd.hpp @@ -141,8 +141,7 @@ namespace LR auto hpsi_func = [&hm](T* psi_in, T* hpsi, const int ld_psi, const int nvec) { hm.hPsi(psi_in, hpsi, ld_psi, nvec); }; - auto spsi_func = [&hm](T* psi_in, T* spsi, const int ld_psi, const int nvec) { - (void)hm; + auto spsi_func = [](T* psi_in, T* spsi, const int ld_psi, const int nvec) { std::memcpy(spsi, psi_in, sizeof(T) * static_cast(ld_psi) * static_cast(nvec)); }; From d656a5fdaa7bf74261a2c73371fe045435e65141 Mon Sep 17 00:00:00 2001 From: Chen Nuo <49788094+Cstandardlib@users.noreply.github.com> Date: Sat, 7 Mar 2026 01:30:58 +0800 Subject: [PATCH 3/3] Remove redundant code --- source/source_hsolver/test/diago_cg_float_test.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/source/source_hsolver/test/diago_cg_float_test.cpp b/source/source_hsolver/test/diago_cg_float_test.cpp index 1bd7e7877f..32514c92e3 100644 --- a/source/source_hsolver/test/diago_cg_float_test.cpp +++ b/source/source_hsolver/test/diago_cg_float_test.cpp @@ -147,7 +147,6 @@ class DiagoCGPrepare const int ld_psi, const int nband, const bool S_orth) { - (void)S_orth; auto psi_in_wrapper = psi::Psi>(psi_in, 1, nband, ld_psi, true); auto psi_out_wrapper = psi::Psi>(psi_out, 1, nband, ld_psi, true); std::vector eigen(nband, 0.0f);