From dac547a1beb257a89a40c4f465dbeaa709e25f65 Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 29 Mar 2022 17:06:08 -0400
Subject: [PATCH 1/7] ARROW-16069: [C++][FlightRPC] Refactor out gRPC error
code handling
---
cpp/src/arrow/flight/flight_test.cc | 6 +
cpp/src/arrow/flight/test_definitions.cc | 123 ++++++++++++++
cpp/src/arrow/flight/test_definitions.h | 19 +++
cpp/src/arrow/flight/transport.cc | 102 ++++++++++++
cpp/src/arrow/flight/transport.h | 32 ++++
.../flight/transport/grpc/util_internal.cc | 156 +++++++++---------
cpp/src/arrow/flight/types.cc | 2 +
7 files changed, 361 insertions(+), 79 deletions(-)
diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc
index 3f0ed7114fa..cf3c30358a3 100644
--- a/cpp/src/arrow/flight/flight_test.cc
+++ b/cpp/src/arrow/flight/flight_test.cc
@@ -109,6 +109,12 @@ class GrpcCudaDataTest : public CudaDataTest {
};
ARROW_FLIGHT_TEST_CUDA_DATA(GrpcCudaDataTest);
+class GrpcErrorHandlingTest : public ErrorHandlingTest {
+ protected:
+ std::string transport() const override { return "grpc"; }
+};
+ARROW_FLIGHT_TEST_ERROR_HANDLING(GrpcErrorHandlingTest);
+
//------------------------------------------------------------
// Ad-hoc gRPC-specific tests
diff --git a/cpp/src/arrow/flight/test_definitions.cc b/cpp/src/arrow/flight/test_definitions.cc
index 1ec06a1f004..a152c3c9601 100644
--- a/cpp/src/arrow/flight/test_definitions.cc
+++ b/cpp/src/arrow/flight/test_definitions.cc
@@ -1363,5 +1363,128 @@ void CudaDataTest::TestDoExchange() {
#endif
+//------------------------------------------------------------
+// Test error handling
+
+namespace {
+constexpr std::initializer_list kStatusCodes = {
+ StatusCode::OutOfMemory,
+ StatusCode::KeyError,
+ StatusCode::TypeError,
+ StatusCode::Invalid,
+ StatusCode::IOError,
+ StatusCode::CapacityError,
+ StatusCode::IndexError,
+ StatusCode::Cancelled,
+ StatusCode::UnknownError,
+ StatusCode::NotImplemented,
+ StatusCode::SerializationError,
+ StatusCode::RError,
+ StatusCode::CodeGenError,
+ StatusCode::ExpressionValidationError,
+ StatusCode::ExecutionError,
+ StatusCode::AlreadyExists,
+};
+
+constexpr std::initializer_list kFlightStatusCodes = {
+ FlightStatusCode::Internal, FlightStatusCode::TimedOut,
+ FlightStatusCode::Cancelled, FlightStatusCode::Unauthenticated,
+ FlightStatusCode::Unauthorized, FlightStatusCode::Unavailable,
+ FlightStatusCode::Failed,
+};
+arrow::Result TryConvertStatusCode(int raw_code) {
+ for (const auto status_code : kStatusCodes) {
+ if (raw_code == static_cast(status_code)) {
+ return status_code;
+ }
+ }
+ return Status::Invalid(raw_code);
+}
+arrow::Result TryConvertFlightStatusCode(int raw_code) {
+ for (const auto status_code : kFlightStatusCodes) {
+ if (raw_code == static_cast(status_code)) {
+ return status_code;
+ }
+ }
+ return Status::Invalid(raw_code);
+}
+
+class TestStatusDetail : public StatusDetail {
+ public:
+ const char* type_id() const override { return "test-status-detail"; }
+ std::string ToString() const override { return "Custom status detail"; }
+};
+class ErrorHandlingTestServer : public FlightServerBase {
+ public:
+ Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request,
+ std::unique_ptr* info) override {
+ if (request.path.size() >= 2) {
+ const int raw_code = std::atoi(request.path[0].c_str());
+ ARROW_ASSIGN_OR_RAISE(StatusCode code, TryConvertStatusCode(raw_code));
+
+ if (request.path.size() == 2) {
+ return Status(code, request.path[1]);
+ } else if (request.path.size() == 3) {
+ return Status(code, request.path[1], std::make_shared());
+ } else {
+ const int raw_code = std::atoi(request.path[2].c_str());
+ ARROW_ASSIGN_OR_RAISE(FlightStatusCode flight_code,
+ TryConvertFlightStatusCode(raw_code));
+ return Status(code, request.path[1],
+ std::make_shared(flight_code, request.path[3]));
+ }
+ }
+ return Status::NotImplemented("NYI");
+ }
+};
+} // namespace
+
+void ErrorHandlingTest::SetUp() {
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
+ ASSERT_OK(MakeServer(
+ location, &server_, &client_,
+ [](FlightServerOptions* options) { return Status::OK(); },
+ [](FlightClientOptions* options) { return Status::OK(); }));
+}
+void ErrorHandlingTest::TearDown() {
+ ASSERT_OK(client_->Close());
+ ASSERT_OK(server_->Shutdown());
+}
+
+void ErrorHandlingTest::TestGetFlightInfo() {
+ std::unique_ptr info;
+ for (const auto code : kStatusCodes) {
+ ARROW_SCOPED_TRACE("C++ status code: ", static_cast(code));
+ auto descr = FlightDescriptor::Path(
+ {std::to_string(static_cast(code)), "Expected message"});
+ auto status = client_->GetFlightInfo(descr).status();
+ EXPECT_EQ(status.code(), code);
+ EXPECT_THAT(status.message(), ::testing::HasSubstr("Expected message"));
+
+ // Custom status detail
+ descr = FlightDescriptor::Path(
+ {std::to_string(static_cast(code)), "Expected message", ""});
+ status = client_->GetFlightInfo(descr).status();
+ EXPECT_EQ(status.code(), code);
+ EXPECT_THAT(status.message(), ::testing::HasSubstr("Expected message"));
+ EXPECT_THAT(status.message(), ::testing::HasSubstr("Detail: Custom status detail"));
+
+ // Flight status detail
+ for (const auto flight_code : kFlightStatusCodes) {
+ ARROW_SCOPED_TRACE("Flight status code: ", static_cast(flight_code));
+ descr = FlightDescriptor::Path(
+ {std::to_string(static_cast(code)), "Expected message",
+ std::to_string(static_cast(flight_code)), "Expected detail message"});
+ status = client_->GetFlightInfo(descr).status();
+ // Don't check status code, since Flight code may override it
+ EXPECT_THAT(status.message(), ::testing::HasSubstr("Expected message"));
+ auto detail = FlightStatusDetail::UnwrapStatus(status);
+ ASSERT_NE(detail, nullptr);
+ EXPECT_EQ(detail->code(), flight_code);
+ EXPECT_THAT(detail->extra_info(), ::testing::HasSubstr("Expected detail message"));
+ }
+ }
+}
+
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/test_definitions.h b/cpp/src/arrow/flight/test_definitions.h
index 601e8d0b4b1..464631455dd 100644
--- a/cpp/src/arrow/flight/test_definitions.h
+++ b/cpp/src/arrow/flight/test_definitions.h
@@ -255,5 +255,24 @@ class ARROW_FLIGHT_EXPORT CudaDataTest : public FlightTest {
TEST_F(FIXTURE, TestDoPut) { TestDoPut(); } \
TEST_F(FIXTURE, TestDoExchange) { TestDoExchange(); }
+/// \brief Tests of error handling.
+class ARROW_FLIGHT_EXPORT ErrorHandlingTest : public FlightTest {
+ public:
+ void SetUp() override;
+ void TearDown() override;
+
+ // Test methods
+ void TestGetFlightInfo();
+
+ private:
+ std::unique_ptr client_;
+ std::unique_ptr server_;
+};
+
+#define ARROW_FLIGHT_TEST_ERROR_HANDLING(FIXTURE) \
+ static_assert(std::is_base_of::value, \
+ ARROW_STRINGIFY(FIXTURE) " must inherit from ErrorHandlingTest"); \
+ TEST_F(FIXTURE, TestGetFlightInfo) { TestGetFlightInfo(); }
+
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport.cc b/cpp/src/arrow/flight/transport.cc
index 2ccdf82bd76..a9fb65f90da 100644
--- a/cpp/src/arrow/flight/transport.cc
+++ b/cpp/src/arrow/flight/transport.cc
@@ -17,6 +17,7 @@
#include "arrow/flight/transport.h"
+#include
#include
#include "arrow/flight/client_auth.h"
@@ -35,6 +36,107 @@ ::arrow::Result> FlightData::OpenMessage() {
return ipc::Message::Open(metadata, body);
}
+TransportStatus TransportStatus::FromStatus(const Status& arrow_status) {
+ if (arrow_status.ok()) {
+ return TransportStatus{TransportStatusCode::kOk, ""};
+ }
+
+ TransportStatusCode code = TransportStatusCode::kUnknown;
+ std::string message = arrow_status.message();
+ if (arrow_status.detail()) {
+ message += ". Detail: ";
+ message += arrow_status.detail()->ToString();
+ }
+
+ std::shared_ptr flight_status =
+ FlightStatusDetail::UnwrapStatus(arrow_status);
+ if (flight_status) {
+ switch (flight_status->code()) {
+ case FlightStatusCode::Internal:
+ code = TransportStatusCode::kInternal;
+ break;
+ case FlightStatusCode::TimedOut:
+ code = TransportStatusCode::kTimedOut;
+ break;
+ case FlightStatusCode::Cancelled:
+ code = TransportStatusCode::kCancelled;
+ break;
+ case FlightStatusCode::Unauthenticated:
+ code = TransportStatusCode::kUnauthenticated;
+ break;
+ case FlightStatusCode::Unauthorized:
+ code = TransportStatusCode::kUnauthorized;
+ break;
+ case FlightStatusCode::Unavailable:
+ code = TransportStatusCode::kUnavailable;
+ break;
+ default:
+ break;
+ }
+ } else if (arrow_status.IsKeyError()) {
+ code = TransportStatusCode::kNotFound;
+ } else if (arrow_status.IsInvalid()) {
+ code = TransportStatusCode::kInvalidArgument;
+ } else if (arrow_status.IsCancelled()) {
+ code = TransportStatusCode::kCancelled;
+ } else if (arrow_status.IsNotImplemented()) {
+ code = TransportStatusCode::kUnimplemented;
+ } else if (arrow_status.IsAlreadyExists()) {
+ code = TransportStatusCode::kAlreadyExists;
+ }
+ return TransportStatus{code, std::move(message)};
+}
+
+Status TransportStatus::ToStatus() const {
+ switch (code) {
+ case TransportStatusCode::kOk:
+ return Status::OK();
+ case TransportStatusCode::kUnknown: {
+ std::stringstream ss;
+ ss << "Flight RPC failed with message: " << message;
+ return Status::UnknownError(ss.str()).WithDetail(
+ std::make_shared(FlightStatusCode::Failed));
+ }
+ case TransportStatusCode::kInternal:
+ return Status::IOError("Flight returned internal error, with message: ", message)
+ .WithDetail(std::make_shared(FlightStatusCode::Internal));
+ case TransportStatusCode::kInvalidArgument:
+ return Status::Invalid("Flight returned invalid argument error, with message: ",
+ message);
+ case TransportStatusCode::kTimedOut:
+ return Status::IOError("Flight returned timeout error, with message: ", message)
+ .WithDetail(std::make_shared(FlightStatusCode::TimedOut));
+ case TransportStatusCode::kNotFound:
+ return Status::KeyError("Flight returned not found error, with message: ", message);
+ case TransportStatusCode::kAlreadyExists:
+ return Status::AlreadyExists("Flight returned already exists error, with message: ",
+ message);
+ case TransportStatusCode::kCancelled:
+ return Status::Cancelled("Flight cancelled call, with message: ", message)
+ .WithDetail(std::make_shared(FlightStatusCode::Cancelled));
+ case TransportStatusCode::kUnauthenticated:
+ return Status::IOError("Flight returned unauthenticated error, with message: ",
+ message)
+ .WithDetail(
+ std::make_shared(FlightStatusCode::Unauthenticated));
+ case TransportStatusCode::kUnauthorized:
+ return Status::IOError("Flight returned unauthorized error, with message: ",
+ message)
+ .WithDetail(
+ std::make_shared(FlightStatusCode::Unauthorized));
+ case TransportStatusCode::kUnimplemented:
+ return Status::NotImplemented("Flight returned unimplemented error, with message: ",
+ message);
+ case TransportStatusCode::kUnavailable:
+ return Status::IOError("Flight returned unavailable error, with message: ", message)
+ .WithDetail(
+ std::make_shared(FlightStatusCode::Unavailable));
+ default:
+ return Status::UnknownError("Flight failed with error code ",
+ static_cast(code), " and message: ", message);
+ }
+}
+
bool TransportDataStream::ReadData(internal::FlightData*) { return false; }
arrow::Result TransportDataStream::WriteData(const FlightPayload&) {
return Status::NotImplemented("Writing data for this stream");
diff --git a/cpp/src/arrow/flight/transport.h b/cpp/src/arrow/flight/transport.h
index f02ab05157a..04a5ef0b034 100644
--- a/cpp/src/arrow/flight/transport.h
+++ b/cpp/src/arrow/flight/transport.h
@@ -91,6 +91,38 @@ struct FlightData {
::arrow::Result> OpenMessage();
};
+/// \brief Abstract status code as per the Flight specification.
+enum class TransportStatusCode {
+ kOk = 0,
+ kUnknown = 1,
+ kInternal = 2,
+ kInvalidArgument = 3,
+ kTimedOut = 4,
+ kNotFound = 5,
+ kAlreadyExists = 6,
+ kCancelled = 7,
+ kUnauthenticated = 8,
+ kUnauthorized = 9,
+ kUnimplemented = 10,
+ kUnavailable = 11,
+};
+
+/// \brief Abstract error status.
+///
+/// Transport implementations may use side channels (e.g. HTTP
+/// trailers) to convey additional information to reconstruct the
+/// original C++ status for implementations that can use it.
+struct TransportStatus {
+ TransportStatusCode code;
+ std::string message;
+
+ /// \brief Convert a C++ status to an abstract transport status.
+ static TransportStatus FromStatus(const Status& arrow_status);
+
+ /// \brief Convert an abstract transport status to a C++ status.
+ Status ToStatus() const;
+};
+
/// \brief A transport-specific interface for reading/writing Arrow data.
///
/// New transports will implement this to read/write IPC payloads to
diff --git a/cpp/src/arrow/flight/transport/grpc/util_internal.cc b/cpp/src/arrow/flight/transport/grpc/util_internal.cc
index 5268df160e9..d979f44edcb 100644
--- a/cpp/src/arrow/flight/transport/grpc/util_internal.cc
+++ b/cpp/src/arrow/flight/transport/grpc/util_internal.cc
@@ -20,7 +20,6 @@
#include
#include