Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 47 additions & 9 deletions cpp/src/arrow/flight/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "arrow/record_batch.h"
#include "arrow/result.h"
#include "arrow/status.h"
#include "arrow/table.h"
#include "arrow/type.h"
#include "arrow/util/logging.h"
#include "arrow/util/uri.h"
Expand Down Expand Up @@ -92,6 +93,14 @@ std::shared_ptr<FlightWriteSizeStatusDetail> FlightWriteSizeStatusDetail::Unwrap

FlightClientOptions FlightClientOptions::Defaults() { return FlightClientOptions(); }

Status FlightStreamReader::ReadAll(std::shared_ptr<Table>* table,
const StopToken& stop_token) {
std::vector<std::shared_ptr<RecordBatch>> batches;
RETURN_NOT_OK(ReadAll(&batches, stop_token));
ARROW_ASSIGN_OR_RAISE(auto schema, GetSchema());
return Table::FromRecordBatches(schema, std::move(batches)).Value(table);
}

struct ClientRpc {
grpc::ClientContext context;

Expand Down Expand Up @@ -484,11 +493,12 @@ template <typename Reader>
class GrpcStreamReader : public FlightStreamReader {
public:
GrpcStreamReader(std::shared_ptr<ClientRpc> rpc, std::shared_ptr<std::mutex> read_mutex,
const ipc::IpcReadOptions& options,
const ipc::IpcReadOptions& options, StopToken stop_token,
std::shared_ptr<FinishableStream<Reader, internal::FlightData>> stream)
: rpc_(rpc),
read_mutex_(read_mutex),
options_(options),
stop_token_(std::move(stop_token)),
stream_(stream),
peekable_reader_(new internal::PeekableFlightDataReader<std::shared_ptr<Reader>>(
stream->stream())),
Expand Down Expand Up @@ -552,6 +562,28 @@ class GrpcStreamReader : public FlightStreamReader {
out->app_metadata = std::move(app_metadata_);
return Status::OK();
}
Status ReadAll(std::vector<std::shared_ptr<RecordBatch>>* batches) override {
return ReadAll(batches, stop_token_);
}
Status ReadAll(std::vector<std::shared_ptr<RecordBatch>>* batches,
const StopToken& stop_token) override {
FlightStreamChunk chunk;

while (true) {
if (stop_token.IsStopRequested()) {
Cancel();
return stop_token.Poll();
}
RETURN_NOT_OK(Next(&chunk));
if (!chunk.data) break;
batches->emplace_back(std::move(chunk.data));
}
return Status::OK();
}
Status ReadAll(std::shared_ptr<Table>* table) override {
return ReadAll(table, stop_token_);
}
using FlightStreamReader::ReadAll;
void Cancel() override { rpc_->context.TryCancel(); }

private:
Expand All @@ -574,6 +606,7 @@ class GrpcStreamReader : public FlightStreamReader {
// read. Nullable, as DoGet() doesn't need this.
std::shared_ptr<std::mutex> read_mutex_;
ipc::IpcReadOptions options_;
StopToken stop_token_;
std::shared_ptr<FinishableStream<Reader, internal::FlightData>> stream_;
std::shared_ptr<internal::PeekableFlightDataReader<std::shared_ptr<Reader>>>
peekable_reader_;
Expand Down Expand Up @@ -1060,12 +1093,13 @@ class FlightClient::FlightClientImpl {
std::vector<FlightInfo> flights;

pb::FlightInfo pb_info;
while (stream->Read(&pb_info)) {
while (!options.stop_token.IsStopRequested() && stream->Read(&pb_info)) {
FlightInfo::Data info_data;
RETURN_NOT_OK(internal::FromProto(pb_info, &info_data));
flights.emplace_back(std::move(info_data));
}

if (options.stop_token.IsStopRequested()) rpc.context.TryCancel();
RETURN_NOT_OK(options.stop_token.Poll());
listing->reset(new SimpleFlightListing(std::move(flights)));
return internal::FromGrpcStatus(stream->Finish(), &rpc.context);
}
Expand All @@ -1083,11 +1117,13 @@ class FlightClient::FlightClientImpl {
pb::Result pb_result;

std::vector<Result> materialized_results;
while (stream->Read(&pb_result)) {
while (!options.stop_token.IsStopRequested() && stream->Read(&pb_result)) {
Result result;
RETURN_NOT_OK(internal::FromProto(pb_result, &result));
materialized_results.emplace_back(std::move(result));
}
if (options.stop_token.IsStopRequested()) rpc.context.TryCancel();
RETURN_NOT_OK(options.stop_token.Poll());

*results = std::unique_ptr<ResultStream>(
new SimpleResultStream(std::move(materialized_results)));
Expand All @@ -1104,10 +1140,12 @@ class FlightClient::FlightClientImpl {

pb::ActionType pb_type;
ActionType type;
while (stream->Read(&pb_type)) {
while (!options.stop_token.IsStopRequested() && stream->Read(&pb_type)) {
RETURN_NOT_OK(internal::FromProto(pb_type, &type));
types->emplace_back(std::move(type));
}
if (options.stop_token.IsStopRequested()) rpc.context.TryCancel();
RETURN_NOT_OK(options.stop_token.Poll());
return internal::FromGrpcStatus(stream->Finish(), &rpc.context);
}

Expand Down Expand Up @@ -1163,8 +1201,8 @@ class FlightClient::FlightClientImpl {
auto finishable_stream = std::make_shared<
FinishableStream<grpc::ClientReader<pb::FlightData>, internal::FlightData>>(
rpc, stream);
*out = std::unique_ptr<StreamReader>(
new StreamReader(rpc, nullptr, options.read_options, finishable_stream));
*out = std::unique_ptr<StreamReader>(new StreamReader(
rpc, nullptr, options.read_options, options.stop_token, finishable_stream));
// Eagerly read the schema
return static_cast<StreamReader*>(out->get())->EnsureDataStarted();
}
Expand Down Expand Up @@ -1208,8 +1246,8 @@ class FlightClient::FlightClientImpl {
auto finishable_stream =
std::make_shared<FinishableWritableStream<GrpcStream, internal::FlightData>>(
rpc, read_mutex, stream);
*reader = std::unique_ptr<StreamReader>(
new StreamReader(rpc, read_mutex, options.read_options, finishable_stream));
*reader = std::unique_ptr<StreamReader>(new StreamReader(
rpc, read_mutex, options.read_options, options.stop_token, finishable_stream));
// Do not eagerly read the schema. There may be metadata messages
// before any data is sent, or data may not be sent at all.
return StreamWriter::Open(descriptor, nullptr, options.write_options, rpc,
Expand Down
10 changes: 10 additions & 0 deletions cpp/src/arrow/flight/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "arrow/ipc/writer.h"
#include "arrow/result.h"
#include "arrow/status.h"
#include "arrow/util/cancel.h"
#include "arrow/util/variant.h"

#include "arrow/flight/types.h" // IWYU pragma: keep
Expand Down Expand Up @@ -69,6 +70,9 @@ class ARROW_FLIGHT_EXPORT FlightCallOptions {

/// \brief Headers for client to add to context.
std::vector<std::pair<std::string, std::string>> headers;

/// \brief A token to enable interactive user cancellation of long-running requests.
StopToken stop_token;
};

/// \brief Indicate that the client attempted to write a message
Expand Down Expand Up @@ -129,6 +133,12 @@ class ARROW_FLIGHT_EXPORT FlightStreamReader : public MetadataRecordBatchReader
public:
/// \brief Try to cancel the call.
virtual void Cancel() = 0;
using MetadataRecordBatchReader::ReadAll;
/// \brief Consume entire stream as a vector of record batches
virtual Status ReadAll(std::vector<std::shared_ptr<RecordBatch>>* batches,
const StopToken& stop_token) = 0;
/// \brief Consume entire stream as a Table
Status ReadAll(std::shared_ptr<Table>* table, const StopToken& stop_token);
};

// Silence warning
Expand Down
142 changes: 142 additions & 0 deletions cpp/src/arrow/flight/flight_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2673,5 +2673,147 @@ TEST_F(TestCookieParsing, CookieCache) {
AddCookieVerifyCache({"id0=0;", "id1=1;", "id2=2"}, "id0=\"0\"; id1=\"1\"; id2=\"2\"");
}

class ForeverFlightListing : public FlightListing {
Status Next(std::unique_ptr<FlightInfo>* info) override {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
*info = arrow::internal::make_unique<FlightInfo>(ExampleFlightInfo()[0]);
return Status::OK();
}
};

class ForeverResultStream : public ResultStream {
Status Next(std::unique_ptr<Result>* result) override {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
*result = arrow::internal::make_unique<Result>();
(*result)->body = Buffer::FromString("foo");
return Status::OK();
}
};

class ForeverDataStream : public FlightDataStream {
public:
ForeverDataStream() : schema_(arrow::schema({})), mapper_(*schema_) {}
std::shared_ptr<Schema> schema() override { return schema_; }

Status GetSchemaPayload(FlightPayload* payload) override {
return ipc::GetSchemaPayload(*schema_, ipc::IpcWriteOptions::Defaults(), mapper_,
&payload->ipc_message);
}

Status Next(FlightPayload* payload) override {
auto batch = RecordBatch::Make(schema_, 0, ArrayVector{});
return ipc::GetRecordBatchPayload(*batch, ipc::IpcWriteOptions::Defaults(),
&payload->ipc_message);
}

private:
std::shared_ptr<Schema> schema_;
ipc::DictionaryFieldMapper mapper_;
};

class CancelTestServer : public FlightServerBase {
public:
Status ListFlights(const ServerCallContext&, const Criteria*,
std::unique_ptr<FlightListing>* listings) override {
*listings = arrow::internal::make_unique<ForeverFlightListing>();
return Status::OK();
}
Status DoAction(const ServerCallContext&, const Action&,
std::unique_ptr<ResultStream>* result) override {
*result = arrow::internal::make_unique<ForeverResultStream>();
return Status::OK();
}
Status ListActions(const ServerCallContext&,
std::vector<ActionType>* actions) override {
*actions = {};
return Status::OK();
}
Status DoGet(const ServerCallContext&, const Ticket&,
std::unique_ptr<FlightDataStream>* data_stream) override {
*data_stream = arrow::internal::make_unique<ForeverDataStream>();
return Status::OK();
}
};

class TestCancel : public ::testing::Test {
public:
void SetUp() {
ASSERT_OK(MakeServer<CancelTestServer>(
&server_, &client_, [](FlightServerOptions* options) { return Status::OK(); },
[](FlightClientOptions* options) { return Status::OK(); }));
}
void TearDown() { ASSERT_OK(server_->Shutdown()); }

protected:
std::unique_ptr<FlightClient> client_;
std::unique_ptr<FlightServerBase> server_;
};

TEST_F(TestCancel, ListFlights) {
StopSource stop_source;
FlightCallOptions options;
options.stop_token = stop_source.token();
std::unique_ptr<FlightListing> listing;
stop_source.RequestStop(Status::Cancelled("StopSource"));
EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
client_->ListFlights(options, {}, &listing));
}

TEST_F(TestCancel, DoAction) {
StopSource stop_source;
FlightCallOptions options;
options.stop_token = stop_source.token();
std::unique_ptr<ResultStream> results;
stop_source.RequestStop(Status::Cancelled("StopSource"));
EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
client_->DoAction(options, {}, &results));
}

TEST_F(TestCancel, ListActions) {
StopSource stop_source;
FlightCallOptions options;
options.stop_token = stop_source.token();
std::vector<ActionType> results;
stop_source.RequestStop(Status::Cancelled("StopSource"));
EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
client_->ListActions(options, &results));
}

TEST_F(TestCancel, DoGet) {
StopSource stop_source;
FlightCallOptions options;
options.stop_token = stop_source.token();
std::unique_ptr<ResultStream> results;
stop_source.RequestStop(Status::Cancelled("StopSource"));
std::unique_ptr<FlightStreamReader> stream;
ASSERT_OK(client_->DoGet(options, {}, &stream));
std::shared_ptr<Table> table;
EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
stream->ReadAll(&table));

ASSERT_OK(client_->DoGet({}, &stream));
EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
stream->ReadAll(&table, options.stop_token));
}

TEST_F(TestCancel, DoExchange) {
StopSource stop_source;
FlightCallOptions options;
options.stop_token = stop_source.token();
std::unique_ptr<ResultStream> results;
stop_source.RequestStop(Status::Cancelled("StopSource"));
std::unique_ptr<FlightStreamWriter> writer;
std::unique_ptr<FlightStreamReader> stream;
ASSERT_OK(
client_->DoExchange(options, FlightDescriptor::Command(""), &writer, &stream));
std::shared_ptr<Table> table;
EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
stream->ReadAll(&table));

ASSERT_OK(client_->DoExchange(FlightDescriptor::Command(""), &writer, &stream));
EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
stream->ReadAll(&table, options.stop_token));
}

} // namespace flight
} // namespace arrow
1 change: 1 addition & 0 deletions cpp/src/arrow/flight/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ class GrpcServerCallContext : public ServerCallContext {

const std::string& peer_identity() const override { return peer_identity_; }
const std::string& peer() const override { return peer_; }
bool is_cancelled() const override { return context_->IsCancelled(); }

// Helper method that runs interceptors given the result of an RPC,
// then returns the final gRPC status to send to the client
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/flight/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ class ARROW_FLIGHT_EXPORT ServerCallContext {
/// to the object beyond the request body.
/// \return The middleware, or nullptr if not found.
virtual ServerMiddleware* GetMiddleware(const std::string& key) const = 0;
/// \brief Check if the current RPC has been cancelled (by the client, by
/// a network error, etc.).
virtual bool is_cancelled() const = 0;
};

class ARROW_FLIGHT_EXPORT FlightServerOptions {
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/util/cancel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ void StopSource::Reset() {

StopToken StopSource::token() { return StopToken(impl_); }

bool StopToken::IsStopRequested() {
bool StopToken::IsStopRequested() const {
if (!impl_) {
return false;
}
return impl_->requested_.load() != 0;
}

Status StopToken::Poll() {
Status StopToken::Poll() const {
if (!impl_) {
return Status::OK();
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/util/cancel.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ class ARROW_EXPORT StopToken {
static StopToken Unstoppable() { return StopToken(); }

// Producer API (the side that gets asked to stopped)
Status Poll();
bool IsStopRequested();
Status Poll() const;
bool IsStopRequested() const;

protected:
std::shared_ptr<StopSourceImpl> impl_;
Expand Down
Loading