From 45f98390b5b8dfeabd73c8e1ceb01983a8613077 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 13 May 2024 14:53:01 +0000 Subject: [PATCH] [Bugfix][Disco] Handle NDArray larger than OS buffer for pipe Prior to this commit, using `disco.Session` methods to transfer `NDArray` instances to workers could raise an exception if the `NDArray` is larger than the buffer allocated by the OS for the controller/worker pipe. In these case, the first call to the `Read` method of `tvm::support::Pipe` would successfully return, but only with the initial bytes of the `NDArray`. Receiving the full `NDArray` requires repeatedly calling the POSIX `read` function. This commit updates the `Read` and `Write` methods of `tvm::support::Pipe` to repeatedly call the underlying read/write methods, until the full `NDArray` has been transferred. This commit does not add any unit tests, as the existing unit test `tests/python/disco/test_ccl.py::test_attention[nccl-ProcessSession]` requires this change to pass. --- src/support/pipe.h | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/src/support/pipe.h b/src/support/pipe.h index 4babc5b7c422..50ad2b578661 100644 --- a/src/support/pipe.h +++ b/src/support/pipe.h @@ -86,8 +86,19 @@ class Pipe : public dmlc::Stream { DWORD nread = static_cast(RetryCallOnEINTR(fread, GetLastErrorCode)); ICHECK_EQ(static_cast(nread), size) << "Read Error: " << GetLastError(); #else - ssize_t nread = RetryCallOnEINTR([&]() { return read(handle_, ptr, size); }, GetLastErrorCode); - ICHECK_GE(nread, 0) << "Write Error: " << strerror(errno); + size_t nread = 0; + while (size) { + ssize_t nread_chunk = + RetryCallOnEINTR([&]() { return read(handle_, ptr, size); }, GetLastErrorCode); + ICHECK_NE(nread_chunk, -1) << "Write Error: " << strerror(errno); + + ICHECK_GT(nread_chunk, 0) << "Was unable to read any data from pipe"; + ICHECK_LE(nread_chunk, size) << "Read " << nread_chunk << " bytes, " + << "but only expected to read " << size << " bytes"; + size -= nread_chunk; + ptr = static_cast(ptr) + nread_chunk; + nread += nread_chunk; + } #endif return static_cast(nread); } @@ -109,9 +120,17 @@ class Pipe : public dmlc::Stream { DWORD nwrite = static_cast(RetryCallOnEINTR(fwrite, GetLastErrorCode)); ICHECK_EQ(static_cast(nwrite), size) << "Write Error: " << GetLastError(); #else - ssize_t nwrite = - RetryCallOnEINTR([&]() { return write(handle_, ptr, size); }, GetLastErrorCode); - ICHECK_EQ(static_cast(nwrite), size) << "Write Error: " << strerror(errno); + while (size) { + ssize_t nwrite = + RetryCallOnEINTR([&]() { return write(handle_, ptr, size); }, GetLastErrorCode); + ICHECK_NE(nwrite, -1) << "Write Error: " << strerror(errno); + + ICHECK_GT(nwrite, 0) << "Was unable to write any data to pipe"; + ICHECK_LE(nwrite, size) << "Wrote " << nwrite << " bytes, " + << "but only expected to write " << size << " bytes"; + size -= nwrite; + ptr = static_cast(ptr) + nwrite; + } #endif } /*!