diff --git a/src/support/errno_handling.h b/src/support/errno_handling.h new file mode 100644 index 000000000000..0bdfdfdf022c --- /dev/null +++ b/src/support/errno_handling.h @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file errno_handling.h + * \brief Common error number handling functions for socket.h and pipe.h + */ +#ifndef TVM_SUPPORT_ERRNO_HANDLING_H_ +#define TVM_SUPPORT_ERRNO_HANDLING_H_ +#include + +#include "ssize.h" + +namespace tvm { +namespace support { +/*! + * \brief Call a function and retry if an EINTR error is encountered. + * + * Socket operations can return EINTR when the interrupt handler + * is registered by the execution environment(e.g. python). + * We should retry if there is no KeyboardInterrupt recorded in + * the environment. + * + * \note This function is needed to avoid rare interrupt event + * in long running server code. + * + * \param func The function to retry. + * \return The return code returned by function f or error_value on retry failure. + */ +template +inline ssize_t RetryCallOnEINTR(FuncType func, GetErrorCodeFuncType fgeterrorcode) { + ssize_t ret = func(); + // common path + if (ret != -1) return ret; + // less common path + do { + if (fgeterrorcode() == EINTR) { + // Call into env check signals to see if there are + // environment specific(e.g. python) signal exceptions. + // This function will throw an exception if there is + // if the process received a signal that requires TVM to return immediately (e.g. SIGINT). + runtime::EnvCheckSignals(); + } else { + // other errors + return ret; + } + ret = func(); + } while (ret == -1); + return ret; +} +} // namespace support +} // namespace tvm +#endif // TVM_SUPPORT_ERRNO_HANDLING_H_ diff --git a/src/support/pipe.h b/src/support/pipe.h index d869504dc4e9..557fe89e4670 100644 --- a/src/support/pipe.h +++ b/src/support/pipe.h @@ -36,6 +36,7 @@ #include #include #endif +#include "errno_handling.h" namespace tvm { namespace support { @@ -52,8 +53,21 @@ class Pipe : public dmlc::Stream { #endif /*! \brief destructor */ ~Pipe() { Flush(); } + using Stream::Read; using Stream::Write; + + /*! + * \return last error of pipe operation + */ + static int GetLastErrorCode() { +#ifdef _WIN32 + return GetLastError(); +#else + return errno; +#endif + } + /*! * \brief reads data from a file descriptor * \param ptr pointer to a memory buffer @@ -63,12 +77,15 @@ class Pipe : public dmlc::Stream { size_t Read(void* ptr, size_t size) final { if (size == 0) return 0; #ifdef _WIN32 - DWORD nread; - ICHECK(ReadFile(handle_, static_cast(ptr), size, &nread, nullptr)) - << "Read Error: " << GetLastError(); + auto fread = [&]() { + DWORD nread; + if (!ReadFile(handle_, static_cast(ptr), size, &nread, nullptr)) return -1; + return nread; + }; + DWORD nread = static_cast(RetryCallOnEINTR(fread, GetLastErrorCode)); + ICHECK_EQ(static_cast(nread), size) << "Read Error: " << GetLastError(); #else - ssize_t nread; - nread = read(handle_, ptr, size); + ssize_t nread = RetryCallOnEINTR([&]() { return read(handle_, ptr, size); }, GetLastErrorCode); ICHECK_GE(nread, 0) << "Write Error: " << strerror(errno); #endif return static_cast(nread); @@ -82,13 +99,16 @@ class Pipe : public dmlc::Stream { void Write(const void* ptr, size_t size) final { if (size == 0) return; #ifdef _WIN32 - DWORD nwrite; - ICHECK(WriteFile(handle_, static_cast(ptr), size, &nwrite, nullptr) && - static_cast(nwrite) == size) - << "Write Error: " << GetLastError(); + auto fwrite = [&]() { + DWORD nwrite; + if (!WriteFile(handle_, static_cast(ptr), size, &nwrite, nullptr)) return -1; + return nwrite; + }; + DWORD nwrite = static_cast(RetryCallOnEINTR(fwrite, GetLastErrorCode)); + ICHECK_EQ(static_cast(nwrite), size) << "Write Error: " << GetLastError(); #else - ssize_t nwrite; - nwrite = write(handle_, ptr, size); + ssize_t nwrite = + RetryCallOnEINTR([&]() { return write(handle_, ptr, size); }, GetLastErrorCode); ICHECK_EQ(static_cast(nwrite), size) << "Write Error: " << strerror(errno); #endif } diff --git a/src/support/socket.h b/src/support/socket.h index f62702bbc445..ac13cd3f2d35 100644 --- a/src/support/socket.h +++ b/src/support/socket.h @@ -39,7 +39,6 @@ #endif #else #include -#include #include #include #include @@ -56,8 +55,9 @@ #include #include -#include "../support/ssize.h" -#include "../support/utils.h" +#include "errno_handling.h" +#include "ssize.h" +#include "utils.h" #if defined(_WIN32) static inline int poll(struct pollfd* pfd, int nfds, int timeout) { @@ -310,7 +310,7 @@ class Socket { /*! * \return last error of socket operation */ - static int GetLastError() { + static int GetLastErrorCode() { #ifdef _WIN32 return WSAGetLastError(); #else @@ -319,7 +319,7 @@ class Socket { } /*! \return whether last error was would block */ static bool LastErrorWouldBlock() { - int errsv = GetLastError(); + int errsv = GetLastErrorCode(); #ifdef _WIN32 return errsv == WSAEWOULDBLOCK; #else @@ -355,7 +355,7 @@ class Socket { * \param msg The error message. */ static void Error(const char* msg) { - int errsv = GetLastError(); + int errsv = GetLastErrorCode(); #ifdef _WIN32 LOG(FATAL) << "Socket " << msg << " Error:WSAError-code=" << errsv; #else @@ -363,42 +363,6 @@ class Socket { #endif } - /*! - * \brief Call a function and retry if an EINTR error is encountered. - * - * Socket operations can return EINTR when the interrupt handler - * is registered by the execution environment(e.g. python). - * We should retry if there is no KeyboardInterrupt recorded in - * the environment. - * - * \note This function is needed to avoid rare interrupt event - * in long running server code. - * - * \param func The function to retry. - * \return The return code returned by function f or error_value on retry failure. - */ - template - ssize_t RetryCallOnEINTR(FuncType func) { - ssize_t ret = func(); - // common path - if (ret != -1) return ret; - // less common path - do { - if (GetLastError() == EINTR) { - // Call into env check signals to see if there are - // environment specific(e.g. python) signal exceptions. - // This function will throw an exception if there is - // if the process received a signal that requires TVM to return immediately (e.g. SIGINT). - runtime::EnvCheckSignals(); - } else { - // other errors - return ret; - } - ret = func(); - } while (ret == -1); - return ret; - } - protected: explicit Socket(SockType sockfd) : sockfd(sockfd) {} }; @@ -445,7 +409,8 @@ class TCPSocket : public Socket { * \return The accepted socket connection. */ TCPSocket Accept() { - SockType newfd = RetryCallOnEINTR([&]() { return accept(sockfd, nullptr, nullptr); }); + SockType newfd = + RetryCallOnEINTR([&]() { return accept(sockfd, nullptr, nullptr); }, GetLastErrorCode); if (newfd == INVALID_SOCKET) { Socket::Error("Accept"); } @@ -459,7 +424,8 @@ class TCPSocket : public Socket { TCPSocket Accept(SockAddr* addr) { socklen_t addrlen = sizeof(addr->addr); SockType newfd = RetryCallOnEINTR( - [&]() { return accept(sockfd, reinterpret_cast(&addr->addr), &addrlen); }); + [&]() { return accept(sockfd, reinterpret_cast(&addr->addr), &addrlen); }, + GetLastErrorCode); if (newfd == INVALID_SOCKET) { Socket::Error("Accept"); } @@ -500,7 +466,7 @@ class TCPSocket : public Socket { ssize_t Send(const void* buf_, size_t len, int flag = 0) { const char* buf = reinterpret_cast(buf_); return RetryCallOnEINTR( - [&]() { return send(sockfd, buf, static_cast(len), flag); }); + [&]() { return send(sockfd, buf, static_cast(len), flag); }, GetLastErrorCode); } /*! * \brief receive data using the socket @@ -513,7 +479,8 @@ class TCPSocket : public Socket { ssize_t Recv(void* buf_, size_t len, int flags = 0) { char* buf = reinterpret_cast(buf_); return RetryCallOnEINTR( - [&]() { return recv(sockfd, buf, static_cast(len), flags); }); + [&]() { return recv(sockfd, buf, static_cast(len), flags); }, + GetLastErrorCode); } /*! * \brief perform block write that will attempt to send all data out @@ -527,7 +494,8 @@ class TCPSocket : public Socket { size_t ndone = 0; while (ndone < len) { ssize_t ret = RetryCallOnEINTR( - [&]() { return send(sockfd, buf, static_cast(len - ndone), 0); }); + [&]() { return send(sockfd, buf, static_cast(len - ndone), 0); }, + GetLastErrorCode); if (ret == -1) { if (LastErrorWouldBlock()) return ndone; Socket::Error("SendAll"); @@ -549,7 +517,8 @@ class TCPSocket : public Socket { size_t ndone = 0; while (ndone < len) { ssize_t ret = RetryCallOnEINTR( - [&]() { return recv(sockfd, buf, static_cast(len - ndone), MSG_WAITALL); }); + [&]() { return recv(sockfd, buf, static_cast(len - ndone), MSG_WAITALL); }, + GetLastErrorCode); if (ret == -1) { if (LastErrorWouldBlock()) { LOG(FATAL) << "would block";