Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions source/source_base/parallel_global.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,10 @@ void Parallel_Global::divide_pools(const int& NPROC,
// and MY_BNDGROUP will be the same as well.
if(BNDPAR > 1 && NPROC %(BNDPAR * KPAR) != 0)
{
std::cout << "Error: When BNDPAR = " << BNDPAR << " > 1, number of processes (" << NPROC << ") must be divisible by the number of groups ("
<< BNDPAR * KPAR << ")." << std::endl;
exit(1);
std::cout << "Error: When BNDPAR = " << BNDPAR << " > 1, number of processes (" << NPROC
<< ") must be divisible by the number of groups (" << BNDPAR * KPAR << ")." << std::endl;
ModuleBase::WARNING_QUIT("ParallelGlobal::divide_pools",
"When BNDPAR > 1, number of processes NPROC must be divisible by the number of groups BNDPAR * KPAR.");
}
// k-point parallelization
MPICommGroup kpar_group(MPI_COMM_WORLD);
Expand Down
197 changes: 180 additions & 17 deletions source/source_base/test_parallel/parallel_global_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
#include <complex>
#include <cstring>
#include <string>
#include <unistd.h>

#include "source_base/tool_quit.h"
#include "source_base/global_variable.h"

/************************************************
* unit test of functions in parallel_global.cpp
Expand Down Expand Up @@ -66,6 +67,7 @@ class MPIContext
int _size;
};

// --- Normal Test ---
class ParaGlobal : public ::testing::Test
{
protected:
Expand All @@ -79,6 +81,7 @@ class ParaGlobal : public ::testing::Test
}
};


TEST_F(ParaGlobal, SplitGrid)
{
// NPROC is set to 4 in parallel_global_test.sh
Expand Down Expand Up @@ -162,14 +165,126 @@ TEST_F(ParaGlobal, MyProd)
EXPECT_EQ(inout[1], std::complex<double>(-3.0, -3.0));
}

TEST_F(ParaGlobal, InitPools)


TEST_F(ParaGlobal, DivideMPIPools)
{
this->nproc = 12;
mpi.kpar = 3;
this->my_rank = 5;
Parallel_Global::divide_mpi_groups(this->nproc,
mpi.kpar,
this->my_rank,
mpi.nproc_in_pool,
mpi.my_pool,
mpi.rank_in_pool);
EXPECT_EQ(mpi.nproc_in_pool, 4);
EXPECT_EQ(mpi.my_pool, 1);
EXPECT_EQ(mpi.rank_in_pool, 1);
}


class FakeMPIContext
{
public:
FakeMPIContext()
{
_rank = 0;
_size = 1;
}

int GetRank() const
{
return _rank;
}
int GetSize() const
{
return _size;
}

int drank;
int dsize;
int dcolor;

int grank;
int gsize;

int kpar;
int nproc_in_pool;
int my_pool;
int rank_in_pool;

int nstogroup;
int MY_BNDGROUP;
int rank_in_stogroup;
int nproc_in_stogroup;

private:
int _rank;
int _size;
};

// --- DeathTest: Single thread ---
// Since these precondition checks cause the processes to die, we call such tests death tests.
// convention of naming the test suite: *DeathTest
// Death tests should be run in a single-threaded context.
// Such DeathTest will be run before all other tests.
class ParaGlobalDeathTest : public ::testing::Test
{
protected:
FakeMPIContext mpi;
int nproc;
int my_rank;
int real_rank;

// DeathTest SetUp:
// Init variable, single thread
void SetUp() override
{
int is_init = 0;
MPI_Initialized(&is_init);
if (is_init) {
MPI_Comm_rank(MPI_COMM_WORLD, &real_rank);
} else {
real_rank = 0;
}

if (real_rank != 0) return;

nproc = mpi.GetSize();
my_rank = mpi.GetRank();

// init log file needed by WARNING_QUIT
GlobalV::ofs_warning.open("warning.log");


}

// clean log file
void TearDown() override
{
if (real_rank != 0) return;

GlobalV::ofs_warning.close();
remove("warning.log");
}
};

TEST_F(ParaGlobalDeathTest, InitPools)
{
if (real_rank != 0) return;
nproc = 12;
mpi.kpar = 3;
mpi.nstogroup = 3;
my_rank = 5;
testing::internal::CaptureStdout();
EXPECT_EXIT(Parallel_Global::init_pools(nproc,
EXPECT_EXIT(
// This gtest Macro expect that a given `statement` causes the program to exit, with an
// integer exit status that satisfies `predicate`(Here ::testing::ExitedWithCode(1)),
// and emitting error output that matches `matcher`(Here "Error").
{
// redirect stdout to stderr to capture WARNING_QUIT output
dup2(STDERR_FILENO, STDOUT_FILENO);
Parallel_Global::init_pools(nproc,
my_rank,
mpi.nstogroup,
mpi.kpar,
Expand All @@ -178,35 +293,83 @@ TEST_F(ParaGlobal, InitPools)
mpi.MY_BNDGROUP,
mpi.nproc_in_pool,
mpi.rank_in_pool,
mpi.my_pool), ::testing::ExitedWithCode(1), "");
std::string output = testing::internal::GetCapturedStdout();
EXPECT_THAT(output, testing::HasSubstr("Error:"));
mpi.my_pool);
},
::testing::ExitedWithCode(1),
"Error");
}


TEST_F(ParaGlobal, DivideMPIPools)
TEST_F(ParaGlobalDeathTest, DivideMPIPoolsNgEqZero)
{
if (real_rank != 0) return;
// test for num_groups == 0,
// Num_group Equals 0
// WARNING_QUIT
this->nproc = 12;
mpi.kpar = 3;
this->my_rank = 5;
Parallel_Global::divide_mpi_groups(this->nproc,
mpi.kpar = 0;
EXPECT_EXIT(
{
// redirect stdout to stderr to capture WARNING_QUIT output
dup2(STDERR_FILENO, STDOUT_FILENO);
Parallel_Global::divide_mpi_groups(this->nproc,
mpi.kpar,
this->my_rank,
mpi.nproc_in_pool,
mpi.my_pool,
mpi.rank_in_pool);
EXPECT_EQ(mpi.nproc_in_pool, 4);
EXPECT_EQ(mpi.my_pool, 1);
EXPECT_EQ(mpi.rank_in_pool, 1);
},
::testing::ExitedWithCode(1),
"Number of groups must be greater than 0."
);
}

TEST_F(ParaGlobalDeathTest, DivideMPIPoolsNgGtProc)
{
if (real_rank != 0) return;
// test for procs < num_groups
// Num_group GreaterThan Processors
// WARNING_QUIT
this->nproc = 12;
mpi.kpar = 24;
this->my_rank = 5;
EXPECT_EXIT(
{
// redirect stdout to stderr to capture WARNING_QUIT output
dup2(STDERR_FILENO, STDOUT_FILENO);
Parallel_Global::divide_mpi_groups(this->nproc,
mpi.kpar,
this->my_rank,
mpi.nproc_in_pool,
mpi.my_pool,
mpi.rank_in_pool);
},
testing::ExitedWithCode(1),
"Error: Number of processes.*must be greater than the number of groups"
);
}

int main(int argc, char** argv)
{
bool is_death_test_child = false;
for (int i = 0; i < argc; ++i) {
if (std::string(argv[i]).find("gtest_internal_run_death_test") != std::string::npos) {
is_death_test_child = true;
break;
}
}

if (!is_death_test_child)
{
MPI_Init(&argc, &argv);
}

MPI_Init(&argc, &argv);
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
int result = RUN_ALL_TESTS();
MPI_Finalize();

if (!is_death_test_child) {
MPI_Finalize();
}
return result;
}
#endif // __MPI
Loading