diff --git a/source/source_estate/module_dm/cal_dm_psi.cpp b/source/source_estate/module_dm/cal_dm_psi.cpp index bd14f21774..ac52854be1 100644 --- a/source/source_estate/module_dm/cal_dm_psi.cpp +++ b/source/source_estate/module_dm/cal_dm_psi.cpp @@ -164,24 +164,24 @@ void psiMulPsiMpi(const psi::Psi& psi1, const int nlocal = desc_dm[2]; const int nbands = desc_psi[3]; - pdgemm_(&N_char, - &T_char, - &nlocal, - &nlocal, - &nbands, - &one_float, + ScalapackConnector::gemm(N_char, + T_char, + nlocal, + nlocal, + nbands, + one_float, psi1.get_pointer(), - &one_int, - &one_int, + one_int, + one_int, desc_psi, psi2.get_pointer(), - &one_int, - &one_int, + one_int, + one_int, desc_psi, - &zero_float, + zero_float, dm_out, - &one_int, - &one_int, + one_int, + one_int, desc_dm); ModuleBase::timer::tick("psiMulPsiMpi", "pdgemm"); } @@ -198,24 +198,24 @@ void psiMulPsiMpi(const psi::Psi>& psi1, const char N_char = 'N', T_char = 'T'; const int nlocal = desc_dm[2]; const int nbands = desc_psi[3]; - pzgemm_(&N_char, - &T_char, - &nlocal, - &nlocal, - &nbands, - &one_complex, + ScalapackConnector::gemm(N_char, + T_char, + nlocal, + nlocal, + nbands, + one_complex, psi1.get_pointer(), - &one_int, - &one_int, + one_int, + one_int, desc_psi, psi2.get_pointer(), - &one_int, - &one_int, + one_int, + one_int, desc_psi, - &zero_complex, + zero_complex, dm_out, - &one_int, - &one_int, + one_int, + one_int, desc_dm); ModuleBase::timer::tick("psiMulPsiMpi", "pdgemm"); } @@ -229,19 +229,19 @@ void psiMulPsi(const psi::Psi& psi1, const psi::Psi& psi2, doubl const char N_char = 'N', T_char = 'T'; const int nlocal = psi1.get_nbasis(); const int nbands = psi1.get_nbands(); - dgemm_(&N_char, - &T_char, - &nlocal, - &nlocal, - &nbands, - &one_float, + BlasConnector::gemm_cm(N_char, + T_char, + nlocal, + nlocal, + nbands, + one_float, psi1.get_pointer(), - &nlocal, + nlocal, psi2.get_pointer(), - &nlocal, - &zero_float, + nlocal, + zero_float, dm_out, - &nlocal); + nlocal); } void psiMulPsi(const psi::Psi>& psi1, @@ -254,19 +254,19 @@ void psiMulPsi(const psi::Psi>& psi1, const int nbands = psi1.get_nbands(); const std::complex one_complex = {1.0, 0.0}; const std::complex zero_complex = {0.0, 0.0}; - zgemm_(&N_char, - &T_char, - &nlocal, - &nlocal, - &nbands, - &one_complex, + BlasConnector::gemm_cm(N_char, + T_char, + nlocal, + nlocal, + nbands, + one_complex, psi1.get_pointer(), - &nlocal, + nlocal, psi2.get_pointer(), - &nlocal, - &zero_complex, + nlocal, + zero_complex, dm_out, - &nlocal); + nlocal); } } // namespace elecstate diff --git a/source/source_io/output_mulliken.cpp b/source/source_io/output_mulliken.cpp index f5d0c47f65..5cbdfab589 100644 --- a/source/source_io/output_mulliken.cpp +++ b/source/source_io/output_mulliken.cpp @@ -547,24 +547,24 @@ void Output_Mulliken>::cal_orbMulP() const char N_char = 'N'; const int one_int = 1; const std::complex one_float = {1.0, 0.0}, zero_float = {0.0, 0.0}; - pzgemm_(&N_char, - &T_char, - &nw, - &nw, - &nw, - &one_float, + ScalapackConnector::gemm(N_char, + T_char, + nw, + nw, + nw, + one_float, p_DMk, - &one_int, - &one_int, + one_int, + one_int, this->ParaV_->desc, p_Sk, - &one_int, - &one_int, + one_int, + one_int, this->ParaV_->desc, - &zero_float, + zero_float, mud.c, - &one_int, - &one_int, + one_int, + one_int, this->ParaV_->desc); this->collect_MW(MecMulP, mud, nw, this->isk_[ik]); #endif @@ -597,24 +597,24 @@ void Output_Mulliken::cal_orbMulP() const char N_char = 'N'; const int one_int = 1; const double one_float = 1.0, zero_float = 0.0; - pdgemm_(&N_char, - &T_char, - &nw, - &nw, - &nw, - &one_float, + ScalapackConnector::gemm(N_char, + T_char, + nw, + nw, + nw, + one_float, p_DMk, - &one_int, - &one_int, + one_int, + one_int, this->ParaV_->desc, p_Sk, - &one_int, - &one_int, + one_int, + one_int, this->ParaV_->desc, - &zero_float, + zero_float, mud.c, - &one_int, - &one_int, + one_int, + one_int, this->ParaV_->desc); if (this->nspin_ == 1 || this->nspin_ == 2) { diff --git a/source/source_io/to_wannier90_lcao.cpp b/source/source_io/to_wannier90_lcao.cpp index 45414d15b9..f7c5d0ac57 100644 --- a/source/source_io/to_wannier90_lcao.cpp +++ b/source/source_io/to_wannier90_lcao.cpp @@ -423,44 +423,44 @@ void toWannier90_LCAO::unkdotkb(const UnitCell& ucell, ModuleBase::GlobalFunc::ZEROS(out_matrix, nloc); #ifdef __MPI - pzgemm_(&transa, - &transb, - &Bands, - &nlocal, - &nlocal, - &alpha, + ScalapackConnector::gemm(transa, + transb, + Bands, + nlocal, + nlocal, + alpha, &psi_in(ik, 0, 0), - &one, - &one, + one, + one, this->ParaV->desc, midmatrix, - &one, - &one, + one, + one, this->ParaV->desc, - &beta, + beta, C_matrix, - &one, - &one, + one, + one, this->ParaV->desc); - pzgemm_(&transb, - &transb, - &Bands, - &Bands, - &nlocal, - &alpha, + ScalapackConnector::gemm(transb, + transb, + Bands, + Bands, + nlocal, + alpha, C_matrix, - &one, - &one, + one, + one, this->ParaV->desc, &psi_in(ikb, 0, 0), - &one, - &one, + one, + one, this->ParaV->desc, - &beta, + beta, out_matrix, - &one, - &one, + one, + one, this->ParaV->desc); #endif diff --git a/source/source_io/unk_overlap_lcao.cpp b/source/source_io/unk_overlap_lcao.cpp index 4d79d78a13..2a2ec8ac91 100644 --- a/source/source_io/unk_overlap_lcao.cpp +++ b/source/source_io/unk_overlap_lcao.cpp @@ -563,44 +563,44 @@ std::complex unkOverlap_lcao::det_berryphase(const UnitCell& ucell, std::complex alpha = {1.0, 0.0}, beta = {0.0, 0.0}; int one = 1; #ifdef __MPI - pzgemm_(&transa, - &transb, - &occBands, - &nlocal, - &nlocal, - &alpha, + ScalapackConnector::gemm(transa, + transb, + occBands, + nlocal, + nlocal, + alpha, &psi_in[0](ik_L, 0, 0), - &one, - &one, + one, + one, para_orb.desc, midmatrix, - &one, - &one, + one, + one, para_orb.desc, - &beta, + beta, C_matrix, - &one, - &one, + one, + one, para_orb.desc); - pzgemm_(&transb, - &transb, - &occBands, - &occBands, - &nlocal, - &alpha, + ScalapackConnector::gemm(transb, + transb, + occBands, + occBands, + nlocal, + alpha, C_matrix, - &one, - &one, + one, + one, para_orb.desc, &psi_in[0](ik_R, 0, 0), - &one, - &one, + one, + one, para_orb.desc, - &beta, + beta, out_matrix, - &one, - &one, + one, + one, para_orb.desc); assert(para_orb.nrow>0); diff --git a/source/source_lcao/module_deepks/deepks_orbpre.cpp b/source/source_lcao/module_deepks/deepks_orbpre.cpp index 60d3ee16a6..6736a7f9b0 100644 --- a/source/source_lcao/module_deepks/deepks_orbpre.cpp +++ b/source/source_lcao/module_deepks/deepks_orbpre.cpp @@ -213,19 +213,19 @@ void DeePKS_domain::cal_orbital_precalc(const std::vector& dm_hl, gemm_alpha = 2.0; } - dgemm_(&transa, - &transb, - &row_size_nks, - &trace_alpha_size, - &col_size, - &gemm_alpha, - dm_array.data(), - &col_size, + BlasConnector::gemm(transb, + transa, + trace_alpha_size, + row_size_nks, + col_size, + gemm_alpha, s_2t.data(), - &col_size, - &gemm_beta, + col_size, + dm_array.data(), + col_size, + gemm_beta, g_1dmt.data(), - &row_size_nks); + row_size_nks); } // ad2 for (int ik = 0; ik < nks; ik++) diff --git a/source/source_lcao/module_deepks/deepks_pdm.cpp b/source/source_lcao/module_deepks/deepks_pdm.cpp index ae687c2c01..b6551116d0 100644 --- a/source/source_lcao/module_deepks/deepks_pdm.cpp +++ b/source/source_lcao/module_deepks/deepks_pdm.cpp @@ -373,19 +373,19 @@ void DeePKS_domain::cal_pdm(bool& init_pdm, // all the input should be data pointer constexpr char transa = 'T', transb = 'N'; const double gemm_alpha = 1.0, gemm_beta = 1.0; - dgemm_(&transa, - &transb, - &row_size, - &trace_alpha_size, - &col_size, - &gemm_alpha, - dm_current, - &col_size, + BlasConnector::gemm(transb, + transa, + trace_alpha_size, + row_size, + col_size, + gemm_alpha, s_2t.data(), - &col_size, - &gemm_beta, + col_size, + dm_current, + col_size, + gemm_beta, g_1dmt.data(), - &row_size); + row_size); } // ad2 if (!PARAM.inp.deepks_equiv) { diff --git a/source/source_lcao/module_dftu/dftu_force.cpp b/source/source_lcao/module_dftu/dftu_force.cpp index b54026663a..3dc90223d0 100644 --- a/source/source_lcao/module_dftu/dftu_force.cpp +++ b/source/source_lcao/module_dftu/dftu_force.cpp @@ -5,6 +5,7 @@ #include "source_base/constants.h" #include "source_base/global_function.h" #include "source_base/inverse_matrix.h" +#include "source_base/module_external/scalapack_connector.h" #include "source_base/parallel_reduce.h" #include "source_base/timer.h" #include "source_estate/elecstate_lcao.h" @@ -112,11 +113,11 @@ void Plus_U::force_stress(const UnitCell& ucell, this->cal_VU_pot_mat_real(spin, false, VU); #ifdef __MPI - pdgemm_(&transT, &transN, &nlocal, &nlocal, &nlocal, - &alpha, (*dmk_d)[spin].data(), &one_int, &one_int, // important to add () outside *dmk_d, mohan note 20251103 - pv.desc, VU, &one_int, &one_int, - pv.desc, &beta, &rho_VU[0], - &one_int, &one_int, pv.desc); + ScalapackConnector::gemm(transT, transN, nlocal, nlocal, nlocal, + alpha, (*dmk_d)[spin].data(), 1, 1, + pv.desc, VU, 1, 1, + pv.desc, beta, &rho_VU[0], + 1, 1, pv.desc); #endif delete[] VU; @@ -444,24 +445,24 @@ void Plus_U::cal_force_gamma(const UnitCell& ucell, } #ifdef __MPI - pdgemm_(&transN, - &transT, - &PARAM.globalv.nlocal, - &PARAM.globalv.nlocal, - &PARAM.globalv.nlocal, - &one, + ScalapackConnector::gemm(transN, + transT, + PARAM.globalv.nlocal, + PARAM.globalv.nlocal, + PARAM.globalv.nlocal, + one, tmp_ptr, - &one_int, - &one_int, + 1, + 1, pv.desc, rho_VU, - &one_int, - &one_int, + 1, + 1, pv.desc, - &zero, + zero, &dm_VU_dSm[0], - &one_int, - &one_int, + 1, + 1, pv.desc); #endif @@ -482,24 +483,24 @@ void Plus_U::cal_force_gamma(const UnitCell& ucell, } // end ir #ifdef __MPI - pdgemm_(&transN, - &transT, - &PARAM.globalv.nlocal, - &PARAM.globalv.nlocal, - &PARAM.globalv.nlocal, - &one, + ScalapackConnector::gemm(transN, + transT, + PARAM.globalv.nlocal, + PARAM.globalv.nlocal, + PARAM.globalv.nlocal, + one, tmp_ptr, - &one_int, - &one_int, + 1, + 1, pv.desc, rho_VU, - &one_int, - &one_int, + 1, + 1, pv.desc, - &zero, + zero, &dm_VU_dSm[0], - &one_int, - &one_int, + 1, + 1, pv.desc); #endif @@ -582,24 +583,24 @@ void Plus_U::cal_stress_gamma(const UnitCell& ucell, this->fold_dSR_gamma(ucell, pv, gd, dsloc_x, dsloc_y, dsloc_z, dh_r, dim1, dim2, &dSR_gamma[0]); #ifdef __MPI - pdgemm_(&transN, - &transN, - &nlocal, - &nlocal, - &nlocal, - &minus_half, + ScalapackConnector::gemm(transN, + transN, + nlocal, + nlocal, + nlocal, + minus_half, rho_VU, - &one_int, - &one_int, + 1, + 1, pv.desc, &dSR_gamma[0], - &one_int, - &one_int, + 1, + 1, pv.desc, - &zero, + zero, &dm_VU_sover[0], - &one_int, - &one_int, + 1, + 1, pv.desc); #endif diff --git a/source/source_lcao/module_dftu/dftu_hamilt.cpp b/source/source_lcao/module_dftu/dftu_hamilt.cpp index 7731aafdb7..f4d1b79a32 100644 --- a/source/source_lcao/module_dftu/dftu_hamilt.cpp +++ b/source/source_lcao/module_dftu/dftu_hamilt.cpp @@ -36,13 +36,13 @@ void Plus_U::cal_eff_pot_mat_complex(const int ik, this->cal_VU_pot_mat_complex(spin, true, &VU[0]); #ifdef __MPI - pzgemm_(&transN, &transN, - &PARAM.globalv.nlocal, &PARAM.globalv.nlocal, &PARAM.globalv.nlocal, - &half, - ModuleBase::GlobalFunc::VECTOR_TO_PTR(VU), &one_int, &one_int, this->paraV->desc, - sk, &one_int, &one_int, this->paraV->desc, - &zero, - eff_pot, &one_int, &one_int, this->paraV->desc); + ScalapackConnector::gemm(transN, transN, + PARAM.globalv.nlocal, PARAM.globalv.nlocal, PARAM.globalv.nlocal, + half, + ModuleBase::GlobalFunc::VECTOR_TO_PTR(VU), one_int, one_int, this->paraV->desc, + sk, one_int, one_int, this->paraV->desc, + zero, + eff_pot, one_int, one_int, this->paraV->desc); #endif for (int irc = 0; irc < this->paraV->nloc; irc++) @@ -86,13 +86,13 @@ void Plus_U::cal_eff_pot_mat_real(const int ik, double* eff_pot, const std::vect this->cal_VU_pot_mat_real(spin, 1, &VU[0]); #ifdef __MPI - pdgemm_(&transN, &transN, - &PARAM.globalv.nlocal, &PARAM.globalv.nlocal, &PARAM.globalv.nlocal, - &half, - ModuleBase::GlobalFunc::VECTOR_TO_PTR(VU), &one_int, &one_int, this->paraV->desc, - sk, &one_int, &one_int, this->paraV->desc, - &beta, - eff_pot, &one_int, &one_int, this->paraV->desc); + ScalapackConnector::gemm(transN, transN, + PARAM.globalv.nlocal, PARAM.globalv.nlocal, PARAM.globalv.nlocal, + half, + ModuleBase::GlobalFunc::VECTOR_TO_PTR(VU), 1, 1, this->paraV->desc, + sk, 1, 1, this->paraV->desc, + beta, + eff_pot, 1, 1, this->paraV->desc); #endif for (int irc = 0; irc < this->paraV->nloc; irc++) @@ -120,21 +120,21 @@ void Plus_U::cal_eff_pot_mat_R_double(const int ispin, double* SR, double* HR) this->cal_VU_pot_mat_real(ispin, 1, &VU[0]); #ifdef __MPI - pdgemm_(&transN, &transN, - &PARAM.globalv.nlocal, &PARAM.globalv.nlocal, &PARAM.globalv.nlocal, - &half, - ModuleBase::GlobalFunc::VECTOR_TO_PTR(VU), &one_int, &one_int, this->paraV->desc, - SR, &one_int, &one_int, this->paraV->desc, - &beta, - HR, &one_int, &one_int, this->paraV->desc); - - pdgemm_(&transN, &transN, - &PARAM.globalv.nlocal, &PARAM.globalv.nlocal, &PARAM.globalv.nlocal, - &half, - SR, &one_int, &one_int, this->paraV->desc, - ModuleBase::GlobalFunc::VECTOR_TO_PTR(VU), &one_int, &one_int, this->paraV->desc, - &one, - HR, &one_int, &one_int, this->paraV->desc); + ScalapackConnector::gemm(transN, transN, + PARAM.globalv.nlocal, PARAM.globalv.nlocal, PARAM.globalv.nlocal, + half, + ModuleBase::GlobalFunc::VECTOR_TO_PTR(VU), 1, 1, this->paraV->desc, + SR, 1, 1, this->paraV->desc, + beta, + HR, 1, 1, this->paraV->desc); + + ScalapackConnector::gemm(transN, transN, + PARAM.globalv.nlocal, PARAM.globalv.nlocal, PARAM.globalv.nlocal, + half, + SR, 1, 1, this->paraV->desc, + ModuleBase::GlobalFunc::VECTOR_TO_PTR(VU), 1, 1, this->paraV->desc, + one, + HR, 1, 1, this->paraV->desc); #endif return; @@ -150,21 +150,21 @@ void Plus_U::cal_eff_pot_mat_R_complex_double(const int ispin, std::complexcal_VU_pot_mat_complex(ispin, 1, &VU[0]); #ifdef __MPI - pzgemm_(&transN, &transN, - &PARAM.globalv.nlocal, &PARAM.globalv.nlocal, &PARAM.globalv.nlocal, - &half, - ModuleBase::GlobalFunc::VECTOR_TO_PTR(VU), &one_int, &one_int, this->paraV->desc, - SR, &one_int, &one_int, this->paraV->desc, - &zero, - HR, &one_int, &one_int, this->paraV->desc); - - pzgemm_(&transN, &transN, - &PARAM.globalv.nlocal, &PARAM.globalv.nlocal, &PARAM.globalv.nlocal, - &half, - SR, &one_int, &one_int, this->paraV->desc, - ModuleBase::GlobalFunc::VECTOR_TO_PTR(VU), &one_int, &one_int, this->paraV->desc, - &one, - HR, &one_int, &one_int, this->paraV->desc); + ScalapackConnector::gemm(transN, transN, + PARAM.globalv.nlocal, PARAM.globalv.nlocal, PARAM.globalv.nlocal, + half, + ModuleBase::GlobalFunc::VECTOR_TO_PTR(VU), one_int, one_int, this->paraV->desc, + SR, one_int, one_int, this->paraV->desc, + zero, + HR, one_int, one_int, this->paraV->desc); + + ScalapackConnector::gemm(transN, transN, + PARAM.globalv.nlocal, PARAM.globalv.nlocal, PARAM.globalv.nlocal, + half, + SR, one_int, one_int, this->paraV->desc, + ModuleBase::GlobalFunc::VECTOR_TO_PTR(VU), one_int, one_int, this->paraV->desc, + one, + HR, one_int, one_int, this->paraV->desc); #endif return; diff --git a/source/source_lcao/module_lr/ao_to_mo_transformer/ao_to_mo_parallel.cpp b/source/source_lcao/module_lr/ao_to_mo_transformer/ao_to_mo_parallel.cpp index 7e614619db..7d936365ce 100644 --- a/source/source_lcao/module_lr/ao_to_mo_transformer/ao_to_mo_parallel.cpp +++ b/source/source_lcao/module_lr/ao_to_mo_transformer/ao_to_mo_parallel.cpp @@ -51,18 +51,18 @@ namespace LR char transb = 'N'; const double alpha = 1.0; const double beta = add_on ? 1.0 : 0.0; - pdgemm_(&transa, &transb, &naos, &nmo1, &naos, - &alpha, mat_ao[isk].data(), &i1, &i1, pmat_ao.desc, - coeff.get_pointer(), &i1, &imo1, pcoeff.desc, - &beta, Vc.data(), &i1, &i1, pVc.desc); + ScalapackConnector::gemm(transa, transb, naos, nmo1, naos, + alpha, mat_ao[isk].data(), i1, i1, pmat_ao.desc, + coeff.get_pointer(), i1, imo1, pcoeff.desc, + beta, Vc.data(), i1, i1, pVc.desc); transa = 'T'; // mat_mo = c ^ TVc // descC puts M(nvirt) to row - pdgemm_(&transa, &transb, &nmo2, &nmo1, &naos, - &alpha, coeff.get_pointer(), &i1, &imo2, pcoeff.desc, - Vc.data(), &i1, &i1, pVc.desc, - &beta, mat_mo + start, &i1, &i1, pmat_mo.desc); + ScalapackConnector::gemm(transa, transb, nmo2, nmo1, naos, + alpha, coeff.get_pointer(), i1, imo2, pcoeff.desc, + Vc.data(), i1, i1, pVc.desc, + beta, mat_mo + start, i1, i1, pmat_mo.desc); } } @@ -109,18 +109,18 @@ namespace LR char transb = 'N'; const std::complex alpha(1.0, 0.0); const std::complex beta = add_on ? std::complex(1.0, 0.0) : std::complex(0.0, 0.0); - pzgemm_(&transa, &transb, &naos, &nmo1, &naos, - &alpha, mat_ao[isk].data>(), &i1, &i1, pmat_ao.desc, - coeff.get_pointer(), &i1, &imo1, pcoeff.desc, - &beta, Vc.data>(), &i1, &i1, pVc.desc); + ScalapackConnector::gemm(transa, transb, naos, nmo1, naos, + alpha, mat_ao[isk].data>(), i1, i1, pmat_ao.desc, + coeff.get_pointer(), i1, imo1, pcoeff.desc, + beta, Vc.data>(), i1, i1, pVc.desc); transa = 'C'; // mat_mo = c ^ TVc // descC puts M(nvirt) to row - pzgemm_(&transa, &transb, &nmo2, &nmo1, &naos, - &alpha, coeff.get_pointer(), &i1, &imo2, pcoeff.desc, - Vc.data>(), &i1, &i1, pVc.desc, - &beta, mat_mo + start, &i1, &i1, pmat_mo.desc); + ScalapackConnector::gemm(transa, transb, nmo2, nmo1, naos, + alpha, coeff.get_pointer(), i1, imo2, pcoeff.desc, + Vc.data>(), i1, i1, pVc.desc, + beta, mat_mo + start, i1, i1, pmat_mo.desc); } } } diff --git a/source/source_lcao/module_lr/ao_to_mo_transformer/ao_to_mo_serial.cpp b/source/source_lcao/module_lr/ao_to_mo_transformer/ao_to_mo_serial.cpp index 1ddec7f8da..b65f871f32 100644 --- a/source/source_lcao/module_lr/ao_to_mo_transformer/ao_to_mo_serial.cpp +++ b/source/source_lcao/module_lr/ao_to_mo_transformer/ao_to_mo_serial.cpp @@ -110,15 +110,15 @@ namespace LR char transb = 'N'; //coeff is col major const double alpha = 1.0; const double beta = add_on ? 1.0 : 0.0; - dgemm_(&transa, &transb, &naos, &nmo1, &naos, &alpha, - mat_ao[isk].data(), &naos, coeff.get_pointer(imo1), &naos, &beta, - Vc.data(), &naos); + BlasConnector::gemm(transb, transa, nmo1, naos, naos, alpha, + coeff.get_pointer(imo1), naos, mat_ao[isk].data(), naos, beta, + Vc.data(), naos); transa = 'T'; //mat_mo=coeff^TVc (nvirt major) - dgemm_(&transa, &transb, &nmo2, &nmo1, &naos, &alpha, - coeff.get_pointer(imo2), &naos, Vc.data(), &naos, &beta, - mat_mo + start, &nmo2); + BlasConnector::gemm(transb, transa, nmo1, nmo2, naos, alpha, + Vc.data(), naos, coeff.get_pointer(imo2), naos, beta, + mat_mo + start, nmo2); } } template<> @@ -151,15 +151,15 @@ namespace LR char transb = 'N'; //coeff is col major const std::complex alpha(1.0, 0.0); const std::complex beta = add_on ? std::complex(1.0, 0.0) : std::complex(0.0, 0.0); - zgemm_(&transa, &transb, &naos, &nmo1, &naos, &alpha, - mat_ao[isk].data>(), &naos, coeff.get_pointer(imo1), &naos, &beta, - Vc.data>(), &naos); + BlasConnector::gemm(transb, transa, nmo1, naos, naos, alpha, + coeff.get_pointer(imo1), naos, mat_ao[isk].data>(), naos, beta, + Vc.data>(), naos); transa = 'C'; //mat_mo=coeff^\dagger Vc (nvirt major) - zgemm_(&transa, &transb, &nmo2, &nmo1, &naos, &alpha, - coeff.get_pointer(imo2), &naos, Vc.data>(), &naos, &beta, - mat_mo + start, &nmo2); + BlasConnector::gemm(transb, transa, nmo1, nmo2, naos, alpha, + Vc.data>(), naos, coeff.get_pointer(imo2), naos, beta, + mat_mo + start, nmo2); } } } \ No newline at end of file diff --git a/source/source_lcao/module_lr/dm_trans/dm_trans_parallel.cpp b/source/source_lcao/module_lr/dm_trans/dm_trans_parallel.cpp index ebc8f2c697..5b36f798b2 100644 --- a/source/source_lcao/module_lr/dm_trans/dm_trans_parallel.cpp +++ b/source/source_lcao/module_lr/dm_trans/dm_trans_parallel.cpp @@ -53,16 +53,16 @@ std::vector cal_dm_trans_pblas(const double* const X_istate, DEV::CpuDevice, {pXc.get_col_size(), pXc.get_row_size()}); // row is "inside"(memory contiguity) for pblas Xc.zero(); - pdgemm_(&transa, &transb, &naos, &nmo2, &nmo1, - &alpha, c.get_pointer(), &i1, &imo1, pc.desc, - X_istate + x_start, &i1, &i1, px.desc, - &beta, Xc.data(), &i1, &i1, pXc.desc); + ScalapackConnector::gemm(transa, transb, naos, nmo2, nmo1, + alpha, c.get_pointer(), 1, imo1, pc.desc, + X_istate + x_start, 1, 1, px.desc, + beta, Xc.data(), 1, 1, pXc.desc); // 2. C_virt*[X*C_occ^T] - pdgemm_(&transa, &transb, &naos, &naos, &nmo2, - &factor, c.get_pointer(), &i1, &imo2, pc.desc, - Xc.data(), &i1, &i1, pXc.desc, - &beta, dm_trans[isk].data(), &i1, &i1, pmat.desc); + ScalapackConnector::gemm(transa, transb, naos, naos, nmo2, + factor, c.get_pointer(), 1, imo2, pc.desc, + Xc.data(), 1, 1, pXc.desc, + beta, dm_trans[isk].data(), 1, 1, pmat.desc); } return dm_trans; } @@ -130,17 +130,17 @@ std::vector cal_dm_trans_pblas(const std::complex* co Xc.zero(); const std::complex alpha(1.0, 0.0); const std::complex beta(0.0, 0.0); - pzgemm_(&transa, &transb, &nmo2, &naos, &nmo1, &alpha, - X_istate + x_start, &i1, &i1, px.desc, - c.get_pointer(), &i1, &imo1, pc.desc, - &beta, Xc.data>(), &i1, &i1, pXc.desc); + ScalapackConnector::gemm(transa, transb, nmo2, naos, nmo1, alpha, + X_istate + x_start, i1, i1, px.desc, + c.get_pointer(), i1, imo1, pc.desc, + beta, Xc.data>(), i1, i1, pXc.desc); // 2. [X*C_occ^\dagger]^TC_virt^T transa = transb = 'T'; - pzgemm_(&transa, &transb, &naos, &naos, &nmo2, - &factor, Xc.data>(), &i1, &i1, pXc.desc, - c.get_pointer(), &i1, &imo2, pc.desc, - &beta, dm_trans[isk].data>(), &i1, &i1, pmat.desc); + ScalapackConnector::gemm(transa, transb, naos, naos, nmo2, + factor, Xc.data>(), i1, i1, pXc.desc, + c.get_pointer(), i1, imo2, pc.desc, + beta, dm_trans[isk].data>(), i1, i1, pmat.desc); } return dm_trans; } diff --git a/source/source_lcao/module_lr/dm_trans/dm_trans_serial.cpp b/source/source_lcao/module_lr/dm_trans/dm_trans_serial.cpp index 7509760345..3a4553e0b4 100644 --- a/source/source_lcao/module_lr/dm_trans/dm_trans_serial.cpp +++ b/source/source_lcao/module_lr/dm_trans/dm_trans_serial.cpp @@ -112,13 +112,13 @@ namespace LR const double alpha = 1.0; const double beta = 0.0; container::Tensor Xc(DAT::DT_DOUBLE, DEV::CpuDevice, { nmo2, naos }); - dgemm_(&transa, &transb, &naos, &nmo2, &nmo1, &alpha, - c.get_pointer(imo1), &naos, X_istate + x_start, &nmo2, - &beta, Xc.data(), &naos); + BlasConnector::gemm(transb, transa, nmo2, naos, nmo1, alpha, + X_istate + x_start, nmo2, c.get_pointer(imo1), naos, + beta, Xc.data(), naos); // 2. C_virt*[X*C_occ^T] - dgemm_(&transa, &transb, &naos, &naos, &nmo2, &factor, - c.get_pointer(imo2), &naos, Xc.data(), &naos, &beta, - dm_trans[isk].data(), &naos); + BlasConnector::gemm(transb, transa, naos, naos, nmo2, factor, + Xc.data(), naos, c.get_pointer(imo2), naos, beta, + dm_trans[isk].data(), naos); } return dm_trans; } @@ -166,14 +166,14 @@ namespace LR // ============== = [C_occ^* * X^T * C_virt^T]^T============= // 1. X*C_occ^\dagger container::Tensor Xc(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { naos, nmo2 }); - zgemm_(&transa, &transb, &nmo2, &naos, &nmo1, &alpha, - X_istate + x_start, &nmo2, c.get_pointer(imo1), &naos, - &beta, Xc.data>(), &nmo2); + BlasConnector::gemm_cm(transa, transb, nmo2, naos, nmo1, alpha, + X_istate + x_start, nmo2, c.get_pointer(imo1), naos, + beta, Xc.data>(), nmo2); // 2. [X*C_occ^\dagger]^TC_virt^T transa = transb = 'T'; - zgemm_(&transa, &transb, &naos, &naos, &nmo2, &factor, - Xc.data>(), &nmo2, c.get_pointer(imo2), &naos, &beta, - dm_trans[isk].data>(), &naos); + BlasConnector::gemm_cm(transa, transb, naos, naos, nmo2, factor, + Xc.data>(), nmo2, c.get_pointer(imo2), naos, beta, + dm_trans[isk].data>(), naos); } return dm_trans; } diff --git a/source/source_lcao/module_operator_lcao/deepks_lcao.cpp b/source/source_lcao/module_operator_lcao/deepks_lcao.cpp index b1dc5f9f12..67a77a0e4a 100644 --- a/source/source_lcao/module_operator_lcao/deepks_lcao.cpp +++ b/source/source_lcao/module_operator_lcao/deepks_lcao.cpp @@ -351,19 +351,19 @@ void hamilt::DeePKS>::calculate_HR() constexpr char transa = 'T', transb = 'N'; const double gemm_alpha = 1.0, gemm_beta = 1.0; - dgemm_(&transa, - &transb, - &col_size, - &row_size, - &trace_alpha_size, - &gemm_alpha, - s_2t.data(), - &trace_alpha_size, + BlasConnector::gemm(transb, + transa, + row_size, + col_size, + trace_alpha_size, + gemm_alpha, s_1t.data(), - &trace_alpha_size, - &gemm_beta, + trace_alpha_size, + s_2t.data(), + trace_alpha_size, + gemm_beta, hr_current.data(), - &col_size); + col_size); // add data of HR to target BaseMatrix #pragma omp critical diff --git a/source/source_lcao/module_rdmft/rdmft_tools.cpp b/source/source_lcao/module_rdmft/rdmft_tools.cpp index 1d3a8a4afd..ffc08c431a 100644 --- a/source/source_lcao/module_rdmft/rdmft_tools.cpp +++ b/source/source_lcao/module_rdmft/rdmft_tools.cpp @@ -44,8 +44,8 @@ void HkPsi(const Parallel_Orbitals* ParaV, const int nbands = ParaV->desc_wfc[3]; //because wfc(bands, basis'), H(basis, basis'), we do wfc*H^T(in the perspective of cpp, not in fortran). And get H_wfc(bands, basis) is correct. - pdgemm_( &C_char, &N_char, &nbasis, &nbands, &nbasis, &one_double, &HK, &one_int, &one_int, ParaV->desc, - &wfc, &one_int, &one_int, ParaV->desc_wfc, &zero_double, &H_wfc, &one_int, &one_int, ParaV->desc_wfc ); + ScalapackConnector::gemm( C_char, N_char, nbasis, nbands, nbasis, one_double, &HK, 1, 1, ParaV->desc, + &wfc, 1, 1, ParaV->desc_wfc, zero_double, &H_wfc, 1, 1, ParaV->desc_wfc ); #endif } @@ -71,8 +71,8 @@ void cal_bra_op_ket(const Parallel_Orbitals* ParaV, const int nbasis = ParaV->desc[2]; const int nbands = ParaV->desc_wfc[3]; - pdgemm_( &T_char, &N_char, &nbands, &nbands, &nbasis, &one_double, &wfc, &one_int, &one_int, ParaV->desc_wfc, - &H_wfc, &one_int, &one_int, ParaV->desc_wfc, &zero_double, &Dmn[0], &one_int, &one_int, para_Eij_in.desc ); + ScalapackConnector::gemm( T_char, N_char, nbands, nbands, nbasis, one_double, &wfc, 1, 1, ParaV->desc_wfc, + &H_wfc, 1, 1, ParaV->desc_wfc, zero_double, &Dmn[0], 1, 1, para_Eij_in.desc ); #endif } diff --git a/source/source_lcao/module_rdmft/rdmft_tools.h b/source/source_lcao/module_rdmft/rdmft_tools.h index 91c69fb8c4..0a4e28f548 100644 --- a/source/source_lcao/module_rdmft/rdmft_tools.h +++ b/source/source_lcao/module_rdmft/rdmft_tools.h @@ -77,8 +77,8 @@ void HkPsi(const Parallel_Orbitals* ParaV, const TK& HK, const TK& wfc, TK& H_wf const int nbands = ParaV->desc_wfc[3]; //because wfc(bands, basis'), H(basis, basis'), we do wfc*H^T(in the perspective of cpp, not in fortran). And get H_wfc(bands, basis) is correct. - pzgemm_( &C_char, &N_char, &nbasis, &nbands, &nbasis, &one_complex, &HK, &one_int, &one_int, ParaV->desc, - &wfc, &one_int, &one_int, ParaV->desc_wfc, &zero_complex, &H_wfc, &one_int, &one_int, ParaV->desc_wfc ); + ScalapackConnector::gemm( C_char, N_char, nbasis, nbands, nbasis, one_complex, &HK, one_int, one_int, ParaV->desc, + &wfc, one_int, one_int, ParaV->desc_wfc, zero_complex, &H_wfc, one_int, one_int, ParaV->desc_wfc ); #endif } @@ -104,8 +104,8 @@ void cal_bra_op_ket(const Parallel_Orbitals* ParaV, const Parallel_2D& para_Eij_ const int nbasis = ParaV->desc[2]; const int nbands = ParaV->desc_wfc[3]; - pzgemm_( &C_char, &N_char, &nbands, &nbands, &nbasis, &one_complex, &wfc, &one_int, &one_int, ParaV->desc_wfc, - &H_wfc, &one_int, &one_int, ParaV->desc_wfc, &zero_complex, &Dmn[0], &one_int, &one_int, para_Eij_in.desc ); + ScalapackConnector::gemm( C_char, N_char, nbands, nbands, nbasis, one_complex, &wfc, one_int, one_int, ParaV->desc_wfc, + &H_wfc, one_int, one_int, ParaV->desc_wfc, zero_complex, &Dmn[0], one_int, one_int, para_Eij_in.desc ); #endif } diff --git a/source/source_lcao/module_ri/module_exx_symmetry/symmetry_rotation.cpp b/source/source_lcao/module_ri/module_exx_symmetry/symmetry_rotation.cpp index 60d27afbf2..53922f42ac 100644 --- a/source/source_lcao/module_ri/module_exx_symmetry/symmetry_rotation.cpp +++ b/source/source_lcao/module_ri/module_exx_symmetry/symmetry_rotation.cpp @@ -417,24 +417,24 @@ namespace ModuleSymmetry if (TRS_conj) { // D^T* = M^T [M^T (D^T)^T]^\dagger - pzgemm_(&transpose, &transpose, &nbasis, &nbasis, &nbasis, - &alpha, this->Ms_[ik_ibz].at(isym).data(), &i1, &i1, pv.desc, DMkibz.data(), &i1, &i1, pv.desc, - &beta, DMkibz_M.data(), &i1, &i1, pv.desc); + ScalapackConnector::gemm(transpose, transpose, nbasis, nbasis, nbasis, + alpha, this->Ms_[ik_ibz].at(isym).data(), i1, i1, pv.desc, DMkibz.data(), i1, i1, pv.desc, + beta, DMkibz_M.data(), i1, i1, pv.desc); alpha.real(1.0 / static_cast(kstar_size)); - pzgemm_(&transpose, &dagger, &nbasis, &nbasis, &nbasis, - &alpha, this->Ms_[ik_ibz].at(isym).data(), &i1, &i1, pv.desc, DMkibz_M.data(), &i1, &i1, pv.desc, - &beta, DMk.data(), &i1, &i1, pv.desc); + ScalapackConnector::gemm(transpose, dagger, nbasis, nbasis, nbasis, + alpha, this->Ms_[ik_ibz].at(isym).data(), i1, i1, pv.desc, DMkibz_M.data(), i1, i1, pv.desc, + beta, DMk.data(), i1, i1, pv.desc); } else { // D^T = M^\daggger D^T M - pzgemm_(&dagger, ¬rans, &nbasis, &nbasis, &nbasis, - &alpha, this->Ms_[ik_ibz].at(isym).data(), &i1, &i1, pv.desc, DMkibz.data(), &i1, &i1, pv.desc, - &beta, DMkibz_M.data(), &i1, &i1, pv.desc); + ScalapackConnector::gemm(dagger, notrans, nbasis, nbasis, nbasis, + alpha, this->Ms_[ik_ibz].at(isym).data(), i1, i1, pv.desc, DMkibz.data(), i1, i1, pv.desc, + beta, DMkibz_M.data(), i1, i1, pv.desc); alpha.real(1.0 / static_cast(kstar_size)); - pzgemm_(¬rans, ¬rans, &nbasis, &nbasis, &nbasis, - &alpha, DMkibz_M.data(), &i1, &i1, pv.desc, this->Ms_[ik_ibz].at(isym).data(), &i1, &i1, pv.desc, - &beta, DMk.data(), &i1, &i1, pv.desc); + ScalapackConnector::gemm(notrans, notrans, nbasis, nbasis, nbasis, + alpha, DMkibz_M.data(), i1, i1, pv.desc, this->Ms_[ik_ibz].at(isym).data(), i1, i1, pv.desc, + beta, DMk.data(), i1, i1, pv.desc); } return DMk; } diff --git a/source/source_pw/module_pwdft/VNL_in_pw.cpp b/source/source_pw/module_pwdft/VNL_in_pw.cpp index e26e09fa3e..544beaf25e 100644 --- a/source/source_pw/module_pwdft/VNL_in_pw.cpp +++ b/source/source_pw/module_pwdft/VNL_in_pw.cpp @@ -1468,19 +1468,19 @@ void pseudopot_cell_vnl::newq(const ModuleBase::matrix& veff, const ModulePW::PW double* qg_ptr = reinterpret_cast(qg.c); double* aux_ptr = reinterpret_cast(aux.c); - dgemm_(&transa, - &transb, - &nij, - &natom, - &complex_npw, - &fact, - qg_ptr, - &complex_npw, + BlasConnector::gemm(transb, + transa, + natom, + nij, + complex_npw, + fact, aux_ptr, - &complex_npw, - &zero, + complex_npw, + qg_ptr, + complex_npw, + zero, deeaux.c, - &nij); + nij); // I'm not sure if this is correct for gamma_only if (rho_basis->gamma_only && rho_basis->ig_gge0 >= 0) { diff --git a/source/source_pw/module_pwdft/forces_us.cpp b/source/source_pw/module_pwdft/forces_us.cpp index f5a6a4444e..908212ca02 100644 --- a/source/source_pw/module_pwdft/forces_us.cpp +++ b/source/source_pw/module_pwdft/forces_us.cpp @@ -98,19 +98,19 @@ void Forces::cal_force_us(ModuleBase::matrix& forcenl, const double zero = 0; for (int ipol = 0; ipol < 3; ipol++) { - dgemm_(&transa, - &transb, - &nij, - &atom->na, - &dim, - &(ucell.omega), - qgm_data, - &dim, + BlasConnector::gemm(transb, + transa, + atom->na, + nij, + dim, + ucell.omega, &aux1_data[ipol * dim * atom->na], - &dim, - &zero, + dim, + qgm_data, + dim, + zero, &ddeeq(is, ipol, 0, 0), - &nij); + nij); } } diff --git a/source/source_pw/module_pwdft/stress_func_us.cpp b/source/source_pw/module_pwdft/stress_func_us.cpp index 1881dacafe..e18ffb9823 100644 --- a/source/source_pw/module_pwdft/stress_func_us.cpp +++ b/source/source_pw/module_pwdft/stress_func_us.cpp @@ -111,19 +111,19 @@ void Stress_PW::stress_us(ModuleBase::matrix& sigma, const int dim = 2 * npw; const double one = 1; const double zero = 0; - dgemm_(&transa, - &transb, - &dim, - &PARAM.inp.nspin, - &nij, - &one, - qgm_data, - &dim, + BlasConnector::gemm(transb, + transa, + PARAM.inp.nspin, + dim, + nij, + one, tbecsum.c, - &nij, - &zero, + nij, + qgm_data, + dim, + zero, aux2_data, - &dim); + dim); for (int is = 0; is < PARAM.inp.nspin; is++) { @@ -148,19 +148,19 @@ void Stress_PW::stress_us(ModuleBase::matrix& sigma, ModuleBase::matrix fac(PARAM.inp.nspin, 3); const char transc = 'T'; const int three = 3; - dgemm_(&transc, - &transb, - &three, - &PARAM.inp.nspin, - &dim, - &one, + BlasConnector::gemm_cm(transc, + transb, + three, + PARAM.inp.nspin, + dim, + one, aux1_data, - &dim, + dim, aux2_data, - &dim, - &zero, + dim, + zero, fac.c, - &three); + three); for (int is = 0; is < PARAM.inp.nspin; is++) { diff --git a/source/source_pw/module_stodft/sto_dos.cpp b/source/source_pw/module_stodft/sto_dos.cpp index 8a8f2cce9b..dd90224e15 100644 --- a/source/source_pw/module_stodft/sto_dos.cpp +++ b/source/source_pw/module_stodft/sto_dos.cpp @@ -157,7 +157,7 @@ void Sto_DOS::caldos(const double sigmain, const double de, cons double* vec_all = (double*)allorderchi.data(); int LDA = npwx * nchipk_new * 2; int M = npwx * nchipk_new * 2; - dgemm_(&trans, &normal, &N, &N, &M, &kweight, vec_all, &LDA, vec_all, &LDA, &one, spolyv.data(), &N); + BlasConnector::gemm(normal, trans, N, N, M, kweight, vec_all, LDA, vec_all, LDA, one, spolyv.data(), N); } } }