diff --git a/csrc/multidevice/ipc_handle.cpp b/csrc/multidevice/ipc_handle.cpp index 6bb700dc2de..9a5ec4286b8 100644 --- a/csrc/multidevice/ipc_handle.cpp +++ b/csrc/multidevice/ipc_handle.cpp @@ -151,6 +151,12 @@ void IpcHandleCache::exchangeHandles( insert(communication, std::move(ipc_handles)); } + + // a second barrier is needed here to ensure all ranks have received the + // memhandles and the keys are deleted from the store before the next call to + // exchangeHandles + // TODO: precisely select what ranks need to wait on that barrier. + communicator->barrier(); } } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_communications.cpp b/tests/cpp/test_multidevice_communications.cpp index 1b6ce59801c..af0c0719aa7 100644 --- a/tests/cpp/test_multidevice_communications.cpp +++ b/tests/cpp/test_multidevice_communications.cpp @@ -417,7 +417,7 @@ INSTANTIATE_TEST_SUITE_P( using P2PCommunicationTest = MultiDeviceTest; -TEST_F(P2PCommunicationTest, DISABLED_CudaComm) { +TEST_F(P2PCommunicationTest, CudaComm) { static constexpr int kTensorSize = 8; static constexpr int kNumRepetitions = 32; diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index 0b6efbd15a4..88286d6e4c0 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -478,7 +478,7 @@ TEST_F(OverlapDistributedMatmulTest, AG_linear) { EXPECT_TRUE(torch::allclose(out_ref, out_at, 1e-1, 1e-1)); } -TEST_F(MultiDeviceTest, DISABLED_ShareIpcMemHandles) { +TEST_F(MultiDeviceTest, ShareIpcMemHandles) { static constexpr int kTensorSize = 4; static constexpr int kNumRepetitions = 10;