Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions cpp/src/arrow/flight/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,25 @@ class ClientMetadataReader : public FlightMetadataReader {
std::shared_ptr<internal::ClientDataStream> stream_;
};

/// This status detail indicates the write failed in the transport
/// (due to the server) and that we should finish the call at a higher
/// level (to get the server error); otherwise the client should pass
/// through the status (it may be recoverable) instead of finishing
/// the call (which may inadvertently make the server think the client
/// intended to end the call successfully) or canceling the call
/// (which may generate an unexpected error message on the client
/// side).
const char* kTagDetailTypeId = "flight::ServerErrorTagStatusDetail";
class ServerErrorTagStatusDetail : public arrow::StatusDetail {
public:
const char* type_id() const override { return kTagDetailTypeId; }
std::string ToString() const override { return type_id(); };

static bool UnwrapStatus(const arrow::Status& status) {
return status.detail() && status.detail()->type_id() == kTagDetailTypeId;
}
};

/// \brief An IpcPayloadWriter for any ClientDataStream.
///
/// To support app_metadata and reuse the existing IPC infrastructure,
Expand Down Expand Up @@ -321,8 +340,8 @@ class ClientPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
}
ARROW_ASSIGN_OR_RAISE(auto success, stream_->WriteData(payload));
if (!success) {
return MakeFlightError(
FlightStatusCode::Internal,
return Status::FromDetailAndArgs(
StatusCode::IOError, std::make_shared<ServerErrorTagStatusDetail>(),
"Could not write record batch to stream (server disconnect?)");
}
return Status::OK();
Expand Down Expand Up @@ -397,9 +416,7 @@ class ClientStreamWriter : public FlightStreamWriter {
RETURN_NOT_OK(internal::ToPayload(descriptor_, &payload.descriptor));
ARROW_ASSIGN_OR_RAISE(auto success, stream_->WriteData(payload));
if (!success) {
return MakeFlightError(
FlightStatusCode::Internal,
"Could not write record batch to stream (server disconnect?)");
return Close();
}
return Status::OK();
}
Expand All @@ -414,8 +431,7 @@ class ClientStreamWriter : public FlightStreamWriter {
payload.app_metadata = app_metadata;
ARROW_ASSIGN_OR_RAISE(auto success, stream_->WriteData(payload));
if (!success) {
return MakeFlightError(FlightStatusCode::Internal,
"Could not write metadata to stream (server disconnect?)");
return Close();
}
return Status::OK();
}
Expand All @@ -424,7 +440,13 @@ class ClientStreamWriter : public FlightStreamWriter {
std::shared_ptr<Buffer> app_metadata) override {
RETURN_NOT_OK(CheckStarted());
app_metadata_ = app_metadata;
return batch_writer_->WriteRecordBatch(batch);
auto status = batch_writer_->WriteRecordBatch(batch);
if (!status.ok() &&
// Only want to Close() if server error, not for client error
ServerErrorTagStatusDetail::UnwrapStatus(status)) {
return Close();
}
return status;
}

Status DoneWriting() override {
Expand Down
75 changes: 75 additions & 0 deletions cpp/src/arrow/flight/test_definitions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,18 @@ class ErrorHandlingTestServer : public FlightServerBase {
}
return Status::NotImplemented("NYI");
}

Status DoPut(const ServerCallContext& context,
std::unique_ptr<FlightMessageReader> reader,
std::unique_ptr<FlightMetadataWriter> writer) override {
return MakeFlightError(FlightStatusCode::Unauthorized, "Unauthorized", "extra info");
}

Status DoExchange(const ServerCallContext& context,
std::unique_ptr<FlightMessageReader> reader,
std::unique_ptr<FlightMessageWriter> writer) override {
return MakeFlightError(FlightStatusCode::Unauthorized, "Unauthorized", "extra info");
}
};
} // namespace

Expand Down Expand Up @@ -1487,5 +1499,68 @@ void ErrorHandlingTest::TestGetFlightInfo() {
}
}

void CheckErrorDetail(const Status& status) {
auto detail = FlightStatusDetail::UnwrapStatus(status);
ASSERT_NE(detail, nullptr) << status.ToString();
ASSERT_EQ(detail->code(), FlightStatusCode::Unauthorized);
ASSERT_EQ(detail->extra_info(), "extra info");
}

