diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index ff13c800cb4..dbf840db050 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -270,6 +270,25 @@ class ClientMetadataReader : public FlightMetadataReader { std::shared_ptr stream_; }; +/// This status detail indicates the write failed in the transport +/// (due to the server) and that we should finish the call at a higher +/// level (to get the server error); otherwise the client should pass +/// through the status (it may be recoverable) instead of finishing +/// the call (which may inadvertently make the server think the client +/// intended to end the call successfully) or canceling the call +/// (which may generate an unexpected error message on the client +/// side). +const char* kTagDetailTypeId = "flight::ServerErrorTagStatusDetail"; +class ServerErrorTagStatusDetail : public arrow::StatusDetail { + public: + const char* type_id() const override { return kTagDetailTypeId; } + std::string ToString() const override { return type_id(); }; + + static bool UnwrapStatus(const arrow::Status& status) { + return status.detail() && status.detail()->type_id() == kTagDetailTypeId; + } +}; + /// \brief An IpcPayloadWriter for any ClientDataStream. /// /// To support app_metadata and reuse the existing IPC infrastructure, @@ -321,8 +340,8 @@ class ClientPutPayloadWriter : public ipc::internal::IpcPayloadWriter { } ARROW_ASSIGN_OR_RAISE(auto success, stream_->WriteData(payload)); if (!success) { - return MakeFlightError( - FlightStatusCode::Internal, + return Status::FromDetailAndArgs( + StatusCode::IOError, std::make_shared(), "Could not write record batch to stream (server disconnect?)"); } return Status::OK(); @@ -397,9 +416,7 @@ class ClientStreamWriter : public FlightStreamWriter { RETURN_NOT_OK(internal::ToPayload(descriptor_, &payload.descriptor)); ARROW_ASSIGN_OR_RAISE(auto success, stream_->WriteData(payload)); if (!success) { - return MakeFlightError( - FlightStatusCode::Internal, - "Could not write record batch to stream (server disconnect?)"); + return Close(); } return Status::OK(); } @@ -414,8 +431,7 @@ class ClientStreamWriter : public FlightStreamWriter { payload.app_metadata = app_metadata; ARROW_ASSIGN_OR_RAISE(auto success, stream_->WriteData(payload)); if (!success) { - return MakeFlightError(FlightStatusCode::Internal, - "Could not write metadata to stream (server disconnect?)"); + return Close(); } return Status::OK(); } @@ -424,7 +440,13 @@ class ClientStreamWriter : public FlightStreamWriter { std::shared_ptr app_metadata) override { RETURN_NOT_OK(CheckStarted()); app_metadata_ = app_metadata; - return batch_writer_->WriteRecordBatch(batch); + auto status = batch_writer_->WriteRecordBatch(batch); + if (!status.ok() && + // Only want to Close() if server error, not for client error + ServerErrorTagStatusDetail::UnwrapStatus(status)) { + return Close(); + } + return status; } Status DoneWriting() override { diff --git a/cpp/src/arrow/flight/test_definitions.cc b/cpp/src/arrow/flight/test_definitions.cc index 61873ba36da..1ede500795c 100644 --- a/cpp/src/arrow/flight/test_definitions.cc +++ b/cpp/src/arrow/flight/test_definitions.cc @@ -1437,6 +1437,18 @@ class ErrorHandlingTestServer : public FlightServerBase { } return Status::NotImplemented("NYI"); } + + Status DoPut(const ServerCallContext& context, + std::unique_ptr reader, + std::unique_ptr writer) override { + return MakeFlightError(FlightStatusCode::Unauthorized, "Unauthorized", "extra info"); + } + + Status DoExchange(const ServerCallContext& context, + std::unique_ptr reader, + std::unique_ptr writer) override { + return MakeFlightError(FlightStatusCode::Unauthorized, "Unauthorized", "extra info"); + } }; } // namespace @@ -1487,5 +1499,68 @@ void ErrorHandlingTest::TestGetFlightInfo() { } } +void CheckErrorDetail(const Status& status) { + auto detail = FlightStatusDetail::UnwrapStatus(status); + ASSERT_NE(detail, nullptr) << status.ToString(); + ASSERT_EQ(detail->code(), FlightStatusCode::Unauthorized); + ASSERT_EQ(detail->extra_info(), "extra info"); +} + +void ErrorHandlingTest::TestDoPut() { + // ARROW-16592 + auto schema = arrow::schema({field("int64", int64())}); + auto descr = FlightDescriptor::Path({""}); + FlightClient::DoPutResult stream; + auto status = client_->DoPut(descr, schema).Value(&stream); + if (!status.ok()) { + ASSERT_NO_FATAL_FAILURE(CheckErrorDetail(status)); + return; + } + + std::thread reader_thread([&]() { + std::shared_ptr out; + while (true) { + if (!stream.reader->ReadMetadata(&out).ok()) { + return; + } + } + }); + + auto batch = RecordBatchFromJSON(schema, "[[0]]"); + while (true) { + status = stream.writer->WriteRecordBatch(*batch); + if (!status.ok()) break; + } + + ASSERT_NO_FATAL_FAILURE(CheckErrorDetail(status)); + ASSERT_NO_FATAL_FAILURE(CheckErrorDetail(stream.writer->Close())); + reader_thread.join(); +} + +void ErrorHandlingTest::TestDoExchange() { + // ARROW-16592 + FlightClient::DoExchangeResult stream; + auto status = client_->DoExchange(FlightDescriptor::Path({""})).Value(&stream); + if (!status.ok()) { + ASSERT_NO_FATAL_FAILURE(CheckErrorDetail(status)); + return; + } + + std::thread reader_thread([&]() { + while (true) { + if (!stream.reader->Next().ok()) return; + } + }); + + while (true) { + status = stream.writer->WriteMetadata(Buffer::FromString("foo")); + if (!status.ok()) break; + } + + ASSERT_NO_FATAL_FAILURE(CheckErrorDetail(status)); + ASSERT_NO_FATAL_FAILURE(CheckErrorDetail(stream.writer->Close())); + reader_thread.join(); +} + } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/test_definitions.h b/cpp/src/arrow/flight/test_definitions.h index aa0d07857be..5f01caf8373 100644 --- a/cpp/src/arrow/flight/test_definitions.h +++ b/cpp/src/arrow/flight/test_definitions.h @@ -263,6 +263,8 @@ class ARROW_FLIGHT_EXPORT ErrorHandlingTest : public FlightTest { // Test methods void TestGetFlightInfo(); + void TestDoPut(); + void TestDoExchange(); private: std::unique_ptr client_; @@ -272,7 +274,9 @@ class ARROW_FLIGHT_EXPORT ErrorHandlingTest : public FlightTest { #define ARROW_FLIGHT_TEST_ERROR_HANDLING(FIXTURE) \ static_assert(std::is_base_of::value, \ ARROW_STRINGIFY(FIXTURE) " must inherit from ErrorHandlingTest"); \ - TEST_F(FIXTURE, TestGetFlightInfo) { TestGetFlightInfo(); } + TEST_F(FIXTURE, TestGetFlightInfo) { TestGetFlightInfo(); } \ + TEST_F(FIXTURE, TestDoPut) { TestDoPut(); } \ + TEST_F(FIXTURE, TestDoExchange) { TestDoExchange(); } } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc index 1c0ac2d31fa..8fe1e1bae79 100644 --- a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc +++ b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc @@ -326,10 +326,7 @@ class WritableDataStream : public FinishableDataStream { public: using Base = FinishableDataStream; WritableDataStream(std::shared_ptr rpc, std::shared_ptr stream) - : Base(std::move(rpc), std::move(stream)), - read_mutex_(), - finish_mutex_(), - done_writing_(false) {} + : Base(std::move(rpc), std::move(stream)), read_mutex_(), done_writing_(false) {} Status WritesDone() override { // This is only used by the writer side of a stream, so it need @@ -350,16 +347,7 @@ class WritableDataStream : public FinishableDataStream { Status DoFinish() override { // This may be used concurrently by reader/writer side of a // stream, so it needs to be protected. - std::lock_guard guard(finish_mutex_); - - // Now that we're shared between a reader and writer, we need to - // protect ourselves from being called while there's an - // outstanding read. - std::unique_lock read_guard(read_mutex_, std::try_to_lock); - if (!read_guard.owns_lock()) { - return MakeFlightError(FlightStatusCode::Internal, - "Cannot close stream with pending read operation."); - } + std::lock_guard guard(read_mutex_); // Try to flush pending writes. Don't use our WritesDone() to // avoid recursion. @@ -377,7 +365,6 @@ class WritableDataStream : public FinishableDataStream { using Base::stream_; std::mutex read_mutex_; - std::mutex finish_mutex_; bool done_writing_; }; @@ -402,6 +389,7 @@ class GrpcClientPutStream bool ReadPutMetadata(std::shared_ptr* out) override { std::lock_guard guard(read_mutex_); + if (finished_) return false; pb::PutResult message; if (stream_->Read(&message)) { *out = Buffer::FromString(std::move(*message.mutable_app_metadata())); @@ -427,6 +415,7 @@ class GrpcClientExchangeStream bool ReadData(internal::FlightData* data) override { std::lock_guard guard(read_mutex_); + if (finished_) return false; return ReadPayload(stream_.get(), data); } arrow::Result WriteData(const FlightPayload& payload) override { diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc index 173132062e5..810f2c482a3 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc @@ -257,7 +257,7 @@ class WriteClientStream : public UcxClientStream { } arrow::Result WriteData(const FlightPayload& payload) override { std::unique_lock guard(driver_mutex_); - if (finished_ || writes_done_) return Status::Invalid("Already done writing"); + if (finished_ || writes_done_) return false; outgoing_ = driver_->SendFlightPayload(payload); working_cv_.notify_all(); completed_cv_.wait(guard, [this] { return outgoing_.is_finished(); }); diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py index fbf478a7cd6..1805453e8ab 100644 --- a/python/pyarrow/tests/test_flight.py +++ b/python/pyarrow/tests/test_flight.py @@ -1517,7 +1517,8 @@ def test_cancel_do_get(): FlightClient(('localhost', server.port)) as client: reader = client.do_get(flight.Ticket(b'ints')) reader.cancel() - with pytest.raises(flight.FlightCancelledError, match=".*Cancel.*"): + with pytest.raises(flight.FlightCancelledError, + match="(?i).*cancel.*"): reader.read_chunk() @@ -2085,3 +2086,74 @@ def test_none_action_side_effect(): client.do_action(flight.Action("append", b"")) r = client.do_action(flight.Action("get_value", b"")) assert json.loads(next(r).body.to_pybytes()) == [True] + + +@pytest.mark.slow # Takes a while for gRPC to "realize" writes fail +def test_write_error_propagation(): + """ + Ensure that exceptions during writing preserve error context. + + See https://issues.apache.org/jira/browse/ARROW-16592. + """ + expected_message = "foo" + expected_info = b"bar" + exc = flight.FlightCancelledError( + expected_message, extra_info=expected_info) + descriptor = flight.FlightDescriptor.for_command(b"") + schema = pa.schema([("int64", pa.int64())]) + + class FailServer(flight.FlightServerBase): + def do_put(self, context, descriptor, reader, writer): + raise exc + + def do_exchange(self, context, descriptor, reader, writer): + raise exc + + with FailServer() as server, \ + FlightClient(('localhost', server.port)) as client: + # DoPut + writer, reader = client.do_put(descriptor, schema) + + # Set a concurrent reader - ensure this doesn't block the + # writer side from calling Close() + def _reader(): + try: + while True: + reader.read() + except flight.FlightError: + return + + thread = threading.Thread(target=_reader, daemon=True) + thread.start() + + with pytest.raises(flight.FlightCancelledError) as exc_info: + while True: + writer.write_batch(pa.record_batch([[1]], schema=schema)) + assert exc_info.value.extra_info == expected_info + + with pytest.raises(flight.FlightCancelledError) as exc_info: + writer.close() + assert exc_info.value.extra_info == expected_info + thread.join() + + # DoExchange + writer, reader = client.do_exchange(descriptor) + + def _reader(): + try: + while True: + reader.read_chunk() + except flight.FlightError: + return + + thread = threading.Thread(target=_reader, daemon=True) + thread.start() + with pytest.raises(flight.FlightCancelledError) as exc_info: + while True: + writer.write_metadata(b" ") + assert exc_info.value.extra_info == expected_info + + with pytest.raises(flight.FlightCancelledError) as exc_info: + writer.close() + assert exc_info.value.extra_info == expected_info + thread.join()