From f8124c8e8dd52ef602e25dad99e45f41d899e36c Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 23 Feb 2022 13:48:29 -0500
Subject: [PATCH 1/5] ARROW-15707: [C++][FlightRPC] Split out Flight tests that
don't use the network
---
cpp/src/arrow/flight/CMakeLists.txt | 6 +
cpp/src/arrow/flight/flight_cuda_test.cc | 10 +-
cpp/src/arrow/flight/flight_internals_test.cc | 449 ++++++++++++++
cpp/src/arrow/flight/flight_test.cc | 553 ++----------------
cpp/src/arrow/flight/test_util.cc | 32 +-
cpp/src/arrow/flight/test_util.h | 71 +--
cpp/src/arrow/flight/types.cc | 4 +
cpp/src/arrow/flight/types.h | 15 +-
8 files changed, 557 insertions(+), 583 deletions(-)
create mode 100644 cpp/src/arrow/flight/flight_internals_test.cc
diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt
index 5861d8475d6..14eebc262ee 100644
--- a/cpp/src/arrow/flight/CMakeLists.txt
+++ b/cpp/src/arrow/flight/CMakeLists.txt
@@ -237,6 +237,12 @@ foreach(LIB_TARGET ${ARROW_FLIGHT_TESTING_LIBRARIES})
${ARROW_BOOST_PROCESS_COMPILE_DEFINITIONS})
endforeach()
+add_arrow_test(flight_internals_test
+ STATIC_LINK_LIBS
+ ${ARROW_FLIGHT_TEST_LINK_LIBS}
+ LABELS
+ "arrow_flight")
+
add_arrow_test(flight_test
STATIC_LINK_LIBS
${ARROW_FLIGHT_TEST_LINK_LIBS}
diff --git a/cpp/src/arrow/flight/flight_cuda_test.cc b/cpp/src/arrow/flight/flight_cuda_test.cc
index b2ed6394cb8..b812efd4677 100644
--- a/cpp/src/arrow/flight/flight_cuda_test.cc
+++ b/cpp/src/arrow/flight/flight_cuda_test.cc
@@ -70,16 +70,16 @@ class CudaTestServer : public FlightServerBase {
Status DoGet(const ServerCallContext&, const Ticket&,
std::unique_ptr* data_stream) override {
- BatchVector batches;
+ RecordBatchVector batches;
RETURN_NOT_OK(ExampleIntBatches(&batches));
- auto batch_reader = std::make_shared(batches[0]->schema(), batches);
+ ARROW_ASSIGN_OR_RAISE(auto batch_reader, RecordBatchReader::Make(batches));
*data_stream = std::unique_ptr(new RecordBatchStream(batch_reader));
return Status::OK();
}
Status DoPut(const ServerCallContext&, std::unique_ptr reader,
std::unique_ptr writer) override {
- BatchVector batches;
+ RecordBatchVector batches;
RETURN_NOT_OK(reader->ReadAll(&batches));
for (const auto& batch : batches) {
for (const auto& column : batch->columns()) {
@@ -161,7 +161,7 @@ TEST_F(TestCuda, DoGet) {
TEST_F(TestCuda, DoPut) {
// Check that we can send a record batch containing references to
// GPU buffers.
- BatchVector batches;
+ RecordBatchVector batches;
ASSERT_OK(ExampleIntBatches(&batches));
std::unique_ptr writer;
@@ -190,7 +190,7 @@ TEST_F(TestCuda, DoExchange) {
FlightCallOptions options;
options.memory_manager = device_->default_memory_manager();
- BatchVector batches;
+ RecordBatchVector batches;
ASSERT_OK(ExampleIntBatches(&batches));
std::unique_ptr writer;
diff --git a/cpp/src/arrow/flight/flight_internals_test.cc b/cpp/src/arrow/flight/flight_internals_test.cc
new file mode 100644
index 00000000000..525e19499c3
--- /dev/null
+++ b/cpp/src/arrow/flight/flight_internals_test.cc
@@ -0,0 +1,449 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// ----------------------------------------------------------------------
+// Tests for Flight which don't actually spin up a client/server
+
+#include
+#include
+
+#include "arrow/flight/client_cookie_middleware.h"
+#include "arrow/flight/client_header_internal.h"
+#include "arrow/flight/client_middleware.h"
+#include "arrow/flight/types.h"
+#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/string.h"
+
+#include "arrow/flight/internal.h"
+#include "arrow/flight/test_util.h"
+
+namespace arrow {
+namespace flight {
+
+namespace pb = arrow::flight::protocol;
+
+// ----------------------------------------------------------------------
+// Core Flight types
+
+TEST(FlightTypes, FlightDescriptor) {
+ auto a = FlightDescriptor::Command("select * from table");
+ auto b = FlightDescriptor::Command("select * from table");
+ auto c = FlightDescriptor::Command("select foo from table");
+ auto d = FlightDescriptor::Path({"foo", "bar"});
+ auto e = FlightDescriptor::Path({"foo", "baz"});
+ auto f = FlightDescriptor::Path({"foo", "baz"});
+
+ ASSERT_EQ(a.ToString(), "FlightDescriptor");
+ ASSERT_EQ(d.ToString(), "FlightDescriptor");
+ ASSERT_TRUE(a.Equals(b));
+ ASSERT_FALSE(a.Equals(c));
+ ASSERT_FALSE(a.Equals(d));
+ ASSERT_FALSE(d.Equals(e));
+ ASSERT_TRUE(e.Equals(f));
+}
+
+// This tests the internal protobuf types which don't get exported in the Flight DLL.
+#ifndef _WIN32
+TEST(FlightTypes, FlightDescriptorToFromProto) {
+ FlightDescriptor descr_test;
+ pb::FlightDescriptor pb_descr;
+
+ FlightDescriptor descr1{FlightDescriptor::PATH, "", {"foo", "bar"}};
+ ASSERT_OK(internal::ToProto(descr1, &pb_descr));
+ ASSERT_OK(internal::FromProto(pb_descr, &descr_test));
+ ASSERT_EQ(descr1, descr_test);
+
+ FlightDescriptor descr2{FlightDescriptor::CMD, "command", {}};
+ ASSERT_OK(internal::ToProto(descr2, &pb_descr));
+ ASSERT_OK(internal::FromProto(pb_descr, &descr_test));
+ ASSERT_EQ(descr2, descr_test);
+}
+#endif
+
+// ARROW-6017: we should be able to construct locations for unknown
+// schemes
+TEST(FlightTypes, LocationUnknownScheme) {
+ Location location;
+ ASSERT_OK(Location::Parse("s3://test", &location));
+ ASSERT_OK(Location::Parse("https://example.com/foo", &location));
+}
+
+TEST(FlightTypes, RoundTripTypes) {
+ Ticket ticket{"foo"};
+ ASSERT_OK_AND_ASSIGN(std::string ticket_serialized, ticket.SerializeToString());
+ ASSERT_OK_AND_ASSIGN(Ticket ticket_deserialized,
+ Ticket::Deserialize(ticket_serialized));
+ ASSERT_EQ(ticket.ticket, ticket_deserialized.ticket);
+
+ FlightDescriptor desc = FlightDescriptor::Command("select * from foo;");
+ ASSERT_OK_AND_ASSIGN(std::string desc_serialized, desc.SerializeToString());
+ ASSERT_OK_AND_ASSIGN(FlightDescriptor desc_deserialized,
+ FlightDescriptor::Deserialize(desc_serialized));
+ ASSERT_TRUE(desc.Equals(desc_deserialized));
+
+ desc = FlightDescriptor::Path({"a", "b", "test.arrow"});
+ ASSERT_OK_AND_ASSIGN(desc_serialized, desc.SerializeToString());
+ ASSERT_OK_AND_ASSIGN(desc_deserialized, FlightDescriptor::Deserialize(desc_serialized));
+ ASSERT_TRUE(desc.Equals(desc_deserialized));
+
+ FlightInfo::Data data;
+ std::shared_ptr schema =
+ arrow::schema({field("a", int64()), field("b", int64()), field("c", int64()),
+ field("d", int64())});
+ Location location1, location2, location3;
+ ASSERT_OK(Location::ForGrpcTcp("localhost", 10010, &location1));
+ ASSERT_OK(Location::ForGrpcTls("localhost", 10010, &location2));
+ ASSERT_OK(Location::ForGrpcUnix("/tmp/test.sock", &location3));
+ std::vector endpoints{FlightEndpoint{ticket, {location1, location2}},
+ FlightEndpoint{ticket, {location3}}};
+ ASSERT_OK(MakeFlightInfo(*schema, desc, endpoints, -1, -1, &data));
+ std::unique_ptr info = std::unique_ptr(new FlightInfo(data));
+ ASSERT_OK_AND_ASSIGN(std::string info_serialized, info->SerializeToString());
+ ASSERT_OK_AND_ASSIGN(std::unique_ptr info_deserialized,
+ FlightInfo::Deserialize(info_serialized));
+ ASSERT_TRUE(info->descriptor().Equals(info_deserialized->descriptor()));
+ ASSERT_EQ(info->endpoints(), info_deserialized->endpoints());
+ ASSERT_EQ(info->total_records(), info_deserialized->total_records());
+ ASSERT_EQ(info->total_bytes(), info_deserialized->total_bytes());
+}
+
+TEST(FlightTypes, RoundtripStatus) {
+ // Make sure status codes round trip through our conversions
+
+ std::shared_ptr detail;
+ detail = FlightStatusDetail::UnwrapStatus(
+ MakeFlightError(FlightStatusCode::Internal, "Test message"));
+ ASSERT_NE(nullptr, detail);
+ ASSERT_EQ(FlightStatusCode::Internal, detail->code());
+
+ detail = FlightStatusDetail::UnwrapStatus(
+ MakeFlightError(FlightStatusCode::TimedOut, "Test message"));
+ ASSERT_NE(nullptr, detail);
+ ASSERT_EQ(FlightStatusCode::TimedOut, detail->code());
+
+ detail = FlightStatusDetail::UnwrapStatus(
+ MakeFlightError(FlightStatusCode::Cancelled, "Test message"));
+ ASSERT_NE(nullptr, detail);
+ ASSERT_EQ(FlightStatusCode::Cancelled, detail->code());
+
+ detail = FlightStatusDetail::UnwrapStatus(
+ MakeFlightError(FlightStatusCode::Unauthenticated, "Test message"));
+ ASSERT_NE(nullptr, detail);
+ ASSERT_EQ(FlightStatusCode::Unauthenticated, detail->code());
+
+ detail = FlightStatusDetail::UnwrapStatus(
+ MakeFlightError(FlightStatusCode::Unauthorized, "Test message"));
+ ASSERT_NE(nullptr, detail);
+ ASSERT_EQ(FlightStatusCode::Unauthorized, detail->code());
+
+ detail = FlightStatusDetail::UnwrapStatus(
+ MakeFlightError(FlightStatusCode::Unavailable, "Test message"));
+ ASSERT_NE(nullptr, detail);
+ ASSERT_EQ(FlightStatusCode::Unavailable, detail->code());
+
+ Status status = internal::FromGrpcStatus(
+ internal::ToGrpcStatus(Status::NotImplemented("Sentinel")));
+ ASSERT_TRUE(status.IsNotImplemented());
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));
+
+ status = internal::FromGrpcStatus(internal::ToGrpcStatus(Status::Invalid("Sentinel")));
+ ASSERT_TRUE(status.IsInvalid());
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));
+
+ status = internal::FromGrpcStatus(internal::ToGrpcStatus(Status::KeyError("Sentinel")));
+ ASSERT_TRUE(status.IsKeyError());
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));
+
+ status =
+ internal::FromGrpcStatus(internal::ToGrpcStatus(Status::AlreadyExists("Sentinel")));
+ ASSERT_TRUE(status.IsAlreadyExists());
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));
+}
+
+// ----------------------------------------------------------------------
+// Cookie authentication/middleware
+
+// This test keeps an internal cookie cache and compares that with the middleware.
+class TestCookieMiddleware : public ::testing::Test {
+ public:
+ // Setup function creates middleware factory and starts it up.
+ void SetUp() {
+ factory_ = GetCookieFactory();
+ CallInfo callInfo;
+ factory_->StartCall(callInfo, &middleware_);
+ }
+
+ // Function to add incoming cookies to middleware and validate them.
+ void AddAndValidate(const std::string& incoming_cookie) {
+ // Add cookie
+ CallHeaders call_headers;
+ call_headers.insert(std::make_pair(arrow::util::string_view("set-cookie"),
+ arrow::util::string_view(incoming_cookie)));
+ middleware_->ReceivedHeaders(call_headers);
+ expected_cookie_cache_.UpdateCachedCookies(call_headers);
+
+ // Get cookie from middleware.
+ TestCallHeaders add_call_headers;
+ middleware_->SendingHeaders(&add_call_headers);
+ const std::string actual_cookies = add_call_headers.GetCookies();
+
+ // Validate cookie
+ const std::string expected_cookies = expected_cookie_cache_.GetValidCookiesAsString();
+ const std::vector split_expected_cookies =
+ SplitCookies(expected_cookies);
+ const std::vector split_actual_cookies = SplitCookies(actual_cookies);
+ EXPECT_EQ(split_expected_cookies, split_actual_cookies);
+ }
+
+ // Function to take a list of cookies and split them into a vector of individual
+ // cookies. This is done because the cookie cache is a map so ordering is not
+ // necessarily consistent.
+ static std::vector SplitCookies(const std::string& cookies) {
+ std::vector split_cookies;
+ std::string::size_type pos1 = 0;
+ std::string::size_type pos2 = 0;
+ while ((pos2 = cookies.find(';', pos1)) != std::string::npos) {
+ split_cookies.push_back(
+ arrow::internal::TrimString(cookies.substr(pos1, pos2 - pos1)));
+ pos1 = pos2 + 1;
+ }
+ if (pos1 < cookies.size()) {
+ split_cookies.push_back(arrow::internal::TrimString(cookies.substr(pos1)));
+ }
+ std::sort(split_cookies.begin(), split_cookies.end());
+ return split_cookies;
+ }
+
+ protected:
+ // Class to allow testing of the call headers.
+ class TestCallHeaders : public AddCallHeaders {
+ public:
+ TestCallHeaders() {}
+ ~TestCallHeaders() {}
+
+ // Function to add cookie header.
+ void AddHeader(const std::string& key, const std::string& value) {
+ ASSERT_EQ(key, "cookie");
+ outbound_cookie_ = value;
+ }
+
+ // Function to get outgoing cookie.
+ std::string GetCookies() { return outbound_cookie_; }
+
+ private:
+ std::string outbound_cookie_;
+ };
+
+ internal::CookieCache expected_cookie_cache_;
+ std::unique_ptr middleware_;
+ std::shared_ptr factory_;
+};
+
+TEST_F(TestCookieMiddleware, BasicParsing) {
+ AddAndValidate("id1=1; foo=bar;");
+ AddAndValidate("id1=1; foo=bar");
+ AddAndValidate("id2=2;");
+ AddAndValidate("id4=\"4\"");
+ AddAndValidate("id5=5; foo=bar; baz=buz;");
+}
+
+TEST_F(TestCookieMiddleware, Overwrite) {
+ AddAndValidate("id0=0");
+ AddAndValidate("id0=1");
+ AddAndValidate("id1=0");
+ AddAndValidate("id1=1");
+ AddAndValidate("id1=1");
+ AddAndValidate("id1=10");
+ AddAndValidate("id=3");
+ AddAndValidate("id=0");
+ AddAndValidate("id=0");
+}
+
+TEST_F(TestCookieMiddleware, MaxAge) {
+ AddAndValidate("id0=0; max-age=0;");
+ AddAndValidate("id1=0; max-age=-1;");
+ AddAndValidate("id2=0; max-age=0");
+ AddAndValidate("id3=0; max-age=-1");
+ AddAndValidate("id4=0; max-age=1");
+ AddAndValidate("id5=0; max-age=1");
+ AddAndValidate("id4=0; max-age=0");
+ AddAndValidate("id5=0; max-age=0");
+}
+
+TEST_F(TestCookieMiddleware, Expires) {
+ AddAndValidate("id0=0; expires=0, 0 0 0 0:0:0 GMT;");
+ AddAndValidate("id0=0; expires=0, 0 0 0 0:0:0 GMT");
+ AddAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT;");
+ AddAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT");
+ AddAndValidate("id0=0; expires=Fri, 01 Jan 2038 22:15:36 GMT;");
+ AddAndValidate("id1=0; expires=Fri, 01 Jan 2038 22:15:36 GMT");
+ AddAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT;");
+ AddAndValidate("id1=0; expires=Fri, 22 Dec 2017 22:15:36 GMT");
+}
+
+// This test is used to test the parsing capabilities of the cookie framework.
+class TestCookieParsing : public ::testing::Test {
+ public:
+ void VerifyParseCookie(const std::string& cookie_str, bool expired) {
+ internal::Cookie cookie = internal::Cookie::parse(cookie_str);
+ EXPECT_EQ(expired, cookie.IsExpired());
+ }
+
+ void VerifyCookieName(const std::string& cookie_str, const std::string& name) {
+ internal::Cookie cookie = internal::Cookie::parse(cookie_str);
+ EXPECT_EQ(name, cookie.GetName());
+ }
+
+ void VerifyCookieString(const std::string& cookie_str,
+ const std::string& cookie_as_string) {
+ internal::Cookie cookie = internal::Cookie::parse(cookie_str);
+ EXPECT_EQ(cookie_as_string, cookie.AsCookieString());
+ }
+
+ void VerifyCookieDateConverson(std::string date, const std::string& converted_date) {
+ internal::Cookie::ConvertCookieDate(&date);
+ EXPECT_EQ(converted_date, date);
+ }
+
+ void VerifyCookieAttributeParsing(
+ const std::string cookie_str, std::string::size_type start_pos,
+ const util::optional> cookie_attribute,
+ const std::string::size_type start_pos_after) {
+ util::optional> attr =
+ internal::Cookie::ParseCookieAttribute(cookie_str, &start_pos);
+
+ if (cookie_attribute == util::nullopt) {
+ EXPECT_EQ(cookie_attribute, attr);
+ } else {
+ EXPECT_EQ(cookie_attribute.value(), attr.value());
+ }
+ EXPECT_EQ(start_pos_after, start_pos);
+ }
+
+ void AddCookieVerifyCache(const std::vector& cookies,
+ const std::string& expected_cookies) {
+ internal::CookieCache cookie_cache;
+ for (auto& cookie : cookies) {
+ // Add cookie
+ CallHeaders call_headers;
+ call_headers.insert(std::make_pair(arrow::util::string_view("set-cookie"),
+ arrow::util::string_view(cookie)));
+ cookie_cache.UpdateCachedCookies(call_headers);
+ }
+ const std::string actual_cookies = cookie_cache.GetValidCookiesAsString();
+ const std::vector actual_split_cookies =
+ TestCookieMiddleware::SplitCookies(actual_cookies);
+ const std::vector expected_split_cookies =
+ TestCookieMiddleware::SplitCookies(expected_cookies);
+ }
+};
+
+TEST_F(TestCookieParsing, Expired) {
+ VerifyParseCookie("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT;", true);
+ VerifyParseCookie("id1=0; max-age=-1;", true);
+ VerifyParseCookie("id0=0; max-age=0;", true);
+}
+
+TEST_F(TestCookieParsing, Invalid) {
+ VerifyParseCookie("id1=0; expires=0, 0 0 0 0:0:0 GMT;", true);
+ VerifyParseCookie("id1=0; expires=Fri, 01 FOO 2038 22:15:36 GMT", true);
+ VerifyParseCookie("id1=0; expires=foo", true);
+ VerifyParseCookie("id1=0; expires=", true);
+ VerifyParseCookie("id1=0; max-age=FOO", true);
+ VerifyParseCookie("id1=0; max-age=", true);
+}
+
+TEST_F(TestCookieParsing, NoExpiry) {
+ VerifyParseCookie("id1=0;", false);
+ VerifyParseCookie("id1=0; noexpiry=Fri, 01 Jan 2038 22:15:36 GMT", false);
+ VerifyParseCookie("id1=0; noexpiry=\"Fri, 01 Jan 2038 22:15:36 GMT\"", false);
+ VerifyParseCookie("id1=0; nomax-age=-1", false);
+ VerifyParseCookie("id1=0; nomax-age=\"-1\"", false);
+ VerifyParseCookie("id1=0; randomattr=foo", false);
+}
+
+TEST_F(TestCookieParsing, NotExpired) {
+ VerifyParseCookie("id5=0; max-age=1", false);
+ VerifyParseCookie("id0=0; expires=Fri, 01 Jan 2038 22:15:36 GMT;", false);
+}
+
+TEST_F(TestCookieParsing, GetName) {
+ VerifyCookieName("id1=1; foo=bar;", "id1");
+ VerifyCookieName("id1=1; foo=bar", "id1");
+ VerifyCookieName("id2=2;", "id2");
+ VerifyCookieName("id4=\"4\"", "id4");
+ VerifyCookieName("id5=5; foo=bar; baz=buz;", "id5");
+}
+
+TEST_F(TestCookieParsing, ToString) {
+ VerifyCookieString("id1=1; foo=bar;", "id1=1");
+ VerifyCookieString("id1=1; foo=bar", "id1=1");
+ VerifyCookieString("id2=2;", "id2=2");
+ VerifyCookieString("id4=\"4\"", "id4=4");
+ VerifyCookieString("id5=5; foo=bar; baz=buz;", "id5=5");
+}
+
+TEST_F(TestCookieParsing, DateConversion) {
+ VerifyCookieDateConverson("Mon, 01 jan 2038 22:15:36 GMT;", "01 01 2038 22:15:36");
+ VerifyCookieDateConverson("TUE, 10 Feb 2038 22:15:36 GMT", "10 02 2038 22:15:36");
+ VerifyCookieDateConverson("WED, 20 MAr 2038 22:15:36 GMT;", "20 03 2038 22:15:36");
+ VerifyCookieDateConverson("thu, 15 APR 2038 22:15:36 GMT", "15 04 2038 22:15:36");
+ VerifyCookieDateConverson("Fri, 30 mAY 2038 22:15:36 GMT;", "30 05 2038 22:15:36");
+ VerifyCookieDateConverson("Sat, 03 juN 2038 22:15:36 GMT", "03 06 2038 22:15:36");
+ VerifyCookieDateConverson("Sun, 01 JuL 2038 22:15:36 GMT;", "01 07 2038 22:15:36");
+ VerifyCookieDateConverson("Fri, 06 aUg 2038 22:15:36 GMT", "06 08 2038 22:15:36");
+ VerifyCookieDateConverson("Fri, 01 SEP 2038 22:15:36 GMT;", "01 09 2038 22:15:36");
+ VerifyCookieDateConverson("Fri, 01 OCT 2038 22:15:36 GMT", "01 10 2038 22:15:36");
+ VerifyCookieDateConverson("Fri, 01 Nov 2038 22:15:36 GMT;", "01 11 2038 22:15:36");
+ VerifyCookieDateConverson("Fri, 01 deC 2038 22:15:36 GMT", "01 12 2038 22:15:36");
+ VerifyCookieDateConverson("", "");
+ VerifyCookieDateConverson("Fri, 01 INVALID 2038 22:15:36 GMT;",
+ "01 INVALID 2038 22:15:36");
+}
+
+TEST_F(TestCookieParsing, ParseCookieAttribute) {
+ VerifyCookieAttributeParsing("", 0, util::nullopt, std::string::npos);
+
+ std::string cookie_string = "attr0=0; attr1=1; attr2=2; attr3=3";
+ auto attr_length = std::string("attr0=0;").length();
+ std::string::size_type start_pos = 0;
+ VerifyCookieAttributeParsing(cookie_string, start_pos, std::make_pair("attr0", "0"),
+ cookie_string.find("attr0=0;") + attr_length);
+ VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length + 1)),
+ std::make_pair("attr1", "1"),
+ cookie_string.find("attr1=1;") + attr_length);
+ VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length + 1)),
+ std::make_pair("attr2", "2"),
+ cookie_string.find("attr2=2;") + attr_length);
+ VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length + 1)),
+ std::make_pair("attr3", "3"), std::string::npos);
+ VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length - 1)),
+ util::nullopt, std::string::npos);
+ VerifyCookieAttributeParsing(cookie_string, std::string::npos, util::nullopt,
+ std::string::npos);
+}
+
+TEST_F(TestCookieParsing, CookieCache) {
+ AddCookieVerifyCache({"id0=0;"}, "");
+ AddCookieVerifyCache({"id0=0;", "id0=1;"}, "id0=1");
+ AddCookieVerifyCache({"id0=0;", "id1=1;"}, "id0=0; id1=1");
+ AddCookieVerifyCache({"id0=0;", "id1=1;", "id2=2"}, "id0=0; id1=1; id2=2");
+}
+
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc
index 003c8e64c91..96bc693486e 100644
--- a/cpp/src/arrow/flight/flight_test.cc
+++ b/cpp/src/arrow/flight/flight_test.cc
@@ -39,14 +39,11 @@
#include "arrow/util/base64.h"
#include "arrow/util/logging.h"
#include "arrow/util/make_unique.h"
-#include "arrow/util/string.h"
#ifdef GRPCPP_GRPCPP_H
#error "gRPC headers should not be in public API"
#endif
-#include "arrow/flight/client_cookie_middleware.h"
-#include "arrow/flight/client_header_internal.h"
#include "arrow/flight/internal.h"
#include "arrow/flight/middleware_internal.h"
#include "arrow/flight/test_util.h"
@@ -65,120 +62,6 @@ const char kBasicPrefix[] = "Basic ";
const char kBearerPrefix[] = "Bearer ";
const char kAuthHeader[] = "authorization";
-void AssertEqual(const ActionType& expected, const ActionType& actual) {
- ASSERT_EQ(expected.type, actual.type);
- ASSERT_EQ(expected.description, actual.description);
-}
-
-void AssertEqual(const FlightDescriptor& expected, const FlightDescriptor& actual) {
- ASSERT_TRUE(expected.Equals(actual));
-}
-
-void AssertEqual(const Ticket& expected, const Ticket& actual) {
- ASSERT_EQ(expected.ticket, actual.ticket);
-}
-
-void AssertEqual(const Location& expected, const Location& actual) {
- ASSERT_EQ(expected, actual);
-}
-
-void AssertEqual(const std::vector& expected,
- const std::vector& actual) {
- ASSERT_EQ(expected.size(), actual.size());
- for (size_t i = 0; i < expected.size(); ++i) {
- AssertEqual(expected[i].ticket, actual[i].ticket);
-
- ASSERT_EQ(expected[i].locations.size(), actual[i].locations.size());
- for (size_t j = 0; j < expected[i].locations.size(); ++j) {
- AssertEqual(expected[i].locations[j], actual[i].locations[j]);
- }
- }
-}
-
-template
-void AssertEqual(const std::vector& expected, const std::vector& actual) {
- ASSERT_EQ(expected.size(), actual.size());
- for (size_t i = 0; i < expected.size(); ++i) {
- AssertEqual(expected[i], actual[i]);
- }
-}
-
-void AssertEqual(const FlightInfo& expected, const FlightInfo& actual) {
- std::shared_ptr ex_schema, actual_schema;
- ipc::DictionaryMemo expected_memo;
- ipc::DictionaryMemo actual_memo;
- ASSERT_OK(expected.GetSchema(&expected_memo, &ex_schema));
- ASSERT_OK(actual.GetSchema(&actual_memo, &actual_schema));
-
- AssertSchemaEqual(*ex_schema, *actual_schema);
- ASSERT_EQ(expected.total_records(), actual.total_records());
- ASSERT_EQ(expected.total_bytes(), actual.total_bytes());
-
- AssertEqual(expected.descriptor(), actual.descriptor());
- AssertEqual(expected.endpoints(), actual.endpoints());
-}
-
-TEST(TestFlightDescriptor, Basics) {
- auto a = FlightDescriptor::Command("select * from table");
- auto b = FlightDescriptor::Command("select * from table");
- auto c = FlightDescriptor::Command("select foo from table");
- auto d = FlightDescriptor::Path({"foo", "bar"});
- auto e = FlightDescriptor::Path({"foo", "baz"});
- auto f = FlightDescriptor::Path({"foo", "baz"});
-
- ASSERT_EQ(a.ToString(), "FlightDescriptor");
- ASSERT_EQ(d.ToString(), "FlightDescriptor");
- ASSERT_TRUE(a.Equals(b));
- ASSERT_FALSE(a.Equals(c));
- ASSERT_FALSE(a.Equals(d));
- ASSERT_FALSE(d.Equals(e));
- ASSERT_TRUE(e.Equals(f));
-}
-
-// This tests the internal protobuf types which don't get exported in the Flight DLL.
-#ifndef _WIN32
-TEST(TestFlightDescriptor, ToFromProto) {
- FlightDescriptor descr_test;
- pb::FlightDescriptor pb_descr;
-
- FlightDescriptor descr1{FlightDescriptor::PATH, "", {"foo", "bar"}};
- ASSERT_OK(internal::ToProto(descr1, &pb_descr));
- ASSERT_OK(internal::FromProto(pb_descr, &descr_test));
- AssertEqual(descr1, descr_test);
-
- FlightDescriptor descr2{FlightDescriptor::CMD, "command", {}};
- ASSERT_OK(internal::ToProto(descr2, &pb_descr));
- ASSERT_OK(internal::FromProto(pb_descr, &descr_test));
- AssertEqual(descr2, descr_test);
-}
-#endif
-
-TEST(TestFlight, DISABLED_StartStopTestServer) {
- TestServer server("flight-test-server");
- server.Start();
- ASSERT_TRUE(server.IsRunning());
-
- std::this_thread::sleep_for(std::chrono::duration(0.2));
-
- ASSERT_TRUE(server.IsRunning());
- int exit_code = server.Stop();
-#ifdef _WIN32
- // We do a hard kill on Windows
- ASSERT_EQ(259, exit_code);
-#else
- ASSERT_EQ(0, exit_code);
-#endif
- ASSERT_FALSE(server.IsRunning());
-}
-
-// ARROW-6017: we should be able to construct locations for unknown
-// schemes
-TEST(TestFlight, UnknownLocationScheme) {
- Location location;
- ASSERT_OK(Location::Parse("s3://test", &location));
- ASSERT_OK(Location::Parse("https://example.com/foo", &location));
-}
-
TEST(TestFlight, ConnectUri) {
TestServer server("flight-test-server");
server.Start();
@@ -221,98 +104,6 @@ TEST(TestFlight, ConnectUriUnix) {
}
#endif
-TEST(TestFlight, RoundTripTypes) {
- Ticket ticket{"foo"};
- ASSERT_OK_AND_ASSIGN(std::string ticket_serialized, ticket.SerializeToString());
- ASSERT_OK_AND_ASSIGN(Ticket ticket_deserialized,
- Ticket::Deserialize(ticket_serialized));
- ASSERT_EQ(ticket.ticket, ticket_deserialized.ticket);
-
- FlightDescriptor desc = FlightDescriptor::Command("select * from foo;");
- ASSERT_OK_AND_ASSIGN(std::string desc_serialized, desc.SerializeToString());
- ASSERT_OK_AND_ASSIGN(FlightDescriptor desc_deserialized,
- FlightDescriptor::Deserialize(desc_serialized));
- ASSERT_TRUE(desc.Equals(desc_deserialized));
-
- desc = FlightDescriptor::Path({"a", "b", "test.arrow"});
- ASSERT_OK_AND_ASSIGN(desc_serialized, desc.SerializeToString());
- ASSERT_OK_AND_ASSIGN(desc_deserialized, FlightDescriptor::Deserialize(desc_serialized));
- ASSERT_TRUE(desc.Equals(desc_deserialized));
-
- FlightInfo::Data data;
- std::shared_ptr schema =
- arrow::schema({field("a", int64()), field("b", int64()), field("c", int64()),
- field("d", int64())});
- Location location1, location2, location3;
- ASSERT_OK(Location::ForGrpcTcp("localhost", 10010, &location1));
- ASSERT_OK(Location::ForGrpcTls("localhost", 10010, &location2));
- ASSERT_OK(Location::ForGrpcUnix("/tmp/test.sock", &location3));
- std::vector endpoints{FlightEndpoint{ticket, {location1, location2}},
- FlightEndpoint{ticket, {location3}}};
- ASSERT_OK(MakeFlightInfo(*schema, desc, endpoints, -1, -1, &data));
- std::unique_ptr info = std::unique_ptr(new FlightInfo(data));
- ASSERT_OK_AND_ASSIGN(std::string info_serialized, info->SerializeToString());
- ASSERT_OK_AND_ASSIGN(std::unique_ptr info_deserialized,
- FlightInfo::Deserialize(info_serialized));
- ASSERT_TRUE(info->descriptor().Equals(info_deserialized->descriptor()));
- ASSERT_EQ(info->endpoints(), info_deserialized->endpoints());
- ASSERT_EQ(info->total_records(), info_deserialized->total_records());
- ASSERT_EQ(info->total_bytes(), info_deserialized->total_bytes());
-}
-
-TEST(TestFlight, RoundtripStatus) {
- // Make sure status codes round trip through our conversions
-
- std::shared_ptr detail;
- detail = FlightStatusDetail::UnwrapStatus(
- MakeFlightError(FlightStatusCode::Internal, "Test message"));
- ASSERT_NE(nullptr, detail);
- ASSERT_EQ(FlightStatusCode::Internal, detail->code());
-
- detail = FlightStatusDetail::UnwrapStatus(
- MakeFlightError(FlightStatusCode::TimedOut, "Test message"));
- ASSERT_NE(nullptr, detail);
- ASSERT_EQ(FlightStatusCode::TimedOut, detail->code());
-
- detail = FlightStatusDetail::UnwrapStatus(
- MakeFlightError(FlightStatusCode::Cancelled, "Test message"));
- ASSERT_NE(nullptr, detail);
- ASSERT_EQ(FlightStatusCode::Cancelled, detail->code());
-
- detail = FlightStatusDetail::UnwrapStatus(
- MakeFlightError(FlightStatusCode::Unauthenticated, "Test message"));
- ASSERT_NE(nullptr, detail);
- ASSERT_EQ(FlightStatusCode::Unauthenticated, detail->code());
-
- detail = FlightStatusDetail::UnwrapStatus(
- MakeFlightError(FlightStatusCode::Unauthorized, "Test message"));
- ASSERT_NE(nullptr, detail);
- ASSERT_EQ(FlightStatusCode::Unauthorized, detail->code());
-
- detail = FlightStatusDetail::UnwrapStatus(
- MakeFlightError(FlightStatusCode::Unavailable, "Test message"));
- ASSERT_NE(nullptr, detail);
- ASSERT_EQ(FlightStatusCode::Unavailable, detail->code());
-
- Status status = internal::FromGrpcStatus(
- internal::ToGrpcStatus(Status::NotImplemented("Sentinel")));
- ASSERT_TRUE(status.IsNotImplemented());
- ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));
-
- status = internal::FromGrpcStatus(internal::ToGrpcStatus(Status::Invalid("Sentinel")));
- ASSERT_TRUE(status.IsInvalid());
- ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));
-
- status = internal::FromGrpcStatus(internal::ToGrpcStatus(Status::KeyError("Sentinel")));
- ASSERT_TRUE(status.IsKeyError());
- ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));
-
- status =
- internal::FromGrpcStatus(internal::ToGrpcStatus(Status::AlreadyExists("Sentinel")));
- ASSERT_TRUE(status.IsAlreadyExists());
- ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));
-}
-
TEST(TestFlight, GetPort) {
Location location;
std::unique_ptr server = ExampleTestServer();
@@ -418,7 +209,8 @@ class TestFlightClient : public ::testing::Test {
}
template
- void CheckDoGet(const FlightDescriptor& descr, const BatchVector& expected_batches,
+ void CheckDoGet(const FlightDescriptor& descr,
+ const RecordBatchVector& expected_batches,
EndpointCheckFunc&& check_endpoints) {
auto expected_schema = expected_batches[0]->schema();
@@ -436,7 +228,7 @@ class TestFlightClient : public ::testing::Test {
CheckDoGet(ticket, expected_batches);
}
- void CheckDoGet(const Ticket& ticket, const BatchVector& expected_batches) {
+ void CheckDoGet(const Ticket& ticket, const RecordBatchVector& expected_batches) {
auto num_batches = static_cast(expected_batches.size());
ASSERT_GE(num_batches, 2);
@@ -515,7 +307,7 @@ class DoPutTestServer : public FlightServerBase {
protected:
FlightDescriptor descriptor_;
- BatchVector batches_;
+ RecordBatchVector batches_;
friend class TestDoPut;
};
@@ -523,7 +315,7 @@ class DoPutTestServer : public FlightServerBase {
class MetadataTestServer : public FlightServerBase {
Status DoGet(const ServerCallContext& context, const Ticket& request,
std::unique_ptr* data_stream) override {
- BatchVector batches;
+ RecordBatchVector batches;
if (request.ticket == "dicts") {
RETURN_NOT_OK(ExampleDictBatches(&batches));
} else if (request.ticket == "floats") {
@@ -531,9 +323,7 @@ class MetadataTestServer : public FlightServerBase {
} else {
RETURN_NOT_OK(ExampleIntBatches(&batches));
}
- std::shared_ptr batch_reader =
- std::make_shared(batches[0]->schema(), batches);
-
+ ARROW_ASSIGN_OR_RAISE(auto batch_reader, RecordBatchReader::Make(batches));
*data_stream = std::unique_ptr(new NumberingStream(
std::unique_ptr(new RecordBatchStream(batch_reader))));
return Status::OK();
@@ -566,9 +356,9 @@ class MetadataTestServer : public FlightServerBase {
class OptionsTestServer : public FlightServerBase {
Status DoGet(const ServerCallContext& context, const Ticket& request,
std::unique_ptr* data_stream) override {
- BatchVector batches;
+ RecordBatchVector batches;
RETURN_NOT_OK(ExampleNestedBatches(&batches));
- auto reader = std::make_shared(batches[0]->schema(), batches);
+ ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make(batches));
*data_stream = std::unique_ptr(new RecordBatchStream(reader));
return Status::OK();
}
@@ -724,7 +514,7 @@ class TestDoPut : public ::testing::Test {
}
void CheckBatches(FlightDescriptor expected_descriptor,
- const BatchVector& expected_batches) {
+ const RecordBatchVector& expected_batches) {
ASSERT_TRUE(do_put_server_->descriptor_.Equals(expected_descriptor));
ASSERT_EQ(do_put_server_->batches_.size(), expected_batches.size());
for (size_t i = 0; i < expected_batches.size(); ++i) {
@@ -733,7 +523,7 @@ class TestDoPut : public ::testing::Test {
}
void CheckDoPut(FlightDescriptor descr, const std::shared_ptr& schema,
- const BatchVector& batches) {
+ const RecordBatchVector& batches) {
std::unique_ptr stream;
std::unique_ptr reader;
ASSERT_OK(client_->DoPut(descr, schema, &stream, &reader));
@@ -1264,139 +1054,6 @@ class TestBasicHeaderAuthMiddleware : public ::testing::Test {
std::shared_ptr bearer_middleware_;
};
-// This test keeps an internal cookie cache and compares that with the middleware.
-class TestCookieMiddleware : public ::testing::Test {
- public:
- // Setup function creates middleware factory and starts it up.
- void SetUp() {
- factory_ = GetCookieFactory();
- CallInfo callInfo;
- factory_->StartCall(callInfo, &middleware_);
- }
-
- // Function to add incoming cookies to middleware and validate them.
- void AddAndValidate(const std::string& incoming_cookie) {
- // Add cookie
- CallHeaders call_headers;
- call_headers.insert(std::make_pair(arrow::util::string_view("set-cookie"),
- arrow::util::string_view(incoming_cookie)));
- middleware_->ReceivedHeaders(call_headers);
- expected_cookie_cache_.UpdateCachedCookies(call_headers);
-
- // Get cookie from middleware.
- TestCallHeaders add_call_headers;
- middleware_->SendingHeaders(&add_call_headers);
- const std::string actual_cookies = add_call_headers.GetCookies();
-
- // Validate cookie
- const std::string expected_cookies = expected_cookie_cache_.GetValidCookiesAsString();
- const std::vector split_expected_cookies =
- SplitCookies(expected_cookies);
- const std::vector split_actual_cookies = SplitCookies(actual_cookies);
- EXPECT_EQ(split_expected_cookies, split_actual_cookies);
- }
-
- // Function to take a list of cookies and split them into a vector of individual
- // cookies. This is done because the cookie cache is a map so ordering is not
- // necessarily consistent.
- static std::vector SplitCookies(const std::string& cookies) {
- std::vector split_cookies;
- std::string::size_type pos1 = 0;
- std::string::size_type pos2 = 0;
- while ((pos2 = cookies.find(';', pos1)) != std::string::npos) {
- split_cookies.push_back(
- arrow::internal::TrimString(cookies.substr(pos1, pos2 - pos1)));
- pos1 = pos2 + 1;
- }
- if (pos1 < cookies.size()) {
- split_cookies.push_back(arrow::internal::TrimString(cookies.substr(pos1)));
- }
- std::sort(split_cookies.begin(), split_cookies.end());
- return split_cookies;
- }
-
- protected:
- // Class to allow testing of the call headers.
- class TestCallHeaders : public AddCallHeaders {
- public:
- TestCallHeaders() {}
- ~TestCallHeaders() {}
-
- // Function to add cookie header.
- void AddHeader(const std::string& key, const std::string& value) {
- ASSERT_EQ(key, "cookie");
- outbound_cookie_ = value;
- }
-
- // Function to get outgoing cookie.
- std::string GetCookies() { return outbound_cookie_; }
-
- private:
- std::string outbound_cookie_;
- };
-
- internal::CookieCache expected_cookie_cache_;
- std::unique_ptr middleware_;
- std::shared_ptr factory_;
-};
-
-// This test is used to test the parsing capabilities of the cookie framework.
-class TestCookieParsing : public ::testing::Test {
- public:
- void VerifyParseCookie(const std::string& cookie_str, bool expired) {
- internal::Cookie cookie = internal::Cookie::parse(cookie_str);
- EXPECT_EQ(expired, cookie.IsExpired());
- }
-
- void VerifyCookieName(const std::string& cookie_str, const std::string& name) {
- internal::Cookie cookie = internal::Cookie::parse(cookie_str);
- EXPECT_EQ(name, cookie.GetName());
- }
-
- void VerifyCookieString(const std::string& cookie_str,
- const std::string& cookie_as_string) {
- internal::Cookie cookie = internal::Cookie::parse(cookie_str);
- EXPECT_EQ(cookie_as_string, cookie.AsCookieString());
- }
-
- void VerifyCookieDateConverson(std::string date, const std::string& converted_date) {
- internal::Cookie::ConvertCookieDate(&date);
- EXPECT_EQ(converted_date, date);
- }
-
- void VerifyCookieAttributeParsing(
- const std::string cookie_str, std::string::size_type start_pos,
- const util::optional> cookie_attribute,
- const std::string::size_type start_pos_after) {
- util::optional> attr =
- internal::Cookie::ParseCookieAttribute(cookie_str, &start_pos);
-
- if (cookie_attribute == util::nullopt) {
- EXPECT_EQ(cookie_attribute, attr);
- } else {
- EXPECT_EQ(cookie_attribute.value(), attr.value());
- }
- EXPECT_EQ(start_pos_after, start_pos);
- }
-
- void AddCookieVerifyCache(const std::vector& cookies,
- const std::string& expected_cookies) {
- internal::CookieCache cookie_cache;
- for (auto& cookie : cookies) {
- // Add cookie
- CallHeaders call_headers;
- call_headers.insert(std::make_pair(arrow::util::string_view("set-cookie"),
- arrow::util::string_view(cookie)));
- cookie_cache.UpdateCachedCookies(call_headers);
- }
- const std::string actual_cookies = cookie_cache.GetValidCookiesAsString();
- const std::vector actual_split_cookies =
- TestCookieMiddleware::SplitCookies(actual_cookies);
- const std::vector expected_split_cookies =
- TestCookieMiddleware::SplitCookies(expected_cookies);
- }
-};
-
TEST_F(TestErrorMiddleware, TestMetadata) {
Action action;
std::unique_ptr stream;
@@ -1474,13 +1131,13 @@ TEST_F(TestFlightClient, GetFlightInfoNotFound) {
TEST_F(TestFlightClient, DoGetInts) {
auto descr = FlightDescriptor::Path({"examples", "ints"});
- BatchVector expected_batches;
+ RecordBatchVector expected_batches;
ASSERT_OK(ExampleIntBatches(&expected_batches));
auto check_endpoints = [](const std::vector& endpoints) {
// Two endpoints in the example FlightInfo
ASSERT_EQ(2, endpoints.size());
- AssertEqual(Ticket{"ticket-ints-1"}, endpoints[0].ticket);
+ ASSERT_EQ(Ticket{"ticket-ints-1"}, endpoints[0].ticket);
};
CheckDoGet(descr, expected_batches, check_endpoints);
@@ -1488,13 +1145,13 @@ TEST_F(TestFlightClient, DoGetInts) {
TEST_F(TestFlightClient, DoGetFloats) {
auto descr = FlightDescriptor::Path({"examples", "floats"});
- BatchVector expected_batches;
+ RecordBatchVector expected_batches;
ASSERT_OK(ExampleFloatBatches(&expected_batches));
auto check_endpoints = [](const std::vector& endpoints) {
// One endpoint in the example FlightInfo
ASSERT_EQ(1, endpoints.size());
- AssertEqual(Ticket{"ticket-floats-1"}, endpoints[0].ticket);
+ ASSERT_EQ(Ticket{"ticket-floats-1"}, endpoints[0].ticket);
};
CheckDoGet(descr, expected_batches, check_endpoints);
@@ -1502,13 +1159,13 @@ TEST_F(TestFlightClient, DoGetFloats) {
TEST_F(TestFlightClient, DoGetDicts) {
auto descr = FlightDescriptor::Path({"examples", "dicts"});
- BatchVector expected_batches;
+ RecordBatchVector expected_batches;
ASSERT_OK(ExampleDictBatches(&expected_batches));
auto check_endpoints = [](const std::vector& endpoints) {
// One endpoint in the example FlightInfo
ASSERT_EQ(1, endpoints.size());
- AssertEqual(Ticket{"ticket-dicts-1"}, endpoints[0].ticket);
+ ASSERT_EQ(Ticket{"ticket-dicts-1"}, endpoints[0].ticket);
};
CheckDoGet(descr, expected_batches, check_endpoints);
@@ -1517,7 +1174,7 @@ TEST_F(TestFlightClient, DoGetDicts) {
// Ensure the gRPC client is configured to allow large messages
// Tests a 32 MiB batch
TEST_F(TestFlightClient, DoGetLargeBatch) {
- BatchVector expected_batches;
+ RecordBatchVector expected_batches;
ASSERT_OK(ExampleLargeBatches(&expected_batches));
Ticket ticket{"ticket-large-batch-1"};
CheckDoGet(ticket, expected_batches);
@@ -1542,7 +1199,7 @@ TEST_F(TestFlightClient, FlightDataOverflowServerBatch) {
std::unique_ptr reader;
std::unique_ptr writer;
ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
- BatchVector batches;
+ RecordBatchVector batches;
EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
reader->ReadAll(&batches));
@@ -1578,7 +1235,7 @@ TEST_F(TestFlightClient, FlightDataOverflowClientBatch) {
TEST_F(TestFlightClient, DoExchange) {
auto descr = FlightDescriptor::Command("counter");
- BatchVector batches;
+ RecordBatchVector batches;
auto a1 = ArrayFromJSON(int32(), "[4, 5, 6, null]");
auto schema = arrow::schema({field("f1", a1->type())});
batches.push_back(RecordBatch::Make(schema, a1->length(), {a1}));
@@ -1647,7 +1304,7 @@ TEST_F(TestFlightClient, DoExchangeGet) {
ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
ASSERT_OK_AND_ASSIGN(auto server_schema, reader->GetSchema());
AssertSchemaEqual(*ExampleIntSchema(), *server_schema);
- BatchVector batches;
+ RecordBatchVector batches;
ASSERT_OK(ExampleIntBatches(&batches));
FlightStreamChunk chunk;
for (const auto& batch : batches) {
@@ -1668,7 +1325,7 @@ TEST_F(TestFlightClient, DoExchangePut) {
std::unique_ptr writer;
ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
ASSERT_OK(writer->Begin(ExampleIntSchema()));
- BatchVector batches;
+ RecordBatchVector batches;
ASSERT_OK(ExampleIntBatches(&batches));
for (const auto& batch : batches) {
ASSERT_OK(writer->WriteRecordBatch(*batch));
@@ -1691,7 +1348,7 @@ TEST_F(TestFlightClient, DoExchangeEcho) {
std::unique_ptr writer;
ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
ASSERT_OK(writer->Begin(ExampleIntSchema()));
- BatchVector batches;
+ RecordBatchVector batches;
FlightStreamChunk chunk;
ASSERT_OK(ExampleIntBatches(&batches));
for (const auto& batch : batches) {
@@ -1813,7 +1470,7 @@ TEST_F(TestFlightClient, ListActions) {
ASSERT_OK(client_->ListActions(&actions));
std::vector expected = ExampleActionTypes();
- AssertEqual(expected, actions);
+ EXPECT_THAT(actions, ::testing::ContainerEq(expected));
}
TEST_F(TestFlightClient, DoAction) {
@@ -1938,7 +1595,7 @@ TEST_F(TestFlightClient, Close) {
TEST_F(TestDoPut, DoPutInts) {
auto descr = FlightDescriptor::Path({"ints"});
- BatchVector batches;
+ RecordBatchVector batches;
auto a0 = ArrayFromJSON(int8(), "[0, 1, 127, -128, null]");
auto a1 = ArrayFromJSON(uint8(), "[0, 1, 127, 255, null]");
auto a2 = ArrayFromJSON(int16(), "[0, 258, 32767, -32768, null]");
@@ -1961,7 +1618,7 @@ TEST_F(TestDoPut, DoPutInts) {
TEST_F(TestDoPut, DoPutFloats) {
auto descr = FlightDescriptor::Path({"floats"});
- BatchVector batches;
+ RecordBatchVector batches;
auto a0 = ArrayFromJSON(float32(), "[0, 1.2, -3.4, 5.6, null]");
auto a1 = ArrayFromJSON(float64(), "[0, 1.2, -3.4, 5.6, null]");
auto schema = arrow::schema({field("f0", a0->type()), field("f1", a1->type())});
@@ -1973,7 +1630,7 @@ TEST_F(TestDoPut, DoPutFloats) {
TEST_F(TestDoPut, DoPutEmptyBatch) {
// Sending and receiving a 0-sized batch shouldn't fail
auto descr = FlightDescriptor::Path({"ints"});
- BatchVector batches;
+ RecordBatchVector batches;
auto a1 = ArrayFromJSON(int32(), "[]");
auto schema = arrow::schema({field("f1", a1->type())});
batches.push_back(RecordBatch::Make(schema, a1->length(), {a1}));
@@ -1983,7 +1640,7 @@ TEST_F(TestDoPut, DoPutEmptyBatch) {
TEST_F(TestDoPut, DoPutDicts) {
auto descr = FlightDescriptor::Path({"dicts"});
- BatchVector batches;
+ RecordBatchVector batches;
auto dict_values = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"quux\"]");
auto ty = dictionary(int8(), dict_values->type());
auto schema = arrow::schema({field("f1", ty)});
@@ -2002,7 +1659,7 @@ TEST_F(TestDoPut, DoPutDicts) {
TEST_F(TestDoPut, DoPutLargeBatch) {
auto descr = FlightDescriptor::Path({"large-batches"});
auto schema = ExampleLargeSchema();
- BatchVector batches;
+ RecordBatchVector batches;
ASSERT_OK(ExampleLargeBatches(&batches));
CheckDoPut(descr, schema, batches);
}
@@ -2020,7 +1677,7 @@ TEST_F(TestDoPut, DoPutSizeLimit) {
// Batch is too large to fit in one message
auto schema = arrow::schema({field("f1", arrow::int64())});
auto batch = arrow::ConstantArrayGenerator::Zeroes(768, schema);
- BatchVector batches;
+ RecordBatchVector batches;
batches.push_back(batch->Slice(0, 384));
batches.push_back(batch->Slice(384));
@@ -2343,7 +2000,7 @@ TEST_F(TestMetadata, DoGet) {
std::unique_ptr stream;
ASSERT_OK(client_->DoGet(ticket, &stream));
- BatchVector expected_batches;
+ RecordBatchVector expected_batches;
ASSERT_OK(ExampleIntBatches(&expected_batches));
FlightStreamChunk chunk;
@@ -2368,7 +2025,7 @@ TEST_F(TestMetadata, DoGetDictionaries) {
std::unique_ptr stream;
ASSERT_OK(client_->DoGet(ticket, &stream));
- BatchVector expected_batches;
+ RecordBatchVector expected_batches;
ASSERT_OK(ExampleDictBatches(&expected_batches));
FlightStreamChunk chunk;
@@ -2390,7 +2047,7 @@ TEST_F(TestMetadata, DoPut) {
std::shared_ptr schema = ExampleIntSchema();
ASSERT_OK(client_->DoPut(FlightDescriptor{}, schema, &writer, &reader));
- BatchVector expected_batches;
+ RecordBatchVector expected_batches;
ASSERT_OK(ExampleIntBatches(&expected_batches));
std::shared_ptr chunk;
@@ -2411,7 +2068,7 @@ TEST_F(TestMetadata, DoPut) {
TEST_F(TestMetadata, DoPutDictionaries) {
std::unique_ptr writer;
std::unique_ptr reader;
- BatchVector expected_batches;
+ RecordBatchVector expected_batches;
ASSERT_OK(ExampleDictBatches(&expected_batches));
// ARROW-8749: don't get the schema via ExampleDictSchema because
// DictionaryMemo uses field addresses to determine whether it's
@@ -2437,7 +2094,7 @@ TEST_F(TestMetadata, DoPutReadMetadata) {
std::shared_ptr schema = ExampleIntSchema();
ASSERT_OK(client_->DoPut(FlightDescriptor{}, schema, &writer, &reader));
- BatchVector expected_batches;
+ RecordBatchVector expected_batches;
ASSERT_OK(ExampleIntBatches(&expected_batches));
std::shared_ptr chunk;
@@ -2470,7 +2127,7 @@ TEST_F(TestOptions, DoPutWriteOptions) {
// Call DoPut, but with a very low write nesting depth set to fail the call.
std::unique_ptr writer;
std::unique_ptr reader;
- BatchVector expected_batches;
+ RecordBatchVector expected_batches;
ASSERT_OK(ExampleNestedBatches(&expected_batches));
auto options = FlightCallOptions();
@@ -2491,7 +2148,7 @@ TEST_F(TestOptions, DoExchangeClientWriteOptions) {
std::unique_ptr reader;
std::unique_ptr writer;
ASSERT_OK(client_->DoExchange(options, descr, &writer, &reader));
- BatchVector batches;
+ RecordBatchVector batches;
ASSERT_OK(ExampleNestedBatches(&batches));
ASSERT_OK(writer->Begin(batches[0]->schema()));
for (const auto& batch : batches) {
@@ -2509,7 +2166,7 @@ TEST_F(TestOptions, DoExchangeClientWriteOptionsBegin) {
std::unique_ptr reader;
std::unique_ptr writer;
ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
- BatchVector batches;
+ RecordBatchVector batches;
ASSERT_OK(ExampleNestedBatches(&batches));
auto options = ipc::IpcWriteOptions::Defaults();
options.max_recursion_depth = 1;
@@ -2528,7 +2185,7 @@ TEST_F(TestOptions, DoExchangeServerWriteOptions) {
std::unique_ptr reader;
std::unique_ptr writer;
ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
- BatchVector batches;
+ RecordBatchVector batches;
ASSERT_OK(ExampleNestedBatches(&batches));
ASSERT_OK(writer->Begin(batches[0]->schema()));
FlightStreamChunk chunk;
@@ -2648,140 +2305,6 @@ TEST_F(TestBasicHeaderAuthMiddleware, ValidCredentials) { RunValidClientAuth();
TEST_F(TestBasicHeaderAuthMiddleware, InvalidCredentials) { RunInvalidClientAuth(); }
-TEST_F(TestCookieMiddleware, BasicParsing) {
- AddAndValidate("id1=1; foo=bar;");
- AddAndValidate("id1=1; foo=bar");
- AddAndValidate("id2=2;");
- AddAndValidate("id4=\"4\"");
- AddAndValidate("id5=5; foo=bar; baz=buz;");
-}
-
-TEST_F(TestCookieMiddleware, Overwrite) {
- AddAndValidate("id0=0");
- AddAndValidate("id0=1");
- AddAndValidate("id1=0");
- AddAndValidate("id1=1");
- AddAndValidate("id1=1");
- AddAndValidate("id1=10");
- AddAndValidate("id=3");
- AddAndValidate("id=0");
- AddAndValidate("id=0");
-}
-
-TEST_F(TestCookieMiddleware, MaxAge) {
- AddAndValidate("id0=0; max-age=0;");
- AddAndValidate("id1=0; max-age=-1;");
- AddAndValidate("id2=0; max-age=0");
- AddAndValidate("id3=0; max-age=-1");
- AddAndValidate("id4=0; max-age=1");
- AddAndValidate("id5=0; max-age=1");
- AddAndValidate("id4=0; max-age=0");
- AddAndValidate("id5=0; max-age=0");
-}
-
-TEST_F(TestCookieMiddleware, Expires) {
- AddAndValidate("id0=0; expires=0, 0 0 0 0:0:0 GMT;");
- AddAndValidate("id0=0; expires=0, 0 0 0 0:0:0 GMT");
- AddAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT;");
- AddAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT");
- AddAndValidate("id0=0; expires=Fri, 01 Jan 2038 22:15:36 GMT;");
- AddAndValidate("id1=0; expires=Fri, 01 Jan 2038 22:15:36 GMT");
- AddAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT;");
- AddAndValidate("id1=0; expires=Fri, 22 Dec 2017 22:15:36 GMT");
-}
-
-TEST_F(TestCookieParsing, Expired) {
- VerifyParseCookie("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT;", true);
- VerifyParseCookie("id1=0; max-age=-1;", true);
- VerifyParseCookie("id0=0; max-age=0;", true);
-}
-
-TEST_F(TestCookieParsing, Invalid) {
- VerifyParseCookie("id1=0; expires=0, 0 0 0 0:0:0 GMT;", true);
- VerifyParseCookie("id1=0; expires=Fri, 01 FOO 2038 22:15:36 GMT", true);
- VerifyParseCookie("id1=0; expires=foo", true);
- VerifyParseCookie("id1=0; expires=", true);
- VerifyParseCookie("id1=0; max-age=FOO", true);
- VerifyParseCookie("id1=0; max-age=", true);
-}
-
-TEST_F(TestCookieParsing, NoExpiry) {
- VerifyParseCookie("id1=0;", false);
- VerifyParseCookie("id1=0; noexpiry=Fri, 01 Jan 2038 22:15:36 GMT", false);
- VerifyParseCookie("id1=0; noexpiry=\"Fri, 01 Jan 2038 22:15:36 GMT\"", false);
- VerifyParseCookie("id1=0; nomax-age=-1", false);
- VerifyParseCookie("id1=0; nomax-age=\"-1\"", false);
- VerifyParseCookie("id1=0; randomattr=foo", false);
-}
-
-TEST_F(TestCookieParsing, NotExpired) {
- VerifyParseCookie("id5=0; max-age=1", false);
- VerifyParseCookie("id0=0; expires=Fri, 01 Jan 2038 22:15:36 GMT;", false);
-}
-
-TEST_F(TestCookieParsing, GetName) {
- VerifyCookieName("id1=1; foo=bar;", "id1");
- VerifyCookieName("id1=1; foo=bar", "id1");
- VerifyCookieName("id2=2;", "id2");
- VerifyCookieName("id4=\"4\"", "id4");
- VerifyCookieName("id5=5; foo=bar; baz=buz;", "id5");
-}
-
-TEST_F(TestCookieParsing, ToString) {
- VerifyCookieString("id1=1; foo=bar;", "id1=1");
- VerifyCookieString("id1=1; foo=bar", "id1=1");
- VerifyCookieString("id2=2;", "id2=2");
- VerifyCookieString("id4=\"4\"", "id4=4");
- VerifyCookieString("id5=5; foo=bar; baz=buz;", "id5=5");
-}
-
-TEST_F(TestCookieParsing, DateConversion) {
- VerifyCookieDateConverson("Mon, 01 jan 2038 22:15:36 GMT;", "01 01 2038 22:15:36");
- VerifyCookieDateConverson("TUE, 10 Feb 2038 22:15:36 GMT", "10 02 2038 22:15:36");
- VerifyCookieDateConverson("WED, 20 MAr 2038 22:15:36 GMT;", "20 03 2038 22:15:36");
- VerifyCookieDateConverson("thu, 15 APR 2038 22:15:36 GMT", "15 04 2038 22:15:36");
- VerifyCookieDateConverson("Fri, 30 mAY 2038 22:15:36 GMT;", "30 05 2038 22:15:36");
- VerifyCookieDateConverson("Sat, 03 juN 2038 22:15:36 GMT", "03 06 2038 22:15:36");
- VerifyCookieDateConverson("Sun, 01 JuL 2038 22:15:36 GMT;", "01 07 2038 22:15:36");
- VerifyCookieDateConverson("Fri, 06 aUg 2038 22:15:36 GMT", "06 08 2038 22:15:36");
- VerifyCookieDateConverson("Fri, 01 SEP 2038 22:15:36 GMT;", "01 09 2038 22:15:36");
- VerifyCookieDateConverson("Fri, 01 OCT 2038 22:15:36 GMT", "01 10 2038 22:15:36");
- VerifyCookieDateConverson("Fri, 01 Nov 2038 22:15:36 GMT;", "01 11 2038 22:15:36");
- VerifyCookieDateConverson("Fri, 01 deC 2038 22:15:36 GMT", "01 12 2038 22:15:36");
- VerifyCookieDateConverson("", "");
- VerifyCookieDateConverson("Fri, 01 INVALID 2038 22:15:36 GMT;",
- "01 INVALID 2038 22:15:36");
-}
-
-TEST_F(TestCookieParsing, ParseCookieAttribute) {
- VerifyCookieAttributeParsing("", 0, util::nullopt, std::string::npos);
-
- std::string cookie_string = "attr0=0; attr1=1; attr2=2; attr3=3";
- auto attr_length = std::string("attr0=0;").length();
- std::string::size_type start_pos = 0;
- VerifyCookieAttributeParsing(cookie_string, start_pos, std::make_pair("attr0", "0"),
- cookie_string.find("attr0=0;") + attr_length);
- VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length + 1)),
- std::make_pair("attr1", "1"),
- cookie_string.find("attr1=1;") + attr_length);
- VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length + 1)),
- std::make_pair("attr2", "2"),
- cookie_string.find("attr2=2;") + attr_length);
- VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length + 1)),
- std::make_pair("attr3", "3"), std::string::npos);
- VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length - 1)),
- util::nullopt, std::string::npos);
- VerifyCookieAttributeParsing(cookie_string, std::string::npos, util::nullopt,
- std::string::npos);
-}
-
-TEST_F(TestCookieParsing, CookieCache) {
- AddCookieVerifyCache({"id0=0;"}, "");
- AddCookieVerifyCache({"id0=0;", "id0=1;"}, "id0=1");
- AddCookieVerifyCache({"id0=0;", "id1=1;"}, "id0=0; id1=1");
- AddCookieVerifyCache({"id0=0;", "id1=1;", "id2=2"}, "id0=0; id1=1; id2=2");
-}
-
class ForeverFlightListing : public FlightListing {
Status Next(std::unique_ptr* info) override {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
diff --git a/cpp/src/arrow/flight/test_util.cc b/cpp/src/arrow/flight/test_util.cc
index 2325445f86c..4405c24eea5 100644
--- a/cpp/src/arrow/flight/test_util.cc
+++ b/cpp/src/arrow/flight/test_util.cc
@@ -151,24 +151,24 @@ const std::string& TestServer::unix_sock() const { return unix_sock_; }
Status GetBatchForFlight(const Ticket& ticket, std::shared_ptr* out) {
if (ticket.ticket == "ticket-ints-1") {
- BatchVector batches;
+ RecordBatchVector batches;
RETURN_NOT_OK(ExampleIntBatches(&batches));
- *out = std::make_shared(batches[0]->schema(), batches);
+ ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches));
return Status::OK();
} else if (ticket.ticket == "ticket-floats-1") {
- BatchVector batches;
+ RecordBatchVector batches;
RETURN_NOT_OK(ExampleFloatBatches(&batches));
- *out = std::make_shared(batches[0]->schema(), batches);
+ ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches));
return Status::OK();
} else if (ticket.ticket == "ticket-dicts-1") {
- BatchVector batches;
+ RecordBatchVector batches;
RETURN_NOT_OK(ExampleDictBatches(&batches));
- *out = std::make_shared(batches[0]->schema(), batches);
+ ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches));
return Status::OK();
} else if (ticket.ticket == "ticket-large-batch-1") {
- BatchVector batches;
+ RecordBatchVector batches;
RETURN_NOT_OK(ExampleLargeBatches(&batches));
- *out = std::make_shared(batches[0]->schema(), batches);
+ ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches));
return Status::OK();
} else {
return Status::NotImplemented("no stream implemented for ticket: " + ticket.ticket);
@@ -233,7 +233,7 @@ class FlightTestServer : public FlightServerBase {
Status DoPut(const ServerCallContext&, std::unique_ptr reader,
std::unique_ptr writer) override {
- BatchVector batches;
+ RecordBatchVector batches;
return reader->ReadAll(&batches);
}
@@ -270,7 +270,7 @@ class FlightTestServer : public FlightServerBase {
Status RunExchangeGet(std::unique_ptr reader,
std::unique_ptr writer) {
RETURN_NOT_OK(writer->Begin(ExampleIntSchema()));
- BatchVector batches;
+ RecordBatchVector batches;
RETURN_NOT_OK(ExampleIntBatches(&batches));
for (const auto& batch : batches) {
RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
@@ -285,7 +285,7 @@ class FlightTestServer : public FlightServerBase {
if (!schema->Equals(ExampleIntSchema(), false)) {
return Status::Invalid("Schema is not as expected");
}
- BatchVector batches;
+ RecordBatchVector batches;
RETURN_NOT_OK(ExampleIntBatches(&batches));
FlightStreamChunk chunk;
for (const auto& batch : batches) {
@@ -590,7 +590,7 @@ std::vector ExampleFlightInfo() {
FlightInfo(flight4)};
}
-Status ExampleIntBatches(BatchVector* out) {
+Status ExampleIntBatches(RecordBatchVector* out) {
std::shared_ptr batch;
for (int i = 0; i < 5; ++i) {
// Make all different sizes, use different random seed
@@ -600,7 +600,7 @@ Status ExampleIntBatches(BatchVector* out) {
return Status::OK();
}
-Status ExampleFloatBatches(BatchVector* out) {
+Status ExampleFloatBatches(RecordBatchVector* out) {
std::shared_ptr batch;
for (int i = 0; i < 5; ++i) {
// Make all different sizes, use different random seed
@@ -610,7 +610,7 @@ Status ExampleFloatBatches(BatchVector* out) {
return Status::OK();
}
-Status ExampleDictBatches(BatchVector* out) {
+Status ExampleDictBatches(RecordBatchVector* out) {
// Just the same batch, repeated a few times
std::shared_ptr batch;
for (int i = 0; i < 3; ++i) {
@@ -620,7 +620,7 @@ Status ExampleDictBatches(BatchVector* out) {
return Status::OK();
}
-Status ExampleNestedBatches(BatchVector* out) {
+Status ExampleNestedBatches(RecordBatchVector* out) {
std::shared_ptr batch;
for (int i = 0; i < 3; ++i) {
RETURN_NOT_OK(ipc::test::MakeListRecordBatch(&batch));
@@ -629,7 +629,7 @@ Status ExampleNestedBatches(BatchVector* out) {
return Status::OK();
}
-Status ExampleLargeBatches(BatchVector* out) {
+Status ExampleLargeBatches(RecordBatchVector* out) {
const auto array_length = 32768;
std::shared_ptr batch;
std::vector> arrays;
diff --git a/cpp/src/arrow/flight/test_util.h b/cpp/src/arrow/flight/test_util.h
index a3c7350fa74..75a534d56b3 100644
--- a/cpp/src/arrow/flight/test_util.h
+++ b/cpp/src/arrow/flight/test_util.h
@@ -17,6 +17,9 @@
#pragma once
+#include
+#include
+
#include
#include
#include
@@ -25,6 +28,7 @@
#include
#include "arrow/status.h"
+#include "arrow/testing/gtest_util.h"
#include "arrow/testing/util.h"
#include "arrow/flight/client.h"
@@ -46,6 +50,24 @@ class child;
namespace arrow {
namespace flight {
+// ----------------------------------------------------------------------
+// Helpers to compare values for equality
+
+inline void AssertEqual(const FlightInfo& expected, const FlightInfo& actual) {
+ std::shared_ptr ex_schema, actual_schema;
+ ipc::DictionaryMemo expected_memo;
+ ipc::DictionaryMemo actual_memo;
+ ASSERT_OK(expected.GetSchema(&expected_memo, &ex_schema));
+ ASSERT_OK(actual.GetSchema(&actual_memo, &actual_schema));
+
+ AssertSchemaEqual(*ex_schema, *actual_schema);
+ ASSERT_EQ(expected.total_records(), actual.total_records());
+ ASSERT_EQ(expected.total_bytes(), actual.total_bytes());
+
+ ASSERT_EQ(expected.descriptor(), actual.descriptor());
+ ASSERT_THAT(actual.endpoints(), ::testing::ContainerEq(expected.endpoints()));
+}
+
// ----------------------------------------------------------------------
// Fixture to use for running test servers
@@ -100,43 +122,6 @@ Status MakeServer(std::unique_ptr* server,
return FlightClient::Connect(real_location, client_options, client);
}
-// ----------------------------------------------------------------------
-// A RecordBatchReader for serving a sequence of in-memory record batches
-
-// Silence warning
-// "non dll-interface class RecordBatchReader used as base for dll-interface class"
-#ifdef _MSC_VER
-#pragma warning(push)
-#pragma warning(disable : 4275)
-#endif
-
-class ARROW_FLIGHT_EXPORT BatchIterator : public RecordBatchReader {
- public:
- BatchIterator(const std::shared_ptr& schema,
- const std::vector>& batches)
- : schema_(schema), batches_(batches), position_(0) {}
-
- std::shared_ptr schema() const override { return schema_; }
-
- Status ReadNext(std::shared_ptr* out) override {
- if (position_ >= batches_.size()) {
- *out = nullptr;
- } else {
- *out = batches_[position_++];
- }
- return Status::OK();
- }
-
- private:
- std::shared_ptr schema_;
- std::vector> batches_;
- size_t position_;
-};
-
-#ifdef _MSC_VER
-#pragma warning(pop)
-#endif
-
// ----------------------------------------------------------------------
// A FlightDataStream that numbers the record batches
/// \brief A basic implementation of FlightDataStream that will provide
@@ -157,8 +142,6 @@ class ARROW_FLIGHT_EXPORT NumberingStream : public FlightDataStream {
// ----------------------------------------------------------------------
// Example data for test-server and unit tests
-using BatchVector = std::vector>;
-
ARROW_FLIGHT_EXPORT
std::shared_ptr ExampleIntSchema();
@@ -172,19 +155,19 @@ ARROW_FLIGHT_EXPORT
std::shared_ptr ExampleLargeSchema();
ARROW_FLIGHT_EXPORT
-Status ExampleIntBatches(BatchVector* out);
+Status ExampleIntBatches(RecordBatchVector* out);
ARROW_FLIGHT_EXPORT
-Status ExampleFloatBatches(BatchVector* out);
+Status ExampleFloatBatches(RecordBatchVector* out);
ARROW_FLIGHT_EXPORT
-Status ExampleDictBatches(BatchVector* out);
+Status ExampleDictBatches(RecordBatchVector* out);
ARROW_FLIGHT_EXPORT
-Status ExampleNestedBatches(BatchVector* out);
+Status ExampleNestedBatches(RecordBatchVector* out);
ARROW_FLIGHT_EXPORT
-Status ExampleLargeBatches(BatchVector* out);
+Status ExampleLargeBatches(RecordBatchVector* out);
ARROW_FLIGHT_EXPORT
arrow::Result> VeryLargeBatch();
diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc
index 2567bdfedd9..5dbba6fba5f 100644
--- a/cpp/src/arrow/flight/types.cc
+++ b/cpp/src/arrow/flight/types.cc
@@ -324,6 +324,10 @@ bool FlightEndpoint::Equals(const FlightEndpoint& other) const {
return ticket == other.ticket && locations == other.locations;
}
+bool ActionType::Equals(const ActionType& other) const {
+ return type == other.type && description == other.description;
+}
+
Status MetadataRecordBatchReader::ReadAll(
std::vector>* batches) {
FlightStreamChunk chunk;
diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h
index 04a9b4b716f..bc68cc280b0 100644
--- a/cpp/src/arrow/flight/types.h
+++ b/cpp/src/arrow/flight/types.h
@@ -139,6 +139,15 @@ struct ARROW_FLIGHT_EXPORT ActionType {
/// \brief A human-readable description of the action.
std::string description;
+
+ bool Equals(const ActionType& other) const;
+
+ friend bool operator==(const ActionType& left, const ActionType& right) {
+ return left.Equals(right);
+ }
+ friend bool operator!=(const ActionType& left, const ActionType& right) {
+ return !(left == right);
+ }
};
/// \brief Opaque selection criteria for ListFlights RPC
@@ -406,9 +415,9 @@ class ARROW_FLIGHT_EXPORT FlightInfo {
const std::vector& endpoints,
int64_t total_records, int64_t total_bytes);
- /// \brief Deserialize the Arrow schema of the dataset, to be passed
- /// to each call to DoGet. Populate any dictionary encoded fields
- /// into a DictionaryMemo for bookkeeping
+ /// \brief Deserialize the Arrow schema of the dataset. Populate any
+ /// dictionary encoded fields into a DictionaryMemo for
+ /// bookkeeping
/// \param[in,out] dictionary_memo for dictionary bookkeeping, will
/// be modified
/// \param[out] out the reconstructed Schema
From b17d1a0f35e1df9d0d55ea0e238554e2bc6cc52a Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 24 Feb 2022 12:36:26 -0500
Subject: [PATCH 2/5] ARROW-15707: [C++][FlightRPC] Make Flight data plane
tests reusable
---
cpp/src/arrow/flight/CMakeLists.txt | 1 +
cpp/src/arrow/flight/flight_test.cc | 924 +------------------
cpp/src/arrow/flight/test_definitions.cc | 1029 ++++++++++++++++++++++
cpp/src/arrow/flight/test_definitions.h | 114 +++
cpp/src/arrow/flight/types.cc | 9 +
cpp/src/arrow/flight/types.h | 4 +
6 files changed, 1166 insertions(+), 915 deletions(-)
create mode 100644 cpp/src/arrow/flight/test_definitions.cc
create mode 100644 cpp/src/arrow/flight/test_definitions.h
diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt
index 14eebc262ee..cae36562be3 100644
--- a/cpp/src/arrow/flight/CMakeLists.txt
+++ b/cpp/src/arrow/flight/CMakeLists.txt
@@ -213,6 +213,7 @@ if(ARROW_TESTING)
OUTPUTS
ARROW_FLIGHT_TESTING_LIBRARIES
SOURCES
+ test_definitions.cc
test_util.cc
DEPENDENCIES
GTest::gtest
diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc
index 96bc693486e..1626bb6d41d 100644
--- a/cpp/src/arrow/flight/flight_test.cc
+++ b/cpp/src/arrow/flight/flight_test.cc
@@ -46,6 +46,7 @@
#include "arrow/flight/internal.h"
#include "arrow/flight/middleware_internal.h"
+#include "arrow/flight/test_definitions.h"
#include "arrow/flight/test_util.h"
namespace arrow {
@@ -62,6 +63,12 @@ const char kBasicPrefix[] = "Basic ";
const char kBearerPrefix[] = "Bearer ";
const char kAuthHeader[] = "authorization";
+INSTANTIATE_TEST_SUITE_P(GrpcConnectivity, ConnectivityTest, testing::Values("grpc"));
+INSTANTIATE_TEST_SUITE_P(GrpcData, DataTest, testing::Values("grpc"));
+INSTANTIATE_TEST_SUITE_P(GrpcDoPut, DoPutTest, testing::Values("grpc"));
+INSTANTIATE_TEST_SUITE_P(GrpcAppMetadata, AppMetadataTest, testing::Values("grpc"));
+INSTANTIATE_TEST_SUITE_P(GrpcIpcOptions, IpcOptionsTest, testing::Values("grpc"));
+
TEST(TestFlight, ConnectUri) {
TestServer server("flight-test-server");
server.Start();
@@ -104,16 +111,6 @@ TEST(TestFlight, ConnectUriUnix) {
}
#endif
-TEST(TestFlight, GetPort) {
- Location location;
- std::unique_ptr server = ExampleTestServer();
-
- ASSERT_OK(Location::ForGrpcTcp("localhost", 0, &location));
- FlightServerOptions options(location);
- ASSERT_OK(server->Init(options));
- ASSERT_GT(server->port(), 0);
-}
-
// CI environments don't have an IPv6 interface configured
TEST(TestFlight, DISABLED_IpV6Port) {
Location location, location2;
@@ -131,56 +128,6 @@ TEST(TestFlight, DISABLED_IpV6Port) {
ASSERT_OK(client->ListFlights(&listing));
}
-TEST(TestFlight, BuilderHook) {
- Location location;
- std::unique_ptr server = ExampleTestServer();
-
- ASSERT_OK(Location::ForGrpcTcp("localhost", 0, &location));
- FlightServerOptions options(location);
- bool builder_hook_run = false;
- options.builder_hook = [&builder_hook_run](void* builder) {
- ASSERT_NE(nullptr, builder);
- builder_hook_run = true;
- };
- ASSERT_OK(server->Init(options));
- ASSERT_TRUE(builder_hook_run);
- ASSERT_GT(server->port(), 0);
- ASSERT_OK(server->Shutdown());
-}
-
-TEST(TestFlight, ServeShutdown) {
- // Regression test for ARROW-15181
- constexpr int kIterations = 10;
- for (int i = 0; i < kIterations; i++) {
- Location location;
- std::unique_ptr server = ExampleTestServer();
-
- ASSERT_OK(Location::ForGrpcTcp("localhost", 0, &location));
- FlightServerOptions options(location);
- ASSERT_OK(server->Init(options));
- ASSERT_GT(server->port(), 0);
- std::thread t([&]() { ASSERT_OK(server->Serve()); });
- ASSERT_OK(server->Shutdown());
- ASSERT_OK(server->Wait());
- t.join();
- }
-}
-
-TEST(TestFlight, ServeShutdownWithDeadline) {
- Location location;
- std::unique_ptr server = ExampleTestServer();
-
- ASSERT_OK(Location::ForGrpcTcp("localhost", 0, &location));
- FlightServerOptions options(location);
- ASSERT_OK(server->Init(options));
- ASSERT_GT(server->port(), 0);
-
- auto deadline = std::chrono::system_clock::now() + std::chrono::microseconds(10);
-
- ASSERT_OK(server->Shutdown(&deadline));
- ASSERT_OK(server->Wait());
-}
-
// ----------------------------------------------------------------------
// Client tests
@@ -296,119 +243,6 @@ class TlsTestServer : public FlightServerBase {
}
};
-class DoPutTestServer : public FlightServerBase {
- public:
- Status DoPut(const ServerCallContext& context,
- std::unique_ptr reader,
- std::unique_ptr writer) override {
- descriptor_ = reader->descriptor();
- return reader->ReadAll(&batches_);
- }
-
- protected:
- FlightDescriptor descriptor_;
- RecordBatchVector batches_;
-
- friend class TestDoPut;
-};
-
-class MetadataTestServer : public FlightServerBase {
- Status DoGet(const ServerCallContext& context, const Ticket& request,
- std::unique_ptr* data_stream) override {
- RecordBatchVector batches;
- if (request.ticket == "dicts") {
- RETURN_NOT_OK(ExampleDictBatches(&batches));
- } else if (request.ticket == "floats") {
- RETURN_NOT_OK(ExampleFloatBatches(&batches));
- } else {
- RETURN_NOT_OK(ExampleIntBatches(&batches));
- }
- ARROW_ASSIGN_OR_RAISE(auto batch_reader, RecordBatchReader::Make(batches));
- *data_stream = std::unique_ptr(new NumberingStream(
- std::unique_ptr(new RecordBatchStream(batch_reader))));
- return Status::OK();
- }
-
- Status DoPut(const ServerCallContext& context,
- std::unique_ptr reader,
- std::unique_ptr writer) override {
- FlightStreamChunk chunk;
- int counter = 0;
- while (true) {
- RETURN_NOT_OK(reader->Next(&chunk));
- if (chunk.data == nullptr) break;
- if (chunk.app_metadata == nullptr) {
- return Status::Invalid("Expected application metadata to be provided");
- }
- if (std::to_string(counter) != chunk.app_metadata->ToString()) {
- return Status::Invalid("Expected metadata value: " + std::to_string(counter) +
- " but got: " + chunk.app_metadata->ToString());
- }
- auto metadata = Buffer::FromString(std::to_string(counter));
- RETURN_NOT_OK(writer->WriteMetadata(*metadata));
- counter++;
- }
- return Status::OK();
- }
-};
-
-// Server for testing custom IPC options support
-class OptionsTestServer : public FlightServerBase {
- Status DoGet(const ServerCallContext& context, const Ticket& request,
- std::unique_ptr* data_stream) override {
- RecordBatchVector batches;
- RETURN_NOT_OK(ExampleNestedBatches(&batches));
- ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make(batches));
- *data_stream = std::unique_ptr(new RecordBatchStream(reader));
- return Status::OK();
- }
-
- // Just echo the number of batches written. The client will try to
- // call this method with different write options set.
- Status DoPut(const ServerCallContext& context,
- std::unique_ptr reader,
- std::unique_ptr writer) override {
- FlightStreamChunk chunk;
- int counter = 0;
- while (true) {
- RETURN_NOT_OK(reader->Next(&chunk));
- if (chunk.data == nullptr) break;
- counter++;
- }
- auto metadata = Buffer::FromString(std::to_string(counter));
- return writer->WriteMetadata(*metadata);
- }
-
- // Echo client data, but with write options set to limit the nesting
- // level.
- Status DoExchange(const ServerCallContext& context,
- std::unique_ptr reader,
- std::unique_ptr writer) override {
- FlightStreamChunk chunk;
- auto options = ipc::IpcWriteOptions::Defaults();
- options.max_recursion_depth = 1;
- bool begun = false;
- while (true) {
- RETURN_NOT_OK(reader->Next(&chunk));
- if (!chunk.data && !chunk.app_metadata) {
- break;
- }
- if (!begun && chunk.data) {
- begun = true;
- RETURN_NOT_OK(writer->Begin(chunk.data->schema(), options));
- }
- if (chunk.data && chunk.app_metadata) {
- RETURN_NOT_OK(writer->WriteWithMetadata(*chunk.data, chunk.app_metadata));
- } else if (chunk.data) {
- RETURN_NOT_OK(writer->WriteRecordBatch(*chunk.data));
- } else if (chunk.app_metadata) {
- RETURN_NOT_OK(writer->WriteMetadata(chunk.app_metadata));
- }
- }
- return Status::OK();
- }
-};
-
class HeaderAuthTestServer : public FlightServerBase {
public:
Status ListFlights(const ServerCallContext& context, const Criteria* criteria,
@@ -417,42 +251,6 @@ class HeaderAuthTestServer : public FlightServerBase {
}
};
-class TestMetadata : public ::testing::Test {
- public:
- void SetUp() {
- ASSERT_OK(MakeServer(
- &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); },
- [](FlightClientOptions* options) { return Status::OK(); }));
- }
-
- void TearDown() {
- ASSERT_OK(client_->Close());
- ASSERT_OK(server_->Shutdown());
- }
-
- protected:
- std::unique_ptr client_;
- std::unique_ptr server_;
-};
-
-class TestOptions : public ::testing::Test {
- public:
- void SetUp() {
- ASSERT_OK(MakeServer(
- &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); },
- [](FlightClientOptions* options) { return Status::OK(); }));
- }
-
- void TearDown() {
- ASSERT_OK(client_->Close());
- ASSERT_OK(server_->Shutdown());
- }
-
- protected:
- std::unique_ptr client_;
- std::unique_ptr server_;
-};
-
class TestAuthHandler : public ::testing::Test {
public:
void SetUp() {
@@ -499,49 +297,6 @@ class TestBasicAuthHandler : public ::testing::Test {
std::unique_ptr server_;
};
-class TestDoPut : public ::testing::Test {
- public:
- void SetUp() {
- ASSERT_OK(MakeServer(
- &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); },
- [](FlightClientOptions* options) { return Status::OK(); }));
- do_put_server_ = (DoPutTestServer*)server_.get();
- }
-
- void TearDown() {
- ASSERT_OK(client_->Close());
- ASSERT_OK(server_->Shutdown());
- }
-
- void CheckBatches(FlightDescriptor expected_descriptor,
- const RecordBatchVector& expected_batches) {
- ASSERT_TRUE(do_put_server_->descriptor_.Equals(expected_descriptor));
- ASSERT_EQ(do_put_server_->batches_.size(), expected_batches.size());
- for (size_t i = 0; i < expected_batches.size(); ++i) {
- ASSERT_BATCHES_EQUAL(*do_put_server_->batches_[i], *expected_batches[i]);
- }
- }
-
- void CheckDoPut(FlightDescriptor descr, const std::shared_ptr& schema,
- const RecordBatchVector& batches) {
- std::unique_ptr stream;
- std::unique_ptr reader;
- ASSERT_OK(client_->DoPut(descr, schema, &stream, &reader));
- for (const auto& batch : batches) {
- ASSERT_OK(stream->WriteRecordBatch(*batch));
- }
- ASSERT_OK(stream->DoneWriting());
- ASSERT_OK(stream->Close());
-
- CheckBatches(descr, batches);
- }
-
- protected:
- std::unique_ptr client_;
- std::unique_ptr server_;
- DoPutTestServer* do_put_server_;
-};
-
class TestTls : public ::testing::Test {
public:
void SetUp() {
@@ -870,7 +625,7 @@ class PropagatingTestServer : public FlightServerBase {
class TestRejectServerMiddleware : public ::testing::Test {
public:
void SetUp() {
- ASSERT_OK(MakeServer(
+ ASSERT_OK(MakeServer(
&server_, &client_,
[](FlightServerOptions* options) {
options->middleware.push_back(
@@ -894,7 +649,7 @@ class TestCountingServerMiddleware : public ::testing::Test {
public:
void SetUp() {
request_counter_ = std::make_shared();
- ASSERT_OK(MakeServer(
+ ASSERT_OK(MakeServer(
&server_, &client_,
[&](FlightServerOptions* options) {
options->middleware.push_back({"request_counter", request_counter_});
@@ -1129,342 +884,6 @@ TEST_F(TestFlightClient, GetFlightInfoNotFound) {
ASSERT_NE(st.message().find("Flight not found"), std::string::npos);
}
-TEST_F(TestFlightClient, DoGetInts) {
- auto descr = FlightDescriptor::Path({"examples", "ints"});
- RecordBatchVector expected_batches;
- ASSERT_OK(ExampleIntBatches(&expected_batches));
-
- auto check_endpoints = [](const std::vector& endpoints) {
- // Two endpoints in the example FlightInfo
- ASSERT_EQ(2, endpoints.size());
- ASSERT_EQ(Ticket{"ticket-ints-1"}, endpoints[0].ticket);
- };
-
- CheckDoGet(descr, expected_batches, check_endpoints);
-}
-
-TEST_F(TestFlightClient, DoGetFloats) {
- auto descr = FlightDescriptor::Path({"examples", "floats"});
- RecordBatchVector expected_batches;
- ASSERT_OK(ExampleFloatBatches(&expected_batches));
-
- auto check_endpoints = [](const std::vector& endpoints) {
- // One endpoint in the example FlightInfo
- ASSERT_EQ(1, endpoints.size());
- ASSERT_EQ(Ticket{"ticket-floats-1"}, endpoints[0].ticket);
- };
-
- CheckDoGet(descr, expected_batches, check_endpoints);
-}
-
-TEST_F(TestFlightClient, DoGetDicts) {
- auto descr = FlightDescriptor::Path({"examples", "dicts"});
- RecordBatchVector expected_batches;
- ASSERT_OK(ExampleDictBatches(&expected_batches));
-
- auto check_endpoints = [](const std::vector& endpoints) {
- // One endpoint in the example FlightInfo
- ASSERT_EQ(1, endpoints.size());
- ASSERT_EQ(Ticket{"ticket-dicts-1"}, endpoints[0].ticket);
- };
-
- CheckDoGet(descr, expected_batches, check_endpoints);
-}
-
-// Ensure the gRPC client is configured to allow large messages
-// Tests a 32 MiB batch
-TEST_F(TestFlightClient, DoGetLargeBatch) {
- RecordBatchVector expected_batches;
- ASSERT_OK(ExampleLargeBatches(&expected_batches));
- Ticket ticket{"ticket-large-batch-1"};
- CheckDoGet(ticket, expected_batches);
-}
-
-TEST_F(TestFlightClient, FlightDataOverflowServerBatch) {
- // Regression test for ARROW-13253
- // N.B. this is rather a slow and memory-hungry test
- {
- // DoGet: check for overflow on large batch
- Ticket ticket{"ARROW-13253-DoGet-Batch"};
- std::unique_ptr stream;
- ASSERT_OK(client_->DoGet(ticket, &stream));
- FlightStreamChunk chunk;
- EXPECT_RAISES_WITH_MESSAGE_THAT(
- Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
- stream->Next(&chunk));
- }
- {
- // DoExchange: check for overflow on large batch from server
- auto descr = FlightDescriptor::Command("large_batch");
- std::unique_ptr reader;
- std::unique_ptr writer;
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
- RecordBatchVector batches;
- EXPECT_RAISES_WITH_MESSAGE_THAT(
- Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
- reader->ReadAll(&batches));
- }
-}
-
-TEST_F(TestFlightClient, FlightDataOverflowClientBatch) {
- ASSERT_OK_AND_ASSIGN(auto batch, VeryLargeBatch());
- {
- // DoPut: check for overflow on large batch
- std::unique_ptr stream;
- std::unique_ptr reader;
- auto descr = FlightDescriptor::Path({""});
- ASSERT_OK(client_->DoPut(descr, batch->schema(), &stream, &reader));
- EXPECT_RAISES_WITH_MESSAGE_THAT(
- Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
- stream->WriteRecordBatch(*batch));
- ASSERT_OK(stream->Close());
- }
- {
- // DoExchange: check for overflow on large batch from client
- auto descr = FlightDescriptor::Command("counter");
- std::unique_ptr reader;
- std::unique_ptr writer;
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
- ASSERT_OK(writer->Begin(batch->schema()));
- EXPECT_RAISES_WITH_MESSAGE_THAT(
- Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
- writer->WriteRecordBatch(*batch));
- ASSERT_OK(writer->Close());
- }
-}
-
-TEST_F(TestFlightClient, DoExchange) {
- auto descr = FlightDescriptor::Command("counter");
- RecordBatchVector batches;
- auto a1 = ArrayFromJSON(int32(), "[4, 5, 6, null]");
- auto schema = arrow::schema({field("f1", a1->type())});
- batches.push_back(RecordBatch::Make(schema, a1->length(), {a1}));
- std::unique_ptr reader;
- std::unique_ptr writer;
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
- ASSERT_OK(writer->Begin(schema));
- for (const auto& batch : batches) {
- ASSERT_OK(writer->WriteRecordBatch(*batch));
- }
- ASSERT_OK(writer->DoneWriting());
- FlightStreamChunk chunk;
- ASSERT_OK(reader->Next(&chunk));
- ASSERT_NE(nullptr, chunk.app_metadata);
- ASSERT_EQ(nullptr, chunk.data);
- ASSERT_EQ("1", chunk.app_metadata->ToString());
- ASSERT_OK_AND_ASSIGN(auto server_schema, reader->GetSchema());
- AssertSchemaEqual(schema, server_schema);
- for (const auto& batch : batches) {
- ASSERT_OK(reader->Next(&chunk));
- ASSERT_BATCHES_EQUAL(*batch, *chunk.data);
- }
- ASSERT_OK(writer->Close());
-}
-
-// Test pure-metadata DoExchange to ensure nothing blocks waiting for
-// schema messages
-TEST_F(TestFlightClient, DoExchangeNoData) {
- auto descr = FlightDescriptor::Command("counter");
- std::unique_ptr reader;
- std::unique_ptr writer;
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
- ASSERT_OK(writer->DoneWriting());
- FlightStreamChunk chunk;
- ASSERT_OK(reader->Next(&chunk));
- ASSERT_EQ(nullptr, chunk.data);
- ASSERT_NE(nullptr, chunk.app_metadata);
- ASSERT_EQ("0", chunk.app_metadata->ToString());
- ASSERT_OK(writer->Close());
-}
-
-// Test sending a schema without any data, as this hits an edge case
-// in the client-side writer.
-TEST_F(TestFlightClient, DoExchangeWriteOnlySchema) {
- auto descr = FlightDescriptor::Command("counter");
- std::unique_ptr reader;
- std::unique_ptr writer;
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
- auto schema = arrow::schema({field("f1", arrow::int32())});
- ASSERT_OK(writer->Begin(schema));
- ASSERT_OK(writer->WriteMetadata(Buffer::FromString("foo")));
- ASSERT_OK(writer->DoneWriting());
- FlightStreamChunk chunk;
- ASSERT_OK(reader->Next(&chunk));
- ASSERT_EQ(nullptr, chunk.data);
- ASSERT_NE(nullptr, chunk.app_metadata);
- ASSERT_EQ("0", chunk.app_metadata->ToString());
- ASSERT_OK(writer->Close());
-}
-
-// Emulate DoGet
-TEST_F(TestFlightClient, DoExchangeGet) {
- auto descr = FlightDescriptor::Command("get");
- std::unique_ptr reader;
- std::unique_ptr writer;
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
- ASSERT_OK_AND_ASSIGN(auto server_schema, reader->GetSchema());
- AssertSchemaEqual(*ExampleIntSchema(), *server_schema);
- RecordBatchVector batches;
- ASSERT_OK(ExampleIntBatches(&batches));
- FlightStreamChunk chunk;
- for (const auto& batch : batches) {
- ASSERT_OK(reader->Next(&chunk));
- ASSERT_NE(nullptr, chunk.data);
- AssertBatchesEqual(*batch, *chunk.data);
- }
- ASSERT_OK(reader->Next(&chunk));
- ASSERT_EQ(nullptr, chunk.data);
- ASSERT_EQ(nullptr, chunk.app_metadata);
- ASSERT_OK(writer->Close());
-}
-
-// Emulate DoPut
-TEST_F(TestFlightClient, DoExchangePut) {
- auto descr = FlightDescriptor::Command("put");
- std::unique_ptr reader;
- std::unique_ptr writer;
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
- ASSERT_OK(writer->Begin(ExampleIntSchema()));
- RecordBatchVector batches;
- ASSERT_OK(ExampleIntBatches(&batches));
- for (const auto& batch : batches) {
- ASSERT_OK(writer->WriteRecordBatch(*batch));
- }
- ASSERT_OK(writer->DoneWriting());
- FlightStreamChunk chunk;
- ASSERT_OK(reader->Next(&chunk));
- ASSERT_NE(nullptr, chunk.app_metadata);
- AssertBufferEqual(*chunk.app_metadata, "done");
- ASSERT_OK(reader->Next(&chunk));
- ASSERT_EQ(nullptr, chunk.data);
- ASSERT_EQ(nullptr, chunk.app_metadata);
- ASSERT_OK(writer->Close());
-}
-
-// Test the echo server
-TEST_F(TestFlightClient, DoExchangeEcho) {
- auto descr = FlightDescriptor::Command("echo");
- std::unique_ptr reader;
- std::unique_ptr writer;
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
- ASSERT_OK(writer->Begin(ExampleIntSchema()));
- RecordBatchVector batches;
- FlightStreamChunk chunk;
- ASSERT_OK(ExampleIntBatches(&batches));
- for (const auto& batch : batches) {
- ASSERT_OK(writer->WriteRecordBatch(*batch));
- ASSERT_OK(reader->Next(&chunk));
- ASSERT_NE(nullptr, chunk.data);
- ASSERT_EQ(nullptr, chunk.app_metadata);
- AssertBatchesEqual(*batch, *chunk.data);
- }
- for (int i = 0; i < 10; i++) {
- const auto buf = Buffer::FromString(std::to_string(i));
- ASSERT_OK(writer->WriteMetadata(buf));
- ASSERT_OK(reader->Next(&chunk));
- ASSERT_EQ(nullptr, chunk.data);
- ASSERT_NE(nullptr, chunk.app_metadata);
- AssertBufferEqual(*buf, *chunk.app_metadata);
- }
- int index = 0;
- for (const auto& batch : batches) {
- const auto buf = Buffer::FromString(std::to_string(index));
- ASSERT_OK(writer->WriteWithMetadata(*batch, buf));
- ASSERT_OK(reader->Next(&chunk));
- ASSERT_NE(nullptr, chunk.data);
- ASSERT_NE(nullptr, chunk.app_metadata);
- AssertBatchesEqual(*batch, *chunk.data);
- AssertBufferEqual(*buf, *chunk.app_metadata);
- index++;
- }
- ASSERT_OK(writer->DoneWriting());
- ASSERT_OK(reader->Next(&chunk));
- ASSERT_EQ(nullptr, chunk.data);
- ASSERT_EQ(nullptr, chunk.app_metadata);
- ASSERT_OK(writer->Close());
-}
-
-// Test interleaved reading/writing
-TEST_F(TestFlightClient, DoExchangeTotal) {
- auto descr = FlightDescriptor::Command("total");
- std::unique_ptr reader;
- std::unique_ptr writer;
- {
- auto a1 = ArrayFromJSON(arrow::int32(), "[4, 5, 6, null]");
- auto schema = arrow::schema({field("f1", a1->type())});
- // XXX: as noted in flight/client.cc, Begin() is lazy and the
- // schema message won't be written until some data is also
- // written. There's also timing issues; hence we check each status
- // here.
- EXPECT_RAISES_WITH_MESSAGE_THAT(
- Invalid, ::testing::HasSubstr("Field is not INT64: f1"), ([&]() {
- RETURN_NOT_OK(client_->DoExchange(descr, &writer, &reader));
- RETURN_NOT_OK(writer->Begin(schema));
- auto batch = RecordBatch::Make(schema, /* num_rows */ 4, {a1});
- RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
- return writer->Close();
- })());
- }
- {
- auto a1 = ArrayFromJSON(arrow::int64(), "[1, 2, null, 3]");
- auto a2 = ArrayFromJSON(arrow::int64(), "[null, 4, 5, 6]");
- auto schema = arrow::schema({field("f1", a1->type()), field("f2", a2->type())});
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
- ASSERT_OK(writer->Begin(schema));
- auto batch = RecordBatch::Make(schema, /* num_rows */ 4, {a1, a2});
- FlightStreamChunk chunk;
- ASSERT_OK(writer->WriteRecordBatch(*batch));
- ASSERT_OK_AND_ASSIGN(auto server_schema, reader->GetSchema());
- AssertSchemaEqual(*schema, *server_schema);
-
- ASSERT_OK(reader->Next(&chunk));
- ASSERT_NE(nullptr, chunk.data);
- auto expected1 = RecordBatch::Make(
- schema, /* num_rows */ 1,
- {ArrayFromJSON(arrow::int64(), "[6]"), ArrayFromJSON(arrow::int64(), "[15]")});
- AssertBatchesEqual(*expected1, *chunk.data);
-
- ASSERT_OK(writer->WriteRecordBatch(*batch));
- ASSERT_OK(reader->Next(&chunk));
- ASSERT_NE(nullptr, chunk.data);
- auto expected2 = RecordBatch::Make(
- schema, /* num_rows */ 1,
- {ArrayFromJSON(arrow::int64(), "[12]"), ArrayFromJSON(arrow::int64(), "[30]")});
- AssertBatchesEqual(*expected2, *chunk.data);
-
- ASSERT_OK(writer->Close());
- }
-}
-
-// Ensure server errors get propagated no matter what we try
-TEST_F(TestFlightClient, DoExchangeError) {
- auto descr = FlightDescriptor::Command("error");
- std::unique_ptr reader;
- std::unique_ptr writer;
- {
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
- auto status = writer->Close();
- EXPECT_RAISES_WITH_MESSAGE_THAT(
- NotImplemented, ::testing::HasSubstr("Expected error"), writer->Close());
- }
- {
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
- FlightStreamChunk chunk;
- EXPECT_RAISES_WITH_MESSAGE_THAT(
- NotImplemented, ::testing::HasSubstr("Expected error"), reader->Next(&chunk));
- }
- {
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
- EXPECT_RAISES_WITH_MESSAGE_THAT(
- NotImplemented, ::testing::HasSubstr("Expected error"), reader->GetSchema());
- }
- // writer->Begin isn't tested here because, as noted in client.cc,
- // OpenRecordBatchWriter lazily writes the initial message - hence
- // Begin() won't fail. Additionally, it appears gRPC may buffer
- // writes - a write won't immediately fail even when the server
- // immediately fails.
-}
-
TEST_F(TestFlightClient, ListActions) {
std::vector actions;
ASSERT_OK(client_->ListActions(&actions));
@@ -1510,21 +929,6 @@ TEST_F(TestFlightClient, RoundTripStatus) {
ASSERT_RAISES(OutOfMemory, status);
}
-TEST_F(TestFlightClient, Issue5095) {
- // Make sure the server-side error message is reflected to the
- // client
- Ticket ticket1{"ARROW-5095-fail"};
- std::unique_ptr stream;
- Status status = client_->DoGet(ticket1, &stream);
- ASSERT_RAISES(UnknownError, status);
- ASSERT_THAT(status.message(), ::testing::HasSubstr("Server-side error"));
-
- Ticket ticket2{"ARROW-5095-success"};
- status = client_->DoGet(ticket2, &stream);
- ASSERT_RAISES(KeyError, status);
- ASSERT_THAT(status.message(), ::testing::HasSubstr("No data"));
-}
-
// Test setting generic transport options by configuring gRPC to fail
// all calls.
TEST_F(TestFlightClient, GenericOptions) {
@@ -1593,117 +997,6 @@ TEST_F(TestFlightClient, Close) {
client_->ListFlights(&listing));
}
-TEST_F(TestDoPut, DoPutInts) {
- auto descr = FlightDescriptor::Path({"ints"});
- RecordBatchVector batches;
- auto a0 = ArrayFromJSON(int8(), "[0, 1, 127, -128, null]");
- auto a1 = ArrayFromJSON(uint8(), "[0, 1, 127, 255, null]");
- auto a2 = ArrayFromJSON(int16(), "[0, 258, 32767, -32768, null]");
- auto a3 = ArrayFromJSON(uint16(), "[0, 258, 32767, 65535, null]");
- auto a4 = ArrayFromJSON(int32(), "[0, 65538, 2147483647, -2147483648, null]");
- auto a5 = ArrayFromJSON(uint32(), "[0, 65538, 2147483647, 4294967295, null]");
- auto a6 = ArrayFromJSON(
- int64(), "[0, 4294967298, 9223372036854775807, -9223372036854775808, null]");
- auto a7 = ArrayFromJSON(
- uint64(), "[0, 4294967298, 9223372036854775807, 18446744073709551615, null]");
- auto schema = arrow::schema({field("f0", a0->type()), field("f1", a1->type()),
- field("f2", a2->type()), field("f3", a3->type()),
- field("f4", a4->type()), field("f5", a5->type()),
- field("f6", a6->type()), field("f7", a7->type())});
- batches.push_back(
- RecordBatch::Make(schema, a0->length(), {a0, a1, a2, a3, a4, a5, a6, a7}));
-
- CheckDoPut(descr, schema, batches);
-}
-
-TEST_F(TestDoPut, DoPutFloats) {
- auto descr = FlightDescriptor::Path({"floats"});
- RecordBatchVector batches;
- auto a0 = ArrayFromJSON(float32(), "[0, 1.2, -3.4, 5.6, null]");
- auto a1 = ArrayFromJSON(float64(), "[0, 1.2, -3.4, 5.6, null]");
- auto schema = arrow::schema({field("f0", a0->type()), field("f1", a1->type())});
- batches.push_back(RecordBatch::Make(schema, a0->length(), {a0, a1}));
-
- CheckDoPut(descr, schema, batches);
-}
-
-TEST_F(TestDoPut, DoPutEmptyBatch) {
- // Sending and receiving a 0-sized batch shouldn't fail
- auto descr = FlightDescriptor::Path({"ints"});
- RecordBatchVector batches;
- auto a1 = ArrayFromJSON(int32(), "[]");
- auto schema = arrow::schema({field("f1", a1->type())});
- batches.push_back(RecordBatch::Make(schema, a1->length(), {a1}));
-
- CheckDoPut(descr, schema, batches);
-}
-
-TEST_F(TestDoPut, DoPutDicts) {
- auto descr = FlightDescriptor::Path({"dicts"});
- RecordBatchVector batches;
- auto dict_values = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"quux\"]");
- auto ty = dictionary(int8(), dict_values->type());
- auto schema = arrow::schema({field("f1", ty)});
- // Make several batches
- for (const char* json : {"[1, 0, 1]", "[null]", "[null, 1]"}) {
- auto indices = ArrayFromJSON(int8(), json);
- auto dict_array = std::make_shared(ty, indices, dict_values);
- batches.push_back(RecordBatch::Make(schema, dict_array->length(), {dict_array}));
- }
-
- CheckDoPut(descr, schema, batches);
-}
-
-// Ensure the gRPC server is configured to allow large messages
-// Tests a 32 MiB batch
-TEST_F(TestDoPut, DoPutLargeBatch) {
- auto descr = FlightDescriptor::Path({"large-batches"});
- auto schema = ExampleLargeSchema();
- RecordBatchVector batches;
- ASSERT_OK(ExampleLargeBatches(&batches));
- CheckDoPut(descr, schema, batches);
-}
-
-TEST_F(TestDoPut, DoPutSizeLimit) {
- const int64_t size_limit = 4096;
- Location location;
- ASSERT_OK(Location::ForGrpcTcp("localhost", server_->port(), &location));
- auto client_options = FlightClientOptions::Defaults();
- client_options.write_size_limit_bytes = size_limit;
- std::unique_ptr client;
- ASSERT_OK(FlightClient::Connect(location, client_options, &client));
-
- auto descr = FlightDescriptor::Path({"ints"});
- // Batch is too large to fit in one message
- auto schema = arrow::schema({field("f1", arrow::int64())});
- auto batch = arrow::ConstantArrayGenerator::Zeroes(768, schema);
- RecordBatchVector batches;
- batches.push_back(batch->Slice(0, 384));
- batches.push_back(batch->Slice(384));
-
- std::unique_ptr stream;
- std::unique_ptr reader;
- ASSERT_OK(client->DoPut(descr, schema, &stream, &reader));
-
- // Large batch will exceed the limit
- const auto status = stream->WriteRecordBatch(*batch);
- EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("exceeded soft limit"),
- status);
- auto detail = FlightWriteSizeStatusDetail::UnwrapStatus(status);
- ASSERT_NE(nullptr, detail);
- ASSERT_EQ(size_limit, detail->limit());
- ASSERT_GT(detail->actual(), size_limit);
-
- // But we can retry with a smaller batch
- for (const auto& batch : batches) {
- ASSERT_OK(stream->WriteRecordBatch(*batch));
- }
-
- ASSERT_OK(stream->DoneWriting());
- ASSERT_OK(stream->Close());
- CheckBatches(descr, batches);
-}
-
TEST_F(TestAuthHandler, PassAuthenticatedCalls) {
ASSERT_OK(client_->Authenticate(
{},
@@ -1995,205 +1288,6 @@ TEST_F(TestTls, OverrideHostnameGeneric) {
// necessarily stable
}
-TEST_F(TestMetadata, DoGet) {
- Ticket ticket{""};
- std::unique_ptr stream;
- ASSERT_OK(client_->DoGet(ticket, &stream));
-
- RecordBatchVector expected_batches;
- ASSERT_OK(ExampleIntBatches(&expected_batches));
-
- FlightStreamChunk chunk;
- auto num_batches = static_cast(expected_batches.size());
- for (int i = 0; i < num_batches; ++i) {
- ASSERT_OK(stream->Next(&chunk));
- ASSERT_NE(nullptr, chunk.data);
- ASSERT_NE(nullptr, chunk.app_metadata);
- ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data);
- ASSERT_EQ(std::to_string(i), chunk.app_metadata->ToString());
- }
- ASSERT_OK(stream->Next(&chunk));
- ASSERT_EQ(nullptr, chunk.data);
-}
-
-// Test dictionaries. This tests a corner case in the reader:
-// dictionary batches come in between the schema and the first record
-// batch, so the server must take care to read application metadata
-// from the record batch, and not one of the dictionary batches.
-TEST_F(TestMetadata, DoGetDictionaries) {
- Ticket ticket{"dicts"};
- std::unique_ptr stream;
- ASSERT_OK(client_->DoGet(ticket, &stream));
-
- RecordBatchVector expected_batches;
- ASSERT_OK(ExampleDictBatches(&expected_batches));
-
- FlightStreamChunk chunk;
- auto num_batches = static_cast(expected_batches.size());
- for (int i = 0; i < num_batches; ++i) {
- ASSERT_OK(stream->Next(&chunk));
- ASSERT_NE(nullptr, chunk.data);
- ASSERT_NE(nullptr, chunk.app_metadata);
- ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data);
- ASSERT_EQ(std::to_string(i), chunk.app_metadata->ToString());
- }
- ASSERT_OK(stream->Next(&chunk));
- ASSERT_EQ(nullptr, chunk.data);
-}
-
-TEST_F(TestMetadata, DoPut) {
- std::unique_ptr writer;
- std::unique_ptr reader;
- std::shared_ptr schema = ExampleIntSchema();
- ASSERT_OK(client_->DoPut(FlightDescriptor{}, schema, &writer, &reader));
-
- RecordBatchVector expected_batches;
- ASSERT_OK(ExampleIntBatches(&expected_batches));
-
- std::shared_ptr chunk;
- std::shared_ptr metadata;
- auto num_batches = static_cast(expected_batches.size());
- for (int i = 0; i < num_batches; ++i) {
- ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i],
- Buffer::FromString(std::to_string(i))));
- }
- // This eventually calls grpc::ClientReaderWriter::Finish which can
- // hang if there are unread messages. So make sure our wrapper
- // around this doesn't hang (because it drains any unread messages)
- ASSERT_OK(writer->Close());
-}
-
-// Test DoPut() with dictionaries. This tests a corner case in the
-// server-side reader; see DoGetDictionaries above.
-TEST_F(TestMetadata, DoPutDictionaries) {
- std::unique_ptr writer;
- std::unique_ptr reader;
- RecordBatchVector expected_batches;
- ASSERT_OK(ExampleDictBatches(&expected_batches));
- // ARROW-8749: don't get the schema via ExampleDictSchema because
- // DictionaryMemo uses field addresses to determine whether it's
- // seen a field before. Hence, if we use a schema that is different
- // (identity-wise) than the schema of the first batch we write,
- // we'll end up generating a duplicate set of dictionaries that
- // confuses the reader.
- ASSERT_OK(client_->DoPut(FlightDescriptor{}, expected_batches[0]->schema(), &writer,
- &reader));
- std::shared_ptr chunk;
- std::shared_ptr metadata;
- auto num_batches = static_cast(expected_batches.size());
- for (int i = 0; i < num_batches; ++i) {
- ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i],
- Buffer::FromString(std::to_string(i))));
- }
- ASSERT_OK(writer->Close());
-}
-
-TEST_F(TestMetadata, DoPutReadMetadata) {
- std::unique_ptr writer;
- std::unique_ptr reader;
- std::shared_ptr schema = ExampleIntSchema();
- ASSERT_OK(client_->DoPut(FlightDescriptor{}, schema, &writer, &reader));
-
- RecordBatchVector expected_batches;
- ASSERT_OK(ExampleIntBatches(&expected_batches));
-
- std::shared_ptr chunk;
- std::shared_ptr metadata;
- auto num_batches = static_cast(expected_batches.size());
- for (int i = 0; i < num_batches; ++i) {
- ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i],
- Buffer::FromString(std::to_string(i))));
- ASSERT_OK(reader->ReadMetadata(&metadata));
- ASSERT_NE(nullptr, metadata);
- ASSERT_EQ(std::to_string(i), metadata->ToString());
- }
- // As opposed to DoPutDrainMetadata, now we've read the messages, so
- // make sure this still closes as expected.
- ASSERT_OK(writer->Close());
-}
-
-TEST_F(TestOptions, DoGetReadOptions) {
- // Call DoGet, but with a very low read nesting depth set to fail the call.
- Ticket ticket{""};
- auto options = FlightCallOptions();
- options.read_options.max_recursion_depth = 1;
- std::unique_ptr stream;
- ASSERT_OK(client_->DoGet(options, ticket, &stream));
- FlightStreamChunk chunk;
- ASSERT_RAISES(Invalid, stream->Next(&chunk));
-}
-
-TEST_F(TestOptions, DoPutWriteOptions) {
- // Call DoPut, but with a very low write nesting depth set to fail the call.
- std::unique_ptr writer;
- std::unique_ptr reader;
- RecordBatchVector expected_batches;
- ASSERT_OK(ExampleNestedBatches(&expected_batches));
-
- auto options = FlightCallOptions();
- options.write_options.max_recursion_depth = 1;
- ASSERT_OK(client_->DoPut(options, FlightDescriptor{}, expected_batches[0]->schema(),
- &writer, &reader));
- for (const auto& batch : expected_batches) {
- ASSERT_RAISES(Invalid, writer->WriteRecordBatch(*batch));
- }
-}
-
-TEST_F(TestOptions, DoExchangeClientWriteOptions) {
- // Call DoExchange and write nested data, but with a very low nesting depth set to
- // fail the call.
- auto options = FlightCallOptions();
- options.write_options.max_recursion_depth = 1;
- auto descr = FlightDescriptor::Command("");
- std::unique_ptr reader;
- std::unique_ptr writer;
- ASSERT_OK(client_->DoExchange(options, descr, &writer, &reader));
- RecordBatchVector batches;
- ASSERT_OK(ExampleNestedBatches(&batches));
- ASSERT_OK(writer->Begin(batches[0]->schema()));
- for (const auto& batch : batches) {
- ASSERT_RAISES(Invalid, writer->WriteRecordBatch(*batch));
- }
- ASSERT_OK(writer->DoneWriting());
- ASSERT_OK(writer->Close());
-}
-
-TEST_F(TestOptions, DoExchangeClientWriteOptionsBegin) {
- // Call DoExchange and write nested data, but with a very low nesting depth set to
- // fail the call. Here the options are set explicitly when we write data and not in the
- // call options.
- auto descr = FlightDescriptor::Command("");
- std::unique_ptr reader;
- std::unique_ptr writer;
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
- RecordBatchVector batches;
- ASSERT_OK(ExampleNestedBatches(&batches));
- auto options = ipc::IpcWriteOptions::Defaults();
- options.max_recursion_depth = 1;
- ASSERT_OK(writer->Begin(batches[0]->schema(), options));
- for (const auto& batch : batches) {
- ASSERT_RAISES(Invalid, writer->WriteRecordBatch(*batch));
- }
- ASSERT_OK(writer->DoneWriting());
- ASSERT_OK(writer->Close());
-}
-
-TEST_F(TestOptions, DoExchangeServerWriteOptions) {
- // Call DoExchange and write nested data, but with a very low nesting depth set to fail
- // the call. (The low nesting depth is set on the server side.)
- auto descr = FlightDescriptor::Command("");
- std::unique_ptr reader;
- std::unique_ptr writer;
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
- RecordBatchVector batches;
- ASSERT_OK(ExampleNestedBatches(&batches));
- ASSERT_OK(writer->Begin(batches[0]->schema()));
- FlightStreamChunk chunk;
- ASSERT_OK(writer->WriteRecordBatch(*batches[0]));
- ASSERT_OK(writer->DoneWriting());
- ASSERT_RAISES(Invalid, writer->Close());
-}
-
TEST_F(TestRejectServerMiddleware, Rejected) {
std::unique_ptr info;
const auto& status = client_->GetFlightInfo(FlightDescriptor{}, &info);
diff --git a/cpp/src/arrow/flight/test_definitions.cc b/cpp/src/arrow/flight/test_definitions.cc
new file mode 100644
index 00000000000..ccb2eff272f
--- /dev/null
+++ b/cpp/src/arrow/flight/test_definitions.cc
@@ -0,0 +1,1029 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/test_definitions.h"
+
+#include
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/array_dict.h"
+#include "arrow/flight/api.h"
+#include "arrow/flight/test_util.h"
+#include "arrow/testing/generator.h"
+
+namespace arrow {
+namespace flight {
+
+//------------------------------------------------------------
+// Tests of initialization/shutdown
+
+GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ConnectivityTest);
+
+TEST_P(ConnectivityTest, GetPort) {
+ std::unique_ptr server = ExampleTestServer();
+
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ FlightServerOptions options(location);
+ ASSERT_OK(server->Init(options));
+ ASSERT_GT(server->port(), 0);
+}
+
+TEST_P(ConnectivityTest, BuilderHook) {
+ std::unique_ptr server = ExampleTestServer();
+
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ FlightServerOptions options(location);
+ bool builder_hook_run = false;
+ options.builder_hook = [&builder_hook_run](void* builder) {
+ ASSERT_NE(nullptr, builder);
+ builder_hook_run = true;
+ };
+ ASSERT_OK(server->Init(options));
+ ASSERT_TRUE(builder_hook_run);
+ ASSERT_GT(server->port(), 0);
+ ASSERT_OK(server->Shutdown());
+}
+
+TEST_P(ConnectivityTest, ServeShutdown) {
+ // Regression test for ARROW-15181
+ constexpr int kIterations = 10;
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ for (int i = 0; i < kIterations; i++) {
+ std::unique_ptr server = ExampleTestServer();
+
+ FlightServerOptions options(location);
+ ASSERT_OK(server->Init(options));
+ ASSERT_GT(server->port(), 0);
+ std::thread t([&]() { ASSERT_OK(server->Serve()); });
+ ASSERT_OK(server->Shutdown());
+ ASSERT_OK(server->Wait());
+ t.join();
+ }
+}
+
+TEST_P(ConnectivityTest, ServeShutdownWithDeadline) {
+ std::unique_ptr server = ExampleTestServer();
+
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ FlightServerOptions options(location);
+ ASSERT_OK(server->Init(options));
+ ASSERT_GT(server->port(), 0);
+
+ auto deadline = std::chrono::system_clock::now() + std::chrono::microseconds(10);
+
+ ASSERT_OK(server->Shutdown(&deadline));
+ ASSERT_OK(server->Wait());
+}
+
+//------------------------------------------------------------
+// Tests of data plane methods
+
+GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(DataTest);
+
+DataTest::~DataTest() = default;
+void DataTest::SetUp() {
+ server_ = ExampleTestServer();
+
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ FlightServerOptions options(location);
+ ASSERT_OK(server_->Init(options));
+
+ ASSERT_OK(ConnectClient());
+}
+void DataTest::TearDown() {
+ ASSERT_OK(client_->Close());
+ ASSERT_OK(server_->Shutdown());
+}
+Status DataTest::ConnectClient() {
+ ARROW_ASSIGN_OR_RAISE(auto location,
+ Location::ForScheme(transport(), "localhost", server_->port()));
+ return FlightClient::Connect(location, &client_);
+}
+void DataTest::CheckDoGet(
+ const FlightDescriptor& descr, const RecordBatchVector& expected_batches,
+ std::function&)> check_endpoints) {
+ auto expected_schema = expected_batches[0]->schema();
+
+ std::unique_ptr info;
+ ASSERT_OK(client_->GetFlightInfo(descr, &info));
+ check_endpoints(info->endpoints());
+
+ std::shared_ptr schema;
+ ipc::DictionaryMemo dict_memo;
+ ASSERT_OK(info->GetSchema(&dict_memo, &schema));
+ AssertSchemaEqual(*expected_schema, *schema);
+
+ // By convention, fetch the first endpoint
+ Ticket ticket = info->endpoints()[0].ticket;
+ CheckDoGet(ticket, expected_batches);
+}
+void DataTest::CheckDoGet(const Ticket& ticket,
+ const RecordBatchVector& expected_batches) {
+ auto num_batches = static_cast(expected_batches.size());
+ ASSERT_GE(num_batches, 2);
+
+ std::unique_ptr stream;
+ ASSERT_OK(client_->DoGet(ticket, &stream));
+
+ std::unique_ptr stream2;
+ ASSERT_OK(client_->DoGet(ticket, &stream2));
+ ASSERT_OK_AND_ASSIGN(auto reader, MakeRecordBatchReader(std::move(stream2)));
+
+ FlightStreamChunk chunk;
+ std::shared_ptr batch;
+ for (int i = 0; i < num_batches; ++i) {
+ ASSERT_OK(stream->Next(&chunk));
+ ASSERT_OK(reader->ReadNext(&batch));
+ ASSERT_NE(nullptr, chunk.data);
+ ASSERT_NE(nullptr, batch);
+#if !defined(__MINGW32__)
+ ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data);
+ ASSERT_BATCHES_EQUAL(*expected_batches[i], *batch);
+#else
+ // In MINGW32, the following code does not have the reproducibility at the LSB
+ // even when this is called twice with the same seed.
+ // As a workaround, use approxEqual
+ // /* from GenerateTypedData in random.cc */
+ // std::default_random_engine rng(seed); // seed = 282475250
+ // std::uniform_real_distribution dist;
+ // std::generate(data, data + n, // n = 10
+ // [&dist, &rng] { return static_cast(dist(rng)); });
+ // /* data[1] = 0x40852cdfe23d3976 or 0x40852cdfe23d3975 */
+ ASSERT_BATCHES_APPROX_EQUAL(*expected_batches[i], *chunk.data);
+ ASSERT_BATCHES_APPROX_EQUAL(*expected_batches[i], *batch);
+#endif
+ }
+
+ // Stream exhausted
+ ASSERT_OK(stream->Next(&chunk));
+ ASSERT_OK(reader->ReadNext(&batch));
+ ASSERT_EQ(nullptr, chunk.data);
+ ASSERT_EQ(nullptr, batch);
+}
+
+TEST_P(DataTest, DoGetInts) {
+ auto descr = FlightDescriptor::Path({"examples", "ints"});
+ RecordBatchVector expected_batches;
+ ASSERT_OK(ExampleIntBatches(&expected_batches));
+
+ auto check_endpoints = [](const std::vector& endpoints) {
+ // Two endpoints in the example FlightInfo
+ ASSERT_EQ(2, endpoints.size());
+ ASSERT_EQ(Ticket{"ticket-ints-1"}, endpoints[0].ticket);
+ };
+
+ CheckDoGet(descr, expected_batches, check_endpoints);
+}
+
+TEST_P(DataTest, DoGetFloats) {
+ auto descr = FlightDescriptor::Path({"examples", "floats"});
+ RecordBatchVector expected_batches;
+ ASSERT_OK(ExampleFloatBatches(&expected_batches));
+
+ auto check_endpoints = [](const std::vector& endpoints) {
+ // One endpoint in the example FlightInfo
+ ASSERT_EQ(1, endpoints.size());
+ ASSERT_EQ(Ticket{"ticket-floats-1"}, endpoints[0].ticket);
+ };
+
+ CheckDoGet(descr, expected_batches, check_endpoints);
+}
+
+TEST_P(DataTest, DoGetDicts) {
+ auto descr = FlightDescriptor::Path({"examples", "dicts"});
+ RecordBatchVector expected_batches;
+ ASSERT_OK(ExampleDictBatches(&expected_batches));
+
+ auto check_endpoints = [](const std::vector& endpoints) {
+ // One endpoint in the example FlightInfo
+ ASSERT_EQ(1, endpoints.size());
+ ASSERT_EQ(Ticket{"ticket-dicts-1"}, endpoints[0].ticket);
+ };
+
+ CheckDoGet(descr, expected_batches, check_endpoints);
+}
+
+// Ensure clients are configured to allow large messages by default
+// Tests a 32 MiB batch
+TEST_P(DataTest, DoGetLargeBatch) {
+ RecordBatchVector expected_batches;
+ ASSERT_OK(ExampleLargeBatches(&expected_batches));
+ Ticket ticket{"ticket-large-batch-1"};
+ CheckDoGet(ticket, expected_batches);
+}
+
+TEST_P(DataTest, FlightDataOverflowServerBatch) {
+ // Regression test for ARROW-13253
+ // N.B. this is rather a slow and memory-hungry test
+ {
+ // DoGet: check for overflow on large batch
+ Ticket ticket{"ARROW-13253-DoGet-Batch"};
+ std::unique_ptr stream;
+ ASSERT_OK(client_->DoGet(ticket, &stream));
+ FlightStreamChunk chunk;
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
+ stream->Next(&chunk));
+ }
+ {
+ // DoExchange: check for overflow on large batch from server
+ auto descr = FlightDescriptor::Command("large_batch");
+ std::unique_ptr reader;
+ std::unique_ptr writer;
+ ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ RecordBatchVector batches;
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
+ reader->ReadAll(&batches));
+ }
+}
+
+TEST_P(DataTest, FlightDataOverflowClientBatch) {
+ ASSERT_OK_AND_ASSIGN(auto batch, VeryLargeBatch());
+ {
+ // DoPut: check for overflow on large batch
+ std::unique_ptr stream;
+ std::unique_ptr reader;
+ auto descr = FlightDescriptor::Path({""});
+ ASSERT_OK(client_->DoPut(descr, batch->schema(), &stream, &reader));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
+ stream->WriteRecordBatch(*batch));
+ ASSERT_OK(stream->Close());
+ }
+ {
+ // DoExchange: check for overflow on large batch from client
+ auto descr = FlightDescriptor::Command("counter");
+ std::unique_ptr