From 7a3adcc973704021a18ef377ffa2ee946b1e9522 Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 14 Mar 2022 16:25:31 -0400 Subject: [PATCH 1/4] ARROW-15932: [C++][FlightRPC] Add more tests to the common Flight suite --- cpp/src/arrow/flight/client.cc | 24 +- cpp/src/arrow/flight/flight_test.cc | 4 + cpp/src/arrow/flight/test_definitions.cc | 227 +++++++++++++++--- cpp/src/arrow/flight/test_definitions.h | 4 + cpp/src/arrow/flight/test_util.cc | 3 + .../flight/transport/grpc/grpc_client.cc | 60 ++--- 6 files changed, 239 insertions(+), 83 deletions(-) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 3067d28c9ed..3aabe37ebcf 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -90,9 +90,12 @@ class IpcMessageReader : public ipc::MessageReader { public: IpcMessageReader(std::shared_ptr stream, std::shared_ptr peekable_reader, + std::shared_ptr memory_manager, std::shared_ptr* app_metadata) : stream_(std::move(stream)), peekable_reader_(peekable_reader), + memory_manager_(memory_manager ? std::move(memory_manager) + : CPUDevice::Instance()->default_memory_manager()), app_metadata_(app_metadata), stream_finished_(false) {} @@ -106,6 +109,9 @@ class IpcMessageReader : public ipc::MessageReader { stream_finished_ = true; return stream_->Finish(Status::OK()); } + if (data->body) { + ARROW_ASSIGN_OR_RAISE(data->body, Buffer::ViewOrCopy(data->body, memory_manager_)); + } // Validate IPC message auto result = data->OpenMessage(); if (!result.ok()) { @@ -132,10 +138,12 @@ class IpcMessageReader : public ipc::MessageReader { class ClientStreamReader : public FlightStreamReader { public: ClientStreamReader(std::shared_ptr stream, - const ipc::IpcReadOptions& options, StopToken stop_token) + const ipc::IpcReadOptions& options, StopToken stop_token, + std::shared_ptr memory_manager) : stream_(std::move(stream)), options_(options), stop_token_(std::move(stop_token)), + memory_manager_(std::move(memory_manager)), peekable_reader_(new internal::PeekableFlightDataReader(stream_.get())), app_metadata_(nullptr) {} @@ -149,8 +157,8 @@ class ClientStreamReader : public FlightStreamReader { FlightStatusCode::Internal, "Server never sent a data message")); } - auto message_reader = std::unique_ptr( - new IpcMessageReader(stream_, peekable_reader_, &app_metadata_)); + auto message_reader = std::unique_ptr(new IpcMessageReader( + stream_, peekable_reader_, memory_manager_, &app_metadata_)); auto result = ipc::RecordBatchStreamReader::Open(std::move(message_reader), options_); RETURN_NOT_OK(OverrideWithServerError(std::move(result).Value(&batch_reader_))); @@ -225,6 +233,7 @@ class ClientStreamReader : public FlightStreamReader { std::shared_ptr stream_; ipc::IpcReadOptions options_; StopToken stop_token_; + std::shared_ptr memory_manager_; std::shared_ptr peekable_reader_; std::shared_ptr batch_reader_; std::shared_ptr app_metadata_; @@ -541,8 +550,9 @@ Status FlightClient::DoGet(const FlightCallOptions& options, const Ticket& ticke RETURN_NOT_OK(CheckOpen()); std::unique_ptr remote_stream; RETURN_NOT_OK(transport_->DoGet(options, ticket, &remote_stream)); - *stream = std::unique_ptr(new ClientStreamReader( - std::move(remote_stream), options.read_options, options.stop_token)); + *stream = std::unique_ptr( + new ClientStreamReader(std::move(remote_stream), options.read_options, + options.stop_token, options.memory_manager)); // Eagerly read the schema return static_cast(stream->get())->EnsureDataStarted(); } @@ -573,8 +583,8 @@ Status FlightClient::DoExchange(const FlightCallOptions& options, std::unique_ptr remote_stream; RETURN_NOT_OK(transport_->DoExchange(options, &remote_stream)); std::shared_ptr shared_stream = std::move(remote_stream); - *reader = std::unique_ptr( - new ClientStreamReader(shared_stream, options.read_options, options.stop_token)); + *reader = std::unique_ptr(new ClientStreamReader( + shared_stream, options.read_options, options.stop_token, options.memory_manager)); auto stream_writer = std::unique_ptr( new ClientStreamWriter(std::move(shared_stream), options.write_options, write_size_limit_bytes_, descriptor)); diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index 812bd080f18..1fc9bd0952c 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -81,6 +81,7 @@ TEST_F(GrpcConnectivityTest, GetPort) { TestGetPort(); } TEST_F(GrpcConnectivityTest, BuilderHook) { TestBuilderHook(); } TEST_F(GrpcConnectivityTest, Shutdown) { TestShutdown(); } TEST_F(GrpcConnectivityTest, ShutdownWithDeadline) { TestShutdownWithDeadline(); } +TEST_F(GrpcConnectivityTest, BrokenConnection) { TestBrokenConnection(); } class GrpcDataTest : public DataTest { protected: @@ -100,6 +101,8 @@ TEST_F(GrpcDataTest, TestDoExchangePut) { TestDoExchangePut(); } TEST_F(GrpcDataTest, TestDoExchangeEcho) { TestDoExchangeEcho(); } TEST_F(GrpcDataTest, TestDoExchangeTotal) { TestDoExchangeTotal(); } TEST_F(GrpcDataTest, TestDoExchangeError) { TestDoExchangeError(); } +TEST_F(GrpcDataTest, TestDoExchangeConcurrency) { TestDoExchangeConcurrency(); } +TEST_F(GrpcDataTest, TestDoExchangeUndrained) { TestDoExchangeUndrained(); } TEST_F(GrpcDataTest, TestIssue5095) { TestIssue5095(); } class GrpcDoPutTest : public DoPutTest { @@ -112,6 +115,7 @@ TEST_F(GrpcDoPutTest, TestEmptyBatch) { TestEmptyBatch(); } TEST_F(GrpcDoPutTest, TestDicts) { TestDicts(); } TEST_F(GrpcDoPutTest, TestLargeBatch) { TestLargeBatch(); } TEST_F(GrpcDoPutTest, TestSizeLimit) { TestSizeLimit(); } +TEST_F(GrpcDoPutTest, TestUndrained) { TestUndrained(); } class GrpcAppMetadataTest : public AppMetadataTest { protected: diff --git a/cpp/src/arrow/flight/test_definitions.cc b/cpp/src/arrow/flight/test_definitions.cc index 099769ded9d..6dec505e3de 100644 --- a/cpp/src/arrow/flight/test_definitions.cc +++ b/cpp/src/arrow/flight/test_definitions.cc @@ -91,6 +91,23 @@ void ConnectivityTest::TestShutdownWithDeadline() { ASSERT_OK(server->Shutdown(&deadline)); ASSERT_OK(server->Wait()); } +void ConnectivityTest::TestBrokenConnection() { + std::unique_ptr server = ExampleTestServer(); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + FlightServerOptions options(location); + ASSERT_OK(server->Init(options)); + + std::unique_ptr client; + ASSERT_OK_AND_ASSIGN(location, + Location::ForScheme(transport(), "localhost", server->port())); + ASSERT_OK(FlightClient::Connect(location, &client)); + + ASSERT_OK(server->Shutdown()); + ASSERT_OK(server->Wait()); + + std::unique_ptr info; + ASSERT_FALSE(client->GetFlightInfo(FlightDescriptor::Command(""), &info).ok()); +} //------------------------------------------------------------ // Tests of data plane methods @@ -500,6 +517,60 @@ void DataTest::TestDoExchangeError() { // buffer writes - a write won't immediately fail even if the server // would immediately return an error. } +void DataTest::TestDoExchangeConcurrency() { + // Ensure that we can do reads/writes on separate threads + auto descr = FlightDescriptor::Command("echo"); + std::unique_ptr reader; + std::unique_ptr writer; + ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); + + RecordBatchVector batches; + ASSERT_OK(ExampleIntBatches(&batches)); + ASSERT_OK(writer->Begin(ExampleIntSchema())); + + std::thread reader_thread([&reader, &batches]() { + FlightStreamChunk chunk; + for (size_t i = 0; i < batches.size(); i++) { + ASSERT_OK(reader->Next(&chunk)); + ASSERT_NE(nullptr, chunk.data); + ASSERT_EQ(nullptr, chunk.app_metadata); + AssertBatchesEqual(*batches[i], *chunk.data); + } + ASSERT_OK(reader->Next(&chunk)); + ASSERT_EQ(nullptr, chunk.data); + ASSERT_EQ(nullptr, chunk.app_metadata); + }); + + for (const auto& batch : batches) { + ASSERT_OK(writer->WriteRecordBatch(*batch)); + } + ASSERT_OK(writer->DoneWriting()); + reader_thread.join(); + ASSERT_OK(writer->Close()); +} +void DataTest::TestDoExchangeUndrained() { + // Ensure if the application doesn't drain all messages, that the + // server/transport does + + auto descr = FlightDescriptor::Command("TestUndrained"); + auto schema = arrow::schema({arrow::field("ints", int64())}); + std::unique_ptr reader; + std::unique_ptr writer; + ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); + + auto batch = RecordBatchFromJSON(schema, "[[1], [2], [3], [4]]"); + ASSERT_OK(writer->Begin(schema)); + // These calls may or may not fail depending on how quickly the + // transport reacts, whether it batches, writes, etc. + ARROW_UNUSED(writer->WriteRecordBatch(*batch)); + ARROW_UNUSED(writer->WriteRecordBatch(*batch)); + ARROW_UNUSED(writer->WriteRecordBatch(*batch)); + ARROW_UNUSED(writer->WriteRecordBatch(*batch)); + ASSERT_OK(writer->Close()); + + // We should be able to make another call + TestDoExchangeGet(); +} void DataTest::TestIssue5095() { // Make sure the server-side error message is reflected to the // client @@ -518,22 +589,49 @@ void DataTest::TestIssue5095() { //------------------------------------------------------------ // Specific tests for DoPut +static constexpr char kExpectedMetadata[] = "foo bar"; + class DoPutTestServer : public FlightServerBase { public: Status DoPut(const ServerCallContext& context, std::unique_ptr reader, std::unique_ptr writer) override { descriptor_ = reader->descriptor(); + + if (descriptor_.type == FlightDescriptor::DescriptorType::CMD) { + if (descriptor_.cmd == "TestUndrained") { + // Don't read all the messages + return Status::OK(); + } + } + int counter = 0; + FlightStreamChunk chunk; while (true) { - FlightStreamChunk chunk; RETURN_NOT_OK(reader->Next(&chunk)); if (!chunk.data) break; + if (counter % 2 == 1) { + if (!chunk.app_metadata) { + return Status::Invalid("Expected app_metadata"); + } else if (chunk.app_metadata->ToString() != std::to_string(counter)) { + return Status::Invalid("Expected app_metadata to be ", counter, " but got ", + chunk.app_metadata->ToString()); + } + } batches_.push_back(std::move(chunk.data)); auto buffer = Buffer::FromString(std::to_string(counter)); RETURN_NOT_OK(writer->WriteMetadata(*buffer)); counter++; } + + // Expect a metadata-only message + if (!chunk.app_metadata) { + return Status::Invalid("Expected app_metadata at end of stream (#1)"); + } else if (chunk.app_metadata->ToString() != kExpectedMetadata) { + return Status::Invalid("Expected app_metadata to be ", kExpectedMetadata, + " but got ", chunk.app_metadata->ToString()); + } + return Status::OK(); } @@ -573,17 +671,25 @@ void DoPutTest::CheckDoPut(const FlightDescriptor& descr, ASSERT_OK(client_->DoPut(descr, schema, &stream, &reader)); // Ensure that the reader can be used independently of the writer - auto* reader_ref = reader.get(); - std::thread reader_thread([reader_ref, &batches]() { + std::thread reader_thread([&reader, &batches]() { for (size_t i = 0; i < batches.size(); i++) { std::shared_ptr out; - ASSERT_OK(reader_ref->ReadMetadata(&out)); + ASSERT_OK(reader->ReadMetadata(&out)); } }); + int64_t counter = 0; for (const auto& batch : batches) { - ASSERT_OK(stream->WriteRecordBatch(*batch)); + if (counter % 2 == 0) { + ASSERT_OK(stream->WriteRecordBatch(*batch)); + } else { + auto buffer = Buffer::FromString(std::to_string(counter)); + ASSERT_OK(stream->WriteWithMetadata(*batch, std::move(buffer))); + } + counter++; } + // Write a metadata-only message + ASSERT_OK(stream->WriteMetadata(Buffer::FromString(kExpectedMetadata))); ASSERT_OK(stream->DoneWriting()); reader_thread.join(); ASSERT_OK(stream->Close()); @@ -671,13 +777,12 @@ void DoPutTest::TestSizeLimit() { std::unique_ptr client; ASSERT_OK(FlightClient::Connect(location, client_options, &client)); - auto descr = FlightDescriptor::Path({"ints"}); + auto descr = FlightDescriptor::Command("simple"); // Batch is too large to fit in one message auto schema = arrow::schema({field("f1", arrow::int64())}); auto batch = arrow::ConstantArrayGenerator::Zeroes(768, schema); - RecordBatchVector batches; - batches.push_back(batch->Slice(0, 384)); - batches.push_back(batch->Slice(384)); + auto batch1 = batch->Slice(0, 384); + auto batch2 = batch->Slice(384); std::unique_ptr stream; std::unique_ptr reader; @@ -692,14 +797,37 @@ void DoPutTest::TestSizeLimit() { ASSERT_EQ(size_limit, detail->limit()); ASSERT_GT(detail->actual(), size_limit); - // But we can retry with a smaller batch - for (const auto& batch : batches) { - ASSERT_OK(stream->WriteRecordBatch(*batch)); - } + // But we can retry with smaller batches + ASSERT_OK(stream->WriteRecordBatch(*batch1)); + ASSERT_OK(stream->WriteWithMetadata(*batch2, Buffer::FromString("1"))); + + // Write a metadata-only message + ASSERT_OK(stream->WriteMetadata(Buffer::FromString(kExpectedMetadata))); ASSERT_OK(stream->DoneWriting()); ASSERT_OK(stream->Close()); - CheckBatches(descr, batches); + CheckBatches(descr, {batch1, batch2}); +} +void DoPutTest::TestUndrained() { + // Ensure if the application doesn't drain all messages, that the + // server/transport does + + auto descr = FlightDescriptor::Command("TestUndrained"); + auto schema = arrow::schema({arrow::field("ints", int64())}); + std::unique_ptr stream; + std::unique_ptr reader; + ASSERT_OK(client_->DoPut(descr, schema, &stream, &reader)); + auto batch = RecordBatchFromJSON(schema, "[[1], [2], [3], [4]]"); + // These calls may or may not fail depending on how quickly the + // transport reacts, whether it batches, writes, etc. + ARROW_UNUSED(stream->WriteRecordBatch(*batch)); + ARROW_UNUSED(stream->WriteRecordBatch(*batch)); + ARROW_UNUSED(stream->WriteRecordBatch(*batch)); + ARROW_UNUSED(stream->WriteRecordBatch(*batch)); + ASSERT_OK(stream->Close()); + + // We should be able to make another call + CheckDoPut(FlightDescriptor::Command("foo"), schema, {batch, batch}); } //------------------------------------------------------------ @@ -1057,26 +1185,21 @@ arrow::Result> CopyBatchToHost(const RecordBatch& b class CudaTestServer : public FlightServerBase { public: - explicit CudaTestServer(std::shared_ptr device) : device_(std::move(device)) {} + explicit CudaTestServer(std::shared_ptr device, + std::shared_ptr context) + : device_(std::move(device)), context_(std::move(context)) {} Status DoGet(const ServerCallContext&, const Ticket&, std::unique_ptr* data_stream) override { - RecordBatchVector batches; - RETURN_NOT_OK(ExampleIntBatches(&batches)); - ARROW_ASSIGN_OR_RAISE(auto batch_reader, RecordBatchReader::Make(batches)); + RETURN_NOT_OK(ExampleIntBatches(&batches_)); + ARROW_ASSIGN_OR_RAISE(auto batch_reader, RecordBatchReader::Make(batches_)); *data_stream = std::unique_ptr(new RecordBatchStream(batch_reader)); return Status::OK(); } Status DoPut(const ServerCallContext&, std::unique_ptr reader, std::unique_ptr writer) override { - RecordBatchVector batches; - RETURN_NOT_OK(reader->ReadAll(&batches)); - for (const auto& batch : batches) { - for (const auto& column : batch->columns()) { - RETURN_NOT_OK(CheckBuffersOnDevice(*column, *device_)); - } - } + RETURN_NOT_OK(reader->ReadAll(&batches_)); return Status::OK(); } @@ -1095,13 +1218,20 @@ class CudaTestServer : public FlightServerBase { for (const auto& column : chunk.data->columns()) { RETURN_NOT_OK(CheckBuffersOnDevice(*column, *device_)); } + // XXX: do not assume transport will synchronize, we must + // synchronize or else data will be "missing" + RETURN_NOT_OK(context_->Synchronize()); RETURN_NOT_OK(writer->WriteRecordBatch(*chunk.data)); } return Status::OK(); } + const RecordBatchVector& batches() const { return batches_; } + private: + RecordBatchVector batches_; std::shared_ptr device_; + std::shared_ptr context_; }; // Store CUDA objects without exposing them in the public header @@ -1128,7 +1258,8 @@ void CudaDataTest::SetUp() { options->memory_manager = impl_->device->default_memory_manager(); return Status::OK(); }, - [](FlightClientOptions* options) { return Status::OK(); }, impl_->device)); + [](FlightClientOptions* options) { return Status::OK(); }, impl_->device, + impl_->context)); } void CudaDataTest::TearDown() { ASSERT_OK(client_->Close()); @@ -1140,17 +1271,34 @@ void CudaDataTest::TestDoGet() { FlightCallOptions options; options.memory_manager = impl_->device->default_memory_manager(); + const RecordBatchVector& batches = + reinterpret_cast(server_.get())->batches(); + Ticket ticket{""}; std::unique_ptr stream; ASSERT_OK(client_->DoGet(options, ticket, &stream)); - std::shared_ptr table; - ASSERT_OK(stream->ReadAll(&table)); - for (const auto& column : table->columns()) { - for (const auto& chunk : column->chunks()) { - ASSERT_OK(CheckBuffersOnDevice(*chunk, *impl_->device)); + size_t idx = 0; + while (true) { + FlightStreamChunk chunk; + ASSERT_OK(stream->Next(&chunk)); + if (!chunk.data) break; + + for (const auto& column : chunk.data->columns()) { + ASSERT_OK(CheckBuffersOnDevice(*column, *impl_->device)); + } + + if (idx >= batches.size()) { + FAIL() << "Server returned more than " << batches.size() << " batches"; + return; } + + // Bounce record batch back to host memory + ASSERT_OK_AND_ASSIGN(auto host_batch, CopyBatchToHost(*chunk.data)); + AssertBatchesEqual(*batches[idx], *host_batch); + idx++; } + ASSERT_EQ(idx, batches.size()) << "Server returned too few batches"; } void CudaDataTest::TestDoPut() { RecordBatchVector batches; @@ -1175,6 +1323,23 @@ void CudaDataTest::TestDoPut() { ASSERT_OK(writer->WriteRecordBatch(*cuda_batch)); } ASSERT_OK(writer->Close()); + ASSERT_OK(impl_->context->Synchronize()); + + const RecordBatchVector& written = + reinterpret_cast(server_.get())->batches(); + ASSERT_EQ(written.size(), batches.size()); + + size_t idx = 0; + for (const auto& batch : written) { + for (const auto& column : batch->columns()) { + ASSERT_OK(CheckBuffersOnDevice(*column, *impl_->device)); + } + + // Bounce record batch back to host memory + ASSERT_OK_AND_ASSIGN(auto host_batch, CopyBatchToHost(*batch)); + AssertBatchesEqual(*batches[idx], *host_batch); + idx++; + } } void CudaDataTest::TestDoExchange() { FlightCallOptions options; diff --git a/cpp/src/arrow/flight/test_definitions.h b/cpp/src/arrow/flight/test_definitions.h index 1e256456557..7ef195d045d 100644 --- a/cpp/src/arrow/flight/test_definitions.h +++ b/cpp/src/arrow/flight/test_definitions.h @@ -50,6 +50,7 @@ class ARROW_FLIGHT_EXPORT ConnectivityTest : public FlightTest { void TestBuilderHook(); void TestShutdown(); void TestShutdownWithDeadline(); + void TestBrokenConnection(); }; /// Common tests of data plane methods @@ -74,6 +75,8 @@ class ARROW_FLIGHT_EXPORT DataTest : public FlightTest { void TestDoExchangeEcho(); void TestDoExchangeTotal(); void TestDoExchangeError(); + void TestDoExchangeConcurrency(); + void TestDoExchangeUndrained(); void TestIssue5095(); private: @@ -103,6 +106,7 @@ class ARROW_FLIGHT_EXPORT DoPutTest : public FlightTest { void TestDicts(); void TestLargeBatch(); void TestSizeLimit(); + void TestUndrained(); private: std::unique_ptr client_; diff --git a/cpp/src/arrow/flight/test_util.cc b/cpp/src/arrow/flight/test_util.cc index 490f53bfb2f..88bbf5977da 100644 --- a/cpp/src/arrow/flight/test_util.cc +++ b/cpp/src/arrow/flight/test_util.cc @@ -262,6 +262,9 @@ class FlightTestServer : public FlightServerBase { return RunExchangeEcho(std::move(reader), std::move(writer)); } else if (cmd == "large_batch") { return RunExchangeLargeBatch(std::move(reader), std::move(writer)); + } else if (cmd == "TestUndrained") { + ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); + return Status::OK(); } else { return Status::NotImplemented("Scenario not implemented: ", cmd); } diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc index 9af58ba4d28..d651e805a0c 100644 --- a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc +++ b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc @@ -266,13 +266,8 @@ class GrpcClientAuthReader : public ClientAuthReader { template class FinishableDataStream : public internal::ClientDataStream { public: - FinishableDataStream(std::shared_ptr rpc, std::shared_ptr stream, - std::shared_ptr memory_manager) - : rpc_(std::move(rpc)), - stream_(std::move(stream)), - memory_manager_(memory_manager ? std::move(memory_manager) - : CPUDevice::Instance()->default_memory_manager()), - finished_(false) {} + FinishableDataStream(std::shared_ptr rpc, std::shared_ptr stream) + : rpc_(std::move(rpc)), stream_(std::move(stream)), finished_(false) {} void TryCancel() override { rpc_->context.TryCancel(); } @@ -316,7 +311,6 @@ class FinishableDataStream : public internal::ClientDataStream { std::shared_ptr rpc_; std::shared_ptr stream_; - std::shared_ptr memory_manager_; bool finished_; Status server_status_; // A transport-side error that needs to get combined with the server status @@ -330,9 +324,8 @@ template class WritableDataStream : public FinishableDataStream { public: using Base = FinishableDataStream; - WritableDataStream(std::shared_ptr rpc, std::shared_ptr stream, - std::shared_ptr memory_manager) - : Base(std::move(rpc), std::move(stream), std::move(memory_manager)), + WritableDataStream(std::shared_ptr rpc, std::shared_ptr stream) + : Base(std::move(rpc), std::move(stream)), read_mutex_(), finish_mutex_(), done_writing_(false) {} @@ -394,16 +387,7 @@ class GrpcClientGetStream using FinishableDataStream::FinishableDataStream; bool ReadData(internal::FlightData* data) override { - bool success = ReadPayload(stream_.get(), data); - if (ARROW_PREDICT_FALSE(!success)) return false; - if (data->body) { - auto status = Buffer::ViewOrCopy(data->body, memory_manager_).Value(&data->body); - if (!status.ok()) { - transport_status_ = std::move(status); - return false; - } - } - return true; + return ReadPayload(stream_.get(), data); } Status WritesDone() override { return Status::NotImplemented("NYI"); } }; @@ -413,10 +397,7 @@ class GrpcClientPutStream pb::PutResult> { public: using Stream = ::grpc::ClientReaderWriter; - GrpcClientPutStream(std::shared_ptr rpc, std::shared_ptr stream, - std::shared_ptr memory_manager) - : WritableDataStream(std::move(rpc), std::move(stream), std::move(memory_manager)) { - } + using WritableDataStream::WritableDataStream; bool ReadPutMetadata(std::shared_ptr* out) override { std::lock_guard guard(read_mutex_); @@ -440,23 +421,12 @@ class GrpcClientExchangeStream internal::FlightData> { public: using Stream = ::grpc::ClientReaderWriter; - GrpcClientExchangeStream(std::shared_ptr rpc, std::shared_ptr stream, - std::shared_ptr memory_manager) - : WritableDataStream(std::move(rpc), std::move(stream), std::move(memory_manager)) { - } + GrpcClientExchangeStream(std::shared_ptr rpc, std::shared_ptr stream) + : WritableDataStream(std::move(rpc), std::move(stream)) {} bool ReadData(internal::FlightData* data) override { std::lock_guard guard(read_mutex_); - bool success = ReadPayload(stream_.get(), data); - if (ARROW_PREDICT_FALSE(!success)) return false; - if (data->body) { - auto status = Buffer::ViewOrCopy(data->body, memory_manager_).Value(&data->body); - if (!status.ok()) { - transport_status_ = std::move(status); - return false; - } - } - return true; + return ReadPayload(stream_.get(), data); } arrow::Result WriteData(const FlightPayload& payload) override { return WritePayload(payload, this->stream_.get()); @@ -859,8 +829,8 @@ class GrpcClientImpl : public internal::ClientTransport { RETURN_NOT_OK(rpc->SetToken(auth_handler_.get())); std::shared_ptr<::grpc::ClientReader> stream = stub_->DoGet(&rpc->context, pb_ticket); - *out = std::unique_ptr(new GrpcClientGetStream( - std::move(rpc), std::move(stream), options.memory_manager)); + *out = std::unique_ptr( + new GrpcClientGetStream(std::move(rpc), std::move(stream))); return Status::OK(); } @@ -871,8 +841,8 @@ class GrpcClientImpl : public internal::ClientTransport { auto rpc = std::make_shared(options); RETURN_NOT_OK(rpc->SetToken(auth_handler_.get())); std::shared_ptr stream = stub_->DoPut(&rpc->context); - *out = std::unique_ptr(new GrpcClientPutStream( - std::move(rpc), std::move(stream), options.memory_manager)); + *out = std::unique_ptr( + new GrpcClientPutStream(std::move(rpc), std::move(stream))); return Status::OK(); } @@ -883,8 +853,8 @@ class GrpcClientImpl : public internal::ClientTransport { auto rpc = std::make_shared(options); RETURN_NOT_OK(rpc->SetToken(auth_handler_.get())); std::shared_ptr stream = stub_->DoExchange(&rpc->context); - *out = std::unique_ptr(new GrpcClientExchangeStream( - std::move(rpc), std::move(stream), options.memory_manager)); + *out = std::unique_ptr( + new GrpcClientExchangeStream(std::move(rpc), std::move(stream))); return Status::OK(); } From d3891bc4a3837dfe2ddb5eb176d283bafc5e0e86 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 15 Mar 2022 09:25:58 -0400 Subject: [PATCH 2/4] ARROW-15932: [C++][FlightRPC] Add convenience macros for testing --- cpp/src/arrow/flight/flight_test.cc | 54 ++--------------- cpp/src/arrow/flight/test_definitions.cc | 2 +- cpp/src/arrow/flight/test_definitions.h | 74 ++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 49 deletions(-) diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index 1fc9bd0952c..d457490e47c 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -77,79 +77,37 @@ class GrpcConnectivityTest : public ConnectivityTest { protected: std::string transport() const override { return "grpc"; } }; -TEST_F(GrpcConnectivityTest, GetPort) { TestGetPort(); } -TEST_F(GrpcConnectivityTest, BuilderHook) { TestBuilderHook(); } -TEST_F(GrpcConnectivityTest, Shutdown) { TestShutdown(); } -TEST_F(GrpcConnectivityTest, ShutdownWithDeadline) { TestShutdownWithDeadline(); } -TEST_F(GrpcConnectivityTest, BrokenConnection) { TestBrokenConnection(); } +ARROW_FLIGHT_TEST_CONNECTIVITY(GrpcConnectivityTest); class GrpcDataTest : public DataTest { protected: std::string transport() const override { return "grpc"; } }; -TEST_F(GrpcDataTest, TestDoGetInts) { TestDoGetInts(); } -TEST_F(GrpcDataTest, TestDoGetFloats) { TestDoGetFloats(); } -TEST_F(GrpcDataTest, TestDoGetDicts) { TestDoGetDicts(); } -TEST_F(GrpcDataTest, TestDoGetLargeBatch) { TestDoGetLargeBatch(); } -TEST_F(GrpcDataTest, TestOverflowServerBatch) { TestOverflowServerBatch(); } -TEST_F(GrpcDataTest, TestOverflowClientBatch) { TestOverflowClientBatch(); } -TEST_F(GrpcDataTest, TestDoExchange) { TestDoExchange(); } -TEST_F(GrpcDataTest, TestDoExchangeNoData) { TestDoExchangeNoData(); } -TEST_F(GrpcDataTest, TestDoExchangeWriteOnlySchema) { TestDoExchangeWriteOnlySchema(); } -TEST_F(GrpcDataTest, TestDoExchangeGet) { TestDoExchangeGet(); } -TEST_F(GrpcDataTest, TestDoExchangePut) { TestDoExchangePut(); } -TEST_F(GrpcDataTest, TestDoExchangeEcho) { TestDoExchangeEcho(); } -TEST_F(GrpcDataTest, TestDoExchangeTotal) { TestDoExchangeTotal(); } -TEST_F(GrpcDataTest, TestDoExchangeError) { TestDoExchangeError(); } -TEST_F(GrpcDataTest, TestDoExchangeConcurrency) { TestDoExchangeConcurrency(); } -TEST_F(GrpcDataTest, TestDoExchangeUndrained) { TestDoExchangeUndrained(); } -TEST_F(GrpcDataTest, TestIssue5095) { TestIssue5095(); } +ARROW_FLIGHT_TEST_DATA(GrpcDataTest); class GrpcDoPutTest : public DoPutTest { protected: std::string transport() const override { return "grpc"; } }; -TEST_F(GrpcDoPutTest, TestInts) { TestInts(); } -TEST_F(GrpcDoPutTest, TestFloats) { TestFloats(); } -TEST_F(GrpcDoPutTest, TestEmptyBatch) { TestEmptyBatch(); } -TEST_F(GrpcDoPutTest, TestDicts) { TestDicts(); } -TEST_F(GrpcDoPutTest, TestLargeBatch) { TestLargeBatch(); } -TEST_F(GrpcDoPutTest, TestSizeLimit) { TestSizeLimit(); } -TEST_F(GrpcDoPutTest, TestUndrained) { TestUndrained(); } +ARROW_FLIGHT_TEST_DO_PUT(GrpcDoPutTest); class GrpcAppMetadataTest : public AppMetadataTest { protected: std::string transport() const override { return "grpc"; } }; -TEST_F(GrpcAppMetadataTest, TestDoGet) { TestDoGet(); } -TEST_F(GrpcAppMetadataTest, TestDoGetDictionaries) { TestDoGetDictionaries(); } -TEST_F(GrpcAppMetadataTest, TestDoPut) { TestDoPut(); } -TEST_F(GrpcAppMetadataTest, TestDoPutDictionaries) { TestDoPutDictionaries(); } -TEST_F(GrpcAppMetadataTest, TestDoPutReadMetadata) { TestDoPutReadMetadata(); } +ARROW_FLIGHT_TEST_APP_METADATA(GrpcAppMetadataTest); class GrpcIpcOptionsTest : public IpcOptionsTest { protected: std::string transport() const override { return "grpc"; } }; -TEST_F(GrpcIpcOptionsTest, TestDoGetReadOptions) { TestDoGetReadOptions(); } -TEST_F(GrpcIpcOptionsTest, TestDoPutWriteOptions) { TestDoPutWriteOptions(); } -TEST_F(GrpcIpcOptionsTest, TestDoExchangeClientWriteOptions) { - TestDoExchangeClientWriteOptions(); -} -TEST_F(GrpcIpcOptionsTest, TestDoExchangeClientWriteOptionsBegin) { - TestDoExchangeClientWriteOptionsBegin(); -} -TEST_F(GrpcIpcOptionsTest, TestDoExchangeServerWriteOptions) { - TestDoExchangeServerWriteOptions(); -} +ARROW_FLIGHT_TEST_IPC_OPTIONS(GrpcIpcOptionsTest); class GrpcCudaDataTest : public CudaDataTest { protected: std::string transport() const override { return "grpc"; } }; -TEST_F(GrpcCudaDataTest, TestDoGet) { TestDoGet(); } -TEST_F(GrpcCudaDataTest, TestDoPut) { TestDoPut(); } -TEST_F(GrpcCudaDataTest, TestDoExchange) { TestDoExchange(); } +ARROW_FLIGHT_TEST_CUDA_DATA(GrpcCudaDataTest); //------------------------------------------------------------ // Ad-hoc gRPC-specific tests diff --git a/cpp/src/arrow/flight/test_definitions.cc b/cpp/src/arrow/flight/test_definitions.cc index 6dec505e3de..abe06941be9 100644 --- a/cpp/src/arrow/flight/test_definitions.cc +++ b/cpp/src/arrow/flight/test_definitions.cc @@ -106,7 +106,7 @@ void ConnectivityTest::TestBrokenConnection() { ASSERT_OK(server->Wait()); std::unique_ptr info; - ASSERT_FALSE(client->GetFlightInfo(FlightDescriptor::Command(""), &info).ok()); + ASSERT_RAISES(IOError, client->GetFlightInfo(FlightDescriptor::Command(""), &info)); } //------------------------------------------------------------ diff --git a/cpp/src/arrow/flight/test_definitions.h b/cpp/src/arrow/flight/test_definitions.h index 7ef195d045d..601e8d0b4b1 100644 --- a/cpp/src/arrow/flight/test_definitions.h +++ b/cpp/src/arrow/flight/test_definitions.h @@ -29,10 +29,12 @@ #include #include #include +#include #include #include "arrow/flight/server.h" #include "arrow/flight/types.h" +#include "arrow/util/macros.h" namespace arrow { namespace flight { @@ -53,6 +55,15 @@ class ARROW_FLIGHT_EXPORT ConnectivityTest : public FlightTest { void TestBrokenConnection(); }; +#define ARROW_FLIGHT_TEST_CONNECTIVITY(FIXTURE) \ + static_assert(std::is_base_of::value, \ + ARROW_STRINGIFY(FIXTURE) " must inherit from ConnectivityTest"); \ + TEST_F(FIXTURE, GetPort) { TestGetPort(); } \ + TEST_F(FIXTURE, BuilderHook) { TestBuilderHook(); } \ + TEST_F(FIXTURE, Shutdown) { TestShutdown(); } \ + TEST_F(FIXTURE, ShutdownWithDeadline) { TestShutdownWithDeadline(); } \ + TEST_F(FIXTURE, BrokenConnection) { TestBrokenConnection(); } + /// Common tests of data plane methods class ARROW_FLIGHT_EXPORT DataTest : public FlightTest { public: @@ -89,6 +100,27 @@ class ARROW_FLIGHT_EXPORT DataTest : public FlightTest { std::unique_ptr server_; }; +#define ARROW_FLIGHT_TEST_DATA(FIXTURE) \ + static_assert(std::is_base_of::value, \ + ARROW_STRINGIFY(FIXTURE) " must inherit from DataTest"); \ + TEST_F(FIXTURE, TestDoGetInts) { TestDoGetInts(); } \ + TEST_F(FIXTURE, TestDoGetFloats) { TestDoGetFloats(); } \ + TEST_F(FIXTURE, TestDoGetDicts) { TestDoGetDicts(); } \ + TEST_F(FIXTURE, TestDoGetLargeBatch) { TestDoGetLargeBatch(); } \ + TEST_F(FIXTURE, TestOverflowServerBatch) { TestOverflowServerBatch(); } \ + TEST_F(FIXTURE, TestOverflowClientBatch) { TestOverflowClientBatch(); } \ + TEST_F(FIXTURE, TestDoExchange) { TestDoExchange(); } \ + TEST_F(FIXTURE, TestDoExchangeNoData) { TestDoExchangeNoData(); } \ + TEST_F(FIXTURE, TestDoExchangeWriteOnlySchema) { TestDoExchangeWriteOnlySchema(); } \ + TEST_F(FIXTURE, TestDoExchangeGet) { TestDoExchangeGet(); } \ + TEST_F(FIXTURE, TestDoExchangePut) { TestDoExchangePut(); } \ + TEST_F(FIXTURE, TestDoExchangeEcho) { TestDoExchangeEcho(); } \ + TEST_F(FIXTURE, TestDoExchangeTotal) { TestDoExchangeTotal(); } \ + TEST_F(FIXTURE, TestDoExchangeError) { TestDoExchangeError(); } \ + TEST_F(FIXTURE, TestDoExchangeConcurrency) { TestDoExchangeConcurrency(); } \ + TEST_F(FIXTURE, TestDoExchangeUndrained) { TestDoExchangeUndrained(); } \ + TEST_F(FIXTURE, TestIssue5095) { TestIssue5095(); } + /// \brief Specific tests of DoPut. class ARROW_FLIGHT_EXPORT DoPutTest : public FlightTest { public: @@ -113,6 +145,17 @@ class ARROW_FLIGHT_EXPORT DoPutTest : public FlightTest { std::unique_ptr server_; }; +#define ARROW_FLIGHT_TEST_DO_PUT(FIXTURE) \ + static_assert(std::is_base_of::value, \ + ARROW_STRINGIFY(FIXTURE) " must inherit from DoPutTest"); \ + TEST_F(FIXTURE, TestInts) { TestInts(); } \ + TEST_F(FIXTURE, TestFloats) { TestFloats(); } \ + TEST_F(FIXTURE, TestEmptyBatch) { TestEmptyBatch(); } \ + TEST_F(FIXTURE, TestDicts) { TestDicts(); } \ + TEST_F(FIXTURE, TestLargeBatch) { TestLargeBatch(); } \ + TEST_F(FIXTURE, TestSizeLimit) { TestSizeLimit(); } \ + TEST_F(FIXTURE, TestUndrained) { TestUndrained(); } + class ARROW_FLIGHT_EXPORT AppMetadataTestServer : public FlightServerBase { public: virtual ~AppMetadataTestServer() = default; @@ -143,6 +186,15 @@ class ARROW_FLIGHT_EXPORT AppMetadataTest : public FlightTest { std::unique_ptr server_; }; +#define ARROW_FLIGHT_TEST_APP_METADATA(FIXTURE) \ + static_assert(std::is_base_of::value, \ + ARROW_STRINGIFY(FIXTURE) " must inherit from AppMetadataTest"); \ + TEST_F(FIXTURE, TestDoGet) { TestDoGet(); } \ + TEST_F(FIXTURE, TestDoGetDictionaries) { TestDoGetDictionaries(); } \ + TEST_F(FIXTURE, TestDoPut) { TestDoPut(); } \ + TEST_F(FIXTURE, TestDoPutDictionaries) { TestDoPutDictionaries(); } \ + TEST_F(FIXTURE, TestDoPutReadMetadata) { TestDoPutReadMetadata(); } + /// \brief Tests of IPC options in data plane methods. class ARROW_FLIGHT_EXPORT IpcOptionsTest : public FlightTest { public: @@ -161,6 +213,21 @@ class ARROW_FLIGHT_EXPORT IpcOptionsTest : public FlightTest { std::unique_ptr server_; }; +#define ARROW_FLIGHT_TEST_IPC_OPTIONS(FIXTURE) \ + static_assert(std::is_base_of::value, \ + ARROW_STRINGIFY(FIXTURE) " must inherit from IpcOptionsTest"); \ + TEST_F(FIXTURE, TestDoGetReadOptions) { TestDoGetReadOptions(); } \ + TEST_F(FIXTURE, TestDoPutWriteOptions) { TestDoPutWriteOptions(); } \ + TEST_F(FIXTURE, TestDoExchangeClientWriteOptions) { \ + TestDoExchangeClientWriteOptions(); \ + } \ + TEST_F(FIXTURE, TestDoExchangeClientWriteOptionsBegin) { \ + TestDoExchangeClientWriteOptionsBegin(); \ + } \ + TEST_F(FIXTURE, TestDoExchangeServerWriteOptions) { \ + TestDoExchangeServerWriteOptions(); \ + } + /// \brief Tests of data plane methods with CUDA memory. /// /// If not built with ARROW_CUDA, tests are no-ops. @@ -181,5 +248,12 @@ class ARROW_FLIGHT_EXPORT CudaDataTest : public FlightTest { std::shared_ptr impl_; }; +#define ARROW_FLIGHT_TEST_CUDA_DATA(FIXTURE) \ + static_assert(std::is_base_of::value, \ + ARROW_STRINGIFY(FIXTURE) " must inherit from CudaDataTest"); \ + TEST_F(FIXTURE, TestDoGet) { TestDoGet(); } \ + TEST_F(FIXTURE, TestDoPut) { TestDoPut(); } \ + TEST_F(FIXTURE, TestDoExchange) { TestDoExchange(); } + } // namespace flight } // namespace arrow From e2ee96b2e57ab8be1b2b273e3c37252aa2db1192 Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 17 Mar 2022 11:02:59 -0400 Subject: [PATCH 3/4] ARROW-15932: [C++][FlightRPC] Address feedback --- cpp/src/arrow/flight/test_definitions.cc | 44 +++++++++++------------- 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/cpp/src/arrow/flight/test_definitions.cc b/cpp/src/arrow/flight/test_definitions.cc index abe06941be9..41bd092f230 100644 --- a/cpp/src/arrow/flight/test_definitions.cc +++ b/cpp/src/arrow/flight/test_definitions.cc @@ -26,6 +26,7 @@ #include "arrow/flight/test_util.h" #include "arrow/table.h" #include "arrow/testing/generator.h" +#include "arrow/util/checked_cast.h" #include "arrow/util/config.h" #include "arrow/util/logging.h" @@ -36,6 +37,8 @@ namespace arrow { namespace flight { +using arrow::internal::checked_cast; + //------------------------------------------------------------ // Tests of initialization/shutdown @@ -617,6 +620,8 @@ class DoPutTestServer : public FlightServerBase { return Status::Invalid("Expected app_metadata to be ", counter, " but got ", chunk.app_metadata->ToString()); } + } else if (chunk.app_metadata) { + return Status::Invalid("Expected no app_metadata"); } batches_.push_back(std::move(chunk.data)); auto buffer = Buffer::FromString(std::to_string(counter)); @@ -1163,6 +1168,13 @@ Status CheckBuffersOnDevice(const Array& array, const Device& device) { return Status::OK(); } +Status CheckBuffersOnDevice(const RecordBatch& batch, const Device& device) { + for (const auto& column : batch.columns()) { + RETURN_NOT_OK(CheckBuffersOnDevice(*column, device)); + } + return Status::OK(); +} + // Copy a record batch to host memory. arrow::Result> CopyBatchToHost(const RecordBatch& batch) { auto mm = CPUDevice::Instance()->default_memory_manager(); @@ -1215,9 +1227,7 @@ class CudaTestServer : public FlightServerBase { begun = true; RETURN_NOT_OK(writer->Begin(chunk.data->schema())); } - for (const auto& column : chunk.data->columns()) { - RETURN_NOT_OK(CheckBuffersOnDevice(*column, *device_)); - } + RETURN_NOT_OK(CheckBuffersOnDevice(*chunk.data, *device_)); // XXX: do not assume transport will synchronize, we must // synchronize or else data will be "missing" RETURN_NOT_OK(context_->Synchronize()); @@ -1272,7 +1282,7 @@ void CudaDataTest::TestDoGet() { options.memory_manager = impl_->device->default_memory_manager(); const RecordBatchVector& batches = - reinterpret_cast(server_.get())->batches(); + checked_cast(server_.get())->batches(); Ticket ticket{""}; std::unique_ptr stream; @@ -1284,10 +1294,7 @@ void CudaDataTest::TestDoGet() { ASSERT_OK(stream->Next(&chunk)); if (!chunk.data) break; - for (const auto& column : chunk.data->columns()) { - ASSERT_OK(CheckBuffersOnDevice(*column, *impl_->device)); - } - + ASSERT_OK(CheckBuffersOnDevice(*chunk.data, *impl_->device)); if (idx >= batches.size()) { FAIL() << "Server returned more than " << batches.size() << " batches"; return; @@ -1316,25 +1323,19 @@ void CudaDataTest::TestDoPut() { ASSERT_OK_AND_ASSIGN(auto cuda_batch, cuda::ReadRecordBatch(batch->schema(), &memo, buffer)); - for (const auto& column : cuda_batch->columns()) { - ASSERT_OK(CheckBuffersOnDevice(*column, *impl_->device)); - } - + ASSERT_OK(CheckBuffersOnDevice(*cuda_batch, *impl_->device)); ASSERT_OK(writer->WriteRecordBatch(*cuda_batch)); } ASSERT_OK(writer->Close()); ASSERT_OK(impl_->context->Synchronize()); const RecordBatchVector& written = - reinterpret_cast(server_.get())->batches(); + checked_cast(server_.get())->batches(); ASSERT_EQ(written.size(), batches.size()); size_t idx = 0; for (const auto& batch : written) { - for (const auto& column : batch->columns()) { - ASSERT_OK(CheckBuffersOnDevice(*column, *impl_->device)); - } - + ASSERT_OK(CheckBuffersOnDevice(*batch, *impl_->device)); // Bounce record batch back to host memory ASSERT_OK_AND_ASSIGN(auto host_batch, CopyBatchToHost(*batch)); AssertBatchesEqual(*batches[idx], *host_batch); @@ -1362,17 +1363,12 @@ void CudaDataTest::TestDoExchange() { ASSERT_OK_AND_ASSIGN(auto cuda_batch, cuda::ReadRecordBatch(batch->schema(), &write_memo, buffer)); - for (const auto& column : cuda_batch->columns()) { - ASSERT_OK(CheckBuffersOnDevice(*column, *impl_->device)); - } - + ASSERT_OK(CheckBuffersOnDevice(*cuda_batch, *impl_->device)); ASSERT_OK(writer->WriteRecordBatch(*cuda_batch)); FlightStreamChunk chunk; ASSERT_OK(reader->Next(&chunk)); - for (const auto& column : chunk.data->columns()) { - ASSERT_OK(CheckBuffersOnDevice(*column, *impl_->device)); - } + ASSERT_OK(CheckBuffersOnDevice(*chunk.data, *impl_->device)); // Bounce record batch back to host memory ASSERT_OK_AND_ASSIGN(auto host_batch, CopyBatchToHost(*chunk.data)); From ca5a9bff8f3737bd8e720d90b7de5209dd66720f Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 17 Mar 2022 13:16:09 -0400 Subject: [PATCH 4/4] ARROW-15932: [C++][FlightRPC] Improve flaky test --- cpp/src/arrow/flight/flight_test.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index d457490e47c..f4207a34f15 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -1224,8 +1224,10 @@ TEST_F(TestBasicAuthHandler, FailUnauthenticatedCalls) { std::shared_ptr schema( (new arrow::Schema(std::vector>()))); status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader); - ASSERT_OK(status); + // May or may not succeed depending on if the transport buffers the write + ARROW_UNUSED(status); status = writer->Close(); + // But this should definitely fail ASSERT_RAISES(IOError, status); ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token")); }