From f51e6c137508c226e9ca00ebdfcbee33e111ebef Mon Sep 17 00:00:00 2001 From: AsTonyshment Date: Thu, 6 Nov 2025 16:52:51 +0800 Subject: [PATCH] Refactor cal_edm_tddft to replace raw ScaLAPACK and BLAS calls with ScalapackConnector and BlasConnector interfaces --- .../source_estate/module_dm/cal_edm_tddft.cpp | 264 +++++++++--------- 1 file changed, 131 insertions(+), 133 deletions(-) diff --git a/source/source_estate/module_dm/cal_edm_tddft.cpp b/source/source_estate/module_dm/cal_edm_tddft.cpp index 809d76e1e4..46168a56e0 100644 --- a/source/source_estate/module_dm/cal_edm_tddft.cpp +++ b/source/source_estate/module_dm/cal_edm_tddft.cpp @@ -2,17 +2,17 @@ #include "source_base/module_external/lapack_connector.h" #include "source_base/module_external/scalapack_connector.h" - #include "source_io/module_parameter/parameter.h" // use PARAM.globalv namespace elecstate { // use the original formula (Hamiltonian matrix) to calculate energy density matrix void cal_edm_tddft(Parallel_Orbitals& pv, - LCAO_domain::Setup_DM> &dmat, + LCAO_domain::Setup_DM>& dmat, K_Vectors& kv, hamilt::Hamilt>* p_hamilt) { - // mohan add 2024-03-27 + ModuleBase::timer::tick("elecstate", "cal_edm_tddft"); + const int nlocal = PARAM.globalv.nlocal; assert(nlocal >= 0); @@ -25,10 +25,6 @@ void cal_edm_tddft(Parallel_Orbitals& pv, ModuleBase::ComplexMatrix& tmp_edmk = dmat.dm->EDMK[ik]; #ifdef __MPI - - // mohan add 2024-03-27 - //! be careful, the type of nloc is 'long' - //! whether the long type is safe, needs more discussion const int nloc = pv.nloc; const int ncol = pv.ncol; const int nrow = pv.nrow; @@ -54,14 +50,14 @@ void cal_edm_tddft(Parallel_Orbitals& pv, hamilt::MatrixBlock> s_mat; p_hamilt->matrix(h_mat, s_mat); - zcopy_(&nloc, h_mat.p, &inc, Htmp, &inc); - zcopy_(&nloc, s_mat.p, &inc, Sinv, &inc); + BlasConnector::copy(nloc, h_mat.p, inc, Htmp, inc); + BlasConnector::copy(nloc, s_mat.p, inc, Sinv, inc); vector ipiv(nloc, 0); int info = 0; const int one_int = 1; - pzgetrf_(&nlocal, &nlocal, Sinv, &one_int, &one_int, pv.desc, ipiv.data(), &info); + ScalapackConnector::getrf(nlocal, nlocal, Sinv, one_int, one_int, pv.desc, ipiv.data(), &info); int lwork = -1; int liwork = -1; @@ -72,136 +68,136 @@ void cal_edm_tddft(Parallel_Orbitals& pv, // if liwork = -1, then the size of iwork is (at least) of length 1. std::vector iwork(1, 0); - pzgetri_(&nlocal, - Sinv, - &one_int, - &one_int, - pv.desc, - ipiv.data(), - work.data(), - &lwork, - iwork.data(), - &liwork, - &info); + ScalapackConnector::getri(nlocal, + Sinv, + one_int, + one_int, + pv.desc, + ipiv.data(), + work.data(), + &lwork, + iwork.data(), + &liwork, + &info); lwork = work[0].real(); work.resize(lwork, 0); liwork = iwork[0]; iwork.resize(liwork, 0); - pzgetri_(&nlocal, - Sinv, - &one_int, - &one_int, - pv.desc, - ipiv.data(), - work.data(), - &lwork, - iwork.data(), - &liwork, - &info); + ScalapackConnector::getri(nlocal, + Sinv, + one_int, + one_int, + pv.desc, + ipiv.data(), + work.data(), + &lwork, + iwork.data(), + &liwork, + &info); const char N_char = 'N'; const char T_char = 'T'; - const std::complex one_float = {1.0, 0.0}; - const std::complex zero_float = {0.0, 0.0}; - const std::complex half_float = {0.5, 0.0}; - - pzgemm_(&N_char, - &N_char, - &nlocal, - &nlocal, - &nlocal, - &one_float, - Htmp, - &one_int, - &one_int, - pv.desc, - Sinv, - &one_int, - &one_int, - pv.desc, - &zero_float, - tmp1, - &one_int, - &one_int, - pv.desc); - - pzgemm_(&T_char, - &N_char, - &nlocal, - &nlocal, - &nlocal, - &one_float, - tmp1, - &one_int, - &one_int, - pv.desc, - tmp_dmk, - &one_int, - &one_int, - pv.desc, - &zero_float, - tmp2, - &one_int, - &one_int, - pv.desc); - - pzgemm_(&N_char, - &N_char, - &nlocal, - &nlocal, - &nlocal, - &one_float, - Sinv, - &one_int, - &one_int, - pv.desc, - Htmp, - &one_int, - &one_int, - pv.desc, - &zero_float, - tmp3, - &one_int, - &one_int, - pv.desc); - - pzgemm_(&N_char, - &T_char, - &nlocal, - &nlocal, - &nlocal, - &one_float, - tmp_dmk, - &one_int, - &one_int, - pv.desc, - tmp3, - &one_int, - &one_int, - pv.desc, - &zero_float, - tmp4, - &one_int, - &one_int, - pv.desc); - - pzgeadd_(&N_char, - &nlocal, - &nlocal, - &half_float, - tmp2, - &one_int, - &one_int, - pv.desc, - &half_float, - tmp4, - &one_int, - &one_int, - pv.desc); - - zcopy_(&nloc, tmp4, &inc, tmp_edmk.c, &inc); + const std::complex one_complex = {1.0, 0.0}; + const std::complex zero_complex = {0.0, 0.0}; + const std::complex half_complex = {0.5, 0.0}; + + ScalapackConnector::gemm(N_char, + N_char, + nlocal, + nlocal, + nlocal, + one_complex, + Htmp, + one_int, + one_int, + pv.desc, + Sinv, + one_int, + one_int, + pv.desc, + zero_complex, + tmp1, + one_int, + one_int, + pv.desc); + + ScalapackConnector::gemm(T_char, + N_char, + nlocal, + nlocal, + nlocal, + one_complex, + tmp1, + one_int, + one_int, + pv.desc, + tmp_dmk, + one_int, + one_int, + pv.desc, + zero_complex, + tmp2, + one_int, + one_int, + pv.desc); + + ScalapackConnector::gemm(N_char, + N_char, + nlocal, + nlocal, + nlocal, + one_complex, + Sinv, + one_int, + one_int, + pv.desc, + Htmp, + one_int, + one_int, + pv.desc, + zero_complex, + tmp3, + one_int, + one_int, + pv.desc); + + ScalapackConnector::gemm(N_char, + T_char, + nlocal, + nlocal, + nlocal, + one_complex, + tmp_dmk, + one_int, + one_int, + pv.desc, + tmp3, + one_int, + one_int, + pv.desc, + zero_complex, + tmp4, + one_int, + one_int, + pv.desc); + + ScalapackConnector::geadd(N_char, + nlocal, + nlocal, + half_complex, + tmp2, + one_int, + one_int, + pv.desc, + half_complex, + tmp4, + one_int, + one_int, + pv.desc); + + BlasConnector::copy(nloc, tmp4, inc, tmp_edmk.c, inc); delete[] Htmp; delete[] Sinv; @@ -219,7 +215,7 @@ void cal_edm_tddft(Parallel_Orbitals& pv, hamilt::MatrixBlock> s_mat; p_hamilt->matrix(h_mat, s_mat); - // cout<<"hmat "<