Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions src/scheduler/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,12 @@ void MpiWorld::reduce(int sendRank,
// the receiver is not co-located with us, do a reduce with the data of
// all our local ranks, and then send the result to the receiver
if (getHostForRank(recvRank) != thisHost) {
// In this step we reduce our local ranks data. It is important
// that we do so in a copy of the send buffer, as the application
// does not expect said buffer's contents to be modified.
auto sendBufferCopy = std::make_unique<uint8_t[]>(bufferSize);
memcpy(sendBufferCopy.get(), sendBuffer, bufferSize);

for (const int r : localRanks) {
if (r == sendRank) {
continue;
Expand All @@ -1120,21 +1126,28 @@ void MpiWorld::reduce(int sendRank,
nullptr,
faabric::MPIMessage::REDUCE);

// Note that we accumulate the reuce operation on the send
// buffer, not the receive one, as we later need to send all
// the reduced data (including ours) to the root rank
op_reduce(
operation, datatype, count, rankData.get(), sendBuffer);
op_reduce(operation,
datatype,
count,
rankData.get(),
sendBufferCopy.get());
}
}

// Send to the receiver rank
send(sendRank,
recvRank,
sendBuffer,
datatype,
count,
faabric::MPIMessage::REDUCE);
send(sendRank,
recvRank,
sendBufferCopy.get(),
datatype,
count,
faabric::MPIMessage::REDUCE);
} else {
// Send to the receiver rank
send(sendRank,
recvRank,
sendBuffer,
datatype,
count,
faabric::MPIMessage::REDUCE);
}
} else {
// If we are neither the receiver of the reduce nor a local leader, we
// send our data for reduction either to our local leader or the
Expand Down
73 changes: 73 additions & 0 deletions tests/test/scheduler/test_remote_mpi_worlds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,79 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture,
thisWorld.destroy();
}

TEST_CASE_METHOD(RemoteCollectiveTestFixture,
"Test reduce across hosts multiple times",
"[mpi]")
{
MpiWorld& thisWorld = setUpThisWorld();

std::vector<int> messageData = { 0, 1, 2 };
int recvRank = 0;
int numRepeats = 10;

std::thread otherWorldThread([this, recvRank, numRepeats, &messageData] {
otherWorld.initialiseFromMsg(msg);

std::vector<int> sendRanksInOrder = { 4, 5, 3 };
for (int i = 0; i < numRepeats; i++) {
for (int r : sendRanksInOrder) {
otherWorld.reduce(r,
recvRank,
BYTES(messageData.data()),
nullptr,
MPI_INT,
messageData.size(),
MPI_SUM);
}
}

// Give the other host time to receive the broadcast
testLatch->wait();
otherWorld.destroy();
});

std::vector<int> actual = messageData;
std::vector<int> sendRanksInOrder = { 1, 2, 0 };
for (int i = 0; i < numRepeats; i++) {
for (int r : sendRanksInOrder) {
if (r != recvRank) {
thisWorld.reduce(r,
recvRank,
BYTES(messageData.data()),
nullptr,
MPI_INT,
messageData.size(),
MPI_SUM);
} else {
thisWorld.reduce(r,
recvRank,
BYTES(actual.data()),
BYTES(actual.data()),
MPI_INT,
messageData.size(),
MPI_SUM);
}
}
}

// The world size is hardcoded in the test fixture
int worldSize = 6;
std::vector<int> expected(messageData.size());
for (int i = 0; i < messageData.size(); i++) {
expected.at(i) = messageData.at(i) * worldSize +
messageData.at(i) * (worldSize - 1) * (numRepeats - 1);
}
REQUIRE(actual == expected);

// Clean up
testLatch->wait();
if (otherWorldThread.joinable()) {
otherWorldThread.join();
}

thisWorld.destroy();
}

TEST_CASE_METHOD(RemoteCollectiveTestFixture,
"Test scatter across hosts",
"[mpi]")
Expand Down