From 7a3adcc973704021a18ef377ffa2ee946b1e9522 Mon Sep 17 00:00:00 2001
From: David Li
Date: Mon, 14 Mar 2022 16:25:31 -0400
Subject: [PATCH 1/4] ARROW-15932: [C++][FlightRPC] Add more tests to the
common Flight suite
---
cpp/src/arrow/flight/client.cc | 24 +-
cpp/src/arrow/flight/flight_test.cc | 4 +
cpp/src/arrow/flight/test_definitions.cc | 227 +++++++++++++++---
cpp/src/arrow/flight/test_definitions.h | 4 +
cpp/src/arrow/flight/test_util.cc | 3 +
.../flight/transport/grpc/grpc_client.cc | 60 ++---
6 files changed, 239 insertions(+), 83 deletions(-)
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index 3067d28c9ed..3aabe37ebcf 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -90,9 +90,12 @@ class IpcMessageReader : public ipc::MessageReader {
public:
IpcMessageReader(std::shared_ptr stream,
std::shared_ptr peekable_reader,
+ std::shared_ptr memory_manager,
std::shared_ptr* app_metadata)
: stream_(std::move(stream)),
peekable_reader_(peekable_reader),
+ memory_manager_(memory_manager ? std::move(memory_manager)
+ : CPUDevice::Instance()->default_memory_manager()),
app_metadata_(app_metadata),
stream_finished_(false) {}
@@ -106,6 +109,9 @@ class IpcMessageReader : public ipc::MessageReader {
stream_finished_ = true;
return stream_->Finish(Status::OK());
}
+ if (data->body) {
+ ARROW_ASSIGN_OR_RAISE(data->body, Buffer::ViewOrCopy(data->body, memory_manager_));
+ }
// Validate IPC message
auto result = data->OpenMessage();
if (!result.ok()) {
@@ -132,10 +138,12 @@ class IpcMessageReader : public ipc::MessageReader {
class ClientStreamReader : public FlightStreamReader {
public:
ClientStreamReader(std::shared_ptr stream,
- const ipc::IpcReadOptions& options, StopToken stop_token)
+ const ipc::IpcReadOptions& options, StopToken stop_token,
+ std::shared_ptr memory_manager)
: stream_(std::move(stream)),
options_(options),
stop_token_(std::move(stop_token)),
+ memory_manager_(std::move(memory_manager)),
peekable_reader_(new internal::PeekableFlightDataReader(stream_.get())),
app_metadata_(nullptr) {}
@@ -149,8 +157,8 @@ class ClientStreamReader : public FlightStreamReader {
FlightStatusCode::Internal, "Server never sent a data message"));
}
- auto message_reader = std::unique_ptr(
- new IpcMessageReader(stream_, peekable_reader_, &app_metadata_));
+ auto message_reader = std::unique_ptr(new IpcMessageReader(
+ stream_, peekable_reader_, memory_manager_, &app_metadata_));
auto result =
ipc::RecordBatchStreamReader::Open(std::move(message_reader), options_);
RETURN_NOT_OK(OverrideWithServerError(std::move(result).Value(&batch_reader_)));
@@ -225,6 +233,7 @@ class ClientStreamReader : public FlightStreamReader {
std::shared_ptr stream_;
ipc::IpcReadOptions options_;
StopToken stop_token_;
+ std::shared_ptr memory_manager_;
std::shared_ptr peekable_reader_;
std::shared_ptr batch_reader_;
std::shared_ptr app_metadata_;
@@ -541,8 +550,9 @@ Status FlightClient::DoGet(const FlightCallOptions& options, const Ticket& ticke
RETURN_NOT_OK(CheckOpen());
std::unique_ptr remote_stream;
RETURN_NOT_OK(transport_->DoGet(options, ticket, &remote_stream));
- *stream = std::unique_ptr(new ClientStreamReader(
- std::move(remote_stream), options.read_options, options.stop_token));
+ *stream = std::unique_ptr(
+ new ClientStreamReader(std::move(remote_stream), options.read_options,
+ options.stop_token, options.memory_manager));
// Eagerly read the schema
return static_cast(stream->get())->EnsureDataStarted();
}
@@ -573,8 +583,8 @@ Status FlightClient::DoExchange(const FlightCallOptions& options,
std::unique_ptr remote_stream;
RETURN_NOT_OK(transport_->DoExchange(options, &remote_stream));
std::shared_ptr shared_stream = std::move(remote_stream);
- *reader = std::unique_ptr(
- new ClientStreamReader(shared_stream, options.read_options, options.stop_token));
+ *reader = std::unique_ptr(new ClientStreamReader(
+ shared_stream, options.read_options, options.stop_token, options.memory_manager));
auto stream_writer = std::unique_ptr(
new ClientStreamWriter(std::move(shared_stream), options.write_options,
write_size_limit_bytes_, descriptor));
diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc
index 812bd080f18..1fc9bd0952c 100644
--- a/cpp/src/arrow/flight/flight_test.cc
+++ b/cpp/src/arrow/flight/flight_test.cc
@@ -81,6 +81,7 @@ TEST_F(GrpcConnectivityTest, GetPort) { TestGetPort(); }
TEST_F(GrpcConnectivityTest, BuilderHook) { TestBuilderHook(); }
TEST_F(GrpcConnectivityTest, Shutdown) { TestShutdown(); }
TEST_F(GrpcConnectivityTest, ShutdownWithDeadline) { TestShutdownWithDeadline(); }
+TEST_F(GrpcConnectivityTest, BrokenConnection) { TestBrokenConnection(); }
class GrpcDataTest : public DataTest {
protected:
@@ -100,6 +101,8 @@ TEST_F(GrpcDataTest, TestDoExchangePut) { TestDoExchangePut(); }
TEST_F(GrpcDataTest, TestDoExchangeEcho) { TestDoExchangeEcho(); }
TEST_F(GrpcDataTest, TestDoExchangeTotal) { TestDoExchangeTotal(); }
TEST_F(GrpcDataTest, TestDoExchangeError) { TestDoExchangeError(); }
+TEST_F(GrpcDataTest, TestDoExchangeConcurrency) { TestDoExchangeConcurrency(); }
+TEST_F(GrpcDataTest, TestDoExchangeUndrained) { TestDoExchangeUndrained(); }
TEST_F(GrpcDataTest, TestIssue5095) { TestIssue5095(); }
class GrpcDoPutTest : public DoPutTest {
@@ -112,6 +115,7 @@ TEST_F(GrpcDoPutTest, TestEmptyBatch) { TestEmptyBatch(); }
TEST_F(GrpcDoPutTest, TestDicts) { TestDicts(); }
TEST_F(GrpcDoPutTest, TestLargeBatch) { TestLargeBatch(); }
TEST_F(GrpcDoPutTest, TestSizeLimit) { TestSizeLimit(); }
+TEST_F(GrpcDoPutTest, TestUndrained) { TestUndrained(); }
class GrpcAppMetadataTest : public AppMetadataTest {
protected:
diff --git a/cpp/src/arrow/flight/test_definitions.cc b/cpp/src/arrow/flight/test_definitions.cc
index 099769ded9d..6dec505e3de 100644
--- a/cpp/src/arrow/flight/test_definitions.cc
+++ b/cpp/src/arrow/flight/test_definitions.cc
@@ -91,6 +91,23 @@ void ConnectivityTest::TestShutdownWithDeadline() {
ASSERT_OK(server->Shutdown(&deadline));
ASSERT_OK(server->Wait());
}
+void ConnectivityTest::TestBrokenConnection() {
+ std::unique_ptr server = ExampleTestServer();
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ FlightServerOptions options(location);
+ ASSERT_OK(server->Init(options));
+
+ std::unique_ptr client;
+ ASSERT_OK_AND_ASSIGN(location,
+ Location::ForScheme(transport(), "localhost", server->port()));
+ ASSERT_OK(FlightClient::Connect(location, &client));
+
+ ASSERT_OK(server->Shutdown());
+ ASSERT_OK(server->Wait());
+
+ std::unique_ptr info;
+ ASSERT_FALSE(client->GetFlightInfo(FlightDescriptor::Command(""), &info).ok());
+}
//------------------------------------------------------------
// Tests of data plane methods
@@ -500,6 +517,60 @@ void DataTest::TestDoExchangeError() {
// buffer writes - a write won't immediately fail even if the server
// would immediately return an error.
}
+void DataTest::TestDoExchangeConcurrency() {
+ // Ensure that we can do reads/writes on separate threads
+ auto descr = FlightDescriptor::Command("echo");
+ std::unique_ptr reader;
+ std::unique_ptr writer;
+ ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+
+ RecordBatchVector batches;
+ ASSERT_OK(ExampleIntBatches(&batches));
+ ASSERT_OK(writer->Begin(ExampleIntSchema()));
+
+ std::thread reader_thread([&reader, &batches]() {
+ FlightStreamChunk chunk;
+ for (size_t i = 0; i < batches.size(); i++) {
+ ASSERT_OK(reader->Next(&chunk));
+ ASSERT_NE(nullptr, chunk.data);
+ ASSERT_EQ(nullptr, chunk.app_metadata);
+ AssertBatchesEqual(*batches[i], *chunk.data);
+ }
+ ASSERT_OK(reader->Next(&chunk));
+ ASSERT_EQ(nullptr, chunk.data);
+ ASSERT_EQ(nullptr, chunk.app_metadata);
+ });
+
+ for (const auto& batch : batches) {
+ ASSERT_OK(writer->WriteRecordBatch(*batch));
+ }
+ ASSERT_OK(writer->DoneWriting());
+ reader_thread.join();
+ ASSERT_OK(writer->Close());
+}
+void DataTest::TestDoExchangeUndrained() {
+ // Ensure if the application doesn't drain all messages, that the
+ // server/transport does
+
+ auto descr = FlightDescriptor::Command("TestUndrained");
+ auto schema = arrow::schema({arrow::field("ints", int64())});
+ std::unique_ptr reader;
+ std::unique_ptr writer;
+ ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+
+ auto batch = RecordBatchFromJSON(schema, "[[1], [2], [3], [4]]");
+ ASSERT_OK(writer->Begin(schema));
+ // These calls may or may not fail depending on how quickly the
+ // transport reacts, whether it batches, writes, etc.
+ ARROW_UNUSED(writer->WriteRecordBatch(*batch));
+ ARROW_UNUSED(writer->WriteRecordBatch(*batch));
+ ARROW_UNUSED(writer->WriteRecordBatch(*batch));
+ ARROW_UNUSED(writer->WriteRecordBatch(*batch));
+ ASSERT_OK(writer->Close());
+
+ // We should be able to make another call
+ TestDoExchangeGet();
+}
void DataTest::TestIssue5095() {
// Make sure the server-side error message is reflected to the
// client
@@ -518,22 +589,49 @@ void DataTest::TestIssue5095() {
//------------------------------------------------------------
// Specific tests for DoPut
+static constexpr char kExpectedMetadata[] = "foo bar";
+
class DoPutTestServer : public FlightServerBase {
public:
Status DoPut(const ServerCallContext& context,
std::unique_ptr reader,
std::unique_ptr writer) override {
descriptor_ = reader->descriptor();
+
+ if (descriptor_.type == FlightDescriptor::DescriptorType::CMD) {
+ if (descriptor_.cmd == "TestUndrained") {
+ // Don't read all the messages
+ return Status::OK();
+ }
+ }
+
int counter = 0;
+ FlightStreamChunk chunk;
while (true) {
- FlightStreamChunk chunk;
RETURN_NOT_OK(reader->Next(&chunk));
if (!chunk.data) break;
+ if (counter % 2 == 1) {
+ if (!chunk.app_metadata) {
+ return Status::Invalid("Expected app_metadata");
+ } else if (chunk.app_metadata->ToString() != std::to_string(counter)) {
+ return Status::Invalid("Expected app_metadata to be ", counter, " but got ",
+ chunk.app_metadata->ToString());
+ }
+ }
batches_.push_back(std::move(chunk.data));
auto buffer = Buffer::FromString(std::to_string(counter));
RETURN_NOT_OK(writer->WriteMetadata(*buffer));
counter++;
}
+
+ // Expect a metadata-only message
+ if (!chunk.app_metadata) {
+ return Status::Invalid("Expected app_metadata at end of stream (#1)");
+ } else if (chunk.app_metadata->ToString() != kExpectedMetadata) {
+ return Status::Invalid("Expected app_metadata to be ", kExpectedMetadata,
+ " but got ", chunk.app_metadata->ToString());
+ }
+
return Status::OK();
}
@@ -573,17 +671,25 @@ void DoPutTest::CheckDoPut(const FlightDescriptor& descr,
ASSERT_OK(client_->DoPut(descr, schema, &stream, &reader));
// Ensure that the reader can be used independently of the writer
- auto* reader_ref = reader.get();
- std::thread reader_thread([reader_ref, &batches]() {
+ std::thread reader_thread([&reader, &batches]() {
for (size_t i = 0; i < batches.size(); i++) {
std::shared_ptr out;
- ASSERT_OK(reader_ref->ReadMetadata(&out));
+ ASSERT_OK(reader->ReadMetadata(&out));
}
});
+ int64_t counter = 0;
for (const auto& batch : batches) {
- ASSERT_OK(stream->WriteRecordBatch(*batch));
+ if (counter % 2 == 0) {
+ ASSERT_OK(stream->WriteRecordBatch(*batch));
+ } else {
+ auto buffer = Buffer::FromString(std::to_string(counter));
+ ASSERT_OK(stream->WriteWithMetadata(*batch, std::move(buffer)));
+ }
+ counter++;
}
+ // Write a metadata-only message
+ ASSERT_OK(stream->WriteMetadata(Buffer::FromString(kExpectedMetadata)));
ASSERT_OK(stream->DoneWriting());
reader_thread.join();
ASSERT_OK(stream->Close());
@@ -671,13 +777,12 @@ void DoPutTest::TestSizeLimit() {
std::unique_ptr client;
ASSERT_OK(FlightClient::Connect(location, client_options, &client));
- auto descr = FlightDescriptor::Path({"ints"});
+ auto descr = FlightDescriptor::Command("simple");
// Batch is too large to fit in one message
auto schema = arrow::schema({field("f1", arrow::int64())});
auto batch = arrow::ConstantArrayGenerator::Zeroes(768, schema);
- RecordBatchVector batches;
- batches.push_back(batch->Slice(0, 384));
- batches.push_back(batch->Slice(384));
+ auto batch1 = batch->Slice(0, 384);
+ auto batch2 = batch->Slice(384);
std::unique_ptr stream;
std::unique_ptr reader;
@@ -692,14 +797,37 @@ void DoPutTest::TestSizeLimit() {
ASSERT_EQ(size_limit, detail->limit());
ASSERT_GT(detail->actual(), size_limit);
- // But we can retry with a smaller batch
- for (const auto& batch : batches) {
- ASSERT_OK(stream->WriteRecordBatch(*batch));
- }
+ // But we can retry with smaller batches
+ ASSERT_OK(stream->WriteRecordBatch(*batch1));
+ ASSERT_OK(stream->WriteWithMetadata(*batch2, Buffer::FromString("1")));
+
+ // Write a metadata-only message
+ ASSERT_OK(stream->WriteMetadata(Buffer::FromString(kExpectedMetadata)));
ASSERT_OK(stream->DoneWriting());
ASSERT_OK(stream->Close());
- CheckBatches(descr, batches);
+ CheckBatches(descr, {batch1, batch2});
+}
+void DoPutTest::TestUndrained() {
+ // Ensure if the application doesn't drain all messages, that the
+ // server/transport does
+
+ auto descr = FlightDescriptor::Command("TestUndrained");
+ auto schema = arrow::schema({arrow::field("ints", int64())});
+ std::unique_ptr stream;
+ std::unique_ptr reader;
+ ASSERT_OK(client_->DoPut(descr, schema, &stream, &reader));
+ auto batch = RecordBatchFromJSON(schema, "[[1], [2], [3], [4]]");
+ // These calls may or may not fail depending on how quickly the
+ // transport reacts, whether it batches, writes, etc.
+ ARROW_UNUSED(stream->WriteRecordBatch(*batch));
+ ARROW_UNUSED(stream->WriteRecordBatch(*batch));
+ ARROW_UNUSED(stream->WriteRecordBatch(*batch));
+ ARROW_UNUSED(stream->WriteRecordBatch(*batch));
+ ASSERT_OK(stream->Close());
+
+ // We should be able to make another call
+ CheckDoPut(FlightDescriptor::Command("foo"), schema, {batch, batch});
}
//------------------------------------------------------------
@@ -1057,26 +1185,21 @@ arrow::Result> CopyBatchToHost(const RecordBatch& b
class CudaTestServer : public FlightServerBase {
public:
- explicit CudaTestServer(std::shared_ptr device) : device_(std::move(device)) {}
+ explicit CudaTestServer(std::shared_ptr device,
+ std::shared_ptr context)
+ : device_(std::move(device)), context_(std::move(context)) {}
Status DoGet(const ServerCallContext&, const Ticket&,
std::unique_ptr* data_stream) override {
- RecordBatchVector batches;
- RETURN_NOT_OK(ExampleIntBatches(&batches));
- ARROW_ASSIGN_OR_RAISE(auto batch_reader, RecordBatchReader::Make(batches));
+ RETURN_NOT_OK(ExampleIntBatches(&batches_));
+ ARROW_ASSIGN_OR_RAISE(auto batch_reader, RecordBatchReader::Make(batches_));
*data_stream = std::unique_ptr(new RecordBatchStream(batch_reader));
return Status::OK();
}
Status DoPut(const ServerCallContext&, std::unique_ptr reader,
std::unique_ptr writer) override {
- RecordBatchVector batches;
- RETURN_NOT_OK(reader->ReadAll(&batches));
- for (const auto& batch : batches) {
- for (const auto& column : batch->columns()) {
- RETURN_NOT_OK(CheckBuffersOnDevice(*column, *device_));
- }
- }
+ RETURN_NOT_OK(reader->ReadAll(&batches_));
return Status::OK();
}
@@ -1095,13 +1218,20 @@ class CudaTestServer : public FlightServerBase {
for (const auto& column : chunk.data->columns()) {
RETURN_NOT_OK(CheckBuffersOnDevice(*column, *device_));
}
+ // XXX: do not assume transport will synchronize, we must
+ // synchronize or else data will be "missing"
+ RETURN_NOT_OK(context_->Synchronize());
RETURN_NOT_OK(writer->WriteRecordBatch(*chunk.data));
}
return Status::OK();
}
+ const RecordBatchVector& batches() const { return batches_; }
+
private:
+ RecordBatchVector batches_;
std::shared_ptr device_;
+ std::shared_ptr context_;
};
// Store CUDA objects without exposing them in the public header
@@ -1128,7 +1258,8 @@ void CudaDataTest::SetUp() {
options->memory_manager = impl_->device->default_memory_manager();
return Status::OK();
},
- [](FlightClientOptions* options) { return Status::OK(); }, impl_->device));
+ [](FlightClientOptions* options) { return Status::OK(); }, impl_->device,
+ impl_->context));
}
void CudaDataTest::TearDown() {
ASSERT_OK(client_->Close());
@@ -1140,17 +1271,34 @@ void CudaDataTest::TestDoGet() {
FlightCallOptions options;
options.memory_manager = impl_->device->default_memory_manager();
+ const RecordBatchVector& batches =
+ reinterpret_cast(server_.get())->batches();
+
Ticket ticket{""};
std::unique_ptr stream;
ASSERT_OK(client_->DoGet(options, ticket, &stream));
- std::shared_ptr table;
- ASSERT_OK(stream->ReadAll(&table));
- for (const auto& column : table->columns()) {
- for (const auto& chunk : column->chunks()) {
- ASSERT_OK(CheckBuffersOnDevice(*chunk, *impl_->device));
+ size_t idx = 0;
+ while (true) {
+ FlightStreamChunk chunk;
+ ASSERT_OK(stream->Next(&chunk));
+ if (!chunk.data) break;
+
+ for (const auto& column : chunk.data->columns()) {
+ ASSERT_OK(CheckBuffersOnDevice(*column, *impl_->device));
+ }
+
+ if (idx >= batches.size()) {
+ FAIL() << "Server returned more than " << batches.size() << " batches";
+ return;
}
+
+ // Bounce record batch back to host memory
+ ASSERT_OK_AND_ASSIGN(auto host_batch, CopyBatchToHost(*chunk.data));
+ AssertBatchesEqual(*batches[idx], *host_batch);
+ idx++;
}
+ ASSERT_EQ(idx, batches.size()) << "Server returned too few batches";
}
void CudaDataTest::TestDoPut() {
RecordBatchVector batches;
@@ -1175,6 +1323,23 @@ void CudaDataTest::TestDoPut() {
ASSERT_OK(writer->WriteRecordBatch(*cuda_batch));
}
ASSERT_OK(writer->Close());
+ ASSERT_OK(impl_->context->Synchronize());
+
+ const RecordBatchVector& written =
+ reinterpret_cast(server_.get())->batches();
+ ASSERT_EQ(written.size(), batches.size());
+
+ size_t idx = 0;
+ for (const auto& batch : written) {
+ for (const auto& column : batch->columns()) {
+ ASSERT_OK(CheckBuffersOnDevice(*column, *impl_->device));
+ }
+
+ // Bounce record batch back to host memory
+ ASSERT_OK_AND_ASSIGN(auto host_batch, CopyBatchToHost(*batch));
+ AssertBatchesEqual(*batches[idx], *host_batch);
+ idx++;
+ }
}
void CudaDataTest::TestDoExchange() {
FlightCallOptions options;
diff --git a/cpp/src/arrow/flight/test_definitions.h b/cpp/src/arrow/flight/test_definitions.h
index 1e256456557..7ef195d045d 100644
--- a/cpp/src/arrow/flight/test_definitions.h
+++ b/cpp/src/arrow/flight/test_definitions.h
@@ -50,6 +50,7 @@ class ARROW_FLIGHT_EXPORT ConnectivityTest : public FlightTest {
void TestBuilderHook();
void TestShutdown();
void TestShutdownWithDeadline();
+ void TestBrokenConnection();
};
/// Common tests of data plane methods
@@ -74,6 +75,8 @@ class ARROW_FLIGHT_EXPORT DataTest : public FlightTest {
void TestDoExchangeEcho();
void TestDoExchangeTotal();
void TestDoExchangeError();
+ void TestDoExchangeConcurrency();
+ void TestDoExchangeUndrained();
void TestIssue5095();
private:
@@ -103,6 +106,7 @@ class ARROW_FLIGHT_EXPORT DoPutTest : public FlightTest {
void TestDicts();
void TestLargeBatch();
void TestSizeLimit();
+ void TestUndrained();
private:
std::unique_ptr client_;
diff --git a/cpp/src/arrow/flight/test_util.cc b/cpp/src/arrow/flight/test_util.cc
index 490f53bfb2f..88bbf5977da 100644
--- a/cpp/src/arrow/flight/test_util.cc
+++ b/cpp/src/arrow/flight/test_util.cc
@@ -262,6 +262,9 @@ class FlightTestServer : public FlightServerBase {
return RunExchangeEcho(std::move(reader), std::move(writer));
} else if (cmd == "large_batch") {
return RunExchangeLargeBatch(std::move(reader), std::move(writer));
+ } else if (cmd == "TestUndrained") {
+ ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema());
+ return Status::OK();
} else {
return Status::NotImplemented("Scenario not implemented: ", cmd);
}
diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc
index 9af58ba4d28..d651e805a0c 100644
--- a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc
+++ b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc
@@ -266,13 +266,8 @@ class GrpcClientAuthReader : public ClientAuthReader {
template
class FinishableDataStream : public internal::ClientDataStream {
public:
- FinishableDataStream(std::shared_ptr rpc, std::shared_ptr stream,
- std::shared_ptr memory_manager)
- : rpc_(std::move(rpc)),
- stream_(std::move(stream)),
- memory_manager_(memory_manager ? std::move(memory_manager)
- : CPUDevice::Instance()->default_memory_manager()),
- finished_(false) {}
+ FinishableDataStream(std::shared_ptr rpc, std::shared_ptr stream)
+ : rpc_(std::move(rpc)), stream_(std::move(stream)), finished_(false) {}
void TryCancel() override { rpc_->context.TryCancel(); }
@@ -316,7 +311,6 @@ class FinishableDataStream : public internal::ClientDataStream {
std::shared_ptr rpc_;
std::shared_ptr stream_;
- std::shared_ptr memory_manager_;
bool finished_;
Status server_status_;
// A transport-side error that needs to get combined with the server status
@@ -330,9 +324,8 @@ template
class WritableDataStream : public FinishableDataStream {
public:
using Base = FinishableDataStream;
- WritableDataStream(std::shared_ptr rpc, std::shared_ptr stream,
- std::shared_ptr memory_manager)
- : Base(std::move(rpc), std::move(stream), std::move(memory_manager)),
+ WritableDataStream(std::shared_ptr rpc, std::shared_ptr stream)
+ : Base(std::move(rpc), std::move(stream)),
read_mutex_(),
finish_mutex_(),
done_writing_(false) {}
@@ -394,16 +387,7 @@ class GrpcClientGetStream
using FinishableDataStream::FinishableDataStream;
bool ReadData(internal::FlightData* data) override {
- bool success = ReadPayload(stream_.get(), data);
- if (ARROW_PREDICT_FALSE(!success)) return false;
- if (data->body) {
- auto status = Buffer::ViewOrCopy(data->body, memory_manager_).Value(&data->body);
- if (!status.ok()) {
- transport_status_ = std::move(status);
- return false;
- }
- }
- return true;
+ return ReadPayload(stream_.get(), data);
}
Status WritesDone() override { return Status::NotImplemented("NYI"); }
};
@@ -413,10 +397,7 @@ class GrpcClientPutStream
pb::PutResult> {
public:
using Stream = ::grpc::ClientReaderWriter;
- GrpcClientPutStream(std::shared_ptr rpc, std::shared_ptr stream,
- std::shared_ptr memory_manager)
- : WritableDataStream(std::move(rpc), std::move(stream), std::move(memory_manager)) {
- }
+ using WritableDataStream::WritableDataStream;
bool ReadPutMetadata(std::shared_ptr* out) override {
std::lock_guard guard(read_mutex_);
@@ -440,23 +421,12 @@ class GrpcClientExchangeStream
internal::FlightData> {
public:
using Stream = ::grpc::ClientReaderWriter;
- GrpcClientExchangeStream(std::shared_ptr rpc, std::shared_ptr stream,
- std::shared_ptr memory_manager)
- : WritableDataStream(std::move(rpc), std::move(stream), std::move(memory_manager)) {
- }
+ GrpcClientExchangeStream(std::shared_ptr rpc, std::shared_ptr stream)
+ : WritableDataStream(std::move(rpc), std::move(stream)) {}
bool ReadData(internal::FlightData* data) override {
std::lock_guard guard(read_mutex_);
- bool success = ReadPayload(stream_.get(), data);
- if (ARROW_PREDICT_FALSE(!success)) return false;
- if (data->body) {
- auto status = Buffer::ViewOrCopy(data->body, memory_manager_).Value(&data->body);
- if (!status.ok()) {
- transport_status_ = std::move(status);
- return false;
- }
- }
- return true;
+ return ReadPayload(stream_.get(), data);
}
arrow::Result WriteData(const FlightPayload& payload) override {
return WritePayload(payload, this->stream_.get());
@@ -859,8 +829,8 @@ class GrpcClientImpl : public internal::ClientTransport {
RETURN_NOT_OK(rpc->SetToken(auth_handler_.get()));
std::shared_ptr<::grpc::ClientReader> stream =
stub_->DoGet(&rpc->context, pb_ticket);
- *out = std::unique_ptr(new GrpcClientGetStream(
- std::move(rpc), std::move(stream), options.memory_manager));
+ *out = std::unique_ptr(
+ new GrpcClientGetStream(std::move(rpc), std::move(stream)));
return Status::OK();
}
@@ -871,8 +841,8 @@ class GrpcClientImpl : public internal::ClientTransport {
auto rpc = std::make_shared(options);
RETURN_NOT_OK(rpc->SetToken(auth_handler_.get()));
std::shared_ptr stream = stub_->DoPut(&rpc->context);
- *out = std::unique_ptr(new GrpcClientPutStream(
- std::move(rpc), std::move(stream), options.memory_manager));
+ *out = std::unique_ptr(
+ new GrpcClientPutStream(std::move(rpc), std::move(stream)));
return Status::OK();
}
@@ -883,8 +853,8 @@ class GrpcClientImpl : public internal::ClientTransport {
auto rpc = std::make_shared(options);
RETURN_NOT_OK(rpc->SetToken(auth_handler_.get()));
std::shared_ptr stream = stub_->DoExchange(&rpc->context);
- *out = std::unique_ptr(new GrpcClientExchangeStream(
- std::move(rpc), std::move(stream), options.memory_manager));
+ *out = std::unique_ptr(
+ new GrpcClientExchangeStream(std::move(rpc), std::move(stream)));
return Status::OK();
}
From d3891bc4a3837dfe2ddb5eb176d283bafc5e0e86 Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 15 Mar 2022 09:25:58 -0400
Subject: [PATCH 2/4] ARROW-15932: [C++][FlightRPC] Add convenience macros for
testing
---
cpp/src/arrow/flight/flight_test.cc | 54 ++---------------
cpp/src/arrow/flight/test_definitions.cc | 2 +-
cpp/src/arrow/flight/test_definitions.h | 74 ++++++++++++++++++++++++
3 files changed, 81 insertions(+), 49 deletions(-)
diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc
index 1fc9bd0952c..d457490e47c 100644
--- a/cpp/src/arrow/flight/flight_test.cc
+++ b/cpp/src/arrow/flight/flight_test.cc
@@ -77,79 +77,37 @@ class GrpcConnectivityTest : public ConnectivityTest {
protected:
std::string transport() const override { return "grpc"; }
};
-TEST_F(GrpcConnectivityTest, GetPort) { TestGetPort(); }
-TEST_F(GrpcConnectivityTest, BuilderHook) { TestBuilderHook(); }
-TEST_F(GrpcConnectivityTest, Shutdown) { TestShutdown(); }
-TEST_F(GrpcConnectivityTest, ShutdownWithDeadline) { TestShutdownWithDeadline(); }
-TEST_F(GrpcConnectivityTest, BrokenConnection) { TestBrokenConnection(); }
+ARROW_FLIGHT_TEST_CONNECTIVITY(GrpcConnectivityTest);
class GrpcDataTest : public DataTest {
protected:
std::string transport() const override { return "grpc"; }
};
-TEST_F(GrpcDataTest, TestDoGetInts) { TestDoGetInts(); }
-TEST_F(GrpcDataTest, TestDoGetFloats) { TestDoGetFloats(); }
-TEST_F(GrpcDataTest, TestDoGetDicts) { TestDoGetDicts(); }
-TEST_F(GrpcDataTest, TestDoGetLargeBatch) { TestDoGetLargeBatch(); }
-TEST_F(GrpcDataTest, TestOverflowServerBatch) { TestOverflowServerBatch(); }
-TEST_F(GrpcDataTest, TestOverflowClientBatch) { TestOverflowClientBatch(); }
-TEST_F(GrpcDataTest, TestDoExchange) { TestDoExchange(); }
-TEST_F(GrpcDataTest, TestDoExchangeNoData) { TestDoExchangeNoData(); }
-TEST_F(GrpcDataTest, TestDoExchangeWriteOnlySchema) { TestDoExchangeWriteOnlySchema(); }
-TEST_F(GrpcDataTest, TestDoExchangeGet) { TestDoExchangeGet(); }
-TEST_F(GrpcDataTest, TestDoExchangePut) { TestDoExchangePut(); }
-TEST_F(GrpcDataTest, TestDoExchangeEcho) { TestDoExchangeEcho(); }
-TEST_F(GrpcDataTest, TestDoExchangeTotal) { TestDoExchangeTotal(); }
-TEST_F(GrpcDataTest, TestDoExchangeError) { TestDoExchangeError(); }
-TEST_F(GrpcDataTest, TestDoExchangeConcurrency) { TestDoExchangeConcurrency(); }
-TEST_F(GrpcDataTest, TestDoExchangeUndrained) { TestDoExchangeUndrained(); }
-TEST_F(GrpcDataTest, TestIssue5095) { TestIssue5095(); }
+ARROW_FLIGHT_TEST_DATA(GrpcDataTest);
class GrpcDoPutTest : public DoPutTest {
protected:
std::string transport() const override { return "grpc"; }
};
-TEST_F(GrpcDoPutTest, TestInts) { TestInts(); }
-TEST_F(GrpcDoPutTest, TestFloats) { TestFloats(); }
-TEST_F(GrpcDoPutTest, TestEmptyBatch) { TestEmptyBatch(); }
-TEST_F(GrpcDoPutTest, TestDicts) { TestDicts(); }
-TEST_F(GrpcDoPutTest, TestLargeBatch) { TestLargeBatch(); }
-TEST_F(GrpcDoPutTest, TestSizeLimit) { TestSizeLimit(); }
-TEST_F(GrpcDoPutTest, TestUndrained) { TestUndrained(); }
+ARROW_FLIGHT_TEST_DO_PUT(GrpcDoPutTest);
class GrpcAppMetadataTest : public AppMetadataTest {
protected:
std::string transport() const override { return "grpc"; }
};
-TEST_F(GrpcAppMetadataTest, TestDoGet) { TestDoGet(); }
-TEST_F(GrpcAppMetadataTest, TestDoGetDictionaries) { TestDoGetDictionaries(); }
-TEST_F(GrpcAppMetadataTest, TestDoPut) { TestDoPut(); }
-TEST_F(GrpcAppMetadataTest, TestDoPutDictionaries) { TestDoPutDictionaries(); }
-TEST_F(GrpcAppMetadataTest, TestDoPutReadMetadata) { TestDoPutReadMetadata(); }
+ARROW_FLIGHT_TEST_APP_METADATA(GrpcAppMetadataTest);
class GrpcIpcOptionsTest : public IpcOptionsTest {
protected:
std::string transport() const override { return "grpc"; }
};
-TEST_F(GrpcIpcOptionsTest, TestDoGetReadOptions) { TestDoGetReadOptions(); }
-TEST_F(GrpcIpcOptionsTest, TestDoPutWriteOptions) { TestDoPutWriteOptions(); }
-TEST_F(GrpcIpcOptionsTest, TestDoExchangeClientWriteOptions) {
- TestDoExchangeClientWriteOptions();
-}
-TEST_F(GrpcIpcOptionsTest, TestDoExchangeClientWriteOptionsBegin) {
- TestDoExchangeClientWriteOptionsBegin();
-}
-TEST_F(GrpcIpcOptionsTest, TestDoExchangeServerWriteOptions) {
- TestDoExchangeServerWriteOptions();
-}
+ARROW_FLIGHT_TEST_IPC_OPTIONS(GrpcIpcOptionsTest);
class GrpcCudaDataTest : public CudaDataTest {
protected:
std::string transport() const override { return "grpc"; }
};
-TEST_F(GrpcCudaDataTest, TestDoGet) { TestDoGet(); }
-TEST_F(GrpcCudaDataTest, TestDoPut) { TestDoPut(); }
-TEST_F(GrpcCudaDataTest, TestDoExchange) { TestDoExchange(); }
+ARROW_FLIGHT_TEST_CUDA_DATA(GrpcCudaDataTest);
//------------------------------------------------------------
// Ad-hoc gRPC-specific tests
diff --git a/cpp/src/arrow/flight/test_definitions.cc b/cpp/src/arrow/flight/test_definitions.cc
index 6dec505e3de..abe06941be9 100644
--- a/cpp/src/arrow/flight/test_definitions.cc
+++ b/cpp/src/arrow/flight/test_definitions.cc
@@ -106,7 +106,7 @@ void ConnectivityTest::TestBrokenConnection() {
ASSERT_OK(server->Wait());
std::unique_ptr info;
- ASSERT_FALSE(client->GetFlightInfo(FlightDescriptor::Command(""), &info).ok());
+ ASSERT_RAISES(IOError, client->GetFlightInfo(FlightDescriptor::Command(""), &info));
}
//------------------------------------------------------------
diff --git a/cpp/src/arrow/flight/test_definitions.h b/cpp/src/arrow/flight/test_definitions.h
index 7ef195d045d..601e8d0b4b1 100644
--- a/cpp/src/arrow/flight/test_definitions.h
+++ b/cpp/src/arrow/flight/test_definitions.h
@@ -29,10 +29,12 @@
#include
#include
#include
+#include
#include
#include "arrow/flight/server.h"
#include "arrow/flight/types.h"
+#include "arrow/util/macros.h"
namespace arrow {
namespace flight {
@@ -53,6 +55,15 @@ class ARROW_FLIGHT_EXPORT ConnectivityTest : public FlightTest {
void TestBrokenConnection();
};
+#define ARROW_FLIGHT_TEST_CONNECTIVITY(FIXTURE) \
+ static_assert(std::is_base_of::value, \
+ ARROW_STRINGIFY(FIXTURE) " must inherit from ConnectivityTest"); \
+ TEST_F(FIXTURE, GetPort) { TestGetPort(); } \
+ TEST_F(FIXTURE, BuilderHook) { TestBuilderHook(); } \
+ TEST_F(FIXTURE, Shutdown) { TestShutdown(); } \
+ TEST_F(FIXTURE, ShutdownWithDeadline) { TestShutdownWithDeadline(); } \
+ TEST_F(FIXTURE, BrokenConnection) { TestBrokenConnection(); }
+
/// Common tests of data plane methods
class ARROW_FLIGHT_EXPORT DataTest : public FlightTest {
public:
@@ -89,6 +100,27 @@ class ARROW_FLIGHT_EXPORT DataTest : public FlightTest {
std::unique_ptr server_;
};
+#define ARROW_FLIGHT_TEST_DATA(FIXTURE) \
+ static_assert(std::is_base_of::value, \
+ ARROW_STRINGIFY(FIXTURE) " must inherit from DataTest"); \
+ TEST_F(FIXTURE, TestDoGetInts) { TestDoGetInts(); } \
+ TEST_F(FIXTURE, TestDoGetFloats) { TestDoGetFloats(); } \
+ TEST_F(FIXTURE, TestDoGetDicts) { TestDoGetDicts(); } \
+ TEST_F(FIXTURE, TestDoGetLargeBatch) { TestDoGetLargeBatch(); } \
+ TEST_F(FIXTURE, TestOverflowServerBatch) { TestOverflowServerBatch(); } \
+ TEST_F(FIXTURE, TestOverflowClientBatch) { TestOverflowClientBatch(); } \
+ TEST_F(FIXTURE, TestDoExchange) { TestDoExchange(); } \
+ TEST_F(FIXTURE, TestDoExchangeNoData) { TestDoExchangeNoData(); } \
+ TEST_F(FIXTURE, TestDoExchangeWriteOnlySchema) { TestDoExchangeWriteOnlySchema(); } \
+ TEST_F(FIXTURE, TestDoExchangeGet) { TestDoExchangeGet(); } \
+ TEST_F(FIXTURE, TestDoExchangePut) { TestDoExchangePut(); } \
+ TEST_F(FIXTURE, TestDoExchangeEcho) { TestDoExchangeEcho(); } \
+ TEST_F(FIXTURE, TestDoExchangeTotal) { TestDoExchangeTotal(); } \
+ TEST_F(FIXTURE, TestDoExchangeError) { TestDoExchangeError(); } \
+ TEST_F(FIXTURE, TestDoExchangeConcurrency) { TestDoExchangeConcurrency(); } \
+ TEST_F(FIXTURE, TestDoExchangeUndrained) { TestDoExchangeUndrained(); } \
+ TEST_F(FIXTURE, TestIssue5095) { TestIssue5095(); }
+
/// \brief Specific tests of DoPut.
class ARROW_FLIGHT_EXPORT DoPutTest : public FlightTest {
public:
@@ -113,6 +145,17 @@ class ARROW_FLIGHT_EXPORT DoPutTest : public FlightTest {
std::unique_ptr server_;
};
+#define ARROW_FLIGHT_TEST_DO_PUT(FIXTURE) \
+ static_assert(std::is_base_of::value, \
+ ARROW_STRINGIFY(FIXTURE) " must inherit from DoPutTest"); \
+ TEST_F(FIXTURE, TestInts) { TestInts(); } \
+ TEST_F(FIXTURE, TestFloats) { TestFloats(); } \
+ TEST_F(FIXTURE, TestEmptyBatch) { TestEmptyBatch(); } \
+ TEST_F(FIXTURE, TestDicts) { TestDicts(); } \
+ TEST_F(FIXTURE, TestLargeBatch) { TestLargeBatch(); } \
+ TEST_F(FIXTURE, TestSizeLimit) { TestSizeLimit(); } \
+ TEST_F(FIXTURE, TestUndrained) { TestUndrained(); }
+
class ARROW_FLIGHT_EXPORT AppMetadataTestServer : public FlightServerBase {
public:
virtual ~AppMetadataTestServer() = default;
@@ -143,6 +186,15 @@ class ARROW_FLIGHT_EXPORT AppMetadataTest : public FlightTest {
std::unique_ptr server_;
};
+#define ARROW_FLIGHT_TEST_APP_METADATA(FIXTURE) \
+ static_assert(std::is_base_of::value, \
+ ARROW_STRINGIFY(FIXTURE) " must inherit from AppMetadataTest"); \
+ TEST_F(FIXTURE, TestDoGet) { TestDoGet(); } \
+ TEST_F(FIXTURE, TestDoGetDictionaries) { TestDoGetDictionaries(); } \
+ TEST_F(FIXTURE, TestDoPut) { TestDoPut(); } \
+ TEST_F(FIXTURE, TestDoPutDictionaries) { TestDoPutDictionaries(); } \
+ TEST_F(FIXTURE, TestDoPutReadMetadata) { TestDoPutReadMetadata(); }
+
/// \brief Tests of IPC options in data plane methods.
class ARROW_FLIGHT_EXPORT IpcOptionsTest : public FlightTest {
public:
@@ -161,6 +213,21 @@ class ARROW_FLIGHT_EXPORT IpcOptionsTest : public FlightTest {
std::unique_ptr server_;
};
+#define ARROW_FLIGHT_TEST_IPC_OPTIONS(FIXTURE) \
+ static_assert(std::is_base_of::value, \
+ ARROW_STRINGIFY(FIXTURE) " must inherit from IpcOptionsTest"); \
+ TEST_F(FIXTURE, TestDoGetReadOptions) { TestDoGetReadOptions(); } \
+ TEST_F(FIXTURE, TestDoPutWriteOptions) { TestDoPutWriteOptions(); } \
+ TEST_F(FIXTURE, TestDoExchangeClientWriteOptions) { \
+ TestDoExchangeClientWriteOptions(); \
+ } \
+ TEST_F(FIXTURE, TestDoExchangeClientWriteOptionsBegin) { \
+ TestDoExchangeClientWriteOptionsBegin(); \
+ } \
+ TEST_F(FIXTURE, TestDoExchangeServerWriteOptions) { \
+ TestDoExchangeServerWriteOptions(); \
+ }
+
/// \brief Tests of data plane methods with CUDA memory.
///
/// If not built with ARROW_CUDA, tests are no-ops.
@@ -181,5 +248,12 @@ class ARROW_FLIGHT_EXPORT CudaDataTest : public FlightTest {
std::shared_ptr impl_;
};
+#define ARROW_FLIGHT_TEST_CUDA_DATA(FIXTURE) \
+ static_assert(std::is_base_of::value, \
+ ARROW_STRINGIFY(FIXTURE) " must inherit from CudaDataTest"); \
+ TEST_F(FIXTURE, TestDoGet) { TestDoGet(); } \
+ TEST_F(FIXTURE, TestDoPut) { TestDoPut(); } \
+ TEST_F(FIXTURE, TestDoExchange) { TestDoExchange(); }
+
} // namespace flight
} // namespace arrow
From e2ee96b2e57ab8be1b2b273e3c37252aa2db1192 Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 17 Mar 2022 11:02:59 -0400
Subject: [PATCH 3/4] ARROW-15932: [C++][FlightRPC] Address feedback
---
cpp/src/arrow/flight/test_definitions.cc | 44 +++++++++++-------------
1 file changed, 20 insertions(+), 24 deletions(-)
diff --git a/cpp/src/arrow/flight/test_definitions.cc b/cpp/src/arrow/flight/test_definitions.cc
index abe06941be9..41bd092f230 100644
--- a/cpp/src/arrow/flight/test_definitions.cc
+++ b/cpp/src/arrow/flight/test_definitions.cc
@@ -26,6 +26,7 @@
#include "arrow/flight/test_util.h"
#include "arrow/table.h"
#include "arrow/testing/generator.h"
+#include "arrow/util/checked_cast.h"
#include "arrow/util/config.h"
#include "arrow/util/logging.h"
@@ -36,6 +37,8 @@
namespace arrow {
namespace flight {
+using arrow::internal::checked_cast;
+
//------------------------------------------------------------
// Tests of initialization/shutdown
@@ -617,6 +620,8 @@ class DoPutTestServer : public FlightServerBase {
return Status::Invalid("Expected app_metadata to be ", counter, " but got ",
chunk.app_metadata->ToString());
}
+ } else if (chunk.app_metadata) {
+ return Status::Invalid("Expected no app_metadata");
}
batches_.push_back(std::move(chunk.data));
auto buffer = Buffer::FromString(std::to_string(counter));
@@ -1163,6 +1168,13 @@ Status CheckBuffersOnDevice(const Array& array, const Device& device) {
return Status::OK();
}
+Status CheckBuffersOnDevice(const RecordBatch& batch, const Device& device) {
+ for (const auto& column : batch.columns()) {
+ RETURN_NOT_OK(CheckBuffersOnDevice(*column, device));
+ }
+ return Status::OK();
+}
+
// Copy a record batch to host memory.
arrow::Result> CopyBatchToHost(const RecordBatch& batch) {
auto mm = CPUDevice::Instance()->default_memory_manager();
@@ -1215,9 +1227,7 @@ class CudaTestServer : public FlightServerBase {
begun = true;
RETURN_NOT_OK(writer->Begin(chunk.data->schema()));
}
- for (const auto& column : chunk.data->columns()) {
- RETURN_NOT_OK(CheckBuffersOnDevice(*column, *device_));
- }
+ RETURN_NOT_OK(CheckBuffersOnDevice(*chunk.data, *device_));
// XXX: do not assume transport will synchronize, we must
// synchronize or else data will be "missing"
RETURN_NOT_OK(context_->Synchronize());
@@ -1272,7 +1282,7 @@ void CudaDataTest::TestDoGet() {
options.memory_manager = impl_->device->default_memory_manager();
const RecordBatchVector& batches =
- reinterpret_cast(server_.get())->batches();
+ checked_cast(server_.get())->batches();
Ticket ticket{""};
std::unique_ptr stream;
@@ -1284,10 +1294,7 @@ void CudaDataTest::TestDoGet() {
ASSERT_OK(stream->Next(&chunk));
if (!chunk.data) break;
- for (const auto& column : chunk.data->columns()) {
- ASSERT_OK(CheckBuffersOnDevice(*column, *impl_->device));
- }
-
+ ASSERT_OK(CheckBuffersOnDevice(*chunk.data, *impl_->device));
if (idx >= batches.size()) {
FAIL() << "Server returned more than " << batches.size() << " batches";
return;
@@ -1316,25 +1323,19 @@ void CudaDataTest::TestDoPut() {
ASSERT_OK_AND_ASSIGN(auto cuda_batch,
cuda::ReadRecordBatch(batch->schema(), &memo, buffer));
- for (const auto& column : cuda_batch->columns()) {
- ASSERT_OK(CheckBuffersOnDevice(*column, *impl_->device));
- }
-
+ ASSERT_OK(CheckBuffersOnDevice(*cuda_batch, *impl_->device));
ASSERT_OK(writer->WriteRecordBatch(*cuda_batch));
}
ASSERT_OK(writer->Close());
ASSERT_OK(impl_->context->Synchronize());
const RecordBatchVector& written =
- reinterpret_cast(server_.get())->batches();
+ checked_cast(server_.get())->batches();
ASSERT_EQ(written.size(), batches.size());
size_t idx = 0;
for (const auto& batch : written) {
- for (const auto& column : batch->columns()) {
- ASSERT_OK(CheckBuffersOnDevice(*column, *impl_->device));
- }
-
+ ASSERT_OK(CheckBuffersOnDevice(*batch, *impl_->device));
// Bounce record batch back to host memory
ASSERT_OK_AND_ASSIGN(auto host_batch, CopyBatchToHost(*batch));
AssertBatchesEqual(*batches[idx], *host_batch);
@@ -1362,17 +1363,12 @@ void CudaDataTest::TestDoExchange() {
ASSERT_OK_AND_ASSIGN(auto cuda_batch,
cuda::ReadRecordBatch(batch->schema(), &write_memo, buffer));
- for (const auto& column : cuda_batch->columns()) {
- ASSERT_OK(CheckBuffersOnDevice(*column, *impl_->device));
- }
-
+ ASSERT_OK(CheckBuffersOnDevice(*cuda_batch, *impl_->device));
ASSERT_OK(writer->WriteRecordBatch(*cuda_batch));
FlightStreamChunk chunk;
ASSERT_OK(reader->Next(&chunk));
- for (const auto& column : chunk.data->columns()) {
- ASSERT_OK(CheckBuffersOnDevice(*column, *impl_->device));
- }
+ ASSERT_OK(CheckBuffersOnDevice(*chunk.data, *impl_->device));
// Bounce record batch back to host memory
ASSERT_OK_AND_ASSIGN(auto host_batch, CopyBatchToHost(*chunk.data));
From ca5a9bff8f3737bd8e720d90b7de5209dd66720f Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 17 Mar 2022 13:16:09 -0400
Subject: [PATCH 4/4] ARROW-15932: [C++][FlightRPC] Improve flaky test
---
cpp/src/arrow/flight/flight_test.cc | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc
index d457490e47c..f4207a34f15 100644
--- a/cpp/src/arrow/flight/flight_test.cc
+++ b/cpp/src/arrow/flight/flight_test.cc
@@ -1224,8 +1224,10 @@ TEST_F(TestBasicAuthHandler, FailUnauthenticatedCalls) {
std::shared_ptr schema(
(new arrow::Schema(std::vector>())));
status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader);
- ASSERT_OK(status);
+ // May or may not succeed depending on if the transport buffers the write
+ ARROW_UNUSED(status);
status = writer->Close();
+ // But this should definitely fail
ASSERT_RAISES(IOError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
}