diff --git a/cpp/src/arrow/flight/flight_internals_test.cc b/cpp/src/arrow/flight/flight_internals_test.cc index f7b731f01cb..84040a1a476 100644 --- a/cpp/src/arrow/flight/flight_internals_test.cc +++ b/cpp/src/arrow/flight/flight_internals_test.cc @@ -482,5 +482,61 @@ TEST_F(TestCookieParsing, CookieCache) { AddCookieVerifyCache({"id0=0;", "id1=1;", "id2=2"}, "id0=0; id1=1; id2=2"); } +// ---------------------------------------------------------------------- +// Transport abstraction tests + +TEST(TransportErrorHandling, ReconstructStatus) { + Status current = Status::Invalid("Base error message"); + // Invalid code + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr(". Also, server sent unknown or invalid Arrow status code -1"), + internal::ReconstructStatus("-1", current, util::nullopt, util::nullopt, + util::nullopt, /*detail=*/nullptr)); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr( + ". Also, server sent unknown or invalid Arrow status code foobar"), + internal::ReconstructStatus("foobar", current, util::nullopt, util::nullopt, + util::nullopt, /*detail=*/nullptr)); + + // Override code + EXPECT_RAISES_WITH_MESSAGE_THAT( + AlreadyExists, ::testing::HasSubstr("Base error message"), + internal::ReconstructStatus( + std::to_string(static_cast(StatusCode::AlreadyExists)), current, + util::nullopt, util::nullopt, util::nullopt, /*detail=*/nullptr)); + + // Override message + EXPECT_RAISES_WITH_MESSAGE_THAT( + AlreadyExists, ::testing::HasSubstr("Custom error message"), + internal::ReconstructStatus( + std::to_string(static_cast(StatusCode::AlreadyExists)), current, + "Custom error message", util::nullopt, util::nullopt, /*detail=*/nullptr)); + + // With detail + EXPECT_RAISES_WITH_MESSAGE_THAT( + AlreadyExists, + ::testing::AllOf(::testing::HasSubstr("Custom error message"), + ::testing::HasSubstr(". Detail: Detail message")), + internal::ReconstructStatus( + std::to_string(static_cast(StatusCode::AlreadyExists)), current, + "Custom error message", "Detail message", util::nullopt, /*detail=*/nullptr)); + + // With detail and bin + auto reconstructed = internal::ReconstructStatus( + std::to_string(static_cast(StatusCode::AlreadyExists)), current, + "Custom error message", "Detail message", "Binary error details", + /*detail=*/nullptr); + EXPECT_RAISES_WITH_MESSAGE_THAT( + AlreadyExists, + ::testing::AllOf(::testing::HasSubstr("Custom error message"), + ::testing::HasSubstr(". Detail: Detail message")), + reconstructed); + auto detail = FlightStatusDetail::UnwrapStatus(reconstructed); + ASSERT_NE(detail, nullptr); + ASSERT_EQ(detail->extra_info(), "Binary error details"); +} + } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index 3f0ed7114fa..cf3c30358a3 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -109,6 +109,12 @@ class GrpcCudaDataTest : public CudaDataTest { }; ARROW_FLIGHT_TEST_CUDA_DATA(GrpcCudaDataTest); +class GrpcErrorHandlingTest : public ErrorHandlingTest { + protected: + std::string transport() const override { return "grpc"; } +}; +ARROW_FLIGHT_TEST_ERROR_HANDLING(GrpcErrorHandlingTest); + //------------------------------------------------------------ // Ad-hoc gRPC-specific tests diff --git a/cpp/src/arrow/flight/test_definitions.cc b/cpp/src/arrow/flight/test_definitions.cc index 1ec06a1f004..a152c3c9601 100644 --- a/cpp/src/arrow/flight/test_definitions.cc +++ b/cpp/src/arrow/flight/test_definitions.cc @@ -1363,5 +1363,128 @@ void CudaDataTest::TestDoExchange() { #endif +//------------------------------------------------------------ +// Test error handling + +namespace { +constexpr std::initializer_list kStatusCodes = { + StatusCode::OutOfMemory, + StatusCode::KeyError, + StatusCode::TypeError, + StatusCode::Invalid, + StatusCode::IOError, + StatusCode::CapacityError, + StatusCode::IndexError, + StatusCode::Cancelled, + StatusCode::UnknownError, + StatusCode::NotImplemented, + StatusCode::SerializationError, + StatusCode::RError, + StatusCode::CodeGenError, + StatusCode::ExpressionValidationError, + StatusCode::ExecutionError, + StatusCode::AlreadyExists, +}; + +constexpr std::initializer_list kFlightStatusCodes = { + FlightStatusCode::Internal, FlightStatusCode::TimedOut, + FlightStatusCode::Cancelled, FlightStatusCode::Unauthenticated, + FlightStatusCode::Unauthorized, FlightStatusCode::Unavailable, + FlightStatusCode::Failed, +}; +arrow::Result TryConvertStatusCode(int raw_code) { + for (const auto status_code : kStatusCodes) { + if (raw_code == static_cast(status_code)) { + return status_code; + } + } + return Status::Invalid(raw_code); +} +arrow::Result TryConvertFlightStatusCode(int raw_code) { + for (const auto status_code : kFlightStatusCodes) { + if (raw_code == static_cast(status_code)) { + return status_code; + } + } + return Status::Invalid(raw_code); +} + +class TestStatusDetail : public StatusDetail { + public: + const char* type_id() const override { return "test-status-detail"; } + std::string ToString() const override { return "Custom status detail"; } +}; +class ErrorHandlingTestServer : public FlightServerBase { + public: + Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request, + std::unique_ptr* info) override { + if (request.path.size() >= 2) { + const int raw_code = std::atoi(request.path[0].c_str()); + ARROW_ASSIGN_OR_RAISE(StatusCode code, TryConvertStatusCode(raw_code)); + + if (request.path.size() == 2) { + return Status(code, request.path[1]); + } else if (request.path.size() == 3) { + return Status(code, request.path[1], std::make_shared()); + } else { + const int raw_code = std::atoi(request.path[2].c_str()); + ARROW_ASSIGN_OR_RAISE(FlightStatusCode flight_code, + TryConvertFlightStatusCode(raw_code)); + return Status(code, request.path[1], + std::make_shared(flight_code, request.path[3])); + } + } + return Status::NotImplemented("NYI"); + } +}; +} // namespace + +void ErrorHandlingTest::SetUp() { + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); + ASSERT_OK(MakeServer( + location, &server_, &client_, + [](FlightServerOptions* options) { return Status::OK(); }, + [](FlightClientOptions* options) { return Status::OK(); })); +} +void ErrorHandlingTest::TearDown() { + ASSERT_OK(client_->Close()); + ASSERT_OK(server_->Shutdown()); +} + +void ErrorHandlingTest::TestGetFlightInfo() { + std::unique_ptr info; + for (const auto code : kStatusCodes) { + ARROW_SCOPED_TRACE("C++ status code: ", static_cast(code)); + auto descr = FlightDescriptor::Path( + {std::to_string(static_cast(code)), "Expected message"}); + auto status = client_->GetFlightInfo(descr).status(); + EXPECT_EQ(status.code(), code); + EXPECT_THAT(status.message(), ::testing::HasSubstr("Expected message")); + + // Custom status detail + descr = FlightDescriptor::Path( + {std::to_string(static_cast(code)), "Expected message", ""}); + status = client_->GetFlightInfo(descr).status(); + EXPECT_EQ(status.code(), code); + EXPECT_THAT(status.message(), ::testing::HasSubstr("Expected message")); + EXPECT_THAT(status.message(), ::testing::HasSubstr("Detail: Custom status detail")); + + // Flight status detail + for (const auto flight_code : kFlightStatusCodes) { + ARROW_SCOPED_TRACE("Flight status code: ", static_cast(flight_code)); + descr = FlightDescriptor::Path( + {std::to_string(static_cast(code)), "Expected message", + std::to_string(static_cast(flight_code)), "Expected detail message"}); + status = client_->GetFlightInfo(descr).status(); + // Don't check status code, since Flight code may override it + EXPECT_THAT(status.message(), ::testing::HasSubstr("Expected message")); + auto detail = FlightStatusDetail::UnwrapStatus(status); + ASSERT_NE(detail, nullptr); + EXPECT_EQ(detail->code(), flight_code); + EXPECT_THAT(detail->extra_info(), ::testing::HasSubstr("Expected detail message")); + } + } +} + } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/test_definitions.h b/cpp/src/arrow/flight/test_definitions.h index 601e8d0b4b1..464631455dd 100644 --- a/cpp/src/arrow/flight/test_definitions.h +++ b/cpp/src/arrow/flight/test_definitions.h @@ -255,5 +255,24 @@ class ARROW_FLIGHT_EXPORT CudaDataTest : public FlightTest { TEST_F(FIXTURE, TestDoPut) { TestDoPut(); } \ TEST_F(FIXTURE, TestDoExchange) { TestDoExchange(); } +/// \brief Tests of error handling. +class ARROW_FLIGHT_EXPORT ErrorHandlingTest : public FlightTest { + public: + void SetUp() override; + void TearDown() override; + + // Test methods + void TestGetFlightInfo(); + + private: + std::unique_ptr client_; + std::unique_ptr server_; +}; + +#define ARROW_FLIGHT_TEST_ERROR_HANDLING(FIXTURE) \ + static_assert(std::is_base_of::value, \ + ARROW_STRINGIFY(FIXTURE) " must inherit from ErrorHandlingTest"); \ + TEST_F(FIXTURE, TestGetFlightInfo) { TestGetFlightInfo(); } + } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/transport.cc b/cpp/src/arrow/flight/transport.cc index 2ccdf82bd76..0da81a567eb 100644 --- a/cpp/src/arrow/flight/transport.cc +++ b/cpp/src/arrow/flight/transport.cc @@ -17,6 +17,7 @@ #include "arrow/flight/transport.h" +#include #include #include "arrow/flight/client_auth.h" @@ -159,6 +160,198 @@ TransportRegistry* GetDefaultTransportRegistry() { return &kRegistry; } +//------------------------------------------------------------ +// Error propagation helpers + +TransportStatus TransportStatus::FromStatus(const Status& arrow_status) { + if (arrow_status.ok()) { + return TransportStatus{TransportStatusCode::kOk, ""}; + } + + TransportStatusCode code = TransportStatusCode::kUnknown; + std::string message = arrow_status.message(); + if (arrow_status.detail()) { + message += ". Detail: "; + message += arrow_status.detail()->ToString(); + } + + std::shared_ptr flight_status = + FlightStatusDetail::UnwrapStatus(arrow_status); + if (flight_status) { + switch (flight_status->code()) { + case FlightStatusCode::Internal: + code = TransportStatusCode::kInternal; + break; + case FlightStatusCode::TimedOut: + code = TransportStatusCode::kTimedOut; + break; + case FlightStatusCode::Cancelled: + code = TransportStatusCode::kCancelled; + break; + case FlightStatusCode::Unauthenticated: + code = TransportStatusCode::kUnauthenticated; + break; + case FlightStatusCode::Unauthorized: + code = TransportStatusCode::kUnauthorized; + break; + case FlightStatusCode::Unavailable: + code = TransportStatusCode::kUnavailable; + break; + default: + break; + } + } else if (arrow_status.IsKeyError()) { + code = TransportStatusCode::kNotFound; + } else if (arrow_status.IsInvalid()) { + code = TransportStatusCode::kInvalidArgument; + } else if (arrow_status.IsCancelled()) { + code = TransportStatusCode::kCancelled; + } else if (arrow_status.IsNotImplemented()) { + code = TransportStatusCode::kUnimplemented; + } else if (arrow_status.IsAlreadyExists()) { + code = TransportStatusCode::kAlreadyExists; + } + return TransportStatus{code, std::move(message)}; +} + +TransportStatus TransportStatus::FromCodeStringAndMessage(const std::string& code_str, + std::string message) { + int code_int = 0; + try { + code_int = std::stoi(code_str); + } catch (...) { + return TransportStatus{ + TransportStatusCode::kUnknown, + message + ". Also, server sent unknown or invalid Arrow status code " + code_str}; + } + switch (code_int) { + case static_cast(TransportStatusCode::kOk): + case static_cast(TransportStatusCode::kUnknown): + case static_cast(TransportStatusCode::kInternal): + case static_cast(TransportStatusCode::kInvalidArgument): + case static_cast(TransportStatusCode::kTimedOut): + case static_cast(TransportStatusCode::kNotFound): + case static_cast(TransportStatusCode::kAlreadyExists): + case static_cast(TransportStatusCode::kCancelled): + case static_cast(TransportStatusCode::kUnauthenticated): + case static_cast(TransportStatusCode::kUnauthorized): + case static_cast(TransportStatusCode::kUnimplemented): + case static_cast(TransportStatusCode::kUnavailable): + return TransportStatus{static_cast(code_int), + std::move(message)}; + default: { + return TransportStatus{ + TransportStatusCode::kUnknown, + message + ". Also, server sent unknown or invalid Arrow status code " + + code_str}; + } + } +} + +Status TransportStatus::ToStatus() const { + switch (code) { + case TransportStatusCode::kOk: + return Status::OK(); + case TransportStatusCode::kUnknown: { + std::stringstream ss; + ss << "Flight RPC failed with message: " << message; + return Status::UnknownError(ss.str()).WithDetail( + std::make_shared(FlightStatusCode::Failed)); + } + case TransportStatusCode::kInternal: + return Status::IOError("Flight returned internal error, with message: ", message) + .WithDetail(std::make_shared(FlightStatusCode::Internal)); + case TransportStatusCode::kInvalidArgument: + return Status::Invalid("Flight returned invalid argument error, with message: ", + message); + case TransportStatusCode::kTimedOut: + return Status::IOError("Flight returned timeout error, with message: ", message) + .WithDetail(std::make_shared(FlightStatusCode::TimedOut)); + case TransportStatusCode::kNotFound: + return Status::KeyError("Flight returned not found error, with message: ", message); + case TransportStatusCode::kAlreadyExists: + return Status::AlreadyExists("Flight returned already exists error, with message: ", + message); + case TransportStatusCode::kCancelled: + return Status::Cancelled("Flight cancelled call, with message: ", message) + .WithDetail(std::make_shared(FlightStatusCode::Cancelled)); + case TransportStatusCode::kUnauthenticated: + return Status::IOError("Flight returned unauthenticated error, with message: ", + message) + .WithDetail( + std::make_shared(FlightStatusCode::Unauthenticated)); + case TransportStatusCode::kUnauthorized: + return Status::IOError("Flight returned unauthorized error, with message: ", + message) + .WithDetail( + std::make_shared(FlightStatusCode::Unauthorized)); + case TransportStatusCode::kUnimplemented: + return Status::NotImplemented("Flight returned unimplemented error, with message: ", + message); + case TransportStatusCode::kUnavailable: + return Status::IOError("Flight returned unavailable error, with message: ", message) + .WithDetail( + std::make_shared(FlightStatusCode::Unavailable)); + default: + return Status::UnknownError("Flight failed with error code ", + static_cast(code), " and message: ", message); + } +} + +Status ReconstructStatus(const std::string& code_str, const Status& current_status, + util::optional message, + util::optional detail_message, + util::optional detail_bin, + std::shared_ptr detail) { + // Bounce through std::string to get a proper null-terminated C string + StatusCode status_code = current_status.code(); + std::stringstream status_message; + try { + const auto code_int = std::stoi(code_str); + switch (code_int) { + case static_cast(StatusCode::OutOfMemory): + case static_cast(StatusCode::KeyError): + case static_cast(StatusCode::TypeError): + case static_cast(StatusCode::Invalid): + case static_cast(StatusCode::IOError): + case static_cast(StatusCode::CapacityError): + case static_cast(StatusCode::IndexError): + case static_cast(StatusCode::Cancelled): + case static_cast(StatusCode::UnknownError): + case static_cast(StatusCode::NotImplemented): + case static_cast(StatusCode::SerializationError): + case static_cast(StatusCode::RError): + case static_cast(StatusCode::CodeGenError): + case static_cast(StatusCode::ExpressionValidationError): + case static_cast(StatusCode::ExecutionError): + case static_cast(StatusCode::AlreadyExists): { + status_code = static_cast(code_int); + break; + } + default: { + status_message << ". Also, server sent unknown or invalid Arrow status code " + << code_str; + break; + } + } + } catch (...) { + status_message << ". Also, server sent unknown or invalid Arrow status code " + << code_str; + } + + status_message << (message.has_value() ? *message : current_status.message()); + if (detail_message.has_value()) { + status_message << ". Detail: " << *detail_message; + } + if (detail_bin.has_value()) { + if (!detail) { + detail = std::make_shared(FlightStatusCode::Internal); + } + detail->set_extra_info(std::move(*detail_bin)); + } + return Status(status_code, status_message.str(), std::move(detail)); +} + } // namespace internal } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/transport.h b/cpp/src/arrow/flight/transport.h index f02ab05157a..66ded71fbe9 100644 --- a/cpp/src/arrow/flight/transport.h +++ b/cpp/src/arrow/flight/transport.h @@ -65,12 +65,14 @@ #include "arrow/flight/type_fwd.h" #include "arrow/flight/visibility.h" #include "arrow/type_fwd.h" +#include "arrow/util/optional.h" namespace arrow { namespace ipc { class Message; } namespace flight { +class FlightStatusDetail; namespace internal { /// Internal, not user-visible type used for memory-efficient reads @@ -220,6 +222,54 @@ class ARROW_FLIGHT_EXPORT TransportRegistry { ARROW_FLIGHT_EXPORT TransportRegistry* GetDefaultTransportRegistry(); +//------------------------------------------------------------ +// Error propagation helpers + +/// \brief Abstract status code as per the Flight specification. +enum class TransportStatusCode { + kOk = 0, + kUnknown = 1, + kInternal = 2, + kInvalidArgument = 3, + kTimedOut = 4, + kNotFound = 5, + kAlreadyExists = 6, + kCancelled = 7, + kUnauthenticated = 8, + kUnauthorized = 9, + kUnimplemented = 10, + kUnavailable = 11, +}; + +/// \brief Abstract error status. +/// +/// Transport implementations may use side channels (e.g. HTTP +/// trailers) to convey additional information to reconstruct the +/// original C++ status for implementations that can use it. +struct ARROW_FLIGHT_EXPORT TransportStatus { + TransportStatusCode code; + std::string message; + + /// \brief Convert a C++ status to an abstract transport status. + static TransportStatus FromStatus(const Status& arrow_status); + + /// \brief Reconstruct a string-encoded TransportStatus. + static TransportStatus FromCodeStringAndMessage(const std::string& code_str, + std::string message); + + /// \brief Convert an abstract transport status to a C++ status. + Status ToStatus() const; +}; + +/// \brief Convert the string representation of an Arrow status code +/// back to an Arrow status. +ARROW_FLIGHT_EXPORT +Status ReconstructStatus(const std::string& code_str, const Status& current_status, + util::optional message, + util::optional detail_message, + util::optional detail_bin, + std::shared_ptr detail); + } // namespace internal } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/transport/grpc/util_internal.cc b/cpp/src/arrow/flight/transport/grpc/util_internal.cc index 5268df160e9..0455dc119a9 100644 --- a/cpp/src/arrow/flight/transport/grpc/util_internal.cc +++ b/cpp/src/arrow/flight/transport/grpc/util_internal.cc @@ -20,7 +20,6 @@ #include #include #include -#include #include #ifdef GRPCPP_PP_INCLUDE @@ -29,6 +28,7 @@ #include #endif +#include "arrow/flight/transport.h" #include "arrow/flight/types.h" #include "arrow/status.h" @@ -43,110 +43,77 @@ const char* kGrpcStatusMessageHeader = "x-arrow-status-message-bin"; const char* kGrpcStatusDetailHeader = "x-arrow-status-detail-bin"; const char* kBinaryErrorDetailsKey = "grpc-status-details-bin"; -static Status StatusCodeFromString(const ::grpc::string_ref& code_ref, StatusCode* code) { - // Bounce through std::string to get a proper null-terminated C string - const auto code_int = std::atoi(std::string(code_ref.data(), code_ref.size()).c_str()); - switch (code_int) { - case static_cast(StatusCode::OutOfMemory): - case static_cast(StatusCode::KeyError): - case static_cast(StatusCode::TypeError): - case static_cast(StatusCode::Invalid): - case static_cast(StatusCode::IOError): - case static_cast(StatusCode::CapacityError): - case static_cast(StatusCode::IndexError): - case static_cast(StatusCode::UnknownError): - case static_cast(StatusCode::NotImplemented): - case static_cast(StatusCode::SerializationError): - case static_cast(StatusCode::RError): - case static_cast(StatusCode::CodeGenError): - case static_cast(StatusCode::ExpressionValidationError): - case static_cast(StatusCode::ExecutionError): - case static_cast(StatusCode::AlreadyExists): { - *code = static_cast(code_int); - return Status::OK(); - } - default: - // Code is invalid - return Status::UnknownError("Unknown Arrow status code", code_ref); - } -} - /// Try to extract a status from gRPC trailers. /// Return Status::OK if found, an error otherwise. -static Status FromGrpcContext(const ::grpc::ClientContext& ctx, Status* status, - std::shared_ptr flight_status_detail) { +static bool FromGrpcContext(const ::grpc::ClientContext& ctx, + const Status& current_status, Status* status, + std::shared_ptr flight_status_detail) { const std::multimap<::grpc::string_ref, ::grpc::string_ref>& trailers = ctx.GetServerTrailingMetadata(); - const auto code_val = trailers.find(kGrpcStatusCodeHeader); - if (code_val == trailers.end()) { - return Status::IOError("Status code header not found"); - } - const ::grpc::string_ref code_ref = code_val->second; - StatusCode code = {}; - RETURN_NOT_OK(StatusCodeFromString(code_ref, &code)); + const auto code_val = trailers.find(kGrpcStatusCodeHeader); + if (code_val == trailers.end()) return false; const auto message_val = trailers.find(kGrpcStatusMessageHeader); - if (message_val == trailers.end()) { - return Status::IOError("Status message header not found"); - } + const util::optional message = + message_val == trailers.end() + ? util::nullopt + : util::optional( + std::string(message_val->second.data(), message_val->second.size())); - const ::grpc::string_ref message_ref = message_val->second; - std::string message = std::string(message_ref.data(), message_ref.size()); const auto detail_val = trailers.find(kGrpcStatusDetailHeader); - if (detail_val != trailers.end()) { - const ::grpc::string_ref detail_ref = detail_val->second; - message += ". Detail: "; - message += std::string(detail_ref.data(), detail_ref.size()); - } + const util::optional detail_message = + detail_val == trailers.end() + ? util::nullopt + : util::optional( + std::string(detail_val->second.data(), detail_val->second.size())); + const auto grpc_detail_val = trailers.find(kBinaryErrorDetailsKey); - if (grpc_detail_val != trailers.end()) { - const ::grpc::string_ref detail_ref = grpc_detail_val->second; - std::string bin_detail = std::string(detail_ref.data(), detail_ref.size()); - if (!flight_status_detail) { - flight_status_detail = - std::make_shared(FlightStatusCode::Internal); - } - flight_status_detail->set_extra_info(bin_detail); - } - *status = Status(code, message, flight_status_detail); - return Status::OK(); + const util::optional detail_bin = + grpc_detail_val == trailers.end() + ? util::nullopt + : util::optional(std::string(grpc_detail_val->second.data(), + grpc_detail_val->second.size())); + + std::string code_str(code_val->second.data(), code_val->second.size()); + *status = internal::ReconstructStatus(code_str, current_status, std::move(message), + std::move(detail_message), std::move(detail_bin), + std::move(flight_status_detail)); + return true; } /// Convert a gRPC status to an Arrow status, ignoring any /// implementation-defined headers that encode further detail. static Status FromGrpcCode(const ::grpc::Status& grpc_status) { + using internal::TransportStatus; + using internal::TransportStatusCode; switch (grpc_status.error_code()) { case ::grpc::StatusCode::OK: return Status::OK(); case ::grpc::StatusCode::CANCELLED: - return Status::IOError("gRPC cancelled call, with message: ", - grpc_status.error_message()) - .WithDetail(std::make_shared(FlightStatusCode::Cancelled)); - case ::grpc::StatusCode::UNKNOWN: { - std::stringstream ss; - ss << "Flight RPC failed with message: " << grpc_status.error_message(); - return Status::UnknownError(ss.str()).WithDetail( - std::make_shared(FlightStatusCode::Failed)); - } + return TransportStatus{TransportStatusCode::kCancelled, grpc_status.error_message()} + .ToStatus(); + case ::grpc::StatusCode::UNKNOWN: + return TransportStatus{TransportStatusCode::kUnknown, grpc_status.error_message()} + .ToStatus(); case ::grpc::StatusCode::INVALID_ARGUMENT: - return Status::Invalid("gRPC returned invalid argument error, with message: ", - grpc_status.error_message()); + return TransportStatus{TransportStatusCode::kInvalidArgument, + grpc_status.error_message()} + .ToStatus(); case ::grpc::StatusCode::DEADLINE_EXCEEDED: - return Status::IOError("gRPC returned deadline exceeded error, with message: ", - grpc_status.error_message()) - .WithDetail(std::make_shared(FlightStatusCode::TimedOut)); + return TransportStatus{TransportStatusCode::kTimedOut, grpc_status.error_message()} + .ToStatus(); case ::grpc::StatusCode::NOT_FOUND: - return Status::KeyError("gRPC returned not found error, with message: ", - grpc_status.error_message()); + return TransportStatus{TransportStatusCode::kNotFound, grpc_status.error_message()} + .ToStatus(); case ::grpc::StatusCode::ALREADY_EXISTS: - return Status::AlreadyExists("gRPC returned already exists error, with message: ", - grpc_status.error_message()); + return TransportStatus{TransportStatusCode::kAlreadyExists, + grpc_status.error_message()} + .ToStatus(); case ::grpc::StatusCode::PERMISSION_DENIED: - return Status::IOError("gRPC returned permission denied error, with message: ", - grpc_status.error_message()) - .WithDetail( - std::make_shared(FlightStatusCode::Unauthorized)); + return TransportStatus{TransportStatusCode::kUnauthorized, + grpc_status.error_message()} + .ToStatus(); case ::grpc::StatusCode::RESOURCE_EXHAUSTED: return Status::Invalid("gRPC returned resource exhausted error, with message: ", grpc_status.error_message()); @@ -161,26 +128,24 @@ static Status FromGrpcCode(const ::grpc::Status& grpc_status) { return Status::Invalid("gRPC returned out-of-range error, with message: ", grpc_status.error_message()); case ::grpc::StatusCode::UNIMPLEMENTED: - return Status::NotImplemented("gRPC returned unimplemented error, with message: ", - grpc_status.error_message()); + return TransportStatus{TransportStatusCode::kUnimplemented, + grpc_status.error_message()} + .ToStatus(); case ::grpc::StatusCode::INTERNAL: - return Status::IOError("gRPC returned internal error, with message: ", - grpc_status.error_message()) - .WithDetail(std::make_shared(FlightStatusCode::Internal)); + return TransportStatus{TransportStatusCode::kInternal, grpc_status.error_message()} + .ToStatus(); case ::grpc::StatusCode::UNAVAILABLE: - return Status::IOError("gRPC returned unavailable error, with message: ", - grpc_status.error_message()) - .WithDetail( - std::make_shared(FlightStatusCode::Unavailable)); + return TransportStatus{TransportStatusCode::kUnavailable, + grpc_status.error_message()} + .ToStatus(); case ::grpc::StatusCode::DATA_LOSS: return Status::IOError("gRPC returned data loss error, with message: ", grpc_status.error_message()) .WithDetail(std::make_shared(FlightStatusCode::Internal)); case ::grpc::StatusCode::UNAUTHENTICATED: - return Status::IOError("gRPC returned unauthenticated error, with message: ", - grpc_status.error_message()) - .WithDetail( - std::make_shared(FlightStatusCode::Unauthenticated)); + return TransportStatus{TransportStatusCode::kUnauthenticated, + grpc_status.error_message()} + .ToStatus(); default: return Status::UnknownError("gRPC failed with error code ", grpc_status.error_code(), @@ -190,70 +155,67 @@ static Status FromGrpcCode(const ::grpc::Status& grpc_status) { Status FromGrpcStatus(const ::grpc::Status& grpc_status, ::grpc::ClientContext* ctx) { const Status status = FromGrpcCode(grpc_status); - if (!status.ok() && ctx) { Status arrow_status; - - if (!FromGrpcContext(*ctx, &arrow_status, FlightStatusDetail::UnwrapStatus(status)) - .ok()) { - // If we fail to decode a more detailed status from the headers, - // proceed normally - return status; + if (FromGrpcContext(*ctx, status, &arrow_status, + FlightStatusDetail::UnwrapStatus(status))) { + return arrow_status; } - - return arrow_status; + // If we fail to decode a more detailed status from the headers, + // proceed normally } return status; } /// Convert an Arrow status to a gRPC status. static ::grpc::Status ToRawGrpcStatus(const Status& arrow_status) { - if (arrow_status.ok()) { - return ::grpc::Status::OK; - } + using internal::TransportStatus; + using internal::TransportStatusCode; + if (arrow_status.ok()) return ::grpc::Status::OK; + TransportStatus transport_status = TransportStatus::FromStatus(arrow_status); ::grpc::StatusCode grpc_code = ::grpc::StatusCode::UNKNOWN; - std::string message = arrow_status.message(); - if (arrow_status.detail()) { - message += ". Detail: "; - message += arrow_status.detail()->ToString(); - } - - std::shared_ptr flight_status = - FlightStatusDetail::UnwrapStatus(arrow_status); - if (flight_status) { - switch (flight_status->code()) { - case FlightStatusCode::Internal: - grpc_code = ::grpc::StatusCode::INTERNAL; - break; - case FlightStatusCode::TimedOut: - grpc_code = ::grpc::StatusCode::DEADLINE_EXCEEDED; - break; - case FlightStatusCode::Cancelled: - grpc_code = ::grpc::StatusCode::CANCELLED; - break; - case FlightStatusCode::Unauthenticated: - grpc_code = ::grpc::StatusCode::UNAUTHENTICATED; - break; - case FlightStatusCode::Unauthorized: - grpc_code = ::grpc::StatusCode::PERMISSION_DENIED; - break; - case FlightStatusCode::Unavailable: - grpc_code = ::grpc::StatusCode::UNAVAILABLE; - break; - default: - break; - } - } else if (arrow_status.IsNotImplemented()) { - grpc_code = ::grpc::StatusCode::UNIMPLEMENTED; - } else if (arrow_status.IsInvalid()) { - grpc_code = ::grpc::StatusCode::INVALID_ARGUMENT; - } else if (arrow_status.IsKeyError()) { - grpc_code = ::grpc::StatusCode::NOT_FOUND; - } else if (arrow_status.IsAlreadyExists()) { - grpc_code = ::grpc::StatusCode::ALREADY_EXISTS; + switch (transport_status.code) { + case TransportStatusCode::kOk: + return ::grpc::Status::OK; + case TransportStatusCode::kUnknown: + grpc_code = ::grpc::StatusCode::UNKNOWN; + break; + case TransportStatusCode::kInternal: + grpc_code = ::grpc::StatusCode::INTERNAL; + break; + case TransportStatusCode::kInvalidArgument: + grpc_code = ::grpc::StatusCode::INVALID_ARGUMENT; + break; + case TransportStatusCode::kTimedOut: + grpc_code = ::grpc::StatusCode::DEADLINE_EXCEEDED; + break; + case TransportStatusCode::kNotFound: + grpc_code = ::grpc::StatusCode::NOT_FOUND; + break; + case TransportStatusCode::kAlreadyExists: + grpc_code = ::grpc::StatusCode::ALREADY_EXISTS; + break; + case TransportStatusCode::kCancelled: + grpc_code = ::grpc::StatusCode::CANCELLED; + break; + case TransportStatusCode::kUnauthenticated: + grpc_code = ::grpc::StatusCode::UNAUTHENTICATED; + break; + case TransportStatusCode::kUnauthorized: + grpc_code = ::grpc::StatusCode::PERMISSION_DENIED; + break; + case TransportStatusCode::kUnimplemented: + grpc_code = ::grpc::StatusCode::UNIMPLEMENTED; + break; + case TransportStatusCode::kUnavailable: + grpc_code = ::grpc::StatusCode::UNAVAILABLE; + break; + default: + grpc_code = ::grpc::StatusCode::UNKNOWN; + break; } - return ::grpc::Status(grpc_code, message); + return ::grpc::Status(grpc_code, std::move(transport_status.message)); } /// Convert an Arrow status to a gRPC status, and add extra headers to diff --git a/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc index 6a580af92fd..a29d498d0be 100644 --- a/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc +++ b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc @@ -86,6 +86,12 @@ class UcxCudaDataTest : public CudaDataTest { }; ARROW_FLIGHT_TEST_CUDA_DATA(UcxCudaDataTest); +class UcxErrorHandlingTest : public ErrorHandlingTest { + protected: + std::string transport() const override { return "ucx"; } +}; +ARROW_FLIGHT_TEST_ERROR_HANDLING(UcxErrorHandlingTest); + //------------------------------------------------------------ // UCX internals tests @@ -203,43 +209,6 @@ TEST(HeadersFrame, Parse) { HeadersFrame::Parse(std::move(buffer))); } } - -TEST(HeadersFrame, RoundTripStatus) { - for (const auto code : kStatusCodes) { - { - Status expected = code == StatusCode::OK ? Status() : Status(code, "foo"); - ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Make(expected, {})); - Status status; - ASSERT_OK(headers.GetStatus(&status)); - ASSERT_EQ(status, expected); - } - - if (code == StatusCode::OK) continue; - - // Attach a generic status detail - { - auto detail = std::make_shared(); - Status original(code, "foo", detail); - Status expected(code, "foo", - std::make_shared(FlightStatusCode::Internal, - detail->ToString())); - ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Make(expected, {})); - Status status; - ASSERT_OK(headers.GetStatus(&status)); - ASSERT_EQ(status, expected); - } - - // Attach a Flight status detail - for (const auto flight_code : kFlightStatusCodes) { - Status expected(code, "foo", - std::make_shared(flight_code, "extra")); - ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Make(expected, {})); - Status status; - ASSERT_OK(headers.GetStatus(&status)); - ASSERT_EQ(status, expected); - } - } -} } // namespace ucx } // namespace transport @@ -342,7 +311,9 @@ TEST_F(TestUcx, Errors) { Status expected(code, "Error message"); server->set_error_status(expected); Status actual = client_->GetFlightInfo(descriptor).status(); - ASSERT_EQ(actual, expected); + ASSERT_EQ(actual.code(), expected.code()) << actual.ToString(); + ASSERT_THAT(actual.message(), ::testing::HasSubstr("Error message")) + << actual.ToString(); // Attach a generic status detail { @@ -352,7 +323,10 @@ TEST_F(TestUcx, Errors) { std::make_shared(FlightStatusCode::Internal, detail->ToString())); Status actual = client_->GetFlightInfo(descriptor).status(); - ASSERT_EQ(actual, expected); + ASSERT_EQ(actual.code(), expected.code()) << actual.ToString(); + ASSERT_THAT(actual.message(), ::testing::HasSubstr("foo")) << actual.ToString(); + ASSERT_THAT(actual.message(), ::testing::HasSubstr("Custom status detail")) + << actual.ToString(); } // Attach a Flight status detail @@ -361,7 +335,9 @@ TEST_F(TestUcx, Errors) { std::make_shared(flight_code, "extra")); server->set_error_status(expected); Status actual = client_->GetFlightInfo(descriptor).status(); - ASSERT_EQ(actual, expected); + ASSERT_EQ(actual.code(), expected.code()) << actual.ToString(); + ASSERT_THAT(actual.message(), ::testing::HasSubstr("Error message")) + << actual.ToString(); } } } diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc index ab4cc323f4c..abcf7911255 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc @@ -36,6 +36,9 @@ namespace flight { namespace transport { namespace ucx { +using internal::TransportStatus; +using internal::TransportStatusCode; + // Defines to test different implementation strategies // Enable the CONTIG path for CPU-only data // #define ARROW_FLIGHT_UCX_SEND_CONTIG @@ -222,17 +225,19 @@ arrow::Result HeadersFrame::Make( const Status& status, const std::vector>& headers) { auto all_headers = headers; + + TransportStatus transport_status = TransportStatus::FromStatus(status); + all_headers.emplace_back(kHeaderStatus, + std::to_string(static_cast(transport_status.code))); + all_headers.emplace_back(kHeaderMessage, std::move(transport_status.message)); all_headers.emplace_back(kHeaderStatusCode, std::to_string(static_cast(status.code()))); all_headers.emplace_back(kHeaderStatusMessage, status.message()); if (status.detail()) { + all_headers.emplace_back(kHeaderStatusDetail, status.detail()->ToString()); auto fsd = FlightStatusDetail::UnwrapStatus(status); - if (fsd) { - all_headers.emplace_back(kHeaderStatusDetailCode, - std::to_string(static_cast(fsd->code()))); - all_headers.emplace_back(kHeaderStatusDetail, fsd->extra_info()); - } else { - all_headers.emplace_back(kHeaderStatusDetail, status.detail()->ToString()); + if (fsd && !fsd->extra_info().empty()) { + all_headers.emplace_back(kHeaderStatusDetailBin, fsd->extra_info()); } } return Make(all_headers); @@ -246,118 +251,46 @@ arrow::Result HeadersFrame::Get(const std::string& key) { } Status HeadersFrame::GetStatus(Status* out) { + static const std::string kUnknownMessage = "Server did not send status message header"; util::string_view code_str, message_str; - auto status = Get(kHeaderStatusCode).Value(&code_str); + auto status = Get(kHeaderStatus).Value(&code_str); if (!status.ok()) { return Status::KeyError("Server did not send status code header ", kHeaderStatusCode); } - - StatusCode status_code = StatusCode::OK; - auto code = std::strtol(code_str.data(), nullptr, /*base=*/10); - switch (code) { - case 0: - status_code = StatusCode::OK; - break; - case 1: - status_code = StatusCode::OutOfMemory; - break; - case 2: - status_code = StatusCode::KeyError; - break; - case 3: - status_code = StatusCode::TypeError; - break; - case 4: - status_code = StatusCode::Invalid; - break; - case 5: - status_code = StatusCode::IOError; - break; - case 6: - status_code = StatusCode::CapacityError; - break; - case 7: - status_code = StatusCode::IndexError; - break; - case 8: - status_code = StatusCode::Cancelled; - break; - case 9: - status_code = StatusCode::UnknownError; - break; - case 10: - status_code = StatusCode::NotImplemented; - break; - case 11: - status_code = StatusCode::SerializationError; - break; - case 13: - status_code = StatusCode::RError; - break; - case 40: - status_code = StatusCode::CodeGenError; - break; - case 41: - status_code = StatusCode::ExpressionValidationError; - break; - case 42: - status_code = StatusCode::ExecutionError; - break; - case 45: - status_code = StatusCode::AlreadyExists; - break; - default: - status_code = StatusCode::UnknownError; - break; - } - if (status_code == StatusCode::OK) { + if (code_str == "0") { // == std::to_string(TransportStatusCode::kOk) *out = Status::OK(); return Status::OK(); } - status = Get(kHeaderStatusMessage).Value(&message_str); - if (!status.ok()) { - *out = Status(status_code, "Server did not send status message header", nullptr); + status = Get(kHeaderMessage).Value(&message_str); + if (!status.ok()) message_str = kUnknownMessage; + + TransportStatus transport_status = TransportStatus::FromCodeStringAndMessage( + std::string(code_str), std::string(message_str)); + if (transport_status.code == TransportStatusCode::kOk) { + *out = Status::OK(); return Status::OK(); } + *out = transport_status.ToStatus(); - util::string_view detail_code_str, detail_str; - FlightStatusCode detail_code = FlightStatusCode::Internal; - - if (Get(kHeaderStatusDetailCode).Value(&detail_code_str).ok()) { - auto detail_code_int = std::strtol(detail_code_str.data(), nullptr, /*base=*/10); - switch (detail_code_int) { - case 1: - detail_code = FlightStatusCode::TimedOut; - break; - case 2: - detail_code = FlightStatusCode::Cancelled; - break; - case 3: - detail_code = FlightStatusCode::Unauthenticated; - break; - case 4: - detail_code = FlightStatusCode::Unauthorized; - break; - case 5: - detail_code = FlightStatusCode::Unavailable; - break; - case 6: - detail_code = FlightStatusCode::Failed; - break; - case 0: - default: - detail_code = FlightStatusCode::Internal; - break; - } + util::string_view detail_str, bin_str; + util::optional message, detail_message, detail_bin; + if (!Get(kHeaderStatusCode).Value(&code_str).ok()) { + // No Arrow status sent, go with the transport status + return Status::OK(); } - ARROW_UNUSED(Get(kHeaderStatusDetail).Value(&detail_str)); - - std::shared_ptr detail = nullptr; - if (!detail_str.empty()) { - detail = std::make_shared(detail_code, std::string(detail_str)); + if (Get(kHeaderStatusMessage).Value(&message_str).ok()) { + message = std::string(message_str); + } + if (Get(kHeaderStatusDetail).Value(&detail_str).ok()) { + detail_message = std::string(detail_str); + } + if (Get(kHeaderStatusDetailBin).Value(&bin_str).ok()) { + detail_bin = std::string(bin_str); } - *out = Status(status_code, std::string(message_str), std::move(detail)); + *out = internal::ReconstructStatus(std::string(code_str), *out, std::move(message), + std::move(detail_message), std::move(detail_bin), + FlightStatusDetail::UnwrapStatus(*out)); return Status::OK(); } diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.h b/cpp/src/arrow/flight/transport/ucx/ucx_internal.h index bd176e23699..f5b81ab4147 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_internal.h +++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.h @@ -50,10 +50,18 @@ static constexpr char kMethodDoGet[] = "DoGet"; static constexpr char kMethodDoPut[] = "DoPut"; static constexpr char kMethodGetFlightInfo[] = "GetFlightInfo"; +/// The header encoding the transport status. +static constexpr char kHeaderStatus[] = "flight-status"; +/// The header encoding the transport status. +static constexpr char kHeaderMessage[] = "flight-message"; +/// The header encoding the C++ status. static constexpr char kHeaderStatusCode[] = "flight-status-code"; +/// The header encoding the C++ status message. static constexpr char kHeaderStatusMessage[] = "flight-status-message"; +/// The header encoding the C++ status detail message. static constexpr char kHeaderStatusDetail[] = "flight-status-detail"; -static constexpr char kHeaderStatusDetailCode[] = "flight-status-detail-code"; +/// The header encoding the C++ status detail binary data. +static constexpr char kHeaderStatusDetailBin[] = "flight-status-detail-bin"; //------------------------------------------------------------ // UCX Helpers diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index 4a169e985c1..efc96bb7756 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -67,6 +67,8 @@ std::string FlightStatusDetail::CodeAsString() const { return "Unauthorized"; case FlightStatusCode::Unavailable: return "Unavailable"; + case FlightStatusCode::Failed: + return "Failed"; default: return "Unknown"; } diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index fa6ef29c072..5821956b295 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -31,7 +31,8 @@ from cython.operator cimport postincrement from libcpp cimport bool as c_bool from pyarrow.lib cimport * -from pyarrow.lib import ArrowException, ArrowInvalid, SignalStopHandler +from pyarrow.lib import (ArrowCancelled, ArrowException, ArrowInvalid, + SignalStopHandler) from pyarrow.lib import as_buffer, frombytes, tobytes from pyarrow.includes.libarrow_flight cimport * from pyarrow.ipc import _get_legacy_format_default, _ReadPandasMixin @@ -170,7 +171,7 @@ cdef class FlightTimedOutError(FlightError, ArrowException): tobytes(str(self)), self.extra_info) -cdef class FlightCancelledError(FlightError, ArrowException): +cdef class FlightCancelledError(FlightError, ArrowCancelled): cdef CStatus to_status(self): return MakeFlightError(CFlightStatusCancelled, tobytes(str(self)), self.extra_info) diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py index 12f815dbeac..9c61097251a 100644 --- a/python/pyarrow/tests/test_flight.py +++ b/python/pyarrow/tests/test_flight.py @@ -355,17 +355,20 @@ def slow_stream(): class ErrorFlightServer(FlightServerBase): """A Flight server that uses all the Flight-specific errors.""" + errors = { + "internal": flight.FlightInternalError, + "timedout": flight.FlightTimedOutError, + "cancel": flight.FlightCancelledError, + "unauthenticated": flight.FlightUnauthenticatedError, + "unauthorized": flight.FlightUnauthorizedError, + "notimplemented": NotImplementedError, + "invalid": pa.ArrowInvalid, + "key": KeyError, + } + def do_action(self, context, action): - if action.type == "internal": - raise flight.FlightInternalError("foo") - elif action.type == "timedout": - raise flight.FlightTimedOutError("foo") - elif action.type == "cancel": - raise flight.FlightCancelledError("foo") - elif action.type == "unauthenticated": - raise flight.FlightUnauthenticatedError("foo") - elif action.type == "unauthorized": - raise flight.FlightUnauthorizedError("foo") + if action.type in self.errors: + raise self.errors[action.type]("foo") elif action.type == "protobuf": err_msg = b'this is an error message' raise flight.FlightUnauthorizedError("foo", err_msg) @@ -1561,16 +1564,9 @@ def test_roundtrip_errors(): with ErrorFlightServer() as server, \ FlightClient(('localhost', server.port)) as client: - with pytest.raises(flight.FlightInternalError, match=".*foo.*"): - list(client.do_action(flight.Action("internal", b""))) - with pytest.raises(flight.FlightTimedOutError, match=".*foo.*"): - list(client.do_action(flight.Action("timedout", b""))) - with pytest.raises(flight.FlightCancelledError, match=".*foo.*"): - list(client.do_action(flight.Action("cancel", b""))) - with pytest.raises(flight.FlightUnauthenticatedError, match=".*foo.*"): - list(client.do_action(flight.Action("unauthenticated", b""))) - with pytest.raises(flight.FlightUnauthorizedError, match=".*foo.*"): - list(client.do_action(flight.Action("unauthorized", b""))) + for arg, exc_type in ErrorFlightServer.errors.items(): + with pytest.raises(exc_type, match=".*foo.*"): + list(client.do_action(flight.Action(arg, b""))) with pytest.raises(flight.FlightInternalError, match=".*foo.*"): list(client.list_flights())