void ErrorHandlingTest::TestDoPut() {
// ARROW-16592
auto schema = arrow::schema({field("int64", int64())});
auto descr = FlightDescriptor::Path({""});
FlightClient::DoPutResult stream;
auto status = client_->DoPut(descr, schema).Value(&stream);
if (!status.ok()) {
ASSERT_NO_FATAL_FAILURE(CheckErrorDetail(status));
return;
}

std::thread reader_thread([&]() {
std::shared_ptr<Buffer> out;
while (true) {
if (!stream.reader->ReadMetadata(&out).ok()) {
return;
}
}
});

auto batch = RecordBatchFromJSON(schema, "[[0]]");
while (true) {
status = stream.writer->WriteRecordBatch(*batch);
if (!status.ok()) break;
}

ASSERT_NO_FATAL_FAILURE(CheckErrorDetail(status));
ASSERT_NO_FATAL_FAILURE(CheckErrorDetail(stream.writer->Close()));
reader_thread.join();
}

void ErrorHandlingTest::TestDoExchange() {
// ARROW-16592
FlightClient::DoExchangeResult stream;
auto status = client_->DoExchange(FlightDescriptor::Path({""})).Value(&stream);
if (!status.ok()) {
ASSERT_NO_FATAL_FAILURE(CheckErrorDetail(status));
return;
}

std::thread reader_thread([&]() {
while (true) {
if (!stream.reader->Next().ok()) return;
}
});

while (true) {
status = stream.writer->WriteMetadata(Buffer::FromString("foo"));
if (!status.ok()) break;
}

ASSERT_NO_FATAL_FAILURE(CheckErrorDetail(status));
ASSERT_NO_FATAL_FAILURE(CheckErrorDetail(stream.writer->Close()));
reader_thread.join();
}

} // namespace flight
} // namespace arrow
6 changes: 5 additions & 1 deletion cpp/src/arrow/flight/test_definitions.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ class ARROW_FLIGHT_EXPORT ErrorHandlingTest : public FlightTest {

// Test methods
void TestGetFlightInfo();
void TestDoPut();
void TestDoExchange();

private:
std::unique_ptr<FlightClient> client_;
Expand All @@ -272,7 +274,9 @@ class ARROW_FLIGHT_EXPORT ErrorHandlingTest : public FlightTest {
#define ARROW_FLIGHT_TEST_ERROR_HANDLING(FIXTURE) \
static_assert(std::is_base_of<ErrorHandlingTest, FIXTURE>::value, \
ARROW_STRINGIFY(FIXTURE) " must inherit from ErrorHandlingTest"); \
TEST_F(FIXTURE, TestGetFlightInfo) { TestGetFlightInfo(); }
TEST_F(FIXTURE, TestGetFlightInfo) { TestGetFlightInfo(); } \
TEST_F(FIXTURE, TestDoPut) { TestDoPut(); } \
TEST_F(FIXTURE, TestDoExchange) { TestDoExchange(); }

} // namespace flight
} // namespace arrow
19 changes: 4 additions & 15 deletions cpp/src/arrow/flight/transport/grpc/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,7 @@ class WritableDataStream : public FinishableDataStream<Stream, ReadPayload> {
public:
using Base = FinishableDataStream<Stream, ReadPayload>;
WritableDataStream(std::shared_ptr<ClientRpc> rpc, std::shared_ptr<Stream> stream)
: Base(std::move(rpc), std::move(stream)),
read_mutex_(),
finish_mutex_(),
done_writing_(false) {}
: Base(std::move(rpc), std::move(stream)), read_mutex_(), done_writing_(false) {}

Status WritesDone() override {
// This is only used by the writer side of a stream, so it need
Expand All @@ -350,16 +347,7 @@ class WritableDataStream : public FinishableDataStream<Stream, ReadPayload> {
Status DoFinish() override {
// This may be used concurrently by reader/writer side of a
// stream, so it needs to be protected.
std::lock_guard<std::mutex> guard(finish_mutex_);

// Now that we're shared between a reader and writer, we need to
// protect ourselves from being called while there's an
// outstanding read.
std::unique_lock<std::mutex> read_guard(read_mutex_, std::try_to_lock);
if (!read_guard.owns_lock()) {
return MakeFlightError(FlightStatusCode::Internal,
"Cannot close stream with pending read operation.");
}
std::lock_guard<std::mutex> guard(read_mutex_);

// Try to flush pending writes. Don't use our WritesDone() to
// avoid recursion.
Expand All @@ -377,7 +365,6 @@ class WritableDataStream : public FinishableDataStream<Stream, ReadPayload> {

using Base::stream_;
std::mutex read_mutex_;
std::mutex finish_mutex_;
bool done_writing_;
};

Expand All @@ -402,6 +389,7 @@ class GrpcClientPutStream

bool ReadPutMetadata(std::shared_ptr<Buffer>* out) override {
std::lock_guard<std::mutex> guard(read_mutex_);
if (finished_) return false;
pb::PutResult message;
if (stream_->Read(&message)) {
*out = Buffer::FromString(std::move(*message.mutable_app_metadata()));
Expand All @@ -427,6 +415,7 @@ class GrpcClientExchangeStream

bool ReadData(internal::FlightData* data) override {
std::lock_guard<std::mutex> guard(read_mutex_);
if (finished_) return false;
return ReadPayload(stream_.get(), data);
}
arrow::Result<bool> WriteData(const FlightPayload& payload) override {
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/flight/transport/ucx/ucx_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ class WriteClientStream : public UcxClientStream {
}
arrow::Result<bool> WriteData(const FlightPayload& payload) override {
std::unique_lock<std::mutex> guard(driver_mutex_);
if (finished_ || writes_done_) return Status::Invalid("Already done writing");
if (finished_ || writes_done_) return false;
outgoing_ = driver_->SendFlightPayload(payload);
working_cv_.notify_all();
completed_cv_.wait(guard, [this] { return outgoing_.is_finished(); });
Expand Down
74 changes: 73 additions & 1 deletion python/pyarrow/tests/test_flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -1517,7 +1517,8 @@ def test_cancel_do_get():
FlightClient(('localhost', server.port)) as client:
reader = client.do_get(flight.Ticket(b'ints'))
reader.cancel()
with pytest.raises(flight.FlightCancelledError, match=".*Cancel.*"):
with pytest.raises(flight.FlightCancelledError,
match="(?i).*cancel.*"):
reader.read_chunk()


Expand Down Expand Up @@ -2085,3 +2086,74 @@ def test_none_action_side_effect():
client.do_action(flight.Action("append", b""))
r = client.do_action(flight.Action("get_value", b""))
assert json.loads(next(r).body.to_pybytes()) == [True]


@pytest.mark.slow # Takes a while for gRPC to "realize" writes fail
def test_write_error_propagation():
"""
Ensure that exceptions during writing preserve error context.

See https://issues.apache.org/jira/browse/ARROW-16592.
"""
expected_message = "foo"
expected_info = b"bar"
exc = flight.FlightCancelledError(
expected_message, extra_info=expected_info)
descriptor = flight.FlightDescriptor.for_command(b"")
schema = pa.schema([("int64", pa.int64())])

class FailServer(flight.FlightServerBase):
def do_put(self, context, descriptor, reader, writer):
raise exc

def do_exchange(self, context, descriptor, reader, writer):
raise exc

with FailServer() as server, \
FlightClient(('localhost', server.port)) as client:
# DoPut
writer, reader = client.do_put(descriptor, schema)

# Set a concurrent reader - ensure this doesn't block the
# writer side from calling Close()
def _reader():
try:
while True:
reader.read()
except flight.FlightError:
return

thread = threading.Thread(target=_reader, daemon=True)
thread.start()

with pytest.raises(flight.FlightCancelledError) as exc_info:
while True:
writer.write_batch(pa.record_batch([[1]], schema=schema))
assert exc_info.value.extra_info == expected_info

with pytest.raises(flight.FlightCancelledError) as exc_info:
writer.close()
assert exc_info.value.extra_info == expected_info
thread.join()

# DoExchange
writer, reader = client.do_exchange(descriptor)

def _reader():
try:
while True:
reader.read_chunk()
except flight.FlightError:
return

thread = threading.Thread(target=_reader, daemon=True)
thread.start()
with pytest.raises(flight.FlightCancelledError) as exc_info:
while True:
writer.write_metadata(b" ")
assert exc_info.value.extra_info == expected_info

with pytest.raises(flight.FlightCancelledError) as exc_info:
writer.close()
assert exc_info.value.extra_info == expected_info
thread.join()