From 7390d62f119ee993febd99f768fae7d97c42d32a Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 12 May 2021 17:09:54 -0400
Subject: [PATCH 1/3] ARROW-12050: [C++][FlightRPC] Enable cancellation of
long-running requests
---
cpp/src/arrow/flight/client.cc | 53 +++++++++--
cpp/src/arrow/flight/client.h | 10 ++
cpp/src/arrow/flight/flight_test.cc | 142 ++++++++++++++++++++++++++++
cpp/src/arrow/flight/server.cc | 1 +
cpp/src/arrow/flight/server.h | 3 +
cpp/src/arrow/util/cancel.cc | 4 +-
cpp/src/arrow/util/cancel.h | 4 +-
7 files changed, 204 insertions(+), 13 deletions(-)
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index c0e8eaaed28..880454bca1b 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -45,6 +45,7 @@
#include "arrow/record_batch.h"
#include "arrow/result.h"
#include "arrow/status.h"
+#include "arrow/table.h"
#include "arrow/type.h"
#include "arrow/util/logging.h"
#include "arrow/util/uri.h"
@@ -484,11 +485,12 @@ template
class GrpcStreamReader : public FlightStreamReader {
public:
GrpcStreamReader(std::shared_ptr rpc, std::shared_ptr read_mutex,
- const ipc::IpcReadOptions& options,
+ const ipc::IpcReadOptions& options, StopToken stop_token,
std::shared_ptr> stream)
: rpc_(rpc),
read_mutex_(read_mutex),
options_(options),
+ stop_token_(std::move(stop_token)),
stream_(stream),
peekable_reader_(new internal::PeekableFlightDataReader>(
stream->stream())),
@@ -552,6 +554,33 @@ class GrpcStreamReader : public FlightStreamReader {
out->app_metadata = std::move(app_metadata_);
return Status::OK();
}
+ Status ReadAll(std::vector>* batches) override {
+ return ReadAll(batches, stop_token_);
+ }
+ Status ReadAll(std::vector>* batches,
+ const StopToken& stop_token) override {
+ FlightStreamChunk chunk;
+
+ while (true) {
+ if (stop_token.IsStopRequested()) {
+ Cancel();
+ return stop_token.Poll();
+ }
+ RETURN_NOT_OK(Next(&chunk));
+ if (!chunk.data) break;
+ batches->emplace_back(std::move(chunk.data));
+ }
+ return Status::OK();
+ }
+ Status ReadAll(std::shared_ptr* table) override {
+ return ReadAll(table, stop_token_);
+ }
+ Status ReadAll(std::shared_ptr* table, const StopToken& stop_token) override {
+ std::vector> batches;
+ RETURN_NOT_OK(ReadAll(&batches, stop_token));
+ ARROW_ASSIGN_OR_RAISE(auto schema, GetSchema());
+ return Table::FromRecordBatches(schema, std::move(batches)).Value(table);
+ }
void Cancel() override { rpc_->context.TryCancel(); }
private:
@@ -574,6 +603,7 @@ class GrpcStreamReader : public FlightStreamReader {
// read. Nullable, as DoGet() doesn't need this.
std::shared_ptr read_mutex_;
ipc::IpcReadOptions options_;
+ StopToken stop_token_;
std::shared_ptr> stream_;
std::shared_ptr>>
peekable_reader_;
@@ -1060,12 +1090,13 @@ class FlightClient::FlightClientImpl {
std::vector flights;
pb::FlightInfo pb_info;
- while (stream->Read(&pb_info)) {
+ while (!options.stop_token.IsStopRequested() && stream->Read(&pb_info)) {
FlightInfo::Data info_data;
RETURN_NOT_OK(internal::FromProto(pb_info, &info_data));
flights.emplace_back(std::move(info_data));
}
-
+ if (options.stop_token.IsStopRequested()) rpc.context.TryCancel();
+ RETURN_NOT_OK(options.stop_token.Poll());
listing->reset(new SimpleFlightListing(std::move(flights)));
return internal::FromGrpcStatus(stream->Finish(), &rpc.context);
}
@@ -1083,11 +1114,13 @@ class FlightClient::FlightClientImpl {
pb::Result pb_result;
std::vector materialized_results;
- while (stream->Read(&pb_result)) {
+ while (!options.stop_token.IsStopRequested() && stream->Read(&pb_result)) {
Result result;
RETURN_NOT_OK(internal::FromProto(pb_result, &result));
materialized_results.emplace_back(std::move(result));
}
+ if (options.stop_token.IsStopRequested()) rpc.context.TryCancel();
+ RETURN_NOT_OK(options.stop_token.Poll());
*results = std::unique_ptr(
new SimpleResultStream(std::move(materialized_results)));
@@ -1104,10 +1137,12 @@ class FlightClient::FlightClientImpl {
pb::ActionType pb_type;
ActionType type;
- while (stream->Read(&pb_type)) {
+ while (!options.stop_token.IsStopRequested() && stream->Read(&pb_type)) {
RETURN_NOT_OK(internal::FromProto(pb_type, &type));
types->emplace_back(std::move(type));
}
+ if (options.stop_token.IsStopRequested()) rpc.context.TryCancel();
+ RETURN_NOT_OK(options.stop_token.Poll());
return internal::FromGrpcStatus(stream->Finish(), &rpc.context);
}
@@ -1163,8 +1198,8 @@ class FlightClient::FlightClientImpl {
auto finishable_stream = std::make_shared<
FinishableStream, internal::FlightData>>(
rpc, stream);
- *out = std::unique_ptr(
- new StreamReader(rpc, nullptr, options.read_options, finishable_stream));
+ *out = std::unique_ptr(new StreamReader(
+ rpc, nullptr, options.read_options, options.stop_token, finishable_stream));
// Eagerly read the schema
return static_cast(out->get())->EnsureDataStarted();
}
@@ -1208,8 +1243,8 @@ class FlightClient::FlightClientImpl {
auto finishable_stream =
std::make_shared>(
rpc, read_mutex, stream);
- *reader = std::unique_ptr(
- new StreamReader(rpc, read_mutex, options.read_options, finishable_stream));
+ *reader = std::unique_ptr(new StreamReader(
+ rpc, read_mutex, options.read_options, options.stop_token, finishable_stream));
// Do not eagerly read the schema. There may be metadata messages
// before any data is sent, or data may not be sent at all.
return StreamWriter::Open(descriptor, nullptr, options.write_options, rpc,
diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h
index b3c5a96e597..bc803c24e93 100644
--- a/cpp/src/arrow/flight/client.h
+++ b/cpp/src/arrow/flight/client.h
@@ -31,6 +31,7 @@
#include "arrow/ipc/writer.h"
#include "arrow/result.h"
#include "arrow/status.h"
+#include "arrow/util/cancel.h"
#include "arrow/util/variant.h"
#include "arrow/flight/types.h" // IWYU pragma: keep
@@ -69,6 +70,9 @@ class ARROW_FLIGHT_EXPORT FlightCallOptions {
/// \brief Headers for client to add to context.
std::vector> headers;
+
+ /// \brief A token to enable interactive user cancellation of long-running requests.
+ StopToken stop_token;
};
/// \brief Indicate that the client attempted to write a message
@@ -129,6 +133,12 @@ class ARROW_FLIGHT_EXPORT FlightStreamReader : public MetadataRecordBatchReader
public:
/// \brief Try to cancel the call.
virtual void Cancel() = 0;
+ using MetadataRecordBatchReader::ReadAll;
+ /// \brief Consume entire stream as a vector of record batches
+ virtual Status ReadAll(std::vector>* batches,
+ const StopToken& stop_token) = 0;
+ /// \brief Consume entire stream as a Table
+ virtual Status ReadAll(std::shared_ptr* table, const StopToken& stop_token) = 0;
};
// Silence warning
diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc
index 35993f1eaa1..8264f3e2197 100644
--- a/cpp/src/arrow/flight/flight_test.cc
+++ b/cpp/src/arrow/flight/flight_test.cc
@@ -2673,5 +2673,147 @@ TEST_F(TestCookieParsing, CookieCache) {
AddCookieVerifyCache({"id0=0;", "id1=1;", "id2=2"}, "id0=\"0\"; id1=\"1\"; id2=\"2\"");
}
+class ForeverFlightListing : public FlightListing {
+ Status Next(std::unique_ptr* info) override {
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ *info = arrow::internal::make_unique(ExampleFlightInfo()[0]);
+ return Status::OK();
+ }
+};
+
+class ForeverResultStream : public ResultStream {
+ Status Next(std::unique_ptr* result) override {
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ *result = arrow::internal::make_unique();
+ (*result)->body = Buffer::FromString("foo");
+ return Status::OK();
+ }
+};
+
+class ForeverDataStream : public FlightDataStream {
+ public:
+ ForeverDataStream() : schema_(arrow::schema({})), mapper_(*schema_) {}
+ std::shared_ptr schema() override { return schema_; }
+
+ Status GetSchemaPayload(FlightPayload* payload) override {
+ return ipc::GetSchemaPayload(*schema_, ipc::IpcWriteOptions::Defaults(), mapper_,
+ &payload->ipc_message);
+ }
+
+ Status Next(FlightPayload* payload) override {
+ auto batch = RecordBatch::Make(schema_, 0, ArrayVector{});
+ return ipc::GetRecordBatchPayload(*batch, ipc::IpcWriteOptions::Defaults(),
+ &payload->ipc_message);
+ }
+
+ private:
+ std::shared_ptr schema_;
+ ipc::DictionaryFieldMapper mapper_;
+};
+
+class CancelTestServer : public FlightServerBase {
+ public:
+ Status ListFlights(const ServerCallContext&, const Criteria*,
+ std::unique_ptr* listings) override {
+ *listings = arrow::internal::make_unique();
+ return Status::OK();
+ }
+ Status DoAction(const ServerCallContext&, const Action&,
+ std::unique_ptr* result) override {
+ *result = arrow::internal::make_unique();
+ return Status::OK();
+ }
+ Status ListActions(const ServerCallContext&,
+ std::vector* actions) override {
+ *actions = {};
+ return Status::OK();
+ }
+ Status DoGet(const ServerCallContext&, const Ticket&,
+ std::unique_ptr* data_stream) override {
+ *data_stream = arrow::internal::make_unique();
+ return Status::OK();
+ }
+};
+
+class TestCancel : public ::testing::Test {
+ public:
+ void SetUp() {
+ ASSERT_OK(MakeServer(
+ &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); },
+ [](FlightClientOptions* options) { return Status::OK(); }));
+ }
+ void TearDown() { ASSERT_OK(server_->Shutdown()); }
+
+ protected:
+ std::unique_ptr client_;
+ std::unique_ptr server_;
+};
+
+TEST_F(TestCancel, ListFlights) {
+ StopSource stop_source;
+ FlightCallOptions options;
+ options.stop_token = stop_source.token();
+ std::unique_ptr listing;
+ stop_source.RequestStop(Status::Cancelled("StopSource"));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
+ client_->ListFlights(options, {}, &listing));
+}
+
+TEST_F(TestCancel, DoAction) {
+ StopSource stop_source;
+ FlightCallOptions options;
+ options.stop_token = stop_source.token();
+ std::unique_ptr results;
+ stop_source.RequestStop(Status::Cancelled("StopSource"));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
+ client_->DoAction(options, {}, &results));
+}
+
+TEST_F(TestCancel, ListActions) {
+ StopSource stop_source;
+ FlightCallOptions options;
+ options.stop_token = stop_source.token();
+ std::vector results;
+ stop_source.RequestStop(Status::Cancelled("StopSource"));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
+ client_->ListActions(options, &results));
+}
+
+TEST_F(TestCancel, DoGet) {
+ StopSource stop_source;
+ FlightCallOptions options;
+ options.stop_token = stop_source.token();
+ std::unique_ptr results;
+ stop_source.RequestStop(Status::Cancelled("StopSource"));
+ std::unique_ptr stream;
+ ASSERT_OK(client_->DoGet(options, {}, &stream));
+ std::shared_ptr table;
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
+ stream->ReadAll(&table));
+
+ ASSERT_OK(client_->DoGet({}, &stream));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
+ stream->ReadAll(&table, options.stop_token));
+}
+
+TEST_F(TestCancel, DoExchange) {
+ StopSource stop_source;
+ FlightCallOptions options;
+ options.stop_token = stop_source.token();
+ std::unique_ptr results;
+ stop_source.RequestStop(Status::Cancelled("StopSource"));
+ std::unique_ptr writer;
+ std::unique_ptr stream;
+ ASSERT_OK(
+ client_->DoExchange(options, FlightDescriptor::Command(""), &writer, &stream));
+ std::shared_ptr table;
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
+ stream->ReadAll(&table));
+
+ ASSERT_OK(client_->DoExchange(FlightDescriptor::Command(""), &writer, &stream));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
+ stream->ReadAll(&table, options.stop_token));
+}
+
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc
index ce5a07fc3e0..8ed76e78da8 100644
--- a/cpp/src/arrow/flight/server.cc
+++ b/cpp/src/arrow/flight/server.cc
@@ -383,6 +383,7 @@ class GrpcServerCallContext : public ServerCallContext {
const std::string& peer_identity() const override { return peer_identity_; }
const std::string& peer() const override { return peer_; }
+ bool is_cancelled() const override { return context_->IsCancelled(); }
// Helper method that runs interceptors given the result of an RPC,
// then returns the final gRPC status to send to the client
diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h
index dd95b7536cd..96b2da488ee 100644
--- a/cpp/src/arrow/flight/server.h
+++ b/cpp/src/arrow/flight/server.h
@@ -119,6 +119,9 @@ class ARROW_FLIGHT_EXPORT ServerCallContext {
/// to the object beyond the request body.
/// \return The middleware, or nullptr if not found.
virtual ServerMiddleware* GetMiddleware(const std::string& key) const = 0;
+ /// \brief Check if the current RPC has been cancelled (by the client, by
+ /// a network error, etc.).
+ virtual bool is_cancelled() const = 0;
};
class ARROW_FLIGHT_EXPORT FlightServerOptions {
diff --git a/cpp/src/arrow/util/cancel.cc b/cpp/src/arrow/util/cancel.cc
index 533075a9a64..874b2c2c886 100644
--- a/cpp/src/arrow/util/cancel.cc
+++ b/cpp/src/arrow/util/cancel.cc
@@ -74,14 +74,14 @@ void StopSource::Reset() {
StopToken StopSource::token() { return StopToken(impl_); }
-bool StopToken::IsStopRequested() {
+bool StopToken::IsStopRequested() const {
if (!impl_) {
return false;
}
return impl_->requested_.load() != 0;
}
-Status StopToken::Poll() {
+Status StopToken::Poll() const {
if (!impl_) {
return Status::OK();
}
diff --git a/cpp/src/arrow/util/cancel.h b/cpp/src/arrow/util/cancel.h
index 506a7e16e4f..9e00f673a21 100644
--- a/cpp/src/arrow/util/cancel.h
+++ b/cpp/src/arrow/util/cancel.h
@@ -65,8 +65,8 @@ class ARROW_EXPORT StopToken {
static StopToken Unstoppable() { return StopToken(); }
// Producer API (the side that gets asked to stopped)
- Status Poll();
- bool IsStopRequested();
+ Status Poll() const;
+ bool IsStopRequested() const;
protected:
std::shared_ptr impl_;
From b284b9dadccf05aaa1bf368ef5815da5974923c2 Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 13 May 2021 14:45:20 -0400
Subject: [PATCH 2/3] ARROW-12050: [Python][FlightRPC] Enable cancellation of
long-running requests
---
python/pyarrow/_flight.pyx | 69 ++++++++++++++-------
python/pyarrow/includes/libarrow_flight.pxd | 4 ++
python/pyarrow/tests/test_csv.py | 15 +----
python/pyarrow/tests/test_flight.py | 62 ++++++++++++++++++
python/pyarrow/tests/util.py | 17 +++++
5 files changed, 131 insertions(+), 36 deletions(-)
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index e5d80df9380..a84166ce866 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -32,7 +32,7 @@ from cython.operator cimport postincrement
from libcpp cimport bool as c_bool
from pyarrow.lib cimport *
-from pyarrow.lib import ArrowException, ArrowInvalid
+from pyarrow.lib import ArrowException, ArrowInvalid, SignalStopHandler
from pyarrow.lib import as_buffer, frombytes, tobytes
from pyarrow.includes.libarrow_flight cimport *
from pyarrow.ipc import _get_legacy_format_default, _ReadPandasMixin
@@ -897,6 +897,19 @@ cdef class FlightStreamReader(MetadataRecordBatchReader):
with nogil:
( self.reader.get()).Cancel()
+ def read_all(self):
+ """Read the entire contents of the stream as a Table."""
+ cdef:
+ shared_ptr[CTable] c_table
+ CStopToken stop_token
+ with SignalStopHandler() as stop_handler:
+ stop_token = ( stop_handler.stop_token).stop_token
+ with nogil:
+ check_flight_status(
+ ( self.reader.get())
+ .ReadAllWithStopToken(&c_table, stop_token))
+ return pyarrow_wrap_table(c_table)
+
cdef class MetadataRecordBatchWriter(_CRecordBatchWriter):
"""A RecordBatchWriter that also allows writing application metadata.
@@ -1204,17 +1217,20 @@ cdef class FlightClient(_Weakrefable):
vector[CActionType] results
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
- with nogil:
- check_flight_status(
- self.client.get().ListActions(deref(c_options), &results))
+ with SignalStopHandler() as stop_handler:
+ c_options.stop_token = \
+ ( stop_handler.stop_token).stop_token
+ with nogil:
+ check_flight_status(
+ self.client.get().ListActions(deref(c_options), &results))
- result = []
- for action_type in results:
- py_action = ActionType(frombytes(action_type.type),
- frombytes(action_type.description))
- result.append(py_action)
+ result = []
+ for action_type in results:
+ py_action = ActionType(frombytes(action_type.type),
+ frombytes(action_type.description))
+ result.append(py_action)
- return result
+ return result
def do_action(self, action, options: FlightCallOptions = None):
"""
@@ -1247,9 +1263,8 @@ cdef class FlightClient(_Weakrefable):
cdef CAction c_action = Action.unwrap( action)
with nogil:
check_flight_status(
- self.client.get().DoAction(deref(c_options), c_action,
- &results))
-
+ self.client.get().DoAction(
+ deref(c_options), c_action, &results))
while True:
result = Result.__new__(Result)
with nogil:
@@ -1270,18 +1285,21 @@ cdef class FlightClient(_Weakrefable):
if criteria:
c_criteria.expression = tobytes(criteria)
- with nogil:
- check_flight_status(
- self.client.get().ListFlights(deref(c_options),
- c_criteria, &listing))
-
- while True:
- result = FlightInfo.__new__(FlightInfo)
+ with SignalStopHandler() as stop_handler:
+ c_options.stop_token = \
+ ( stop_handler.stop_token).stop_token
with nogil:
- check_flight_status(listing.get().Next(&result.info))
- if result.info == NULL:
- break
- yield result
+ check_flight_status(
+ self.client.get().ListFlights(deref(c_options),
+ c_criteria, &listing))
+
+ while True:
+ result = FlightInfo.__new__(FlightInfo)
+ with nogil:
+ check_flight_status(listing.get().Next(&result.info))
+ if result.info == NULL:
+ break
+ yield result
def get_flight_info(self, descriptor: FlightDescriptor,
options: FlightCallOptions = None):
@@ -1497,6 +1515,9 @@ cdef class ServerCallContext(_Weakrefable):
# Set safe=True as gRPC on Windows sometimes gives garbage bytes
return frombytes(self.context.peer(), safe=True)
+ def is_cancelled(self):
+ return self.context.is_cancelled()
+
def get_middleware(self, key):
"""
Get a middleware instance by key.
diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd
index 737babb3fd5..2ac737abaa0 100644
--- a/python/pyarrow/includes/libarrow_flight.pxd
+++ b/python/pyarrow/includes/libarrow_flight.pxd
@@ -166,6 +166,8 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
cdef cppclass CFlightStreamReader \
" arrow::flight::FlightStreamReader"(CMetadataRecordBatchReader):
void Cancel()
+ CStatus ReadAllWithStopToken" ReadAll"\
+ (shared_ptr[CTable]* table, const CStopToken& stop_token)
cdef cppclass CFlightMessageReader \
" arrow::flight::FlightMessageReader"(CMetadataRecordBatchReader):
@@ -211,6 +213,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
cdef cppclass CServerCallContext" arrow::flight::ServerCallContext":
c_string& peer_identity()
c_string& peer()
+ c_bool is_cancelled()
CServerMiddleware* GetMiddleware(const c_string& key)
cdef cppclass CTimeoutDuration" arrow::flight::TimeoutDuration":
@@ -221,6 +224,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
CTimeoutDuration timeout
CIpcWriteOptions write_options
vector[pair[c_string, c_string]] headers
+ CStopToken stop_token
cdef cppclass CCertKeyPair" arrow::flight::CertKeyPair":
CCertKeyPair()
diff --git a/python/pyarrow/tests/test_csv.py b/python/pyarrow/tests/test_csv.py
index 3fa9ae02e4d..69fa4aae3c8 100644
--- a/python/pyarrow/tests/test_csv.py
+++ b/python/pyarrow/tests/test_csv.py
@@ -27,7 +27,6 @@
import shutil
import signal
import string
-import sys
import tempfile
import threading
import time
@@ -41,6 +40,7 @@
from pyarrow.csv import (
open_csv, read_csv, ReadOptions, ParseOptions, ConvertOptions, ISO8601,
write_csv, WriteOptions)
+from pyarrow.tests import util
def generate_col_names():
@@ -918,17 +918,8 @@ def test_cancellation(self):
if (threading.current_thread().ident !=
threading.main_thread().ident):
pytest.skip("test only works from main Python thread")
-
- if sys.version_info >= (3, 8):
- raise_signal = signal.raise_signal
- elif os.name == 'nt':
- # On Windows, os.kill() doesn't actually send a signal,
- # it just terminates the process with the given exit code.
- pytest.skip("test requires Python 3.8+ on Windows")
- else:
- # On Unix, emulate raise_signal() with os.kill().
- def raise_signal(signum):
- os.kill(os.getpid(), signum)
+ # Skips test if not available
+ raise_signal = util.get_raise_signal()
# Make the interruptible workload large enough to not finish
# before the interrupt comes, even in release mode on fast machines
diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py
index 585fdb2a062..ef324cd3380 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -17,7 +17,9 @@
import ast
import base64
+import itertools
import os
+import signal
import struct
import tempfile
import threading
@@ -30,6 +32,7 @@
from pyarrow.lib import tobytes
from pyarrow.util import pathlib, find_free_port
+from pyarrow.tests import util
try:
from pyarrow import flight
@@ -1810,3 +1813,62 @@ def test_generic_options():
generic_options=options)
with pytest.raises(pa.ArrowInvalid):
client.do_get(flight.Ticket(b'ints'))
+
+
+class CancelFlightServer(FlightServerBase):
+ """A server for testing StopToken."""
+
+ def do_get(self, context, ticket):
+ schema = pa.schema([])
+ rb = pa.RecordBatch.from_arrays([], schema=schema)
+ return flight.GeneratorStream(schema, itertools.repeat(rb))
+
+ def do_exchange(self, context, descriptor, reader, writer):
+ schema = pa.schema([])
+ rb = pa.RecordBatch.from_arrays([], schema=schema)
+ writer.begin(schema)
+ while not context.is_cancelled():
+ # TODO: writing schema.empty_table() here hangs/fails
+ writer.write_batch(rb)
+ time.sleep(0.5)
+
+
+def test_interrupt():
+ if threading.current_thread().ident != threading.main_thread().ident:
+ pytest.skip("test only works from main Python thread")
+ # Skips test if not available
+ raise_signal = util.get_raise_signal()
+
+ def signal_from_thread():
+ time.sleep(0.5)
+ raise_signal(signal.SIGINT)
+
+ exc_types = (KeyboardInterrupt, pa.ArrowCancelled)
+
+ def test(read_all):
+ try:
+ try:
+ t = threading.Thread(target=signal_from_thread)
+ with pytest.raises(exc_types) as exc_info:
+ t.start()
+ read_all()
+ finally:
+ t.join()
+ except KeyboardInterrupt:
+ # In case KeyboardInterrupt didn't interrupt read_all
+ # above, at least prevent it from stopping the test suite
+ # pytest.fail("KeyboardInterrupt didn't interrupt Flight read_all")
+ raise
+ e = exc_info.value.__context__
+ assert isinstance(e, pa.ArrowCancelled) or isinstance(
+ e, pa.ArrowCancelled)
+
+ with CancelFlightServer() as server:
+ client = FlightClient(("localhost", server.port))
+
+ reader = client.do_get(flight.Ticket(b""))
+ test(reader.read_all)
+
+ descriptor = flight.FlightDescriptor.for_command(b"echo")
+ writer, reader = client.do_exchange(descriptor)
+ test(reader.read_all)
diff --git a/python/pyarrow/tests/util.py b/python/pyarrow/tests/util.py
index ea43b7c4e64..3425fe01c9b 100644
--- a/python/pyarrow/tests/util.py
+++ b/python/pyarrow/tests/util.py
@@ -25,10 +25,13 @@
import numpy as np
import os
import random
+import signal
import string
import subprocess
import sys
+import pytest
+
import pyarrow as pa
@@ -237,3 +240,17 @@ def __init__(self, path):
def __fspath__(self):
return str(self._path)
+
+
+def get_raise_signal():
+ if sys.version_info >= (3, 8):
+ return signal.raise_signal
+ elif os.name == 'nt':
+ # On Windows, os.kill() doesn't actually send a signal,
+ # it just terminates the process with the given exit code.
+ pytest.skip("test requires Python 3.8+ on Windows")
+ else:
+ # On Unix, emulate raise_signal() with os.kill().
+ def raise_signal(signum):
+ os.kill(os.getpid(), signum)
+ return raise_signal
From 3dd7663076d7eca0e58606a2183ed243119943e1 Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 2 Jun 2021 11:54:20 -0500
Subject: [PATCH 3/3] ARROW-12050: [C++][FlightRPC] Address review feedback
---
cpp/src/arrow/flight/client.cc | 15 +++++++++------
cpp/src/arrow/flight/client.h | 2 +-
python/pyarrow/tests/test_flight.py | 8 +++-----
3 files changed, 13 insertions(+), 12 deletions(-)
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index 880454bca1b..84fc4a28e92 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -93,6 +93,14 @@ std::shared_ptr FlightWriteSizeStatusDetail::Unwrap
FlightClientOptions FlightClientOptions::Defaults() { return FlightClientOptions(); }
+Status FlightStreamReader::ReadAll(std::shared_ptr* table,
+ const StopToken& stop_token) {
+ std::vector> batches;
+ RETURN_NOT_OK(ReadAll(&batches, stop_token));
+ ARROW_ASSIGN_OR_RAISE(auto schema, GetSchema());
+ return Table::FromRecordBatches(schema, std::move(batches)).Value(table);
+}
+
struct ClientRpc {
grpc::ClientContext context;
@@ -575,12 +583,7 @@ class GrpcStreamReader : public FlightStreamReader {
Status ReadAll(std::shared_ptr* table) override {
return ReadAll(table, stop_token_);
}
- Status ReadAll(std::shared_ptr* table, const StopToken& stop_token) override {
- std::vector> batches;
- RETURN_NOT_OK(ReadAll(&batches, stop_token));
- ARROW_ASSIGN_OR_RAISE(auto schema, GetSchema());
- return Table::FromRecordBatches(schema, std::move(batches)).Value(table);
- }
+ using FlightStreamReader::ReadAll;
void Cancel() override { rpc_->context.TryCancel(); }
private:
diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h
index bc803c24e93..0a35b6d10e8 100644
--- a/cpp/src/arrow/flight/client.h
+++ b/cpp/src/arrow/flight/client.h
@@ -138,7 +138,7 @@ class ARROW_FLIGHT_EXPORT FlightStreamReader : public MetadataRecordBatchReader
virtual Status ReadAll(std::vector>* batches,
const StopToken& stop_token) = 0;
/// \brief Consume entire stream as a Table
- virtual Status ReadAll(std::shared_ptr* table, const StopToken& stop_token) = 0;
+ Status ReadAll(std::shared_ptr* table, const StopToken& stop_token);
};
// Silence warning
diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py
index ef324cd3380..1ab01f735e9 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -1828,7 +1828,6 @@ def do_exchange(self, context, descriptor, reader, writer):
rb = pa.RecordBatch.from_arrays([], schema=schema)
writer.begin(schema)
while not context.is_cancelled():
- # TODO: writing schema.empty_table() here hangs/fails
writer.write_batch(rb)
time.sleep(0.5)
@@ -1857,11 +1856,10 @@ def test(read_all):
except KeyboardInterrupt:
# In case KeyboardInterrupt didn't interrupt read_all
# above, at least prevent it from stopping the test suite
- # pytest.fail("KeyboardInterrupt didn't interrupt Flight read_all")
- raise
+ pytest.fail("KeyboardInterrupt didn't interrupt Flight read_all")
e = exc_info.value.__context__
- assert isinstance(e, pa.ArrowCancelled) or isinstance(
- e, pa.ArrowCancelled)
+ assert isinstance(e, pa.ArrowCancelled) or \
+ isinstance(e, KeyboardInterrupt)
with CancelFlightServer() as server:
client = FlightClient(("localhost", server.port))