Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions cpp/src/arrow/flight/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,12 @@ class IpcMessageReader : public ipc::MessageReader {
public:
IpcMessageReader(std::shared_ptr<internal::ClientDataStream> stream,
std::shared_ptr<internal::PeekableFlightDataReader> peekable_reader,
std::shared_ptr<MemoryManager> memory_manager,
std::shared_ptr<Buffer>* app_metadata)
: stream_(std::move(stream)),
peekable_reader_(peekable_reader),
memory_manager_(memory_manager ? std::move(memory_manager)
: CPUDevice::Instance()->default_memory_manager()),
app_metadata_(app_metadata),
stream_finished_(false) {}

Expand All @@ -106,6 +109,9 @@ class IpcMessageReader : public ipc::MessageReader {
stream_finished_ = true;
return stream_->Finish(Status::OK());
}
if (data->body) {
ARROW_ASSIGN_OR_RAISE(data->body, Buffer::ViewOrCopy(data->body, memory_manager_));
}
// Validate IPC message
auto result = data->OpenMessage();
if (!result.ok()) {
Expand All @@ -132,10 +138,12 @@ class IpcMessageReader : public ipc::MessageReader {
class ClientStreamReader : public FlightStreamReader {
public:
ClientStreamReader(std::shared_ptr<internal::ClientDataStream> stream,
const ipc::IpcReadOptions& options, StopToken stop_token)
const ipc::IpcReadOptions& options, StopToken stop_token,
std::shared_ptr<MemoryManager> memory_manager)
: stream_(std::move(stream)),
options_(options),
stop_token_(std::move(stop_token)),
memory_manager_(std::move(memory_manager)),
peekable_reader_(new internal::PeekableFlightDataReader(stream_.get())),
app_metadata_(nullptr) {}

Expand All @@ -149,8 +157,8 @@ class ClientStreamReader : public FlightStreamReader {
FlightStatusCode::Internal, "Server never sent a data message"));
}

auto message_reader = std::unique_ptr<ipc::MessageReader>(
new IpcMessageReader(stream_, peekable_reader_, &app_metadata_));
auto message_reader = std::unique_ptr<ipc::MessageReader>(new IpcMessageReader(
stream_, peekable_reader_, memory_manager_, &app_metadata_));
auto result =
ipc::RecordBatchStreamReader::Open(std::move(message_reader), options_);
RETURN_NOT_OK(OverrideWithServerError(std::move(result).Value(&batch_reader_)));
Expand Down Expand Up @@ -225,6 +233,7 @@ class ClientStreamReader : public FlightStreamReader {
std::shared_ptr<internal::ClientDataStream> stream_;
ipc::IpcReadOptions options_;
StopToken stop_token_;
std::shared_ptr<MemoryManager> memory_manager_;
std::shared_ptr<internal::PeekableFlightDataReader> peekable_reader_;
std::shared_ptr<ipc::RecordBatchReader> batch_reader_;
std::shared_ptr<Buffer> app_metadata_;
Expand Down Expand Up @@ -541,8 +550,9 @@ Status FlightClient::DoGet(const FlightCallOptions& options, const Ticket& ticke
RETURN_NOT_OK(CheckOpen());
std::unique_ptr<internal::ClientDataStream> remote_stream;
RETURN_NOT_OK(transport_->DoGet(options, ticket, &remote_stream));
*stream = std::unique_ptr<ClientStreamReader>(new ClientStreamReader(
std::move(remote_stream), options.read_options, options.stop_token));
*stream = std::unique_ptr<ClientStreamReader>(
new ClientStreamReader(std::move(remote_stream), options.read_options,
options.stop_token, options.memory_manager));
// Eagerly read the schema
return static_cast<ClientStreamReader*>(stream->get())->EnsureDataStarted();
}
Expand Down Expand Up @@ -573,8 +583,8 @@ Status FlightClient::DoExchange(const FlightCallOptions& options,
std::unique_ptr<internal::ClientDataStream> remote_stream;
RETURN_NOT_OK(transport_->DoExchange(options, &remote_stream));
std::shared_ptr<internal::ClientDataStream> shared_stream = std::move(remote_stream);
*reader = std::unique_ptr<FlightStreamReader>(
new ClientStreamReader(shared_stream, options.read_options, options.stop_token));
*reader = std::unique_ptr<FlightStreamReader>(new ClientStreamReader(
shared_stream, options.read_options, options.stop_token, options.memory_manager));
auto stream_writer = std::unique_ptr<ClientStreamWriter>(
new ClientStreamWriter(std::move(shared_stream), options.write_options,
write_size_limit_bytes_, descriptor));
Expand Down
54 changes: 9 additions & 45 deletions cpp/src/arrow/flight/flight_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,75 +77,37 @@ class GrpcConnectivityTest : public ConnectivityTest {
protected:
std::string transport() const override { return "grpc"; }
};
TEST_F(GrpcConnectivityTest, GetPort) { TestGetPort(); }
TEST_F(GrpcConnectivityTest, BuilderHook) { TestBuilderHook(); }
TEST_F(GrpcConnectivityTest, Shutdown) { TestShutdown(); }
TEST_F(GrpcConnectivityTest, ShutdownWithDeadline) { TestShutdownWithDeadline(); }
ARROW_FLIGHT_TEST_CONNECTIVITY(GrpcConnectivityTest);

class GrpcDataTest : public DataTest {
protected:
std::string transport() const override { return "grpc"; }
};
TEST_F(GrpcDataTest, TestDoGetInts) { TestDoGetInts(); }
TEST_F(GrpcDataTest, TestDoGetFloats) { TestDoGetFloats(); }
TEST_F(GrpcDataTest, TestDoGetDicts) { TestDoGetDicts(); }
TEST_F(GrpcDataTest, TestDoGetLargeBatch) { TestDoGetLargeBatch(); }
TEST_F(GrpcDataTest, TestOverflowServerBatch) { TestOverflowServerBatch(); }
TEST_F(GrpcDataTest, TestOverflowClientBatch) { TestOverflowClientBatch(); }
TEST_F(GrpcDataTest, TestDoExchange) { TestDoExchange(); }
TEST_F(GrpcDataTest, TestDoExchangeNoData) { TestDoExchangeNoData(); }
TEST_F(GrpcDataTest, TestDoExchangeWriteOnlySchema) { TestDoExchangeWriteOnlySchema(); }
TEST_F(GrpcDataTest, TestDoExchangeGet) { TestDoExchangeGet(); }
TEST_F(GrpcDataTest, TestDoExchangePut) { TestDoExchangePut(); }
TEST_F(GrpcDataTest, TestDoExchangeEcho) { TestDoExchangeEcho(); }
TEST_F(GrpcDataTest, TestDoExchangeTotal) { TestDoExchangeTotal(); }
TEST_F(GrpcDataTest, TestDoExchangeError) { TestDoExchangeError(); }
TEST_F(GrpcDataTest, TestIssue5095) { TestIssue5095(); }
ARROW_FLIGHT_TEST_DATA(GrpcDataTest);

class GrpcDoPutTest : public DoPutTest {
protected:
std::string transport() const override { return "grpc"; }
};
TEST_F(GrpcDoPutTest, TestInts) { TestInts(); }
TEST_F(GrpcDoPutTest, TestFloats) { TestFloats(); }
TEST_F(GrpcDoPutTest, TestEmptyBatch) { TestEmptyBatch(); }
TEST_F(GrpcDoPutTest, TestDicts) { TestDicts(); }
TEST_F(GrpcDoPutTest, TestLargeBatch) { TestLargeBatch(); }
TEST_F(GrpcDoPutTest, TestSizeLimit) { TestSizeLimit(); }
ARROW_FLIGHT_TEST_DO_PUT(GrpcDoPutTest);

class GrpcAppMetadataTest : public AppMetadataTest {
protected:
std::string transport() const override { return "grpc"; }
};
TEST_F(GrpcAppMetadataTest, TestDoGet) { TestDoGet(); }
TEST_F(GrpcAppMetadataTest, TestDoGetDictionaries) { TestDoGetDictionaries(); }
TEST_F(GrpcAppMetadataTest, TestDoPut) { TestDoPut(); }
TEST_F(GrpcAppMetadataTest, TestDoPutDictionaries) { TestDoPutDictionaries(); }
TEST_F(GrpcAppMetadataTest, TestDoPutReadMetadata) { TestDoPutReadMetadata(); }
ARROW_FLIGHT_TEST_APP_METADATA(GrpcAppMetadataTest);

class GrpcIpcOptionsTest : public IpcOptionsTest {
protected:
std::string transport() const override { return "grpc"; }
};
TEST_F(GrpcIpcOptionsTest, TestDoGetReadOptions) { TestDoGetReadOptions(); }
TEST_F(GrpcIpcOptionsTest, TestDoPutWriteOptions) { TestDoPutWriteOptions(); }
TEST_F(GrpcIpcOptionsTest, TestDoExchangeClientWriteOptions) {
TestDoExchangeClientWriteOptions();
}
TEST_F(GrpcIpcOptionsTest, TestDoExchangeClientWriteOptionsBegin) {
TestDoExchangeClientWriteOptionsBegin();
}
TEST_F(GrpcIpcOptionsTest, TestDoExchangeServerWriteOptions) {
TestDoExchangeServerWriteOptions();
}
ARROW_FLIGHT_TEST_IPC_OPTIONS(GrpcIpcOptionsTest);

class GrpcCudaDataTest : public CudaDataTest {
protected:
std::string transport() const override { return "grpc"; }
};
TEST_F(GrpcCudaDataTest, TestDoGet) { TestDoGet(); }
TEST_F(GrpcCudaDataTest, TestDoPut) { TestDoPut(); }
TEST_F(GrpcCudaDataTest, TestDoExchange) { TestDoExchange(); }
ARROW_FLIGHT_TEST_CUDA_DATA(GrpcCudaDataTest);

//------------------------------------------------------------
// Ad-hoc gRPC-specific tests
Expand Down Expand Up @@ -1262,8 +1224,10 @@ TEST_F(TestBasicAuthHandler, FailUnauthenticatedCalls) {
std::shared_ptr<Schema> schema(
(new arrow::Schema(std::vector<std::shared_ptr<Field>>())));
status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader);
ASSERT_OK(status);
// May or may not succeed depending on if the transport buffers the write
ARROW_UNUSED(status);
status = writer->Close();
// But this should definitely fail
ASSERT_RAISES(IOError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
}
Expand Down
Loading