From 3b9e7063839b679a13446dd59772c20eb5fbb333 Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 18 May 2022 15:54:20 -0400
Subject: [PATCH] ARROW-16592: [C++][Python][FlightRPC] Finish after failed
writes
Currently a failed write (due to the server sending an error,
disconnecting, etc.) will raise an uninformative error on the
client. Prior to the refactoring done in Arrow 8.0.0, this was
silently swallowed (so clients would not get any indication until
they finished writing). In 8.0.0 instead the error got propagated
but this led to confusing, uninformative errors. Instead, tag this
specific error so that the client implementation knows to finish
the call and get the actual server error.
(gRPC doesn't give us the actual error until we explicitly finish
the call, so we can't get the actual error directly.)
---
cpp/src/arrow/flight/client.cc | 38 ++++++++--
cpp/src/arrow/flight/test_definitions.cc | 75 +++++++++++++++++++
cpp/src/arrow/flight/test_definitions.h | 6 +-
.../flight/transport/grpc/grpc_client.cc | 19 +----
.../arrow/flight/transport/ucx/ucx_client.cc | 2 +-
python/pyarrow/tests/test_flight.py | 74 +++++++++++++++++-
6 files changed, 188 insertions(+), 26 deletions(-)
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index ff13c800cb4..dbf840db050 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -270,6 +270,25 @@ class ClientMetadataReader : public FlightMetadataReader {
std::shared_ptr stream_;
};
+/// This status detail indicates the write failed in the transport
+/// (due to the server) and that we should finish the call at a higher
+/// level (to get the server error); otherwise the client should pass
+/// through the status (it may be recoverable) instead of finishing
+/// the call (which may inadvertently make the server think the client
+/// intended to end the call successfully) or canceling the call
+/// (which may generate an unexpected error message on the client
+/// side).
+const char* kTagDetailTypeId = "flight::ServerErrorTagStatusDetail";
+class ServerErrorTagStatusDetail : public arrow::StatusDetail {
+ public:
+ const char* type_id() const override { return kTagDetailTypeId; }
+ std::string ToString() const override { return type_id(); };
+
+ static bool UnwrapStatus(const arrow::Status& status) {
+ return status.detail() && status.detail()->type_id() == kTagDetailTypeId;
+ }
+};
+
/// \brief An IpcPayloadWriter for any ClientDataStream.
///
/// To support app_metadata and reuse the existing IPC infrastructure,
@@ -321,8 +340,8 @@ class ClientPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
}
ARROW_ASSIGN_OR_RAISE(auto success, stream_->WriteData(payload));
if (!success) {
- return MakeFlightError(
- FlightStatusCode::Internal,
+ return Status::FromDetailAndArgs(
+ StatusCode::IOError, std::make_shared(),
"Could not write record batch to stream (server disconnect?)");
}
return Status::OK();
@@ -397,9 +416,7 @@ class ClientStreamWriter : public FlightStreamWriter {
RETURN_NOT_OK(internal::ToPayload(descriptor_, &payload.descriptor));
ARROW_ASSIGN_OR_RAISE(auto success, stream_->WriteData(payload));
if (!success) {
- return MakeFlightError(
- FlightStatusCode::Internal,
- "Could not write record batch to stream (server disconnect?)");
+ return Close();
}
return Status::OK();
}
@@ -414,8 +431,7 @@ class ClientStreamWriter : public FlightStreamWriter {
payload.app_metadata = app_metadata;
ARROW_ASSIGN_OR_RAISE(auto success, stream_->WriteData(payload));
if (!success) {
- return MakeFlightError(FlightStatusCode::Internal,
- "Could not write metadata to stream (server disconnect?)");
+ return Close();
}
return Status::OK();
}
@@ -424,7 +440,13 @@ class ClientStreamWriter : public FlightStreamWriter {
std::shared_ptr app_metadata) override {
RETURN_NOT_OK(CheckStarted());
app_metadata_ = app_metadata;
- return batch_writer_->WriteRecordBatch(batch);
+ auto status = batch_writer_->WriteRecordBatch(batch);
+ if (!status.ok() &&
+ // Only want to Close() if server error, not for client error
+ ServerErrorTagStatusDetail::UnwrapStatus(status)) {
+ return Close();
+ }
+ return status;
}
Status DoneWriting() override {
diff --git a/cpp/src/arrow/flight/test_definitions.cc b/cpp/src/arrow/flight/test_definitions.cc
index 61873ba36da..1ede500795c 100644
--- a/cpp/src/arrow/flight/test_definitions.cc
+++ b/cpp/src/arrow/flight/test_definitions.cc
@@ -1437,6 +1437,18 @@ class ErrorHandlingTestServer : public FlightServerBase {
}
return Status::NotImplemented("NYI");
}
+
+ Status DoPut(const ServerCallContext& context,
+ std::unique_ptr reader,
+ std::unique_ptr writer) override {
+ return MakeFlightError(FlightStatusCode::Unauthorized, "Unauthorized", "extra info");
+ }
+
+ Status DoExchange(const ServerCallContext& context,
+ std::unique_ptr reader,
+ std::unique_ptr writer) override {
+ return MakeFlightError(FlightStatusCode::Unauthorized, "Unauthorized", "extra info");
+ }
};
} // namespace
@@ -1487,5 +1499,68 @@ void ErrorHandlingTest::TestGetFlightInfo() {
}
}
+void CheckErrorDetail(const Status& status) {
+ auto detail = FlightStatusDetail::UnwrapStatus(status);
+ ASSERT_NE(detail, nullptr) << status.ToString();
+ ASSERT_EQ(detail->code(), FlightStatusCode::Unauthorized);
+ ASSERT_EQ(detail->extra_info(), "extra info");
+}
+
+void ErrorHandlingTest::TestDoPut() {
+ // ARROW-16592
+ auto schema = arrow::schema({field("int64", int64())});
+ auto descr = FlightDescriptor::Path({""});
+ FlightClient::DoPutResult stream;
+ auto status = client_->DoPut(descr, schema).Value(&stream);
+ if (!status.ok()) {
+ ASSERT_NO_FATAL_FAILURE(CheckErrorDetail(status));
+ return;
+ }
+
+ std::thread reader_thread([&]() {
+ std::shared_ptr out;
+ while (true) {
+ if (!stream.reader->ReadMetadata(&out).ok()) {
+ return;
+ }
+ }
+ });
+
+ auto batch = RecordBatchFromJSON(schema, "[[0]]");
+ while (true) {
+ status = stream.writer->WriteRecordBatch(*batch);
+ if (!status.ok()) break;
+ }
+
+ ASSERT_NO_FATAL_FAILURE(CheckErrorDetail(status));
+ ASSERT_NO_FATAL_FAILURE(CheckErrorDetail(stream.writer->Close()));
+ reader_thread.join();
+}
+
+void ErrorHandlingTest::TestDoExchange() {
+ // ARROW-16592
+ FlightClient::DoExchangeResult stream;
+ auto status = client_->DoExchange(FlightDescriptor::Path({""})).Value(&stream);
+ if (!status.ok()) {
+ ASSERT_NO_FATAL_FAILURE(CheckErrorDetail(status));
+ return;
+ }
+
+ std::thread reader_thread([&]() {
+ while (true) {
+ if (!stream.reader->Next().ok()) return;
+ }
+ });
+
+ while (true) {
+ status = stream.writer->WriteMetadata(Buffer::FromString("foo"));
+ if (!status.ok()) break;
+ }
+
+ ASSERT_NO_FATAL_FAILURE(CheckErrorDetail(status));
+ ASSERT_NO_FATAL_FAILURE(CheckErrorDetail(stream.writer->Close()));
+ reader_thread.join();
+}
+
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/test_definitions.h b/cpp/src/arrow/flight/test_definitions.h
index aa0d07857be..5f01caf8373 100644
--- a/cpp/src/arrow/flight/test_definitions.h
+++ b/cpp/src/arrow/flight/test_definitions.h
@@ -263,6 +263,8 @@ class ARROW_FLIGHT_EXPORT ErrorHandlingTest : public FlightTest {
// Test methods
void TestGetFlightInfo();
+ void TestDoPut();
+ void TestDoExchange();
private:
std::unique_ptr client_;
@@ -272,7 +274,9 @@ class ARROW_FLIGHT_EXPORT ErrorHandlingTest : public FlightTest {
#define ARROW_FLIGHT_TEST_ERROR_HANDLING(FIXTURE) \
static_assert(std::is_base_of::value, \
ARROW_STRINGIFY(FIXTURE) " must inherit from ErrorHandlingTest"); \
- TEST_F(FIXTURE, TestGetFlightInfo) { TestGetFlightInfo(); }
+ TEST_F(FIXTURE, TestGetFlightInfo) { TestGetFlightInfo(); } \
+ TEST_F(FIXTURE, TestDoPut) { TestDoPut(); } \
+ TEST_F(FIXTURE, TestDoExchange) { TestDoExchange(); }
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc
index 1c0ac2d31fa..8fe1e1bae79 100644
--- a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc
+++ b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc
@@ -326,10 +326,7 @@ class WritableDataStream : public FinishableDataStream {
public:
using Base = FinishableDataStream;
WritableDataStream(std::shared_ptr rpc, std::shared_ptr stream)
- : Base(std::move(rpc), std::move(stream)),
- read_mutex_(),
- finish_mutex_(),
- done_writing_(false) {}
+ : Base(std::move(rpc), std::move(stream)), read_mutex_(), done_writing_(false) {}
Status WritesDone() override {
// This is only used by the writer side of a stream, so it need
@@ -350,16 +347,7 @@ class WritableDataStream : public FinishableDataStream {
Status DoFinish() override {
// This may be used concurrently by reader/writer side of a
// stream, so it needs to be protected.
- std::lock_guard guard(finish_mutex_);
-
- // Now that we're shared between a reader and writer, we need to
- // protect ourselves from being called while there's an
- // outstanding read.
- std::unique_lock read_guard(read_mutex_, std::try_to_lock);
- if (!read_guard.owns_lock()) {
- return MakeFlightError(FlightStatusCode::Internal,
- "Cannot close stream with pending read operation.");
- }
+ std::lock_guard guard(read_mutex_);
// Try to flush pending writes. Don't use our WritesDone() to
// avoid recursion.
@@ -377,7 +365,6 @@ class WritableDataStream : public FinishableDataStream {
using Base::stream_;
std::mutex read_mutex_;
- std::mutex finish_mutex_;
bool done_writing_;
};
@@ -402,6 +389,7 @@ class GrpcClientPutStream
bool ReadPutMetadata(std::shared_ptr* out) override {
std::lock_guard guard(read_mutex_);
+ if (finished_) return false;
pb::PutResult message;
if (stream_->Read(&message)) {
*out = Buffer::FromString(std::move(*message.mutable_app_metadata()));
@@ -427,6 +415,7 @@ class GrpcClientExchangeStream
bool ReadData(internal::FlightData* data) override {
std::lock_guard guard(read_mutex_);
+ if (finished_) return false;
return ReadPayload(stream_.get(), data);
}
arrow::Result WriteData(const FlightPayload& payload) override {
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc
index 173132062e5..810f2c482a3 100644
--- a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc
+++ b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc
@@ -257,7 +257,7 @@ class WriteClientStream : public UcxClientStream {
}
arrow::Result WriteData(const FlightPayload& payload) override {
std::unique_lock guard(driver_mutex_);
- if (finished_ || writes_done_) return Status::Invalid("Already done writing");
+ if (finished_ || writes_done_) return false;
outgoing_ = driver_->SendFlightPayload(payload);
working_cv_.notify_all();
completed_cv_.wait(guard, [this] { return outgoing_.is_finished(); });
diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py
index fbf478a7cd6..1805453e8ab 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -1517,7 +1517,8 @@ def test_cancel_do_get():
FlightClient(('localhost', server.port)) as client:
reader = client.do_get(flight.Ticket(b'ints'))
reader.cancel()
- with pytest.raises(flight.FlightCancelledError, match=".*Cancel.*"):
+ with pytest.raises(flight.FlightCancelledError,
+ match="(?i).*cancel.*"):
reader.read_chunk()
@@ -2085,3 +2086,74 @@ def test_none_action_side_effect():
client.do_action(flight.Action("append", b""))
r = client.do_action(flight.Action("get_value", b""))
assert json.loads(next(r).body.to_pybytes()) == [True]
+
+
+@pytest.mark.slow # Takes a while for gRPC to "realize" writes fail
+def test_write_error_propagation():
+ """
+ Ensure that exceptions during writing preserve error context.
+
+ See https://issues.apache.org/jira/browse/ARROW-16592.
+ """
+ expected_message = "foo"
+ expected_info = b"bar"
+ exc = flight.FlightCancelledError(
+ expected_message, extra_info=expected_info)
+ descriptor = flight.FlightDescriptor.for_command(b"")
+ schema = pa.schema([("int64", pa.int64())])
+
+ class FailServer(flight.FlightServerBase):
+ def do_put(self, context, descriptor, reader, writer):
+ raise exc
+
+ def do_exchange(self, context, descriptor, reader, writer):
+ raise exc
+
+ with FailServer() as server, \
+ FlightClient(('localhost', server.port)) as client:
+ # DoPut
+ writer, reader = client.do_put(descriptor, schema)
+
+ # Set a concurrent reader - ensure this doesn't block the
+ # writer side from calling Close()
+ def _reader():
+ try:
+ while True:
+ reader.read()
+ except flight.FlightError:
+ return
+
+ thread = threading.Thread(target=_reader, daemon=True)
+ thread.start()
+
+ with pytest.raises(flight.FlightCancelledError) as exc_info:
+ while True:
+ writer.write_batch(pa.record_batch([[1]], schema=schema))
+ assert exc_info.value.extra_info == expected_info
+
+ with pytest.raises(flight.FlightCancelledError) as exc_info:
+ writer.close()
+ assert exc_info.value.extra_info == expected_info
+ thread.join()
+
+ # DoExchange
+ writer, reader = client.do_exchange(descriptor)
+
+ def _reader():
+ try:
+ while True:
+ reader.read_chunk()
+ except flight.FlightError:
+ return
+
+ thread = threading.Thread(target=_reader, daemon=True)
+ thread.start()
+ with pytest.raises(flight.FlightCancelledError) as exc_info:
+ while True:
+ writer.write_metadata(b" ")
+ assert exc_info.value.extra_info == expected_info
+
+ with pytest.raises(flight.FlightCancelledError) as exc_info:
+ writer.close()
+ assert exc_info.value.extra_info == expected_info
+ thread.join()