diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 3de8ff76569..fd7027c30eb 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -334,6 +334,10 @@ if(ARROW_GANDIVA) set(ARROW_WITH_RE2 ON) endif() +if(ARROW_BUILD_INTEGRATION AND ARROW_FLIGHT) + set(ARROW_FLIGHT_SQL ON) +endif() + if(ARROW_FLIGHT_SQL) set(ARROW_FLIGHT ON) endif() diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 502629a92a4..cc979a22e09 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -736,6 +736,12 @@ if(ARROW_FLIGHT_SQL) add_subdirectory(flight/sql) endif() +if(ARROW_FLIGHT + AND ARROW_FLIGHT_SQL + AND ARROW_BUILD_INTEGRATION) + add_subdirectory(flight/integration_tests) +endif() + if(ARROW_HIVESERVER2) add_subdirectory(dbi/hiveserver2) endif() diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index 55e89b2eb99..2cf8c9913e5 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -202,7 +202,6 @@ if(ARROW_TESTING) OUTPUTS ARROW_FLIGHT_TESTING_LIBRARIES SOURCES - test_integration.cc test_util.cc DEPENDENCIES GTest::gtest @@ -246,21 +245,6 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_BENCHMARKS) add_dependencies(arrow_flight flight-test-server) endif() -if(ARROW_BUILD_INTEGRATION) - add_executable(flight-test-integration-server test_integration_server.cc) - target_link_libraries(flight-test-integration-server ${ARROW_FLIGHT_TEST_LINK_LIBS} - ${GFLAGS_LIBRARIES} GTest::gtest) - - add_executable(flight-test-integration-client test_integration_client.cc) - target_link_libraries(flight-test-integration-client ${ARROW_FLIGHT_TEST_LINK_LIBS} - ${GFLAGS_LIBRARIES} GTest::gtest) - - add_dependencies(arrow_flight flight-test-integration-client - flight-test-integration-server) - add_dependencies(arrow-integration flight-test-integration-client - flight-test-integration-server) -endif() - if(ARROW_BUILD_BENCHMARKS) # Perf server for benchmarks set(PERF_PROTO_GENERATED_FILES "${CMAKE_CURRENT_BINARY_DIR}/perf.pb.cc" diff --git a/cpp/src/arrow/flight/integration_tests/CMakeLists.txt b/cpp/src/arrow/flight/integration_tests/CMakeLists.txt new file mode 100644 index 00000000000..3a878d7f305 --- /dev/null +++ b/cpp/src/arrow/flight/integration_tests/CMakeLists.txt @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +add_custom_target(arrow_flight_integration_tests) + +if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static") + set(ARROW_FLIGHT_TEST_LINK_LIBS + arrow_flight_static + arrow_flight_testing_static + arrow_flight_sql_static + ${ARROW_FLIGHT_STATIC_LINK_LIBS} + ${ARROW_TEST_LINK_LIBS}) +else() + set(ARROW_FLIGHT_TEST_LINK_LIBS + arrow_flight_shared + arrow_flight_testing_shared + arrow_flight_sql_shared + ${ARROW_TEST_LINK_LIBS} + ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS}) +endif() + +add_executable(flight-test-integration-server test_integration_server.cc + test_integration.cc) +target_link_libraries(flight-test-integration-server ${ARROW_FLIGHT_TEST_LINK_LIBS} + ${GFLAGS_LIBRARIES} GTest::gtest) + +add_executable(flight-test-integration-client test_integration_client.cc + test_integration.cc) +target_link_libraries(flight-test-integration-client ${ARROW_FLIGHT_TEST_LINK_LIBS} + ${GFLAGS_LIBRARIES} GTest::gtest) + +add_dependencies(arrow-integration flight-test-integration-client + flight-test-integration-server) diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc new file mode 100644 index 00000000000..1e08f47b579 --- /dev/null +++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc @@ -0,0 +1,684 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/integration_tests/test_integration.h" +#include "arrow/flight/client_middleware.h" +#include "arrow/flight/server_middleware.h" +#include "arrow/flight/sql/client.h" +#include "arrow/flight/sql/server.h" +#include "arrow/flight/test_util.h" +#include "arrow/flight/types.h" +#include "arrow/ipc/dictionary.h" +#include "arrow/testing/gtest_util.h" + +#include +#include +#include +#include +#include + +namespace arrow { +namespace flight { +namespace integration_tests { + +/// \brief The server for the basic auth integration test. +class AuthBasicProtoServer : public FlightServerBase { + Status DoAction(const ServerCallContext& context, const Action& action, + std::unique_ptr* result) override { + // Respond with the authenticated username. + auto buf = Buffer::FromString(context.peer_identity()); + *result = std::unique_ptr(new SimpleResultStream({Result{buf}})); + return Status::OK(); + } +}; + +/// Validate the result of a DoAction. +Status CheckActionResults(FlightClient* client, const Action& action, + std::vector results) { + std::unique_ptr stream; + RETURN_NOT_OK(client->DoAction(action, &stream)); + std::unique_ptr result; + for (const std::string& expected : results) { + RETURN_NOT_OK(stream->Next(&result)); + if (!result) { + return Status::Invalid("Action result stream ended early"); + } + const auto actual = result->body->ToString(); + if (expected != actual) { + return Status::Invalid("Got wrong result; expected", expected, "but got", actual); + } + } + RETURN_NOT_OK(stream->Next(&result)); + if (result) { + return Status::Invalid("Action result stream had too many entries"); + } + return Status::OK(); +} + +// The expected username for the basic auth integration test. +constexpr auto kAuthUsername = "arrow"; +// The expected password for the basic auth integration test. +constexpr auto kAuthPassword = "flight"; + +/// \brief A scenario testing the basic auth protobuf. +class AuthBasicProtoScenario : public Scenario { + Status MakeServer(std::unique_ptr* server, + FlightServerOptions* options) override { + server->reset(new AuthBasicProtoServer()); + options->auth_handler = + std::make_shared(kAuthUsername, kAuthPassword); + return Status::OK(); + } + + Status MakeClient(FlightClientOptions* options) override { return Status::OK(); } + + Status RunClient(std::unique_ptr client) override { + Action action; + std::unique_ptr stream; + std::shared_ptr detail; + const auto& status = client->DoAction(action, &stream); + detail = FlightStatusDetail::UnwrapStatus(status); + // This client is unauthenticated and should fail. + if (detail == nullptr) { + return Status::Invalid("Expected UNAUTHENTICATED but got ", status.ToString()); + } + if (detail->code() != FlightStatusCode::Unauthenticated) { + return Status::Invalid("Expected UNAUTHENTICATED but got ", detail->ToString()); + } + + auto client_handler = std::unique_ptr( + new TestClientBasicAuthHandler(kAuthUsername, kAuthPassword)); + RETURN_NOT_OK(client->Authenticate({}, std::move(client_handler))); + return CheckActionResults(client.get(), action, {kAuthUsername}); + } +}; + +/// \brief Test middleware that echoes back the value of a particular +/// incoming header. +/// +/// In Java, gRPC may consolidate this header with HTTP/2 trailers if +/// the call fails, but C++ generally doesn't do this. The integration +/// test confirms the presence of this header to ensure we can read it +/// regardless of what gRPC does. +class TestServerMiddleware : public ServerMiddleware { + public: + explicit TestServerMiddleware(std::string received) : received_(received) {} + + void SendingHeaders(AddCallHeaders* outgoing_headers) override { + outgoing_headers->AddHeader("x-middleware", received_); + } + + void CallCompleted(const Status& status) override {} + + std::string name() const override { return "GrpcTrailersMiddleware"; } + + private: + std::string received_; +}; + +class TestServerMiddlewareFactory : public ServerMiddlewareFactory { + public: + Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers, + std::shared_ptr* middleware) override { + const std::pair& iter_pair = + incoming_headers.equal_range("x-middleware"); + std::string received = ""; + if (iter_pair.first != iter_pair.second) { + const util::string_view& value = (*iter_pair.first).second; + received = std::string(value); + } + *middleware = std::make_shared(received); + return Status::OK(); + } +}; + +/// \brief Test middleware that adds a header on every outgoing call, +/// and gets the value of the expected header sent by the server. +class TestClientMiddleware : public ClientMiddleware { + public: + explicit TestClientMiddleware(std::string* received_header) + : received_header_(received_header) {} + + void SendingHeaders(AddCallHeaders* outgoing_headers) { + outgoing_headers->AddHeader("x-middleware", "expected value"); + } + + void ReceivedHeaders(const CallHeaders& incoming_headers) { + // We expect the server to always send this header. gRPC/Java may + // send it in trailers instead of headers, so we expect Flight to + // account for this. + const std::pair& iter_pair = + incoming_headers.equal_range("x-middleware"); + if (iter_pair.first != iter_pair.second) { + const util::string_view& value = (*iter_pair.first).second; + *received_header_ = std::string(value); + } + } + + void CallCompleted(const Status& status) {} + + private: + std::string* received_header_; +}; + +class TestClientMiddlewareFactory : public ClientMiddlewareFactory { + public: + void StartCall(const CallInfo& info, std::unique_ptr* middleware) { + *middleware = + std::unique_ptr(new TestClientMiddleware(&received_header_)); + } + + std::string received_header_; +}; + +/// \brief The server used for testing middleware. Implements only one +/// endpoint, GetFlightInfo, in such a way that it either succeeds or +/// returns an error based on the input, in order to test both paths. +class MiddlewareServer : public FlightServerBase { + Status GetFlightInfo(const ServerCallContext& context, + const FlightDescriptor& descriptor, + std::unique_ptr* result) override { + if (descriptor.type == FlightDescriptor::DescriptorType::CMD && + descriptor.cmd == "success") { + // Don't fail + std::shared_ptr schema = arrow::schema({}); + Location location; + // Return a fake location - the test doesn't read it + RETURN_NOT_OK(Location::ForGrpcTcp("localhost", 10010, &location)); + std::vector endpoints{FlightEndpoint{{"foo"}, {location}}}; + ARROW_ASSIGN_OR_RAISE(auto info, + FlightInfo::Make(*schema, descriptor, endpoints, -1, -1)); + *result = std::unique_ptr(new FlightInfo(info)); + return Status::OK(); + } + // Fail the call immediately. In some gRPC implementations, this + // means that gRPC sends only HTTP/2 trailers and not headers. We want + // Flight middleware to be agnostic to this difference. + return Status::UnknownError("Unknown"); + } +}; + +/// \brief The middleware scenario. +/// +/// This tests that the server and client get expected header values. +class MiddlewareScenario : public Scenario { + Status MakeServer(std::unique_ptr* server, + FlightServerOptions* options) override { + options->middleware.push_back( + {"grpc_trailers", std::make_shared()}); + server->reset(new MiddlewareServer()); + return Status::OK(); + } + + Status MakeClient(FlightClientOptions* options) override { + client_middleware_ = std::make_shared(); + options->middleware.push_back(client_middleware_); + return Status::OK(); + } + + Status RunClient(std::unique_ptr client) override { + std::unique_ptr info; + // This call is expected to fail. In gRPC/Java, this causes the + // server to combine headers and HTTP/2 trailers, so to read the + // expected header, Flight must check for both headers and + // trailers. + if (client->GetFlightInfo(FlightDescriptor::Command(""), &info).ok()) { + return Status::Invalid("Expected call to fail"); + } + if (client_middleware_->received_header_ != "expected value") { + return Status::Invalid( + "Expected to receive header 'x-middleware: expected value', but instead got: '", + client_middleware_->received_header_, "'"); + } + std::cerr << "Headers received successfully on failing call." << std::endl; + + // This call should succeed + client_middleware_->received_header_ = ""; + RETURN_NOT_OK(client->GetFlightInfo(FlightDescriptor::Command("success"), &info)); + if (client_middleware_->received_header_ != "expected value") { + return Status::Invalid( + "Expected to receive header 'x-middleware: expected value', but instead got '", + client_middleware_->received_header_, "'"); + } + std::cerr << "Headers received successfully on passing call." << std::endl; + return Status::OK(); + } + + std::shared_ptr client_middleware_; +}; + +/// \brief Schema to be returned for mocking the statement/prepared statement results. +/// Must be the same across all languages. +std::shared_ptr GetQuerySchema() { + return arrow::schema({arrow::field("id", int64())}); +} + +constexpr int64_t kUpdateStatementExpectedRows = 10000L; +constexpr int64_t kUpdatePreparedStatementExpectedRows = 20000L; + +template +arrow::Status AssertEq(const T& expected, const T& actual) { + if (expected != actual) { + return Status::Invalid("Expected \"", expected, "\", got \'", actual, "\""); + } + return Status::OK(); +} + +/// \brief The server used for testing Flight SQL, this implements a static Flight SQL +/// server which only asserts that commands called during integration tests are being +/// parsed correctly and returns the expected schemas to be validated on client. +class FlightSqlScenarioServer : public sql::FlightSqlServerBase { + public: + arrow::Result> GetFlightInfoStatement( + const ServerCallContext& context, const sql::StatementQuery& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK(AssertEq("SELECT STATEMENT", command.query)); + + ARROW_ASSIGN_OR_RAISE(auto handle, + sql::CreateStatementQueryTicket("SELECT STATEMENT HANDLE")); + + std::vector endpoints{FlightEndpoint{{handle}, {}}}; + ARROW_ASSIGN_OR_RAISE( + auto result, FlightInfo::Make(*GetQuerySchema(), descriptor, endpoints, -1, -1)) + + return std::unique_ptr(new FlightInfo(result)); + } + + arrow::Result> DoGetStatement( + const ServerCallContext& context, + const sql::StatementQueryTicket& command) override { + return DoGetForTestCase(GetQuerySchema()); + } + + arrow::Result> GetFlightInfoPreparedStatement( + const ServerCallContext& context, const sql::PreparedStatementQuery& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK(AssertEq("SELECT PREPARED STATEMENT HANDLE", + command.prepared_statement_handle)); + + return GetFlightInfoForCommand(descriptor, GetQuerySchema()); + } + + arrow::Result> DoGetPreparedStatement( + const ServerCallContext& context, + const sql::PreparedStatementQuery& command) override { + return DoGetForTestCase(GetQuerySchema()); + } + + arrow::Result> GetFlightInfoCatalogs( + const ServerCallContext& context, const FlightDescriptor& descriptor) override { + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetCatalogsSchema()); + } + + arrow::Result> DoGetCatalogs( + const ServerCallContext& context) override { + return DoGetForTestCase(sql::SqlSchema::GetCatalogsSchema()); + } + + arrow::Result> GetFlightInfoSqlInfo( + const ServerCallContext& context, const sql::GetSqlInfo& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK(AssertEq(2, command.info.size())); + ARROW_RETURN_NOT_OK(AssertEq( + sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME, command.info[0])); + ARROW_RETURN_NOT_OK(AssertEq( + sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY, command.info[1])); + + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetSqlInfoSchema()); + } + + arrow::Result> DoGetSqlInfo( + const ServerCallContext& context, const sql::GetSqlInfo& command) override { + return DoGetForTestCase(sql::SqlSchema::GetSqlInfoSchema()); + } + + arrow::Result> GetFlightInfoSchemas( + const ServerCallContext& context, const sql::GetDbSchemas& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK(AssertEq("catalog", command.catalog.value())); + ARROW_RETURN_NOT_OK(AssertEq("db_schema_filter_pattern", + command.db_schema_filter_pattern.value())); + + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetDbSchemasSchema()); + } + + arrow::Result> DoGetDbSchemas( + const ServerCallContext& context, const sql::GetDbSchemas& command) override { + return DoGetForTestCase(sql::SqlSchema::GetDbSchemasSchema()); + } + + arrow::Result> GetFlightInfoTables( + const ServerCallContext& context, const sql::GetTables& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK(AssertEq("catalog", command.catalog.value())); + ARROW_RETURN_NOT_OK(AssertEq("db_schema_filter_pattern", + command.db_schema_filter_pattern.value())); + ARROW_RETURN_NOT_OK(AssertEq("table_filter_pattern", + command.table_name_filter_pattern.value())); + ARROW_RETURN_NOT_OK(AssertEq(2, command.table_types.size())); + ARROW_RETURN_NOT_OK(AssertEq("table", command.table_types[0])); + ARROW_RETURN_NOT_OK(AssertEq("view", command.table_types[1])); + ARROW_RETURN_NOT_OK(AssertEq(true, command.include_schema)); + + return GetFlightInfoForCommand(descriptor, + sql::SqlSchema::GetTablesSchemaWithIncludedSchema()); + } + + arrow::Result> DoGetTables( + const ServerCallContext& context, const sql::GetTables& command) override { + return DoGetForTestCase(sql::SqlSchema::GetTablesSchemaWithIncludedSchema()); + } + + arrow::Result> GetFlightInfoTableTypes( + const ServerCallContext& context, const FlightDescriptor& descriptor) override { + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetTableTypesSchema()); + } + + arrow::Result> DoGetTableTypes( + const ServerCallContext& context) override { + return DoGetForTestCase(sql::SqlSchema::GetTableTypesSchema()); + } + + arrow::Result> GetFlightInfoPrimaryKeys( + const ServerCallContext& context, const sql::GetPrimaryKeys& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK( + AssertEq("catalog", command.table_ref.catalog.value())); + ARROW_RETURN_NOT_OK( + AssertEq("db_schema", command.table_ref.db_schema.value())); + ARROW_RETURN_NOT_OK(AssertEq("table", command.table_ref.table)); + + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetPrimaryKeysSchema()); + } + + arrow::Result> DoGetPrimaryKeys( + const ServerCallContext& context, const sql::GetPrimaryKeys& command) override { + return DoGetForTestCase(sql::SqlSchema::GetPrimaryKeysSchema()); + } + + arrow::Result> GetFlightInfoExportedKeys( + const ServerCallContext& context, const sql::GetExportedKeys& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK( + AssertEq("catalog", command.table_ref.catalog.value())); + ARROW_RETURN_NOT_OK( + AssertEq("db_schema", command.table_ref.db_schema.value())); + ARROW_RETURN_NOT_OK(AssertEq("table", command.table_ref.table)); + + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetExportedKeysSchema()); + } + + arrow::Result> DoGetExportedKeys( + const ServerCallContext& context, const sql::GetExportedKeys& command) override { + return DoGetForTestCase(sql::SqlSchema::GetExportedKeysSchema()); + } + + arrow::Result> GetFlightInfoImportedKeys( + const ServerCallContext& context, const sql::GetImportedKeys& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK( + AssertEq("catalog", command.table_ref.catalog.value())); + ARROW_RETURN_NOT_OK( + AssertEq("db_schema", command.table_ref.db_schema.value())); + ARROW_RETURN_NOT_OK(AssertEq("table", command.table_ref.table)); + + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetImportedKeysSchema()); + } + + arrow::Result> DoGetImportedKeys( + const ServerCallContext& context, const sql::GetImportedKeys& command) override { + return DoGetForTestCase(sql::SqlSchema::GetImportedKeysSchema()); + } + + arrow::Result> GetFlightInfoCrossReference( + const ServerCallContext& context, const sql::GetCrossReference& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK( + AssertEq("pk_catalog", command.pk_table_ref.catalog.value())); + ARROW_RETURN_NOT_OK( + AssertEq("pk_db_schema", command.pk_table_ref.db_schema.value())); + ARROW_RETURN_NOT_OK(AssertEq("pk_table", command.pk_table_ref.table)); + ARROW_RETURN_NOT_OK( + AssertEq("fk_catalog", command.fk_table_ref.catalog.value())); + ARROW_RETURN_NOT_OK( + AssertEq("fk_db_schema", command.fk_table_ref.db_schema.value())); + ARROW_RETURN_NOT_OK(AssertEq("fk_table", command.fk_table_ref.table)); + + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetTableTypesSchema()); + } + + arrow::Result> DoGetCrossReference( + const ServerCallContext& context, const sql::GetCrossReference& command) override { + return DoGetForTestCase(sql::SqlSchema::GetCrossReferenceSchema()); + } + + arrow::Result DoPutCommandStatementUpdate( + const ServerCallContext& context, const sql::StatementUpdate& command) override { + ARROW_RETURN_NOT_OK(AssertEq("UPDATE STATEMENT", command.query)); + + return kUpdateStatementExpectedRows; + } + + arrow::Result CreatePreparedStatement( + const ServerCallContext& context, + const sql::ActionCreatePreparedStatementRequest& request) override { + ARROW_RETURN_NOT_OK( + AssertEq(true, request.query == "SELECT PREPARED STATEMENT" || + request.query == "UPDATE PREPARED STATEMENT")); + + sql::ActionCreatePreparedStatementResult result; + result.prepared_statement_handle = request.query + " HANDLE"; + + return result; + } + + Status ClosePreparedStatement( + const ServerCallContext& context, + const sql::ActionClosePreparedStatementRequest& request) override { + return Status::OK(); + } + + Status DoPutPreparedStatementQuery(const ServerCallContext& context, + const sql::PreparedStatementQuery& command, + FlightMessageReader* reader, + FlightMetadataWriter* writer) override { + ARROW_RETURN_NOT_OK(AssertEq("SELECT PREPARED STATEMENT HANDLE", + command.prepared_statement_handle)); + + ARROW_ASSIGN_OR_RAISE(auto actual_schema, reader->GetSchema()); + ARROW_RETURN_NOT_OK(AssertEq(*GetQuerySchema(), *actual_schema)); + + return Status::OK(); + } + + arrow::Result DoPutPreparedStatementUpdate( + const ServerCallContext& context, const sql::PreparedStatementUpdate& command, + FlightMessageReader* reader) override { + ARROW_RETURN_NOT_OK(AssertEq("UPDATE PREPARED STATEMENT HANDLE", + command.prepared_statement_handle)); + + return kUpdatePreparedStatementExpectedRows; + } + + private: + arrow::Result> GetFlightInfoForCommand( + const FlightDescriptor& descriptor, const std::shared_ptr& schema) { + std::vector endpoints{FlightEndpoint{{descriptor.cmd}, {}}}; + ARROW_ASSIGN_OR_RAISE(auto result, + FlightInfo::Make(*schema, descriptor, endpoints, -1, -1)) + + return std::unique_ptr(new FlightInfo(result)); + } + + arrow::Result> DoGetForTestCase( + const std::shared_ptr& schema) { + ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({}, schema)); + return std::unique_ptr(new RecordBatchStream(reader)); + } +}; + +/// \brief Integration test scenario for validating Flight SQL specs across multiple +/// implementations. This should ensure that RPC objects are being built and parsed +/// correctly for multiple languages and that the Arrow schemas are returned as expected. +class FlightSqlScenario : public Scenario { + Status MakeServer(std::unique_ptr* server, + FlightServerOptions* options) override { + server->reset(new FlightSqlScenarioServer()); + return Status::OK(); + } + + Status MakeClient(FlightClientOptions* options) override { return Status::OK(); } + + Status Validate(std::shared_ptr expected_schema, + arrow::Result> flight_info_result, + sql::FlightSqlClient* sql_client) { + FlightCallOptions call_options; + + ARROW_ASSIGN_OR_RAISE(auto flight_info, flight_info_result); + ARROW_ASSIGN_OR_RAISE( + auto reader, sql_client->DoGet(call_options, flight_info->endpoints()[0].ticket)); + + ARROW_ASSIGN_OR_RAISE(auto actual_schema, reader->GetSchema()); + + AssertSchemaEqual(expected_schema, actual_schema); + + return Status::OK(); + } + + Status RunClient(std::unique_ptr client) override { + sql::FlightSqlClient sql_client(std::move(client)); + + ARROW_RETURN_NOT_OK(ValidateMetadataRetrieval(&sql_client)); + + ARROW_RETURN_NOT_OK(ValidateStatementExecution(&sql_client)); + + ARROW_RETURN_NOT_OK(ValidatePreparedStatementExecution(&sql_client)); + + return Status::OK(); + } + + Status ValidateMetadataRetrieval(sql::FlightSqlClient* sql_client) { + FlightCallOptions options; + + std::string catalog = "catalog"; + std::string db_schema_filter_pattern = "db_schema_filter_pattern"; + std::string table_filter_pattern = "table_filter_pattern"; + std::string table = "table"; + std::string db_schema = "db_schema"; + std::vector table_types = {"table", "view"}; + + sql::TableRef table_ref = {catalog, db_schema, table}; + sql::TableRef pk_table_ref = {"pk_catalog", "pk_db_schema", "pk_table"}; + sql::TableRef fk_table_ref = {"fk_catalog", "fk_db_schema", "fk_table"}; + + ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetCatalogsSchema(), + sql_client->GetCatalogs(options), sql_client)); + ARROW_RETURN_NOT_OK( + Validate(sql::SqlSchema::GetDbSchemasSchema(), + sql_client->GetDbSchemas(options, &catalog, &db_schema_filter_pattern), + sql_client)); + ARROW_RETURN_NOT_OK( + Validate(sql::SqlSchema::GetTablesSchemaWithIncludedSchema(), + sql_client->GetTables(options, &catalog, &db_schema_filter_pattern, + &table_filter_pattern, true, &table_types), + sql_client)); + ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetTableTypesSchema(), + sql_client->GetTableTypes(options), sql_client)); + ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetPrimaryKeysSchema(), + sql_client->GetPrimaryKeys(options, table_ref), + sql_client)); + ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetExportedKeysSchema(), + sql_client->GetExportedKeys(options, table_ref), + sql_client)); + ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetImportedKeysSchema(), + sql_client->GetImportedKeys(options, table_ref), + sql_client)); + ARROW_RETURN_NOT_OK(Validate( + sql::SqlSchema::GetCrossReferenceSchema(), + sql_client->GetCrossReference(options, pk_table_ref, fk_table_ref), sql_client)); + ARROW_RETURN_NOT_OK(Validate( + sql::SqlSchema::GetSqlInfoSchema(), + sql_client->GetSqlInfo( + options, {sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME, + sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY}), + sql_client)); + + return Status::OK(); + } + + Status ValidateStatementExecution(sql::FlightSqlClient* sql_client) { + FlightCallOptions options; + + ARROW_RETURN_NOT_OK(Validate( + GetQuerySchema(), sql_client->Execute(options, "SELECT STATEMENT"), sql_client)); + ARROW_ASSIGN_OR_RAISE(auto update_statement_result, + sql_client->ExecuteUpdate(options, "UPDATE STATEMENT")); + if (update_statement_result != kUpdateStatementExpectedRows) { + return Status::Invalid("Expected 'UPDATE STATEMENT' return ", + kUpdateStatementExpectedRows, ", got ", + update_statement_result); + } + + return Status::OK(); + } + + Status ValidatePreparedStatementExecution(sql::FlightSqlClient* sql_client) { + FlightCallOptions options; + + ARROW_ASSIGN_OR_RAISE(auto select_prepared_statement, + sql_client->Prepare(options, "SELECT PREPARED STATEMENT")); + + auto parameters = + RecordBatch::Make(GetQuerySchema(), 1, {ArrayFromJSON(int64(), "[1]")}); + ARROW_RETURN_NOT_OK(select_prepared_statement->SetParameters(parameters)); + + ARROW_RETURN_NOT_OK( + Validate(GetQuerySchema(), select_prepared_statement->Execute(), sql_client)); + ARROW_RETURN_NOT_OK(select_prepared_statement->Close()); + + ARROW_ASSIGN_OR_RAISE(auto update_prepared_statement, + sql_client->Prepare(options, "UPDATE PREPARED STATEMENT")); + ARROW_ASSIGN_OR_RAISE(auto update_prepared_statement_result, + update_prepared_statement->ExecuteUpdate()); + if (update_prepared_statement_result != kUpdatePreparedStatementExpectedRows) { + return Status::Invalid("Expected 'UPDATE STATEMENT' return ", + kUpdatePreparedStatementExpectedRows, ", got ", + update_prepared_statement_result); + } + ARROW_RETURN_NOT_OK(update_prepared_statement->Close()); + + return Status::OK(); + } +}; + +Status GetScenario(const std::string& scenario_name, std::shared_ptr* out) { + if (scenario_name == "auth:basic_proto") { + *out = std::make_shared(); + return Status::OK(); + } else if (scenario_name == "middleware") { + *out = std::make_shared(); + return Status::OK(); + } else if (scenario_name == "flight_sql") { + *out = std::make_shared(); + return Status::OK(); + } + return Status::KeyError("Scenario not found: ", scenario_name); +} + +} // namespace integration_tests +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/test_integration.h b/cpp/src/arrow/flight/integration_tests/test_integration.h similarity index 95% rename from cpp/src/arrow/flight/test_integration.h rename to cpp/src/arrow/flight/integration_tests/test_integration.h index 5d9bd7fd7bd..74093f8cd23 100644 --- a/cpp/src/arrow/flight/test_integration.h +++ b/cpp/src/arrow/flight/integration_tests/test_integration.h @@ -17,6 +17,8 @@ // Integration test scenarios for Arrow Flight. +#pragma once + #include "arrow/flight/visibility.h" #include @@ -28,16 +30,20 @@ namespace arrow { namespace flight { +namespace integration_tests { /// \brief An integration test for Arrow Flight. class ARROW_FLIGHT_EXPORT Scenario { public: virtual ~Scenario() = default; + /// \brief Set up the server. virtual Status MakeServer(std::unique_ptr* server, FlightServerOptions* options) = 0; + /// \brief Set up the client. virtual Status MakeClient(FlightClientOptions* options) = 0; + /// \brief Run the scenario as the client. virtual Status RunClient(std::unique_ptr client) = 0; }; @@ -45,5 +51,6 @@ class ARROW_FLIGHT_EXPORT Scenario { /// \brief Get the implementation of an integration test scenario by name. Status GetScenario(const std::string& scenario_name, std::shared_ptr* out); +} // namespace integration_tests } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/test_integration_client.cc b/cpp/src/arrow/flight/integration_tests/test_integration_client.cc similarity index 94% rename from cpp/src/arrow/flight/test_integration_client.cc rename to cpp/src/arrow/flight/integration_tests/test_integration_client.cc index 6c1d6904603..366284389f1 100644 --- a/cpp/src/arrow/flight/test_integration_client.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration_client.cc @@ -41,7 +41,7 @@ #include "arrow/util/logging.h" #include "arrow/flight/api.h" -#include "arrow/flight/test_integration.h" +#include "arrow/flight/integration_tests/test_integration.h" #include "arrow/flight/test_util.h" DEFINE_string(host, "localhost", "Server port to connect to"); @@ -51,6 +51,7 @@ DEFINE_string(scenario, "", "Integration test scenario to run"); namespace arrow { namespace flight { +namespace integration_tests { /// \brief Helper to read all batches from a JsonReader Status ReadBatches(std::unique_ptr& reader, @@ -133,7 +134,7 @@ Status ConsumeFlightLocation( return Status::OK(); } -class IntegrationTestScenario : public flight::Scenario { +class IntegrationTestScenario : public Scenario { public: Status MakeServer(std::unique_ptr* server, FlightServerOptions* options) override { @@ -201,12 +202,13 @@ class IntegrationTestScenario : public flight::Scenario { } }; +} // namespace integration_tests } // namespace flight } // namespace arrow constexpr int kRetries = 3; -arrow::Status RunScenario(arrow::flight::Scenario* scenario) { +arrow::Status RunScenario(arrow::flight::integration_tests::Scenario* scenario) { auto options = arrow::flight::FlightClientOptions::Defaults(); std::unique_ptr client; @@ -222,11 +224,13 @@ int main(int argc, char** argv) { gflags::SetUsageMessage("Integration testing client for Flight."); gflags::ParseCommandLineFlags(&argc, &argv, true); - std::shared_ptr scenario; + std::shared_ptr scenario; if (!FLAGS_scenario.empty()) { - ARROW_CHECK_OK(arrow::flight::GetScenario(FLAGS_scenario, &scenario)); + ARROW_CHECK_OK( + arrow::flight::integration_tests::GetScenario(FLAGS_scenario, &scenario)); } else { - scenario = std::make_shared(); + scenario = + std::make_shared(); } // ARROW-11908: retry a few times in case a client is slow to bring up the server diff --git a/cpp/src/arrow/flight/test_integration_server.cc b/cpp/src/arrow/flight/integration_tests/test_integration_server.cc similarity index 94% rename from cpp/src/arrow/flight/test_integration_server.cc rename to cpp/src/arrow/flight/integration_tests/test_integration_server.cc index 4b904b0eba1..92b2241a872 100644 --- a/cpp/src/arrow/flight/test_integration_server.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration_server.cc @@ -34,10 +34,10 @@ #include "arrow/testing/json_integration.h" #include "arrow/util/logging.h" +#include "arrow/flight/integration_tests/test_integration.h" #include "arrow/flight/internal.h" #include "arrow/flight/server.h" #include "arrow/flight/server_auth.h" -#include "arrow/flight/test_integration.h" #include "arrow/flight/test_util.h" DEFINE_int32(port, 31337, "Server port to listen on"); @@ -45,6 +45,7 @@ DEFINE_string(scenario, "", "Integration test senario to run"); namespace arrow { namespace flight { +namespace integration_tests { struct IntegrationDataset { std::shared_ptr schema; @@ -175,6 +176,7 @@ class IntegrationTestScenario : public Scenario { } }; +} // namespace integration_tests } // namespace flight } // namespace arrow @@ -184,12 +186,14 @@ int main(int argc, char** argv) { gflags::SetUsageMessage("Integration testing server for Flight."); gflags::ParseCommandLineFlags(&argc, &argv, true); - std::shared_ptr scenario; + std::shared_ptr scenario; if (!FLAGS_scenario.empty()) { - ARROW_CHECK_OK(arrow::flight::GetScenario(FLAGS_scenario, &scenario)); + ARROW_CHECK_OK( + arrow::flight::integration_tests::GetScenario(FLAGS_scenario, &scenario)); } else { - scenario = std::make_shared(); + scenario = + std::make_shared(); } arrow::flight::Location location; ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location)); diff --git a/cpp/src/arrow/flight/sql/server.cc b/cpp/src/arrow/flight/sql/server.cc index 6d328c07b0e..bbbe801ea24 100644 --- a/cpp/src/arrow/flight/sql/server.cc +++ b/cpp/src/arrow/flight/sql/server.cc @@ -689,7 +689,7 @@ arrow::Result FlightSqlServerBase::DoPutCommandStatementUpdate( } std::shared_ptr SqlSchema::GetCatalogsSchema() { - return arrow::schema({field("catalog_name", utf8())}); + return arrow::schema({field("catalog_name", utf8(), false)}); } std::shared_ptr SqlSchema::GetDbSchemasSchema() { @@ -699,23 +699,26 @@ std::shared_ptr SqlSchema::GetDbSchemasSchema() { std::shared_ptr SqlSchema::GetTablesSchema() { return arrow::schema({field("catalog_name", utf8()), field("db_schema_name", utf8()), - field("table_name", utf8()), field("table_type", utf8())}); + field("table_name", utf8(), false), + field("table_type", utf8(), false)}); } std::shared_ptr SqlSchema::GetTablesSchemaWithIncludedSchema() { return arrow::schema({field("catalog_name", utf8()), field("db_schema_name", utf8()), - field("table_name", utf8()), field("table_type", utf8()), - field("table_schema", binary())}); + field("table_name", utf8(), false), + field("table_type", utf8(), false), + field("table_schema", binary(), false)}); } std::shared_ptr SqlSchema::GetTableTypesSchema() { - return arrow::schema({field("table_type", utf8())}); + return arrow::schema({field("table_type", utf8(), false)}); } std::shared_ptr SqlSchema::GetPrimaryKeysSchema() { - return arrow::schema({field("catalog_name", utf8()), field("db_schema_name", utf8()), - field("table_name", utf8()), field("column_name", utf8()), - field("key_sequence", int64()), field("key_name", utf8())}); + return arrow::schema( + {field("catalog_name", utf8()), field("db_schema_name", utf8()), + field("table_name", utf8(), false), field("column_name", utf8(), false), + field("key_sequence", int32(), false), field("key_name", utf8())}); } std::shared_ptr GetImportedExportedKeysAndCrossReferenceSchema() { @@ -742,7 +745,7 @@ std::shared_ptr SqlSchema::GetCrossReferenceSchema() { } std::shared_ptr SqlSchema::GetSqlInfoSchema() { - return arrow::schema({field("name", uint32(), false), + return arrow::schema({field("info_name", uint32(), false), field("value", dense_union({field("string_value", utf8(), false), field("bool_value", boolean(), false), diff --git a/cpp/src/arrow/flight/sql/server_test.cc b/cpp/src/arrow/flight/sql/server_test.cc index 8dfea7a013e..d74b6d40137 100644 --- a/cpp/src/arrow/flight/sql/server_test.cc +++ b/cpp/src/arrow/flight/sql/server_test.cc @@ -609,7 +609,7 @@ TEST_F(TestFlightSqlServer, TestCommandGetPrimaryKeys) { const auto key_name = ArrayFromJSON(utf8(), R"([null])"); const auto table_name = ArrayFromJSON(utf8(), R"(["intTable"])"); const auto column_name = ArrayFromJSON(utf8(), R"(["id"])"); - const auto key_sequence = ArrayFromJSON(int64(), R"([1])"); + const auto key_sequence = ArrayFromJSON(int32(), R"([1])"); const std::shared_ptr& expected_table = Table::Make( SqlSchema::GetPrimaryKeysSchema(), @@ -758,7 +758,7 @@ TEST_F(TestFlightSqlServer, TestCommandGetSqlInfoNoInfo) { ASSERT_OK_AND_ASSIGN(auto flight_info, sql_client->GetSqlInfo(call_options, {999999})); EXPECT_RAISES_WITH_MESSAGE_THAT( - KeyError, ::testing::HasSubstr("No information for SQL info number 999999."), + KeyError, ::testing::HasSubstr("No information for SQL info number 999999"), sql_client->DoGet(call_options, flight_info->endpoints()[0].ticket)); } diff --git a/cpp/src/arrow/flight/sql/test_server_cli.cc b/cpp/src/arrow/flight/sql/test_server_cli.cc index 8074ab534bd..e0ba5340e8d 100644 --- a/cpp/src/arrow/flight/sql/test_server_cli.cc +++ b/cpp/src/arrow/flight/sql/test_server_cli.cc @@ -17,14 +17,13 @@ #include +#include #include #include #include #include "arrow/flight/server.h" #include "arrow/flight/sql/example/sqlite_server.h" -#include "arrow/flight/test_integration.h" -#include "arrow/flight/test_util.h" #include "arrow/io/test_common.h" #include "arrow/testing/json_integration.h" #include "arrow/util/logging.h" diff --git a/cpp/src/arrow/flight/test_integration.cc b/cpp/src/arrow/flight/test_integration.cc deleted file mode 100644 index 29ce5601f37..00000000000 --- a/cpp/src/arrow/flight/test_integration.cc +++ /dev/null @@ -1,270 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/flight/test_integration.h" -#include "arrow/flight/client_middleware.h" -#include "arrow/flight/server_middleware.h" -#include "arrow/flight/test_util.h" -#include "arrow/flight/types.h" -#include "arrow/ipc/dictionary.h" - -#include -#include -#include -#include -#include - -namespace arrow { -namespace flight { - -/// \brief The server for the basic auth integration test. -class AuthBasicProtoServer : public FlightServerBase { - Status DoAction(const ServerCallContext& context, const Action& action, - std::unique_ptr* result) override { - // Respond with the authenticated username. - auto buf = Buffer::FromString(context.peer_identity()); - *result = std::unique_ptr(new SimpleResultStream({Result{buf}})); - return Status::OK(); - } -}; - -/// Validate the result of a DoAction. -Status CheckActionResults(FlightClient* client, const Action& action, - std::vector results) { - std::unique_ptr stream; - RETURN_NOT_OK(client->DoAction(action, &stream)); - std::unique_ptr result; - for (const std::string& expected : results) { - RETURN_NOT_OK(stream->Next(&result)); - if (!result) { - return Status::Invalid("Action result stream ended early"); - } - const auto actual = result->body->ToString(); - if (expected != actual) { - return Status::Invalid("Got wrong result; expected", expected, "but got", actual); - } - } - RETURN_NOT_OK(stream->Next(&result)); - if (result) { - return Status::Invalid("Action result stream had too many entries"); - } - return Status::OK(); -} - -// The expected username for the basic auth integration test. -constexpr auto kAuthUsername = "arrow"; -// The expected password for the basic auth integration test. -constexpr auto kAuthPassword = "flight"; - -/// \brief A scenario testing the basic auth protobuf. -class AuthBasicProtoScenario : public Scenario { - Status MakeServer(std::unique_ptr* server, - FlightServerOptions* options) override { - server->reset(new AuthBasicProtoServer()); - options->auth_handler = - std::make_shared(kAuthUsername, kAuthPassword); - return Status::OK(); - } - - Status MakeClient(FlightClientOptions* options) override { return Status::OK(); } - - Status RunClient(std::unique_ptr client) override { - Action action; - std::unique_ptr stream; - std::shared_ptr detail; - const auto& status = client->DoAction(action, &stream); - detail = FlightStatusDetail::UnwrapStatus(status); - // This client is unauthenticated and should fail. - if (detail == nullptr) { - return Status::Invalid("Expected UNAUTHENTICATED but got ", status.ToString()); - } - if (detail->code() != FlightStatusCode::Unauthenticated) { - return Status::Invalid("Expected UNAUTHENTICATED but got ", detail->ToString()); - } - - auto client_handler = std::unique_ptr( - new TestClientBasicAuthHandler(kAuthUsername, kAuthPassword)); - RETURN_NOT_OK(client->Authenticate({}, std::move(client_handler))); - return CheckActionResults(client.get(), action, {kAuthUsername}); - } -}; - -/// \brief Test middleware that echoes back the value of a particular -/// incoming header. -/// -/// In Java, gRPC may consolidate this header with HTTP/2 trailers if -/// the call fails, but C++ generally doesn't do this. The integration -/// test confirms the presence of this header to ensure we can read it -/// regardless of what gRPC does. -class TestServerMiddleware : public ServerMiddleware { - public: - explicit TestServerMiddleware(std::string received) : received_(received) {} - void SendingHeaders(AddCallHeaders* outgoing_headers) override { - outgoing_headers->AddHeader("x-middleware", received_); - } - void CallCompleted(const Status& status) override {} - - std::string name() const override { return "GrpcTrailersMiddleware"; } - - private: - std::string received_; -}; - -class TestServerMiddlewareFactory : public ServerMiddlewareFactory { - public: - Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers, - std::shared_ptr* middleware) override { - const std::pair& iter_pair = - incoming_headers.equal_range("x-middleware"); - std::string received = ""; - if (iter_pair.first != iter_pair.second) { - const util::string_view& value = (*iter_pair.first).second; - received = std::string(value); - } - *middleware = std::make_shared(received); - return Status::OK(); - } -}; - -/// \brief Test middleware that adds a header on every outgoing call, -/// and gets the value of the expected header sent by the server. -class TestClientMiddleware : public ClientMiddleware { - public: - explicit TestClientMiddleware(std::string* received_header) - : received_header_(received_header) {} - - void SendingHeaders(AddCallHeaders* outgoing_headers) { - outgoing_headers->AddHeader("x-middleware", "expected value"); - } - - void ReceivedHeaders(const CallHeaders& incoming_headers) { - // We expect the server to always send this header. gRPC/Java may - // send it in trailers instead of headers, so we expect Flight to - // account for this. - const std::pair& iter_pair = - incoming_headers.equal_range("x-middleware"); - if (iter_pair.first != iter_pair.second) { - const util::string_view& value = (*iter_pair.first).second; - *received_header_ = std::string(value); - } - } - - void CallCompleted(const Status& status) {} - - private: - std::string* received_header_; -}; - -class TestClientMiddlewareFactory : public ClientMiddlewareFactory { - public: - void StartCall(const CallInfo& info, std::unique_ptr* middleware) { - *middleware = - std::unique_ptr(new TestClientMiddleware(&received_header_)); - } - - std::string received_header_; -}; - -/// \brief The server used for testing middleware. Implements only one -/// endpoint, GetFlightInfo, in such a way that it either succeeds or -/// returns an error based on the input, in order to test both paths. -class MiddlewareServer : public FlightServerBase { - Status GetFlightInfo(const ServerCallContext& context, - const FlightDescriptor& descriptor, - std::unique_ptr* result) override { - if (descriptor.type == FlightDescriptor::DescriptorType::CMD && - descriptor.cmd == "success") { - // Don't fail - std::shared_ptr schema = arrow::schema({}); - Location location; - // Return a fake location - the test doesn't read it - RETURN_NOT_OK(Location::ForGrpcTcp("localhost", 10010, &location)); - std::vector endpoints{FlightEndpoint{{"foo"}, {location}}}; - ARROW_ASSIGN_OR_RAISE(auto info, - FlightInfo::Make(*schema, descriptor, endpoints, -1, -1)); - *result = std::unique_ptr(new FlightInfo(info)); - return Status::OK(); - } - // Fail the call immediately. In some gRPC implementations, this - // means that gRPC sends only HTTP/2 trailers and not headers. We want - // Flight middleware to be agnostic to this difference. - return Status::UnknownError("Unknown"); - } -}; - -/// \brief The middleware scenario. -/// -/// This tests that the server and client get expected header values. -class MiddlewareScenario : public Scenario { - Status MakeServer(std::unique_ptr* server, - FlightServerOptions* options) override { - options->middleware.push_back( - {"grpc_trailers", std::make_shared()}); - server->reset(new MiddlewareServer()); - return Status::OK(); - } - - Status MakeClient(FlightClientOptions* options) override { - client_middleware_ = std::make_shared(); - options->middleware.push_back(client_middleware_); - return Status::OK(); - } - - Status RunClient(std::unique_ptr client) override { - std::unique_ptr info; - // This call is expected to fail. In gRPC/Java, this causes the - // server to combine headers and HTTP/2 trailers, so to read the - // expected header, Flight must check for both headers and - // trailers. - if (client->GetFlightInfo(FlightDescriptor::Command(""), &info).ok()) { - return Status::Invalid("Expected call to fail"); - } - if (client_middleware_->received_header_ != "expected value") { - return Status::Invalid( - "Expected to receive header 'x-middleware: expected value', but instead got: '", - client_middleware_->received_header_, "'"); - } - std::cerr << "Headers received successfully on failing call." << std::endl; - - // This call should succeed - client_middleware_->received_header_ = ""; - RETURN_NOT_OK(client->GetFlightInfo(FlightDescriptor::Command("success"), &info)); - if (client_middleware_->received_header_ != "expected value") { - return Status::Invalid( - "Expected to receive header 'x-middleware: expected value', but instead got '", - client_middleware_->received_header_, "'"); - } - std::cerr << "Headers received successfully on passing call." << std::endl; - return Status::OK(); - } - - std::shared_ptr client_middleware_; -}; - -Status GetScenario(const std::string& scenario_name, std::shared_ptr* out) { - if (scenario_name == "auth:basic_proto") { - *out = std::make_shared(); - return Status::OK(); - } else if (scenario_name == "middleware") { - *out = std::make_shared(); - return Status::OK(); - } - return Status::KeyError("Scenario not found: ", scenario_name); -} - -} // namespace flight -} // namespace arrow diff --git a/dev/archery/archery/integration/runner.py b/dev/archery/archery/integration/runner.py index 96ebd48912b..74bbed1fc4f 100644 --- a/dev/archery/archery/integration/runner.py +++ b/dev/archery/archery/integration/runner.py @@ -382,6 +382,11 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True, description="Ensure headers are propagated via middleware.", skip={"Rust"} # TODO(ARROW-10961): tonic upgrade needed ), + Scenario( + "flight_sql", + description="Ensure Flight SQL protocol is working as expected.", + skip={"Rust", "Go"} + ), ] runner = IntegrationRunner(json_files, flight_scenarios, testers, **kwargs) diff --git a/dev/archery/archery/integration/tester_java.py b/dev/archery/archery/integration/tester_java.py index 5104a0cc755..69c6e54e056 100644 --- a/dev/archery/archery/integration/tester_java.py +++ b/dev/archery/archery/integration/tester_java.py @@ -49,11 +49,12 @@ class JavaTester(Tester): ARROW_FLIGHT_JAR = os.environ.get( 'ARROW_FLIGHT_JAVA_INTEGRATION_JAR', os.path.join(ARROW_ROOT_DEFAULT, - 'java/flight/flight-core/target/flight-core-{}-' - 'jar-with-dependencies.jar'.format(_arrow_version))) - ARROW_FLIGHT_SERVER = ('org.apache.arrow.flight.example.integration.' + 'java/flight/flight-integration-tests/target/' + 'flight-integration-tests-{}-jar-with-dependencies.jar' + .format(_arrow_version))) + ARROW_FLIGHT_SERVER = ('org.apache.arrow.flight.integration.tests.' 'IntegrationTestServer') - ARROW_FLIGHT_CLIENT = ('org.apache.arrow.flight.example.integration.' + ARROW_FLIGHT_CLIENT = ('org.apache.arrow.flight.integration.tests.' 'IntegrationTestClient') name = 'Java' diff --git a/java/flight/flight-core/pom.xml b/java/flight/flight-core/pom.xml index c8ab5ac1d26..d870faf9c50 100644 --- a/java/flight/flight-core/pom.xml +++ b/java/flight/flight-core/pom.xml @@ -93,11 +93,6 @@ com.google.guava guava - - commons-cli - commons-cli - 1.4 - io.grpc grpc-stub diff --git a/java/flight/flight-integration-tests/pom.xml b/java/flight/flight-integration-tests/pom.xml new file mode 100644 index 00000000000..2bd9a9f4e04 --- /dev/null +++ b/java/flight/flight-integration-tests/pom.xml @@ -0,0 +1,86 @@ + + + + 4.0.0 + + arrow-flight + org.apache.arrow + 7.0.0-SNAPSHOT + ../pom.xml + + + flight-integration-tests + Arrow Flight Integration Tests + 7.0.0-SNAPSHOT + jar + + + + org.apache.arrow + arrow-vector + ${project.version} + + + org.apache.arrow + arrow-memory-core + ${project.version} + + + org.apache.arrow + flight-core + ${project.version} + + + org.apache.arrow + flight-sql + ${project.version} + + + com.google.protobuf + protobuf-java + ${dep.protobuf.version} + + + commons-cli + commons-cli + 1.4 + + + org.slf4j + slf4j-api + + + + + + + maven-assembly-plugin + 3.0.0 + + + jar-with-dependencies + + + + + make-assembly + package + + single + + + + + + + diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/AuthBasicProtoScenario.java similarity index 98% rename from java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java rename to java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/AuthBasicProtoScenario.java index 3955d7d21bf..1c95d4d5593 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/AuthBasicProtoScenario.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.flight.example.integration; +package org.apache.arrow.flight.integration.tests; import java.nio.charset.StandardCharsets; import java.util.Arrays; diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java new file mode 100644 index 00000000000..374e634e8a3 --- /dev/null +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.integration.tests; + +import java.util.Arrays; + +import org.apache.arrow.flight.CallOption; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.FlightSqlClient; +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.flight.sql.util.TableRef; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Schema; + +/** + * Integration test scenario for validating Flight SQL specs across multiple implementations. + * This should ensure that RPC objects are being built and parsed correctly for multiple languages + * and that the Arrow schemas are returned as expected. + */ +public class FlightSqlScenario implements Scenario { + + public static final long UPDATE_STATEMENT_EXPECTED_ROWS = 10000L; + public static final long UPDATE_PREPARED_STATEMENT_EXPECTED_ROWS = 20000L; + + @Override + public FlightProducer producer(BufferAllocator allocator, Location location) throws Exception { + return new FlightSqlScenarioProducer(allocator); + } + + @Override + public void buildServer(FlightServer.Builder builder) throws Exception { + + } + + @Override + public void client(BufferAllocator allocator, Location location, FlightClient client) + throws Exception { + final FlightSqlClient sqlClient = new FlightSqlClient(client); + + validateMetadataRetrieval(sqlClient); + + validateStatementExecution(sqlClient); + + validatePreparedStatementExecution(sqlClient, allocator); + } + + private void validateMetadataRetrieval(FlightSqlClient sqlClient) throws Exception { + final CallOption[] options = new CallOption[0]; + + validate(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA, sqlClient.getCatalogs(options), + sqlClient); + validate(FlightSqlProducer.Schemas.GET_SCHEMAS_SCHEMA, + sqlClient.getSchemas("catalog", "db_schema_filter_pattern", options), + sqlClient); + validate(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA, + sqlClient.getTables("catalog", "db_schema_filter_pattern", "table_filter_pattern", + Arrays.asList("table", "view"), true, options), sqlClient); + validate(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA, sqlClient.getTableTypes(options), + sqlClient); + validate(FlightSqlProducer.Schemas.GET_PRIMARY_KEYS_SCHEMA, + sqlClient.getPrimaryKeys(TableRef.of("catalog", "db_schema", "table"), options), + sqlClient); + validate(FlightSqlProducer.Schemas.GET_EXPORTED_KEYS_SCHEMA, + sqlClient.getExportedKeys(TableRef.of("catalog", "db_schema", "table"), options), + sqlClient); + validate(FlightSqlProducer.Schemas.GET_IMPORTED_KEYS_SCHEMA, + sqlClient.getImportedKeys(TableRef.of("catalog", "db_schema", "table"), options), + sqlClient); + validate(FlightSqlProducer.Schemas.GET_CROSS_REFERENCE_SCHEMA, + sqlClient.getCrossReference(TableRef.of("pk_catalog", "pk_db_schema", "pk_table"), + TableRef.of("fk_catalog", "fk_db_schema", "fk_table"), options), + sqlClient); + validate(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA, + sqlClient.getSqlInfo(new FlightSql.SqlInfo[] {FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME, + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY}, options), sqlClient); + } + + private void validateStatementExecution(FlightSqlClient sqlClient) throws Exception { + final CallOption[] options = new CallOption[0]; + + validate(FlightSqlScenarioProducer.getQuerySchema(), + sqlClient.execute("SELECT STATEMENT", options), sqlClient); + + IntegrationAssertions.assertEquals(sqlClient.executeUpdate("UPDATE STATEMENT", options), + UPDATE_STATEMENT_EXPECTED_ROWS); + } + + private void validatePreparedStatementExecution(FlightSqlClient sqlClient, + BufferAllocator allocator) throws Exception { + final CallOption[] options = new CallOption[0]; + try (FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare( + "SELECT PREPARED STATEMENT"); + VectorSchemaRoot parameters = VectorSchemaRoot.create( + FlightSqlScenarioProducer.getQuerySchema(), allocator)) { + parameters.setRowCount(1); + preparedStatement.setParameters(parameters); + + validate(FlightSqlScenarioProducer.getQuerySchema(), preparedStatement.execute(options), + sqlClient); + } + + try (FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare( + "UPDATE PREPARED STATEMENT")) { + IntegrationAssertions.assertEquals(preparedStatement.executeUpdate(options), + UPDATE_PREPARED_STATEMENT_EXPECTED_ROWS); + } + } + + private void validate(Schema expectedSchema, FlightInfo flightInfo, + FlightSqlClient sqlClient) throws Exception { + Ticket ticket = flightInfo.getEndpoints().get(0).getTicket(); + try (FlightStream stream = sqlClient.getStream(ticket)) { + Schema actualSchema = stream.getSchema(); + IntegrationAssertions.assertEquals(expectedSchema, actualSchema); + } + } +} diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java new file mode 100644 index 00000000000..f3554e1d3d8 --- /dev/null +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.integration.tests; + +import static com.google.protobuf.Any.pack; +import static java.util.Collections.singletonList; + +import java.util.List; + +import org.apache.arrow.flight.Criteria; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.SchemaResult; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; + +import com.google.protobuf.ByteString; +import com.google.protobuf.Message; + +/** + * Hardcoded Flight SQL producer used for cross-language integration tests. + */ +public class FlightSqlScenarioProducer implements FlightSqlProducer { + private final BufferAllocator allocator; + + public FlightSqlScenarioProducer(BufferAllocator allocator) { + this.allocator = allocator; + } + + /** + * Schema to be returned for mocking the statement/prepared statement results. + * Must be the same across all languages. + */ + static Schema getQuerySchema() { + return new Schema( + singletonList( + new Field("id", FieldType.nullable(new ArrowType.Int(64, true)), null) + ) + ); + } + + @Override + public void createPreparedStatement(FlightSql.ActionCreatePreparedStatementRequest request, + CallContext context, StreamListener listener) { + IntegrationAssertions.assertTrue("Expect to be one of the two queries used on tests", + request.getQuery().equals("SELECT PREPARED STATEMENT") || + request.getQuery().equals("UPDATE PREPARED STATEMENT")); + + final FlightSql.ActionCreatePreparedStatementResult + result = FlightSql.ActionCreatePreparedStatementResult.newBuilder() + .setPreparedStatementHandle(ByteString.copyFromUtf8(request.getQuery() + " HANDLE")) + .build(); + listener.onNext(new Result(pack(result).toByteArray())); + listener.onCompleted(); + } + + @Override + public void closePreparedStatement(FlightSql.ActionClosePreparedStatementRequest request, + CallContext context, StreamListener listener) { + IntegrationAssertions.assertTrue("Expect to be one of the two queries used on tests", + request.getPreparedStatementHandle().toStringUtf8().equals("SELECT PREPARED STATEMENT HANDLE") || + request.getPreparedStatementHandle().toStringUtf8().equals("UPDATE PREPARED STATEMENT HANDLE")); + + listener.onCompleted(); + } + + @Override + public FlightInfo getFlightInfoStatement(FlightSql.CommandStatementQuery command, + CallContext context, FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(command.getQuery(), "SELECT STATEMENT"); + + ByteString handle = ByteString.copyFromUtf8("SELECT STATEMENT HANDLE"); + + FlightSql.TicketStatementQuery ticket = FlightSql.TicketStatementQuery.newBuilder() + .setStatementHandle(handle) + .build(); + return getFlightInfoForSchema(ticket, descriptor, getQuerySchema()); + } + + @Override + public FlightInfo getFlightInfoPreparedStatement(FlightSql.CommandPreparedStatementQuery command, + CallContext context, + FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(command.getPreparedStatementHandle().toStringUtf8(), + "SELECT PREPARED STATEMENT HANDLE"); + + return getFlightInfoForSchema(command, descriptor, getQuerySchema()); + } + + @Override + public SchemaResult getSchemaStatement(FlightSql.CommandStatementQuery command, + CallContext context, FlightDescriptor descriptor) { + return new SchemaResult(getQuerySchema()); + } + + @Override + public void getStreamStatement(FlightSql.TicketStatementQuery ticket, CallContext context, + ServerStreamListener listener) { + putEmptyBatchToStreamListener(listener, getQuerySchema()); + } + + @Override + public void getStreamPreparedStatement(FlightSql.CommandPreparedStatementQuery command, + CallContext context, ServerStreamListener listener) { + putEmptyBatchToStreamListener(listener, getQuerySchema()); + } + + private Runnable acceptPutReturnConstant(StreamListener ackStream, long value) { + return () -> { + final FlightSql.DoPutUpdateResult build = + FlightSql.DoPutUpdateResult.newBuilder().setRecordCount(value).build(); + + try (final ArrowBuf buffer = allocator.buffer(build.getSerializedSize())) { + buffer.writeBytes(build.toByteArray()); + ackStream.onNext(PutResult.metadata(buffer)); + ackStream.onCompleted(); + } + }; + } + + @Override + public Runnable acceptPutStatement(FlightSql.CommandStatementUpdate command, CallContext context, + FlightStream flightStream, + StreamListener ackStream) { + IntegrationAssertions.assertEquals(command.getQuery(), "UPDATE STATEMENT"); + + return acceptPutReturnConstant(ackStream, FlightSqlScenario.UPDATE_STATEMENT_EXPECTED_ROWS); + } + + @Override + public Runnable acceptPutPreparedStatementUpdate(FlightSql.CommandPreparedStatementUpdate command, + CallContext context, FlightStream flightStream, + StreamListener ackStream) { + IntegrationAssertions.assertEquals(command.getPreparedStatementHandle().toStringUtf8(), + "UPDATE PREPARED STATEMENT HANDLE"); + + return acceptPutReturnConstant(ackStream, FlightSqlScenario.UPDATE_PREPARED_STATEMENT_EXPECTED_ROWS); + } + + @Override + public Runnable acceptPutPreparedStatementQuery(FlightSql.CommandPreparedStatementQuery command, + CallContext context, FlightStream flightStream, + StreamListener ackStream) { + IntegrationAssertions.assertEquals(command.getPreparedStatementHandle().toStringUtf8(), + "SELECT PREPARED STATEMENT HANDLE"); + + IntegrationAssertions.assertEquals(getQuerySchema(), flightStream.getSchema()); + + return ackStream::onCompleted; + } + + @Override + public FlightInfo getFlightInfoSqlInfo(FlightSql.CommandGetSqlInfo request, CallContext context, + FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(request.getInfoCount(), 2); + IntegrationAssertions.assertEquals(request.getInfo(0), + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME_VALUE); + IntegrationAssertions.assertEquals(request.getInfo(1), + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY_VALUE); + + return getFlightInfoForSchema(request, descriptor, Schemas.GET_SQL_INFO_SCHEMA); + } + + @Override + public void getStreamSqlInfo(FlightSql.CommandGetSqlInfo command, CallContext context, + ServerStreamListener listener) { + putEmptyBatchToStreamListener(listener, Schemas.GET_SQL_INFO_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoCatalogs(FlightSql.CommandGetCatalogs request, CallContext context, + FlightDescriptor descriptor) { + return getFlightInfoForSchema(request, descriptor, Schemas.GET_CATALOGS_SCHEMA); + } + + private void putEmptyBatchToStreamListener(ServerStreamListener stream, Schema schema) { + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + stream.start(root); + stream.putNext(); + stream.completed(); + } + } + + @Override + public void getStreamCatalogs(CallContext context, ServerStreamListener listener) { + putEmptyBatchToStreamListener(listener, Schemas.GET_CATALOGS_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoSchemas(FlightSql.CommandGetDbSchemas request, CallContext context, + FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(request.getCatalog(), "catalog"); + IntegrationAssertions.assertEquals(request.getDbSchemaFilterPattern(), + "db_schema_filter_pattern"); + + return getFlightInfoForSchema(request, descriptor, Schemas.GET_SCHEMAS_SCHEMA); + } + + @Override + public void getStreamSchemas(FlightSql.CommandGetDbSchemas command, CallContext context, + ServerStreamListener listener) { + putEmptyBatchToStreamListener(listener, Schemas.GET_SCHEMAS_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoTables(FlightSql.CommandGetTables request, CallContext context, + FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(request.getCatalog(), "catalog"); + IntegrationAssertions.assertEquals(request.getDbSchemaFilterPattern(), + "db_schema_filter_pattern"); + IntegrationAssertions.assertEquals(request.getTableNameFilterPattern(), "table_filter_pattern"); + IntegrationAssertions.assertEquals(request.getTableTypesCount(), 2); + IntegrationAssertions.assertEquals(request.getTableTypes(0), "table"); + IntegrationAssertions.assertEquals(request.getTableTypes(1), "view"); + + return getFlightInfoForSchema(request, descriptor, Schemas.GET_TABLES_SCHEMA); + } + + @Override + public void getStreamTables(FlightSql.CommandGetTables command, CallContext context, + ServerStreamListener listener) { + putEmptyBatchToStreamListener(listener, Schemas.GET_TABLES_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoTableTypes(FlightSql.CommandGetTableTypes request, + CallContext context, FlightDescriptor descriptor) { + return getFlightInfoForSchema(request, descriptor, Schemas.GET_TABLE_TYPES_SCHEMA); + } + + @Override + public void getStreamTableTypes(CallContext context, ServerStreamListener listener) { + putEmptyBatchToStreamListener(listener, Schemas.GET_TABLE_TYPES_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoPrimaryKeys(FlightSql.CommandGetPrimaryKeys request, + CallContext context, FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(request.getCatalog(), "catalog"); + IntegrationAssertions.assertEquals(request.getDbSchema(), "db_schema"); + IntegrationAssertions.assertEquals(request.getTable(), "table"); + + return getFlightInfoForSchema(request, descriptor, Schemas.GET_PRIMARY_KEYS_SCHEMA); + } + + @Override + public void getStreamPrimaryKeys(FlightSql.CommandGetPrimaryKeys command, CallContext context, + ServerStreamListener listener) { + putEmptyBatchToStreamListener(listener, Schemas.GET_PRIMARY_KEYS_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoExportedKeys(FlightSql.CommandGetExportedKeys request, + CallContext context, FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(request.getCatalog(), "catalog"); + IntegrationAssertions.assertEquals(request.getDbSchema(), "db_schema"); + IntegrationAssertions.assertEquals(request.getTable(), "table"); + + return getFlightInfoForSchema(request, descriptor, Schemas.GET_EXPORTED_KEYS_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoImportedKeys(FlightSql.CommandGetImportedKeys request, + CallContext context, FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(request.getCatalog(), "catalog"); + IntegrationAssertions.assertEquals(request.getDbSchema(), "db_schema"); + IntegrationAssertions.assertEquals(request.getTable(), "table"); + + return getFlightInfoForSchema(request, descriptor, Schemas.GET_IMPORTED_KEYS_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoCrossReference(FlightSql.CommandGetCrossReference request, + CallContext context, FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(request.getPkCatalog(), "pk_catalog"); + IntegrationAssertions.assertEquals(request.getPkDbSchema(), "pk_db_schema"); + IntegrationAssertions.assertEquals(request.getPkTable(), "pk_table"); + IntegrationAssertions.assertEquals(request.getFkCatalog(), "fk_catalog"); + IntegrationAssertions.assertEquals(request.getFkDbSchema(), "fk_db_schema"); + IntegrationAssertions.assertEquals(request.getFkTable(), "fk_table"); + + return getFlightInfoForSchema(request, descriptor, Schemas.GET_CROSS_REFERENCE_SCHEMA); + } + + @Override + public void getStreamExportedKeys(FlightSql.CommandGetExportedKeys command, CallContext context, + ServerStreamListener listener) { + putEmptyBatchToStreamListener(listener, Schemas.GET_EXPORTED_KEYS_SCHEMA); + } + + @Override + public void getStreamImportedKeys(FlightSql.CommandGetImportedKeys command, CallContext context, + ServerStreamListener listener) { + putEmptyBatchToStreamListener(listener, Schemas.GET_IMPORTED_KEYS_SCHEMA); + } + + @Override + public void getStreamCrossReference(FlightSql.CommandGetCrossReference command, + CallContext context, ServerStreamListener listener) { + putEmptyBatchToStreamListener(listener, Schemas.GET_CROSS_REFERENCE_SCHEMA); + } + + @Override + public void close() throws Exception { + + } + + @Override + public void listFlights(CallContext context, Criteria criteria, + StreamListener listener) { + + } + + private FlightInfo getFlightInfoForSchema(final T request, + final FlightDescriptor descriptor, + final Schema schema) { + final Ticket ticket = new Ticket(pack(request).toByteArray()); + final List endpoints = singletonList(new FlightEndpoint(ticket)); + + return new FlightInfo(schema, descriptor, endpoints, -1, -1); + } +} diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationAssertions.java similarity index 88% rename from java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java rename to java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationAssertions.java index 576d1887f39..e124ed0ea74 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationAssertions.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.flight.example.integration; +package org.apache.arrow.flight.integration.tests; import java.util.Objects; @@ -63,6 +63,15 @@ static void assertFalse(String message, boolean value) { } } + /** + * Assert that the value is true, using the given message as an error otherwise. + */ + static void assertTrue(String message, boolean value) { + if (!value) { + throw new AssertionError("Expected true: " + message); + } + } + /** * An interface used with {@link #assertThrows(Class, AssertThrows)}. */ diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestClient.java similarity index 93% rename from java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java rename to java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestClient.java index 27a545f84fd..2a36747b618 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestClient.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.flight.example.integration; +package org.apache.arrow.flight.integration.tests; import static org.apache.arrow.memory.util.LargeMemoryUtil.checkedCastToInt; @@ -91,7 +91,7 @@ private void run(String[] args) throws Exception { final Location defaultLocation = Location.forGrpcInsecure(host, port); try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); - final FlightClient client = FlightClient.builder(allocator, defaultLocation).build()) { + final FlightClient client = FlightClient.builder(allocator, defaultLocation).build()) { if (cmd.hasOption("scenario")) { Scenarios.getScenario(cmd.getOptionValue("scenario")).client(allocator, defaultLocation, client); @@ -109,7 +109,7 @@ private static void testStream(BufferAllocator allocator, Location server, Fligh // 1. Read data from JSON and upload to server. FlightDescriptor descriptor = FlightDescriptor.path(inputPath); try (JsonFileReader reader = new JsonFileReader(new File(inputPath), allocator); - VectorSchemaRoot root = VectorSchemaRoot.create(reader.start(), allocator)) { + VectorSchemaRoot root = VectorSchemaRoot.create(reader.start(), allocator)) { FlightClient.ClientStreamListener stream = client.startPut(descriptor, root, reader, new AsyncPutListener() { int counter = 0; @@ -157,10 +157,10 @@ public void onNext(PutResult val) { for (Location location : locations) { System.out.println("Verifying location " + location.getUri()); try (FlightClient readClient = FlightClient.builder(allocator, location).build(); - FlightStream stream = readClient.getStream(endpoint.getTicket()); - VectorSchemaRoot root = stream.getRoot(); - VectorSchemaRoot downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator); - JsonFileReader reader = new JsonFileReader(new File(inputPath), allocator)) { + FlightStream stream = readClient.getStream(endpoint.getTicket()); + VectorSchemaRoot root = stream.getRoot(); + VectorSchemaRoot downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator); + JsonFileReader reader = new JsonFileReader(new File(inputPath), allocator)) { VectorLoader loader = new VectorLoader(downloadedRoot); VectorUnloader unloader = new VectorUnloader(root); diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestServer.java similarity index 98% rename from java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java rename to java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestServer.java index da336c5024a..7f5e15fe376 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestServer.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.flight.example.integration; +package org.apache.arrow.flight.integration.tests; import org.apache.arrow.flight.FlightServer; import org.apache.arrow.flight.Location; diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/MiddlewareScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/MiddlewareScenario.java similarity index 99% rename from java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/MiddlewareScenario.java rename to java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/MiddlewareScenario.java index c710ce98b56..c284a577c08 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/MiddlewareScenario.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/MiddlewareScenario.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.flight.example.integration; +package org.apache.arrow.flight.integration.tests; import java.nio.charset.StandardCharsets; import java.util.Arrays; diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenario.java similarity index 96% rename from java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java rename to java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenario.java index b3b962d2e73..bcc657b765c 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenario.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.flight.example.integration; +package org.apache.arrow.flight.integration.tests; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightProducer; diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java similarity index 96% rename from java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java rename to java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java index cd9859b4f36..16cc856daf5 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.flight.example.integration; +package org.apache.arrow.flight.integration.tests; import java.util.Map; import java.util.TreeMap; @@ -41,6 +41,7 @@ private Scenarios() { scenarios = new TreeMap<>(); scenarios.put("auth:basic_proto", AuthBasicProtoScenario::new); scenarios.put("middleware", MiddlewareScenario::new); + scenarios.put("flight_sql", FlightSqlScenario::new); } private static Scenarios getInstance() { diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java index 87c8b3e092d..f1eaf2f8988 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java @@ -62,7 +62,6 @@ import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementUpdate; import org.apache.arrow.flight.sql.impl.FlightSql.DoPutUpdateResult; import org.apache.arrow.flight.sql.impl.FlightSql.TicketStatementQuery; -import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.UnionMode; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -595,13 +594,13 @@ void getStreamCrossReference(CommandGetCrossReference command, CallContext conte final class Schemas { public static final Schema GET_TABLES_SCHEMA = new Schema(asList( Field.nullable("catalog_name", VARCHAR.getType()), - Field.nullable("schema_name", VARCHAR.getType()), + Field.nullable("db_schema_name", VARCHAR.getType()), Field.notNullable("table_name", VARCHAR.getType()), Field.notNullable("table_type", VARCHAR.getType()), Field.notNullable("table_schema", MinorType.VARBINARY.getType()))); public static final Schema GET_TABLES_SCHEMA_NO_SCHEMA = new Schema(asList( Field.nullable("catalog_name", VARCHAR.getType()), - Field.nullable("schema_name", VARCHAR.getType()), + Field.nullable("db_schema_name", VARCHAR.getType()), Field.notNullable("table_name", VARCHAR.getType()), Field.notNullable("table_type", VARCHAR.getType()))); public static final Schema GET_CATALOGS_SCHEMA = new Schema( @@ -611,15 +610,15 @@ final class Schemas { public static final Schema GET_SCHEMAS_SCHEMA = new Schema(asList( Field.nullable("catalog_name", VARCHAR.getType()), - Field.notNullable("schema_name", VARCHAR.getType()))); + Field.notNullable("db_schema_name", VARCHAR.getType()))); private static final Schema GET_IMPORTED_EXPORTED_AND_CROSS_REFERENCE_KEYS_SCHEMA = new Schema(asList( Field.nullable("pk_catalog_name", VARCHAR.getType()), - Field.nullable("pk_schema_name", VARCHAR.getType()), + Field.nullable("pk_db_schema_name", VARCHAR.getType()), Field.notNullable("pk_table_name", VARCHAR.getType()), Field.notNullable("pk_column_name", VARCHAR.getType()), Field.nullable("fk_catalog_name", VARCHAR.getType()), - Field.nullable("fk_schema_name", VARCHAR.getType()), + Field.nullable("fk_db_schema_name", VARCHAR.getType()), Field.notNullable("fk_table_name", VARCHAR.getType()), Field.notNullable("fk_column_name", VARCHAR.getType()), Field.notNullable("key_sequence", INT.getType()), @@ -631,32 +630,32 @@ final class Schemas { public static final Schema GET_EXPORTED_KEYS_SCHEMA = GET_IMPORTED_EXPORTED_AND_CROSS_REFERENCE_KEYS_SCHEMA; public static final Schema GET_CROSS_REFERENCE_SCHEMA = GET_IMPORTED_EXPORTED_AND_CROSS_REFERENCE_KEYS_SCHEMA; private static final List GET_SQL_INFO_DENSE_UNION_SCHEMA_FIELDS = asList( - Field.nullable("string_value", VARCHAR.getType()), - Field.nullable("bool_value", BIT.getType()), - Field.nullable("bigint_value", BIGINT.getType()), - Field.nullable("int32_bitmask", INT.getType()), + Field.notNullable("string_value", VARCHAR.getType()), + Field.notNullable("bool_value", BIT.getType()), + Field.notNullable("bigint_value", BIGINT.getType()), + Field.notNullable("int32_bitmask", INT.getType()), new Field( - "string_list", FieldType.nullable(LIST.getType()), - singletonList(Field.nullable(ListVector.DATA_VECTOR_NAME, VARCHAR.getType()))), + "string_list", FieldType.notNullable(LIST.getType()), + singletonList(Field.nullable("item", VARCHAR.getType()))), new Field( - "int32_to_int32_list_map", FieldType.nullable(new ArrowType.Map(false)), + "int32_to_int32_list_map", FieldType.notNullable(new ArrowType.Map(false)), singletonList(new Field(DATA_VECTOR_NAME, new FieldType(false, STRUCT.getType(), null), ImmutableList.of( Field.notNullable(KEY_NAME, INT.getType()), new Field( VALUE_NAME, FieldType.nullable(LIST.getType()), - singletonList(Field.nullable(ListVector.DATA_VECTOR_NAME, INT.getType())))))))); + singletonList(Field.nullable("item", INT.getType())))))))); public static final Schema GET_SQL_INFO_SCHEMA = new Schema(asList( Field.notNullable("info_name", UINT4.getType()), new Field("value", - FieldType.nullable( + FieldType.notNullable( new Union(UnionMode.Dense, range(0, GET_SQL_INFO_DENSE_UNION_SCHEMA_FIELDS.size()).toArray())), GET_SQL_INFO_DENSE_UNION_SCHEMA_FIELDS))); public static final Schema GET_PRIMARY_KEYS_SCHEMA = new Schema(asList( Field.nullable("catalog_name", VARCHAR.getType()), - Field.nullable("schema_name", VARCHAR.getType()), + Field.nullable("db_schema_name", VARCHAR.getType()), Field.notNullable("table_name", VARCHAR.getType()), Field.notNullable("column_name", VARCHAR.getType()), Field.notNullable("key_sequence", INT.getType()), diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java index 687840386e9..634343c236c 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java @@ -356,7 +356,7 @@ private static VectorSchemaRoot getSchemasRoot(final ResultSet data, final Buffe throws SQLException { final VarCharVector catalogs = new VarCharVector("catalog_name", allocator); final VarCharVector schemas = - new VarCharVector("schema_name", FieldType.notNullable(MinorType.VARCHAR.getType()), allocator); + new VarCharVector("db_schema_name", FieldType.notNullable(MinorType.VARCHAR.getType()), allocator); final List vectors = ImmutableList.of(catalogs, schemas); vectors.forEach(FieldVector::allocateNew); final Map vectorToColumnName = ImmutableMap.of( @@ -449,7 +449,7 @@ private static VectorSchemaRoot getTablesRoot(final DatabaseMetaData databaseMet */ Objects.requireNonNull(allocator, "BufferAllocator cannot be null."); final VarCharVector catalogNameVector = new VarCharVector("catalog_name", allocator); - final VarCharVector schemaNameVector = new VarCharVector("schema_name", allocator); + final VarCharVector schemaNameVector = new VarCharVector("db_schema_name", allocator); final VarCharVector tableNameVector = new VarCharVector("table_name", FieldType.notNullable(MinorType.VARCHAR.getType()), allocator); final VarCharVector tableTypeVector = @@ -1409,7 +1409,7 @@ public void getStreamPrimaryKeys(final CommandGetPrimaryKeys command, final Call final ResultSet primaryKeys = connection.getMetaData().getPrimaryKeys(catalog, schema, table); final VarCharVector catalogNameVector = new VarCharVector("catalog_name", rootAllocator); - final VarCharVector schemaNameVector = new VarCharVector("schema_name", rootAllocator); + final VarCharVector schemaNameVector = new VarCharVector("db_schema_name", rootAllocator); final VarCharVector tableNameVector = new VarCharVector("table_name", rootAllocator); final VarCharVector columnNameVector = new VarCharVector("column_name", rootAllocator); final IntVector keySequenceVector = new IntVector("key_sequence", rootAllocator); @@ -1527,11 +1527,11 @@ public void getStreamCrossReference(CommandGetCrossReference command, CallContex private VectorSchemaRoot createVectors(ResultSet keys) throws SQLException { final VarCharVector pkCatalogNameVector = new VarCharVector("pk_catalog_name", rootAllocator); - final VarCharVector pkSchemaNameVector = new VarCharVector("pk_schema_name", rootAllocator); + final VarCharVector pkSchemaNameVector = new VarCharVector("pk_db_schema_name", rootAllocator); final VarCharVector pkTableNameVector = new VarCharVector("pk_table_name", rootAllocator); final VarCharVector pkColumnNameVector = new VarCharVector("pk_column_name", rootAllocator); final VarCharVector fkCatalogNameVector = new VarCharVector("fk_catalog_name", rootAllocator); - final VarCharVector fkSchemaNameVector = new VarCharVector("fk_schema_name", rootAllocator); + final VarCharVector fkSchemaNameVector = new VarCharVector("fk_db_schema_name", rootAllocator); final VarCharVector fkTableNameVector = new VarCharVector("fk_table_name", rootAllocator); final VarCharVector fkColumnNameVector = new VarCharVector("fk_column_name", rootAllocator); final IntVector keySequenceVector = new IntVector("key_sequence", rootAllocator); diff --git a/java/flight/pom.xml b/java/flight/pom.xml index 2cb409aaad0..7cb0e1d7171 100644 --- a/java/flight/pom.xml +++ b/java/flight/pom.xml @@ -33,6 +33,7 @@ flight-core flight-grpc flight-sql + flight-integration-tests