diff --git a/source/source_base/parallel_global.cpp b/source/source_base/parallel_global.cpp index dcfea10d30..f227e6518f 100644 --- a/source/source_base/parallel_global.cpp +++ b/source/source_base/parallel_global.cpp @@ -328,9 +328,10 @@ void Parallel_Global::divide_pools(const int& NPROC, // and MY_BNDGROUP will be the same as well. if(BNDPAR > 1 && NPROC %(BNDPAR * KPAR) != 0) { - std::cout << "Error: When BNDPAR = " << BNDPAR << " > 1, number of processes (" << NPROC << ") must be divisible by the number of groups (" - << BNDPAR * KPAR << ")." << std::endl; - exit(1); + std::cout << "Error: When BNDPAR = " << BNDPAR << " > 1, number of processes (" << NPROC + << ") must be divisible by the number of groups (" << BNDPAR * KPAR << ")." << std::endl; + ModuleBase::WARNING_QUIT("ParallelGlobal::divide_pools", + "When BNDPAR > 1, number of processes NPROC must be divisible by the number of groups BNDPAR * KPAR."); } // k-point parallelization MPICommGroup kpar_group(MPI_COMM_WORLD); diff --git a/source/source_base/test_parallel/parallel_global_test.cpp b/source/source_base/test_parallel/parallel_global_test.cpp index 9b8799a0c6..3b6bf8491f 100644 --- a/source/source_base/test_parallel/parallel_global_test.cpp +++ b/source/source_base/test_parallel/parallel_global_test.cpp @@ -8,8 +8,9 @@ #include #include #include +#include -#include "source_base/tool_quit.h" +#include "source_base/global_variable.h" /************************************************ * unit test of functions in parallel_global.cpp @@ -66,6 +67,7 @@ class MPIContext int _size; }; +// --- Normal Test --- class ParaGlobal : public ::testing::Test { protected: @@ -79,6 +81,7 @@ class ParaGlobal : public ::testing::Test } }; + TEST_F(ParaGlobal, SplitGrid) { // NPROC is set to 4 in parallel_global_test.sh @@ -162,14 +165,126 @@ TEST_F(ParaGlobal, MyProd) EXPECT_EQ(inout[1], std::complex(-3.0, -3.0)); } -TEST_F(ParaGlobal, InitPools) + + +TEST_F(ParaGlobal, DivideMPIPools) +{ + this->nproc = 12; + mpi.kpar = 3; + this->my_rank = 5; + Parallel_Global::divide_mpi_groups(this->nproc, + mpi.kpar, + this->my_rank, + mpi.nproc_in_pool, + mpi.my_pool, + mpi.rank_in_pool); + EXPECT_EQ(mpi.nproc_in_pool, 4); + EXPECT_EQ(mpi.my_pool, 1); + EXPECT_EQ(mpi.rank_in_pool, 1); +} + + +class FakeMPIContext +{ + public: + FakeMPIContext() + { + _rank = 0; + _size = 1; + } + + int GetRank() const + { + return _rank; + } + int GetSize() const + { + return _size; + } + + int drank; + int dsize; + int dcolor; + + int grank; + int gsize; + + int kpar; + int nproc_in_pool; + int my_pool; + int rank_in_pool; + + int nstogroup; + int MY_BNDGROUP; + int rank_in_stogroup; + int nproc_in_stogroup; + + private: + int _rank; + int _size; +}; + +// --- DeathTest: Single thread --- +// Since these precondition checks cause the processes to die, we call such tests death tests. +// convention of naming the test suite: *DeathTest +// Death tests should be run in a single-threaded context. +// Such DeathTest will be run before all other tests. +class ParaGlobalDeathTest : public ::testing::Test +{ + protected: + FakeMPIContext mpi; + int nproc; + int my_rank; + int real_rank; + + // DeathTest SetUp: + // Init variable, single thread + void SetUp() override + { + int is_init = 0; + MPI_Initialized(&is_init); + if (is_init) { + MPI_Comm_rank(MPI_COMM_WORLD, &real_rank); + } else { + real_rank = 0; + } + + if (real_rank != 0) return; + + nproc = mpi.GetSize(); + my_rank = mpi.GetRank(); + + // init log file needed by WARNING_QUIT + GlobalV::ofs_warning.open("warning.log"); + + + } + + // clean log file + void TearDown() override + { + if (real_rank != 0) return; + + GlobalV::ofs_warning.close(); + remove("warning.log"); + } +}; + +TEST_F(ParaGlobalDeathTest, InitPools) { + if (real_rank != 0) return; nproc = 12; mpi.kpar = 3; mpi.nstogroup = 3; my_rank = 5; - testing::internal::CaptureStdout(); - EXPECT_EXIT(Parallel_Global::init_pools(nproc, + EXPECT_EXIT( + // This gtest Macro expect that a given `statement` causes the program to exit, with an + // integer exit status that satisfies `predicate`(Here ::testing::ExitedWithCode(1)), + // and emitting error output that matches `matcher`(Here "Error"). + { + // redirect stdout to stderr to capture WARNING_QUIT output + dup2(STDERR_FILENO, STDOUT_FILENO); + Parallel_Global::init_pools(nproc, my_rank, mpi.nstogroup, mpi.kpar, @@ -178,35 +293,83 @@ TEST_F(ParaGlobal, InitPools) mpi.MY_BNDGROUP, mpi.nproc_in_pool, mpi.rank_in_pool, - mpi.my_pool), ::testing::ExitedWithCode(1), ""); - std::string output = testing::internal::GetCapturedStdout(); - EXPECT_THAT(output, testing::HasSubstr("Error:")); + mpi.my_pool); + }, + ::testing::ExitedWithCode(1), + "Error"); } - -TEST_F(ParaGlobal, DivideMPIPools) +TEST_F(ParaGlobalDeathTest, DivideMPIPoolsNgEqZero) { + if (real_rank != 0) return; + // test for num_groups == 0, + // Num_group Equals 0 + // WARNING_QUIT this->nproc = 12; - mpi.kpar = 3; - this->my_rank = 5; - Parallel_Global::divide_mpi_groups(this->nproc, + mpi.kpar = 0; + EXPECT_EXIT( + { + // redirect stdout to stderr to capture WARNING_QUIT output + dup2(STDERR_FILENO, STDOUT_FILENO); + Parallel_Global::divide_mpi_groups(this->nproc, mpi.kpar, this->my_rank, mpi.nproc_in_pool, mpi.my_pool, mpi.rank_in_pool); - EXPECT_EQ(mpi.nproc_in_pool, 4); - EXPECT_EQ(mpi.my_pool, 1); - EXPECT_EQ(mpi.rank_in_pool, 1); + }, + ::testing::ExitedWithCode(1), + "Number of groups must be greater than 0." + ); +} + +TEST_F(ParaGlobalDeathTest, DivideMPIPoolsNgGtProc) +{ + if (real_rank != 0) return; + // test for procs < num_groups + // Num_group GreaterThan Processors + // WARNING_QUIT + this->nproc = 12; + mpi.kpar = 24; + this->my_rank = 5; + EXPECT_EXIT( + { + // redirect stdout to stderr to capture WARNING_QUIT output + dup2(STDERR_FILENO, STDOUT_FILENO); + Parallel_Global::divide_mpi_groups(this->nproc, + mpi.kpar, + this->my_rank, + mpi.nproc_in_pool, + mpi.my_pool, + mpi.rank_in_pool); + }, + testing::ExitedWithCode(1), + "Error: Number of processes.*must be greater than the number of groups" + ); } int main(int argc, char** argv) { + bool is_death_test_child = false; + for (int i = 0; i < argc; ++i) { + if (std::string(argv[i]).find("gtest_internal_run_death_test") != std::string::npos) { + is_death_test_child = true; + break; + } + } + + if (!is_death_test_child) + { + MPI_Init(&argc, &argv); + } - MPI_Init(&argc, &argv); testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; int result = RUN_ALL_TESTS(); - MPI_Finalize(); + + if (!is_death_test_child) { + MPI_Finalize(); + } return result; } #endif // __MPI