diff --git a/source/source_base/parallel_reduce.cpp b/source/source_base/parallel_reduce.cpp index 03535573b7..5b833c3ec9 100644 --- a/source/source_base/parallel_reduce.cpp +++ b/source/source_base/parallel_reduce.cpp @@ -117,6 +117,14 @@ void Parallel_Reduce::reduce_pool(double& object) return; } +template <> +void Parallel_Reduce::reduce_pool(int* object, const int n) +{ +#ifdef __MPI + MPI_Allreduce(MPI_IN_PLACE, object, n, MPI_INT, MPI_SUM, POOL_WORLD); +#endif +} + template <> void Parallel_Reduce::reduce_pool(double* object, const int n) { diff --git a/source/source_estate/module_charge/symmetry_rhog.cpp b/source/source_estate/module_charge/symmetry_rhog.cpp index b16ac8476c..7d37df1d80 100644 --- a/source/source_estate/module_charge/symmetry_rhog.cpp +++ b/source/source_estate/module_charge/symmetry_rhog.cpp @@ -1,4 +1,5 @@ #include "symmetry_rho.h" +#include "source_base/parallel_reduce.h" #include "source_base/parallel_global.h" #include "source_hamilt/module_xc/xc_functional.h" @@ -9,7 +10,7 @@ void Symmetry_rho::psymmg(std::complex* rhog_part, const ModulePW::PW_Ba int * fftixy2is = new int [rho_basis->fftnxy]; rho_basis->getfftixy2is(fftixy2is); //current proc #ifdef __MPI - MPI_Allreduce(MPI_IN_PLACE, fftixy2is, rho_basis->fftnxy, MPI_INT, MPI_SUM, POOL_WORLD); + Parallel_Reduce::reduce_pool(fftixy2is, rho_basis->fftnxy); if(rho_basis->poolnproc>1) for (int i=0;ifftnxy;++i) fftixy2is[i]+=rho_basis->poolnproc-1; diff --git a/source/source_io/output_log.cpp b/source/source_io/output_log.cpp index 7a4471b0a6..8c3ba0d114 100644 --- a/source/source_io/output_log.cpp +++ b/source/source_io/output_log.cpp @@ -4,7 +4,7 @@ #include "source_base/constants.h" #include "source_base/formatter.h" #include "source_base/global_variable.h" - +#include "source_base/parallel_reduce.h" #include "source_base/parallel_comm.h" #ifdef __MPI @@ -154,7 +154,7 @@ void output_vacuum_level(const UnitCell* ucell, } #ifdef __MPI - MPI_Allreduce(MPI_IN_PLACE, ave, length, MPI_DOUBLE, MPI_SUM, POOL_WORLD); + Parallel_Reduce::reduce_pool(ave, length); #endif int surface = nxyz / length; diff --git a/source/source_pw/module_pwdft/elecond.cpp b/source/source_pw/module_pwdft/elecond.cpp index 068ca01067..34644ef986 100644 --- a/source/source_pw/module_pwdft/elecond.cpp +++ b/source/source_pw/module_pwdft/elecond.cpp @@ -93,9 +93,9 @@ void EleCond::KG(const int& smear_type, jjresponse_ks(ik, nt, dt, decut, wg, velop, ct11.data(), ct12.data(), ct22.data()); } #ifdef __MPI - MPI_Allreduce(MPI_IN_PLACE, ct11.data(), nt, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); - MPI_Allreduce(MPI_IN_PLACE, ct12.data(), nt, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); - MPI_Allreduce(MPI_IN_PLACE, ct22.data(), nt, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); + Parallel_Reduce::reduce_all(ct11.data(), nt); + Parallel_Reduce::reduce_all(ct12.data(), nt); + Parallel_Reduce::reduce_all(ct22.data(), nt); #endif //------------------------------------------------------------------ // Output diff --git a/source/source_pw/module_pwdft/vnl_pw.cpp b/source/source_pw/module_pwdft/vnl_pw.cpp index 52d3d447da..dd017e1f4e 100644 --- a/source/source_pw/module_pwdft/vnl_pw.cpp +++ b/source/source_pw/module_pwdft/vnl_pw.cpp @@ -9,6 +9,7 @@ #include "source_base/math_sphbes.h" #include "source_base/math_ylmreal.h" #include "source_base/memory.h" +#include "source_base/parallel_reduce.h" #include "source_base/module_device/device.h" #include "source_base/timer.h" #include "source_pw/module_pwdft/kernels/vnl_op.h" @@ -684,8 +685,8 @@ void pseudopot_cell_vnl::init_vnl(UnitCell& cell, const ModulePW::PW_Basis* rho_ } #ifdef __MPI - MPI_Allreduce(MPI_IN_PLACE, this->qq_nt.ptr, this->qq_nt.getSize(), MPI_DOUBLE, MPI_SUM, POOL_WORLD); - MPI_Allreduce(MPI_IN_PLACE, this->qq_so.ptr, this->qq_so.getSize(), MPI_DOUBLE_COMPLEX, MPI_SUM, POOL_WORLD); + Parallel_Reduce::reduce_pool(this->qq_nt.ptr, this->qq_nt.getSize()); + Parallel_Reduce::reduce_pool(this->qq_so.ptr, this->qq_so.getSize()); #endif // set the atomic specific qq_at matrices @@ -1511,7 +1512,7 @@ void pseudopot_cell_vnl::newq(const ModuleBase::matrix& veff, const ModulePW::PW } #ifdef __MPI - MPI_Allreduce(MPI_IN_PLACE, deeq.ptr, deeq.getSize(), MPI_DOUBLE, MPI_SUM, POOL_WORLD); + Parallel_Reduce::reduce_pool(deeq.ptr,deeq.getSize()); #endif delete[] qnorm; diff --git a/source/source_pw/module_stodft/sto_dos.cpp b/source/source_pw/module_stodft/sto_dos.cpp index dd90224e15..2665dde510 100644 --- a/source/source_pw/module_stodft/sto_dos.cpp +++ b/source/source_pw/module_stodft/sto_dos.cpp @@ -235,9 +235,9 @@ void Sto_DOS::caldos(const double sigmain, const double de, cons } #ifdef __MPI MPI_Allreduce(MPI_IN_PLACE, ks_dos.data(), ndos, MPI_DOUBLE, MPI_SUM, INT_BGROUP); - MPI_Allreduce(MPI_IN_PLACE, sto_dos.data(), ndos, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); - MPI_Allreduce(MPI_IN_PLACE, error.data(), ndos, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); -#endif + Parallel_Reduce::reduce_all(sto_dos.data(), ndos); + Parallel_Reduce::reduce_all(error.data(), ndos); + #endif if (GlobalV::MY_RANK == 0) { std::string dosfile = PARAM.globalv.global_out_dir + "dos_sdft.txt"; diff --git a/source/source_pw/module_stodft/sto_elecond.cpp b/source/source_pw/module_stodft/sto_elecond.cpp index b0fe4d71d7..1a32e9db83 100644 --- a/source/source_pw/module_stodft/sto_elecond.cpp +++ b/source/source_pw/module_stodft/sto_elecond.cpp @@ -1059,9 +1059,9 @@ void Sto_EleCond::sKG(const int& smear_type, } // ik loop ModuleBase::timer::tick("Sto_EleCond", "kloop"); #ifdef __MPI - MPI_Allreduce(MPI_IN_PLACE, ct11.data(), nt, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); - MPI_Allreduce(MPI_IN_PLACE, ct12.data(), nt, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); - MPI_Allreduce(MPI_IN_PLACE, ct22.data(), nt, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); + Parallel_Reduce::reduce_all(ct11.data(), nt); + Parallel_Reduce::reduce_all(ct12.data(), nt); + Parallel_Reduce::reduce_all(ct22.data(), nt); #endif //------------------------------------------------------------------ diff --git a/source/source_pw/module_stodft/sto_iter.cpp b/source/source_pw/module_stodft/sto_iter.cpp index cd00e2f6f2..b2267917e2 100644 --- a/source/source_pw/module_stodft/sto_iter.cpp +++ b/source/source_pw/module_stodft/sto_iter.cpp @@ -248,7 +248,7 @@ void Stochastic_Iter::check_precision(const double ref, const double } #ifdef __MPI - MPI_Allreduce(MPI_IN_PLACE, &error, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); + Parallel_Reduce::reduce_all(error); #endif double relative_error = std::abs(error / ref); GlobalV::ofs_running << info << "Relative Chebyshev Precision: " << relative_error * 1e9 << "E-09" << std::endl; @@ -472,7 +472,7 @@ double Stochastic_Iter::calne(elecstate::ElecState* pes) { MPI_Allreduce(MPI_IN_PLACE, &KS_ne, 1, MPI_DOUBLE, MPI_SUM, BP_WORLD); } - MPI_Allreduce(MPI_IN_PLACE, &sto_ne, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); + Parallel_Reduce::reduce_all(sto_ne); #endif totne = KS_ne + sto_ne; @@ -539,7 +539,7 @@ void Stochastic_Iter::sum_stoeband(Stochastic_WF& stowf, { MPI_Allreduce(MPI_IN_PLACE, &pes->f_en.demet, 1, MPI_DOUBLE, MPI_SUM, BP_WORLD); } - MPI_Allreduce(MPI_IN_PLACE, &stodemet, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); + Parallel_Reduce::reduce_all(stodemet); #endif pes->f_en.demet += stodemet; this->check_precision(pes->f_en.demet, 1e-4, "TS"); @@ -580,7 +580,7 @@ void Stochastic_Iter::sum_stoeband(Stochastic_WF& stowf, } } #ifdef __MPI - MPI_Allreduce(MPI_IN_PLACE, &sto_eband, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); + Parallel_Reduce::reduce_all(sto_eband); #endif pes->f_en.eband += sto_eband; ModuleBase::timer::tick("Stochastic_Iter", "sum_stoeband"); @@ -694,7 +694,7 @@ void Stochastic_Iter::cal_storho(const UnitCell& ucell, sto_ne *= ucell.omega / wfc_basis->nxyz; #ifdef __MPI - MPI_Allreduce(MPI_IN_PLACE, &sto_ne, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD); + Parallel_Reduce::reduce_pool(sto_ne); #endif double factor = targetne / (KS_ne + sto_ne); if (std::abs(factor - 1) > 1e-10)