diff --git a/src/scheduler/MpiWorld.cpp b/src/scheduler/MpiWorld.cpp index a2614fbe6..d61dca062 100644 --- a/src/scheduler/MpiWorld.cpp +++ b/src/scheduler/MpiWorld.cpp @@ -1106,6 +1106,12 @@ void MpiWorld::reduce(int sendRank, // the receiver is not co-located with us, do a reduce with the data of // all our local ranks, and then send the result to the receiver if (getHostForRank(recvRank) != thisHost) { + // In this step we reduce our local ranks data. It is important + // that we do so in a copy of the send buffer, as the application + // does not expect said buffer's contents to be modified. + auto sendBufferCopy = std::make_unique(bufferSize); + memcpy(sendBufferCopy.get(), sendBuffer, bufferSize); + for (const int r : localRanks) { if (r == sendRank) { continue; @@ -1120,21 +1126,28 @@ void MpiWorld::reduce(int sendRank, nullptr, faabric::MPIMessage::REDUCE); - // Note that we accumulate the reuce operation on the send - // buffer, not the receive one, as we later need to send all - // the reduced data (including ours) to the root rank - op_reduce( - operation, datatype, count, rankData.get(), sendBuffer); + op_reduce(operation, + datatype, + count, + rankData.get(), + sendBufferCopy.get()); } - } - // Send to the receiver rank - send(sendRank, - recvRank, - sendBuffer, - datatype, - count, - faabric::MPIMessage::REDUCE); + send(sendRank, + recvRank, + sendBufferCopy.get(), + datatype, + count, + faabric::MPIMessage::REDUCE); + } else { + // Send to the receiver rank + send(sendRank, + recvRank, + sendBuffer, + datatype, + count, + faabric::MPIMessage::REDUCE); + } } else { // If we are neither the receiver of the reduce nor a local leader, we // send our data for reduction either to our local leader or the diff --git a/tests/test/scheduler/test_remote_mpi_worlds.cpp b/tests/test/scheduler/test_remote_mpi_worlds.cpp index faf6c9b75..a7870874d 100644 --- a/tests/test/scheduler/test_remote_mpi_worlds.cpp +++ b/tests/test/scheduler/test_remote_mpi_worlds.cpp @@ -411,6 +411,79 @@ TEST_CASE_METHOD(RemoteCollectiveTestFixture, thisWorld.destroy(); } +TEST_CASE_METHOD(RemoteCollectiveTestFixture, + "Test reduce across hosts multiple times", + "[mpi]") +{ + MpiWorld& thisWorld = setUpThisWorld(); + + std::vector messageData = { 0, 1, 2 }; + int recvRank = 0; + int numRepeats = 10; + + std::thread otherWorldThread([this, recvRank, numRepeats, &messageData] { + otherWorld.initialiseFromMsg(msg); + + std::vector sendRanksInOrder = { 4, 5, 3 }; + for (int i = 0; i < numRepeats; i++) { + for (int r : sendRanksInOrder) { + otherWorld.reduce(r, + recvRank, + BYTES(messageData.data()), + nullptr, + MPI_INT, + messageData.size(), + MPI_SUM); + } + } + + // Give the other host time to receive the broadcast + testLatch->wait(); + otherWorld.destroy(); + }); + + std::vector actual = messageData; + std::vector sendRanksInOrder = { 1, 2, 0 }; + for (int i = 0; i < numRepeats; i++) { + for (int r : sendRanksInOrder) { + if (r != recvRank) { + thisWorld.reduce(r, + recvRank, + BYTES(messageData.data()), + nullptr, + MPI_INT, + messageData.size(), + MPI_SUM); + } else { + thisWorld.reduce(r, + recvRank, + BYTES(actual.data()), + BYTES(actual.data()), + MPI_INT, + messageData.size(), + MPI_SUM); + } + } + } + + // The world size is hardcoded in the test fixture + int worldSize = 6; + std::vector expected(messageData.size()); + for (int i = 0; i < messageData.size(); i++) { + expected.at(i) = messageData.at(i) * worldSize + + messageData.at(i) * (worldSize - 1) * (numRepeats - 1); + } + REQUIRE(actual == expected); + + // Clean up + testLatch->wait(); + if (otherWorldThread.joinable()) { + otherWorldThread.join(); + } + + thisWorld.destroy(); +} + TEST_CASE_METHOD(RemoteCollectiveTestFixture, "Test scatter across hosts", "[mpi]")