diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index 5861d8475d6..cae36562be3 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -213,6 +213,7 @@ if(ARROW_TESTING) OUTPUTS ARROW_FLIGHT_TESTING_LIBRARIES SOURCES + test_definitions.cc test_util.cc DEPENDENCIES GTest::gtest @@ -237,6 +238,12 @@ foreach(LIB_TARGET ${ARROW_FLIGHT_TESTING_LIBRARIES}) ${ARROW_BOOST_PROCESS_COMPILE_DEFINITIONS}) endforeach() +add_arrow_test(flight_internals_test + STATIC_LINK_LIBS + ${ARROW_FLIGHT_TEST_LINK_LIBS} + LABELS + "arrow_flight") + add_arrow_test(flight_test STATIC_LINK_LIBS ${ARROW_FLIGHT_TEST_LINK_LIBS} diff --git a/cpp/src/arrow/flight/flight_cuda_test.cc b/cpp/src/arrow/flight/flight_cuda_test.cc index b2ed6394cb8..b812efd4677 100644 --- a/cpp/src/arrow/flight/flight_cuda_test.cc +++ b/cpp/src/arrow/flight/flight_cuda_test.cc @@ -70,16 +70,16 @@ class CudaTestServer : public FlightServerBase { Status DoGet(const ServerCallContext&, const Ticket&, std::unique_ptr* data_stream) override { - BatchVector batches; + RecordBatchVector batches; RETURN_NOT_OK(ExampleIntBatches(&batches)); - auto batch_reader = std::make_shared(batches[0]->schema(), batches); + ARROW_ASSIGN_OR_RAISE(auto batch_reader, RecordBatchReader::Make(batches)); *data_stream = std::unique_ptr(new RecordBatchStream(batch_reader)); return Status::OK(); } Status DoPut(const ServerCallContext&, std::unique_ptr reader, std::unique_ptr writer) override { - BatchVector batches; + RecordBatchVector batches; RETURN_NOT_OK(reader->ReadAll(&batches)); for (const auto& batch : batches) { for (const auto& column : batch->columns()) { @@ -161,7 +161,7 @@ TEST_F(TestCuda, DoGet) { TEST_F(TestCuda, DoPut) { // Check that we can send a record batch containing references to // GPU buffers. - BatchVector batches; + RecordBatchVector batches; ASSERT_OK(ExampleIntBatches(&batches)); std::unique_ptr writer; @@ -190,7 +190,7 @@ TEST_F(TestCuda, DoExchange) { FlightCallOptions options; options.memory_manager = device_->default_memory_manager(); - BatchVector batches; + RecordBatchVector batches; ASSERT_OK(ExampleIntBatches(&batches)); std::unique_ptr writer; diff --git a/cpp/src/arrow/flight/flight_internals_test.cc b/cpp/src/arrow/flight/flight_internals_test.cc new file mode 100644 index 00000000000..525e19499c3 --- /dev/null +++ b/cpp/src/arrow/flight/flight_internals_test.cc @@ -0,0 +1,449 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// ---------------------------------------------------------------------- +// Tests for Flight which don't actually spin up a client/server + +#include +#include + +#include "arrow/flight/client_cookie_middleware.h" +#include "arrow/flight/client_header_internal.h" +#include "arrow/flight/client_middleware.h" +#include "arrow/flight/types.h" +#include "arrow/status.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/util/string.h" + +#include "arrow/flight/internal.h" +#include "arrow/flight/test_util.h" + +namespace arrow { +namespace flight { + +namespace pb = arrow::flight::protocol; + +// ---------------------------------------------------------------------- +// Core Flight types + +TEST(FlightTypes, FlightDescriptor) { + auto a = FlightDescriptor::Command("select * from table"); + auto b = FlightDescriptor::Command("select * from table"); + auto c = FlightDescriptor::Command("select foo from table"); + auto d = FlightDescriptor::Path({"foo", "bar"}); + auto e = FlightDescriptor::Path({"foo", "baz"}); + auto f = FlightDescriptor::Path({"foo", "baz"}); + + ASSERT_EQ(a.ToString(), "FlightDescriptor"); + ASSERT_EQ(d.ToString(), "FlightDescriptor"); + ASSERT_TRUE(a.Equals(b)); + ASSERT_FALSE(a.Equals(c)); + ASSERT_FALSE(a.Equals(d)); + ASSERT_FALSE(d.Equals(e)); + ASSERT_TRUE(e.Equals(f)); +} + +// This tests the internal protobuf types which don't get exported in the Flight DLL. +#ifndef _WIN32 +TEST(FlightTypes, FlightDescriptorToFromProto) { + FlightDescriptor descr_test; + pb::FlightDescriptor pb_descr; + + FlightDescriptor descr1{FlightDescriptor::PATH, "", {"foo", "bar"}}; + ASSERT_OK(internal::ToProto(descr1, &pb_descr)); + ASSERT_OK(internal::FromProto(pb_descr, &descr_test)); + ASSERT_EQ(descr1, descr_test); + + FlightDescriptor descr2{FlightDescriptor::CMD, "command", {}}; + ASSERT_OK(internal::ToProto(descr2, &pb_descr)); + ASSERT_OK(internal::FromProto(pb_descr, &descr_test)); + ASSERT_EQ(descr2, descr_test); +} +#endif + +// ARROW-6017: we should be able to construct locations for unknown +// schemes +TEST(FlightTypes, LocationUnknownScheme) { + Location location; + ASSERT_OK(Location::Parse("s3://test", &location)); + ASSERT_OK(Location::Parse("https://example.com/foo", &location)); +} + +TEST(FlightTypes, RoundTripTypes) { + Ticket ticket{"foo"}; + ASSERT_OK_AND_ASSIGN(std::string ticket_serialized, ticket.SerializeToString()); + ASSERT_OK_AND_ASSIGN(Ticket ticket_deserialized, + Ticket::Deserialize(ticket_serialized)); + ASSERT_EQ(ticket.ticket, ticket_deserialized.ticket); + + FlightDescriptor desc = FlightDescriptor::Command("select * from foo;"); + ASSERT_OK_AND_ASSIGN(std::string desc_serialized, desc.SerializeToString()); + ASSERT_OK_AND_ASSIGN(FlightDescriptor desc_deserialized, + FlightDescriptor::Deserialize(desc_serialized)); + ASSERT_TRUE(desc.Equals(desc_deserialized)); + + desc = FlightDescriptor::Path({"a", "b", "test.arrow"}); + ASSERT_OK_AND_ASSIGN(desc_serialized, desc.SerializeToString()); + ASSERT_OK_AND_ASSIGN(desc_deserialized, FlightDescriptor::Deserialize(desc_serialized)); + ASSERT_TRUE(desc.Equals(desc_deserialized)); + + FlightInfo::Data data; + std::shared_ptr schema = + arrow::schema({field("a", int64()), field("b", int64()), field("c", int64()), + field("d", int64())}); + Location location1, location2, location3; + ASSERT_OK(Location::ForGrpcTcp("localhost", 10010, &location1)); + ASSERT_OK(Location::ForGrpcTls("localhost", 10010, &location2)); + ASSERT_OK(Location::ForGrpcUnix("/tmp/test.sock", &location3)); + std::vector endpoints{FlightEndpoint{ticket, {location1, location2}}, + FlightEndpoint{ticket, {location3}}}; + ASSERT_OK(MakeFlightInfo(*schema, desc, endpoints, -1, -1, &data)); + std::unique_ptr info = std::unique_ptr(new FlightInfo(data)); + ASSERT_OK_AND_ASSIGN(std::string info_serialized, info->SerializeToString()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr info_deserialized, + FlightInfo::Deserialize(info_serialized)); + ASSERT_TRUE(info->descriptor().Equals(info_deserialized->descriptor())); + ASSERT_EQ(info->endpoints(), info_deserialized->endpoints()); + ASSERT_EQ(info->total_records(), info_deserialized->total_records()); + ASSERT_EQ(info->total_bytes(), info_deserialized->total_bytes()); +} + +TEST(FlightTypes, RoundtripStatus) { + // Make sure status codes round trip through our conversions + + std::shared_ptr detail; + detail = FlightStatusDetail::UnwrapStatus( + MakeFlightError(FlightStatusCode::Internal, "Test message")); + ASSERT_NE(nullptr, detail); + ASSERT_EQ(FlightStatusCode::Internal, detail->code()); + + detail = FlightStatusDetail::UnwrapStatus( + MakeFlightError(FlightStatusCode::TimedOut, "Test message")); + ASSERT_NE(nullptr, detail); + ASSERT_EQ(FlightStatusCode::TimedOut, detail->code()); + + detail = FlightStatusDetail::UnwrapStatus( + MakeFlightError(FlightStatusCode::Cancelled, "Test message")); + ASSERT_NE(nullptr, detail); + ASSERT_EQ(FlightStatusCode::Cancelled, detail->code()); + + detail = FlightStatusDetail::UnwrapStatus( + MakeFlightError(FlightStatusCode::Unauthenticated, "Test message")); + ASSERT_NE(nullptr, detail); + ASSERT_EQ(FlightStatusCode::Unauthenticated, detail->code()); + + detail = FlightStatusDetail::UnwrapStatus( + MakeFlightError(FlightStatusCode::Unauthorized, "Test message")); + ASSERT_NE(nullptr, detail); + ASSERT_EQ(FlightStatusCode::Unauthorized, detail->code()); + + detail = FlightStatusDetail::UnwrapStatus( + MakeFlightError(FlightStatusCode::Unavailable, "Test message")); + ASSERT_NE(nullptr, detail); + ASSERT_EQ(FlightStatusCode::Unavailable, detail->code()); + + Status status = internal::FromGrpcStatus( + internal::ToGrpcStatus(Status::NotImplemented("Sentinel"))); + ASSERT_TRUE(status.IsNotImplemented()); + ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel")); + + status = internal::FromGrpcStatus(internal::ToGrpcStatus(Status::Invalid("Sentinel"))); + ASSERT_TRUE(status.IsInvalid()); + ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel")); + + status = internal::FromGrpcStatus(internal::ToGrpcStatus(Status::KeyError("Sentinel"))); + ASSERT_TRUE(status.IsKeyError()); + ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel")); + + status = + internal::FromGrpcStatus(internal::ToGrpcStatus(Status::AlreadyExists("Sentinel"))); + ASSERT_TRUE(status.IsAlreadyExists()); + ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel")); +} + +// ---------------------------------------------------------------------- +// Cookie authentication/middleware + +// This test keeps an internal cookie cache and compares that with the middleware. +class TestCookieMiddleware : public ::testing::Test { + public: + // Setup function creates middleware factory and starts it up. + void SetUp() { + factory_ = GetCookieFactory(); + CallInfo callInfo; + factory_->StartCall(callInfo, &middleware_); + } + + // Function to add incoming cookies to middleware and validate them. + void AddAndValidate(const std::string& incoming_cookie) { + // Add cookie + CallHeaders call_headers; + call_headers.insert(std::make_pair(arrow::util::string_view("set-cookie"), + arrow::util::string_view(incoming_cookie))); + middleware_->ReceivedHeaders(call_headers); + expected_cookie_cache_.UpdateCachedCookies(call_headers); + + // Get cookie from middleware. + TestCallHeaders add_call_headers; + middleware_->SendingHeaders(&add_call_headers); + const std::string actual_cookies = add_call_headers.GetCookies(); + + // Validate cookie + const std::string expected_cookies = expected_cookie_cache_.GetValidCookiesAsString(); + const std::vector split_expected_cookies = + SplitCookies(expected_cookies); + const std::vector split_actual_cookies = SplitCookies(actual_cookies); + EXPECT_EQ(split_expected_cookies, split_actual_cookies); + } + + // Function to take a list of cookies and split them into a vector of individual + // cookies. This is done because the cookie cache is a map so ordering is not + // necessarily consistent. + static std::vector SplitCookies(const std::string& cookies) { + std::vector split_cookies; + std::string::size_type pos1 = 0; + std::string::size_type pos2 = 0; + while ((pos2 = cookies.find(';', pos1)) != std::string::npos) { + split_cookies.push_back( + arrow::internal::TrimString(cookies.substr(pos1, pos2 - pos1))); + pos1 = pos2 + 1; + } + if (pos1 < cookies.size()) { + split_cookies.push_back(arrow::internal::TrimString(cookies.substr(pos1))); + } + std::sort(split_cookies.begin(), split_cookies.end()); + return split_cookies; + } + + protected: + // Class to allow testing of the call headers. + class TestCallHeaders : public AddCallHeaders { + public: + TestCallHeaders() {} + ~TestCallHeaders() {} + + // Function to add cookie header. + void AddHeader(const std::string& key, const std::string& value) { + ASSERT_EQ(key, "cookie"); + outbound_cookie_ = value; + } + + // Function to get outgoing cookie. + std::string GetCookies() { return outbound_cookie_; } + + private: + std::string outbound_cookie_; + }; + + internal::CookieCache expected_cookie_cache_; + std::unique_ptr middleware_; + std::shared_ptr factory_; +}; + +TEST_F(TestCookieMiddleware, BasicParsing) { + AddAndValidate("id1=1; foo=bar;"); + AddAndValidate("id1=1; foo=bar"); + AddAndValidate("id2=2;"); + AddAndValidate("id4=\"4\""); + AddAndValidate("id5=5; foo=bar; baz=buz;"); +} + +TEST_F(TestCookieMiddleware, Overwrite) { + AddAndValidate("id0=0"); + AddAndValidate("id0=1"); + AddAndValidate("id1=0"); + AddAndValidate("id1=1"); + AddAndValidate("id1=1"); + AddAndValidate("id1=10"); + AddAndValidate("id=3"); + AddAndValidate("id=0"); + AddAndValidate("id=0"); +} + +TEST_F(TestCookieMiddleware, MaxAge) { + AddAndValidate("id0=0; max-age=0;"); + AddAndValidate("id1=0; max-age=-1;"); + AddAndValidate("id2=0; max-age=0"); + AddAndValidate("id3=0; max-age=-1"); + AddAndValidate("id4=0; max-age=1"); + AddAndValidate("id5=0; max-age=1"); + AddAndValidate("id4=0; max-age=0"); + AddAndValidate("id5=0; max-age=0"); +} + +TEST_F(TestCookieMiddleware, Expires) { + AddAndValidate("id0=0; expires=0, 0 0 0 0:0:0 GMT;"); + AddAndValidate("id0=0; expires=0, 0 0 0 0:0:0 GMT"); + AddAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT;"); + AddAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT"); + AddAndValidate("id0=0; expires=Fri, 01 Jan 2038 22:15:36 GMT;"); + AddAndValidate("id1=0; expires=Fri, 01 Jan 2038 22:15:36 GMT"); + AddAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT;"); + AddAndValidate("id1=0; expires=Fri, 22 Dec 2017 22:15:36 GMT"); +} + +// This test is used to test the parsing capabilities of the cookie framework. +class TestCookieParsing : public ::testing::Test { + public: + void VerifyParseCookie(const std::string& cookie_str, bool expired) { + internal::Cookie cookie = internal::Cookie::parse(cookie_str); + EXPECT_EQ(expired, cookie.IsExpired()); + } + + void VerifyCookieName(const std::string& cookie_str, const std::string& name) { + internal::Cookie cookie = internal::Cookie::parse(cookie_str); + EXPECT_EQ(name, cookie.GetName()); + } + + void VerifyCookieString(const std::string& cookie_str, + const std::string& cookie_as_string) { + internal::Cookie cookie = internal::Cookie::parse(cookie_str); + EXPECT_EQ(cookie_as_string, cookie.AsCookieString()); + } + + void VerifyCookieDateConverson(std::string date, const std::string& converted_date) { + internal::Cookie::ConvertCookieDate(&date); + EXPECT_EQ(converted_date, date); + } + + void VerifyCookieAttributeParsing( + const std::string cookie_str, std::string::size_type start_pos, + const util::optional> cookie_attribute, + const std::string::size_type start_pos_after) { + util::optional> attr = + internal::Cookie::ParseCookieAttribute(cookie_str, &start_pos); + + if (cookie_attribute == util::nullopt) { + EXPECT_EQ(cookie_attribute, attr); + } else { + EXPECT_EQ(cookie_attribute.value(), attr.value()); + } + EXPECT_EQ(start_pos_after, start_pos); + } + + void AddCookieVerifyCache(const std::vector& cookies, + const std::string& expected_cookies) { + internal::CookieCache cookie_cache; + for (auto& cookie : cookies) { + // Add cookie + CallHeaders call_headers; + call_headers.insert(std::make_pair(arrow::util::string_view("set-cookie"), + arrow::util::string_view(cookie))); + cookie_cache.UpdateCachedCookies(call_headers); + } + const std::string actual_cookies = cookie_cache.GetValidCookiesAsString(); + const std::vector actual_split_cookies = + TestCookieMiddleware::SplitCookies(actual_cookies); + const std::vector expected_split_cookies = + TestCookieMiddleware::SplitCookies(expected_cookies); + } +}; + +TEST_F(TestCookieParsing, Expired) { + VerifyParseCookie("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT;", true); + VerifyParseCookie("id1=0; max-age=-1;", true); + VerifyParseCookie("id0=0; max-age=0;", true); +} + +TEST_F(TestCookieParsing, Invalid) { + VerifyParseCookie("id1=0; expires=0, 0 0 0 0:0:0 GMT;", true); + VerifyParseCookie("id1=0; expires=Fri, 01 FOO 2038 22:15:36 GMT", true); + VerifyParseCookie("id1=0; expires=foo", true); + VerifyParseCookie("id1=0; expires=", true); + VerifyParseCookie("id1=0; max-age=FOO", true); + VerifyParseCookie("id1=0; max-age=", true); +} + +TEST_F(TestCookieParsing, NoExpiry) { + VerifyParseCookie("id1=0;", false); + VerifyParseCookie("id1=0; noexpiry=Fri, 01 Jan 2038 22:15:36 GMT", false); + VerifyParseCookie("id1=0; noexpiry=\"Fri, 01 Jan 2038 22:15:36 GMT\"", false); + VerifyParseCookie("id1=0; nomax-age=-1", false); + VerifyParseCookie("id1=0; nomax-age=\"-1\"", false); + VerifyParseCookie("id1=0; randomattr=foo", false); +} + +TEST_F(TestCookieParsing, NotExpired) { + VerifyParseCookie("id5=0; max-age=1", false); + VerifyParseCookie("id0=0; expires=Fri, 01 Jan 2038 22:15:36 GMT;", false); +} + +TEST_F(TestCookieParsing, GetName) { + VerifyCookieName("id1=1; foo=bar;", "id1"); + VerifyCookieName("id1=1; foo=bar", "id1"); + VerifyCookieName("id2=2;", "id2"); + VerifyCookieName("id4=\"4\"", "id4"); + VerifyCookieName("id5=5; foo=bar; baz=buz;", "id5"); +} + +TEST_F(TestCookieParsing, ToString) { + VerifyCookieString("id1=1; foo=bar;", "id1=1"); + VerifyCookieString("id1=1; foo=bar", "id1=1"); + VerifyCookieString("id2=2;", "id2=2"); + VerifyCookieString("id4=\"4\"", "id4=4"); + VerifyCookieString("id5=5; foo=bar; baz=buz;", "id5=5"); +} + +TEST_F(TestCookieParsing, DateConversion) { + VerifyCookieDateConverson("Mon, 01 jan 2038 22:15:36 GMT;", "01 01 2038 22:15:36"); + VerifyCookieDateConverson("TUE, 10 Feb 2038 22:15:36 GMT", "10 02 2038 22:15:36"); + VerifyCookieDateConverson("WED, 20 MAr 2038 22:15:36 GMT;", "20 03 2038 22:15:36"); + VerifyCookieDateConverson("thu, 15 APR 2038 22:15:36 GMT", "15 04 2038 22:15:36"); + VerifyCookieDateConverson("Fri, 30 mAY 2038 22:15:36 GMT;", "30 05 2038 22:15:36"); + VerifyCookieDateConverson("Sat, 03 juN 2038 22:15:36 GMT", "03 06 2038 22:15:36"); + VerifyCookieDateConverson("Sun, 01 JuL 2038 22:15:36 GMT;", "01 07 2038 22:15:36"); + VerifyCookieDateConverson("Fri, 06 aUg 2038 22:15:36 GMT", "06 08 2038 22:15:36"); + VerifyCookieDateConverson("Fri, 01 SEP 2038 22:15:36 GMT;", "01 09 2038 22:15:36"); + VerifyCookieDateConverson("Fri, 01 OCT 2038 22:15:36 GMT", "01 10 2038 22:15:36"); + VerifyCookieDateConverson("Fri, 01 Nov 2038 22:15:36 GMT;", "01 11 2038 22:15:36"); + VerifyCookieDateConverson("Fri, 01 deC 2038 22:15:36 GMT", "01 12 2038 22:15:36"); + VerifyCookieDateConverson("", ""); + VerifyCookieDateConverson("Fri, 01 INVALID 2038 22:15:36 GMT;", + "01 INVALID 2038 22:15:36"); +} + +TEST_F(TestCookieParsing, ParseCookieAttribute) { + VerifyCookieAttributeParsing("", 0, util::nullopt, std::string::npos); + + std::string cookie_string = "attr0=0; attr1=1; attr2=2; attr3=3"; + auto attr_length = std::string("attr0=0;").length(); + std::string::size_type start_pos = 0; + VerifyCookieAttributeParsing(cookie_string, start_pos, std::make_pair("attr0", "0"), + cookie_string.find("attr0=0;") + attr_length); + VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length + 1)), + std::make_pair("attr1", "1"), + cookie_string.find("attr1=1;") + attr_length); + VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length + 1)), + std::make_pair("attr2", "2"), + cookie_string.find("attr2=2;") + attr_length); + VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length + 1)), + std::make_pair("attr3", "3"), std::string::npos); + VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length - 1)), + util::nullopt, std::string::npos); + VerifyCookieAttributeParsing(cookie_string, std::string::npos, util::nullopt, + std::string::npos); +} + +TEST_F(TestCookieParsing, CookieCache) { + AddCookieVerifyCache({"id0=0;"}, ""); + AddCookieVerifyCache({"id0=0;", "id0=1;"}, "id0=1"); + AddCookieVerifyCache({"id0=0;", "id1=1;"}, "id0=0; id1=1"); + AddCookieVerifyCache({"id0=0;", "id1=1;", "id2=2"}, "id0=0; id1=1; id2=2"); +} + +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index 003c8e64c91..72644448f2e 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -39,16 +39,14 @@ #include "arrow/util/base64.h" #include "arrow/util/logging.h" #include "arrow/util/make_unique.h" -#include "arrow/util/string.h" #ifdef GRPCPP_GRPCPP_H #error "gRPC headers should not be in public API" #endif -#include "arrow/flight/client_cookie_middleware.h" -#include "arrow/flight/client_header_internal.h" #include "arrow/flight/internal.h" #include "arrow/flight/middleware_internal.h" +#include "arrow/flight/test_definitions.h" #include "arrow/flight/test_util.h" namespace arrow { @@ -65,119 +63,79 @@ const char kBasicPrefix[] = "Basic "; const char kBearerPrefix[] = "Bearer "; const char kAuthHeader[] = "authorization"; -void AssertEqual(const ActionType& expected, const ActionType& actual) { - ASSERT_EQ(expected.type, actual.type); - ASSERT_EQ(expected.description, actual.description); -} +//------------------------------------------------------------ +// Common transport tests -void AssertEqual(const FlightDescriptor& expected, const FlightDescriptor& actual) { - ASSERT_TRUE(expected.Equals(actual)); -} +class GrpcConnectivityTest : public ConnectivityTest { + protected: + std::string transport() const override { return "grpc"; } +}; +TEST_F(GrpcConnectivityTest, GetPort) { TestGetPort(); } +TEST_F(GrpcConnectivityTest, BuilderHook) { TestBuilderHook(); } +TEST_F(GrpcConnectivityTest, Shutdown) { TestShutdown(); } +TEST_F(GrpcConnectivityTest, ShutdownWithDeadline) { TestShutdownWithDeadline(); } -void AssertEqual(const Ticket& expected, const Ticket& actual) { - ASSERT_EQ(expected.ticket, actual.ticket); -} +class GrpcDataTest : public DataTest { + protected: + std::string transport() const override { return "grpc"; } +}; +TEST_F(GrpcDataTest, TestDoGetInts) { TestDoGetInts(); } +TEST_F(GrpcDataTest, TestDoGetFloats) { TestDoGetFloats(); } +TEST_F(GrpcDataTest, TestDoGetDicts) { TestDoGetDicts(); } +TEST_F(GrpcDataTest, TestDoGetLargeBatch) { TestDoGetLargeBatch(); } +TEST_F(GrpcDataTest, TestOverflowServerBatch) { TestOverflowServerBatch(); } +TEST_F(GrpcDataTest, TestOverflowClientBatch) { TestOverflowClientBatch(); } +TEST_F(GrpcDataTest, TestDoExchange) { TestDoExchange(); } +TEST_F(GrpcDataTest, TestDoExchangeNoData) { TestDoExchangeNoData(); } +TEST_F(GrpcDataTest, TestDoExchangeWriteOnlySchema) { TestDoExchangeWriteOnlySchema(); } +TEST_F(GrpcDataTest, TestDoExchangeGet) { TestDoExchangeGet(); } +TEST_F(GrpcDataTest, TestDoExchangePut) { TestDoExchangePut(); } +TEST_F(GrpcDataTest, TestDoExchangeEcho) { TestDoExchangeEcho(); } +TEST_F(GrpcDataTest, TestDoExchangeTotal) { TestDoExchangeTotal(); } +TEST_F(GrpcDataTest, TestDoExchangeError) { TestDoExchangeError(); } +TEST_F(GrpcDataTest, TestIssue5095) { TestIssue5095(); } + +class GrpcDoPutTest : public DoPutTest { + protected: + std::string transport() const override { return "grpc"; } +}; +TEST_F(GrpcDoPutTest, TestInts) { TestInts(); } +TEST_F(GrpcDoPutTest, TestDoPutFloats) { TestDoPutFloats(); } +TEST_F(GrpcDoPutTest, TestDoPutEmptyBatch) { TestDoPutEmptyBatch(); } +TEST_F(GrpcDoPutTest, TestDoPutDicts) { TestDoPutDicts(); } +TEST_F(GrpcDoPutTest, TestDoPutLargeBatch) { TestDoPutLargeBatch(); } +TEST_F(GrpcDoPutTest, TestDoPutSizeLimit) { TestDoPutSizeLimit(); } + +class GrpcAppMetadataTest : public AppMetadataTest { + protected: + std::string transport() const override { return "grpc"; } +}; -void AssertEqual(const Location& expected, const Location& actual) { - ASSERT_EQ(expected, actual); -} +TEST_F(GrpcAppMetadataTest, TestDoGet) { TestDoGet(); } +TEST_F(GrpcAppMetadataTest, TestDoGetDictionaries) { TestDoGetDictionaries(); } +TEST_F(GrpcAppMetadataTest, TestDoPut) { TestDoPut(); } +TEST_F(GrpcAppMetadataTest, TestDoPutDictionaries) { TestDoPutDictionaries(); } +TEST_F(GrpcAppMetadataTest, TestDoPutReadMetadata) { TestDoPutReadMetadata(); } -void AssertEqual(const std::vector& expected, - const std::vector& actual) { - ASSERT_EQ(expected.size(), actual.size()); - for (size_t i = 0; i < expected.size(); ++i) { - AssertEqual(expected[i].ticket, actual[i].ticket); +class GrpcIpcOptionsTest : public IpcOptionsTest { + protected: + std::string transport() const override { return "grpc"; } +}; - ASSERT_EQ(expected[i].locations.size(), actual[i].locations.size()); - for (size_t j = 0; j < expected[i].locations.size(); ++j) { - AssertEqual(expected[i].locations[j], actual[i].locations[j]); - } - } +TEST_F(GrpcIpcOptionsTest, TestDoGetReadOptions) { TestDoGetReadOptions(); } +TEST_F(GrpcIpcOptionsTest, TestDoPutWriteOptions) { TestDoPutWriteOptions(); } +TEST_F(GrpcIpcOptionsTest, TestDoExchangeClientWriteOptions) { + TestDoExchangeClientWriteOptions(); } - -template -void AssertEqual(const std::vector& expected, const std::vector& actual) { - ASSERT_EQ(expected.size(), actual.size()); - for (size_t i = 0; i < expected.size(); ++i) { - AssertEqual(expected[i], actual[i]); - } +TEST_F(GrpcIpcOptionsTest, TestDoExchangeClientWriteOptionsBegin) { + TestDoExchangeClientWriteOptionsBegin(); } - -void AssertEqual(const FlightInfo& expected, const FlightInfo& actual) { - std::shared_ptr ex_schema, actual_schema; - ipc::DictionaryMemo expected_memo; - ipc::DictionaryMemo actual_memo; - ASSERT_OK(expected.GetSchema(&expected_memo, &ex_schema)); - ASSERT_OK(actual.GetSchema(&actual_memo, &actual_schema)); - - AssertSchemaEqual(*ex_schema, *actual_schema); - ASSERT_EQ(expected.total_records(), actual.total_records()); - ASSERT_EQ(expected.total_bytes(), actual.total_bytes()); - - AssertEqual(expected.descriptor(), actual.descriptor()); - AssertEqual(expected.endpoints(), actual.endpoints()); +TEST_F(GrpcIpcOptionsTest, TestDoExchangeServerWriteOptions) { + TestDoExchangeServerWriteOptions(); } -TEST(TestFlightDescriptor, Basics) { - auto a = FlightDescriptor::Command("select * from table"); - auto b = FlightDescriptor::Command("select * from table"); - auto c = FlightDescriptor::Command("select foo from table"); - auto d = FlightDescriptor::Path({"foo", "bar"}); - auto e = FlightDescriptor::Path({"foo", "baz"}); - auto f = FlightDescriptor::Path({"foo", "baz"}); - - ASSERT_EQ(a.ToString(), "FlightDescriptor"); - ASSERT_EQ(d.ToString(), "FlightDescriptor"); - ASSERT_TRUE(a.Equals(b)); - ASSERT_FALSE(a.Equals(c)); - ASSERT_FALSE(a.Equals(d)); - ASSERT_FALSE(d.Equals(e)); - ASSERT_TRUE(e.Equals(f)); -} - -// This tests the internal protobuf types which don't get exported in the Flight DLL. -#ifndef _WIN32 -TEST(TestFlightDescriptor, ToFromProto) { - FlightDescriptor descr_test; - pb::FlightDescriptor pb_descr; - - FlightDescriptor descr1{FlightDescriptor::PATH, "", {"foo", "bar"}}; - ASSERT_OK(internal::ToProto(descr1, &pb_descr)); - ASSERT_OK(internal::FromProto(pb_descr, &descr_test)); - AssertEqual(descr1, descr_test); - - FlightDescriptor descr2{FlightDescriptor::CMD, "command", {}}; - ASSERT_OK(internal::ToProto(descr2, &pb_descr)); - ASSERT_OK(internal::FromProto(pb_descr, &descr_test)); - AssertEqual(descr2, descr_test); -} -#endif - -TEST(TestFlight, DISABLED_StartStopTestServer) { - TestServer server("flight-test-server"); - server.Start(); - ASSERT_TRUE(server.IsRunning()); - - std::this_thread::sleep_for(std::chrono::duration(0.2)); - - ASSERT_TRUE(server.IsRunning()); - int exit_code = server.Stop(); -#ifdef _WIN32 - // We do a hard kill on Windows - ASSERT_EQ(259, exit_code); -#else - ASSERT_EQ(0, exit_code); -#endif - ASSERT_FALSE(server.IsRunning()); -} - -// ARROW-6017: we should be able to construct locations for unknown -// schemes -TEST(TestFlight, UnknownLocationScheme) { - Location location; - ASSERT_OK(Location::Parse("s3://test", &location)); - ASSERT_OK(Location::Parse("https://example.com/foo", &location)); -} +//------------------------------------------------------------ +// Ad-hoc gRPC-specific tests TEST(TestFlight, ConnectUri) { TestServer server("flight-test-server"); @@ -221,108 +179,6 @@ TEST(TestFlight, ConnectUriUnix) { } #endif -TEST(TestFlight, RoundTripTypes) { - Ticket ticket{"foo"}; - ASSERT_OK_AND_ASSIGN(std::string ticket_serialized, ticket.SerializeToString()); - ASSERT_OK_AND_ASSIGN(Ticket ticket_deserialized, - Ticket::Deserialize(ticket_serialized)); - ASSERT_EQ(ticket.ticket, ticket_deserialized.ticket); - - FlightDescriptor desc = FlightDescriptor::Command("select * from foo;"); - ASSERT_OK_AND_ASSIGN(std::string desc_serialized, desc.SerializeToString()); - ASSERT_OK_AND_ASSIGN(FlightDescriptor desc_deserialized, - FlightDescriptor::Deserialize(desc_serialized)); - ASSERT_TRUE(desc.Equals(desc_deserialized)); - - desc = FlightDescriptor::Path({"a", "b", "test.arrow"}); - ASSERT_OK_AND_ASSIGN(desc_serialized, desc.SerializeToString()); - ASSERT_OK_AND_ASSIGN(desc_deserialized, FlightDescriptor::Deserialize(desc_serialized)); - ASSERT_TRUE(desc.Equals(desc_deserialized)); - - FlightInfo::Data data; - std::shared_ptr schema = - arrow::schema({field("a", int64()), field("b", int64()), field("c", int64()), - field("d", int64())}); - Location location1, location2, location3; - ASSERT_OK(Location::ForGrpcTcp("localhost", 10010, &location1)); - ASSERT_OK(Location::ForGrpcTls("localhost", 10010, &location2)); - ASSERT_OK(Location::ForGrpcUnix("/tmp/test.sock", &location3)); - std::vector endpoints{FlightEndpoint{ticket, {location1, location2}}, - FlightEndpoint{ticket, {location3}}}; - ASSERT_OK(MakeFlightInfo(*schema, desc, endpoints, -1, -1, &data)); - std::unique_ptr info = std::unique_ptr(new FlightInfo(data)); - ASSERT_OK_AND_ASSIGN(std::string info_serialized, info->SerializeToString()); - ASSERT_OK_AND_ASSIGN(std::unique_ptr info_deserialized, - FlightInfo::Deserialize(info_serialized)); - ASSERT_TRUE(info->descriptor().Equals(info_deserialized->descriptor())); - ASSERT_EQ(info->endpoints(), info_deserialized->endpoints()); - ASSERT_EQ(info->total_records(), info_deserialized->total_records()); - ASSERT_EQ(info->total_bytes(), info_deserialized->total_bytes()); -} - -TEST(TestFlight, RoundtripStatus) { - // Make sure status codes round trip through our conversions - - std::shared_ptr detail; - detail = FlightStatusDetail::UnwrapStatus( - MakeFlightError(FlightStatusCode::Internal, "Test message")); - ASSERT_NE(nullptr, detail); - ASSERT_EQ(FlightStatusCode::Internal, detail->code()); - - detail = FlightStatusDetail::UnwrapStatus( - MakeFlightError(FlightStatusCode::TimedOut, "Test message")); - ASSERT_NE(nullptr, detail); - ASSERT_EQ(FlightStatusCode::TimedOut, detail->code()); - - detail = FlightStatusDetail::UnwrapStatus( - MakeFlightError(FlightStatusCode::Cancelled, "Test message")); - ASSERT_NE(nullptr, detail); - ASSERT_EQ(FlightStatusCode::Cancelled, detail->code()); - - detail = FlightStatusDetail::UnwrapStatus( - MakeFlightError(FlightStatusCode::Unauthenticated, "Test message")); - ASSERT_NE(nullptr, detail); - ASSERT_EQ(FlightStatusCode::Unauthenticated, detail->code()); - - detail = FlightStatusDetail::UnwrapStatus( - MakeFlightError(FlightStatusCode::Unauthorized, "Test message")); - ASSERT_NE(nullptr, detail); - ASSERT_EQ(FlightStatusCode::Unauthorized, detail->code()); - - detail = FlightStatusDetail::UnwrapStatus( - MakeFlightError(FlightStatusCode::Unavailable, "Test message")); - ASSERT_NE(nullptr, detail); - ASSERT_EQ(FlightStatusCode::Unavailable, detail->code()); - - Status status = internal::FromGrpcStatus( - internal::ToGrpcStatus(Status::NotImplemented("Sentinel"))); - ASSERT_TRUE(status.IsNotImplemented()); - ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel")); - - status = internal::FromGrpcStatus(internal::ToGrpcStatus(Status::Invalid("Sentinel"))); - ASSERT_TRUE(status.IsInvalid()); - ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel")); - - status = internal::FromGrpcStatus(internal::ToGrpcStatus(Status::KeyError("Sentinel"))); - ASSERT_TRUE(status.IsKeyError()); - ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel")); - - status = - internal::FromGrpcStatus(internal::ToGrpcStatus(Status::AlreadyExists("Sentinel"))); - ASSERT_TRUE(status.IsAlreadyExists()); - ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel")); -} - -TEST(TestFlight, GetPort) { - Location location; - std::unique_ptr server = ExampleTestServer(); - - ASSERT_OK(Location::ForGrpcTcp("localhost", 0, &location)); - FlightServerOptions options(location); - ASSERT_OK(server->Init(options)); - ASSERT_GT(server->port(), 0); -} - // CI environments don't have an IPv6 interface configured TEST(TestFlight, DISABLED_IpV6Port) { Location location, location2; @@ -340,56 +196,6 @@ TEST(TestFlight, DISABLED_IpV6Port) { ASSERT_OK(client->ListFlights(&listing)); } -TEST(TestFlight, BuilderHook) { - Location location; - std::unique_ptr server = ExampleTestServer(); - - ASSERT_OK(Location::ForGrpcTcp("localhost", 0, &location)); - FlightServerOptions options(location); - bool builder_hook_run = false; - options.builder_hook = [&builder_hook_run](void* builder) { - ASSERT_NE(nullptr, builder); - builder_hook_run = true; - }; - ASSERT_OK(server->Init(options)); - ASSERT_TRUE(builder_hook_run); - ASSERT_GT(server->port(), 0); - ASSERT_OK(server->Shutdown()); -} - -TEST(TestFlight, ServeShutdown) { - // Regression test for ARROW-15181 - constexpr int kIterations = 10; - for (int i = 0; i < kIterations; i++) { - Location location; - std::unique_ptr server = ExampleTestServer(); - - ASSERT_OK(Location::ForGrpcTcp("localhost", 0, &location)); - FlightServerOptions options(location); - ASSERT_OK(server->Init(options)); - ASSERT_GT(server->port(), 0); - std::thread t([&]() { ASSERT_OK(server->Serve()); }); - ASSERT_OK(server->Shutdown()); - ASSERT_OK(server->Wait()); - t.join(); - } -} - -TEST(TestFlight, ServeShutdownWithDeadline) { - Location location; - std::unique_ptr server = ExampleTestServer(); - - ASSERT_OK(Location::ForGrpcTcp("localhost", 0, &location)); - FlightServerOptions options(location); - ASSERT_OK(server->Init(options)); - ASSERT_GT(server->port(), 0); - - auto deadline = std::chrono::system_clock::now() + std::chrono::microseconds(10); - - ASSERT_OK(server->Shutdown(&deadline)); - ASSERT_OK(server->Wait()); -} - // ---------------------------------------------------------------------- // Client tests @@ -418,7 +224,8 @@ class TestFlightClient : public ::testing::Test { } template - void CheckDoGet(const FlightDescriptor& descr, const BatchVector& expected_batches, + void CheckDoGet(const FlightDescriptor& descr, + const RecordBatchVector& expected_batches, EndpointCheckFunc&& check_endpoints) { auto expected_schema = expected_batches[0]->schema(); @@ -436,7 +243,7 @@ class TestFlightClient : public ::testing::Test { CheckDoGet(ticket, expected_batches); } - void CheckDoGet(const Ticket& ticket, const BatchVector& expected_batches) { + void CheckDoGet(const Ticket& ticket, const RecordBatchVector& expected_batches) { auto num_batches = static_cast(expected_batches.size()); ASSERT_GE(num_batches, 2); @@ -504,121 +311,6 @@ class TlsTestServer : public FlightServerBase { } }; -class DoPutTestServer : public FlightServerBase { - public: - Status DoPut(const ServerCallContext& context, - std::unique_ptr reader, - std::unique_ptr writer) override { - descriptor_ = reader->descriptor(); - return reader->ReadAll(&batches_); - } - - protected: - FlightDescriptor descriptor_; - BatchVector batches_; - - friend class TestDoPut; -}; - -class MetadataTestServer : public FlightServerBase { - Status DoGet(const ServerCallContext& context, const Ticket& request, - std::unique_ptr* data_stream) override { - BatchVector batches; - if (request.ticket == "dicts") { - RETURN_NOT_OK(ExampleDictBatches(&batches)); - } else if (request.ticket == "floats") { - RETURN_NOT_OK(ExampleFloatBatches(&batches)); - } else { - RETURN_NOT_OK(ExampleIntBatches(&batches)); - } - std::shared_ptr batch_reader = - std::make_shared(batches[0]->schema(), batches); - - *data_stream = std::unique_ptr(new NumberingStream( - std::unique_ptr(new RecordBatchStream(batch_reader)))); - return Status::OK(); - } - - Status DoPut(const ServerCallContext& context, - std::unique_ptr reader, - std::unique_ptr writer) override { - FlightStreamChunk chunk; - int counter = 0; - while (true) { - RETURN_NOT_OK(reader->Next(&chunk)); - if (chunk.data == nullptr) break; - if (chunk.app_metadata == nullptr) { - return Status::Invalid("Expected application metadata to be provided"); - } - if (std::to_string(counter) != chunk.app_metadata->ToString()) { - return Status::Invalid("Expected metadata value: " + std::to_string(counter) + - " but got: " + chunk.app_metadata->ToString()); - } - auto metadata = Buffer::FromString(std::to_string(counter)); - RETURN_NOT_OK(writer->WriteMetadata(*metadata)); - counter++; - } - return Status::OK(); - } -}; - -// Server for testing custom IPC options support -class OptionsTestServer : public FlightServerBase { - Status DoGet(const ServerCallContext& context, const Ticket& request, - std::unique_ptr* data_stream) override { - BatchVector batches; - RETURN_NOT_OK(ExampleNestedBatches(&batches)); - auto reader = std::make_shared(batches[0]->schema(), batches); - *data_stream = std::unique_ptr(new RecordBatchStream(reader)); - return Status::OK(); - } - - // Just echo the number of batches written. The client will try to - // call this method with different write options set. - Status DoPut(const ServerCallContext& context, - std::unique_ptr reader, - std::unique_ptr writer) override { - FlightStreamChunk chunk; - int counter = 0; - while (true) { - RETURN_NOT_OK(reader->Next(&chunk)); - if (chunk.data == nullptr) break; - counter++; - } - auto metadata = Buffer::FromString(std::to_string(counter)); - return writer->WriteMetadata(*metadata); - } - - // Echo client data, but with write options set to limit the nesting - // level. - Status DoExchange(const ServerCallContext& context, - std::unique_ptr reader, - std::unique_ptr writer) override { - FlightStreamChunk chunk; - auto options = ipc::IpcWriteOptions::Defaults(); - options.max_recursion_depth = 1; - bool begun = false; - while (true) { - RETURN_NOT_OK(reader->Next(&chunk)); - if (!chunk.data && !chunk.app_metadata) { - break; - } - if (!begun && chunk.data) { - begun = true; - RETURN_NOT_OK(writer->Begin(chunk.data->schema(), options)); - } - if (chunk.data && chunk.app_metadata) { - RETURN_NOT_OK(writer->WriteWithMetadata(*chunk.data, chunk.app_metadata)); - } else if (chunk.data) { - RETURN_NOT_OK(writer->WriteRecordBatch(*chunk.data)); - } else if (chunk.app_metadata) { - RETURN_NOT_OK(writer->WriteMetadata(chunk.app_metadata)); - } - } - return Status::OK(); - } -}; - class HeaderAuthTestServer : public FlightServerBase { public: Status ListFlights(const ServerCallContext& context, const Criteria* criteria, @@ -627,42 +319,6 @@ class HeaderAuthTestServer : public FlightServerBase { } }; -class TestMetadata : public ::testing::Test { - public: - void SetUp() { - ASSERT_OK(MakeServer( - &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); }, - [](FlightClientOptions* options) { return Status::OK(); })); - } - - void TearDown() { - ASSERT_OK(client_->Close()); - ASSERT_OK(server_->Shutdown()); - } - - protected: - std::unique_ptr client_; - std::unique_ptr server_; -}; - -class TestOptions : public ::testing::Test { - public: - void SetUp() { - ASSERT_OK(MakeServer( - &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); }, - [](FlightClientOptions* options) { return Status::OK(); })); - } - - void TearDown() { - ASSERT_OK(client_->Close()); - ASSERT_OK(server_->Shutdown()); - } - - protected: - std::unique_ptr client_; - std::unique_ptr server_; -}; - class TestAuthHandler : public ::testing::Test { public: void SetUp() { @@ -709,49 +365,6 @@ class TestBasicAuthHandler : public ::testing::Test { std::unique_ptr server_; }; -class TestDoPut : public ::testing::Test { - public: - void SetUp() { - ASSERT_OK(MakeServer( - &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); }, - [](FlightClientOptions* options) { return Status::OK(); })); - do_put_server_ = (DoPutTestServer*)server_.get(); - } - - void TearDown() { - ASSERT_OK(client_->Close()); - ASSERT_OK(server_->Shutdown()); - } - - void CheckBatches(FlightDescriptor expected_descriptor, - const BatchVector& expected_batches) { - ASSERT_TRUE(do_put_server_->descriptor_.Equals(expected_descriptor)); - ASSERT_EQ(do_put_server_->batches_.size(), expected_batches.size()); - for (size_t i = 0; i < expected_batches.size(); ++i) { - ASSERT_BATCHES_EQUAL(*do_put_server_->batches_[i], *expected_batches[i]); - } - } - - void CheckDoPut(FlightDescriptor descr, const std::shared_ptr& schema, - const BatchVector& batches) { - std::unique_ptr stream; - std::unique_ptr reader; - ASSERT_OK(client_->DoPut(descr, schema, &stream, &reader)); - for (const auto& batch : batches) { - ASSERT_OK(stream->WriteRecordBatch(*batch)); - } - ASSERT_OK(stream->DoneWriting()); - ASSERT_OK(stream->Close()); - - CheckBatches(descr, batches); - } - - protected: - std::unique_ptr client_; - std::unique_ptr server_; - DoPutTestServer* do_put_server_; -}; - class TestTls : public ::testing::Test { public: void SetUp() { @@ -1080,7 +693,7 @@ class PropagatingTestServer : public FlightServerBase { class TestRejectServerMiddleware : public ::testing::Test { public: void SetUp() { - ASSERT_OK(MakeServer( + ASSERT_OK(MakeServer( &server_, &client_, [](FlightServerOptions* options) { options->middleware.push_back( @@ -1104,7 +717,7 @@ class TestCountingServerMiddleware : public ::testing::Test { public: void SetUp() { request_counter_ = std::make_shared(); - ASSERT_OK(MakeServer( + ASSERT_OK(MakeServer( &server_, &client_, [&](FlightServerOptions* options) { options->middleware.push_back({"request_counter", request_counter_}); @@ -1264,139 +877,6 @@ class TestBasicHeaderAuthMiddleware : public ::testing::Test { std::shared_ptr bearer_middleware_; }; -// This test keeps an internal cookie cache and compares that with the middleware. -class TestCookieMiddleware : public ::testing::Test { - public: - // Setup function creates middleware factory and starts it up. - void SetUp() { - factory_ = GetCookieFactory(); - CallInfo callInfo; - factory_->StartCall(callInfo, &middleware_); - } - - // Function to add incoming cookies to middleware and validate them. - void AddAndValidate(const std::string& incoming_cookie) { - // Add cookie - CallHeaders call_headers; - call_headers.insert(std::make_pair(arrow::util::string_view("set-cookie"), - arrow::util::string_view(incoming_cookie))); - middleware_->ReceivedHeaders(call_headers); - expected_cookie_cache_.UpdateCachedCookies(call_headers); - - // Get cookie from middleware. - TestCallHeaders add_call_headers; - middleware_->SendingHeaders(&add_call_headers); - const std::string actual_cookies = add_call_headers.GetCookies(); - - // Validate cookie - const std::string expected_cookies = expected_cookie_cache_.GetValidCookiesAsString(); - const std::vector split_expected_cookies = - SplitCookies(expected_cookies); - const std::vector split_actual_cookies = SplitCookies(actual_cookies); - EXPECT_EQ(split_expected_cookies, split_actual_cookies); - } - - // Function to take a list of cookies and split them into a vector of individual - // cookies. This is done because the cookie cache is a map so ordering is not - // necessarily consistent. - static std::vector SplitCookies(const std::string& cookies) { - std::vector split_cookies; - std::string::size_type pos1 = 0; - std::string::size_type pos2 = 0; - while ((pos2 = cookies.find(';', pos1)) != std::string::npos) { - split_cookies.push_back( - arrow::internal::TrimString(cookies.substr(pos1, pos2 - pos1))); - pos1 = pos2 + 1; - } - if (pos1 < cookies.size()) { - split_cookies.push_back(arrow::internal::TrimString(cookies.substr(pos1))); - } - std::sort(split_cookies.begin(), split_cookies.end()); - return split_cookies; - } - - protected: - // Class to allow testing of the call headers. - class TestCallHeaders : public AddCallHeaders { - public: - TestCallHeaders() {} - ~TestCallHeaders() {} - - // Function to add cookie header. - void AddHeader(const std::string& key, const std::string& value) { - ASSERT_EQ(key, "cookie"); - outbound_cookie_ = value; - } - - // Function to get outgoing cookie. - std::string GetCookies() { return outbound_cookie_; } - - private: - std::string outbound_cookie_; - }; - - internal::CookieCache expected_cookie_cache_; - std::unique_ptr middleware_; - std::shared_ptr factory_; -}; - -// This test is used to test the parsing capabilities of the cookie framework. -class TestCookieParsing : public ::testing::Test { - public: - void VerifyParseCookie(const std::string& cookie_str, bool expired) { - internal::Cookie cookie = internal::Cookie::parse(cookie_str); - EXPECT_EQ(expired, cookie.IsExpired()); - } - - void VerifyCookieName(const std::string& cookie_str, const std::string& name) { - internal::Cookie cookie = internal::Cookie::parse(cookie_str); - EXPECT_EQ(name, cookie.GetName()); - } - - void VerifyCookieString(const std::string& cookie_str, - const std::string& cookie_as_string) { - internal::Cookie cookie = internal::Cookie::parse(cookie_str); - EXPECT_EQ(cookie_as_string, cookie.AsCookieString()); - } - - void VerifyCookieDateConverson(std::string date, const std::string& converted_date) { - internal::Cookie::ConvertCookieDate(&date); - EXPECT_EQ(converted_date, date); - } - - void VerifyCookieAttributeParsing( - const std::string cookie_str, std::string::size_type start_pos, - const util::optional> cookie_attribute, - const std::string::size_type start_pos_after) { - util::optional> attr = - internal::Cookie::ParseCookieAttribute(cookie_str, &start_pos); - - if (cookie_attribute == util::nullopt) { - EXPECT_EQ(cookie_attribute, attr); - } else { - EXPECT_EQ(cookie_attribute.value(), attr.value()); - } - EXPECT_EQ(start_pos_after, start_pos); - } - - void AddCookieVerifyCache(const std::vector& cookies, - const std::string& expected_cookies) { - internal::CookieCache cookie_cache; - for (auto& cookie : cookies) { - // Add cookie - CallHeaders call_headers; - call_headers.insert(std::make_pair(arrow::util::string_view("set-cookie"), - arrow::util::string_view(cookie))); - cookie_cache.UpdateCachedCookies(call_headers); - } - const std::string actual_cookies = cookie_cache.GetValidCookiesAsString(); - const std::vector actual_split_cookies = - TestCookieMiddleware::SplitCookies(actual_cookies); - const std::vector expected_split_cookies = - TestCookieMiddleware::SplitCookies(expected_cookies); - } -}; - TEST_F(TestErrorMiddleware, TestMetadata) { Action action; std::unique_ptr stream; @@ -1472,348 +952,12 @@ TEST_F(TestFlightClient, GetFlightInfoNotFound) { ASSERT_NE(st.message().find("Flight not found"), std::string::npos); } -TEST_F(TestFlightClient, DoGetInts) { - auto descr = FlightDescriptor::Path({"examples", "ints"}); - BatchVector expected_batches; - ASSERT_OK(ExampleIntBatches(&expected_batches)); - - auto check_endpoints = [](const std::vector& endpoints) { - // Two endpoints in the example FlightInfo - ASSERT_EQ(2, endpoints.size()); - AssertEqual(Ticket{"ticket-ints-1"}, endpoints[0].ticket); - }; - - CheckDoGet(descr, expected_batches, check_endpoints); -} - -TEST_F(TestFlightClient, DoGetFloats) { - auto descr = FlightDescriptor::Path({"examples", "floats"}); - BatchVector expected_batches; - ASSERT_OK(ExampleFloatBatches(&expected_batches)); - - auto check_endpoints = [](const std::vector& endpoints) { - // One endpoint in the example FlightInfo - ASSERT_EQ(1, endpoints.size()); - AssertEqual(Ticket{"ticket-floats-1"}, endpoints[0].ticket); - }; - - CheckDoGet(descr, expected_batches, check_endpoints); -} - -TEST_F(TestFlightClient, DoGetDicts) { - auto descr = FlightDescriptor::Path({"examples", "dicts"}); - BatchVector expected_batches; - ASSERT_OK(ExampleDictBatches(&expected_batches)); - - auto check_endpoints = [](const std::vector& endpoints) { - // One endpoint in the example FlightInfo - ASSERT_EQ(1, endpoints.size()); - AssertEqual(Ticket{"ticket-dicts-1"}, endpoints[0].ticket); - }; - - CheckDoGet(descr, expected_batches, check_endpoints); -} - -// Ensure the gRPC client is configured to allow large messages -// Tests a 32 MiB batch -TEST_F(TestFlightClient, DoGetLargeBatch) { - BatchVector expected_batches; - ASSERT_OK(ExampleLargeBatches(&expected_batches)); - Ticket ticket{"ticket-large-batch-1"}; - CheckDoGet(ticket, expected_batches); -} - -TEST_F(TestFlightClient, FlightDataOverflowServerBatch) { - // Regression test for ARROW-13253 - // N.B. this is rather a slow and memory-hungry test - { - // DoGet: check for overflow on large batch - Ticket ticket{"ARROW-13253-DoGet-Batch"}; - std::unique_ptr stream; - ASSERT_OK(client_->DoGet(ticket, &stream)); - FlightStreamChunk chunk; - EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"), - stream->Next(&chunk)); - } - { - // DoExchange: check for overflow on large batch from server - auto descr = FlightDescriptor::Command("large_batch"); - std::unique_ptr reader; - std::unique_ptr writer; - ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); - BatchVector batches; - EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"), - reader->ReadAll(&batches)); - } -} - -TEST_F(TestFlightClient, FlightDataOverflowClientBatch) { - ASSERT_OK_AND_ASSIGN(auto batch, VeryLargeBatch()); - { - // DoPut: check for overflow on large batch - std::unique_ptr stream; - std::unique_ptr reader; - auto descr = FlightDescriptor::Path({""}); - ASSERT_OK(client_->DoPut(descr, batch->schema(), &stream, &reader)); - EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"), - stream->WriteRecordBatch(*batch)); - ASSERT_OK(stream->Close()); - } - { - // DoExchange: check for overflow on large batch from client - auto descr = FlightDescriptor::Command("counter"); - std::unique_ptr reader; - std::unique_ptr writer; - ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); - ASSERT_OK(writer->Begin(batch->schema())); - EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"), - writer->WriteRecordBatch(*batch)); - ASSERT_OK(writer->Close()); - } -} - -TEST_F(TestFlightClient, DoExchange) { - auto descr = FlightDescriptor::Command("counter"); - BatchVector batches; - auto a1 = ArrayFromJSON(int32(), "[4, 5, 6, null]"); - auto schema = arrow::schema({field("f1", a1->type())}); - batches.push_back(RecordBatch::Make(schema, a1->length(), {a1})); - std::unique_ptr reader; - std::unique_ptr writer; - ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); - ASSERT_OK(writer->Begin(schema)); - for (const auto& batch : batches) { - ASSERT_OK(writer->WriteRecordBatch(*batch)); - } - ASSERT_OK(writer->DoneWriting()); - FlightStreamChunk chunk; - ASSERT_OK(reader->Next(&chunk)); - ASSERT_NE(nullptr, chunk.app_metadata); - ASSERT_EQ(nullptr, chunk.data); - ASSERT_EQ("1", chunk.app_metadata->ToString()); - ASSERT_OK_AND_ASSIGN(auto server_schema, reader->GetSchema()); - AssertSchemaEqual(schema, server_schema); - for (const auto& batch : batches) { - ASSERT_OK(reader->Next(&chunk)); - ASSERT_BATCHES_EQUAL(*batch, *chunk.data); - } - ASSERT_OK(writer->Close()); -} - -// Test pure-metadata DoExchange to ensure nothing blocks waiting for -// schema messages -TEST_F(TestFlightClient, DoExchangeNoData) { - auto descr = FlightDescriptor::Command("counter"); - std::unique_ptr reader; - std::unique_ptr writer; - ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); - ASSERT_OK(writer->DoneWriting()); - FlightStreamChunk chunk; - ASSERT_OK(reader->Next(&chunk)); - ASSERT_EQ(nullptr, chunk.data); - ASSERT_NE(nullptr, chunk.app_metadata); - ASSERT_EQ("0", chunk.app_metadata->ToString()); - ASSERT_OK(writer->Close()); -} - -// Test sending a schema without any data, as this hits an edge case -// in the client-side writer. -TEST_F(TestFlightClient, DoExchangeWriteOnlySchema) { - auto descr = FlightDescriptor::Command("counter"); - std::unique_ptr reader; - std::unique_ptr writer; - ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); - auto schema = arrow::schema({field("f1", arrow::int32())}); - ASSERT_OK(writer->Begin(schema)); - ASSERT_OK(writer->WriteMetadata(Buffer::FromString("foo"))); - ASSERT_OK(writer->DoneWriting()); - FlightStreamChunk chunk; - ASSERT_OK(reader->Next(&chunk)); - ASSERT_EQ(nullptr, chunk.data); - ASSERT_NE(nullptr, chunk.app_metadata); - ASSERT_EQ("0", chunk.app_metadata->ToString()); - ASSERT_OK(writer->Close()); -} - -// Emulate DoGet -TEST_F(TestFlightClient, DoExchangeGet) { - auto descr = FlightDescriptor::Command("get"); - std::unique_ptr reader; - std::unique_ptr writer; - ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); - ASSERT_OK_AND_ASSIGN(auto server_schema, reader->GetSchema()); - AssertSchemaEqual(*ExampleIntSchema(), *server_schema); - BatchVector batches; - ASSERT_OK(ExampleIntBatches(&batches)); - FlightStreamChunk chunk; - for (const auto& batch : batches) { - ASSERT_OK(reader->Next(&chunk)); - ASSERT_NE(nullptr, chunk.data); - AssertBatchesEqual(*batch, *chunk.data); - } - ASSERT_OK(reader->Next(&chunk)); - ASSERT_EQ(nullptr, chunk.data); - ASSERT_EQ(nullptr, chunk.app_metadata); - ASSERT_OK(writer->Close()); -} - -// Emulate DoPut -TEST_F(TestFlightClient, DoExchangePut) { - auto descr = FlightDescriptor::Command("put"); - std::unique_ptr reader; - std::unique_ptr writer; - ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); - ASSERT_OK(writer->Begin(ExampleIntSchema())); - BatchVector batches; - ASSERT_OK(ExampleIntBatches(&batches)); - for (const auto& batch : batches) { - ASSERT_OK(writer->WriteRecordBatch(*batch)); - } - ASSERT_OK(writer->DoneWriting()); - FlightStreamChunk chunk; - ASSERT_OK(reader->Next(&chunk)); - ASSERT_NE(nullptr, chunk.app_metadata); - AssertBufferEqual(*chunk.app_metadata, "done"); - ASSERT_OK(reader->Next(&chunk)); - ASSERT_EQ(nullptr, chunk.data); - ASSERT_EQ(nullptr, chunk.app_metadata); - ASSERT_OK(writer->Close()); -} - -// Test the echo server -TEST_F(TestFlightClient, DoExchangeEcho) { - auto descr = FlightDescriptor::Command("echo"); - std::unique_ptr reader; - std::unique_ptr writer; - ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); - ASSERT_OK(writer->Begin(ExampleIntSchema())); - BatchVector batches; - FlightStreamChunk chunk; - ASSERT_OK(ExampleIntBatches(&batches)); - for (const auto& batch : batches) { - ASSERT_OK(writer->WriteRecordBatch(*batch)); - ASSERT_OK(reader->Next(&chunk)); - ASSERT_NE(nullptr, chunk.data); - ASSERT_EQ(nullptr, chunk.app_metadata); - AssertBatchesEqual(*batch, *chunk.data); - } - for (int i = 0; i < 10; i++) { - const auto buf = Buffer::FromString(std::to_string(i)); - ASSERT_OK(writer->WriteMetadata(buf)); - ASSERT_OK(reader->Next(&chunk)); - ASSERT_EQ(nullptr, chunk.data); - ASSERT_NE(nullptr, chunk.app_metadata); - AssertBufferEqual(*buf, *chunk.app_metadata); - } - int index = 0; - for (const auto& batch : batches) { - const auto buf = Buffer::FromString(std::to_string(index)); - ASSERT_OK(writer->WriteWithMetadata(*batch, buf)); - ASSERT_OK(reader->Next(&chunk)); - ASSERT_NE(nullptr, chunk.data); - ASSERT_NE(nullptr, chunk.app_metadata); - AssertBatchesEqual(*batch, *chunk.data); - AssertBufferEqual(*buf, *chunk.app_metadata); - index++; - } - ASSERT_OK(writer->DoneWriting()); - ASSERT_OK(reader->Next(&chunk)); - ASSERT_EQ(nullptr, chunk.data); - ASSERT_EQ(nullptr, chunk.app_metadata); - ASSERT_OK(writer->Close()); -} - -// Test interleaved reading/writing -TEST_F(TestFlightClient, DoExchangeTotal) { - auto descr = FlightDescriptor::Command("total"); - std::unique_ptr reader; - std::unique_ptr writer; - { - auto a1 = ArrayFromJSON(arrow::int32(), "[4, 5, 6, null]"); - auto schema = arrow::schema({field("f1", a1->type())}); - // XXX: as noted in flight/client.cc, Begin() is lazy and the - // schema message won't be written until some data is also - // written. There's also timing issues; hence we check each status - // here. - EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, ::testing::HasSubstr("Field is not INT64: f1"), ([&]() { - RETURN_NOT_OK(client_->DoExchange(descr, &writer, &reader)); - RETURN_NOT_OK(writer->Begin(schema)); - auto batch = RecordBatch::Make(schema, /* num_rows */ 4, {a1}); - RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); - return writer->Close(); - })()); - } - { - auto a1 = ArrayFromJSON(arrow::int64(), "[1, 2, null, 3]"); - auto a2 = ArrayFromJSON(arrow::int64(), "[null, 4, 5, 6]"); - auto schema = arrow::schema({field("f1", a1->type()), field("f2", a2->type())}); - ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); - ASSERT_OK(writer->Begin(schema)); - auto batch = RecordBatch::Make(schema, /* num_rows */ 4, {a1, a2}); - FlightStreamChunk chunk; - ASSERT_OK(writer->WriteRecordBatch(*batch)); - ASSERT_OK_AND_ASSIGN(auto server_schema, reader->GetSchema()); - AssertSchemaEqual(*schema, *server_schema); - - ASSERT_OK(reader->Next(&chunk)); - ASSERT_NE(nullptr, chunk.data); - auto expected1 = RecordBatch::Make( - schema, /* num_rows */ 1, - {ArrayFromJSON(arrow::int64(), "[6]"), ArrayFromJSON(arrow::int64(), "[15]")}); - AssertBatchesEqual(*expected1, *chunk.data); - - ASSERT_OK(writer->WriteRecordBatch(*batch)); - ASSERT_OK(reader->Next(&chunk)); - ASSERT_NE(nullptr, chunk.data); - auto expected2 = RecordBatch::Make( - schema, /* num_rows */ 1, - {ArrayFromJSON(arrow::int64(), "[12]"), ArrayFromJSON(arrow::int64(), "[30]")}); - AssertBatchesEqual(*expected2, *chunk.data); - - ASSERT_OK(writer->Close()); - } -} - -// Ensure server errors get propagated no matter what we try -TEST_F(TestFlightClient, DoExchangeError) { - auto descr = FlightDescriptor::Command("error"); - std::unique_ptr reader; - std::unique_ptr writer; - { - ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); - auto status = writer->Close(); - EXPECT_RAISES_WITH_MESSAGE_THAT( - NotImplemented, ::testing::HasSubstr("Expected error"), writer->Close()); - } - { - ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); - FlightStreamChunk chunk; - EXPECT_RAISES_WITH_MESSAGE_THAT( - NotImplemented, ::testing::HasSubstr("Expected error"), reader->Next(&chunk)); - } - { - ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); - EXPECT_RAISES_WITH_MESSAGE_THAT( - NotImplemented, ::testing::HasSubstr("Expected error"), reader->GetSchema()); - } - // writer->Begin isn't tested here because, as noted in client.cc, - // OpenRecordBatchWriter lazily writes the initial message - hence - // Begin() won't fail. Additionally, it appears gRPC may buffer - // writes - a write won't immediately fail even when the server - // immediately fails. -} - TEST_F(TestFlightClient, ListActions) { std::vector actions; ASSERT_OK(client_->ListActions(&actions)); std::vector expected = ExampleActionTypes(); - AssertEqual(expected, actions); + EXPECT_THAT(actions, ::testing::ContainerEq(expected)); } TEST_F(TestFlightClient, DoAction) { @@ -1853,21 +997,6 @@ TEST_F(TestFlightClient, RoundTripStatus) { ASSERT_RAISES(OutOfMemory, status); } -TEST_F(TestFlightClient, Issue5095) { - // Make sure the server-side error message is reflected to the - // client - Ticket ticket1{"ARROW-5095-fail"}; - std::unique_ptr stream; - Status status = client_->DoGet(ticket1, &stream); - ASSERT_RAISES(UnknownError, status); - ASSERT_THAT(status.message(), ::testing::HasSubstr("Server-side error")); - - Ticket ticket2{"ARROW-5095-success"}; - status = client_->DoGet(ticket2, &stream); - ASSERT_RAISES(KeyError, status); - ASSERT_THAT(status.message(), ::testing::HasSubstr("No data")); -} - // Test setting generic transport options by configuring gRPC to fail // all calls. TEST_F(TestFlightClient, GenericOptions) { @@ -1936,117 +1065,6 @@ TEST_F(TestFlightClient, Close) { client_->ListFlights(&listing)); } -TEST_F(TestDoPut, DoPutInts) { - auto descr = FlightDescriptor::Path({"ints"}); - BatchVector batches; - auto a0 = ArrayFromJSON(int8(), "[0, 1, 127, -128, null]"); - auto a1 = ArrayFromJSON(uint8(), "[0, 1, 127, 255, null]"); - auto a2 = ArrayFromJSON(int16(), "[0, 258, 32767, -32768, null]"); - auto a3 = ArrayFromJSON(uint16(), "[0, 258, 32767, 65535, null]"); - auto a4 = ArrayFromJSON(int32(), "[0, 65538, 2147483647, -2147483648, null]"); - auto a5 = ArrayFromJSON(uint32(), "[0, 65538, 2147483647, 4294967295, null]"); - auto a6 = ArrayFromJSON( - int64(), "[0, 4294967298, 9223372036854775807, -9223372036854775808, null]"); - auto a7 = ArrayFromJSON( - uint64(), "[0, 4294967298, 9223372036854775807, 18446744073709551615, null]"); - auto schema = arrow::schema({field("f0", a0->type()), field("f1", a1->type()), - field("f2", a2->type()), field("f3", a3->type()), - field("f4", a4->type()), field("f5", a5->type()), - field("f6", a6->type()), field("f7", a7->type())}); - batches.push_back( - RecordBatch::Make(schema, a0->length(), {a0, a1, a2, a3, a4, a5, a6, a7})); - - CheckDoPut(descr, schema, batches); -} - -TEST_F(TestDoPut, DoPutFloats) { - auto descr = FlightDescriptor::Path({"floats"}); - BatchVector batches; - auto a0 = ArrayFromJSON(float32(), "[0, 1.2, -3.4, 5.6, null]"); - auto a1 = ArrayFromJSON(float64(), "[0, 1.2, -3.4, 5.6, null]"); - auto schema = arrow::schema({field("f0", a0->type()), field("f1", a1->type())}); - batches.push_back(RecordBatch::Make(schema, a0->length(), {a0, a1})); - - CheckDoPut(descr, schema, batches); -} - -TEST_F(TestDoPut, DoPutEmptyBatch) { - // Sending and receiving a 0-sized batch shouldn't fail - auto descr = FlightDescriptor::Path({"ints"}); - BatchVector batches; - auto a1 = ArrayFromJSON(int32(), "[]"); - auto schema = arrow::schema({field("f1", a1->type())}); - batches.push_back(RecordBatch::Make(schema, a1->length(), {a1})); - - CheckDoPut(descr, schema, batches); -} - -TEST_F(TestDoPut, DoPutDicts) { - auto descr = FlightDescriptor::Path({"dicts"}); - BatchVector batches; - auto dict_values = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"quux\"]"); - auto ty = dictionary(int8(), dict_values->type()); - auto schema = arrow::schema({field("f1", ty)}); - // Make several batches - for (const char* json : {"[1, 0, 1]", "[null]", "[null, 1]"}) { - auto indices = ArrayFromJSON(int8(), json); - auto dict_array = std::make_shared(ty, indices, dict_values); - batches.push_back(RecordBatch::Make(schema, dict_array->length(), {dict_array})); - } - - CheckDoPut(descr, schema, batches); -} - -// Ensure the gRPC server is configured to allow large messages -// Tests a 32 MiB batch -TEST_F(TestDoPut, DoPutLargeBatch) { - auto descr = FlightDescriptor::Path({"large-batches"}); - auto schema = ExampleLargeSchema(); - BatchVector batches; - ASSERT_OK(ExampleLargeBatches(&batches)); - CheckDoPut(descr, schema, batches); -} - -TEST_F(TestDoPut, DoPutSizeLimit) { - const int64_t size_limit = 4096; - Location location; - ASSERT_OK(Location::ForGrpcTcp("localhost", server_->port(), &location)); - auto client_options = FlightClientOptions::Defaults(); - client_options.write_size_limit_bytes = size_limit; - std::unique_ptr client; - ASSERT_OK(FlightClient::Connect(location, client_options, &client)); - - auto descr = FlightDescriptor::Path({"ints"}); - // Batch is too large to fit in one message - auto schema = arrow::schema({field("f1", arrow::int64())}); - auto batch = arrow::ConstantArrayGenerator::Zeroes(768, schema); - BatchVector batches; - batches.push_back(batch->Slice(0, 384)); - batches.push_back(batch->Slice(384)); - - std::unique_ptr stream; - std::unique_ptr reader; - ASSERT_OK(client->DoPut(descr, schema, &stream, &reader)); - - // Large batch will exceed the limit - const auto status = stream->WriteRecordBatch(*batch); - EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("exceeded soft limit"), - status); - auto detail = FlightWriteSizeStatusDetail::UnwrapStatus(status); - ASSERT_NE(nullptr, detail); - ASSERT_EQ(size_limit, detail->limit()); - ASSERT_GT(detail->actual(), size_limit); - - // But we can retry with a smaller batch - for (const auto& batch : batches) { - ASSERT_OK(stream->WriteRecordBatch(*batch)); - } - - ASSERT_OK(stream->DoneWriting()); - ASSERT_OK(stream->Close()); - CheckBatches(descr, batches); -} - TEST_F(TestAuthHandler, PassAuthenticatedCalls) { ASSERT_OK(client_->Authenticate( {}, @@ -2338,205 +1356,6 @@ TEST_F(TestTls, OverrideHostnameGeneric) { // necessarily stable } -TEST_F(TestMetadata, DoGet) { - Ticket ticket{""}; - std::unique_ptr stream; - ASSERT_OK(client_->DoGet(ticket, &stream)); - - BatchVector expected_batches; - ASSERT_OK(ExampleIntBatches(&expected_batches)); - - FlightStreamChunk chunk; - auto num_batches = static_cast(expected_batches.size()); - for (int i = 0; i < num_batches; ++i) { - ASSERT_OK(stream->Next(&chunk)); - ASSERT_NE(nullptr, chunk.data); - ASSERT_NE(nullptr, chunk.app_metadata); - ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data); - ASSERT_EQ(std::to_string(i), chunk.app_metadata->ToString()); - } - ASSERT_OK(stream->Next(&chunk)); - ASSERT_EQ(nullptr, chunk.data); -} - -// Test dictionaries. This tests a corner case in the reader: -// dictionary batches come in between the schema and the first record -// batch, so the server must take care to read application metadata -// from the record batch, and not one of the dictionary batches. -TEST_F(TestMetadata, DoGetDictionaries) { - Ticket ticket{"dicts"}; - std::unique_ptr stream; - ASSERT_OK(client_->DoGet(ticket, &stream)); - - BatchVector expected_batches; - ASSERT_OK(ExampleDictBatches(&expected_batches)); - - FlightStreamChunk chunk; - auto num_batches = static_cast(expected_batches.size()); - for (int i = 0; i < num_batches; ++i) { - ASSERT_OK(stream->Next(&chunk)); - ASSERT_NE(nullptr, chunk.data); - ASSERT_NE(nullptr, chunk.app_metadata); - ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data); - ASSERT_EQ(std::to_string(i), chunk.app_metadata->ToString()); - } - ASSERT_OK(stream->Next(&chunk)); - ASSERT_EQ(nullptr, chunk.data); -} - -TEST_F(TestMetadata, DoPut) { - std::unique_ptr writer; - std::unique_ptr reader; - std::shared_ptr schema = ExampleIntSchema(); - ASSERT_OK(client_->DoPut(FlightDescriptor{}, schema, &writer, &reader)); - - BatchVector expected_batches; - ASSERT_OK(ExampleIntBatches(&expected_batches)); - - std::shared_ptr chunk; - std::shared_ptr metadata; - auto num_batches = static_cast(expected_batches.size()); - for (int i = 0; i < num_batches; ++i) { - ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i], - Buffer::FromString(std::to_string(i)))); - } - // This eventually calls grpc::ClientReaderWriter::Finish which can - // hang if there are unread messages. So make sure our wrapper - // around this doesn't hang (because it drains any unread messages) - ASSERT_OK(writer->Close()); -} - -// Test DoPut() with dictionaries. This tests a corner case in the -// server-side reader; see DoGetDictionaries above. -TEST_F(TestMetadata, DoPutDictionaries) { - std::unique_ptr writer; - std::unique_ptr reader; - BatchVector expected_batches; - ASSERT_OK(ExampleDictBatches(&expected_batches)); - // ARROW-8749: don't get the schema via ExampleDictSchema because - // DictionaryMemo uses field addresses to determine whether it's - // seen a field before. Hence, if we use a schema that is different - // (identity-wise) than the schema of the first batch we write, - // we'll end up generating a duplicate set of dictionaries that - // confuses the reader. - ASSERT_OK(client_->DoPut(FlightDescriptor{}, expected_batches[0]->schema(), &writer, - &reader)); - std::shared_ptr chunk; - std::shared_ptr metadata; - auto num_batches = static_cast(expected_batches.size()); - for (int i = 0; i < num_batches; ++i) { - ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i], - Buffer::FromString(std::to_string(i)))); - } - ASSERT_OK(writer->Close()); -} - -TEST_F(TestMetadata, DoPutReadMetadata) { - std::unique_ptr writer; - std::unique_ptr reader; - std::shared_ptr schema = ExampleIntSchema(); - ASSERT_OK(client_->DoPut(FlightDescriptor{}, schema, &writer, &reader)); - - BatchVector expected_batches; - ASSERT_OK(ExampleIntBatches(&expected_batches)); - - std::shared_ptr chunk; - std::shared_ptr metadata; - auto num_batches = static_cast(expected_batches.size()); - for (int i = 0; i < num_batches; ++i) { - ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i], - Buffer::FromString(std::to_string(i)))); - ASSERT_OK(reader->ReadMetadata(&metadata)); - ASSERT_NE(nullptr, metadata); - ASSERT_EQ(std::to_string(i), metadata->ToString()); - } - // As opposed to DoPutDrainMetadata, now we've read the messages, so - // make sure this still closes as expected. - ASSERT_OK(writer->Close()); -} - -TEST_F(TestOptions, DoGetReadOptions) { - // Call DoGet, but with a very low read nesting depth set to fail the call. - Ticket ticket{""}; - auto options = FlightCallOptions(); - options.read_options.max_recursion_depth = 1; - std::unique_ptr stream; - ASSERT_OK(client_->DoGet(options, ticket, &stream)); - FlightStreamChunk chunk; - ASSERT_RAISES(Invalid, stream->Next(&chunk)); -} - -TEST_F(TestOptions, DoPutWriteOptions) { - // Call DoPut, but with a very low write nesting depth set to fail the call. - std::unique_ptr writer; - std::unique_ptr reader; - BatchVector expected_batches; - ASSERT_OK(ExampleNestedBatches(&expected_batches)); - - auto options = FlightCallOptions(); - options.write_options.max_recursion_depth = 1; - ASSERT_OK(client_->DoPut(options, FlightDescriptor{}, expected_batches[0]->schema(), - &writer, &reader)); - for (const auto& batch : expected_batches) { - ASSERT_RAISES(Invalid, writer->WriteRecordBatch(*batch)); - } -} - -TEST_F(TestOptions, DoExchangeClientWriteOptions) { - // Call DoExchange and write nested data, but with a very low nesting depth set to - // fail the call. - auto options = FlightCallOptions(); - options.write_options.max_recursion_depth = 1; - auto descr = FlightDescriptor::Command(""); - std::unique_ptr reader; - std::unique_ptr writer; - ASSERT_OK(client_->DoExchange(options, descr, &writer, &reader)); - BatchVector batches; - ASSERT_OK(ExampleNestedBatches(&batches)); - ASSERT_OK(writer->Begin(batches[0]->schema())); - for (const auto& batch : batches) { - ASSERT_RAISES(Invalid, writer->WriteRecordBatch(*batch)); - } - ASSERT_OK(writer->DoneWriting()); - ASSERT_OK(writer->Close()); -} - -TEST_F(TestOptions, DoExchangeClientWriteOptionsBegin) { - // Call DoExchange and write nested data, but with a very low nesting depth set to - // fail the call. Here the options are set explicitly when we write data and not in the - // call options. - auto descr = FlightDescriptor::Command(""); - std::unique_ptr reader; - std::unique_ptr writer; - ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); - BatchVector batches; - ASSERT_OK(ExampleNestedBatches(&batches)); - auto options = ipc::IpcWriteOptions::Defaults(); - options.max_recursion_depth = 1; - ASSERT_OK(writer->Begin(batches[0]->schema(), options)); - for (const auto& batch : batches) { - ASSERT_RAISES(Invalid, writer->WriteRecordBatch(*batch)); - } - ASSERT_OK(writer->DoneWriting()); - ASSERT_OK(writer->Close()); -} - -TEST_F(TestOptions, DoExchangeServerWriteOptions) { - // Call DoExchange and write nested data, but with a very low nesting depth set to fail - // the call. (The low nesting depth is set on the server side.) - auto descr = FlightDescriptor::Command(""); - std::unique_ptr reader; - std::unique_ptr writer; - ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); - BatchVector batches; - ASSERT_OK(ExampleNestedBatches(&batches)); - ASSERT_OK(writer->Begin(batches[0]->schema())); - FlightStreamChunk chunk; - ASSERT_OK(writer->WriteRecordBatch(*batches[0])); - ASSERT_OK(writer->DoneWriting()); - ASSERT_RAISES(Invalid, writer->Close()); -} - TEST_F(TestRejectServerMiddleware, Rejected) { std::unique_ptr info; const auto& status = client_->GetFlightInfo(FlightDescriptor{}, &info); @@ -2648,140 +1467,6 @@ TEST_F(TestBasicHeaderAuthMiddleware, ValidCredentials) { RunValidClientAuth(); TEST_F(TestBasicHeaderAuthMiddleware, InvalidCredentials) { RunInvalidClientAuth(); } -TEST_F(TestCookieMiddleware, BasicParsing) { - AddAndValidate("id1=1; foo=bar;"); - AddAndValidate("id1=1; foo=bar"); - AddAndValidate("id2=2;"); - AddAndValidate("id4=\"4\""); - AddAndValidate("id5=5; foo=bar; baz=buz;"); -} - -TEST_F(TestCookieMiddleware, Overwrite) { - AddAndValidate("id0=0"); - AddAndValidate("id0=1"); - AddAndValidate("id1=0"); - AddAndValidate("id1=1"); - AddAndValidate("id1=1"); - AddAndValidate("id1=10"); - AddAndValidate("id=3"); - AddAndValidate("id=0"); - AddAndValidate("id=0"); -} - -TEST_F(TestCookieMiddleware, MaxAge) { - AddAndValidate("id0=0; max-age=0;"); - AddAndValidate("id1=0; max-age=-1;"); - AddAndValidate("id2=0; max-age=0"); - AddAndValidate("id3=0; max-age=-1"); - AddAndValidate("id4=0; max-age=1"); - AddAndValidate("id5=0; max-age=1"); - AddAndValidate("id4=0; max-age=0"); - AddAndValidate("id5=0; max-age=0"); -} - -TEST_F(TestCookieMiddleware, Expires) { - AddAndValidate("id0=0; expires=0, 0 0 0 0:0:0 GMT;"); - AddAndValidate("id0=0; expires=0, 0 0 0 0:0:0 GMT"); - AddAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT;"); - AddAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT"); - AddAndValidate("id0=0; expires=Fri, 01 Jan 2038 22:15:36 GMT;"); - AddAndValidate("id1=0; expires=Fri, 01 Jan 2038 22:15:36 GMT"); - AddAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT;"); - AddAndValidate("id1=0; expires=Fri, 22 Dec 2017 22:15:36 GMT"); -} - -TEST_F(TestCookieParsing, Expired) { - VerifyParseCookie("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT;", true); - VerifyParseCookie("id1=0; max-age=-1;", true); - VerifyParseCookie("id0=0; max-age=0;", true); -} - -TEST_F(TestCookieParsing, Invalid) { - VerifyParseCookie("id1=0; expires=0, 0 0 0 0:0:0 GMT;", true); - VerifyParseCookie("id1=0; expires=Fri, 01 FOO 2038 22:15:36 GMT", true); - VerifyParseCookie("id1=0; expires=foo", true); - VerifyParseCookie("id1=0; expires=", true); - VerifyParseCookie("id1=0; max-age=FOO", true); - VerifyParseCookie("id1=0; max-age=", true); -} - -TEST_F(TestCookieParsing, NoExpiry) { - VerifyParseCookie("id1=0;", false); - VerifyParseCookie("id1=0; noexpiry=Fri, 01 Jan 2038 22:15:36 GMT", false); - VerifyParseCookie("id1=0; noexpiry=\"Fri, 01 Jan 2038 22:15:36 GMT\"", false); - VerifyParseCookie("id1=0; nomax-age=-1", false); - VerifyParseCookie("id1=0; nomax-age=\"-1\"", false); - VerifyParseCookie("id1=0; randomattr=foo", false); -} - -TEST_F(TestCookieParsing, NotExpired) { - VerifyParseCookie("id5=0; max-age=1", false); - VerifyParseCookie("id0=0; expires=Fri, 01 Jan 2038 22:15:36 GMT;", false); -} - -TEST_F(TestCookieParsing, GetName) { - VerifyCookieName("id1=1; foo=bar;", "id1"); - VerifyCookieName("id1=1; foo=bar", "id1"); - VerifyCookieName("id2=2;", "id2"); - VerifyCookieName("id4=\"4\"", "id4"); - VerifyCookieName("id5=5; foo=bar; baz=buz;", "id5"); -} - -TEST_F(TestCookieParsing, ToString) { - VerifyCookieString("id1=1; foo=bar;", "id1=1"); - VerifyCookieString("id1=1; foo=bar", "id1=1"); - VerifyCookieString("id2=2;", "id2=2"); - VerifyCookieString("id4=\"4\"", "id4=4"); - VerifyCookieString("id5=5; foo=bar; baz=buz;", "id5=5"); -} - -TEST_F(TestCookieParsing, DateConversion) { - VerifyCookieDateConverson("Mon, 01 jan 2038 22:15:36 GMT;", "01 01 2038 22:15:36"); - VerifyCookieDateConverson("TUE, 10 Feb 2038 22:15:36 GMT", "10 02 2038 22:15:36"); - VerifyCookieDateConverson("WED, 20 MAr 2038 22:15:36 GMT;", "20 03 2038 22:15:36"); - VerifyCookieDateConverson("thu, 15 APR 2038 22:15:36 GMT", "15 04 2038 22:15:36"); - VerifyCookieDateConverson("Fri, 30 mAY 2038 22:15:36 GMT;", "30 05 2038 22:15:36"); - VerifyCookieDateConverson("Sat, 03 juN 2038 22:15:36 GMT", "03 06 2038 22:15:36"); - VerifyCookieDateConverson("Sun, 01 JuL 2038 22:15:36 GMT;", "01 07 2038 22:15:36"); - VerifyCookieDateConverson("Fri, 06 aUg 2038 22:15:36 GMT", "06 08 2038 22:15:36"); - VerifyCookieDateConverson("Fri, 01 SEP 2038 22:15:36 GMT;", "01 09 2038 22:15:36"); - VerifyCookieDateConverson("Fri, 01 OCT 2038 22:15:36 GMT", "01 10 2038 22:15:36"); - VerifyCookieDateConverson("Fri, 01 Nov 2038 22:15:36 GMT;", "01 11 2038 22:15:36"); - VerifyCookieDateConverson("Fri, 01 deC 2038 22:15:36 GMT", "01 12 2038 22:15:36"); - VerifyCookieDateConverson("", ""); - VerifyCookieDateConverson("Fri, 01 INVALID 2038 22:15:36 GMT;", - "01 INVALID 2038 22:15:36"); -} - -TEST_F(TestCookieParsing, ParseCookieAttribute) { - VerifyCookieAttributeParsing("", 0, util::nullopt, std::string::npos); - - std::string cookie_string = "attr0=0; attr1=1; attr2=2; attr3=3"; - auto attr_length = std::string("attr0=0;").length(); - std::string::size_type start_pos = 0; - VerifyCookieAttributeParsing(cookie_string, start_pos, std::make_pair("attr0", "0"), - cookie_string.find("attr0=0;") + attr_length); - VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length + 1)), - std::make_pair("attr1", "1"), - cookie_string.find("attr1=1;") + attr_length); - VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length + 1)), - std::make_pair("attr2", "2"), - cookie_string.find("attr2=2;") + attr_length); - VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length + 1)), - std::make_pair("attr3", "3"), std::string::npos); - VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length - 1)), - util::nullopt, std::string::npos); - VerifyCookieAttributeParsing(cookie_string, std::string::npos, util::nullopt, - std::string::npos); -} - -TEST_F(TestCookieParsing, CookieCache) { - AddCookieVerifyCache({"id0=0;"}, ""); - AddCookieVerifyCache({"id0=0;", "id0=1;"}, "id0=1"); - AddCookieVerifyCache({"id0=0;", "id1=1;"}, "id0=0; id1=1"); - AddCookieVerifyCache({"id0=0;", "id1=1;", "id2=2"}, "id0=0; id1=1; id2=2"); -} - class ForeverFlightListing : public FlightListing { Status Next(std::unique_ptr* info) override { std::this_thread::sleep_for(std::chrono::milliseconds(100)); diff --git a/cpp/src/arrow/flight/test_definitions.cc b/cpp/src/arrow/flight/test_definitions.cc new file mode 100644 index 00000000000..ad40fb6a6cd --- /dev/null +++ b/cpp/src/arrow/flight/test_definitions.cc @@ -0,0 +1,980 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/test_definitions.h" + +#include + +#include "arrow/array/array_base.h" +#include "arrow/array/array_dict.h" +#include "arrow/flight/api.h" +#include "arrow/flight/test_util.h" +#include "arrow/testing/generator.h" + +namespace arrow { +namespace flight { + +//------------------------------------------------------------ +// Tests of initialization/shutdown + +void ConnectivityTest::TestGetPort() { + std::unique_ptr server = ExampleTestServer(); + + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + FlightServerOptions options(location); + ASSERT_OK(server->Init(options)); + ASSERT_GT(server->port(), 0); +} +void ConnectivityTest::TestBuilderHook() { + std::unique_ptr server = ExampleTestServer(); + + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + FlightServerOptions options(location); + bool builder_hook_run = false; + options.builder_hook = [&builder_hook_run](void* builder) { + ASSERT_NE(nullptr, builder); + builder_hook_run = true; + }; + ASSERT_OK(server->Init(options)); + ASSERT_TRUE(builder_hook_run); + ASSERT_GT(server->port(), 0); + ASSERT_OK(server->Shutdown()); +} +void ConnectivityTest::TestShutdown() { + // Regression test for ARROW-15181 + constexpr int kIterations = 10; + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + for (int i = 0; i < kIterations; i++) { + std::unique_ptr server = ExampleTestServer(); + + FlightServerOptions options(location); + ASSERT_OK(server->Init(options)); + ASSERT_GT(server->port(), 0); + std::thread t([&]() { ASSERT_OK(server->Serve()); }); + ASSERT_OK(server->Shutdown()); + ASSERT_OK(server->Wait()); + t.join(); + } +} +void ConnectivityTest::TestShutdownWithDeadline() { + std::unique_ptr server = ExampleTestServer(); + + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + FlightServerOptions options(location); + ASSERT_OK(server->Init(options)); + ASSERT_GT(server->port(), 0); + + auto deadline = std::chrono::system_clock::now() + std::chrono::microseconds(10); + + ASSERT_OK(server->Shutdown(&deadline)); + ASSERT_OK(server->Wait()); +} + +//------------------------------------------------------------ +// Tests of data plane methods + +void DataTest::SetUp() { + server_ = ExampleTestServer(); + + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + FlightServerOptions options(location); + ASSERT_OK(server_->Init(options)); + + ASSERT_OK(ConnectClient()); +} +void DataTest::TearDown() { + ASSERT_OK(client_->Close()); + ASSERT_OK(server_->Shutdown()); +} +Status DataTest::ConnectClient() { + ARROW_ASSIGN_OR_RAISE(auto location, + Location::ForScheme(transport(), "localhost", server_->port())); + return FlightClient::Connect(location, &client_); +} +void DataTest::CheckDoGet( + const FlightDescriptor& descr, const RecordBatchVector& expected_batches, + std::function&)> check_endpoints) { + auto expected_schema = expected_batches[0]->schema(); + + std::unique_ptr info; + ASSERT_OK(client_->GetFlightInfo(descr, &info)); + check_endpoints(info->endpoints()); + + std::shared_ptr schema; + ipc::DictionaryMemo dict_memo; + ASSERT_OK(info->GetSchema(&dict_memo, &schema)); + AssertSchemaEqual(*expected_schema, *schema); + + // By convention, fetch the first endpoint + Ticket ticket = info->endpoints()[0].ticket; + CheckDoGet(ticket, expected_batches); +} +void DataTest::CheckDoGet(const Ticket& ticket, + const RecordBatchVector& expected_batches) { + auto num_batches = static_cast(expected_batches.size()); + ASSERT_GE(num_batches, 2); + + std::unique_ptr stream; + ASSERT_OK(client_->DoGet(ticket, &stream)); + + std::unique_ptr stream2; + ASSERT_OK(client_->DoGet(ticket, &stream2)); + ASSERT_OK_AND_ASSIGN(auto reader, MakeRecordBatchReader(std::move(stream2))); + + FlightStreamChunk chunk; + std::shared_ptr batch; + for (int i = 0; i < num_batches; ++i) { + ASSERT_OK(stream->Next(&chunk)); + ASSERT_OK(reader->ReadNext(&batch)); + ASSERT_NE(nullptr, chunk.data); + ASSERT_NE(nullptr, batch); +#if !defined(__MINGW32__) + ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data); + ASSERT_BATCHES_EQUAL(*expected_batches[i], *batch); +#else + // In MINGW32, the following code does not have the reproducibility at the LSB + // even when this is called twice with the same seed. + // As a workaround, use approxEqual + // /* from GenerateTypedData in random.cc */ + // std::default_random_engine rng(seed); // seed = 282475250 + // std::uniform_real_distribution dist; + // std::generate(data, data + n, // n = 10 + // [&dist, &rng] { return static_cast(dist(rng)); }); + // /* data[1] = 0x40852cdfe23d3976 or 0x40852cdfe23d3975 */ + ASSERT_BATCHES_APPROX_EQUAL(*expected_batches[i], *chunk.data); + ASSERT_BATCHES_APPROX_EQUAL(*expected_batches[i], *batch); +#endif + } + + // Stream exhausted + ASSERT_OK(stream->Next(&chunk)); + ASSERT_OK(reader->ReadNext(&batch)); + ASSERT_EQ(nullptr, chunk.data); + ASSERT_EQ(nullptr, batch); +} + +void DataTest::TestDoGetInts() { + auto descr = FlightDescriptor::Path({"examples", "ints"}); + RecordBatchVector expected_batches; + ASSERT_OK(ExampleIntBatches(&expected_batches)); + + auto check_endpoints = [](const std::vector& endpoints) { + // Two endpoints in the example FlightInfo + ASSERT_EQ(2, endpoints.size()); + ASSERT_EQ(Ticket{"ticket-ints-1"}, endpoints[0].ticket); + }; + + CheckDoGet(descr, expected_batches, check_endpoints); +} +void DataTest::TestDoGetFloats() { + auto descr = FlightDescriptor::Path({"examples", "floats"}); + RecordBatchVector expected_batches; + ASSERT_OK(ExampleFloatBatches(&expected_batches)); + + auto check_endpoints = [](const std::vector& endpoints) { + // One endpoint in the example FlightInfo + ASSERT_EQ(1, endpoints.size()); + ASSERT_EQ(Ticket{"ticket-floats-1"}, endpoints[0].ticket); + }; + + CheckDoGet(descr, expected_batches, check_endpoints); +} +void DataTest::TestDoGetDicts() { + auto descr = FlightDescriptor::Path({"examples", "dicts"}); + RecordBatchVector expected_batches; + ASSERT_OK(ExampleDictBatches(&expected_batches)); + + auto check_endpoints = [](const std::vector& endpoints) { + // One endpoint in the example FlightInfo + ASSERT_EQ(1, endpoints.size()); + ASSERT_EQ(Ticket{"ticket-dicts-1"}, endpoints[0].ticket); + }; + + CheckDoGet(descr, expected_batches, check_endpoints); +} +// Ensure clients are configured to allow large messages by default +// Tests a 32 MiB batch +void DataTest::TestDoGetLargeBatch() { + RecordBatchVector expected_batches; + ASSERT_OK(ExampleLargeBatches(&expected_batches)); + Ticket ticket{"ticket-large-batch-1"}; + CheckDoGet(ticket, expected_batches); +} +void DataTest::TestOverflowServerBatch() { + // Regression test for ARROW-13253 + // N.B. this is rather a slow and memory-hungry test + { + // DoGet: check for overflow on large batch + Ticket ticket{"ARROW-13253-DoGet-Batch"}; + std::unique_ptr stream; + ASSERT_OK(client_->DoGet(ticket, &stream)); + FlightStreamChunk chunk; + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"), + stream->Next(&chunk)); + } + { + // DoExchange: check for overflow on large batch from server + auto descr = FlightDescriptor::Command("large_batch"); + std::unique_ptr reader; + std::unique_ptr writer; + ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); + RecordBatchVector batches; + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"), + reader->ReadAll(&batches)); + } +} +void DataTest::TestOverflowClientBatch() { + ASSERT_OK_AND_ASSIGN(auto batch, VeryLargeBatch()); + { + // DoPut: check for overflow on large batch + std::unique_ptr stream; + std::unique_ptr reader; + auto descr = FlightDescriptor::Path({""}); + ASSERT_OK(client_->DoPut(descr, batch->schema(), &stream, &reader)); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"), + stream->WriteRecordBatch(*batch)); + ASSERT_OK(stream->Close()); + } + { + // DoExchange: check for overflow on large batch from client + auto descr = FlightDescriptor::Command("counter"); + std::unique_ptr reader; + std::unique_ptr writer; + ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); + ASSERT_OK(writer->Begin(batch->schema())); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"), + writer->WriteRecordBatch(*batch)); + ASSERT_OK(writer->Close()); + } +} +void DataTest::TestDoExchange() { + auto descr = FlightDescriptor::Command("counter"); + RecordBatchVector batches; + auto a1 = ArrayFromJSON(int32(), "[4, 5, 6, null]"); + auto schema = arrow::schema({field("f1", a1->type())}); + batches.push_back(RecordBatch::Make(schema, a1->length(), {a1})); + std::unique_ptr reader; + std::unique_ptr writer; + ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); + ASSERT_OK(writer->Begin(schema)); + for (const auto& batch : batches) { + ASSERT_OK(writer->WriteRecordBatch(*batch)); + } + ASSERT_OK(writer->DoneWriting()); + FlightStreamChunk chunk; + ASSERT_OK(reader->Next(&chunk)); + ASSERT_NE(nullptr, chunk.app_metadata); + ASSERT_EQ(nullptr, chunk.data); + ASSERT_EQ("1", chunk.app_metadata->ToString()); + ASSERT_OK_AND_ASSIGN(auto server_schema, reader->GetSchema()); + AssertSchemaEqual(schema, server_schema); + for (const auto& batch : batches) { + ASSERT_OK(reader->Next(&chunk)); + ASSERT_BATCHES_EQUAL(*batch, *chunk.data); + } + ASSERT_OK(writer->Close()); +} +// Test pure-metadata DoExchange to ensure nothing blocks waiting for +// schema messages +void DataTest::TestDoExchangeNoData() { + auto descr = FlightDescriptor::Command("counter"); + std::unique_ptr reader; + std::unique_ptr writer; + ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); + ASSERT_OK(writer->DoneWriting()); + FlightStreamChunk chunk; + ASSERT_OK(reader->Next(&chunk)); + ASSERT_EQ(nullptr, chunk.data); + ASSERT_NE(nullptr, chunk.app_metadata); + ASSERT_EQ("0", chunk.app_metadata->ToString()); + ASSERT_OK(writer->Close()); +} +// Test sending a schema without any data, as this hits an edge case +// in the client-side writer. +void DataTest::TestDoExchangeWriteOnlySchema() { + auto descr = FlightDescriptor::Command("counter"); + std::unique_ptr reader; + std::unique_ptr writer; + ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); + auto schema = arrow::schema({field("f1", arrow::int32())}); + ASSERT_OK(writer->Begin(schema)); + ASSERT_OK(writer->WriteMetadata(Buffer::FromString("foo"))); + ASSERT_OK(writer->DoneWriting()); + FlightStreamChunk chunk; + ASSERT_OK(reader->Next(&chunk)); + ASSERT_EQ(nullptr, chunk.data); + ASSERT_NE(nullptr, chunk.app_metadata); + ASSERT_EQ("0", chunk.app_metadata->ToString()); + ASSERT_OK(writer->Close()); +} +// Emulate DoGet +void DataTest::TestDoExchangeGet() { + auto descr = FlightDescriptor::Command("get"); + std::unique_ptr reader; + std::unique_ptr writer; + ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); + ASSERT_OK_AND_ASSIGN(auto server_schema, reader->GetSchema()); + AssertSchemaEqual(*ExampleIntSchema(), *server_schema); + RecordBatchVector batches; + ASSERT_OK(ExampleIntBatches(&batches)); + FlightStreamChunk chunk; + for (const auto& batch : batches) { + ASSERT_OK(reader->Next(&chunk)); + ASSERT_NE(nullptr, chunk.data); + AssertBatchesEqual(*batch, *chunk.data); + } + ASSERT_OK(reader->Next(&chunk)); + ASSERT_EQ(nullptr, chunk.data); + ASSERT_EQ(nullptr, chunk.app_metadata); + ASSERT_OK(writer->Close()); +} +// Emulate DoPut +void DataTest::TestDoExchangePut() { + auto descr = FlightDescriptor::Command("put"); + std::unique_ptr reader; + std::unique_ptr writer; + ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); + ASSERT_OK(writer->Begin(ExampleIntSchema())); + RecordBatchVector batches; + ASSERT_OK(ExampleIntBatches(&batches)); + for (const auto& batch : batches) { + ASSERT_OK(writer->WriteRecordBatch(*batch)); + } + ASSERT_OK(writer->DoneWriting()); + FlightStreamChunk chunk; + ASSERT_OK(reader->Next(&chunk)); + ASSERT_NE(nullptr, chunk.app_metadata); + AssertBufferEqual(*chunk.app_metadata, "done"); + ASSERT_OK(reader->Next(&chunk)); + ASSERT_EQ(nullptr, chunk.data); + ASSERT_EQ(nullptr, chunk.app_metadata); + ASSERT_OK(writer->Close()); +} +// Test the echo server +void DataTest::TestDoExchangeEcho() { + auto descr = FlightDescriptor::Command("echo"); + std::unique_ptr reader; + std::unique_ptr writer; + ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); + ASSERT_OK(writer->Begin(ExampleIntSchema())); + RecordBatchVector batches; + FlightStreamChunk chunk; + ASSERT_OK(ExampleIntBatches(&batches)); + for (const auto& batch : batches) { + ASSERT_OK(writer->WriteRecordBatch(*batch)); + ASSERT_OK(reader->Next(&chunk)); + ASSERT_NE(nullptr, chunk.data); + ASSERT_EQ(nullptr, chunk.app_metadata); + AssertBatchesEqual(*batch, *chunk.data); + } + for (int i = 0; i < 10; i++) { + const auto buf = Buffer::FromString(std::to_string(i)); + ASSERT_OK(writer->WriteMetadata(buf)); + ASSERT_OK(reader->Next(&chunk)); + ASSERT_EQ(nullptr, chunk.data); + ASSERT_NE(nullptr, chunk.app_metadata); + AssertBufferEqual(*buf, *chunk.app_metadata); + } + int index = 0; + for (const auto& batch : batches) { + const auto buf = Buffer::FromString(std::to_string(index)); + ASSERT_OK(writer->WriteWithMetadata(*batch, buf)); + ASSERT_OK(reader->Next(&chunk)); + ASSERT_NE(nullptr, chunk.data); + ASSERT_NE(nullptr, chunk.app_metadata); + AssertBatchesEqual(*batch, *chunk.data); + AssertBufferEqual(*buf, *chunk.app_metadata); + index++; + } + ASSERT_OK(writer->DoneWriting()); + ASSERT_OK(reader->Next(&chunk)); + ASSERT_EQ(nullptr, chunk.data); + ASSERT_EQ(nullptr, chunk.app_metadata); + ASSERT_OK(writer->Close()); +} +// Test interleaved reading/writing +void DataTest::TestDoExchangeTotal() { + auto descr = FlightDescriptor::Command("total"); + std::unique_ptr reader; + std::unique_ptr writer; + { + auto a1 = ArrayFromJSON(arrow::int32(), "[4, 5, 6, null]"); + auto schema = arrow::schema({field("f1", a1->type())}); + // XXX: as noted in flight/client.cc, Begin() is lazy and the + // schema message won't be written until some data is also + // written. There's also timing issues; hence we check each status + // here. + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Field is not INT64: f1"), ([&]() { + RETURN_NOT_OK(client_->DoExchange(descr, &writer, &reader)); + RETURN_NOT_OK(writer->Begin(schema)); + auto batch = RecordBatch::Make(schema, /* num_rows */ 4, {a1}); + RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); + return writer->Close(); + })()); + } + { + auto a1 = ArrayFromJSON(arrow::int64(), "[1, 2, null, 3]"); + auto a2 = ArrayFromJSON(arrow::int64(), "[null, 4, 5, 6]"); + auto schema = arrow::schema({field("f1", a1->type()), field("f2", a2->type())}); + ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); + ASSERT_OK(writer->Begin(schema)); + auto batch = RecordBatch::Make(schema, /* num_rows */ 4, {a1, a2}); + FlightStreamChunk chunk; + ASSERT_OK(writer->WriteRecordBatch(*batch)); + ASSERT_OK_AND_ASSIGN(auto server_schema, reader->GetSchema()); + AssertSchemaEqual(*schema, *server_schema); + + ASSERT_OK(reader->Next(&chunk)); + ASSERT_NE(nullptr, chunk.data); + auto expected1 = RecordBatch::Make( + schema, /* num_rows */ 1, + {ArrayFromJSON(arrow::int64(), "[6]"), ArrayFromJSON(arrow::int64(), "[15]")}); + AssertBatchesEqual(*expected1, *chunk.data); + + ASSERT_OK(writer->WriteRecordBatch(*batch)); + ASSERT_OK(reader->Next(&chunk)); + ASSERT_NE(nullptr, chunk.data); + auto expected2 = RecordBatch::Make( + schema, /* num_rows */ 1, + {ArrayFromJSON(arrow::int64(), "[12]"), ArrayFromJSON(arrow::int64(), "[30]")}); + AssertBatchesEqual(*expected2, *chunk.data); + + ASSERT_OK(writer->Close()); + } +} +// Ensure server errors get propagated no matter what we try +void DataTest::TestDoExchangeError() { + auto descr = FlightDescriptor::Command("error"); + std::unique_ptr reader; + std::unique_ptr writer; + { + ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); + auto status = writer->Close(); + EXPECT_RAISES_WITH_MESSAGE_THAT( + NotImplemented, ::testing::HasSubstr("Expected error"), writer->Close()); + } + { + ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); + FlightStreamChunk chunk; + EXPECT_RAISES_WITH_MESSAGE_THAT( + NotImplemented, ::testing::HasSubstr("Expected error"), reader->Next(&chunk)); + } + { + ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); + EXPECT_RAISES_WITH_MESSAGE_THAT( + NotImplemented, ::testing::HasSubstr("Expected error"), reader->GetSchema()); + } + // writer->Begin isn't tested here because, as noted in client.cc, + // OpenRecordBatchWriter lazily writes the initial message - hence + // Begin() won't fail. Additionally, it appears gRPC may buffer + // writes - a write won't immediately fail even when the server + // immediately fails. +} +void DataTest::TestIssue5095() { + // Make sure the server-side error message is reflected to the + // client + Ticket ticket1{"ARROW-5095-fail"}; + std::unique_ptr stream; + Status status = client_->DoGet(ticket1, &stream); + ASSERT_RAISES(UnknownError, status); + ASSERT_THAT(status.message(), ::testing::HasSubstr("Server-side error")); + + Ticket ticket2{"ARROW-5095-success"}; + status = client_->DoGet(ticket2, &stream); + ASSERT_RAISES(KeyError, status); + ASSERT_THAT(status.message(), ::testing::HasSubstr("No data")); +} + +//------------------------------------------------------------ +// Specific tests for DoPut + +class DoPutTestServer : public FlightServerBase { + public: + Status DoPut(const ServerCallContext& context, + std::unique_ptr reader, + std::unique_ptr writer) override { + descriptor_ = reader->descriptor(); + return reader->ReadAll(&batches_); + } + + protected: + FlightDescriptor descriptor_; + RecordBatchVector batches_; + + friend class DoPutTest; +}; + +void DoPutTest::SetUp() { + ASSERT_OK(MakeServer( + &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); }, + [](FlightClientOptions* options) { return Status::OK(); })); +} +void DoPutTest::TearDown() { + ASSERT_OK(client_->Close()); + ASSERT_OK(server_->Shutdown()); +} +void DoPutTest::CheckBatches(const FlightDescriptor& expected_descriptor, + const RecordBatchVector& expected_batches) { + auto* do_put_server = (DoPutTestServer*)server_.get(); + ASSERT_EQ(do_put_server->descriptor_, expected_descriptor); + ASSERT_EQ(do_put_server->batches_.size(), expected_batches.size()); + for (size_t i = 0; i < expected_batches.size(); ++i) { + ASSERT_BATCHES_EQUAL(*do_put_server->batches_[i], *expected_batches[i]); + } +} +void DoPutTest::CheckDoPut(const FlightDescriptor& descr, + const std::shared_ptr& schema, + const RecordBatchVector& batches) { + std::unique_ptr stream; + std::unique_ptr reader; + ASSERT_OK(client_->DoPut(descr, schema, &stream, &reader)); + for (const auto& batch : batches) { + ASSERT_OK(stream->WriteRecordBatch(*batch)); + } + ASSERT_OK(stream->DoneWriting()); + ASSERT_OK(stream->Close()); + + CheckBatches(descr, batches); +} + +void DoPutTest::TestInts() { + auto descr = FlightDescriptor::Path({"ints"}); + RecordBatchVector batches; + auto a0 = ArrayFromJSON(int8(), "[0, 1, 127, -128, null]"); + auto a1 = ArrayFromJSON(uint8(), "[0, 1, 127, 255, null]"); + auto a2 = ArrayFromJSON(int16(), "[0, 258, 32767, -32768, null]"); + auto a3 = ArrayFromJSON(uint16(), "[0, 258, 32767, 65535, null]"); + auto a4 = ArrayFromJSON(int32(), "[0, 65538, 2147483647, -2147483648, null]"); + auto a5 = ArrayFromJSON(uint32(), "[0, 65538, 2147483647, 4294967295, null]"); + auto a6 = ArrayFromJSON( + int64(), "[0, 4294967298, 9223372036854775807, -9223372036854775808, null]"); + auto a7 = ArrayFromJSON( + uint64(), "[0, 4294967298, 9223372036854775807, 18446744073709551615, null]"); + auto schema = arrow::schema({field("f0", a0->type()), field("f1", a1->type()), + field("f2", a2->type()), field("f3", a3->type()), + field("f4", a4->type()), field("f5", a5->type()), + field("f6", a6->type()), field("f7", a7->type())}); + batches.push_back( + RecordBatch::Make(schema, a0->length(), {a0, a1, a2, a3, a4, a5, a6, a7})); + + CheckDoPut(descr, schema, batches); +} + +void DoPutTest::TestDoPutFloats() { + auto descr = FlightDescriptor::Path({"floats"}); + RecordBatchVector batches; + auto a0 = ArrayFromJSON(float32(), "[0, 1.2, -3.4, 5.6, null]"); + auto a1 = ArrayFromJSON(float64(), "[0, 1.2, -3.4, 5.6, null]"); + auto schema = arrow::schema({field("f0", a0->type()), field("f1", a1->type())}); + batches.push_back(RecordBatch::Make(schema, a0->length(), {a0, a1})); + + CheckDoPut(descr, schema, batches); +} + +void DoPutTest::TestDoPutEmptyBatch() { + // Sending and receiving a 0-sized batch shouldn't fail + auto descr = FlightDescriptor::Path({"ints"}); + RecordBatchVector batches; + auto a1 = ArrayFromJSON(int32(), "[]"); + auto schema = arrow::schema({field("f1", a1->type())}); + batches.push_back(RecordBatch::Make(schema, a1->length(), {a1})); + + CheckDoPut(descr, schema, batches); +} + +void DoPutTest::TestDoPutDicts() { + auto descr = FlightDescriptor::Path({"dicts"}); + RecordBatchVector batches; + auto dict_values = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"quux\"]"); + auto ty = dictionary(int8(), dict_values->type()); + auto schema = arrow::schema({field("f1", ty)}); + // Make several batches + for (const char* json : {"[1, 0, 1]", "[null]", "[null, 1]"}) { + auto indices = ArrayFromJSON(int8(), json); + auto dict_array = std::make_shared(ty, indices, dict_values); + batches.push_back(RecordBatch::Make(schema, dict_array->length(), {dict_array})); + } + + CheckDoPut(descr, schema, batches); +} + +// Ensure the server is configured to allow large messages by default +// Tests a 32 MiB batch +void DoPutTest::TestDoPutLargeBatch() { + auto descr = FlightDescriptor::Path({"large-batches"}); + auto schema = ExampleLargeSchema(); + RecordBatchVector batches; + ASSERT_OK(ExampleLargeBatches(&batches)); + CheckDoPut(descr, schema, batches); +} + +void DoPutTest::TestDoPutSizeLimit() { + const int64_t size_limit = 4096; + Location location; + ASSERT_OK(Location::ForGrpcTcp("localhost", server_->port(), &location)); + auto client_options = FlightClientOptions::Defaults(); + client_options.write_size_limit_bytes = size_limit; + std::unique_ptr client; + ASSERT_OK(FlightClient::Connect(location, client_options, &client)); + + auto descr = FlightDescriptor::Path({"ints"}); + // Batch is too large to fit in one message + auto schema = arrow::schema({field("f1", arrow::int64())}); + auto batch = arrow::ConstantArrayGenerator::Zeroes(768, schema); + RecordBatchVector batches; + batches.push_back(batch->Slice(0, 384)); + batches.push_back(batch->Slice(384)); + + std::unique_ptr stream; + std::unique_ptr reader; + ASSERT_OK(client->DoPut(descr, schema, &stream, &reader)); + + // Large batch will exceed the limit + const auto status = stream->WriteRecordBatch(*batch); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("exceeded soft limit"), + status); + auto detail = FlightWriteSizeStatusDetail::UnwrapStatus(status); + ASSERT_NE(nullptr, detail); + ASSERT_EQ(size_limit, detail->limit()); + ASSERT_GT(detail->actual(), size_limit); + + // But we can retry with a smaller batch + for (const auto& batch : batches) { + ASSERT_OK(stream->WriteRecordBatch(*batch)); + } + + ASSERT_OK(stream->DoneWriting()); + ASSERT_OK(stream->Close()); + CheckBatches(descr, batches); +} + +//------------------------------------------------------------ +// Test app_metadata in data plane methods + +Status AppMetadataTestServer::DoGet(const ServerCallContext& context, + const Ticket& request, + std::unique_ptr* data_stream) { + RecordBatchVector batches; + if (request.ticket == "dicts") { + RETURN_NOT_OK(ExampleDictBatches(&batches)); + } else if (request.ticket == "floats") { + RETURN_NOT_OK(ExampleFloatBatches(&batches)); + } else { + RETURN_NOT_OK(ExampleIntBatches(&batches)); + } + ARROW_ASSIGN_OR_RAISE(auto batch_reader, RecordBatchReader::Make(batches)); + *data_stream = std::unique_ptr(new NumberingStream( + std::unique_ptr(new RecordBatchStream(batch_reader)))); + return Status::OK(); +} +Status AppMetadataTestServer::DoPut(const ServerCallContext& context, + std::unique_ptr reader, + std::unique_ptr writer) { + FlightStreamChunk chunk; + int counter = 0; + while (true) { + RETURN_NOT_OK(reader->Next(&chunk)); + if (chunk.data == nullptr) break; + if (chunk.app_metadata == nullptr) { + return Status::Invalid("Expected application metadata to be provided"); + } + if (std::to_string(counter) != chunk.app_metadata->ToString()) { + return Status::Invalid("Expected metadata value: " + std::to_string(counter) + + " but got: " + chunk.app_metadata->ToString()); + } + auto metadata = Buffer::FromString(std::to_string(counter)); + RETURN_NOT_OK(writer->WriteMetadata(*metadata)); + counter++; + } + return Status::OK(); +} + +void AppMetadataTest::SetUp() { + ASSERT_OK(MakeServer( + &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); }, + [](FlightClientOptions* options) { return Status::OK(); })); +} +void AppMetadataTest::TearDown() { + ASSERT_OK(client_->Close()); + ASSERT_OK(server_->Shutdown()); +} +void AppMetadataTest::TestDoGet() { + Ticket ticket{""}; + std::unique_ptr stream; + ASSERT_OK(client_->DoGet(ticket, &stream)); + + RecordBatchVector expected_batches; + ASSERT_OK(ExampleIntBatches(&expected_batches)); + + FlightStreamChunk chunk; + auto num_batches = static_cast(expected_batches.size()); + for (int i = 0; i < num_batches; ++i) { + ASSERT_OK(stream->Next(&chunk)); + ASSERT_NE(nullptr, chunk.data); + ASSERT_NE(nullptr, chunk.app_metadata); + ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data); + ASSERT_EQ(std::to_string(i), chunk.app_metadata->ToString()); + } + ASSERT_OK(stream->Next(&chunk)); + ASSERT_EQ(nullptr, chunk.data); +} +// Test dictionaries. This tests a corner case in the reader: +// dictionary batches come in between the schema and the first record +// batch, so the server must take care to read application metadata +// from the record batch, and not one of the dictionary batches. +void AppMetadataTest::TestDoGetDictionaries() { + Ticket ticket{"dicts"}; + std::unique_ptr stream; + ASSERT_OK(client_->DoGet(ticket, &stream)); + + RecordBatchVector expected_batches; + ASSERT_OK(ExampleDictBatches(&expected_batches)); + + FlightStreamChunk chunk; + auto num_batches = static_cast(expected_batches.size()); + for (int i = 0; i < num_batches; ++i) { + ASSERT_OK(stream->Next(&chunk)); + ASSERT_NE(nullptr, chunk.data); + ASSERT_NE(nullptr, chunk.app_metadata); + ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data); + ASSERT_EQ(std::to_string(i), chunk.app_metadata->ToString()); + } + ASSERT_OK(stream->Next(&chunk)); + ASSERT_EQ(nullptr, chunk.data); +} +void AppMetadataTest::TestDoPut() { + std::unique_ptr writer; + std::unique_ptr reader; + std::shared_ptr schema = ExampleIntSchema(); + ASSERT_OK(client_->DoPut(FlightDescriptor{}, schema, &writer, &reader)); + + RecordBatchVector expected_batches; + ASSERT_OK(ExampleIntBatches(&expected_batches)); + + std::shared_ptr chunk; + std::shared_ptr metadata; + auto num_batches = static_cast(expected_batches.size()); + for (int i = 0; i < num_batches; ++i) { + ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i], + Buffer::FromString(std::to_string(i)))); + } + // This eventually calls grpc::ClientReaderWriter::Finish which can + // hang if there are unread messages. So make sure our wrapper + // around this doesn't hang (because it drains any unread messages) + ASSERT_OK(writer->Close()); +} +// Test DoPut() with dictionaries. This tests a corner case in the +// server-side reader; see DoGetDictionaries above. +void AppMetadataTest::TestDoPutDictionaries() { + std::unique_ptr writer; + std::unique_ptr reader; + RecordBatchVector expected_batches; + ASSERT_OK(ExampleDictBatches(&expected_batches)); + // ARROW-8749: don't get the schema via ExampleDictSchema because + // DictionaryMemo uses field addresses to determine whether it's + // seen a field before. Hence, if we use a schema that is different + // (identity-wise) than the schema of the first batch we write, + // we'll end up generating a duplicate set of dictionaries that + // confuses the reader. + ASSERT_OK(client_->DoPut(FlightDescriptor{}, expected_batches[0]->schema(), &writer, + &reader)); + std::shared_ptr chunk; + std::shared_ptr metadata; + auto num_batches = static_cast(expected_batches.size()); + for (int i = 0; i < num_batches; ++i) { + ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i], + Buffer::FromString(std::to_string(i)))); + } + ASSERT_OK(writer->Close()); +} +void AppMetadataTest::TestDoPutReadMetadata() { + std::unique_ptr writer; + std::unique_ptr reader; + std::shared_ptr schema = ExampleIntSchema(); + ASSERT_OK(client_->DoPut(FlightDescriptor{}, schema, &writer, &reader)); + + RecordBatchVector expected_batches; + ASSERT_OK(ExampleIntBatches(&expected_batches)); + + std::shared_ptr chunk; + std::shared_ptr metadata; + auto num_batches = static_cast(expected_batches.size()); + for (int i = 0; i < num_batches; ++i) { + ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i], + Buffer::FromString(std::to_string(i)))); + ASSERT_OK(reader->ReadMetadata(&metadata)); + ASSERT_NE(nullptr, metadata); + ASSERT_EQ(std::to_string(i), metadata->ToString()); + } + // As opposed to DoPutDrainMetadata, now we've read the messages, so + // make sure this still closes as expected. + ASSERT_OK(writer->Close()); +} + +//------------------------------------------------------------ +// Test IPC options in data plane methods + +class IpcOptionsTestServer : public FlightServerBase { + Status DoGet(const ServerCallContext& context, const Ticket& request, + std::unique_ptr* data_stream) override { + RecordBatchVector batches; + RETURN_NOT_OK(ExampleNestedBatches(&batches)); + ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make(batches)); + *data_stream = std::unique_ptr(new RecordBatchStream(reader)); + return Status::OK(); + } + + // Just echo the number of batches written. The client will try to + // call this method with different write options set. + Status DoPut(const ServerCallContext& context, + std::unique_ptr reader, + std::unique_ptr writer) override { + FlightStreamChunk chunk; + int counter = 0; + while (true) { + RETURN_NOT_OK(reader->Next(&chunk)); + if (chunk.data == nullptr) break; + counter++; + } + auto metadata = Buffer::FromString(std::to_string(counter)); + return writer->WriteMetadata(*metadata); + } + + // Echo client data, but with write options set to limit the nesting + // level. + Status DoExchange(const ServerCallContext& context, + std::unique_ptr reader, + std::unique_ptr writer) override { + FlightStreamChunk chunk; + auto options = ipc::IpcWriteOptions::Defaults(); + options.max_recursion_depth = 1; + bool begun = false; + while (true) { + RETURN_NOT_OK(reader->Next(&chunk)); + if (!chunk.data && !chunk.app_metadata) { + break; + } + if (!begun && chunk.data) { + begun = true; + RETURN_NOT_OK(writer->Begin(chunk.data->schema(), options)); + } + if (chunk.data && chunk.app_metadata) { + RETURN_NOT_OK(writer->WriteWithMetadata(*chunk.data, chunk.app_metadata)); + } else if (chunk.data) { + RETURN_NOT_OK(writer->WriteRecordBatch(*chunk.data)); + } else if (chunk.app_metadata) { + RETURN_NOT_OK(writer->WriteMetadata(chunk.app_metadata)); + } + } + return Status::OK(); + } +}; + +void IpcOptionsTest::SetUp() { + ASSERT_OK(MakeServer( + &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); }, + [](FlightClientOptions* options) { return Status::OK(); })); +} +void IpcOptionsTest::TearDown() { + ASSERT_OK(client_->Close()); + ASSERT_OK(server_->Shutdown()); +} +void IpcOptionsTest::TestDoGetReadOptions() { + // Call DoGet, but with a very low read nesting depth set to fail the call. + Ticket ticket{""}; + auto options = FlightCallOptions(); + options.read_options.max_recursion_depth = 1; + std::unique_ptr stream; + ASSERT_OK(client_->DoGet(options, ticket, &stream)); + FlightStreamChunk chunk; + ASSERT_RAISES(Invalid, stream->Next(&chunk)); +} +void IpcOptionsTest::TestDoPutWriteOptions() { + // Call DoPut, but with a very low write nesting depth set to fail the call. + std::unique_ptr writer; + std::unique_ptr reader; + RecordBatchVector expected_batches; + ASSERT_OK(ExampleNestedBatches(&expected_batches)); + + auto options = FlightCallOptions(); + options.write_options.max_recursion_depth = 1; + ASSERT_OK(client_->DoPut(options, FlightDescriptor{}, expected_batches[0]->schema(), + &writer, &reader)); + for (const auto& batch : expected_batches) { + ASSERT_RAISES(Invalid, writer->WriteRecordBatch(*batch)); + } +} +void IpcOptionsTest::TestDoExchangeClientWriteOptions() { + // Call DoExchange and write nested data, but with a very low nesting depth set to + // fail the call. + auto options = FlightCallOptions(); + options.write_options.max_recursion_depth = 1; + auto descr = FlightDescriptor::Command(""); + std::unique_ptr reader; + std::unique_ptr writer; + ASSERT_OK(client_->DoExchange(options, descr, &writer, &reader)); + RecordBatchVector batches; + ASSERT_OK(ExampleNestedBatches(&batches)); + ASSERT_OK(writer->Begin(batches[0]->schema())); + for (const auto& batch : batches) { + ASSERT_RAISES(Invalid, writer->WriteRecordBatch(*batch)); + } + ASSERT_OK(writer->DoneWriting()); + ASSERT_OK(writer->Close()); +} +void IpcOptionsTest::TestDoExchangeClientWriteOptionsBegin() { + // Call DoExchange and write nested data, but with a very low nesting depth set to + // fail the call. Here the options are set explicitly when we write data and not in the + // call options. + auto descr = FlightDescriptor::Command(""); + std::unique_ptr reader; + std::unique_ptr writer; + ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); + RecordBatchVector batches; + ASSERT_OK(ExampleNestedBatches(&batches)); + auto options = ipc::IpcWriteOptions::Defaults(); + options.max_recursion_depth = 1; + ASSERT_OK(writer->Begin(batches[0]->schema(), options)); + for (const auto& batch : batches) { + ASSERT_RAISES(Invalid, writer->WriteRecordBatch(*batch)); + } + ASSERT_OK(writer->DoneWriting()); + ASSERT_OK(writer->Close()); +} +void IpcOptionsTest::TestDoExchangeServerWriteOptions() { + // Call DoExchange and write nested data, but with a very low nesting depth set to fail + // the call. (The low nesting depth is set on the server side.) + auto descr = FlightDescriptor::Command(""); + std::unique_ptr reader; + std::unique_ptr writer; + ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); + RecordBatchVector batches; + ASSERT_OK(ExampleNestedBatches(&batches)); + ASSERT_OK(writer->Begin(batches[0]->schema())); + FlightStreamChunk chunk; + ASSERT_OK(writer->WriteRecordBatch(*batches[0])); + ASSERT_OK(writer->DoneWriting()); + ASSERT_RAISES(Invalid, writer->Close()); +} + +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/test_definitions.h b/cpp/src/arrow/flight/test_definitions.h new file mode 100644 index 00000000000..ff107f802e3 --- /dev/null +++ b/cpp/src/arrow/flight/test_definitions.h @@ -0,0 +1,158 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Common test definitions for Flight. Individual transport +// implementations can instantiate these tests. +// +// While Googletest's value-parameterized tests would be a more +// natural way to do this, they cause runtime issues on MinGW/MSVC +// (Googletest thinks the test suite has been defined twice). + +#pragma once + +#include + +#include +#include +#include +#include + +#include "arrow/flight/server.h" +#include "arrow/flight/types.h" + +namespace arrow { +namespace flight { + +class ARROW_FLIGHT_EXPORT FlightTest : public testing::Test { + protected: + virtual std::string transport() const = 0; +}; + +/// Common tests of startup/shutdown +class ARROW_FLIGHT_EXPORT ConnectivityTest : public FlightTest { + public: + // Test methods + void TestGetPort(); + void TestBuilderHook(); + void TestShutdown(); + void TestShutdownWithDeadline(); +}; + +/// Common tests of data plane methods +class ARROW_FLIGHT_EXPORT DataTest : public FlightTest { + public: + void SetUp(); + void TearDown(); + Status ConnectClient(); + + // Test methods + void TestDoGetInts(); + void TestDoGetFloats(); + void TestDoGetDicts(); + void TestDoGetLargeBatch(); + void TestOverflowServerBatch(); + void TestOverflowClientBatch(); + void TestDoExchange(); + void TestDoExchangeNoData(); + void TestDoExchangeWriteOnlySchema(); + void TestDoExchangeGet(); + void TestDoExchangePut(); + void TestDoExchangeEcho(); + void TestDoExchangeTotal(); + void TestDoExchangeError(); + void TestIssue5095(); + + private: + void CheckDoGet( + const FlightDescriptor& descr, const RecordBatchVector& expected_batches, + std::function&)> check_endpoints); + void CheckDoGet(const Ticket& ticket, const RecordBatchVector& expected_batches); + + std::unique_ptr client_; + std::unique_ptr server_; +}; + +class ARROW_FLIGHT_EXPORT DoPutTest : public FlightTest { + public: + void SetUp(); + void TearDown(); + void CheckBatches(const FlightDescriptor& expected_descriptor, + const RecordBatchVector& expected_batches); + void CheckDoPut(const FlightDescriptor& descr, const std::shared_ptr& schema, + const RecordBatchVector& batches); + + // Test methods + void TestInts(); + void TestDoPutFloats(); + void TestDoPutEmptyBatch(); + void TestDoPutDicts(); + void TestDoPutLargeBatch(); + void TestDoPutSizeLimit(); + + private: + std::unique_ptr client_; + std::unique_ptr server_; +}; + +class ARROW_FLIGHT_EXPORT AppMetadataTestServer : public FlightServerBase { + public: + virtual ~AppMetadataTestServer() = default; + + Status DoGet(const ServerCallContext& context, const Ticket& request, + std::unique_ptr* data_stream) override; + + Status DoPut(const ServerCallContext& context, + std::unique_ptr reader, + std::unique_ptr writer) override; +}; + +class ARROW_FLIGHT_EXPORT AppMetadataTest : public FlightTest { + public: + void SetUp(); + void TearDown(); + + // Test methods + void TestDoGet(); + void TestDoGetDictionaries(); + void TestDoPut(); + void TestDoPutDictionaries(); + void TestDoPutReadMetadata(); + + private: + std::unique_ptr client_; + std::unique_ptr server_; +}; + +class ARROW_FLIGHT_EXPORT IpcOptionsTest : public FlightTest { + public: + void SetUp(); + void TearDown(); + + // Test methods + void TestDoGetReadOptions(); + void TestDoPutWriteOptions(); + void TestDoExchangeClientWriteOptions(); + void TestDoExchangeClientWriteOptionsBegin(); + void TestDoExchangeServerWriteOptions(); + + private: + std::unique_ptr client_; + std::unique_ptr server_; +}; + +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/test_util.cc b/cpp/src/arrow/flight/test_util.cc index 2325445f86c..4405c24eea5 100644 --- a/cpp/src/arrow/flight/test_util.cc +++ b/cpp/src/arrow/flight/test_util.cc @@ -151,24 +151,24 @@ const std::string& TestServer::unix_sock() const { return unix_sock_; } Status GetBatchForFlight(const Ticket& ticket, std::shared_ptr* out) { if (ticket.ticket == "ticket-ints-1") { - BatchVector batches; + RecordBatchVector batches; RETURN_NOT_OK(ExampleIntBatches(&batches)); - *out = std::make_shared(batches[0]->schema(), batches); + ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); return Status::OK(); } else if (ticket.ticket == "ticket-floats-1") { - BatchVector batches; + RecordBatchVector batches; RETURN_NOT_OK(ExampleFloatBatches(&batches)); - *out = std::make_shared(batches[0]->schema(), batches); + ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); return Status::OK(); } else if (ticket.ticket == "ticket-dicts-1") { - BatchVector batches; + RecordBatchVector batches; RETURN_NOT_OK(ExampleDictBatches(&batches)); - *out = std::make_shared(batches[0]->schema(), batches); + ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); return Status::OK(); } else if (ticket.ticket == "ticket-large-batch-1") { - BatchVector batches; + RecordBatchVector batches; RETURN_NOT_OK(ExampleLargeBatches(&batches)); - *out = std::make_shared(batches[0]->schema(), batches); + ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); return Status::OK(); } else { return Status::NotImplemented("no stream implemented for ticket: " + ticket.ticket); @@ -233,7 +233,7 @@ class FlightTestServer : public FlightServerBase { Status DoPut(const ServerCallContext&, std::unique_ptr reader, std::unique_ptr writer) override { - BatchVector batches; + RecordBatchVector batches; return reader->ReadAll(&batches); } @@ -270,7 +270,7 @@ class FlightTestServer : public FlightServerBase { Status RunExchangeGet(std::unique_ptr reader, std::unique_ptr writer) { RETURN_NOT_OK(writer->Begin(ExampleIntSchema())); - BatchVector batches; + RecordBatchVector batches; RETURN_NOT_OK(ExampleIntBatches(&batches)); for (const auto& batch : batches) { RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); @@ -285,7 +285,7 @@ class FlightTestServer : public FlightServerBase { if (!schema->Equals(ExampleIntSchema(), false)) { return Status::Invalid("Schema is not as expected"); } - BatchVector batches; + RecordBatchVector batches; RETURN_NOT_OK(ExampleIntBatches(&batches)); FlightStreamChunk chunk; for (const auto& batch : batches) { @@ -590,7 +590,7 @@ std::vector ExampleFlightInfo() { FlightInfo(flight4)}; } -Status ExampleIntBatches(BatchVector* out) { +Status ExampleIntBatches(RecordBatchVector* out) { std::shared_ptr batch; for (int i = 0; i < 5; ++i) { // Make all different sizes, use different random seed @@ -600,7 +600,7 @@ Status ExampleIntBatches(BatchVector* out) { return Status::OK(); } -Status ExampleFloatBatches(BatchVector* out) { +Status ExampleFloatBatches(RecordBatchVector* out) { std::shared_ptr batch; for (int i = 0; i < 5; ++i) { // Make all different sizes, use different random seed @@ -610,7 +610,7 @@ Status ExampleFloatBatches(BatchVector* out) { return Status::OK(); } -Status ExampleDictBatches(BatchVector* out) { +Status ExampleDictBatches(RecordBatchVector* out) { // Just the same batch, repeated a few times std::shared_ptr batch; for (int i = 0; i < 3; ++i) { @@ -620,7 +620,7 @@ Status ExampleDictBatches(BatchVector* out) { return Status::OK(); } -Status ExampleNestedBatches(BatchVector* out) { +Status ExampleNestedBatches(RecordBatchVector* out) { std::shared_ptr batch; for (int i = 0; i < 3; ++i) { RETURN_NOT_OK(ipc::test::MakeListRecordBatch(&batch)); @@ -629,7 +629,7 @@ Status ExampleNestedBatches(BatchVector* out) { return Status::OK(); } -Status ExampleLargeBatches(BatchVector* out) { +Status ExampleLargeBatches(RecordBatchVector* out) { const auto array_length = 32768; std::shared_ptr batch; std::vector> arrays; diff --git a/cpp/src/arrow/flight/test_util.h b/cpp/src/arrow/flight/test_util.h index a3c7350fa74..75a534d56b3 100644 --- a/cpp/src/arrow/flight/test_util.h +++ b/cpp/src/arrow/flight/test_util.h @@ -17,6 +17,9 @@ #pragma once +#include +#include + #include #include #include @@ -25,6 +28,7 @@ #include #include "arrow/status.h" +#include "arrow/testing/gtest_util.h" #include "arrow/testing/util.h" #include "arrow/flight/client.h" @@ -46,6 +50,24 @@ class child; namespace arrow { namespace flight { +// ---------------------------------------------------------------------- +// Helpers to compare values for equality + +inline void AssertEqual(const FlightInfo& expected, const FlightInfo& actual) { + std::shared_ptr ex_schema, actual_schema; + ipc::DictionaryMemo expected_memo; + ipc::DictionaryMemo actual_memo; + ASSERT_OK(expected.GetSchema(&expected_memo, &ex_schema)); + ASSERT_OK(actual.GetSchema(&actual_memo, &actual_schema)); + + AssertSchemaEqual(*ex_schema, *actual_schema); + ASSERT_EQ(expected.total_records(), actual.total_records()); + ASSERT_EQ(expected.total_bytes(), actual.total_bytes()); + + ASSERT_EQ(expected.descriptor(), actual.descriptor()); + ASSERT_THAT(actual.endpoints(), ::testing::ContainerEq(expected.endpoints())); +} + // ---------------------------------------------------------------------- // Fixture to use for running test servers @@ -100,43 +122,6 @@ Status MakeServer(std::unique_ptr* server, return FlightClient::Connect(real_location, client_options, client); } -// ---------------------------------------------------------------------- -// A RecordBatchReader for serving a sequence of in-memory record batches - -// Silence warning -// "non dll-interface class RecordBatchReader used as base for dll-interface class" -#ifdef _MSC_VER -#pragma warning(push) -#pragma warning(disable : 4275) -#endif - -class ARROW_FLIGHT_EXPORT BatchIterator : public RecordBatchReader { - public: - BatchIterator(const std::shared_ptr& schema, - const std::vector>& batches) - : schema_(schema), batches_(batches), position_(0) {} - - std::shared_ptr schema() const override { return schema_; } - - Status ReadNext(std::shared_ptr* out) override { - if (position_ >= batches_.size()) { - *out = nullptr; - } else { - *out = batches_[position_++]; - } - return Status::OK(); - } - - private: - std::shared_ptr schema_; - std::vector> batches_; - size_t position_; -}; - -#ifdef _MSC_VER -#pragma warning(pop) -#endif - // ---------------------------------------------------------------------- // A FlightDataStream that numbers the record batches /// \brief A basic implementation of FlightDataStream that will provide @@ -157,8 +142,6 @@ class ARROW_FLIGHT_EXPORT NumberingStream : public FlightDataStream { // ---------------------------------------------------------------------- // Example data for test-server and unit tests -using BatchVector = std::vector>; - ARROW_FLIGHT_EXPORT std::shared_ptr ExampleIntSchema(); @@ -172,19 +155,19 @@ ARROW_FLIGHT_EXPORT std::shared_ptr ExampleLargeSchema(); ARROW_FLIGHT_EXPORT -Status ExampleIntBatches(BatchVector* out); +Status ExampleIntBatches(RecordBatchVector* out); ARROW_FLIGHT_EXPORT -Status ExampleFloatBatches(BatchVector* out); +Status ExampleFloatBatches(RecordBatchVector* out); ARROW_FLIGHT_EXPORT -Status ExampleDictBatches(BatchVector* out); +Status ExampleDictBatches(RecordBatchVector* out); ARROW_FLIGHT_EXPORT -Status ExampleNestedBatches(BatchVector* out); +Status ExampleNestedBatches(RecordBatchVector* out); ARROW_FLIGHT_EXPORT -Status ExampleLargeBatches(BatchVector* out); +Status ExampleLargeBatches(RecordBatchVector* out); ARROW_FLIGHT_EXPORT arrow::Result> VeryLargeBatch(); diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index 2567bdfedd9..6338c36567b 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -306,6 +306,15 @@ Status Location::ForGrpcUnix(const std::string& path, Location* location) { return Location::Parse(uri_string.str(), location); } +arrow::Result Location::ForScheme(const std::string& scheme, + const std::string& host, const int port) { + Location location; + std::stringstream uri_string; + uri_string << scheme << "://" << host << ':' << port; + RETURN_NOT_OK(Location::Parse(uri_string.str(), &location)); + return location; +} + std::string Location::ToString() const { return uri_->ToString(); } std::string Location::scheme() const { std::string scheme = uri_->scheme(); @@ -324,6 +333,10 @@ bool FlightEndpoint::Equals(const FlightEndpoint& other) const { return ticket == other.ticket && locations == other.locations; } +bool ActionType::Equals(const ActionType& other) const { + return type == other.type && description == other.description; +} + Status MetadataRecordBatchReader::ReadAll( std::vector>* batches) { FlightStreamChunk chunk; diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index 04a9b4b716f..a609cacb95e 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -139,6 +139,15 @@ struct ARROW_FLIGHT_EXPORT ActionType { /// \brief A human-readable description of the action. std::string description; + + bool Equals(const ActionType& other) const; + + friend bool operator==(const ActionType& left, const ActionType& right) { + return left.Equals(right); + } + friend bool operator!=(const ActionType& left, const ActionType& right) { + return !(left == right); + } }; /// \brief Opaque selection criteria for ListFlights RPC @@ -312,6 +321,10 @@ struct ARROW_FLIGHT_EXPORT Location { /// \param[out] location The resulting location static Status ForGrpcUnix(const std::string& path, Location* location); + /// \brief Initialize a location based on a URI scheme + static arrow::Result ForScheme(const std::string& scheme, + const std::string& host, const int port); + /// \brief Get a representation of this URI as a string. std::string ToString() const; @@ -406,9 +419,9 @@ class ARROW_FLIGHT_EXPORT FlightInfo { const std::vector& endpoints, int64_t total_records, int64_t total_bytes); - /// \brief Deserialize the Arrow schema of the dataset, to be passed - /// to each call to DoGet. Populate any dictionary encoded fields - /// into a DictionaryMemo for bookkeeping + /// \brief Deserialize the Arrow schema of the dataset. Populate any + /// dictionary encoded fields into a DictionaryMemo for + /// bookkeeping /// \param[in,out] dictionary_memo for dictionary bookkeeping, will /// be modified /// \param[out] out the reconstructed Schema