diff --git a/source/source_base/parallel_reduce.cpp b/source/source_base/parallel_reduce.cpp index 03535573b7..ab5c8505b5 100644 --- a/source/source_base/parallel_reduce.cpp +++ b/source/source_base/parallel_reduce.cpp @@ -233,85 +233,65 @@ void Parallel_Reduce::gather_int_all(int& v, int* all) return; } -void Parallel_Reduce::gather_min_int_all(const int& nproc, int& v) +template <> +void Parallel_Reduce::reduce_min(int& v) { #ifdef __MPI - std::vector all(nproc, 0); - MPI_Allgather(&v, 1, MPI_INT, all.data(), 1, MPI_INT, MPI_COMM_WORLD); - for (int i = 0; i < nproc; i++) - { - if (v > all[i]) - { - v = all[i]; - } - } + MPI_Allreduce(MPI_IN_PLACE, &v, 1, MPI_INT, MPI_MIN, MPI_COMM_WORLD); #endif } -void Parallel_Reduce::gather_max_double_all(const int& nproc, double& v) +template <> +void Parallel_Reduce::reduce_min(float& v) { #ifdef __MPI - std::vector value(nproc, 0.0); - MPI_Allgather(&v, 1, MPI_DOUBLE, value.data(), 1, MPI_DOUBLE, MPI_COMM_WORLD); - for (int i = 0; i < nproc; i++) - { - if (v < value[i]) - { - v = value[i]; - } - } + MPI_Allreduce(MPI_IN_PLACE, &v, 1, MPI_FLOAT, MPI_MIN, MPI_COMM_WORLD); #endif } -void Parallel_Reduce::gather_max_double_pool(const int& nproc_in_pool, double& v) +template <> +void Parallel_Reduce::reduce_min(double& v) { #ifdef __MPI - if (nproc_in_pool == 1) - { - return; - } - std::vector value(nproc_in_pool, 0.0); - MPI_Allgather(&v, 1, MPI_DOUBLE, value.data(), 1, MPI_DOUBLE, POOL_WORLD); - for (int i = 0; i < nproc_in_pool; i++) - { - if (v < value[i]) - { - v = value[i]; - } - } + MPI_Allreduce(MPI_IN_PLACE, &v, 1, MPI_DOUBLE, MPI_MIN, MPI_COMM_WORLD); +#endif +} + +template <> +void Parallel_Reduce::reduce_max(float& v) +{ +#ifdef __MPI + MPI_Allreduce(MPI_IN_PLACE, &v, 1, MPI_FLOAT, MPI_MAX, MPI_COMM_WORLD); +#endif +} + +template <> +void Parallel_Reduce::reduce_max(double& v) +{ +#ifdef __MPI + MPI_Allreduce(MPI_IN_PLACE, &v, 1, MPI_DOUBLE, MPI_MAX, MPI_COMM_WORLD); #endif } -void Parallel_Reduce::gather_min_double_pool(const int& nproc_in_pool, double& v) +template <> +void Parallel_Reduce::reduce_max_pool(const int& nproc_in_pool, double& v) { #ifdef __MPI if (nproc_in_pool == 1) { return; } - std::vector value(nproc_in_pool, 0.0); - MPI_Allgather(&v, 1, MPI_DOUBLE, value.data(), 1, MPI_DOUBLE, POOL_WORLD); - for (int i = 0; i < nproc_in_pool; i++) - { - if (v > value[i]) - { - v = value[i]; - } - } + MPI_Allreduce(MPI_IN_PLACE, &v, 1, MPI_DOUBLE, MPI_MAX, POOL_WORLD); #endif } - -void Parallel_Reduce::gather_min_double_all(const int& nproc, double& v) +template <> +void Parallel_Reduce::reduce_min_pool(const int& nproc_in_pool, double& v) { #ifdef __MPI - std::vector value(nproc, 0.0); - MPI_Allgather(&v, 1, MPI_DOUBLE, value.data(), 1, MPI_DOUBLE, MPI_COMM_WORLD); - for (int i = 0; i < nproc; i++) + if (nproc_in_pool == 1) { - if (v > value[i]) - { - v = value[i]; - } + return; } + MPI_Allreduce(MPI_IN_PLACE, &v, 1, MPI_DOUBLE, MPI_MIN, POOL_WORLD); #endif } \ No newline at end of file diff --git a/source/source_base/parallel_reduce.h b/source/source_base/parallel_reduce.h index 7ab85be1cb..a781989951 100644 --- a/source/source_base/parallel_reduce.h +++ b/source/source_base/parallel_reduce.h @@ -21,6 +21,14 @@ template void reduce_pool(T& object); template void reduce_pool(T* object, const int n); +template +void reduce_min(T& v); +template +void reduce_max(T& v); +template +void reduce_min_pool(const int& nproc_in_pool, T& v); +template +void reduce_max_pool(const int& nproc_in_pool, T& v); void reduce_int_diag(int& object); // mohan add 2012-01-12 @@ -34,13 +42,6 @@ void reduce_double_diag(double* object, const int n); void reduce_double_allpool(const int& npool, const int& nproc_in_pool, double& object); void reduce_double_allpool(const int& npool, const int& nproc_in_pool, double* object, const int n); -void gather_min_int_all(const int& nproc, int& v); -void gather_max_double_all(const int& nproc, double& v); -void gather_min_double_all(const int& nproc, double& v); -void gather_max_double_pool(const int& nproc_in_pool, double& v); -void gather_min_double_pool(const int& nproc_in_pool, double& v); - -// mohan add 2011-04-21 void gather_int_all(int& v, int* all); bool check_if_equal(double& v); // mohan add 2009-11-11 diff --git a/source/source_base/test_parallel/parallel_reduce_test.cpp b/source/source_base/test_parallel/parallel_reduce_test.cpp index 696de3b485..ac980ba24d 100644 --- a/source/source_base/test_parallel/parallel_reduce_test.cpp +++ b/source/source_base/test_parallel/parallel_reduce_test.cpp @@ -30,9 +30,9 @@ * 3. ReduceComplexAll: * Tests two variations of reduce_complex_all() * 4. GatherIntAll: - * Tests gather_int_all() and gather_min_int_all() + * Tests gather_int_all() and reduce_min() * 5. GatherDoubleAll: - * Tests gather_min_double_all() and gather_max_double_all() + * Tests reduce_min_double() and reduce_max_double() * 6. ReduceIntDiag: * Tests reduce_int_diag() * 7. ReduceDoubleDiag: @@ -47,7 +47,7 @@ * 11.ReduceComplexPool: * Tests two variations of reduce_pool() * 12.GatherDoublePool: - * Tests gather_min_double_pool() and gather_max_double_pool() + * Tests reduce_min_pool() and reduce_max_pool() * * */ @@ -233,7 +233,7 @@ TEST_F(ParaReduce, GatherIntAll) EXPECT_EQ(local_number, array[my_rank]); // get minimum integer among all processes int min_number = local_number; - Parallel_Reduce::gather_min_int_all(nproc, min_number); + Parallel_Reduce::reduce_min(min_number); for (int i = 0; i < nproc; i++) { EXPECT_LE(min_number, array[i]); @@ -256,10 +256,10 @@ TEST_F(ParaReduce, GatherDoubleAll) EXPECT_EQ(local_number, array[my_rank]); // get minimum integer among all processes double min_number = local_number; - Parallel_Reduce::gather_min_double_all(nproc, min_number); + Parallel_Reduce::reduce_min(min_number); // get maximum integer among all processes double max_number = local_number; - Parallel_Reduce::gather_max_double_all(nproc, max_number); + Parallel_Reduce::reduce_max(max_number); for (int i = 0; i < nproc; i++) { EXPECT_LE(min_number, array[i]); @@ -587,10 +587,10 @@ TEST_F(ParaReduce, GatherDoublePool) EXPECT_EQ(local_number, array[mpiContext.rank_in_pool]); // get minimum integer among all processes double min_number = local_number; - Parallel_Reduce::gather_min_double_pool(mpiContext.nproc_in_pool, min_number); + Parallel_Reduce::reduce_min_pool(mpiContext.nproc_in_pool, min_number); // get maximum integer among all processes double max_number = local_number; - Parallel_Reduce::gather_max_double_pool(mpiContext.nproc_in_pool, max_number); + Parallel_Reduce::reduce_max_pool(mpiContext.nproc_in_pool, max_number); for (int i = 0; i < mpiContext.nproc_in_pool; i++) { EXPECT_LE(min_number, array[i]); diff --git a/source/source_basis/module_pw/pw_basis_big.h b/source/source_basis/module_pw/pw_basis_big.h index 2a04720877..987af787b3 100644 --- a/source/source_basis/module_pw/pw_basis_big.h +++ b/source/source_basis/module_pw/pw_basis_big.h @@ -2,6 +2,7 @@ #define PW_BASIS_BIG_H #include "source_base/constants.h" #include "source_base/global_function.h" + #ifdef __MPI #include "mpi.h" #endif diff --git a/source/source_cell/k_vector_utils.cpp b/source/source_cell/k_vector_utils.cpp index 3f5bef0b44..6af9e22835 100644 --- a/source/source_cell/k_vector_utils.cpp +++ b/source/source_cell/k_vector_utils.cpp @@ -245,7 +245,7 @@ void kvec_mpi_k(K_Vectors& kv) ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "Number of k-points in this process", kv.nks); int nks_minimum = kv.nks; - Parallel_Reduce::gather_min_int_all(GlobalV::NPROC, nks_minimum); + Parallel_Reduce::reduce_min(nks_minimum); if (nks_minimum == 0) { diff --git a/source/source_estate/elecstate_energy.cpp b/source/source_estate/elecstate_energy.cpp index 5083ea4621..ea8687f853 100644 --- a/source/source_estate/elecstate_energy.cpp +++ b/source/source_estate/elecstate_energy.cpp @@ -47,11 +47,10 @@ void ElecState::cal_bandgap() { vbm =this->eferm.ef; } -#ifdef __MPI - Parallel_Reduce::gather_max_double_all(GlobalV::NPROC, vbm); - Parallel_Reduce::gather_min_double_all(GlobalV::NPROC, cbm); -#endif - + #ifdef __MPI + Parallel_Reduce::reduce_max(vbm); + Parallel_Reduce::reduce_min(cbm); + #endif this->bandgap = cbm - vbm; } @@ -119,14 +118,12 @@ void ElecState::cal_bandgap_updw() { vbm_dw =this->eferm.ef_dw; } - -#ifdef __MPI - Parallel_Reduce::gather_max_double_all(GlobalV::NPROC, vbm_up); - Parallel_Reduce::gather_min_double_all(GlobalV::NPROC, cbm_up); - Parallel_Reduce::gather_max_double_all(GlobalV::NPROC, vbm_dw); - Parallel_Reduce::gather_min_double_all(GlobalV::NPROC, cbm_dw); -#endif - + #ifdef __MPI + Parallel_Reduce::reduce_max(vbm_up); + Parallel_Reduce::reduce_min(cbm_up); + Parallel_Reduce::reduce_max(vbm_dw); + Parallel_Reduce::reduce_min(cbm_dw); + #endif this->bandgap_up = cbm_up - vbm_up; this->bandgap_dw = cbm_dw - vbm_dw; } diff --git a/source/source_estate/occupy.cpp b/source/source_estate/occupy.cpp index 95fd4a83a4..fa50d1520d 100644 --- a/source/source_estate/occupy.cpp +++ b/source/source_estate/occupy.cpp @@ -179,10 +179,9 @@ void Occupy::iweights( } } } -#ifdef __MPI - Parallel_Reduce::gather_max_double_all(GlobalV::NPROC, ef); -#endif - + #ifdef __MPI + Parallel_Reduce::reduce_max(ef); + #endif return; } @@ -306,13 +305,11 @@ void Occupy::efermig(const ModuleBase::matrix& ekb, eup += 2 * smearing_sigma; elw -= 2 * smearing_sigma; - -#ifdef __MPI // find min and max across pools - Parallel_Reduce::gather_max_double_all(GlobalV::NPROC, eup); - Parallel_Reduce::gather_min_double_all(GlobalV::NPROC, elw); - -#endif + #ifdef __MPI + Parallel_Reduce::reduce_max(eup); + Parallel_Reduce::reduce_min(elw); + #endif //================= // Bisection method //================= diff --git a/source/source_io/cal_dos.cpp b/source/source_io/cal_dos.cpp index 6eb53aac8f..2e3e7f869c 100644 --- a/source/source_io/cal_dos.cpp +++ b/source/source_io/cal_dos.cpp @@ -56,8 +56,8 @@ void ModuleIO::prepare_dos(std::ofstream& ofs_running, } #ifdef __MPI - Parallel_Reduce::gather_max_double_all(GlobalV::NPROC, emax); - Parallel_Reduce::gather_min_double_all(GlobalV::NPROC, emin); + Parallel_Reduce::reduce_max(emax); + Parallel_Reduce::reduce_min(emin); #endif emax *= ModuleBase::Ry_to_eV; diff --git a/source/source_lcao/module_operator_lcao/op_exx_lcao.hpp b/source/source_lcao/module_operator_lcao/op_exx_lcao.hpp index 1d7e61409f..876b4f6af3 100644 --- a/source/source_lcao/module_operator_lcao/op_exx_lcao.hpp +++ b/source/source_lcao/module_operator_lcao/op_exx_lcao.hpp @@ -3,6 +3,7 @@ #ifdef __EXX #include "op_exx_lcao.h" +#include "source_base/parallel_reduce.h" #include "source_io/module_parameter/parameter.h" #include "source_lcao/module_ri/RI_2D_Comm.h" #include "source_hamilt/module_xc/xc_functional.h" @@ -244,10 +245,9 @@ OperatorEXX>::OperatorEXX(HS_Matrix_K* hsk_in, if (!ifs) { all_exist = 0; break; } } // Add MPI communication to synchronize all_exist across processes -#ifdef __MPI - // don't read in any files if one of the processes doesn't have it - MPI_Allreduce(MPI_IN_PLACE, &all_exist, 1, MPI_INT, MPI_MIN, MPI_COMM_WORLD); -#endif + #ifdef __MPI + Parallel_Reduce::reduce_min(all_exist); + #endif if (all_exist) { // Read HexxR in CSR format @@ -264,9 +264,9 @@ OperatorEXX>::OperatorEXX(HS_Matrix_K* hsk_in, const std::string restart_HR_path_cereal = GlobalC::restart.folder + "HexxR_" + std::to_string(PARAM.globalv.myrank); std::ifstream ifs(restart_HR_path_cereal, std::ios::binary); int all_exist_cereal = ifs ? 1 : 0; -#ifdef __MPI - MPI_Allreduce(MPI_IN_PLACE, &all_exist_cereal, 1, MPI_INT, MPI_MIN, MPI_COMM_WORLD); -#endif + #ifdef __MPI + Parallel_Reduce::reduce_min(all_exist_cereal); + #endif if (!all_exist_cereal) { //no HexxR file in CSR or binary format diff --git a/source/source_pw/module_pwdft/setup_pwwfc.cpp b/source/source_pw/module_pwdft/setup_pwwfc.cpp index 759178638c..cd06c917fc 100644 --- a/source/source_pw/module_pwdft/setup_pwwfc.cpp +++ b/source/source_pw/module_pwdft/setup_pwwfc.cpp @@ -1,5 +1,6 @@ #include "source_pw/module_pwdft/setup_pwwfc.h" // pw_wfc #include "source_base/parallel_comm.h" // POOL_WORLD +#include "source_base/parallel_reduce.h" // Parallel_Reduce #include "source_io/print_info.h" // print information void pw::teardown_pwwfc(ModulePW::PW_Basis_K* &pw_wfc) @@ -52,14 +53,12 @@ void pw::setup_pwwfc(const Input_para& inp, pw_rho.nz); pw_wfc->initparameters(false, inp.ecutwfc, kv.get_nks(), kv.kvec_d.data()); - #ifdef __MPI if (inp.pw_seed > 0) { - MPI_Allreduce(MPI_IN_PLACE, &pw_wfc->ggecut, 1, MPI_DOUBLE, MPI_MAX, MPI_COMM_WORLD); + Parallel_Reduce::reduce_max( pw_wfc->ggecut); } - // qianrui add 2021-8-13 to make different kpar parameters can get the same - // results + // qianrui add 2021-8-13 to make different kpar parameters can get the same result #endif pw_wfc->fft_bundle.initfftmode(inp.fft_mode); diff --git a/source/source_pw/module_stodft/sto_iter.cpp b/source/source_pw/module_stodft/sto_iter.cpp index 99fd2ef2bb..5e6d1e6d9a 100644 --- a/source/source_pw/module_stodft/sto_iter.cpp +++ b/source/source_pw/module_stodft/sto_iter.cpp @@ -203,8 +203,8 @@ void Stochastic_Iter::checkemm(const int& ik, if (ik == nks - 1) { #ifdef __MPI - MPI_Allreduce(MPI_IN_PLACE, p_hamilt_sto->emax, 1, MPI_DOUBLE, MPI_MAX, MPI_COMM_WORLD); - MPI_Allreduce(MPI_IN_PLACE, p_hamilt_sto->emin, 1, MPI_DOUBLE, MPI_MIN, MPI_COMM_WORLD); + Parallel_Reduce::reduce_max(*p_hamilt_sto->emax); + Parallel_Reduce::reduce_min(*p_hamilt_sto->emin); MPI_Allreduce(MPI_IN_PLACE, &change, 1, MPI_CHAR, MPI_LOR, MPI_COMM_WORLD); #endif if (change) diff --git a/source/source_pw/module_stodft/sto_tool.cpp b/source/source_pw/module_stodft/sto_tool.cpp index de1e72e3f1..4ba359310b 100644 --- a/source/source_pw/module_stodft/sto_tool.cpp +++ b/source/source_pw/module_stodft/sto_tool.cpp @@ -2,6 +2,7 @@ #include "source_base/math_chebyshev.h" #include "source_base/parallel_device.h" +#include "source_base/parallel_reduce.h" #include "source_base/timer.h" #include "source_io/module_parameter/parameter.h" #ifdef __MPI @@ -103,8 +104,8 @@ void check_che_op::operator()(const int& nche_in, if (ik == nk - 1) { #ifdef __MPI - MPI_Allreduce(MPI_IN_PLACE, p_hamilt_sto->emax, 1, MPI_DOUBLE, MPI_MAX, MPI_COMM_WORLD); - MPI_Allreduce(MPI_IN_PLACE, p_hamilt_sto->emin, 1, MPI_DOUBLE, MPI_MIN, MPI_COMM_WORLD); + Parallel_Reduce::reduce_max(*p_hamilt_sto->emax); + Parallel_Reduce::reduce_min(*p_hamilt_sto->emin); #endif GlobalV::ofs_running << "New Emax " << *p_hamilt_sto->emax << " Ry; new Emin " << *p_hamilt_sto->emin << " Ry" << std::endl;