diff --git a/ci/docker/conda-cpp.dockerfile b/ci/docker/conda-cpp.dockerfile index 8fd5e46fd6d..9363e67f796 100644 --- a/ci/docker/conda-cpp.dockerfile +++ b/ci/docker/conda-cpp.dockerfile @@ -41,6 +41,7 @@ ENV ARROW_BUILD_TESTS=ON \ ARROW_DATASET=ON \ ARROW_DEPENDENCY_SOURCE=CONDA \ ARROW_FLIGHT=ON \ + ARROW_FLIGHT_SQL=ON \ ARROW_GANDIVA=ON \ ARROW_HOME=$CONDA_PREFIX \ ARROW_ORC=ON \ diff --git a/ci/scripts/cpp_build.sh b/ci/scripts/cpp_build.sh index f791ddd5645..02718e57836 100755 --- a/ci/scripts/cpp_build.sh +++ b/ci/scripts/cpp_build.sh @@ -70,6 +70,7 @@ cmake \ -DARROW_EXTRA_ERROR_CONTEXT=${ARROW_EXTRA_ERROR_CONTEXT:-OFF} \ -DARROW_FILESYSTEM=${ARROW_FILESYSTEM:-ON} \ -DARROW_FLIGHT=${ARROW_FLIGHT:-OFF} \ + -DARROW_FLIGHT_SQL=${ARROW_FLIGHT_SQL:-OFF} \ -DARROW_FUZZING=${ARROW_FUZZING:-OFF} \ -DARROW_GANDIVA_JAVA=${ARROW_GANDIVA_JAVA:-OFF} \ -DARROW_GANDIVA_PC_CXX_FLAGS=${ARROW_GANDIVA_PC_CXX_FLAGS:-} \ diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 0e7b7b79a9f..fd7027c30eb 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -334,6 +334,14 @@ if(ARROW_GANDIVA) set(ARROW_WITH_RE2 ON) endif() +if(ARROW_BUILD_INTEGRATION AND ARROW_FLIGHT) + set(ARROW_FLIGHT_SQL ON) +endif() + +if(ARROW_FLIGHT_SQL) + set(ARROW_FLIGHT ON) +endif() + if(ARROW_CUDA OR ARROW_FLIGHT OR ARROW_PARQUET diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake index f2ddff3997d..2afbdab4a40 100644 --- a/cpp/cmake_modules/DefineOptions.cmake +++ b/cpp/cmake_modules/DefineOptions.cmake @@ -226,6 +226,8 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}") define_option(ARROW_FLIGHT "Build the Arrow Flight RPC System (requires GRPC, Protocol Buffers)" OFF) + define_option(ARROW_FLIGHT_SQL "Build the Arrow Flight SQL extension" OFF) + define_option(ARROW_GANDIVA "Build the Gandiva libraries" OFF) define_option(ARROW_GCS diff --git a/cpp/cmake_modules/FindArrowFlightSql.cmake b/cpp/cmake_modules/FindArrowFlightSql.cmake new file mode 100644 index 00000000000..cbca81cac44 --- /dev/null +++ b/cpp/cmake_modules/FindArrowFlightSql.cmake @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# - Find Arrow Flight SQL +# +# This module requires Arrow from which it uses +# arrow_find_package() +# +# This module defines +# ARROW_FLIGHT_SQL_FOUND, whether Flight has been found +# ARROW_FLIGHT_SQL_IMPORT_LIB, +# path to libarrow_flight's import library (Windows only) +# ARROW_FLIGHT_SQL_INCLUDE_DIR, directory containing headers +# ARROW_FLIGHT_SQL_LIBS, deprecated. Use ARROW_FLIGHT_SQL_LIB_DIR instead +# ARROW_FLIGHT_SQL_LIB_DIR, directory containing Flight libraries +# ARROW_FLIGHT_SQL_SHARED_IMP_LIB, deprecated. Use ARROW_FLIGHT_SQL_IMPORT_LIB instead +# ARROW_FLIGHT_SQL_SHARED_LIB, path to libarrow_flight's shared library +# ARROW_FLIGHT_SQL_STATIC_LIB, path to libarrow_flight.a + +if(DEFINED ARROW_FLIGHT_SQL_FOUND) + return() +endif() + +set(find_package_arguments) +if(${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION) + list(APPEND find_package_arguments "${${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION}") +endif() +if(${CMAKE_FIND_PACKAGE_NAME}_FIND_REQUIRED) + list(APPEND find_package_arguments REQUIRED) +endif() +if(${CMAKE_FIND_PACKAGE_NAME}_FIND_QUIETLY) + list(APPEND find_package_arguments QUIET) +endif() +find_package(Arrow ${find_package_arguments}) + +if(ARROW_FOUND) + arrow_find_package(ARROW_FLIGHT_SQL + "${ARROW_HOME}" + arrow_flight_sql + arrow/flight/sql/api.h + ArrowFlightSql + arrow-flight-sql) + if(NOT ARROW_FLIGHT_SQL_VERSION) + set(ARROW_FLIGHT_SQL_VERSION "${ARROW_VERSION}") + endif() +endif() + +if("${ARROW_FLIGHT_SQL_VERSION}" VERSION_EQUAL "${ARROW_VERSION}") + set(ARROW_FLIGHT_SQL_VERSION_MATCH TRUE) +else() + set(ARROW_FLIGHT_SQL_VERSION_MATCH FALSE) +endif() + +mark_as_advanced(ARROW_FLIGHT_SQL_IMPORT_LIB + ARROW_FLIGHT_SQL_INCLUDE_DIR + ARROW_FLIGHT_SQL_LIBS + ARROW_FLIGHT_SQL_LIB_DIR + ARROW_FLIGHT_SQL_SHARED_IMP_LIB + ARROW_FLIGHT_SQL_SHARED_LIB + ARROW_FLIGHT_SQL_STATIC_LIB + ARROW_FLIGHT_SQL_VERSION + ARROW_FLIGHT_SQL_VERSION_MATCH) + +find_package_handle_standard_args( + ArrowFlightSql + REQUIRED_VARS ARROW_FLIGHT_SQL_INCLUDE_DIR ARROW_FLIGHT_SQL_LIB_DIR + ARROW_FLIGHT_SQL_VERSION_MATCH + VERSION_VAR ARROW_FLIGHT_SQL_VERSION) +set(ARROW_FLIGHT_SQL_FOUND ${ArrowFlightSql_FOUND}) + +if(ArrowFlightSql_FOUND AND NOT ArrowFlightSql_FIND_QUIETLY) + message(STATUS "Found the Arrow Flight SQL by ${ARROW_FLIGHT_SQL_FIND_APPROACH}") + message(STATUS "Found the Arrow Flight SQL shared library: ${ARROW_FLIGHT_SQL_SHARED_LIB}" + ) + message(STATUS "Found the Arrow Flight SQL import library: ${ARROW_FLIGHT_SQL_IMPORT_LIB}" + ) + message(STATUS "Found the Arrow Flight SQL static library: ${ARROW_FLIGHT_SQL_STATIC_LIB}" + ) +endif() diff --git a/cpp/cmake_modules/FindSQLite3Alt.cmake b/cpp/cmake_modules/FindSQLite3Alt.cmake new file mode 100644 index 00000000000..73a45f098c6 --- /dev/null +++ b/cpp/cmake_modules/FindSQLite3Alt.cmake @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Once done this will define +# - FindSQLite3Alt +# +# This module will set the following variables if found: +# SQLite3_INCLUDE_DIRS - SQLite3 include dir. +# SQLite3_LIBRARIES - List of libraries when using SQLite3. +# SQLite3_FOUND - True if SQLite3 found. +# +# Usage of this module as follows: +# find_package(SQLite3Alt) + +find_path(SQLite3_INCLUDE_DIR sqlite3.h) +find_library(SQLite3_LIBRARY NAMES sqlite3) + +# handle the QUIETLY and REQUIRED arguments and set SQLite3_FOUND to TRUE if +# all listed variables are TRUE +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(SQLite3Alt REQUIRED_VARS SQLite3_LIBRARY + SQLite3_INCLUDE_DIR) + +mark_as_advanced(SQLite3_LIBRARY SQLite3_INCLUDE_DIR) + +if(SQLite3Alt_FOUND) + set(SQLite3_INCLUDE_DIRS ${SQLite3_INCLUDE_DIR}) + set(SQLite3_LIBRARIES ${SQLite3_LIBRARY}) +endif() diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 5736c557bd0..cc979a22e09 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -732,6 +732,16 @@ if(ARROW_FLIGHT) add_subdirectory(flight) endif() +if(ARROW_FLIGHT_SQL) + add_subdirectory(flight/sql) +endif() + +if(ARROW_FLIGHT + AND ARROW_FLIGHT_SQL + AND ARROW_BUILD_INTEGRATION) + add_subdirectory(flight/integration_tests) +endif() + if(ARROW_HIVESERVER2) add_subdirectory(dbi/hiveserver2) endif() diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index 8a3228e5026..2cf8c9913e5 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -25,7 +25,23 @@ if(WIN32) list(APPEND ARROW_FLIGHT_LINK_LIBS ws2_32.lib) endif() -if(ARROW_TEST_LINKAGE STREQUAL "static") +set(ARROW_FLIGHT_TEST_LINKAGE + "${ARROW_TEST_LINKAGE}" + PARENT_SCOPE) +if(Protobuf_USE_STATIC_LIBS) + message(STATUS "Linking Arrow Flight tests statically due to static Protobuf") + set(ARROW_FLIGHT_TEST_LINKAGE + "static" + PARENT_SCOPE) +endif() +if(NOT ARROW_GRPC_USE_SHARED) + message(STATUS "Linking Arrow Flight tests statically due to static gRPC") + set(ARROW_FLIGHT_TEST_LINKAGE + "static" + PARENT_SCOPE) +endif() + +if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static") set(ARROW_FLIGHT_TEST_LINK_LIBS arrow_flight_static arrow_flight_testing_static ${ARROW_FLIGHT_STATIC_LINK_LIBS} ${ARROW_TEST_LINK_LIBS}) @@ -186,7 +202,6 @@ if(ARROW_TESTING) OUTPUTS ARROW_FLIGHT_TESTING_LIBRARIES SOURCES - test_integration.cc test_util.cc DEPENDENCIES GTest::gtest @@ -230,21 +245,6 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_BENCHMARKS) add_dependencies(arrow_flight flight-test-server) endif() -if(ARROW_BUILD_INTEGRATION) - add_executable(flight-test-integration-server test_integration_server.cc) - target_link_libraries(flight-test-integration-server ${ARROW_FLIGHT_TEST_LINK_LIBS} - ${GFLAGS_LIBRARIES} GTest::gtest) - - add_executable(flight-test-integration-client test_integration_client.cc) - target_link_libraries(flight-test-integration-client ${ARROW_FLIGHT_TEST_LINK_LIBS} - ${GFLAGS_LIBRARIES} GTest::gtest) - - add_dependencies(arrow_flight flight-test-integration-client - flight-test-integration-server) - add_dependencies(arrow-integration flight-test-integration-client - flight-test-integration-server) -endif() - if(ARROW_BUILD_BENCHMARKS) # Perf server for benchmarks set(PERF_PROTO_GENERATED_FILES "${CMAKE_CURRENT_BINARY_DIR}/perf.pb.cc" diff --git a/cpp/src/arrow/flight/integration_tests/CMakeLists.txt b/cpp/src/arrow/flight/integration_tests/CMakeLists.txt new file mode 100644 index 00000000000..3a878d7f305 --- /dev/null +++ b/cpp/src/arrow/flight/integration_tests/CMakeLists.txt @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +add_custom_target(arrow_flight_integration_tests) + +if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static") + set(ARROW_FLIGHT_TEST_LINK_LIBS + arrow_flight_static + arrow_flight_testing_static + arrow_flight_sql_static + ${ARROW_FLIGHT_STATIC_LINK_LIBS} + ${ARROW_TEST_LINK_LIBS}) +else() + set(ARROW_FLIGHT_TEST_LINK_LIBS + arrow_flight_shared + arrow_flight_testing_shared + arrow_flight_sql_shared + ${ARROW_TEST_LINK_LIBS} + ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS}) +endif() + +add_executable(flight-test-integration-server test_integration_server.cc + test_integration.cc) +target_link_libraries(flight-test-integration-server ${ARROW_FLIGHT_TEST_LINK_LIBS} + ${GFLAGS_LIBRARIES} GTest::gtest) + +add_executable(flight-test-integration-client test_integration_client.cc + test_integration.cc) +target_link_libraries(flight-test-integration-client ${ARROW_FLIGHT_TEST_LINK_LIBS} + ${GFLAGS_LIBRARIES} GTest::gtest) + +add_dependencies(arrow-integration flight-test-integration-client + flight-test-integration-server) diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc new file mode 100644 index 00000000000..1e08f47b579 --- /dev/null +++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc @@ -0,0 +1,684 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/integration_tests/test_integration.h" +#include "arrow/flight/client_middleware.h" +#include "arrow/flight/server_middleware.h" +#include "arrow/flight/sql/client.h" +#include "arrow/flight/sql/server.h" +#include "arrow/flight/test_util.h" +#include "arrow/flight/types.h" +#include "arrow/ipc/dictionary.h" +#include "arrow/testing/gtest_util.h" + +#include +#include +#include +#include +#include + +namespace arrow { +namespace flight { +namespace integration_tests { + +/// \brief The server for the basic auth integration test. +class AuthBasicProtoServer : public FlightServerBase { + Status DoAction(const ServerCallContext& context, const Action& action, + std::unique_ptr* result) override { + // Respond with the authenticated username. + auto buf = Buffer::FromString(context.peer_identity()); + *result = std::unique_ptr(new SimpleResultStream({Result{buf}})); + return Status::OK(); + } +}; + +/// Validate the result of a DoAction. +Status CheckActionResults(FlightClient* client, const Action& action, + std::vector results) { + std::unique_ptr stream; + RETURN_NOT_OK(client->DoAction(action, &stream)); + std::unique_ptr result; + for (const std::string& expected : results) { + RETURN_NOT_OK(stream->Next(&result)); + if (!result) { + return Status::Invalid("Action result stream ended early"); + } + const auto actual = result->body->ToString(); + if (expected != actual) { + return Status::Invalid("Got wrong result; expected", expected, "but got", actual); + } + } + RETURN_NOT_OK(stream->Next(&result)); + if (result) { + return Status::Invalid("Action result stream had too many entries"); + } + return Status::OK(); +} + +// The expected username for the basic auth integration test. +constexpr auto kAuthUsername = "arrow"; +// The expected password for the basic auth integration test. +constexpr auto kAuthPassword = "flight"; + +/// \brief A scenario testing the basic auth protobuf. +class AuthBasicProtoScenario : public Scenario { + Status MakeServer(std::unique_ptr* server, + FlightServerOptions* options) override { + server->reset(new AuthBasicProtoServer()); + options->auth_handler = + std::make_shared(kAuthUsername, kAuthPassword); + return Status::OK(); + } + + Status MakeClient(FlightClientOptions* options) override { return Status::OK(); } + + Status RunClient(std::unique_ptr client) override { + Action action; + std::unique_ptr stream; + std::shared_ptr detail; + const auto& status = client->DoAction(action, &stream); + detail = FlightStatusDetail::UnwrapStatus(status); + // This client is unauthenticated and should fail. + if (detail == nullptr) { + return Status::Invalid("Expected UNAUTHENTICATED but got ", status.ToString()); + } + if (detail->code() != FlightStatusCode::Unauthenticated) { + return Status::Invalid("Expected UNAUTHENTICATED but got ", detail->ToString()); + } + + auto client_handler = std::unique_ptr( + new TestClientBasicAuthHandler(kAuthUsername, kAuthPassword)); + RETURN_NOT_OK(client->Authenticate({}, std::move(client_handler))); + return CheckActionResults(client.get(), action, {kAuthUsername}); + } +}; + +/// \brief Test middleware that echoes back the value of a particular +/// incoming header. +/// +/// In Java, gRPC may consolidate this header with HTTP/2 trailers if +/// the call fails, but C++ generally doesn't do this. The integration +/// test confirms the presence of this header to ensure we can read it +/// regardless of what gRPC does. +class TestServerMiddleware : public ServerMiddleware { + public: + explicit TestServerMiddleware(std::string received) : received_(received) {} + + void SendingHeaders(AddCallHeaders* outgoing_headers) override { + outgoing_headers->AddHeader("x-middleware", received_); + } + + void CallCompleted(const Status& status) override {} + + std::string name() const override { return "GrpcTrailersMiddleware"; } + + private: + std::string received_; +}; + +class TestServerMiddlewareFactory : public ServerMiddlewareFactory { + public: + Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers, + std::shared_ptr* middleware) override { + const std::pair& iter_pair = + incoming_headers.equal_range("x-middleware"); + std::string received = ""; + if (iter_pair.first != iter_pair.second) { + const util::string_view& value = (*iter_pair.first).second; + received = std::string(value); + } + *middleware = std::make_shared(received); + return Status::OK(); + } +}; + +/// \brief Test middleware that adds a header on every outgoing call, +/// and gets the value of the expected header sent by the server. +class TestClientMiddleware : public ClientMiddleware { + public: + explicit TestClientMiddleware(std::string* received_header) + : received_header_(received_header) {} + + void SendingHeaders(AddCallHeaders* outgoing_headers) { + outgoing_headers->AddHeader("x-middleware", "expected value"); + } + + void ReceivedHeaders(const CallHeaders& incoming_headers) { + // We expect the server to always send this header. gRPC/Java may + // send it in trailers instead of headers, so we expect Flight to + // account for this. + const std::pair& iter_pair = + incoming_headers.equal_range("x-middleware"); + if (iter_pair.first != iter_pair.second) { + const util::string_view& value = (*iter_pair.first).second; + *received_header_ = std::string(value); + } + } + + void CallCompleted(const Status& status) {} + + private: + std::string* received_header_; +}; + +class TestClientMiddlewareFactory : public ClientMiddlewareFactory { + public: + void StartCall(const CallInfo& info, std::unique_ptr* middleware) { + *middleware = + std::unique_ptr(new TestClientMiddleware(&received_header_)); + } + + std::string received_header_; +}; + +/// \brief The server used for testing middleware. Implements only one +/// endpoint, GetFlightInfo, in such a way that it either succeeds or +/// returns an error based on the input, in order to test both paths. +class MiddlewareServer : public FlightServerBase { + Status GetFlightInfo(const ServerCallContext& context, + const FlightDescriptor& descriptor, + std::unique_ptr* result) override { + if (descriptor.type == FlightDescriptor::DescriptorType::CMD && + descriptor.cmd == "success") { + // Don't fail + std::shared_ptr schema = arrow::schema({}); + Location location; + // Return a fake location - the test doesn't read it + RETURN_NOT_OK(Location::ForGrpcTcp("localhost", 10010, &location)); + std::vector endpoints{FlightEndpoint{{"foo"}, {location}}}; + ARROW_ASSIGN_OR_RAISE(auto info, + FlightInfo::Make(*schema, descriptor, endpoints, -1, -1)); + *result = std::unique_ptr(new FlightInfo(info)); + return Status::OK(); + } + // Fail the call immediately. In some gRPC implementations, this + // means that gRPC sends only HTTP/2 trailers and not headers. We want + // Flight middleware to be agnostic to this difference. + return Status::UnknownError("Unknown"); + } +}; + +/// \brief The middleware scenario. +/// +/// This tests that the server and client get expected header values. +class MiddlewareScenario : public Scenario { + Status MakeServer(std::unique_ptr* server, + FlightServerOptions* options) override { + options->middleware.push_back( + {"grpc_trailers", std::make_shared()}); + server->reset(new MiddlewareServer()); + return Status::OK(); + } + + Status MakeClient(FlightClientOptions* options) override { + client_middleware_ = std::make_shared(); + options->middleware.push_back(client_middleware_); + return Status::OK(); + } + + Status RunClient(std::unique_ptr client) override { + std::unique_ptr info; + // This call is expected to fail. In gRPC/Java, this causes the + // server to combine headers and HTTP/2 trailers, so to read the + // expected header, Flight must check for both headers and + // trailers. + if (client->GetFlightInfo(FlightDescriptor::Command(""), &info).ok()) { + return Status::Invalid("Expected call to fail"); + } + if (client_middleware_->received_header_ != "expected value") { + return Status::Invalid( + "Expected to receive header 'x-middleware: expected value', but instead got: '", + client_middleware_->received_header_, "'"); + } + std::cerr << "Headers received successfully on failing call." << std::endl; + + // This call should succeed + client_middleware_->received_header_ = ""; + RETURN_NOT_OK(client->GetFlightInfo(FlightDescriptor::Command("success"), &info)); + if (client_middleware_->received_header_ != "expected value") { + return Status::Invalid( + "Expected to receive header 'x-middleware: expected value', but instead got '", + client_middleware_->received_header_, "'"); + } + std::cerr << "Headers received successfully on passing call." << std::endl; + return Status::OK(); + } + + std::shared_ptr client_middleware_; +}; + +/// \brief Schema to be returned for mocking the statement/prepared statement results. +/// Must be the same across all languages. +std::shared_ptr GetQuerySchema() { + return arrow::schema({arrow::field("id", int64())}); +} + +constexpr int64_t kUpdateStatementExpectedRows = 10000L; +constexpr int64_t kUpdatePreparedStatementExpectedRows = 20000L; + +template +arrow::Status AssertEq(const T& expected, const T& actual) { + if (expected != actual) { + return Status::Invalid("Expected \"", expected, "\", got \'", actual, "\""); + } + return Status::OK(); +} + +/// \brief The server used for testing Flight SQL, this implements a static Flight SQL +/// server which only asserts that commands called during integration tests are being +/// parsed correctly and returns the expected schemas to be validated on client. +class FlightSqlScenarioServer : public sql::FlightSqlServerBase { + public: + arrow::Result> GetFlightInfoStatement( + const ServerCallContext& context, const sql::StatementQuery& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK(AssertEq("SELECT STATEMENT", command.query)); + + ARROW_ASSIGN_OR_RAISE(auto handle, + sql::CreateStatementQueryTicket("SELECT STATEMENT HANDLE")); + + std::vector endpoints{FlightEndpoint{{handle}, {}}}; + ARROW_ASSIGN_OR_RAISE( + auto result, FlightInfo::Make(*GetQuerySchema(), descriptor, endpoints, -1, -1)) + + return std::unique_ptr(new FlightInfo(result)); + } + + arrow::Result> DoGetStatement( + const ServerCallContext& context, + const sql::StatementQueryTicket& command) override { + return DoGetForTestCase(GetQuerySchema()); + } + + arrow::Result> GetFlightInfoPreparedStatement( + const ServerCallContext& context, const sql::PreparedStatementQuery& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK(AssertEq("SELECT PREPARED STATEMENT HANDLE", + command.prepared_statement_handle)); + + return GetFlightInfoForCommand(descriptor, GetQuerySchema()); + } + + arrow::Result> DoGetPreparedStatement( + const ServerCallContext& context, + const sql::PreparedStatementQuery& command) override { + return DoGetForTestCase(GetQuerySchema()); + } + + arrow::Result> GetFlightInfoCatalogs( + const ServerCallContext& context, const FlightDescriptor& descriptor) override { + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetCatalogsSchema()); + } + + arrow::Result> DoGetCatalogs( + const ServerCallContext& context) override { + return DoGetForTestCase(sql::SqlSchema::GetCatalogsSchema()); + } + + arrow::Result> GetFlightInfoSqlInfo( + const ServerCallContext& context, const sql::GetSqlInfo& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK(AssertEq(2, command.info.size())); + ARROW_RETURN_NOT_OK(AssertEq( + sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME, command.info[0])); + ARROW_RETURN_NOT_OK(AssertEq( + sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY, command.info[1])); + + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetSqlInfoSchema()); + } + + arrow::Result> DoGetSqlInfo( + const ServerCallContext& context, const sql::GetSqlInfo& command) override { + return DoGetForTestCase(sql::SqlSchema::GetSqlInfoSchema()); + } + + arrow::Result> GetFlightInfoSchemas( + const ServerCallContext& context, const sql::GetDbSchemas& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK(AssertEq("catalog", command.catalog.value())); + ARROW_RETURN_NOT_OK(AssertEq("db_schema_filter_pattern", + command.db_schema_filter_pattern.value())); + + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetDbSchemasSchema()); + } + + arrow::Result> DoGetDbSchemas( + const ServerCallContext& context, const sql::GetDbSchemas& command) override { + return DoGetForTestCase(sql::SqlSchema::GetDbSchemasSchema()); + } + + arrow::Result> GetFlightInfoTables( + const ServerCallContext& context, const sql::GetTables& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK(AssertEq("catalog", command.catalog.value())); + ARROW_RETURN_NOT_OK(AssertEq("db_schema_filter_pattern", + command.db_schema_filter_pattern.value())); + ARROW_RETURN_NOT_OK(AssertEq("table_filter_pattern", + command.table_name_filter_pattern.value())); + ARROW_RETURN_NOT_OK(AssertEq(2, command.table_types.size())); + ARROW_RETURN_NOT_OK(AssertEq("table", command.table_types[0])); + ARROW_RETURN_NOT_OK(AssertEq("view", command.table_types[1])); + ARROW_RETURN_NOT_OK(AssertEq(true, command.include_schema)); + + return GetFlightInfoForCommand(descriptor, + sql::SqlSchema::GetTablesSchemaWithIncludedSchema()); + } + + arrow::Result> DoGetTables( + const ServerCallContext& context, const sql::GetTables& command) override { + return DoGetForTestCase(sql::SqlSchema::GetTablesSchemaWithIncludedSchema()); + } + + arrow::Result> GetFlightInfoTableTypes( + const ServerCallContext& context, const FlightDescriptor& descriptor) override { + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetTableTypesSchema()); + } + + arrow::Result> DoGetTableTypes( + const ServerCallContext& context) override { + return DoGetForTestCase(sql::SqlSchema::GetTableTypesSchema()); + } + + arrow::Result> GetFlightInfoPrimaryKeys( + const ServerCallContext& context, const sql::GetPrimaryKeys& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK( + AssertEq("catalog", command.table_ref.catalog.value())); + ARROW_RETURN_NOT_OK( + AssertEq("db_schema", command.table_ref.db_schema.value())); + ARROW_RETURN_NOT_OK(AssertEq("table", command.table_ref.table)); + + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetPrimaryKeysSchema()); + } + + arrow::Result> DoGetPrimaryKeys( + const ServerCallContext& context, const sql::GetPrimaryKeys& command) override { + return DoGetForTestCase(sql::SqlSchema::GetPrimaryKeysSchema()); + } + + arrow::Result> GetFlightInfoExportedKeys( + const ServerCallContext& context, const sql::GetExportedKeys& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK( + AssertEq("catalog", command.table_ref.catalog.value())); + ARROW_RETURN_NOT_OK( + AssertEq("db_schema", command.table_ref.db_schema.value())); + ARROW_RETURN_NOT_OK(AssertEq("table", command.table_ref.table)); + + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetExportedKeysSchema()); + } + + arrow::Result> DoGetExportedKeys( + const ServerCallContext& context, const sql::GetExportedKeys& command) override { + return DoGetForTestCase(sql::SqlSchema::GetExportedKeysSchema()); + } + + arrow::Result> GetFlightInfoImportedKeys( + const ServerCallContext& context, const sql::GetImportedKeys& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK( + AssertEq("catalog", command.table_ref.catalog.value())); + ARROW_RETURN_NOT_OK( + AssertEq("db_schema", command.table_ref.db_schema.value())); + ARROW_RETURN_NOT_OK(AssertEq("table", command.table_ref.table)); + + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetImportedKeysSchema()); + } + + arrow::Result> DoGetImportedKeys( + const ServerCallContext& context, const sql::GetImportedKeys& command) override { + return DoGetForTestCase(sql::SqlSchema::GetImportedKeysSchema()); + } + + arrow::Result> GetFlightInfoCrossReference( + const ServerCallContext& context, const sql::GetCrossReference& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK( + AssertEq("pk_catalog", command.pk_table_ref.catalog.value())); + ARROW_RETURN_NOT_OK( + AssertEq("pk_db_schema", command.pk_table_ref.db_schema.value())); + ARROW_RETURN_NOT_OK(AssertEq("pk_table", command.pk_table_ref.table)); + ARROW_RETURN_NOT_OK( + AssertEq("fk_catalog", command.fk_table_ref.catalog.value())); + ARROW_RETURN_NOT_OK( + AssertEq("fk_db_schema", command.fk_table_ref.db_schema.value())); + ARROW_RETURN_NOT_OK(AssertEq("fk_table", command.fk_table_ref.table)); + + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetTableTypesSchema()); + } + + arrow::Result> DoGetCrossReference( + const ServerCallContext& context, const sql::GetCrossReference& command) override { + return DoGetForTestCase(sql::SqlSchema::GetCrossReferenceSchema()); + } + + arrow::Result DoPutCommandStatementUpdate( + const ServerCallContext& context, const sql::StatementUpdate& command) override { + ARROW_RETURN_NOT_OK(AssertEq("UPDATE STATEMENT", command.query)); + + return kUpdateStatementExpectedRows; + } + + arrow::Result CreatePreparedStatement( + const ServerCallContext& context, + const sql::ActionCreatePreparedStatementRequest& request) override { + ARROW_RETURN_NOT_OK( + AssertEq(true, request.query == "SELECT PREPARED STATEMENT" || + request.query == "UPDATE PREPARED STATEMENT")); + + sql::ActionCreatePreparedStatementResult result; + result.prepared_statement_handle = request.query + " HANDLE"; + + return result; + } + + Status ClosePreparedStatement( + const ServerCallContext& context, + const sql::ActionClosePreparedStatementRequest& request) override { + return Status::OK(); + } + + Status DoPutPreparedStatementQuery(const ServerCallContext& context, + const sql::PreparedStatementQuery& command, + FlightMessageReader* reader, + FlightMetadataWriter* writer) override { + ARROW_RETURN_NOT_OK(AssertEq("SELECT PREPARED STATEMENT HANDLE", + command.prepared_statement_handle)); + + ARROW_ASSIGN_OR_RAISE(auto actual_schema, reader->GetSchema()); + ARROW_RETURN_NOT_OK(AssertEq(*GetQuerySchema(), *actual_schema)); + + return Status::OK(); + } + + arrow::Result DoPutPreparedStatementUpdate( + const ServerCallContext& context, const sql::PreparedStatementUpdate& command, + FlightMessageReader* reader) override { + ARROW_RETURN_NOT_OK(AssertEq("UPDATE PREPARED STATEMENT HANDLE", + command.prepared_statement_handle)); + + return kUpdatePreparedStatementExpectedRows; + } + + private: + arrow::Result> GetFlightInfoForCommand( + const FlightDescriptor& descriptor, const std::shared_ptr& schema) { + std::vector endpoints{FlightEndpoint{{descriptor.cmd}, {}}}; + ARROW_ASSIGN_OR_RAISE(auto result, + FlightInfo::Make(*schema, descriptor, endpoints, -1, -1)) + + return std::unique_ptr(new FlightInfo(result)); + } + + arrow::Result> DoGetForTestCase( + const std::shared_ptr& schema) { + ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({}, schema)); + return std::unique_ptr(new RecordBatchStream(reader)); + } +}; + +/// \brief Integration test scenario for validating Flight SQL specs across multiple +/// implementations. This should ensure that RPC objects are being built and parsed +/// correctly for multiple languages and that the Arrow schemas are returned as expected. +class FlightSqlScenario : public Scenario { + Status MakeServer(std::unique_ptr* server, + FlightServerOptions* options) override { + server->reset(new FlightSqlScenarioServer()); + return Status::OK(); + } + + Status MakeClient(FlightClientOptions* options) override { return Status::OK(); } + + Status Validate(std::shared_ptr expected_schema, + arrow::Result> flight_info_result, + sql::FlightSqlClient* sql_client) { + FlightCallOptions call_options; + + ARROW_ASSIGN_OR_RAISE(auto flight_info, flight_info_result); + ARROW_ASSIGN_OR_RAISE( + auto reader, sql_client->DoGet(call_options, flight_info->endpoints()[0].ticket)); + + ARROW_ASSIGN_OR_RAISE(auto actual_schema, reader->GetSchema()); + + AssertSchemaEqual(expected_schema, actual_schema); + + return Status::OK(); + } + + Status RunClient(std::unique_ptr client) override { + sql::FlightSqlClient sql_client(std::move(client)); + + ARROW_RETURN_NOT_OK(ValidateMetadataRetrieval(&sql_client)); + + ARROW_RETURN_NOT_OK(ValidateStatementExecution(&sql_client)); + + ARROW_RETURN_NOT_OK(ValidatePreparedStatementExecution(&sql_client)); + + return Status::OK(); + } + + Status ValidateMetadataRetrieval(sql::FlightSqlClient* sql_client) { + FlightCallOptions options; + + std::string catalog = "catalog"; + std::string db_schema_filter_pattern = "db_schema_filter_pattern"; + std::string table_filter_pattern = "table_filter_pattern"; + std::string table = "table"; + std::string db_schema = "db_schema"; + std::vector table_types = {"table", "view"}; + + sql::TableRef table_ref = {catalog, db_schema, table}; + sql::TableRef pk_table_ref = {"pk_catalog", "pk_db_schema", "pk_table"}; + sql::TableRef fk_table_ref = {"fk_catalog", "fk_db_schema", "fk_table"}; + + ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetCatalogsSchema(), + sql_client->GetCatalogs(options), sql_client)); + ARROW_RETURN_NOT_OK( + Validate(sql::SqlSchema::GetDbSchemasSchema(), + sql_client->GetDbSchemas(options, &catalog, &db_schema_filter_pattern), + sql_client)); + ARROW_RETURN_NOT_OK( + Validate(sql::SqlSchema::GetTablesSchemaWithIncludedSchema(), + sql_client->GetTables(options, &catalog, &db_schema_filter_pattern, + &table_filter_pattern, true, &table_types), + sql_client)); + ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetTableTypesSchema(), + sql_client->GetTableTypes(options), sql_client)); + ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetPrimaryKeysSchema(), + sql_client->GetPrimaryKeys(options, table_ref), + sql_client)); + ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetExportedKeysSchema(), + sql_client->GetExportedKeys(options, table_ref), + sql_client)); + ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetImportedKeysSchema(), + sql_client->GetImportedKeys(options, table_ref), + sql_client)); + ARROW_RETURN_NOT_OK(Validate( + sql::SqlSchema::GetCrossReferenceSchema(), + sql_client->GetCrossReference(options, pk_table_ref, fk_table_ref), sql_client)); + ARROW_RETURN_NOT_OK(Validate( + sql::SqlSchema::GetSqlInfoSchema(), + sql_client->GetSqlInfo( + options, {sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME, + sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY}), + sql_client)); + + return Status::OK(); + } + + Status ValidateStatementExecution(sql::FlightSqlClient* sql_client) { + FlightCallOptions options; + + ARROW_RETURN_NOT_OK(Validate( + GetQuerySchema(), sql_client->Execute(options, "SELECT STATEMENT"), sql_client)); + ARROW_ASSIGN_OR_RAISE(auto update_statement_result, + sql_client->ExecuteUpdate(options, "UPDATE STATEMENT")); + if (update_statement_result != kUpdateStatementExpectedRows) { + return Status::Invalid("Expected 'UPDATE STATEMENT' return ", + kUpdateStatementExpectedRows, ", got ", + update_statement_result); + } + + return Status::OK(); + } + + Status ValidatePreparedStatementExecution(sql::FlightSqlClient* sql_client) { + FlightCallOptions options; + + ARROW_ASSIGN_OR_RAISE(auto select_prepared_statement, + sql_client->Prepare(options, "SELECT PREPARED STATEMENT")); + + auto parameters = + RecordBatch::Make(GetQuerySchema(), 1, {ArrayFromJSON(int64(), "[1]")}); + ARROW_RETURN_NOT_OK(select_prepared_statement->SetParameters(parameters)); + + ARROW_RETURN_NOT_OK( + Validate(GetQuerySchema(), select_prepared_statement->Execute(), sql_client)); + ARROW_RETURN_NOT_OK(select_prepared_statement->Close()); + + ARROW_ASSIGN_OR_RAISE(auto update_prepared_statement, + sql_client->Prepare(options, "UPDATE PREPARED STATEMENT")); + ARROW_ASSIGN_OR_RAISE(auto update_prepared_statement_result, + update_prepared_statement->ExecuteUpdate()); + if (update_prepared_statement_result != kUpdatePreparedStatementExpectedRows) { + return Status::Invalid("Expected 'UPDATE STATEMENT' return ", + kUpdatePreparedStatementExpectedRows, ", got ", + update_prepared_statement_result); + } + ARROW_RETURN_NOT_OK(update_prepared_statement->Close()); + + return Status::OK(); + } +}; + +Status GetScenario(const std::string& scenario_name, std::shared_ptr* out) { + if (scenario_name == "auth:basic_proto") { + *out = std::make_shared(); + return Status::OK(); + } else if (scenario_name == "middleware") { + *out = std::make_shared(); + return Status::OK(); + } else if (scenario_name == "flight_sql") { + *out = std::make_shared(); + return Status::OK(); + } + return Status::KeyError("Scenario not found: ", scenario_name); +} + +} // namespace integration_tests +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/test_integration.h b/cpp/src/arrow/flight/integration_tests/test_integration.h similarity index 95% rename from cpp/src/arrow/flight/test_integration.h rename to cpp/src/arrow/flight/integration_tests/test_integration.h index 5d9bd7fd7bd..74093f8cd23 100644 --- a/cpp/src/arrow/flight/test_integration.h +++ b/cpp/src/arrow/flight/integration_tests/test_integration.h @@ -17,6 +17,8 @@ // Integration test scenarios for Arrow Flight. +#pragma once + #include "arrow/flight/visibility.h" #include @@ -28,16 +30,20 @@ namespace arrow { namespace flight { +namespace integration_tests { /// \brief An integration test for Arrow Flight. class ARROW_FLIGHT_EXPORT Scenario { public: virtual ~Scenario() = default; + /// \brief Set up the server. virtual Status MakeServer(std::unique_ptr* server, FlightServerOptions* options) = 0; + /// \brief Set up the client. virtual Status MakeClient(FlightClientOptions* options) = 0; + /// \brief Run the scenario as the client. virtual Status RunClient(std::unique_ptr client) = 0; }; @@ -45,5 +51,6 @@ class ARROW_FLIGHT_EXPORT Scenario { /// \brief Get the implementation of an integration test scenario by name. Status GetScenario(const std::string& scenario_name, std::shared_ptr* out); +} // namespace integration_tests } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/test_integration_client.cc b/cpp/src/arrow/flight/integration_tests/test_integration_client.cc similarity index 94% rename from cpp/src/arrow/flight/test_integration_client.cc rename to cpp/src/arrow/flight/integration_tests/test_integration_client.cc index 6c1d6904603..366284389f1 100644 --- a/cpp/src/arrow/flight/test_integration_client.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration_client.cc @@ -41,7 +41,7 @@ #include "arrow/util/logging.h" #include "arrow/flight/api.h" -#include "arrow/flight/test_integration.h" +#include "arrow/flight/integration_tests/test_integration.h" #include "arrow/flight/test_util.h" DEFINE_string(host, "localhost", "Server port to connect to"); @@ -51,6 +51,7 @@ DEFINE_string(scenario, "", "Integration test scenario to run"); namespace arrow { namespace flight { +namespace integration_tests { /// \brief Helper to read all batches from a JsonReader Status ReadBatches(std::unique_ptr& reader, @@ -133,7 +134,7 @@ Status ConsumeFlightLocation( return Status::OK(); } -class IntegrationTestScenario : public flight::Scenario { +class IntegrationTestScenario : public Scenario { public: Status MakeServer(std::unique_ptr* server, FlightServerOptions* options) override { @@ -201,12 +202,13 @@ class IntegrationTestScenario : public flight::Scenario { } }; +} // namespace integration_tests } // namespace flight } // namespace arrow constexpr int kRetries = 3; -arrow::Status RunScenario(arrow::flight::Scenario* scenario) { +arrow::Status RunScenario(arrow::flight::integration_tests::Scenario* scenario) { auto options = arrow::flight::FlightClientOptions::Defaults(); std::unique_ptr client; @@ -222,11 +224,13 @@ int main(int argc, char** argv) { gflags::SetUsageMessage("Integration testing client for Flight."); gflags::ParseCommandLineFlags(&argc, &argv, true); - std::shared_ptr scenario; + std::shared_ptr scenario; if (!FLAGS_scenario.empty()) { - ARROW_CHECK_OK(arrow::flight::GetScenario(FLAGS_scenario, &scenario)); + ARROW_CHECK_OK( + arrow::flight::integration_tests::GetScenario(FLAGS_scenario, &scenario)); } else { - scenario = std::make_shared(); + scenario = + std::make_shared(); } // ARROW-11908: retry a few times in case a client is slow to bring up the server diff --git a/cpp/src/arrow/flight/test_integration_server.cc b/cpp/src/arrow/flight/integration_tests/test_integration_server.cc similarity index 94% rename from cpp/src/arrow/flight/test_integration_server.cc rename to cpp/src/arrow/flight/integration_tests/test_integration_server.cc index 4b904b0eba1..92b2241a872 100644 --- a/cpp/src/arrow/flight/test_integration_server.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration_server.cc @@ -34,10 +34,10 @@ #include "arrow/testing/json_integration.h" #include "arrow/util/logging.h" +#include "arrow/flight/integration_tests/test_integration.h" #include "arrow/flight/internal.h" #include "arrow/flight/server.h" #include "arrow/flight/server_auth.h" -#include "arrow/flight/test_integration.h" #include "arrow/flight/test_util.h" DEFINE_int32(port, 31337, "Server port to listen on"); @@ -45,6 +45,7 @@ DEFINE_string(scenario, "", "Integration test senario to run"); namespace arrow { namespace flight { +namespace integration_tests { struct IntegrationDataset { std::shared_ptr schema; @@ -175,6 +176,7 @@ class IntegrationTestScenario : public Scenario { } }; +} // namespace integration_tests } // namespace flight } // namespace arrow @@ -184,12 +186,14 @@ int main(int argc, char** argv) { gflags::SetUsageMessage("Integration testing server for Flight."); gflags::ParseCommandLineFlags(&argc, &argv, true); - std::shared_ptr scenario; + std::shared_ptr scenario; if (!FLAGS_scenario.empty()) { - ARROW_CHECK_OK(arrow::flight::GetScenario(FLAGS_scenario, &scenario)); + ARROW_CHECK_OK( + arrow::flight::integration_tests::GetScenario(FLAGS_scenario, &scenario)); } else { - scenario = std::make_shared(); + scenario = + std::make_shared(); } arrow::flight::Location location; ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location)); diff --git a/cpp/src/arrow/flight/sql/ArrowFlightSqlConfig.cmake.in b/cpp/src/arrow/flight/sql/ArrowFlightSqlConfig.cmake.in new file mode 100644 index 00000000000..1658f44f418 --- /dev/null +++ b/cpp/src/arrow/flight/sql/ArrowFlightSqlConfig.cmake.in @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This config sets the following variables in your project:: +# +# ArrowFlightSql_FOUND - true if Arrow Flight SQL found on the system +# +# This config sets the following targets in your project:: +# +# arrow_flight_sql_shared - for linked as shared library if shared library is built +# arrow_flight_sql_static - for linked as static library if static library is built + +@PACKAGE_INIT@ + +include(CMakeFindDependencyMacro) +find_dependency(ArrowFlight) + +# Load targets only once. If we load targets multiple times, CMake reports +# already existent target error. +if(NOT (TARGET arrow_flight_sql_shared OR TARGET arrow_flight_sql_static)) + include("${CMAKE_CURRENT_LIST_DIR}/ArrowFlightSqlTargets.cmake") +endif() diff --git a/cpp/src/arrow/flight/sql/CMakeLists.txt b/cpp/src/arrow/flight/sql/CMakeLists.txt new file mode 100644 index 00000000000..4a31f5ba2e2 --- /dev/null +++ b/cpp/src/arrow/flight/sql/CMakeLists.txt @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +add_custom_target(arrow_flight_sql) + +arrow_install_all_headers("arrow/flight/sql") + +set(FLIGHT_SQL_PROTO_PATH "${ARROW_SOURCE_DIR}/../format") +set(FLIGHT_SQL_PROTO ${ARROW_SOURCE_DIR}/../format/FlightSql.proto) + +set(FLIGHT_SQL_GENERATED_PROTO_FILES "${CMAKE_CURRENT_BINARY_DIR}/FlightSql.pb.cc" + "${CMAKE_CURRENT_BINARY_DIR}/FlightSql.pb.h") + +set(PROTO_DEPENDS ${FLIGHT_SQL_PROTO} ${ARROW_PROTOBUF_LIBPROTOBUF}) + +add_custom_command(OUTPUT ${FLIGHT_SQL_GENERATED_PROTO_FILES} + COMMAND ${ARROW_PROTOBUF_PROTOC} "-I${FLIGHT_SQL_PROTO_PATH}" + "--cpp_out=${CMAKE_CURRENT_BINARY_DIR}" "${FLIGHT_SQL_PROTO}" + DEPENDS ${PROTO_DEPENDS}) + +set_source_files_properties(${FLIGHT_SQL_GENERATED_PROTO_FILES} PROPERTIES GENERATED TRUE) + +add_custom_target(flight_sql_protobuf_gen ALL DEPENDS ${FLIGHT_SQL_GENERATED_PROTO_FILES}) + +set(ARROW_FLIGHT_SQL_SRCS server.cc sql_info_internal.cc client.cc + "${CMAKE_CURRENT_BINARY_DIR}/FlightSql.pb.cc") + +add_arrow_lib(arrow_flight_sql + CMAKE_PACKAGE_NAME + ArrowFlightSql + PKG_CONFIG_NAME + arrow-flight-sql + OUTPUTS + ARROW_FLIGHT_SQL_LIBRARIES + SOURCES + ${ARROW_FLIGHT_SQL_SRCS} + DEPENDENCIES + flight_sql_protobuf_gen + SHARED_LINK_FLAGS + ${ARROW_VERSION_SCRIPT_FLAGS} # Defined in cpp/arrow/CMakeLists.txt + SHARED_LINK_LIBS + arrow_flight_shared + STATIC_LINK_LIBS + arrow_flight_static) + +if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static") + set(ARROW_FLIGHT_SQL_TEST_LINK_LIBS + arrow_flight_sql_static arrow_flight_testing_static + ${ARROW_FLIGHT_STATIC_LINK_LIBS} ${ARROW_TEST_LINK_LIBS}) +else() + set(ARROW_FLIGHT_SQL_TEST_LINK_LIBS arrow_flight_sql_shared arrow_flight_testing_shared + ${ARROW_TEST_LINK_LIBS}) +endif() + +# Build test server for unit tests +if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES) + find_package(SQLite3Alt REQUIRED) + + set(ARROW_FLIGHT_SQL_TEST_SERVER_SRCS + example/sqlite_sql_info.cc + example/sqlite_statement.cc + example/sqlite_statement_batch_reader.cc + example/sqlite_server.cc + example/sqlite_tables_schema_batch_reader.cc) + + add_arrow_test(flight_sql_test + SOURCES + client_test.cc + server_test.cc + ${ARROW_FLIGHT_SQL_TEST_SERVER_SRCS} + STATIC_LINK_LIBS + ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS} + ${SQLite3_LIBRARIES} + LABELS + "arrow_flight_sql") + + add_executable(flight_sql_test_server test_server_cli.cc + ${ARROW_FLIGHT_SQL_TEST_SERVER_SRCS}) + target_link_libraries(flight_sql_test_server + PRIVATE ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS} ${GFLAGS_LIBRARIES} + ${SQLite3_LIBRARIES}) + + add_executable(flight_sql_test_app test_app_cli.cc) + target_link_libraries(flight_sql_test_app PRIVATE ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS} + ${GFLAGS_LIBRARIES}) +endif() diff --git a/cpp/src/arrow/flight/sql/api.h b/cpp/src/arrow/flight/sql/api.h new file mode 100644 index 00000000000..3b909eedf29 --- /dev/null +++ b/cpp/src/arrow/flight/sql/api.h @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/flight/sql/client.h" diff --git a/cpp/src/arrow/flight/sql/arrow-flight-sql.pc.in b/cpp/src/arrow/flight/sql/arrow-flight-sql.pc.in new file mode 100644 index 00000000000..6d4eab0b4a0 --- /dev/null +++ b/cpp/src/arrow/flight/sql/arrow-flight-sql.pc.in @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +libdir=@CMAKE_INSTALL_FULL_LIBDIR@ +includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ + +Name: Apache Arrow Flight SQL +Description: Apache Arrow Flight SQL extension +Version: @ARROW_VERSION@ +Requires: arrow-flight +Libs: -L${libdir} -larrow_flight_sql diff --git a/cpp/src/arrow/flight/sql/client.cc b/cpp/src/arrow/flight/sql/client.cc new file mode 100644 index 00000000000..50a5777cd9c --- /dev/null +++ b/cpp/src/arrow/flight/sql/client.cc @@ -0,0 +1,425 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/sql/client.h" + +#include + +#include "arrow/buffer.h" +#include "arrow/flight/sql/FlightSql.pb.h" +#include "arrow/flight/types.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" +#include "arrow/result.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/util/logging.h" + +namespace flight_sql_pb = arrow::flight::protocol::sql; + +namespace arrow { +namespace flight { +namespace sql { + +FlightSqlClient::FlightSqlClient(std::shared_ptr client) + : impl_(std::move(client)) {} + +PreparedStatement::PreparedStatement(FlightSqlClient* client, std::string handle, + std::shared_ptr dataset_schema, + std::shared_ptr parameter_schema, + FlightCallOptions options) + : client_(client), + options_(std::move(options)), + handle_(std::move(handle)), + dataset_schema_(std::move(dataset_schema)), + parameter_schema_(std::move(parameter_schema)), + is_closed_(false) {} + +PreparedStatement::~PreparedStatement() { + if (IsClosed()) return; + + const Status status = Close(); + if (!status.ok()) { + ARROW_LOG(ERROR) << "Failed to delete PreparedStatement: " << status.ToString(); + } +} + +inline FlightDescriptor GetFlightDescriptorForCommand( + const google::protobuf::Message& command) { + google::protobuf::Any any; + any.PackFrom(command); + + const std::string& string = any.SerializeAsString(); + return FlightDescriptor::Command(string); +} + +arrow::Result> GetFlightInfoForCommand( + FlightSqlClient& client, const FlightCallOptions& options, + const google::protobuf::Message& command) { + const FlightDescriptor& descriptor = GetFlightDescriptorForCommand(command); + + ARROW_ASSIGN_OR_RAISE(auto flight_info, client.GetFlightInfo(options, descriptor)); + return std::move(flight_info); +} + +arrow::Result> FlightSqlClient::Execute( + const FlightCallOptions& options, const std::string& query) { + flight_sql_pb::CommandStatementQuery command; + command.set_query(query); + + return GetFlightInfoForCommand(*this, options, command); +} + +arrow::Result FlightSqlClient::ExecuteUpdate(const FlightCallOptions& options, + const std::string& query) { + flight_sql_pb::CommandStatementUpdate command; + command.set_query(query); + + const FlightDescriptor& descriptor = GetFlightDescriptorForCommand(command); + + std::unique_ptr writer; + std::unique_ptr reader; + + ARROW_RETURN_NOT_OK(DoPut(options, descriptor, NULLPTR, &writer, &reader)); + + std::shared_ptr metadata; + + ARROW_RETURN_NOT_OK(reader->ReadMetadata(&metadata)); + + flight_sql_pb::DoPutUpdateResult doPutUpdateResult; + + flight_sql_pb::DoPutUpdateResult result; + if (!result.ParseFromArray(metadata->data(), static_cast(metadata->size()))) { + return Status::Invalid("Unable to parse DoPutUpdateResult object."); + } + + return result.record_count(); +} + +arrow::Result> FlightSqlClient::GetCatalogs( + const FlightCallOptions& options) { + flight_sql_pb::CommandGetCatalogs command; + + return GetFlightInfoForCommand(*this, options, command); +} + +arrow::Result> FlightSqlClient::GetDbSchemas( + const FlightCallOptions& options, const std::string* catalog, + const std::string* db_schema_filter_pattern) { + flight_sql_pb::CommandGetDbSchemas command; + if (catalog != NULLPTR) { + command.set_catalog(*catalog); + } + if (db_schema_filter_pattern != NULLPTR) { + command.set_db_schema_filter_pattern(*db_schema_filter_pattern); + } + + return GetFlightInfoForCommand(*this, options, command); +} + +arrow::Result> FlightSqlClient::GetTables( + const FlightCallOptions& options, const std::string* catalog, + const std::string* db_schema_filter_pattern, const std::string* table_filter_pattern, + bool include_schema, const std::vector* table_types) { + flight_sql_pb::CommandGetTables command; + + if (catalog != NULLPTR) { + command.set_catalog(*catalog); + } + + if (db_schema_filter_pattern != NULLPTR) { + command.set_db_schema_filter_pattern(*db_schema_filter_pattern); + } + + if (table_filter_pattern != NULLPTR) { + command.set_table_name_filter_pattern(*table_filter_pattern); + } + + command.set_include_schema(include_schema); + + if (table_types != NULLPTR) { + for (const std::string& table_type : *table_types) { + command.add_table_types(table_type); + } + } + + return GetFlightInfoForCommand(*this, options, command); +} + +arrow::Result> FlightSqlClient::GetPrimaryKeys( + const FlightCallOptions& options, const TableRef& table_ref) { + flight_sql_pb::CommandGetPrimaryKeys command; + + if (table_ref.catalog.has_value()) { + command.set_catalog(table_ref.catalog.value()); + } + + if (table_ref.db_schema.has_value()) { + command.set_db_schema(table_ref.db_schema.value()); + } + + command.set_table(table_ref.table); + + return GetFlightInfoForCommand(*this, options, command); +} + +arrow::Result> FlightSqlClient::GetExportedKeys( + const FlightCallOptions& options, const TableRef& table_ref) { + flight_sql_pb::CommandGetExportedKeys command; + + if (table_ref.catalog.has_value()) { + command.set_catalog(table_ref.catalog.value()); + } + + if (table_ref.db_schema.has_value()) { + command.set_db_schema(table_ref.db_schema.value()); + } + + command.set_table(table_ref.table); + + return GetFlightInfoForCommand(*this, options, command); +} + +arrow::Result> FlightSqlClient::GetImportedKeys( + const FlightCallOptions& options, const TableRef& table_ref) { + flight_sql_pb::CommandGetImportedKeys command; + + if (table_ref.catalog.has_value()) { + command.set_catalog(table_ref.catalog.value()); + } + + if (table_ref.db_schema.has_value()) { + command.set_db_schema(table_ref.db_schema.value()); + } + + command.set_table(table_ref.table); + + return GetFlightInfoForCommand(*this, options, command); +} + +arrow::Result> FlightSqlClient::GetCrossReference( + const FlightCallOptions& options, const TableRef& pk_table_ref, + const TableRef& fk_table_ref) { + flight_sql_pb::CommandGetCrossReference command; + + if (pk_table_ref.catalog.has_value()) { + command.set_pk_catalog(pk_table_ref.catalog.value()); + } + if (pk_table_ref.db_schema.has_value()) { + command.set_pk_db_schema(pk_table_ref.db_schema.value()); + } + command.set_pk_table(pk_table_ref.table); + + if (fk_table_ref.catalog.has_value()) { + command.set_fk_catalog(fk_table_ref.catalog.value()); + } + if (fk_table_ref.db_schema.has_value()) { + command.set_fk_db_schema(fk_table_ref.db_schema.value()); + } + command.set_fk_table(fk_table_ref.table); + + return GetFlightInfoForCommand(*this, options, command); +} + +arrow::Result> FlightSqlClient::GetTableTypes( + const FlightCallOptions& options) { + flight_sql_pb::CommandGetTableTypes command; + + return GetFlightInfoForCommand(*this, options, command); +} + +arrow::Result> FlightSqlClient::DoGet( + const FlightCallOptions& options, const Ticket& ticket) { + std::unique_ptr stream; + ARROW_RETURN_NOT_OK(DoGet(options, ticket, &stream)); + + return std::move(stream); +} + +arrow::Result> FlightSqlClient::Prepare( + const FlightCallOptions& options, const std::string& query) { + google::protobuf::Any command; + flight_sql_pb::ActionCreatePreparedStatementRequest request; + request.set_query(query); + command.PackFrom(request); + + Action action; + action.type = "CreatePreparedStatement"; + action.body = Buffer::FromString(command.SerializeAsString()); + + std::unique_ptr results; + + ARROW_RETURN_NOT_OK(DoAction(options, action, &results)); + + std::unique_ptr result; + ARROW_RETURN_NOT_OK(results->Next(&result)); + + google::protobuf::Any prepared_result; + + std::shared_ptr message = std::move(result->body); + if (!prepared_result.ParseFromArray(message->data(), + static_cast(message->size()))) { + return Status::Invalid("Unable to parse packed ActionCreatePreparedStatementResult"); + } + + flight_sql_pb::ActionCreatePreparedStatementResult prepared_statement_result; + + if (!prepared_result.UnpackTo(&prepared_statement_result)) { + return Status::Invalid("Unable to unpack ActionCreatePreparedStatementResult"); + } + + const std::string& serialized_dataset_schema = + prepared_statement_result.dataset_schema(); + const std::string& serialized_parameter_schema = + prepared_statement_result.parameter_schema(); + + std::shared_ptr dataset_schema; + if (!serialized_dataset_schema.empty()) { + io::BufferReader dataset_schema_reader(serialized_dataset_schema); + ipc::DictionaryMemo in_memo; + ARROW_ASSIGN_OR_RAISE(dataset_schema, ReadSchema(&dataset_schema_reader, &in_memo)); + } + std::shared_ptr parameter_schema; + if (!serialized_parameter_schema.empty()) { + io::BufferReader parameter_schema_reader(serialized_parameter_schema); + ipc::DictionaryMemo in_memo; + ARROW_ASSIGN_OR_RAISE(parameter_schema, + ReadSchema(¶meter_schema_reader, &in_memo)); + } + auto handle = prepared_statement_result.prepared_statement_handle(); + + return std::make_shared(this, handle, dataset_schema, + parameter_schema, options); +} + +arrow::Result> PreparedStatement::Execute() { + if (is_closed_) { + return Status::Invalid("Statement already closed."); + } + + flight_sql_pb::CommandPreparedStatementQuery execute_query_command; + + execute_query_command.set_prepared_statement_handle(handle_); + + google::protobuf::Any any; + any.PackFrom(execute_query_command); + + const std::string& string = any.SerializeAsString(); + const FlightDescriptor descriptor = FlightDescriptor::Command(string); + + if (parameter_binding_ && parameter_binding_->num_rows() > 0) { + std::unique_ptr writer; + std::unique_ptr reader; + ARROW_RETURN_NOT_OK(client_->DoPut(options_, descriptor, parameter_binding_->schema(), + &writer, &reader)); + + ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*parameter_binding_)); + ARROW_RETURN_NOT_OK(writer->DoneWriting()); + // Wait for the server to ack the result + std::shared_ptr buffer; + ARROW_RETURN_NOT_OK(reader->ReadMetadata(&buffer)); + } + + ARROW_ASSIGN_OR_RAISE(auto flight_info, client_->GetFlightInfo(options_, descriptor)); + return std::move(flight_info); +} + +arrow::Result PreparedStatement::ExecuteUpdate() { + if (is_closed_) { + return Status::Invalid("Statement already closed."); + } + + flight_sql_pb::CommandPreparedStatementUpdate command; + command.set_prepared_statement_handle(handle_); + const FlightDescriptor& descriptor = GetFlightDescriptorForCommand(command); + std::unique_ptr writer; + std::unique_ptr reader; + + if (parameter_binding_ && parameter_binding_->num_rows() > 0) { + ARROW_RETURN_NOT_OK(client_->DoPut(options_, descriptor, parameter_binding_->schema(), + &writer, &reader)); + ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*parameter_binding_)); + } else { + const std::shared_ptr schema = arrow::schema({}); + ARROW_RETURN_NOT_OK(client_->DoPut(options_, descriptor, schema, &writer, &reader)); + const auto& record_batch = + arrow::RecordBatch::Make(schema, 0, (std::vector>){}); + ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*record_batch)); + } + + ARROW_RETURN_NOT_OK(writer->DoneWriting()); + std::shared_ptr metadata; + ARROW_RETURN_NOT_OK(reader->ReadMetadata(&metadata)); + ARROW_RETURN_NOT_OK(writer->Close()); + + flight_sql_pb::DoPutUpdateResult result; + if (!result.ParseFromArray(metadata->data(), static_cast(metadata->size()))) { + return Status::Invalid("Unable to parse DoPutUpdateResult object."); + } + + return result.record_count(); +} + +Status PreparedStatement::SetParameters(std::shared_ptr parameter_binding) { + parameter_binding_ = std::move(parameter_binding); + + return Status::OK(); +} + +bool PreparedStatement::IsClosed() const { return is_closed_; } + +std::shared_ptr PreparedStatement::dataset_schema() const { + return dataset_schema_; +} + +std::shared_ptr PreparedStatement::parameter_schema() const { + return parameter_schema_; +} + +Status PreparedStatement::Close() { + if (is_closed_) { + return Status::Invalid("Statement already closed."); + } + google::protobuf::Any command; + flight_sql_pb::ActionClosePreparedStatementRequest request; + request.set_prepared_statement_handle(handle_); + + command.PackFrom(request); + + Action action; + action.type = "ClosePreparedStatement"; + action.body = Buffer::FromString(command.SerializeAsString()); + + std::unique_ptr results; + + ARROW_RETURN_NOT_OK(client_->DoAction(options_, action, &results)); + + is_closed_ = true; + + return Status::OK(); +} + +arrow::Result> FlightSqlClient::GetSqlInfo( + const FlightCallOptions& options, const std::vector& sql_info) { + flight_sql_pb::CommandGetSqlInfo command; + for (const int& info : sql_info) command.add_info(info); + + return GetFlightInfoForCommand(*this, options, command); +} + +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/client.h b/cpp/src/arrow/flight/sql/client.h new file mode 100644 index 00000000000..5bf1b3e64a0 --- /dev/null +++ b/cpp/src/arrow/flight/sql/client.h @@ -0,0 +1,247 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "arrow/flight/client.h" +#include "arrow/flight/sql/types.h" +#include "arrow/flight/types.h" +#include "arrow/result.h" +#include "arrow/status.h" + +namespace arrow { +namespace flight { +namespace sql { + +class PreparedStatement; + +/// \brief Flight client with Flight SQL semantics. +class ARROW_EXPORT FlightSqlClient { + friend class PreparedStatement; + + private: + std::shared_ptr impl_; + + public: + explicit FlightSqlClient(std::shared_ptr client); + + virtual ~FlightSqlClient() = default; + + /// \brief Execute a query on the server. + /// \param[in] options RPC-layer hints for this call. + /// \param[in] query The query to be executed in the UTF-8 format. + /// \return The FlightInfo describing where to access the dataset. + arrow::Result> Execute(const FlightCallOptions& options, + const std::string& query); + + /// \brief Execute an update query on the server. + /// \param[in] options RPC-layer hints for this call. + /// \param[in] query The query to be executed in the UTF-8 format. + /// \return The quantity of rows affected by the operation. + arrow::Result ExecuteUpdate(const FlightCallOptions& options, + const std::string& query); + + /// \brief Request a list of catalogs. + /// \param[in] options RPC-layer hints for this call. + /// \return The FlightInfo describing where to access the dataset. + arrow::Result> GetCatalogs( + const FlightCallOptions& options); + + /// \brief Request a list of database schemas. + /// \param[in] options RPC-layer hints for this call. + /// \param[in] catalog The catalog. + /// \param[in] db_schema_filter_pattern The schema filter pattern. + /// \return The FlightInfo describing where to access the dataset. + arrow::Result> GetDbSchemas( + const FlightCallOptions& options, const std::string* catalog, + const std::string* db_schema_filter_pattern); + + /// \brief Given a flight ticket and schema, request to be sent the + /// stream. Returns record batch stream reader + /// \param[in] options Per-RPC options + /// \param[in] ticket The flight ticket to use + /// \return The returned RecordBatchReader + arrow::Result> DoGet( + const FlightCallOptions& options, const Ticket& ticket); + + /// \brief Request a list of tables. + /// \param[in] options RPC-layer hints for this call. + /// \param[in] catalog The catalog. + /// \param[in] db_schema_filter_pattern The schema filter pattern. + /// \param[in] table_filter_pattern The table filter pattern. + /// \param[in] include_schema True to include the schema upon return, + /// false to not include the schema. + /// \param[in] table_types The table types to include. + /// \return The FlightInfo describing where to access the dataset. + arrow::Result> GetTables( + const FlightCallOptions& options, const std::string* catalog, + const std::string* db_schema_filter_pattern, + const std::string* table_filter_pattern, bool include_schema, + const std::vector* table_types); + + /// \brief Request the primary keys for a table. + /// \param[in] options RPC-layer hints for this call. + /// \param[in] table_ref The table reference. + /// \return The FlightInfo describing where to access the dataset. + arrow::Result> GetPrimaryKeys( + const FlightCallOptions& options, const TableRef& table_ref); + + /// \brief Retrieves a description about the foreign key columns that reference the + /// primary key columns of the given table. + /// \param[in] options RPC-layer hints for this call. + /// \param[in] table_ref The table reference. + /// \return The FlightInfo describing where to access the dataset. + arrow::Result> GetExportedKeys( + const FlightCallOptions& options, const TableRef& table_ref); + + /// \brief Retrieves the foreign key columns for the given table. + /// \param[in] options RPC-layer hints for this call. + /// \param[in] table_ref The table reference. + /// \return The FlightInfo describing where to access the dataset. + arrow::Result> GetImportedKeys( + const FlightCallOptions& options, const TableRef& table_ref); + + /// \brief Retrieves a description of the foreign key columns in the given foreign key + /// table that reference the primary key or the columns representing a unique + /// constraint of the parent table (could be the same or a different table). + /// \param[in] options RPC-layer hints for this call. + /// \param[in] pk_table_ref The table reference that exports the key. + /// \param[in] fk_table_ref The table reference that imports the key. + /// \return The FlightInfo describing where to access the dataset. + arrow::Result> GetCrossReference( + const FlightCallOptions& options, const TableRef& pk_table_ref, + const TableRef& fk_table_ref); + + /// \brief Request a list of table types. + /// \param[in] options RPC-layer hints for this call. + /// \return The FlightInfo describing where to access the dataset. + arrow::Result> GetTableTypes( + const FlightCallOptions& options); + + /// \brief Request a list of SQL information. + /// \param[in] options RPC-layer hints for this call. + /// \param[in] sql_info the SQL info required. + /// \return The FlightInfo describing where to access the dataset. + arrow::Result> GetSqlInfo(const FlightCallOptions& options, + const std::vector& sql_info); + + /// \brief Create a prepared statement object. + /// \param[in] options RPC-layer hints for this call. + /// \param[in] query The query that will be executed. + /// \return The created prepared statement. + arrow::Result> Prepare( + const FlightCallOptions& options, const std::string& query); + + /// \brief Retrieve the FlightInfo. + /// \param[in] options RPC-layer hints for this call. + /// \param[in] descriptor The flight descriptor. + /// \return The flight info with the metadata. + // NOTE: This is public because it is been used by the anonymous + // function GetFlightInfoForCommand. + virtual arrow::Result> GetFlightInfo( + const FlightCallOptions& options, const FlightDescriptor& descriptor) { + std::unique_ptr info; + ARROW_RETURN_NOT_OK(impl_->GetFlightInfo(options, descriptor, &info)); + + return info; + } + + protected: + virtual Status DoPut(const FlightCallOptions& options, + const FlightDescriptor& descriptor, + const std::shared_ptr& schema, + std::unique_ptr* stream, + std::unique_ptr* reader) { + return impl_->DoPut(options, descriptor, schema, stream, reader); + } + + virtual Status DoGet(const FlightCallOptions& options, const Ticket& ticket, + std::unique_ptr* stream) { + return impl_->DoGet(options, ticket, stream); + } + + virtual Status DoAction(const FlightCallOptions& options, const Action& action, + std::unique_ptr* results) { + return impl_->DoAction(options, action, results); + } +}; + +/// \brief PreparedStatement class from flight sql. +class ARROW_EXPORT PreparedStatement { + FlightSqlClient* client_; + FlightCallOptions options_; + std::string handle_; + std::shared_ptr dataset_schema_; + std::shared_ptr parameter_schema_; + std::shared_ptr parameter_binding_; + bool is_closed_; + + public: + /// \brief Constructor for the PreparedStatement class. + /// \param[in] client Client object used to make the RPC requests. + /// \param[in] handle Handle for this prepared statement. + /// \param[in] dataset_schema Schema of the resulting dataset. + /// \param[in] parameter_schema Schema of the parameters (if any). + /// \param[in] options RPC-layer hints for this call. + PreparedStatement(FlightSqlClient* client, std::string handle, + std::shared_ptr dataset_schema, + std::shared_ptr parameter_schema, FlightCallOptions options); + + /// \brief Default destructor for the PreparedStatement class. + /// The destructor will call the Close method from the class in order, + /// to send a request to close the PreparedStatement. + /// NOTE: It is best to explicitly close the PreparedStatement, otherwise + /// errors can't be caught. + ~PreparedStatement(); + + /// \brief Executes the prepared statement query on the server. + /// \return A FlightInfo object representing the stream(s) to fetch. + arrow::Result> Execute(); + + /// \brief Executes the prepared statement update query on the server. + /// \return The number of rows affected. + arrow::Result ExecuteUpdate(); + + /// \brief Retrieve the parameter schema from the query. + /// \return The parameter schema from the query. + std::shared_ptr parameter_schema() const; + + /// \brief Retrieve the ResultSet schema from the query. + /// \return The ResultSet schema from the query. + std::shared_ptr dataset_schema() const; + + /// \brief Set a RecordBatch that contains the parameters that will be bind. + /// \param parameter_binding The parameters that will be bind. + /// \return Status. + Status SetParameters(std::shared_ptr parameter_binding); + + /// \brief Close the prepared statement, so that this PreparedStatement can not used + /// anymore and server can free up any resources. + /// \return Status. + Status Close(); + + /// \brief Check if the prepared statement is closed. + /// \return The state of the prepared statement. + bool IsClosed() const; +}; + +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/client_test.cc b/cpp/src/arrow/flight/sql/client_test.cc new file mode 100644 index 00000000000..8c0c8333074 --- /dev/null +++ b/cpp/src/arrow/flight/sql/client_test.cc @@ -0,0 +1,515 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/client.h" + +#include +#include +#include + +#include + +#include "arrow/flight/sql/FlightSql.pb.h" +#include "arrow/flight/sql/api.h" +#include "arrow/testing/gtest_util.h" + +namespace pb = arrow::flight::protocol; +using ::testing::_; +using ::testing::Ref; + +namespace arrow { +namespace flight { +namespace sql { + +class FlightSqlClientMock : public FlightSqlClient { + public: + FlightSqlClientMock() : FlightSqlClient(nullptr) {} + + ~FlightSqlClientMock() = default; + + MOCK_METHOD(arrow::Result>, GetFlightInfo, + (const FlightCallOptions&, const FlightDescriptor&)); + MOCK_METHOD(Status, DoGet, + (const FlightCallOptions& options, const Ticket& ticket, + std::unique_ptr* stream)); + MOCK_METHOD(Status, DoPut, + (const FlightCallOptions&, const FlightDescriptor&, + const std::shared_ptr& schema, + std::unique_ptr*, + std::unique_ptr*)); + MOCK_METHOD(Status, DoAction, + (const FlightCallOptions& options, const Action& action, + std::unique_ptr* results)); +}; + +class TestFlightSqlClient : public ::testing::Test { + protected: + FlightSqlClientMock sql_client_; + FlightCallOptions call_options_; + + void SetUp() override {} + + void TearDown() override {} +}; + +class FlightMetadataReaderMock : public FlightMetadataReader { + public: + std::shared_ptr* buffer; + + explicit FlightMetadataReaderMock(std::shared_ptr* buffer) { + this->buffer = buffer; + } + + Status ReadMetadata(std::shared_ptr* out) override { + *out = *buffer; + return Status::OK(); + } +}; + +class FlightStreamWriterMock : public FlightStreamWriter { + public: + FlightStreamWriterMock() = default; + + Status DoneWriting() override { return Status::OK(); } + + Status WriteMetadata(std::shared_ptr app_metadata) override { + return Status::OK(); + } + + Status Begin(const std::shared_ptr& schema, + const ipc::IpcWriteOptions& options) override { + return Status::OK(); + } + + Status Begin(const std::shared_ptr& schema) override { + return MetadataRecordBatchWriter::Begin(schema); + } + + ipc::WriteStats stats() const override { return ipc::WriteStats(); } + + Status WriteWithMetadata(const RecordBatch& batch, + std::shared_ptr app_metadata) override { + return Status::OK(); + } + + Status Close() override { return Status::OK(); } + + Status WriteRecordBatch(const RecordBatch& batch) override { return Status::OK(); } +}; + +FlightDescriptor getDescriptor(google::protobuf::Message& command) { + google::protobuf::Any any; + any.PackFrom(command); + + const std::string& string = any.SerializeAsString(); + return FlightDescriptor::Command(string); +} + +auto ReturnEmptyFlightInfo = [](const FlightCallOptions& options, + const FlightDescriptor& descriptor) { + std::unique_ptr flight_info; + return flight_info; +}; + +TEST_F(TestFlightSqlClient, TestGetCatalogs) { + pb::sql::CommandGetCatalogs command; + FlightDescriptor descriptor = getDescriptor(command); + + ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); + EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); + + ASSERT_OK(sql_client_.GetCatalogs(call_options_)); +} + +TEST_F(TestFlightSqlClient, TestGetDbSchemas) { + std::string schema_filter_pattern = "schema_filter_pattern"; + std::string catalog = "catalog"; + + pb::sql::CommandGetDbSchemas command; + command.set_catalog(catalog); + command.set_db_schema_filter_pattern(schema_filter_pattern); + FlightDescriptor descriptor = getDescriptor(command); + + ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); + EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); + + ASSERT_OK(sql_client_.GetDbSchemas(call_options_, &catalog, &schema_filter_pattern)); +} + +TEST_F(TestFlightSqlClient, TestGetTables) { + std::string catalog = "catalog"; + std::string schema_filter_pattern = "schema_filter_pattern"; + std::string table_name_filter_pattern = "table_name_filter_pattern"; + bool include_schema = true; + std::vector table_types = {"type1", "type2"}; + + pb::sql::CommandGetTables command; + command.set_catalog(catalog); + command.set_db_schema_filter_pattern(schema_filter_pattern); + command.set_table_name_filter_pattern(table_name_filter_pattern); + command.set_include_schema(include_schema); + for (const std::string& table_type : table_types) { + command.add_table_types(table_type); + } + FlightDescriptor descriptor = getDescriptor(command); + + ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); + EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); + + ASSERT_OK(sql_client_.GetTables(call_options_, &catalog, &schema_filter_pattern, + &table_name_filter_pattern, include_schema, + &table_types)); +} + +TEST_F(TestFlightSqlClient, TestGetTableTypes) { + pb::sql::CommandGetTableTypes command; + FlightDescriptor descriptor = getDescriptor(command); + + ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); + EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); + + ASSERT_OK(sql_client_.GetTableTypes(call_options_)); +} + +TEST_F(TestFlightSqlClient, TestGetExported) { + std::string catalog = "catalog"; + std::string schema = "schema"; + std::string table = "table"; + + pb::sql::CommandGetExportedKeys command; + command.set_catalog(catalog); + command.set_db_schema(schema); + command.set_table(table); + FlightDescriptor descriptor = getDescriptor(command); + + ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); + EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); + + TableRef table_ref = {util::make_optional(catalog), util::make_optional(schema), table}; + ASSERT_OK(sql_client_.GetExportedKeys(call_options_, table_ref)); +} + +TEST_F(TestFlightSqlClient, TestGetImported) { + std::string catalog = "catalog"; + std::string schema = "schema"; + std::string table = "table"; + + pb::sql::CommandGetImportedKeys command; + command.set_catalog(catalog); + command.set_db_schema(schema); + command.set_table(table); + FlightDescriptor descriptor = getDescriptor(command); + + ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); + EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); + + TableRef table_ref = {util::make_optional(catalog), util::make_optional(schema), table}; + ASSERT_OK(sql_client_.GetImportedKeys(call_options_, table_ref)); +} + +TEST_F(TestFlightSqlClient, TestGetPrimary) { + std::string catalog = "catalog"; + std::string schema = "schema"; + std::string table = "table"; + + pb::sql::CommandGetPrimaryKeys command; + command.set_catalog(catalog); + command.set_db_schema(schema); + command.set_table(table); + FlightDescriptor descriptor = getDescriptor(command); + + ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); + EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); + + TableRef table_ref = {util::make_optional(catalog), util::make_optional(schema), table}; + ASSERT_OK(sql_client_.GetPrimaryKeys(call_options_, table_ref)); +} + +TEST_F(TestFlightSqlClient, TestGetCrossReference) { + std::string pk_catalog = "pk_catalog"; + std::string pk_schema = "pk_schema"; + std::string pk_table = "pk_table"; + std::string fk_catalog = "fk_catalog"; + std::string fk_schema = "fk_schema"; + std::string fk_table = "fk_table"; + + pb::sql::CommandGetCrossReference command; + command.set_pk_catalog(pk_catalog); + command.set_pk_db_schema(pk_schema); + command.set_pk_table(pk_table); + command.set_fk_catalog(fk_catalog); + command.set_fk_db_schema(fk_schema); + command.set_fk_table(fk_table); + FlightDescriptor descriptor = getDescriptor(command); + + ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); + EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); + + TableRef pk_table_ref = {util::make_optional(pk_catalog), + util::make_optional(pk_schema), pk_table}; + TableRef fk_table_ref = {util::make_optional(fk_catalog), + util::make_optional(fk_schema), fk_table}; + ASSERT_OK(sql_client_.GetCrossReference(call_options_, pk_table_ref, fk_table_ref)); +} + +TEST_F(TestFlightSqlClient, TestExecute) { + std::string query = "query"; + + pb::sql::CommandStatementQuery command; + command.set_query(query); + FlightDescriptor descriptor = getDescriptor(command); + + ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); + EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); + + ASSERT_OK(sql_client_.Execute(call_options_, query)); +} + +TEST_F(TestFlightSqlClient, TestPreparedStatementExecute) { + const std::string query = "query"; + + ON_CALL(sql_client_, DoAction) + .WillByDefault([](const FlightCallOptions& options, const Action& action, + std::unique_ptr* results) { + google::protobuf::Any command; + + pb::sql::ActionCreatePreparedStatementResult prepared_statement_result; + + prepared_statement_result.set_prepared_statement_handle("query"); + + command.PackFrom(prepared_statement_result); + + *results = std::unique_ptr(new SimpleResultStream( + {Result{Buffer::FromString(command.SerializeAsString())}})); + + return Status::OK(); + }); + + EXPECT_CALL(sql_client_, DoAction(_, _, _)).Times(2); + + ASSERT_OK_AND_ASSIGN(auto prepared_statement, + sql_client_.Prepare(call_options_, query)); + + ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); + EXPECT_CALL(sql_client_, GetFlightInfo(_, _)); + + ASSERT_OK(prepared_statement->Execute()); +} + +TEST_F(TestFlightSqlClient, TestPreparedStatementExecuteParameterBinding) { + const std::string query = "query"; + + ON_CALL(sql_client_, DoAction) + .WillByDefault([](const FlightCallOptions& options, const Action& action, + std::unique_ptr* results) { + google::protobuf::Any command; + + pb::sql::ActionCreatePreparedStatementResult prepared_statement_result; + + prepared_statement_result.set_prepared_statement_handle("query"); + + auto schema = arrow::schema({arrow::field("id", int64())}); + + std::shared_ptr schema_buffer; + const arrow::Result>& result = + arrow::ipc::SerializeSchema(*schema); + + ARROW_ASSIGN_OR_RAISE(schema_buffer, result); + + prepared_statement_result.set_parameter_schema(schema_buffer->ToString()); + + command.PackFrom(prepared_statement_result); + + *results = std::unique_ptr(new SimpleResultStream( + {Result{Buffer::FromString(command.SerializeAsString())}})); + + return Status::OK(); + }); + + std::shared_ptr buffer_ptr; + ON_CALL(sql_client_, DoPut) + .WillByDefault([&buffer_ptr](const FlightCallOptions& options, + const FlightDescriptor& descriptor1, + const std::shared_ptr& schema, + std::unique_ptr* writer, + std::unique_ptr* reader) { + writer->reset(new FlightStreamWriterMock()); + reader->reset(new FlightMetadataReaderMock(&buffer_ptr)); + + return Status::OK(); + }); + + EXPECT_CALL(sql_client_, DoAction(_, _, _)).Times(2); + EXPECT_CALL(sql_client_, DoPut(_, _, _, _, _)); + + ASSERT_OK_AND_ASSIGN(auto prepared_statement, + sql_client_.Prepare(call_options_, query)); + + auto parameter_schema = prepared_statement->parameter_schema(); + + auto result = RecordBatchFromJSON(parameter_schema, "[[1]]"); + ASSERT_OK(prepared_statement->SetParameters(result)); + + ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); + EXPECT_CALL(sql_client_, GetFlightInfo(_, _)); + + ASSERT_OK(prepared_statement->Execute()); +} + +TEST_F(TestFlightSqlClient, TestExecuteUpdate) { + std::string query = "query"; + + pb::sql::CommandStatementUpdate command; + + command.set_query(query); + + google::protobuf::Any any; + any.PackFrom(command); + + const FlightDescriptor& descriptor = FlightDescriptor::Command(any.SerializeAsString()); + + pb::sql::DoPutUpdateResult doPutUpdateResult; + doPutUpdateResult.set_record_count(100); + const std::string& string = doPutUpdateResult.SerializeAsString(); + + auto buffer_ptr = std::make_shared( + reinterpret_cast(string.data()), doPutUpdateResult.ByteSizeLong()); + + ON_CALL(sql_client_, DoPut) + .WillByDefault([&buffer_ptr](const FlightCallOptions& options, + const FlightDescriptor& descriptor1, + const std::shared_ptr& schema, + std::unique_ptr* writer, + std::unique_ptr* reader) { + reader->reset(new FlightMetadataReaderMock(&buffer_ptr)); + + return Status::OK(); + }); + + std::unique_ptr flight_info; + std::unique_ptr writer; + std::unique_ptr reader; + EXPECT_CALL(sql_client_, DoPut(Ref(call_options_), descriptor, _, _, _)); + + ASSERT_OK_AND_ASSIGN(auto num_rows, sql_client_.ExecuteUpdate(call_options_, query)); + + ASSERT_EQ(num_rows, 100); +} + +TEST_F(TestFlightSqlClient, TestGetSqlInfo) { + std::vector sql_info{pb::sql::SqlInfo::FLIGHT_SQL_SERVER_NAME, + pb::sql::SqlInfo::FLIGHT_SQL_SERVER_VERSION, + pb::sql::SqlInfo::FLIGHT_SQL_SERVER_ARROW_VERSION}; + pb::sql::CommandGetSqlInfo command; + + for (const auto& info : sql_info) command.add_info(info); + google::protobuf::Any any; + any.PackFrom(command); + const FlightDescriptor& descriptor = FlightDescriptor::Command(any.SerializeAsString()); + + ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); + EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); + + ASSERT_OK(sql_client_.GetSqlInfo(call_options_, sql_info)); +} + +template +inline void AssertTestPreparedStatementExecuteUpdateOk( + Func func, const std::shared_ptr* schema, FlightSqlClientMock& sql_client_) { + const std::string query = "SELECT * FROM IRRELEVANT"; + int64_t expected_rows = 100L; + pb::sql::DoPutUpdateResult result; + result.set_record_count(expected_rows); + + ON_CALL(sql_client_, DoAction) + .WillByDefault([&query, &schema](const FlightCallOptions& options, + const Action& action, + std::unique_ptr* results) { + google::protobuf::Any command; + pb::sql::ActionCreatePreparedStatementResult prepared_statement_result; + + prepared_statement_result.set_prepared_statement_handle(query); + + if (schema != NULLPTR) { + std::shared_ptr schema_buffer; + const arrow::Result>& result = + arrow::ipc::SerializeSchema(**schema); + + ARROW_ASSIGN_OR_RAISE(schema_buffer, result); + prepared_statement_result.set_parameter_schema(schema_buffer->ToString()); + } + + command.PackFrom(prepared_statement_result); + *results = std::unique_ptr(new SimpleResultStream( + {Result{Buffer::FromString(command.SerializeAsString())}})); + + return Status::OK(); + }); + EXPECT_CALL(sql_client_, DoAction(_, _, _)).Times(2); + + auto buffer = Buffer::FromString(result.SerializeAsString()); + ON_CALL(sql_client_, DoPut) + .WillByDefault([&buffer](const FlightCallOptions& options, + const FlightDescriptor& descriptor1, + const std::shared_ptr& schema, + std::unique_ptr* writer, + std::unique_ptr* reader) { + reader->reset(new FlightMetadataReaderMock(&buffer)); + writer->reset(new FlightStreamWriterMock()); + return Status::OK(); + }); + if (schema == NULLPTR) { + EXPECT_CALL(sql_client_, DoPut(_, _, _, _, _)); + } else { + EXPECT_CALL(sql_client_, DoPut(_, _, *schema, _, _)); + } + + ASSERT_OK_AND_ASSIGN(auto prepared_statement, sql_client_.Prepare({}, query)); + func(prepared_statement, sql_client_, schema, expected_rows); + ASSERT_OK_AND_ASSIGN(auto rows, prepared_statement->ExecuteUpdate()); + ASSERT_EQ(expected_rows, rows); + ASSERT_OK(prepared_statement->Close()); +} + +TEST_F(TestFlightSqlClient, TestPreparedStatementExecuteUpdateNoParameterBinding) { + AssertTestPreparedStatementExecuteUpdateOk( + [](const std::shared_ptr& prepared_statement, + FlightSqlClient& sql_client_, const std::shared_ptr* schema, + const int64_t& row_count) {}, + NULLPTR, sql_client_); +} + +TEST_F(TestFlightSqlClient, TestPreparedStatementExecuteUpdateWithParameterBinding) { + const auto schema = arrow::schema( + {arrow::field("field0", arrow::utf8()), arrow::field("field1", arrow::uint8())}); + AssertTestPreparedStatementExecuteUpdateOk( + [](const std::shared_ptr& prepared_statement, + FlightSqlClient& sql_client_, const std::shared_ptr* schema, + const int64_t& row_count) { + auto string_array = + ArrayFromJSON(utf8(), R"(["Lorem", "Ipsum", "Foo", "Bar", "Baz"])"); + auto uint8_array = ArrayFromJSON(uint8(), R"([0, 10, 15, 20, 25])"); + std::shared_ptr recordBatch = + RecordBatch::Make(*schema, row_count, {string_array, uint8_array}); + ASSERT_OK(prepared_statement->SetParameters(recordBatch)); + }, + &schema, sql_client_); +} + +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/example/sqlite_server.cc b/cpp/src/arrow/flight/sql/example/sqlite_server.cc new file mode 100644 index 00000000000..dde364f64e3 --- /dev/null +++ b/cpp/src/arrow/flight/sql/example/sqlite_server.cc @@ -0,0 +1,813 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/sql/example/sqlite_server.h" + +#include + +#include +#include +#include +#include + +#include "arrow/api.h" +#include "arrow/flight/sql/example/sqlite_sql_info.h" +#include "arrow/flight/sql/example/sqlite_statement.h" +#include "arrow/flight/sql/example/sqlite_statement_batch_reader.h" +#include "arrow/flight/sql/example/sqlite_tables_schema_batch_reader.h" +#include "arrow/flight/sql/server.h" + +namespace arrow { +namespace flight { +namespace sql { +namespace example { + +namespace { + +/// \brief Gets a SqliteStatement by given handle +arrow::Result> GetStatementByHandle( + const std::map>& prepared_statements, + const std::string& handle) { + auto search = prepared_statements.find(handle); + if (search == prepared_statements.end()) { + return Status::Invalid("Prepared statement not found"); + } + + return search->second; +} + +std::string PrepareQueryForGetTables(const GetTables& command) { + std::stringstream table_query; + + table_query << "SELECT null as catalog_name, null as schema_name, name as " + "table_name, type as table_type FROM sqlite_master where 1=1"; + + if (command.catalog.has_value()) { + table_query << " and catalog_name='" << command.catalog.value() << "'"; + } + + if (command.db_schema_filter_pattern.has_value()) { + table_query << " and schema_name LIKE '" << command.db_schema_filter_pattern.value() + << "'"; + } + + if (command.table_name_filter_pattern.has_value()) { + table_query << " and table_name LIKE '" << command.table_name_filter_pattern.value() + << "'"; + } + + if (!command.table_types.empty()) { + table_query << " and table_type IN ("; + size_t size = command.table_types.size(); + for (size_t i = 0; i < size; i++) { + table_query << "'" << command.table_types[i] << "'"; + if (size - 1 != i) { + table_query << ","; + } + } + + table_query << ")"; + } + + table_query << " order by table_name"; + return table_query.str(); +} + +Status SetParametersOnSQLiteStatement(sqlite3_stmt* stmt, FlightMessageReader* reader) { + FlightStreamChunk chunk; + while (true) { + RETURN_NOT_OK(reader->Next(&chunk)); + std::shared_ptr& record_batch = chunk.data; + if (record_batch == nullptr) break; + + const int64_t num_rows = record_batch->num_rows(); + const int& num_columns = record_batch->num_columns(); + + for (int i = 0; i < num_rows; ++i) { + for (int c = 0; c < num_columns; ++c) { + const std::shared_ptr& column = record_batch->column(c); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr scalar, column->GetScalar(i)); + + auto& holder = static_cast(*scalar).value; + + switch (holder->type->id()) { + case Type::INT64: { + int64_t value = static_cast(*holder).value; + sqlite3_bind_int64(stmt, c + 1, value); + break; + } + case Type::FLOAT: { + double value = static_cast(*holder).value; + sqlite3_bind_double(stmt, c + 1, value); + break; + } + case Type::STRING: { + std::shared_ptr buffer = static_cast(*holder).value; + sqlite3_bind_text(stmt, c + 1, reinterpret_cast(buffer->data()), + static_cast(buffer->size()), SQLITE_TRANSIENT); + break; + } + case Type::BINARY: { + std::shared_ptr buffer = static_cast(*holder).value; + sqlite3_bind_blob(stmt, c + 1, buffer->data(), + static_cast(buffer->size()), SQLITE_TRANSIENT); + break; + } + default: + return Status::Invalid("Received unsupported data type: ", + holder->type->ToString()); + } + } + } + } + + return Status::OK(); +} + +arrow::Result> DoGetSQLiteQuery( + sqlite3* db, const std::string& query, const std::shared_ptr& schema) { + std::shared_ptr statement; + + ARROW_ASSIGN_OR_RAISE(statement, SqliteStatement::Create(db, query)); + + std::shared_ptr reader; + ARROW_ASSIGN_OR_RAISE(reader, SqliteStatementBatchReader::Create(statement, schema)); + + return std::unique_ptr(new RecordBatchStream(reader)); +} + +arrow::Result> GetFlightInfoForCommand( + const FlightDescriptor& descriptor, const std::shared_ptr& schema) { + std::vector endpoints{FlightEndpoint{{descriptor.cmd}, {}}}; + ARROW_ASSIGN_OR_RAISE(auto result, + FlightInfo::Make(*schema, descriptor, endpoints, -1, -1)) + + return std::unique_ptr(new FlightInfo(result)); +} + +std::string PrepareQueryForGetImportedOrExportedKeys(const std::string& filter) { + return R"(SELECT * FROM (SELECT NULL AS pk_catalog_name, + NULL AS pk_schema_name, + p."table" AS pk_table_name, + p."to" AS pk_column_name, + NULL AS fk_catalog_name, + NULL AS fk_schema_name, + m.name AS fk_table_name, + p."from" AS fk_column_name, + p.seq AS key_sequence, + NULL AS pk_key_name, + NULL AS fk_key_name, + CASE + WHEN p.on_update = 'CASCADE' THEN 0 + WHEN p.on_update = 'RESTRICT' THEN 1 + WHEN p.on_update = 'SET NULL' THEN 2 + WHEN p.on_update = 'NO ACTION' THEN 3 + WHEN p.on_update = 'SET DEFAULT' THEN 4 + END AS update_rule, + CASE + WHEN p.on_delete = 'CASCADE' THEN 0 + WHEN p.on_delete = 'RESTRICT' THEN 1 + WHEN p.on_delete = 'SET NULL' THEN 2 + WHEN p.on_delete = 'NO ACTION' THEN 3 + WHEN p.on_delete = 'SET DEFAULT' THEN 4 + END AS delete_rule + FROM sqlite_master m + JOIN pragma_foreign_key_list(m.name) p ON m.name != p."table" + WHERE m.type = 'table') WHERE )" + + filter + R"( ORDER BY + pk_catalog_name, pk_schema_name, pk_table_name, pk_key_name, key_sequence)"; +} + +} // namespace + +std::shared_ptr GetArrowType(const char* sqlite_type) { + if (sqlite_type == NULLPTR) { + // SQLite may not know the column type yet. + return null(); + } + + if (boost::iequals(sqlite_type, "int") || boost::iequals(sqlite_type, "integer")) { + return int64(); + } else if (boost::iequals(sqlite_type, "REAL")) { + return float64(); + } else if (boost::iequals(sqlite_type, "BLOB")) { + return binary(); + } else if (boost::iequals(sqlite_type, "TEXT") || + boost::istarts_with(sqlite_type, "char") || + boost::istarts_with(sqlite_type, "varchar")) { + return utf8(); + } else { + throw std::invalid_argument("Invalid SQLite type: " + std::string(sqlite_type)); + } +} + +class SQLiteFlightSqlServer::Impl { + sqlite3* db_; + std::map> prepared_statements_; + std::default_random_engine gen_; + + public: + explicit Impl(sqlite3* db) : db_(db) {} + + ~Impl() { sqlite3_close(db_); } + + std::string GenerateRandomString() { + uint32_t length = 16; + + std::uniform_int_distribution dist('0', 'z'); + std::string ret(length, 0); + auto get_random_char = [&]() { return dist(gen_); }; + std::generate_n(ret.begin(), length, get_random_char); + return ret; + } + + arrow::Result> GetFlightInfoStatement( + const ServerCallContext& context, const StatementQuery& command, + const FlightDescriptor& descriptor) { + const std::string& query = command.query; + + ARROW_ASSIGN_OR_RAISE(auto statement, SqliteStatement::Create(db_, query)); + + ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()); + + ARROW_ASSIGN_OR_RAISE(auto ticket_string, CreateStatementQueryTicket(query)); + std::vector endpoints{FlightEndpoint{{ticket_string}, {}}}; + ARROW_ASSIGN_OR_RAISE(auto result, + FlightInfo::Make(*schema, descriptor, endpoints, -1, -1)) + + return std::unique_ptr(new FlightInfo(result)); + } + + arrow::Result> DoGetStatement( + const ServerCallContext& context, const StatementQueryTicket& command) { + const std::string& sql = command.statement_handle; + + std::shared_ptr statement; + ARROW_ASSIGN_OR_RAISE(statement, SqliteStatement::Create(db_, sql)); + + std::shared_ptr reader; + ARROW_ASSIGN_OR_RAISE(reader, SqliteStatementBatchReader::Create(statement)); + + return std::unique_ptr(new RecordBatchStream(reader)); + } + + arrow::Result> GetFlightInfoCatalogs( + const ServerCallContext& context, const FlightDescriptor& descriptor) { + return GetFlightInfoForCommand(descriptor, SqlSchema::GetCatalogsSchema()); + } + + arrow::Result> DoGetCatalogs( + const ServerCallContext& context) { + // As SQLite doesn't support catalogs, this will return an empty record batch. + + const std::shared_ptr& schema = SqlSchema::GetCatalogsSchema(); + + StringBuilder catalog_name_builder; + ARROW_ASSIGN_OR_RAISE(auto catalog_name, catalog_name_builder.Finish()); + + const std::shared_ptr& batch = + RecordBatch::Make(schema, 0, {catalog_name}); + + ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({batch})); + + return std::unique_ptr(new RecordBatchStream(reader)); + } + + arrow::Result> GetFlightInfoSchemas( + const ServerCallContext& context, const GetDbSchemas& command, + const FlightDescriptor& descriptor) { + return GetFlightInfoForCommand(descriptor, SqlSchema::GetDbSchemasSchema()); + } + + arrow::Result> DoGetDbSchemas( + const ServerCallContext& context, const GetDbSchemas& command) { + // As SQLite doesn't support schemas, this will return an empty record batch. + + const std::shared_ptr& schema = SqlSchema::GetDbSchemasSchema(); + + StringBuilder catalog_name_builder; + ARROW_ASSIGN_OR_RAISE(auto catalog_name, catalog_name_builder.Finish()); + StringBuilder schema_name_builder; + ARROW_ASSIGN_OR_RAISE(auto schema_name, schema_name_builder.Finish()); + + const std::shared_ptr& batch = + RecordBatch::Make(schema, 0, {catalog_name, schema_name}); + + ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({batch})); + + return std::unique_ptr(new RecordBatchStream(reader)); + } + + arrow::Result> GetFlightInfoTables( + const ServerCallContext& context, const GetTables& command, + const FlightDescriptor& descriptor) { + std::vector endpoints{FlightEndpoint{{descriptor.cmd}, {}}}; + + bool include_schema = command.include_schema; + + ARROW_ASSIGN_OR_RAISE( + auto result, + FlightInfo::Make(include_schema ? *SqlSchema::GetTablesSchemaWithIncludedSchema() + : *SqlSchema::GetTablesSchema(), + descriptor, endpoints, -1, -1)) + + return std::unique_ptr(new FlightInfo(result)); + } + + arrow::Result> DoGetTables( + const ServerCallContext& context, const GetTables& command) { + std::string query = PrepareQueryForGetTables(command); + + std::shared_ptr statement; + ARROW_ASSIGN_OR_RAISE(statement, SqliteStatement::Create(db_, query)); + + std::shared_ptr reader; + ARROW_ASSIGN_OR_RAISE(reader, SqliteStatementBatchReader::Create( + statement, SqlSchema::GetTablesSchema())); + + if (command.include_schema) { + std::shared_ptr table_schema_reader = + std::make_shared(reader, query, db_); + return std::unique_ptr( + new RecordBatchStream(table_schema_reader)); + } else { + return std::unique_ptr(new RecordBatchStream(reader)); + } + } + + arrow::Result DoPutCommandStatementUpdate(const ServerCallContext& context, + const StatementUpdate& command) { + const std::string& sql = command.query; + + ARROW_ASSIGN_OR_RAISE(auto statement, SqliteStatement::Create(db_, sql)); + + return statement->ExecuteUpdate(); + } + + arrow::Result CreatePreparedStatement( + const ServerCallContext& context, + const ActionCreatePreparedStatementRequest& request) { + std::shared_ptr statement; + ARROW_ASSIGN_OR_RAISE(statement, SqliteStatement::Create(db_, request.query)); + const std::string handle = GenerateRandomString(); + prepared_statements_[handle] = statement; + + ARROW_ASSIGN_OR_RAISE(auto dataset_schema, statement->GetSchema()); + + sqlite3_stmt* stmt = statement->GetSqlite3Stmt(); + const int parameter_count = sqlite3_bind_parameter_count(stmt); + std::vector> parameter_fields; + parameter_fields.reserve(parameter_count); + + // As SQLite doesn't know the parameter types before executing the query, the + // example server is accepting any SQLite supported type as input by using a dense + // union. + const std::shared_ptr& dense_union_type = GetUnknownColumnDataType(); + + for (int i = 0; i < parameter_count; i++) { + const char* parameter_name_chars = sqlite3_bind_parameter_name(stmt, i + 1); + std::string parameter_name; + if (parameter_name_chars == NULLPTR) { + parameter_name = std::string("parameter_") + std::to_string(i + 1); + } else { + parameter_name = parameter_name_chars; + } + parameter_fields.push_back(field(parameter_name, dense_union_type)); + } + + const std::shared_ptr& parameter_schema = arrow::schema(parameter_fields); + + ActionCreatePreparedStatementResult result{.dataset_schema = dataset_schema, + .parameter_schema = parameter_schema, + .prepared_statement_handle = handle}; + + return result; + } + + Status ClosePreparedStatement(const ServerCallContext& context, + const ActionClosePreparedStatementRequest& request) { + const std::string& prepared_statement_handle = request.prepared_statement_handle; + + auto search = prepared_statements_.find(prepared_statement_handle); + if (search != prepared_statements_.end()) { + prepared_statements_.erase(prepared_statement_handle); + } else { + return Status::Invalid("Prepared statement not found"); + } + + return Status::OK(); + } + + arrow::Result> GetFlightInfoPreparedStatement( + const ServerCallContext& context, const PreparedStatementQuery& command, + const FlightDescriptor& descriptor) { + const std::string& prepared_statement_handle = command.prepared_statement_handle; + + auto search = prepared_statements_.find(prepared_statement_handle); + if (search == prepared_statements_.end()) { + return Status::Invalid("Prepared statement not found"); + } + + std::shared_ptr statement = search->second; + + ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()); + + return GetFlightInfoForCommand(descriptor, schema); + } + + arrow::Result> DoGetPreparedStatement( + const ServerCallContext& context, const PreparedStatementQuery& command) { + const std::string& prepared_statement_handle = command.prepared_statement_handle; + + auto search = prepared_statements_.find(prepared_statement_handle); + if (search == prepared_statements_.end()) { + return Status::Invalid("Prepared statement not found"); + } + + std::shared_ptr statement = search->second; + + std::shared_ptr reader; + ARROW_ASSIGN_OR_RAISE(reader, SqliteStatementBatchReader::Create(statement)); + + return std::unique_ptr(new RecordBatchStream(reader)); + } + + Status DoPutPreparedStatementQuery(const ServerCallContext& context, + const PreparedStatementQuery& command, + FlightMessageReader* reader, + FlightMetadataWriter* writer) { + const std::string& prepared_statement_handle = command.prepared_statement_handle; + ARROW_ASSIGN_OR_RAISE( + auto statement, + GetStatementByHandle(prepared_statements_, prepared_statement_handle)); + + sqlite3_stmt* stmt = statement->GetSqlite3Stmt(); + ARROW_RETURN_NOT_OK(SetParametersOnSQLiteStatement(stmt, reader)); + + return Status::OK(); + } + + arrow::Result DoPutPreparedStatementUpdate( + const ServerCallContext& context, const PreparedStatementUpdate& command, + FlightMessageReader* reader) { + const std::string& prepared_statement_handle = command.prepared_statement_handle; + ARROW_ASSIGN_OR_RAISE( + auto statement, + GetStatementByHandle(prepared_statements_, prepared_statement_handle)); + + sqlite3_stmt* stmt = statement->GetSqlite3Stmt(); + ARROW_RETURN_NOT_OK(SetParametersOnSQLiteStatement(stmt, reader)); + + return statement->ExecuteUpdate(); + } + + arrow::Result> GetFlightInfoTableTypes( + const ServerCallContext& context, const FlightDescriptor& descriptor) { + return GetFlightInfoForCommand(descriptor, SqlSchema::GetTableTypesSchema()); + } + + arrow::Result> DoGetTableTypes( + const ServerCallContext& context) { + std::string query = "SELECT DISTINCT type as table_type FROM sqlite_master"; + + return DoGetSQLiteQuery(db_, query, SqlSchema::GetTableTypesSchema()); + } + + arrow::Result> GetFlightInfoPrimaryKeys( + const ServerCallContext& context, const GetPrimaryKeys& command, + const FlightDescriptor& descriptor) { + return GetFlightInfoForCommand(descriptor, SqlSchema::GetPrimaryKeysSchema()); + } + + arrow::Result> DoGetPrimaryKeys( + const ServerCallContext& context, const GetPrimaryKeys& command) { + std::stringstream table_query; + + // The field key_name can not be recovered by the sqlite, so it is being set + // to null following the same pattern for catalog_name and schema_name. + table_query << "SELECT null as catalog_name, null as schema_name, table_name, " + "name as column_name, pk as key_sequence, null as key_name\n" + "FROM pragma_table_info(table_name)\n" + " JOIN (SELECT null as catalog_name, null as schema_name, name as " + "table_name, type as table_type\n" + "FROM sqlite_master) where 1=1 and pk != 0"; + + const TableRef& table_ref = command.table_ref; + if (table_ref.catalog.has_value()) { + table_query << " and catalog_name LIKE '" << table_ref.catalog.value() << "'"; + } + + if (table_ref.db_schema.has_value()) { + table_query << " and schema_name LIKE '" << table_ref.db_schema.value() << "'"; + } + + table_query << " and table_name LIKE '" << table_ref.table << "'"; + + return DoGetSQLiteQuery(db_, table_query.str(), SqlSchema::GetPrimaryKeysSchema()); + } + + arrow::Result> GetFlightInfoImportedKeys( + const ServerCallContext& context, const GetImportedKeys& command, + const FlightDescriptor& descriptor) { + return GetFlightInfoForCommand(descriptor, SqlSchema::GetImportedKeysSchema()); + } + + arrow::Result> DoGetImportedKeys( + const ServerCallContext& context, const GetImportedKeys& command) { + const TableRef& table_ref = command.table_ref; + std::string filter = "fk_table_name = '" + table_ref.table + "'"; + if (table_ref.catalog.has_value()) { + filter += " AND fk_catalog_name = '" + table_ref.catalog.value() + "'"; + } + if (table_ref.db_schema.has_value()) { + filter += " AND fk_schema_name = '" + table_ref.db_schema.value() + "'"; + } + std::string query = PrepareQueryForGetImportedOrExportedKeys(filter); + + return DoGetSQLiteQuery(db_, query, SqlSchema::GetImportedKeysSchema()); + } + + arrow::Result> GetFlightInfoExportedKeys( + const ServerCallContext& context, const GetExportedKeys& command, + const FlightDescriptor& descriptor) { + return GetFlightInfoForCommand(descriptor, SqlSchema::GetExportedKeysSchema()); + } + + arrow::Result> DoGetExportedKeys( + const ServerCallContext& context, const GetExportedKeys& command) { + const TableRef& table_ref = command.table_ref; + std::string filter = "pk_table_name = '" + table_ref.table + "'"; + if (table_ref.catalog.has_value()) { + filter += " AND pk_catalog_name = '" + table_ref.catalog.value() + "'"; + } + if (table_ref.db_schema.has_value()) { + filter += " AND pk_schema_name = '" + table_ref.db_schema.value() + "'"; + } + std::string query = PrepareQueryForGetImportedOrExportedKeys(filter); + + return DoGetSQLiteQuery(db_, query, SqlSchema::GetExportedKeysSchema()); + } + + arrow::Result> GetFlightInfoCrossReference( + const ServerCallContext& context, const GetCrossReference& command, + const FlightDescriptor& descriptor) { + return GetFlightInfoForCommand(descriptor, SqlSchema::GetCrossReferenceSchema()); + } + + arrow::Result> DoGetCrossReference( + const ServerCallContext& context, const GetCrossReference& command) { + const TableRef& pk_table_ref = command.pk_table_ref; + std::string filter = "pk_table_name = '" + pk_table_ref.table + "'"; + if (pk_table_ref.catalog.has_value()) { + filter += " AND pk_catalog_name = '" + pk_table_ref.catalog.value() + "'"; + } + if (pk_table_ref.db_schema.has_value()) { + filter += " AND pk_schema_name = '" + pk_table_ref.db_schema.value() + "'"; + } + + const TableRef& fk_table_ref = command.fk_table_ref; + filter += " AND fk_table_name = '" + fk_table_ref.table + "'"; + if (fk_table_ref.catalog.has_value()) { + filter += " AND fk_catalog_name = '" + fk_table_ref.catalog.value() + "'"; + } + if (fk_table_ref.db_schema.has_value()) { + filter += " AND fk_schema_name = '" + fk_table_ref.db_schema.value() + "'"; + } + std::string query = PrepareQueryForGetImportedOrExportedKeys(filter); + + return DoGetSQLiteQuery(db_, query, SqlSchema::GetCrossReferenceSchema()); + } + + Status ExecuteSql(const std::string& sql) { + char* err_msg = nullptr; + int rc = sqlite3_exec(db_, sql.c_str(), nullptr, nullptr, &err_msg); + if (rc != SQLITE_OK) { + std::string error_msg; + if (err_msg != nullptr) { + error_msg = err_msg; + } + sqlite3_free(err_msg); + return Status::ExecutionError(error_msg); + } + return Status::OK(); + } +}; + +SQLiteFlightSqlServer::SQLiteFlightSqlServer(std::shared_ptr impl) + : impl_(std::move(impl)) {} + +arrow::Result> SQLiteFlightSqlServer::Create() { + sqlite3* db = nullptr; + + if (sqlite3_open(":memory:", &db)) { + std::string err_msg = "Can't open database: "; + if (db != nullptr) { + err_msg += sqlite3_errmsg(db); + sqlite3_close(db); + } else { + err_msg += "Unable to start SQLite. Insufficient memory"; + } + + return Status::Invalid(err_msg); + } + + std::shared_ptr impl = std::make_shared(db); + + std::shared_ptr result(new SQLiteFlightSqlServer(impl)); + for (const auto& id_to_result : GetSqlInfoResultMap()) { + result->RegisterSqlInfo(id_to_result.first, id_to_result.second); + } + + ARROW_RETURN_NOT_OK(result->ExecuteSql(R"( + CREATE TABLE foreignTable ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + foreignName varchar(100), + value int); + + CREATE TABLE intTable ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + keyName varchar(100), + value int, + foreignId int references foreignTable(id)); + + INSERT INTO foreignTable (foreignName, value) VALUES ('keyOne', 1); + INSERT INTO foreignTable (foreignName, value) VALUES ('keyTwo', 0); + INSERT INTO foreignTable (foreignName, value) VALUES ('keyThree', -1); + INSERT INTO intTable (keyName, value, foreignId) VALUES ('one', 1, 1); + INSERT INTO intTable (keyName, value, foreignId) VALUES ('zero', 0, 1); + INSERT INTO intTable (keyName, value, foreignId) VALUES ('negative one', -1, 1); + INSERT INTO intTable (keyName, value, foreignId) VALUES (NULL, NULL, NULL); + )")); + + return result; +} + +SQLiteFlightSqlServer::~SQLiteFlightSqlServer() = default; + +Status SQLiteFlightSqlServer::ExecuteSql(const std::string& sql) { + return impl_->ExecuteSql(sql); +} + +arrow::Result> SQLiteFlightSqlServer::GetFlightInfoStatement( + const ServerCallContext& context, const StatementQuery& command, + const FlightDescriptor& descriptor) { + return impl_->GetFlightInfoStatement(context, command, descriptor); +} + +arrow::Result> SQLiteFlightSqlServer::DoGetStatement( + const ServerCallContext& context, const StatementQueryTicket& command) { + return impl_->DoGetStatement(context, command); +} + +arrow::Result> SQLiteFlightSqlServer::GetFlightInfoCatalogs( + const ServerCallContext& context, const FlightDescriptor& descriptor) { + return impl_->GetFlightInfoCatalogs(context, descriptor); +} + +arrow::Result> SQLiteFlightSqlServer::DoGetCatalogs( + const ServerCallContext& context) { + return impl_->DoGetCatalogs(context); +} + +arrow::Result> SQLiteFlightSqlServer::GetFlightInfoSchemas( + const ServerCallContext& context, const GetDbSchemas& command, + const FlightDescriptor& descriptor) { + return impl_->GetFlightInfoSchemas(context, command, descriptor); +} + +arrow::Result> SQLiteFlightSqlServer::DoGetDbSchemas( + const ServerCallContext& context, const GetDbSchemas& command) { + return impl_->DoGetDbSchemas(context, command); +} + +arrow::Result> SQLiteFlightSqlServer::GetFlightInfoTables( + const ServerCallContext& context, const GetTables& command, + const FlightDescriptor& descriptor) { + return impl_->GetFlightInfoTables(context, command, descriptor); +} + +arrow::Result> SQLiteFlightSqlServer::DoGetTables( + const ServerCallContext& context, const GetTables& command) { + return impl_->DoGetTables(context, command); +} + +arrow::Result SQLiteFlightSqlServer::DoPutCommandStatementUpdate( + const ServerCallContext& context, const StatementUpdate& command) { + return impl_->DoPutCommandStatementUpdate(context, command); +} + +arrow::Result +SQLiteFlightSqlServer::CreatePreparedStatement( + const ServerCallContext& context, + const ActionCreatePreparedStatementRequest& request) { + return impl_->CreatePreparedStatement(context, request); +} + +Status SQLiteFlightSqlServer::ClosePreparedStatement( + const ServerCallContext& context, + const ActionClosePreparedStatementRequest& request) { + return impl_->ClosePreparedStatement(context, request); +} + +arrow::Result> +SQLiteFlightSqlServer::GetFlightInfoPreparedStatement( + const ServerCallContext& context, const PreparedStatementQuery& command, + const FlightDescriptor& descriptor) { + return impl_->GetFlightInfoPreparedStatement(context, command, descriptor); +} + +arrow::Result> +SQLiteFlightSqlServer::DoGetPreparedStatement(const ServerCallContext& context, + const PreparedStatementQuery& command) { + return impl_->DoGetPreparedStatement(context, command); +} + +Status SQLiteFlightSqlServer::DoPutPreparedStatementQuery( + const ServerCallContext& context, const PreparedStatementQuery& command, + FlightMessageReader* reader, FlightMetadataWriter* writer) { + return impl_->DoPutPreparedStatementQuery(context, command, reader, writer); +} + +arrow::Result SQLiteFlightSqlServer::DoPutPreparedStatementUpdate( + const ServerCallContext& context, const PreparedStatementUpdate& command, + FlightMessageReader* reader) { + return impl_->DoPutPreparedStatementUpdate(context, command, reader); +} + +arrow::Result> SQLiteFlightSqlServer::GetFlightInfoTableTypes( + const ServerCallContext& context, const FlightDescriptor& descriptor) { + return impl_->GetFlightInfoTableTypes(context, descriptor); +} + +arrow::Result> SQLiteFlightSqlServer::DoGetTableTypes( + const ServerCallContext& context) { + return impl_->DoGetTableTypes(context); +} + +arrow::Result> +SQLiteFlightSqlServer::GetFlightInfoPrimaryKeys(const ServerCallContext& context, + const GetPrimaryKeys& command, + const FlightDescriptor& descriptor) { + return impl_->GetFlightInfoPrimaryKeys(context, command, descriptor); +} + +arrow::Result> SQLiteFlightSqlServer::DoGetPrimaryKeys( + const ServerCallContext& context, const GetPrimaryKeys& command) { + return impl_->DoGetPrimaryKeys(context, command); +} + +arrow::Result> +SQLiteFlightSqlServer::GetFlightInfoImportedKeys(const ServerCallContext& context, + const GetImportedKeys& command, + const FlightDescriptor& descriptor) { + return impl_->GetFlightInfoImportedKeys(context, command, descriptor); +} + +arrow::Result> SQLiteFlightSqlServer::DoGetImportedKeys( + const ServerCallContext& context, const GetImportedKeys& command) { + return impl_->DoGetImportedKeys(context, command); +} + +arrow::Result> +SQLiteFlightSqlServer::GetFlightInfoExportedKeys(const ServerCallContext& context, + const GetExportedKeys& command, + const FlightDescriptor& descriptor) { + return impl_->GetFlightInfoExportedKeys(context, command, descriptor); +} + +arrow::Result> SQLiteFlightSqlServer::DoGetExportedKeys( + const ServerCallContext& context, const GetExportedKeys& command) { + return impl_->DoGetExportedKeys(context, command); +} + +arrow::Result> +SQLiteFlightSqlServer::GetFlightInfoCrossReference(const ServerCallContext& context, + const GetCrossReference& command, + const FlightDescriptor& descriptor) { + return impl_->GetFlightInfoCrossReference(context, command, descriptor); +} + +arrow::Result> +SQLiteFlightSqlServer::DoGetCrossReference(const ServerCallContext& context, + const GetCrossReference& command) { + return impl_->DoGetCrossReference(context, command); +} + +} // namespace example +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/example/sqlite_server.h b/cpp/src/arrow/flight/sql/example/sqlite_server.h new file mode 100644 index 00000000000..b2954b8703e --- /dev/null +++ b/cpp/src/arrow/flight/sql/example/sqlite_server.h @@ -0,0 +1,142 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include +#include + +#include "arrow/api.h" +#include "arrow/flight/sql/example/sqlite_statement.h" +#include "arrow/flight/sql/example/sqlite_statement_batch_reader.h" +#include "arrow/flight/sql/server.h" + +namespace arrow { +namespace flight { +namespace sql { +namespace example { + +/// \brief Convert a column type to a ArrowType. +/// \param sqlite_type the sqlite type. +/// \return The equivalent ArrowType. +std::shared_ptr GetArrowType(const char* sqlite_type); + +/// \brief Get the DataType used when parameter type is not known. +/// \return DataType used when parameter type is not known. +inline std::shared_ptr GetUnknownColumnDataType() { + return dense_union({ + field("string", utf8()), + field("bytes", binary()), + field("bigint", int64()), + field("double", float64()), + }); +} + +/// \brief Example implementation of FlightSqlServerBase backed by an in-memory SQLite3 +/// database. +class SQLiteFlightSqlServer : public FlightSqlServerBase { + public: + ~SQLiteFlightSqlServer() override; + + static arrow::Result> Create(); + + /// \brief Auxiliary method used to execute an arbitrary SQL statement on the underlying + /// SQLite database. + Status ExecuteSql(const std::string& sql); + + arrow::Result> GetFlightInfoStatement( + const ServerCallContext& context, const StatementQuery& command, + const FlightDescriptor& descriptor) override; + + arrow::Result> DoGetStatement( + const ServerCallContext& context, const StatementQueryTicket& command) override; + arrow::Result> GetFlightInfoCatalogs( + const ServerCallContext& context, const FlightDescriptor& descriptor) override; + arrow::Result> DoGetCatalogs( + const ServerCallContext& context) override; + arrow::Result> GetFlightInfoSchemas( + const ServerCallContext& context, const GetDbSchemas& command, + const FlightDescriptor& descriptor) override; + arrow::Result> DoGetDbSchemas( + const ServerCallContext& context, const GetDbSchemas& command) override; + arrow::Result DoPutCommandStatementUpdate( + const ServerCallContext& context, const StatementUpdate& update) override; + arrow::Result CreatePreparedStatement( + const ServerCallContext& context, + const ActionCreatePreparedStatementRequest& request) override; + Status ClosePreparedStatement( + const ServerCallContext& context, + const ActionClosePreparedStatementRequest& request) override; + arrow::Result> GetFlightInfoPreparedStatement( + const ServerCallContext& context, const PreparedStatementQuery& command, + const FlightDescriptor& descriptor) override; + arrow::Result> DoGetPreparedStatement( + const ServerCallContext& context, const PreparedStatementQuery& command) override; + Status DoPutPreparedStatementQuery(const ServerCallContext& context, + const PreparedStatementQuery& command, + FlightMessageReader* reader, + FlightMetadataWriter* writer) override; + arrow::Result DoPutPreparedStatementUpdate( + const ServerCallContext& context, const PreparedStatementUpdate& command, + FlightMessageReader* reader) override; + + arrow::Result> GetFlightInfoTables( + const ServerCallContext& context, const GetTables& command, + const FlightDescriptor& descriptor) override; + + arrow::Result> DoGetTables( + const ServerCallContext& context, const GetTables& command) override; + arrow::Result> GetFlightInfoTableTypes( + const ServerCallContext& context, const FlightDescriptor& descriptor) override; + arrow::Result> DoGetTableTypes( + const ServerCallContext& context) override; + arrow::Result> GetFlightInfoImportedKeys( + const ServerCallContext& context, const GetImportedKeys& command, + const FlightDescriptor& descriptor) override; + arrow::Result> DoGetImportedKeys( + const ServerCallContext& context, const GetImportedKeys& command) override; + arrow::Result> GetFlightInfoExportedKeys( + const ServerCallContext& context, const GetExportedKeys& command, + const FlightDescriptor& descriptor) override; + arrow::Result> DoGetExportedKeys( + const ServerCallContext& context, const GetExportedKeys& command) override; + arrow::Result> GetFlightInfoCrossReference( + const ServerCallContext& context, const GetCrossReference& command, + const FlightDescriptor& descriptor) override; + arrow::Result> DoGetCrossReference( + const ServerCallContext& context, const GetCrossReference& command) override; + + arrow::Result> GetFlightInfoPrimaryKeys( + const ServerCallContext& context, const GetPrimaryKeys& command, + const FlightDescriptor& descriptor) override; + + arrow::Result> DoGetPrimaryKeys( + const ServerCallContext& context, const GetPrimaryKeys& command) override; + + private: + class Impl; + std::shared_ptr impl_; + + explicit SQLiteFlightSqlServer(std::shared_ptr impl); +}; + +} // namespace example +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/example/sqlite_sql_info.cc b/cpp/src/arrow/flight/sql/example/sqlite_sql_info.cc new file mode 100644 index 00000000000..94f25b39017 --- /dev/null +++ b/cpp/src/arrow/flight/sql/example/sqlite_sql_info.cc @@ -0,0 +1,223 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/sql/example/sqlite_sql_info.h" + +#include "arrow/flight/sql/types.h" + +namespace arrow { +namespace flight { +namespace sql { +namespace example { + +/// \brief Gets the mapping from SQL info ids to SqlInfoResult instances. +/// \return the cache. +SqlInfoResultMap GetSqlInfoResultMap() { + return { + {SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME, + SqlInfoResult(std::string("db_name"))}, + {SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_VERSION, + SqlInfoResult(std::string("sqlite 3"))}, + {SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_ARROW_VERSION, + SqlInfoResult(std::string("7.0.0-SNAPSHOT" /* Only an example */))}, + {SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY, SqlInfoResult(false)}, + {SqlInfoOptions::SqlInfo::SQL_DDL_CATALOG, + SqlInfoResult(false /* SQLite 3 does not support catalogs */)}, + {SqlInfoOptions::SqlInfo::SQL_DDL_SCHEMA, + SqlInfoResult(false /* SQLite 3 does not support schemas */)}, + {SqlInfoOptions::SqlInfo::SQL_DDL_TABLE, SqlInfoResult(true)}, + {SqlInfoOptions::SqlInfo::SQL_IDENTIFIER_CASE, + SqlInfoResult(int64_t(SqlInfoOptions::SqlSupportedCaseSensitivity:: + SQL_CASE_SENSITIVITY_CASE_INSENSITIVE))}, + {SqlInfoOptions::SqlInfo::SQL_IDENTIFIER_QUOTE_CHAR, + SqlInfoResult(std::string("\""))}, + {SqlInfoOptions::SqlInfo::SQL_QUOTED_IDENTIFIER_CASE, + SqlInfoResult(int64_t(SqlInfoOptions::SqlSupportedCaseSensitivity:: + SQL_CASE_SENSITIVITY_CASE_INSENSITIVE))}, + {SqlInfoOptions::SqlInfo::SQL_ALL_TABLES_ARE_SELECTABLE, SqlInfoResult(true)}, + {SqlInfoOptions::SqlInfo::SQL_NULL_ORDERING, + SqlInfoResult( + int64_t(SqlInfoOptions::SqlNullOrdering::SQL_NULLS_SORTED_AT_START))}, + {SqlInfoOptions::SqlInfo::SQL_KEYWORDS, + SqlInfoResult(std::vector({"ABORT", + "ACTION", + "ADD", + "AFTER", + "ALL", + "ALTER", + "ALWAYS", + "ANALYZE", + "AND", + "AS", + "ASC", + "ATTACH", + "AUTOINCREMENT", + "BEFORE", + "BEGIN", + "BETWEEN", + "BY", + "CASCADE", + "CASE", + "CAST", + "CHECK", + "COLLATE", + "COLUMN", + "COMMIT", + "CONFLICT", + "CONSTRAINT", + "CREATE", + "CROSS", + "CURRENT", + "CURRENT_DATE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "DATABASE", + "DEFAULT", + "DEFERRABLE", + "DEFERRED", + "DELETE", + "DESC", + "DETACH", + "DISTINCT", + "DO", + "DROP", + "EACH", + "ELSE", + "END", + "ESCAPE", + "EXCEPT", + "EXCLUDE", + "EXCLUSIVE", + "EXISTS", + "EXPLAIN", + "FAIL", + "FILTER", + "FIRST", + "FOLLOWING", + "FOR", + "FOREIGN", + "FROM", + "FULL", + "GENERATED", + "GLOB", + "GROUP", + "GROUPS", + "HAVING", + "IF", + "IGNORE", + "IMMEDIATE", + "IN", + "INDEX", + "INDEXED", + "INITIALLY", + "INNER", + "INSERT", + "INSTEAD", + "INTERSECT", + "INTO", + "IS", + "ISNULL", + "JOIN", + "KEY", + "LAST", + "LEFT", + "LIKE", + "LIMIT", + "MATCH", + "MATERIALIZED", + "NATURAL", + "NO", + "NOT", + "NOTHING", + "NOTNULL", + "NULL", + "NULLS", + "OF", + "OFFSET", + "ON", + "OR", + "ORDER", + "OTHERS", + "OUTER", + "OVER", + "PARTITION", + "PLAN", + "PRAGMA", + "PRECEDING", + "PRIMARY", + "QUERY", + "RAISE", + "RANGE", + "RECURSIVE", + "REFERENCES", + "REGEXP", + "REINDEX", + "RELEASE", + "RENAME", + "REPLACE", + "RESTRICT", + "RETURNING", + "RIGHT", + "ROLLBACK", + "ROW", + "ROWS", + "SAVEPOINT", + "SELECT", + "SET", + "TABLE", + "TEMP", + "TEMPORARY", + "THEN", + "TIES", + "TO", + "TRANSACTION", + "TRIGGER", + "UNBOUNDED", + "UNION", + "UNIQUE", + "UPDATE", + "USING", + "VACUUM", + "VALUES", + "VIEW", + "VIRTUAL", + "WHEN", + "WHERE", + "WINDOW", + "WITH", + "WITHOUT"}))}, + {SqlInfoOptions::SqlInfo::SQL_NUMERIC_FUNCTIONS, + SqlInfoResult(std::vector( + {"ACOS", "ACOSH", "ASIN", "ASINH", "ATAN", "ATAN2", "ATANH", "CEIL", + "CEILING", "COS", "COSH", "DEGREES", "EXP", "FLOOR", "LN", "LOG", + "LOG", "LOG10", "LOG2", "MOD", "PI", "POW", "POWER", "RADIANS", + "SIN", "SINH", "SQRT", "TAN", "TANH", "TRUNC"}))}, + {SqlInfoOptions::SqlInfo::SQL_STRING_FUNCTIONS, + SqlInfoResult( + std::vector({"SUBSTR", "TRIM", "LTRIM", "RTRIM", "LENGTH", + "REPLACE", "UPPER", "LOWER", "INSTR"}))}, + {SqlInfoOptions::SqlInfo::SQL_SUPPORTS_CONVERT, + SqlInfoResult(std::unordered_map>( + {{SqlInfoOptions::SqlSupportsConvert::SQL_CONVERT_BIGINT, + std::vector( + {SqlInfoOptions::SqlSupportsConvert::SQL_CONVERT_INTEGER})}}))}}; +} + +} // namespace example +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/example/sqlite_sql_info.h b/cpp/src/arrow/flight/sql/example/sqlite_sql_info.h new file mode 100644 index 00000000000..3c6dd42135e --- /dev/null +++ b/cpp/src/arrow/flight/sql/example/sqlite_sql_info.h @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/flight/sql/types.h" + +namespace arrow { +namespace flight { +namespace sql { +namespace example { + +/// \brief Gets the mapping from SQL info ids to SqlInfoResult instances. +/// \return the cache. +SqlInfoResultMap GetSqlInfoResultMap(); + +} // namespace example +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/example/sqlite_statement.cc b/cpp/src/arrow/flight/sql/example/sqlite_statement.cc new file mode 100644 index 00000000000..018f8de37db --- /dev/null +++ b/cpp/src/arrow/flight/sql/example/sqlite_statement.cc @@ -0,0 +1,137 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/sql/example/sqlite_statement.h" + +#include + +#include + +#include "arrow/flight/sql/example/sqlite_server.h" + +namespace arrow { +namespace flight { +namespace sql { +namespace example { + +std::shared_ptr GetDataTypeFromSqliteType(const int column_type) { + switch (column_type) { + case SQLITE_INTEGER: + return int64(); + case SQLITE_FLOAT: + return float64(); + case SQLITE_BLOB: + return binary(); + case SQLITE_TEXT: + return utf8(); + case SQLITE_NULL: + default: + return null(); + } +} + +arrow::Result> SqliteStatement::Create( + sqlite3* db, const std::string& sql) { + sqlite3_stmt* stmt = nullptr; + int rc = + sqlite3_prepare_v2(db, sql.c_str(), static_cast(sql.size()), &stmt, NULLPTR); + + if (rc != SQLITE_OK) { + std::string err_msg = "Can't prepare statement: " + std::string(sqlite3_errmsg(db)); + if (stmt != nullptr) { + rc = sqlite3_finalize(stmt); + if (rc != SQLITE_OK) { + err_msg += "; Failed to finalize SQLite statement: "; + err_msg += std::string(sqlite3_errmsg(db)); + } + } + return Status::Invalid(err_msg); + } + + std::shared_ptr result(new SqliteStatement(db, stmt)); + return result; +} + +arrow::Result> SqliteStatement::GetSchema() const { + std::vector> fields; + int column_count = sqlite3_column_count(stmt_); + for (int i = 0; i < column_count; i++) { + const char* column_name = sqlite3_column_name(stmt_, i); + + // SQLite does not always provide column types, especially when the statement has not + // been executed yet. Because of this behaviour this method tries to get the column + // types in two attempts: + // 1. Use sqlite3_column_type(), which return SQLITE_NULL if the statement has not + // been executed yet + // 2. Use sqlite3_column_decltype(), which returns correctly if given column is + // declared in the table. + // Because of this limitation, it is not possible to know the column types for some + // prepared statements, in this case it returns a dense_union type covering any type + // SQLite supports. + const int column_type = sqlite3_column_type(stmt_, i); + std::shared_ptr data_type = GetDataTypeFromSqliteType(column_type); + if (data_type->id() == Type::NA) { + // Try to retrieve column type from sqlite3_column_decltype + const char* column_decltype = sqlite3_column_decltype(stmt_, i); + if (column_decltype != NULLPTR) { + data_type = GetArrowType(column_decltype); + } else { + // If it can not determine the actual column type, return a dense_union type + // covering any type SQLite supports. + data_type = GetUnknownColumnDataType(); + } + } + + fields.push_back(arrow::field(column_name, data_type)); + } + + return arrow::schema(fields); +} + +SqliteStatement::~SqliteStatement() { sqlite3_finalize(stmt_); } + +arrow::Result SqliteStatement::Step() { + int rc = sqlite3_step(stmt_); + if (rc == SQLITE_ERROR) { + return Status::ExecutionError("A SQLite runtime error has occurred: ", + sqlite3_errmsg(db_)); + } + + return rc; +} + +arrow::Result SqliteStatement::Reset() { + int rc = sqlite3_reset(stmt_); + if (rc == SQLITE_ERROR) { + return Status::ExecutionError("A SQLite runtime error has occurred: ", + sqlite3_errmsg(db_)); + } + + return rc; +} + +sqlite3_stmt* SqliteStatement::GetSqlite3Stmt() const { return stmt_; } + +arrow::Result SqliteStatement::ExecuteUpdate() { + ARROW_RETURN_NOT_OK(Step()); + return sqlite3_changes(db_); +} + +} // namespace example +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/example/sqlite_statement.h b/cpp/src/arrow/flight/sql/example/sqlite_statement.h new file mode 100644 index 00000000000..a3f086abc47 --- /dev/null +++ b/cpp/src/arrow/flight/sql/example/sqlite_statement.h @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include +#include + +#include "arrow/type_fwd.h" + +namespace arrow { +namespace flight { +namespace sql { +namespace example { + +class SqliteStatement { + public: + /// \brief Creates a SQLite3 statement. + /// \param[in] db SQLite3 database instance. + /// \param[in] sql SQL statement. + /// \return A SqliteStatement object. + static arrow::Result> Create(sqlite3* db, + const std::string& sql); + + ~SqliteStatement(); + + /// \brief Creates an Arrow Schema based on the results of this statement. + /// \return The resulting Schema. + arrow::Result> GetSchema() const; + + /// \brief Steps on underlying sqlite3_stmt. + /// \return The resulting return code from SQLite. + arrow::Result Step(); + + /// \brief Reset the state of the sqlite3_stmt. + /// \return The resulting return code from SQLite. + arrow::Result Reset(); + + /// \brief Returns the underlying sqlite3_stmt. + /// \return A sqlite statement. + sqlite3_stmt* GetSqlite3Stmt() const; + + /// \brief Executes an UPDATE, INSERT or DELETE statement. + /// \return The number of rows changed by execution. + arrow::Result ExecuteUpdate(); + + private: + sqlite3* db_; + sqlite3_stmt* stmt_; + + SqliteStatement(sqlite3* db, sqlite3_stmt* stmt) : db_(db), stmt_(stmt) {} +}; + +} // namespace example +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.cc b/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.cc new file mode 100644 index 00000000000..08a03c4ca60 --- /dev/null +++ b/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.cc @@ -0,0 +1,189 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/sql/example/sqlite_statement_batch_reader.h" + +#include + +#include "arrow/builder.h" +#include "arrow/flight/sql/example/sqlite_statement.h" + +#define STRING_BUILDER_CASE(TYPE_CLASS, STMT, COLUMN) \ + case TYPE_CLASS##Type::type_id: { \ + int bytes = sqlite3_column_bytes(STMT, COLUMN); \ + const unsigned char* string = sqlite3_column_text(STMT, COLUMN); \ + if (string == nullptr) { \ + ARROW_RETURN_NOT_OK( \ + (reinterpret_cast(builder)).AppendNull()); \ + break; \ + } \ + ARROW_RETURN_NOT_OK( \ + (reinterpret_cast(builder)).Append(string, bytes)); \ + break; \ + } + +#define BINARY_BUILDER_CASE(TYPE_CLASS, STMT, COLUMN) \ + case TYPE_CLASS##Type::type_id: { \ + int bytes = sqlite3_column_bytes(STMT, COLUMN); \ + const void* blob = sqlite3_column_blob(STMT, COLUMN); \ + if (blob == nullptr) { \ + ARROW_RETURN_NOT_OK( \ + (reinterpret_cast(builder)).AppendNull()); \ + break; \ + } \ + ARROW_RETURN_NOT_OK( \ + (reinterpret_cast(builder)).Append((char*)blob, bytes)); \ + break; \ + } + +#define INT_BUILDER_CASE(TYPE_CLASS, STMT, COLUMN) \ + case TYPE_CLASS##Type::type_id: { \ + if (sqlite3_column_type(stmt_, i) == SQLITE_NULL) { \ + ARROW_RETURN_NOT_OK( \ + (reinterpret_cast(builder)).AppendNull()); \ + break; \ + } \ + sqlite3_int64 value = sqlite3_column_int64(STMT, COLUMN); \ + ARROW_RETURN_NOT_OK( \ + (reinterpret_cast(builder)).Append(value)); \ + break; \ + } + +#define FLOAT_BUILDER_CASE(TYPE_CLASS, STMT, COLUMN) \ + case TYPE_CLASS##Type::type_id: { \ + if (sqlite3_column_type(stmt_, i) == SQLITE_NULL) { \ + ARROW_RETURN_NOT_OK( \ + (reinterpret_cast(builder)).AppendNull()); \ + break; \ + } \ + double value = sqlite3_column_double(STMT, COLUMN); \ + ARROW_RETURN_NOT_OK( \ + (reinterpret_cast(builder)).Append(value)); \ + break; \ + } + +namespace arrow { +namespace flight { +namespace sql { +namespace example { + +// Batch size for SQLite statement results +static constexpr int kMaxBatchSize = 1024; + +std::shared_ptr SqliteStatementBatchReader::schema() const { return schema_; } + +SqliteStatementBatchReader::SqliteStatementBatchReader( + std::shared_ptr statement, std::shared_ptr schema) + : statement_(std::move(statement)), + schema_(std::move(schema)), + rc_(SQLITE_OK), + already_executed_(false) {} + +arrow::Result> +SqliteStatementBatchReader::Create(const std::shared_ptr& statement_) { + ARROW_RETURN_NOT_OK(statement_->Step()); + + ARROW_ASSIGN_OR_RAISE(auto schema, statement_->GetSchema()); + + std::shared_ptr result( + new SqliteStatementBatchReader(statement_, schema)); + + return result; +} + +arrow::Result> +SqliteStatementBatchReader::Create(const std::shared_ptr& statement, + const std::shared_ptr& schema) { + std::shared_ptr result( + new SqliteStatementBatchReader(statement, schema)); + + return result; +} + +Status SqliteStatementBatchReader::ReadNext(std::shared_ptr* out) { + sqlite3_stmt* stmt_ = statement_->GetSqlite3Stmt(); + + const int num_fields = schema_->num_fields(); + std::vector> builders(num_fields); + + for (int i = 0; i < num_fields; i++) { + const std::shared_ptr& field = schema_->field(i); + const std::shared_ptr& field_type = field->type(); + + ARROW_RETURN_NOT_OK(MakeBuilder(default_memory_pool(), field_type, &builders[i])); + } + + if (!already_executed_) { + ARROW_ASSIGN_OR_RAISE(rc_, statement_->Reset()); + ARROW_ASSIGN_OR_RAISE(rc_, statement_->Step()); + already_executed_ = true; + } + + int64_t rows = 0; + while (rows < kMaxBatchSize && rc_ == SQLITE_ROW) { + rows++; + for (int i = 0; i < num_fields; i++) { + const std::shared_ptr& field = schema_->field(i); + const std::shared_ptr& field_type = field->type(); + ArrayBuilder& builder = *builders[i]; + + // NOTE: This is not the optimal way of building Arrow vectors. + // That would be to presize the builders to avoiding several resizing operations + // when appending values and also to build one vector at a time. + switch (field_type->id()) { + INT_BUILDER_CASE(Int64, stmt_, i) + INT_BUILDER_CASE(UInt64, stmt_, i) + INT_BUILDER_CASE(Int32, stmt_, i) + INT_BUILDER_CASE(UInt32, stmt_, i) + INT_BUILDER_CASE(Int16, stmt_, i) + INT_BUILDER_CASE(UInt16, stmt_, i) + INT_BUILDER_CASE(Int8, stmt_, i) + INT_BUILDER_CASE(UInt8, stmt_, i) + FLOAT_BUILDER_CASE(Double, stmt_, i) + FLOAT_BUILDER_CASE(Float, stmt_, i) + FLOAT_BUILDER_CASE(HalfFloat, stmt_, i) + BINARY_BUILDER_CASE(Binary, stmt_, i) + BINARY_BUILDER_CASE(LargeBinary, stmt_, i) + STRING_BUILDER_CASE(String, stmt_, i) + STRING_BUILDER_CASE(LargeString, stmt_, i) + default: + return Status::NotImplemented("Not implemented SQLite data conversion to ", + field_type->name()); + } + } + + ARROW_ASSIGN_OR_RAISE(rc_, statement_->Step()); + } + + if (rows > 0) { + std::vector> arrays(builders.size()); + for (int i = 0; i < num_fields; i++) { + ARROW_RETURN_NOT_OK(builders[i]->Finish(&arrays[i])); + } + + *out = RecordBatch::Make(schema_, rows, arrays); + } else { + *out = NULLPTR; + } + + return Status::OK(); +} + +} // namespace example +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.h b/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.h new file mode 100644 index 00000000000..8a6bc6078e7 --- /dev/null +++ b/cpp/src/arrow/flight/sql/example/sqlite_statement_batch_reader.h @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include + +#include "arrow/flight/sql/example/sqlite_statement.h" +#include "arrow/record_batch.h" + +namespace arrow { +namespace flight { +namespace sql { +namespace example { + +class SqliteStatementBatchReader : public RecordBatchReader { + public: + /// \brief Creates a RecordBatchReader backed by a SQLite statement. + /// \param[in] statement SQLite statement to be read. + /// \return A SqliteStatementBatchReader. + static arrow::Result> Create( + const std::shared_ptr& statement); + + /// \brief Creates a RecordBatchReader backed by a SQLite statement. + /// \param[in] statement SQLite statement to be read. + /// \param[in] schema Schema to be used on results. + /// \return A SqliteStatementBatchReader.. + static arrow::Result> Create( + const std::shared_ptr& statement, + const std::shared_ptr& schema); + + std::shared_ptr schema() const override; + + Status ReadNext(std::shared_ptr* out) override; + + private: + std::shared_ptr statement_; + std::shared_ptr schema_; + int rc_; + bool already_executed_; + + SqliteStatementBatchReader(std::shared_ptr statement, + std::shared_ptr schema); +}; + +} // namespace example +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/example/sqlite_tables_schema_batch_reader.cc b/cpp/src/arrow/flight/sql/example/sqlite_tables_schema_batch_reader.cc new file mode 100644 index 00000000000..7fb68a709f8 --- /dev/null +++ b/cpp/src/arrow/flight/sql/example/sqlite_tables_schema_batch_reader.cc @@ -0,0 +1,106 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/sql/example/sqlite_tables_schema_batch_reader.h" + +#include + +#include + +#include "arrow/flight/sql/example/sqlite_server.h" +#include "arrow/flight/sql/example/sqlite_statement.h" +#include "arrow/flight/sql/server.h" +#include "arrow/ipc/writer.h" +#include "arrow/record_batch.h" + +namespace arrow { +namespace flight { +namespace sql { +namespace example { + +std::shared_ptr SqliteTablesWithSchemaBatchReader::schema() const { + return SqlSchema::GetTablesSchemaWithIncludedSchema(); +} + +Status SqliteTablesWithSchemaBatchReader::ReadNext(std::shared_ptr* batch) { + std::stringstream schema_query; + + schema_query + << "SELECT table_name, name, type, [notnull] FROM pragma_table_info(table_name)" + << "JOIN(" << main_query_ << ") order by table_name"; + + std::shared_ptr schema_statement; + ARROW_ASSIGN_OR_RAISE(schema_statement, + example::SqliteStatement::Create(db_, schema_query.str())) + + std::shared_ptr first_batch; + + ARROW_RETURN_NOT_OK(reader_->ReadNext(&first_batch)); + + if (!first_batch) { + *batch = NULLPTR; + return Status::OK(); + } + + const std::shared_ptr table_name_array = + first_batch->GetColumnByName("table_name"); + + BinaryBuilder schema_builder; + + auto* string_array = reinterpret_cast(table_name_array.get()); + + std::vector> column_fields; + for (int i = 0; i < table_name_array->length(); i++) { + const std::string& table_name = string_array->GetString(i); + + while (sqlite3_step(schema_statement->GetSqlite3Stmt()) == SQLITE_ROW) { + std::string sqlite_table_name = std::string(reinterpret_cast( + sqlite3_column_text(schema_statement->GetSqlite3Stmt(), 0))); + if (sqlite_table_name == table_name) { + const char* column_name = reinterpret_cast( + sqlite3_column_text(schema_statement->GetSqlite3Stmt(), 1)); + const char* column_type = reinterpret_cast( + sqlite3_column_text(schema_statement->GetSqlite3Stmt(), 2)); + int nullable = sqlite3_column_int(schema_statement->GetSqlite3Stmt(), 3); + + column_fields.push_back( + arrow::field(column_name, GetArrowType(column_type), nullable == 0, NULL)); + } + } + const arrow::Result>& value = + ipc::SerializeSchema(*arrow::schema(column_fields)); + + std::shared_ptr schema_buffer; + ARROW_ASSIGN_OR_RAISE(schema_buffer, value); + + column_fields.clear(); + ARROW_RETURN_NOT_OK( + schema_builder.Append(schema_buffer->data(), schema_buffer->size())); + } + + std::shared_ptr schema_array; + ARROW_RETURN_NOT_OK(schema_builder.Finish(&schema_array)); + + ARROW_ASSIGN_OR_RAISE(*batch, first_batch->AddColumn(4, "table_schema", schema_array)); + + return Status::OK(); +} + +} // namespace example +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/example/sqlite_tables_schema_batch_reader.h b/cpp/src/arrow/flight/sql/example/sqlite_tables_schema_batch_reader.h new file mode 100644 index 00000000000..ecba88efb2f --- /dev/null +++ b/cpp/src/arrow/flight/sql/example/sqlite_tables_schema_batch_reader.h @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include +#include + +#include "arrow/flight/sql/example/sqlite_statement.h" +#include "arrow/flight/sql/example/sqlite_statement_batch_reader.h" +#include "arrow/record_batch.h" + +namespace arrow { +namespace flight { +namespace sql { +namespace example { + +class SqliteTablesWithSchemaBatchReader : public RecordBatchReader { + private: + std::shared_ptr reader_; + std::string main_query_; + sqlite3* db_; + + public: + /// Constructor for SqliteTablesWithSchemaBatchReader class + /// \param reader an shared_ptr from a SqliteStatementBatchReader. + /// \param main_query SQL query that originated reader's data. + /// \param db a pointer to the sqlite3 db. + SqliteTablesWithSchemaBatchReader( + std::shared_ptr reader, std::string main_query, + sqlite3* db) + : reader_(std::move(reader)), main_query_(std::move(main_query)), db_(db) {} + + std::shared_ptr schema() const override; + + Status ReadNext(std::shared_ptr* batch) override; +}; + +} // namespace example +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/server.cc b/cpp/src/arrow/flight/sql/server.cc new file mode 100644 index 00000000000..bbbe801ea24 --- /dev/null +++ b/cpp/src/arrow/flight/sql/server.cc @@ -0,0 +1,764 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Interfaces to use for defining Flight RPC servers. API should be considered +// experimental for now + +#include "arrow/flight/sql/server.h" + +#include + +#include "arrow/buffer.h" +#include "arrow/builder.h" +#include "arrow/flight/sql/FlightSql.pb.h" +#include "arrow/flight/sql/sql_info_internal.h" +#include "arrow/type.h" +#include "arrow/util/checked_cast.h" + +#define PROPERTY_TO_OPTIONAL(COMMAND, PROPERTY) \ + COMMAND.has_##PROPERTY() ? util::make_optional(COMMAND.PROPERTY()) : util::nullopt + +namespace arrow { +namespace flight { +namespace sql { + +namespace pb = arrow::flight::protocol; + +using arrow::internal::checked_cast; +using arrow::internal::checked_pointer_cast; + +namespace { + +arrow::Result ParseCommandGetCrossReference( + const google::protobuf::Any& any) { + pb::sql::CommandGetCrossReference command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack CommandGetCrossReference."); + } + + GetCrossReference result; + result.pk_table_ref = {PROPERTY_TO_OPTIONAL(command, pk_catalog), + PROPERTY_TO_OPTIONAL(command, pk_db_schema), command.pk_table()}; + result.fk_table_ref = {PROPERTY_TO_OPTIONAL(command, fk_catalog), + PROPERTY_TO_OPTIONAL(command, fk_db_schema), command.fk_table()}; + return result; +} + +arrow::Result ParseCommandGetImportedKeys( + const google::protobuf::Any& any) { + pb::sql::CommandGetImportedKeys command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack CommandGetImportedKeys."); + } + + GetImportedKeys result; + result.table_ref = {PROPERTY_TO_OPTIONAL(command, catalog), + PROPERTY_TO_OPTIONAL(command, db_schema), command.table()}; + return result; +} + +arrow::Result ParseCommandGetExportedKeys( + const google::protobuf::Any& any) { + pb::sql::CommandGetExportedKeys command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack CommandGetExportedKeys."); + } + + GetExportedKeys result; + result.table_ref = {PROPERTY_TO_OPTIONAL(command, catalog), + PROPERTY_TO_OPTIONAL(command, db_schema), command.table()}; + return result; +} + +arrow::Result ParseCommandGetPrimaryKeys( + const google::protobuf::Any& any) { + pb::sql::CommandGetPrimaryKeys command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack CommandGetPrimaryKeys."); + } + + GetPrimaryKeys result; + result.table_ref = {PROPERTY_TO_OPTIONAL(command, catalog), + PROPERTY_TO_OPTIONAL(command, db_schema), command.table()}; + return result; +} + +arrow::Result ParseCommandGetSqlInfo( + const google::protobuf::Any& any, const SqlInfoResultMap& sql_info_id_to_result) { + pb::sql::CommandGetSqlInfo command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack CommandGetSqlInfo."); + } + + GetSqlInfo result; + if (command.info_size() > 0) { + result.info.reserve(command.info_size()); + result.info.assign(command.info().begin(), command.info().end()); + } else { + result.info.reserve(sql_info_id_to_result.size()); + for (const auto& it : sql_info_id_to_result) { + result.info.push_back(it.first); + } + } + return result; +} + +arrow::Result ParseCommandGetDbSchemas(const google::protobuf::Any& any) { + pb::sql::CommandGetDbSchemas command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack CommandGetDbSchemas."); + } + + GetDbSchemas result; + result.catalog = PROPERTY_TO_OPTIONAL(command, catalog); + result.db_schema_filter_pattern = + PROPERTY_TO_OPTIONAL(command, db_schema_filter_pattern); + return result; +} + +arrow::Result ParseCommandPreparedStatementQuery( + const google::protobuf::Any& any) { + pb::sql::CommandPreparedStatementQuery command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack CommandPreparedStatementQuery."); + } + + PreparedStatementQuery result; + result.prepared_statement_handle = command.prepared_statement_handle(); + return result; +} + +arrow::Result ParseCommandStatementQuery( + const google::protobuf::Any& any) { + pb::sql::CommandStatementQuery command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack CommandStatementQuery."); + } + + StatementQuery result; + result.query = command.query(); + return result; +} + +arrow::Result ParseCommandGetTables(const google::protobuf::Any& any) { + pb::sql::CommandGetTables command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack CommandGetTables."); + } + + std::vector table_types(command.table_types_size()); + std::copy(command.table_types().begin(), command.table_types().end(), + table_types.begin()); + + GetTables result; + result.catalog = PROPERTY_TO_OPTIONAL(command, catalog); + result.db_schema_filter_pattern = + PROPERTY_TO_OPTIONAL(command, db_schema_filter_pattern); + result.table_name_filter_pattern = + PROPERTY_TO_OPTIONAL(command, table_name_filter_pattern); + result.table_types = table_types; + result.include_schema = command.include_schema(); + return result; +} + +arrow::Result ParseStatementQueryTicket( + const google::protobuf::Any& any) { + pb::sql::TicketStatementQuery command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack TicketStatementQuery."); + } + + StatementQueryTicket result; + result.statement_handle = command.statement_handle(); + return result; +} + +arrow::Result ParseCommandStatementUpdate( + const google::protobuf::Any& any) { + pb::sql::CommandStatementUpdate command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack CommandStatementUpdate."); + } + + StatementUpdate result; + result.query = command.query(); + return result; +} + +arrow::Result ParseCommandPreparedStatementUpdate( + const google::protobuf::Any& any) { + pb::sql::CommandPreparedStatementUpdate command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack CommandPreparedStatementUpdate."); + } + + PreparedStatementUpdate result; + result.prepared_statement_handle = command.prepared_statement_handle(); + return result; +} + +arrow::Result +ParseActionCreatePreparedStatementRequest(const google::protobuf::Any& any) { + pb::sql::ActionCreatePreparedStatementRequest command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack ActionCreatePreparedStatementRequest."); + } + + ActionCreatePreparedStatementRequest result; + result.query = command.query(); + return result; +} + +arrow::Result +ParseActionClosePreparedStatementRequest(const google::protobuf::Any& any) { + pb::sql::ActionClosePreparedStatementRequest command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack ActionClosePreparedStatementRequest."); + } + + ActionClosePreparedStatementRequest result; + result.prepared_statement_handle = command.prepared_statement_handle(); + return result; +} + +} // namespace + +arrow::Result CreateStatementQueryTicket( + const std::string& statement_handle) { + protocol::sql::TicketStatementQuery ticket_statement_query; + ticket_statement_query.set_statement_handle(statement_handle); + + google::protobuf::Any ticket; + ticket.PackFrom(ticket_statement_query); + + std::string ticket_string; + + if (!ticket.SerializeToString(&ticket_string)) { + return Status::IOError("Invalid ticket."); + } + return ticket_string; +} + +Status FlightSqlServerBase::GetFlightInfo(const ServerCallContext& context, + const FlightDescriptor& request, + std::unique_ptr* info) { + google::protobuf::Any any; + if (!any.ParseFromArray(request.cmd.data(), static_cast(request.cmd.size()))) { + return Status::Invalid("Unable to parse command"); + } + + if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(StatementQuery internal_command, + ParseCommandStatementQuery(any)); + ARROW_ASSIGN_OR_RAISE(*info, + GetFlightInfoStatement(context, internal_command, request)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(PreparedStatementQuery internal_command, + ParseCommandPreparedStatementQuery(any)); + ARROW_ASSIGN_OR_RAISE( + *info, GetFlightInfoPreparedStatement(context, internal_command, request)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(*info, GetFlightInfoCatalogs(context, request)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(GetDbSchemas internal_command, ParseCommandGetDbSchemas(any)); + ARROW_ASSIGN_OR_RAISE(*info, + GetFlightInfoSchemas(context, internal_command, request)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(GetTables command, ParseCommandGetTables(any)); + ARROW_ASSIGN_OR_RAISE(*info, GetFlightInfoTables(context, command, request)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(*info, GetFlightInfoTableTypes(context, request)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(GetSqlInfo internal_command, + ParseCommandGetSqlInfo(any, sql_info_id_to_result_)); + ARROW_ASSIGN_OR_RAISE(*info, + GetFlightInfoSqlInfo(context, internal_command, request)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(GetPrimaryKeys internal_command, + ParseCommandGetPrimaryKeys(any)); + ARROW_ASSIGN_OR_RAISE(*info, + GetFlightInfoPrimaryKeys(context, internal_command, request)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(GetExportedKeys internal_command, + ParseCommandGetExportedKeys(any)); + ARROW_ASSIGN_OR_RAISE(*info, + GetFlightInfoExportedKeys(context, internal_command, request)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(GetImportedKeys internal_command, + ParseCommandGetImportedKeys(any)); + ARROW_ASSIGN_OR_RAISE(*info, + GetFlightInfoImportedKeys(context, internal_command, request)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(GetCrossReference internal_command, + ParseCommandGetCrossReference(any)); + ARROW_ASSIGN_OR_RAISE( + *info, GetFlightInfoCrossReference(context, internal_command, request)); + return Status::OK(); + } + + return Status::Invalid("The defined request is invalid."); +} + +Status FlightSqlServerBase::DoGet(const ServerCallContext& context, const Ticket& request, + std::unique_ptr* stream) { + google::protobuf::Any any; + + if (!any.ParseFromArray(request.ticket.data(), + static_cast(request.ticket.size()))) { + return Status::Invalid("Unable to parse ticket."); + } + + if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(StatementQueryTicket command, ParseStatementQueryTicket(any)); + ARROW_ASSIGN_OR_RAISE(*stream, DoGetStatement(context, command)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(PreparedStatementQuery internal_command, + ParseCommandPreparedStatementQuery(any)); + ARROW_ASSIGN_OR_RAISE(*stream, DoGetPreparedStatement(context, internal_command)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(*stream, DoGetCatalogs(context)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(GetDbSchemas internal_command, ParseCommandGetDbSchemas(any)); + ARROW_ASSIGN_OR_RAISE(*stream, DoGetDbSchemas(context, internal_command)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(GetTables command, ParseCommandGetTables(any)); + ARROW_ASSIGN_OR_RAISE(*stream, DoGetTables(context, command)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(*stream, DoGetTableTypes(context)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(GetSqlInfo internal_command, + ParseCommandGetSqlInfo(any, sql_info_id_to_result_)); + ARROW_ASSIGN_OR_RAISE(*stream, DoGetSqlInfo(context, internal_command)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(GetPrimaryKeys internal_command, + ParseCommandGetPrimaryKeys(any)); + ARROW_ASSIGN_OR_RAISE(*stream, DoGetPrimaryKeys(context, internal_command)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(GetExportedKeys internal_command, + ParseCommandGetExportedKeys(any)); + ARROW_ASSIGN_OR_RAISE(*stream, DoGetExportedKeys(context, internal_command)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(GetImportedKeys internal_command, + ParseCommandGetImportedKeys(any)); + ARROW_ASSIGN_OR_RAISE(*stream, DoGetImportedKeys(context, internal_command)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(GetCrossReference internal_command, + ParseCommandGetCrossReference(any)); + ARROW_ASSIGN_OR_RAISE(*stream, DoGetCrossReference(context, internal_command)); + return Status::OK(); + } + + return Status::Invalid("The defined request is invalid."); +} + +Status FlightSqlServerBase::DoPut(const ServerCallContext& context, + std::unique_ptr reader, + std::unique_ptr writer) { + const FlightDescriptor& request = reader->descriptor(); + + google::protobuf::Any any; + if (!any.ParseFromArray(request.cmd.data(), static_cast(request.cmd.size()))) { + return Status::Invalid("Unable to parse command."); + } + + if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(StatementUpdate internal_command, + ParseCommandStatementUpdate(any)); + ARROW_ASSIGN_OR_RAISE(auto record_count, + DoPutCommandStatementUpdate(context, internal_command)) + + pb::sql::DoPutUpdateResult result; + result.set_record_count(record_count); + + const auto buffer = Buffer::FromString(result.SerializeAsString()); + ARROW_RETURN_NOT_OK(writer->WriteMetadata(*buffer)); + + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(PreparedStatementQuery internal_command, + ParseCommandPreparedStatementQuery(any)); + return DoPutPreparedStatementQuery(context, internal_command, reader.get(), + writer.get()); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(PreparedStatementUpdate internal_command, + ParseCommandPreparedStatementUpdate(any)); + ARROW_ASSIGN_OR_RAISE(auto record_count, DoPutPreparedStatementUpdate( + context, internal_command, reader.get())) + + pb::sql::DoPutUpdateResult result; + result.set_record_count(record_count); + + const auto buffer = Buffer::FromString(result.SerializeAsString()); + ARROW_RETURN_NOT_OK(writer->WriteMetadata(*buffer)); + + return Status::OK(); + } + + return Status::Invalid("The defined request is invalid."); +} + +Status FlightSqlServerBase::ListActions(const ServerCallContext& context, + std::vector* actions) { + *actions = {FlightSqlServerBase::kCreatePreparedStatementActionType, + FlightSqlServerBase::kClosePreparedStatementActionType}; + return Status::OK(); +} + +Status FlightSqlServerBase::DoAction(const ServerCallContext& context, + const Action& action, + std::unique_ptr* result_stream) { + if (action.type == FlightSqlServerBase::kCreatePreparedStatementActionType.type) { + google::protobuf::Any any_command; + if (!any_command.ParseFromArray(action.body->data(), + static_cast(action.body->size()))) { + return Status::Invalid("Unable to parse action."); + } + + ARROW_ASSIGN_OR_RAISE(ActionCreatePreparedStatementRequest internal_command, + ParseActionCreatePreparedStatementRequest(any_command)); + ARROW_ASSIGN_OR_RAISE(auto result, CreatePreparedStatement(context, internal_command)) + + pb::sql::ActionCreatePreparedStatementResult action_result; + action_result.set_prepared_statement_handle(result.prepared_statement_handle); + if (result.dataset_schema != nullptr) { + ARROW_ASSIGN_OR_RAISE(auto serialized_dataset_schema, + ipc::SerializeSchema(*result.dataset_schema)) + action_result.set_dataset_schema(serialized_dataset_schema->ToString()); + } + if (result.parameter_schema != nullptr) { + ARROW_ASSIGN_OR_RAISE(auto serialized_parameter_schema, + ipc::SerializeSchema(*result.parameter_schema)) + action_result.set_parameter_schema(serialized_parameter_schema->ToString()); + } + + google::protobuf::Any any; + any.PackFrom(action_result); + + auto buf = Buffer::FromString(any.SerializeAsString()); + *result_stream = std::unique_ptr(new SimpleResultStream({Result{buf}})); + + return Status::OK(); + } else if (action.type == FlightSqlServerBase::kClosePreparedStatementActionType.type) { + google::protobuf::Any any; + if (!any.ParseFromArray(action.body->data(), static_cast(action.body->size()))) { + return Status::Invalid("Unable to parse action."); + } + + ARROW_ASSIGN_OR_RAISE(ActionClosePreparedStatementRequest internal_command, + ParseActionClosePreparedStatementRequest(any)); + + ARROW_RETURN_NOT_OK(ClosePreparedStatement(context, internal_command)); + + // Need to instantiate a ResultStream, otherwise clients can not wait for completion. + *result_stream = std::unique_ptr(new SimpleResultStream({})); + return Status::OK(); + } + return Status::Invalid("The defined request is invalid."); +} + +arrow::Result> FlightSqlServerBase::GetFlightInfoCatalogs( + const ServerCallContext& context, const FlightDescriptor& descriptor) { + return Status::NotImplemented("GetFlightInfoCatalogs not implemented"); +} + +arrow::Result> FlightSqlServerBase::DoGetCatalogs( + const ServerCallContext& context) { + return Status::NotImplemented("DoGetCatalogs not implemented"); +} + +arrow::Result> FlightSqlServerBase::GetFlightInfoStatement( + const ServerCallContext& context, const StatementQuery& command, + const FlightDescriptor& descriptor) { + return Status::NotImplemented("GetFlightInfoStatement not implemented"); +} + +arrow::Result> FlightSqlServerBase::DoGetStatement( + const ServerCallContext& context, const StatementQueryTicket& command) { + return Status::NotImplemented("DoGetStatement not implemented"); +} + +arrow::Result> +FlightSqlServerBase::GetFlightInfoPreparedStatement(const ServerCallContext& context, + const PreparedStatementQuery& command, + const FlightDescriptor& descriptor) { + return Status::NotImplemented("GetFlightInfoPreparedStatement not implemented"); +} + +arrow::Result> +FlightSqlServerBase::DoGetPreparedStatement(const ServerCallContext& context, + const PreparedStatementQuery& command) { + return Status::NotImplemented("DoGetPreparedStatement not implemented"); +} + +arrow::Result> FlightSqlServerBase::GetFlightInfoSqlInfo( + const ServerCallContext& context, const GetSqlInfo& command, + const FlightDescriptor& descriptor) { + if (sql_info_id_to_result_.empty()) { + return Status::KeyError("No SQL information available."); + } + + std::vector endpoints{FlightEndpoint{{descriptor.cmd}, {}}}; + ARROW_ASSIGN_OR_RAISE(auto result, FlightInfo::Make(*SqlSchema::GetSqlInfoSchema(), + descriptor, endpoints, -1, -1)) + + return std::unique_ptr(new FlightInfo(result)); +} + +void FlightSqlServerBase::RegisterSqlInfo(int32_t id, const SqlInfoResult& result) { + sql_info_id_to_result_[id] = result; +} + +arrow::Result> FlightSqlServerBase::DoGetSqlInfo( + const ServerCallContext& context, const GetSqlInfo& command) { + MemoryPool* memory_pool = default_memory_pool(); + UInt32Builder name_field_builder(memory_pool); + std::unique_ptr value_field_builder; + const auto& value_field_type = checked_pointer_cast( + SqlSchema::GetSqlInfoSchema()->fields()[1]->type()); + ARROW_RETURN_NOT_OK(MakeBuilder(memory_pool, value_field_type, &value_field_builder)); + + internal::SqlInfoResultAppender sql_info_result_appender( + checked_cast(value_field_builder.get())); + + // Populate both name_field_builder and value_field_builder for each element + // on command.info. + // value_field_builder is populated differently depending on the data type (as it is + // a DenseUnionBuilder). The population for each data type is implemented on + // internal::SqlInfoResultAppender. + for (const auto& info : command.info) { + const auto it = sql_info_id_to_result_.find(info); + if (it == sql_info_id_to_result_.end()) { + return Status::KeyError("No information for SQL info number ", info); + } + ARROW_RETURN_NOT_OK(name_field_builder.Append(info)); + ARROW_RETURN_NOT_OK(arrow::util::visit(sql_info_result_appender, it->second)); + } + + std::shared_ptr name; + ARROW_RETURN_NOT_OK(name_field_builder.Finish(&name)); + std::shared_ptr value; + ARROW_RETURN_NOT_OK(value_field_builder->Finish(&value)); + + auto row_count = static_cast(command.info.size()); + const std::shared_ptr& batch = + RecordBatch::Make(SqlSchema::GetSqlInfoSchema(), row_count, {name, value}); + ARROW_ASSIGN_OR_RAISE(const auto reader, RecordBatchReader::Make({batch})); + + return std::unique_ptr(new RecordBatchStream(reader)); +} + +arrow::Result> FlightSqlServerBase::GetFlightInfoSchemas( + const ServerCallContext& context, const GetDbSchemas& command, + const FlightDescriptor& descriptor) { + return Status::NotImplemented("GetFlightInfoSchemas not implemented"); +} + +arrow::Result> FlightSqlServerBase::DoGetDbSchemas( + const ServerCallContext& context, const GetDbSchemas& command) { + return Status::NotImplemented("DoGetDbSchemas not implemented"); +} + +arrow::Result> FlightSqlServerBase::GetFlightInfoTables( + const ServerCallContext& context, const GetTables& command, + const FlightDescriptor& descriptor) { + return Status::NotImplemented("GetFlightInfoTables not implemented"); +} + +arrow::Result> FlightSqlServerBase::DoGetTables( + const ServerCallContext& context, const GetTables& command) { + return Status::NotImplemented("DoGetTables not implemented"); +} + +arrow::Result> FlightSqlServerBase::GetFlightInfoTableTypes( + const ServerCallContext& context, const FlightDescriptor& descriptor) { + return Status::NotImplemented("GetFlightInfoTableTypes not implemented"); +} + +arrow::Result> FlightSqlServerBase::DoGetTableTypes( + const ServerCallContext& context) { + return Status::NotImplemented("DoGetTableTypes not implemented"); +} + +arrow::Result> FlightSqlServerBase::GetFlightInfoPrimaryKeys( + const ServerCallContext& context, const GetPrimaryKeys& command, + const FlightDescriptor& descriptor) { + return Status::NotImplemented("GetFlightInfoPrimaryKeys not implemented"); +} + +arrow::Result> FlightSqlServerBase::DoGetPrimaryKeys( + const ServerCallContext& context, const GetPrimaryKeys& command) { + return Status::NotImplemented("DoGetPrimaryKeys not implemented"); +} + +arrow::Result> FlightSqlServerBase::GetFlightInfoExportedKeys( + const ServerCallContext& context, const GetExportedKeys& command, + const FlightDescriptor& descriptor) { + return Status::NotImplemented("GetFlightInfoExportedKeys not implemented"); +} + +arrow::Result> FlightSqlServerBase::DoGetExportedKeys( + const ServerCallContext& context, const GetExportedKeys& command) { + return Status::NotImplemented("DoGetExportedKeys not implemented"); +} + +arrow::Result> FlightSqlServerBase::GetFlightInfoImportedKeys( + const ServerCallContext& context, const GetImportedKeys& command, + const FlightDescriptor& descriptor) { + return Status::NotImplemented("GetFlightInfoImportedKeys not implemented"); +} + +arrow::Result> FlightSqlServerBase::DoGetImportedKeys( + const ServerCallContext& context, const GetImportedKeys& command) { + return Status::NotImplemented("DoGetImportedKeys not implemented"); +} + +arrow::Result> +FlightSqlServerBase::GetFlightInfoCrossReference(const ServerCallContext& context, + const GetCrossReference& command, + const FlightDescriptor& descriptor) { + return Status::NotImplemented("GetFlightInfoCrossReference not implemented"); +} + +arrow::Result> FlightSqlServerBase::DoGetCrossReference( + const ServerCallContext& context, const GetCrossReference& command) { + return Status::NotImplemented("DoGetCrossReference not implemented"); +} + +arrow::Result +FlightSqlServerBase::CreatePreparedStatement( + const ServerCallContext& context, + const ActionCreatePreparedStatementRequest& request) { + return Status::NotImplemented("CreatePreparedStatement not implemented"); +} + +Status FlightSqlServerBase::ClosePreparedStatement( + const ServerCallContext& context, + const ActionClosePreparedStatementRequest& request) { + return Status::NotImplemented("ClosePreparedStatement not implemented"); +} + +Status FlightSqlServerBase::DoPutPreparedStatementQuery( + const ServerCallContext& context, const PreparedStatementQuery& command, + FlightMessageReader* reader, FlightMetadataWriter* writer) { + return Status::NotImplemented("DoPutPreparedStatementQuery not implemented"); +} + +arrow::Result FlightSqlServerBase::DoPutPreparedStatementUpdate( + const ServerCallContext& context, const PreparedStatementUpdate& command, + FlightMessageReader* reader) { + return Status::NotImplemented("DoPutPreparedStatementUpdate not implemented"); +} + +arrow::Result FlightSqlServerBase::DoPutCommandStatementUpdate( + const ServerCallContext& context, const StatementUpdate& command) { + return Status::NotImplemented("DoPutCommandStatementUpdate not implemented"); +} + +std::shared_ptr SqlSchema::GetCatalogsSchema() { + return arrow::schema({field("catalog_name", utf8(), false)}); +} + +std::shared_ptr SqlSchema::GetDbSchemasSchema() { + return arrow::schema( + {field("catalog_name", utf8()), field("db_schema_name", utf8(), false)}); +} + +std::shared_ptr SqlSchema::GetTablesSchema() { + return arrow::schema({field("catalog_name", utf8()), field("db_schema_name", utf8()), + field("table_name", utf8(), false), + field("table_type", utf8(), false)}); +} + +std::shared_ptr SqlSchema::GetTablesSchemaWithIncludedSchema() { + return arrow::schema({field("catalog_name", utf8()), field("db_schema_name", utf8()), + field("table_name", utf8(), false), + field("table_type", utf8(), false), + field("table_schema", binary(), false)}); +} + +std::shared_ptr SqlSchema::GetTableTypesSchema() { + return arrow::schema({field("table_type", utf8(), false)}); +} + +std::shared_ptr SqlSchema::GetPrimaryKeysSchema() { + return arrow::schema( + {field("catalog_name", utf8()), field("db_schema_name", utf8()), + field("table_name", utf8(), false), field("column_name", utf8(), false), + field("key_sequence", int32(), false), field("key_name", utf8())}); +} + +std::shared_ptr GetImportedExportedKeysAndCrossReferenceSchema() { + return arrow::schema( + {field("pk_catalog_name", utf8(), true), field("pk_db_schema_name", utf8(), true), + field("pk_table_name", utf8(), false), field("pk_column_name", utf8(), false), + field("fk_catalog_name", utf8(), true), field("fk_db_schema_name", utf8(), true), + field("fk_table_name", utf8(), false), field("fk_column_name", utf8(), false), + field("key_sequence", int32(), false), field("fk_key_name", utf8(), true), + field("pk_key_name", utf8(), true), field("update_rule", uint8(), false), + field("delete_rule", uint8(), false)}); +} + +std::shared_ptr SqlSchema::GetImportedKeysSchema() { + return GetImportedExportedKeysAndCrossReferenceSchema(); +} + +std::shared_ptr SqlSchema::GetExportedKeysSchema() { + return GetImportedExportedKeysAndCrossReferenceSchema(); +} + +std::shared_ptr SqlSchema::GetCrossReferenceSchema() { + return GetImportedExportedKeysAndCrossReferenceSchema(); +} + +std::shared_ptr SqlSchema::GetSqlInfoSchema() { + return arrow::schema({field("info_name", uint32(), false), + field("value", + dense_union({field("string_value", utf8(), false), + field("bool_value", boolean(), false), + field("bigint_value", int64(), false), + field("int32_bitmask", int32(), false), + field("string_list", list(utf8()), false), + field("int32_to_int32_list_map", + map(int32(), list(int32())), false)}), + false)}); +} + +} // namespace sql +} // namespace flight +} // namespace arrow + +#undef PROPERTY_TO_OPTIONAL diff --git a/cpp/src/arrow/flight/sql/server.h b/cpp/src/arrow/flight/sql/server.h new file mode 100644 index 00000000000..1d6101683c1 --- /dev/null +++ b/cpp/src/arrow/flight/sql/server.h @@ -0,0 +1,443 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Interfaces to use for defining Flight RPC servers. API should be considered +// experimental for now + +#pragma once + +#include +#include +#include + +#include "arrow/flight/server.h" +#include "arrow/flight/sql/server.h" +#include "arrow/flight/sql/types.h" +#include "arrow/util/optional.h" + +namespace arrow { +namespace flight { +namespace sql { + +struct StatementQuery { + std::string query; +}; + +struct StatementUpdate { + std::string query; +}; + +struct StatementQueryTicket { + std::string statement_handle; +}; + +struct PreparedStatementQuery { + std::string prepared_statement_handle; +}; + +struct PreparedStatementUpdate { + std::string prepared_statement_handle; +}; + +struct GetSqlInfo { + std::vector info; +}; + +struct GetDbSchemas { + util::optional catalog; + util::optional db_schema_filter_pattern; +}; + +struct GetTables { + util::optional catalog; + util::optional db_schema_filter_pattern; + util::optional table_name_filter_pattern; + std::vector table_types; + bool include_schema; +}; + +struct GetPrimaryKeys { + TableRef table_ref; +}; + +struct GetExportedKeys { + TableRef table_ref; +}; + +struct GetImportedKeys { + TableRef table_ref; +}; + +struct GetCrossReference { + TableRef pk_table_ref; + TableRef fk_table_ref; +}; + +struct ActionCreatePreparedStatementRequest { + std::string query; +}; + +struct ActionClosePreparedStatementRequest { + std::string prepared_statement_handle; +}; + +struct ActionCreatePreparedStatementResult { + std::shared_ptr dataset_schema; + std::shared_ptr parameter_schema; + std::string prepared_statement_handle; +}; + +/// \brief A utility function to create a ticket (a opaque binary token that the server +/// uses to identify this query) for a statement query. +/// Intended for Flight SQL server implementations. +/// \param[in] statement_handle The statement handle that will originate the ticket. +/// \return The parsed ticket as an string. +arrow::Result CreateStatementQueryTicket( + const std::string& statement_handle); + +class ARROW_EXPORT FlightSqlServerBase : public FlightServerBase { + private: + SqlInfoResultMap sql_info_id_to_result_; + + public: + Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request, + std::unique_ptr* info) override; + + Status DoGet(const ServerCallContext& context, const Ticket& request, + std::unique_ptr* stream) override; + + Status DoPut(const ServerCallContext& context, + std::unique_ptr reader, + std::unique_ptr writer) override; + + const ActionType kCreatePreparedStatementActionType = + ActionType{"CreatePreparedStatement", + "Creates a reusable prepared statement resource on the server.\n" + "Request Message: ActionCreatePreparedStatementRequest\n" + "Response Message: ActionCreatePreparedStatementResult"}; + const ActionType kClosePreparedStatementActionType = + ActionType{"ClosePreparedStatement", + "Closes a reusable prepared statement resource on the server.\n" + "Request Message: ActionClosePreparedStatementRequest\n" + "Response Message: N/A"}; + + Status ListActions(const ServerCallContext& context, + std::vector* actions) override; + + Status DoAction(const ServerCallContext& context, const Action& action, + std::unique_ptr* result) override; + + /// \brief Get a FlightInfo for executing a SQL query. + /// \param[in] context Per-call context. + /// \param[in] command The StatementQuery object containing the SQL statement. + /// \param[in] descriptor The descriptor identifying the data stream. + /// \return The FlightInfo describing where to access the dataset. + virtual arrow::Result> GetFlightInfoStatement( + const ServerCallContext& context, const StatementQuery& command, + const FlightDescriptor& descriptor); + + /// \brief Get a FlightDataStream containing the query results. + /// \param[in] context Per-call context. + /// \param[in] command The StatementQueryTicket containing the statement handle. + /// \return The FlightDataStream containing the results. + virtual arrow::Result> DoGetStatement( + const ServerCallContext& context, const StatementQueryTicket& command); + + /// \brief Get a FlightInfo for executing an already created prepared statement. + /// \param[in] context Per-call context. + /// \param[in] command The PreparedStatementQuery object containing the + /// prepared statement handle. + /// \param[in] descriptor The descriptor identifying the data stream. + /// \return The FlightInfo describing where to access the + /// dataset. + virtual arrow::Result> GetFlightInfoPreparedStatement( + const ServerCallContext& context, const PreparedStatementQuery& command, + const FlightDescriptor& descriptor); + + /// \brief Get a FlightDataStream containing the prepared statement query results. + /// \param[in] context Per-call context. + /// \param[in] command The PreparedStatementQuery object containing the + /// prepared statement handle. + /// \return The FlightDataStream containing the results. + virtual arrow::Result> DoGetPreparedStatement( + const ServerCallContext& context, const PreparedStatementQuery& command); + + /// \brief Get a FlightInfo for listing catalogs. + /// \param[in] context Per-call context. + /// \param[in] descriptor The descriptor identifying the data stream. + /// \return The FlightInfo describing where to access the dataset. + virtual arrow::Result> GetFlightInfoCatalogs( + const ServerCallContext& context, const FlightDescriptor& descriptor); + + /// \brief Get a FlightDataStream containing the list of catalogs. + /// \param[in] context Per-call context. + /// \return An interface for sending data back to the client. + virtual arrow::Result> DoGetCatalogs( + const ServerCallContext& context); + + /// \brief Get a FlightInfo for retrieving other information (See SqlInfo). + /// \param[in] context Per-call context. + /// \param[in] command The GetSqlInfo object containing the list of SqlInfo + /// to be returned. + /// \param[in] descriptor The descriptor identifying the data stream. + /// \return The FlightInfo describing where to access the dataset. + virtual arrow::Result> GetFlightInfoSqlInfo( + const ServerCallContext& context, const GetSqlInfo& command, + const FlightDescriptor& descriptor); + + /// \brief Get a FlightDataStream containing the list of SqlInfo results. + /// \param[in] context Per-call context. + /// \param[in] command The GetSqlInfo object containing the list of SqlInfo + /// to be returned. + /// \return The FlightDataStream containing the results. + virtual arrow::Result> DoGetSqlInfo( + const ServerCallContext& context, const GetSqlInfo& command); + + /// \brief Get a FlightInfo for listing schemas. + /// \param[in] context Per-call context. + /// \param[in] command The GetDbSchemas object which may contain filters for + /// catalog and schema name. + /// \param[in] descriptor The descriptor identifying the data stream. + /// \return The FlightInfo describing where to access the dataset. + virtual arrow::Result> GetFlightInfoSchemas( + const ServerCallContext& context, const GetDbSchemas& command, + const FlightDescriptor& descriptor); + + /// \brief Get a FlightDataStream containing the list of schemas. + /// \param[in] context Per-call context. + /// \param[in] command The GetDbSchemas object which may contain filters for + /// catalog and schema name. + /// \return The FlightDataStream containing the results. + virtual arrow::Result> DoGetDbSchemas( + const ServerCallContext& context, const GetDbSchemas& command); + + ///\brief Get a FlightInfo for listing tables. + /// \param[in] context Per-call context. + /// \param[in] command The GetTables object which may contain filters for + /// catalog, schema and table names. + /// \param[in] descriptor The descriptor identifying the data stream. + /// \return The FlightInfo describing where to access the dataset. + virtual arrow::Result> GetFlightInfoTables( + const ServerCallContext& context, const GetTables& command, + const FlightDescriptor& descriptor); + + /// \brief Get a FlightDataStream containing the list of tables. + /// \param[in] context Per-call context. + /// \param[in] command The GetTables object which may contain filters for + /// catalog, schema and table names. + /// \return The FlightDataStream containing the results. + virtual arrow::Result> DoGetTables( + const ServerCallContext& context, const GetTables& command); + + /// \brief Get a FlightInfo to extract information about the table types. + /// \param[in] context Per-call context. + /// \param[in] descriptor The descriptor identifying the data stream. + /// \return The FlightInfo describing where to access the + /// dataset. + virtual arrow::Result> GetFlightInfoTableTypes( + const ServerCallContext& context, const FlightDescriptor& descriptor); + + /// \brief Get a FlightDataStream containing the data related to the table types. + /// \param[in] context Per-call context. + /// \return The FlightDataStream containing the results. + virtual arrow::Result> DoGetTableTypes( + const ServerCallContext& context); + + /// \brief Get a FlightInfo to extract information about primary and foreign keys. + /// \param[in] context Per-call context. + /// \param[in] command The GetPrimaryKeys object with necessary information + /// to execute the request. + /// \param[in] descriptor The descriptor identifying the data stream. + /// \return The FlightInfo describing where to access the + /// dataset. + virtual arrow::Result> GetFlightInfoPrimaryKeys( + const ServerCallContext& context, const GetPrimaryKeys& command, + const FlightDescriptor& descriptor); + + /// \brief Get a FlightDataStream containing the data related to the primary and + /// foreign + /// keys. + /// \param[in] context Per-call context. + /// \param[in] command The GetPrimaryKeys object with necessary information + /// to execute the request. + /// \return The FlightDataStream containing the results. + virtual arrow::Result> DoGetPrimaryKeys( + const ServerCallContext& context, const GetPrimaryKeys& command); + + /// \brief Get a FlightInfo to extract information about foreign and primary keys. + /// \param[in] context Per-call context. + /// \param[in] command The GetExportedKeys object with necessary information + /// to execute the request. + /// \param[in] descriptor The descriptor identifying the data stream. + /// \return The FlightInfo describing where to access the + /// dataset. + virtual arrow::Result> GetFlightInfoExportedKeys( + const ServerCallContext& context, const GetExportedKeys& command, + const FlightDescriptor& descriptor); + + /// \brief Get a FlightDataStream containing the data related to the foreign and + /// primary + /// keys. + /// \param[in] context Per-call context. + /// \param[in] command The GetExportedKeys object with necessary information + /// to execute the request. + /// \return The FlightDataStream containing the results. + virtual arrow::Result> DoGetExportedKeys( + const ServerCallContext& context, const GetExportedKeys& command); + + /// \brief Get a FlightInfo to extract information about foreign and primary keys. + /// \param[in] context Per-call context. + /// \param[in] command The GetImportedKeys object with necessary information + /// to execute the request. + /// \param[in] descriptor The descriptor identifying the data stream. + /// \return The FlightInfo describing where to access the + /// dataset. + virtual arrow::Result> GetFlightInfoImportedKeys( + const ServerCallContext& context, const GetImportedKeys& command, + const FlightDescriptor& descriptor); + + /// \brief Get a FlightDataStream containing the data related to the foreign and + /// primary keys. + /// \param[in] context Per-call context. + /// \param[in] command The GetImportedKeys object with necessary information + /// to execute the request. + /// \return The FlightDataStream containing the results. + virtual arrow::Result> DoGetImportedKeys( + const ServerCallContext& context, const GetImportedKeys& command); + + /// \brief Get a FlightInfo to extract information about foreign and primary keys. + /// \param[in] context Per-call context. + /// \param[in] command The GetCrossReference object with necessary + /// information + /// to execute the request. + /// \param[in] descriptor The descriptor identifying the data stream. + /// \return The FlightInfo describing where to access the + /// dataset. + virtual arrow::Result> GetFlightInfoCrossReference( + const ServerCallContext& context, const GetCrossReference& command, + const FlightDescriptor& descriptor); + + /// \brief Get a FlightDataStream containing the data related to the foreign and + /// primary keys. + /// \param[in] context Per-call context. + /// \param[in] command The GetCrossReference object with necessary information + /// to execute the request. + /// \return The FlightDataStream containing the results. + virtual arrow::Result> DoGetCrossReference( + const ServerCallContext& context, const GetCrossReference& command); + + /// \brief Execute an update SQL statement. + /// \param[in] context The call context. + /// \param[in] command The StatementUpdate object containing the SQL statement. + /// \return The changed record count. + virtual arrow::Result DoPutCommandStatementUpdate( + const ServerCallContext& context, const StatementUpdate& command); + + /// \brief Create a prepared statement from given SQL statement. + /// \param[in] context The call context. + /// \param[in] request The ActionCreatePreparedStatementRequest object containing the + /// SQL statement. + /// \return A ActionCreatePreparedStatementResult containing the dataset + /// and parameter schemas and a handle for created statement. + virtual arrow::Result CreatePreparedStatement( + const ServerCallContext& context, + const ActionCreatePreparedStatementRequest& request); + + /// \brief Close a prepared statement. + /// \param[in] context The call context. + /// \param[in] request The ActionClosePreparedStatementRequest object containing the + /// prepared statement handle. + virtual Status ClosePreparedStatement( + const ServerCallContext& context, + const ActionClosePreparedStatementRequest& request); + + /// \brief Bind parameters to given prepared statement. + /// \param[in] context The call context. + /// \param[in] command The PreparedStatementQuery object containing the + /// prepared statement handle. + /// \param[in] reader A sequence of uploaded record batches. + /// \param[in] writer Send metadata back to the client. + virtual Status DoPutPreparedStatementQuery(const ServerCallContext& context, + const PreparedStatementQuery& command, + FlightMessageReader* reader, + FlightMetadataWriter* writer); + + /// \brief Execute an update SQL prepared statement. + /// \param[in] context The call context. + /// \param[in] command The PreparedStatementUpdate object containing the + /// prepared statement handle. + /// \param[in] reader a sequence of uploaded record batches. + /// \return The changed record count. + virtual arrow::Result DoPutPreparedStatementUpdate( + const ServerCallContext& context, const PreparedStatementUpdate& command, + FlightMessageReader* reader); + + /// \brief Register a new SqlInfo result, making it available when calling GetSqlInfo. + /// \param[in] id the SqlInfo identifier. + /// \param[in] result the result. + void RegisterSqlInfo(int32_t id, const SqlInfoResult& result); +}; + +/// \brief Auxiliary class containing all Schemas used on Flight SQL. +class ARROW_EXPORT SqlSchema { + public: + /// \brief Get the Schema used on GetCatalogs response. + /// \return The default schema template. + static std::shared_ptr GetCatalogsSchema(); + + /// \brief Get the Schema used on GetDbSchemas response. + /// \return The default schema template. + static std::shared_ptr GetDbSchemasSchema(); + + /// \brief Get the Schema used on GetTables response when included schema + /// flags is set to false. + /// \return The default schema template. + static std::shared_ptr GetTablesSchema(); + + /// \brief Get the Schema used on GetTables response when included schema + /// flags is set to true. + /// \return The default schema template. + static std::shared_ptr GetTablesSchemaWithIncludedSchema(); + + /// \brief Get the Schema used on GetTableTypes response. + /// \return The default schema template. + static std::shared_ptr GetTableTypesSchema(); + + /// \brief Get the Schema used on GetPrimaryKeys response when included schema + /// flags is set to true. + /// \return The default schema template. + static std::shared_ptr GetPrimaryKeysSchema(); + + /// \brief Get the Schema used on GetImportedKeys response. + /// \return The default schema template. + static std::shared_ptr GetExportedKeysSchema(); + + /// \brief Get the Schema used on GetImportedKeys response. + /// \return The default schema template. + static std::shared_ptr GetImportedKeysSchema(); + + /// \brief Get the Schema used on GetCrossReference response. + /// \return The default schema template. + static std::shared_ptr GetCrossReferenceSchema(); + + /// \brief Get the Schema used on GetSqlInfo response. + /// \return The default schema template. + static std::shared_ptr GetSqlInfoSchema(); +}; +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/server_test.cc b/cpp/src/arrow/flight/sql/server_test.cc new file mode 100644 index 00000000000..d74b6d40137 --- /dev/null +++ b/cpp/src/arrow/flight/sql/server_test.cc @@ -0,0 +1,767 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/sql/server.h" + +#include +#include +#include + +#include +#include + +#include "arrow/flight/api.h" +#include "arrow/flight/sql/api.h" +#include "arrow/flight/sql/example/sqlite_server.h" +#include "arrow/flight/sql/example/sqlite_sql_info.h" +#include "arrow/flight/test_util.h" +#include "arrow/flight/types.h" +#include "arrow/testing/gtest_util.h" + +using ::testing::_; +using ::testing::Ref; + +using arrow::internal::checked_cast; + +namespace arrow { +namespace flight { +namespace sql { + +/// \brief Auxiliary variant visitor used to assert that GetSqlInfo's values are +/// correctly placed on its DenseUnionArray +class SqlInfoDenseUnionValidator { + private: + const DenseUnionScalar& data; + + public: + /// \brief Asserts that the current DenseUnionScalar equals to given string value + void operator()(const std::string& string_value) const { + const auto& scalar = checked_cast(*data.value); + ASSERT_EQ(string_value, scalar.ToString()); + } + + /// \brief Asserts that the current DenseUnionScalar equals to given bool value + void operator()(const bool bool_value) const { + const auto& scalar = checked_cast(*data.value); + ASSERT_EQ(bool_value, scalar.value); + } + + /// \brief Asserts that the current DenseUnionScalar equals to given int64_t value + void operator()(const int64_t bigint_value) const { + const auto& scalar = checked_cast(*data.value); + ASSERT_EQ(bigint_value, scalar.value); + } + + /// \brief Asserts that the current DenseUnionScalar equals to given int32_t value + void operator()(const int32_t int32_bitmask) const { + const auto& scalar = checked_cast(*data.value); + ASSERT_EQ(int32_bitmask, scalar.value); + } + + /// \brief Asserts that the current DenseUnionScalar equals to given string list + void operator()(const std::vector& string_list) const { + const auto& array = checked_cast( + *(checked_cast(*data.value).value)); + + ASSERT_EQ(string_list.size(), array.length()); + + for (size_t index = 0; index < string_list.size(); index++) { + ASSERT_EQ(string_list[index], array.GetString(index)); + } + } + + /// \brief Asserts that the current DenseUnionScalar equals to given int32 to int32 list + /// map. + void operator()(const std::unordered_map>& + int32_to_int32_list) const { + const auto& struct_array = checked_cast( + *checked_cast(*data.value).value); + const auto& keys = checked_cast(*struct_array.field(0)); + const auto& values = checked_cast(*struct_array.field(1)); + + // Assert that the given map has the right size + ASSERT_EQ(int32_to_int32_list.size(), keys.length()); + + // For each element on given MapScalar, assert it matches the argument + for (int i = 0; i < keys.length(); i++) { + ASSERT_OK_AND_ASSIGN(const auto& key_scalar, keys.GetScalar(i)); + int32_t sql_info_id = checked_cast(*key_scalar).value; + + // Assert the key (SqlInfo id) exists + ASSERT_TRUE(int32_to_int32_list.count(sql_info_id)); + + const std::vector& expected_int32_list = + int32_to_int32_list.at(sql_info_id); + + // Assert the value (int32 list) has the correct size + ASSERT_EQ(expected_int32_list.size(), values.value_length(i)); + + // For each element on current ListScalar, assert it matches with the argument + for (size_t j = 0; j < expected_int32_list.size(); j++) { + ASSERT_OK_AND_ASSIGN(auto list_item_scalar, + values.values()->GetScalar(values.value_offset(i) + j)); + const auto& list_item = checked_cast(*list_item_scalar).value; + ASSERT_EQ(expected_int32_list[j], list_item); + } + } + } + + explicit SqlInfoDenseUnionValidator(const DenseUnionScalar& data) : data(data) {} + + SqlInfoDenseUnionValidator(const SqlInfoDenseUnionValidator&) = delete; + SqlInfoDenseUnionValidator(SqlInfoDenseUnionValidator&&) = delete; + SqlInfoDenseUnionValidator& operator=(const SqlInfoDenseUnionValidator&) = delete; +}; + +class TestFlightSqlServer : public ::testing::Test { + public: + std::unique_ptr sql_client; + + arrow::Result ExecuteCountQuery(const std::string& query) { + ARROW_ASSIGN_OR_RAISE(auto flight_info, sql_client->Execute({}, query)); + + ARROW_ASSIGN_OR_RAISE(auto stream, + sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + + std::shared_ptr table; + ARROW_RETURN_NOT_OK(stream->ReadAll(&table)); + + const std::shared_ptr& result_array = table->column(0)->chunk(0); + ARROW_ASSIGN_OR_RAISE(auto count_scalar, result_array->GetScalar(0)); + + return reinterpret_cast(*count_scalar).value; + } + + protected: + void SetUp() override { + port = GetListenPort(); + server_thread.reset(new std::thread([&]() { RunServer(); })); + + std::unique_lock lk(server_ready_m); + server_ready_cv.wait(lk); + + std::stringstream ss; + ss << "grpc://localhost:" << port; + std::string uri = ss.str(); + + std::unique_ptr client; + Location location; + ASSERT_OK(Location::Parse(uri, &location)); + ASSERT_OK(FlightClient::Connect(location, &client)); + + sql_client.reset(new FlightSqlClient(std::move(client))); + } + + void TearDown() override { + sql_client.reset(); + + ASSERT_OK(server->Shutdown()); + server_thread->join(); + server_thread.reset(); + } + + private: + int port; + std::shared_ptr server; + std::unique_ptr server_thread; + std::condition_variable server_ready_cv; + std::mutex server_ready_m; + + void RunServer() { + arrow::flight::Location location; + ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("localhost", port, &location)); + arrow::flight::FlightServerOptions options(location); + + ARROW_CHECK_OK(example::SQLiteFlightSqlServer::Create().Value(&server)); + + ARROW_CHECK_OK(server->Init(options)); + // Exit with a clean error code (0) on SIGTERM + ARROW_CHECK_OK(server->SetShutdownOnSignals({SIGTERM})); + + server_ready_cv.notify_all(); + ARROW_CHECK_OK(server->Serve()); + } +}; + +TEST_F(TestFlightSqlServer, TestCommandStatementQuery) { + ASSERT_OK_AND_ASSIGN(auto flight_info, + sql_client->Execute({}, "SELECT * FROM intTable")); + + ASSERT_OK_AND_ASSIGN(auto stream, + sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + + std::shared_ptr
table; + ASSERT_OK(stream->ReadAll(&table)); + + const std::shared_ptr& expected_schema = + arrow::schema({arrow::field("id", int64()), arrow::field("keyName", utf8()), + arrow::field("value", int64()), arrow::field("foreignId", int64())}); + + const auto id_array = ArrayFromJSON(int64(), R"([1, 2, 3, 4])"); + const auto keyname_array = + ArrayFromJSON(utf8(), R"(["one", "zero", "negative one", null])"); + const auto value_array = ArrayFromJSON(int64(), R"([1, 0, -1, null])"); + const auto foreignId_array = ArrayFromJSON(int64(), R"([1, 1, 1, null])"); + + const std::shared_ptr
& expected_table = Table::Make( + expected_schema, {id_array, keyname_array, value_array, foreignId_array}); + + AssertTablesEqual(*expected_table, *table); +} + +TEST_F(TestFlightSqlServer, TestCommandGetTables) { + FlightCallOptions options = {}; + std::string* catalog = nullptr; + std::string* schema_filter_pattern = nullptr; + std::string* table_filter_pattern = nullptr; + bool include_schema = false; + std::vector* table_types = nullptr; + + ASSERT_OK_AND_ASSIGN( + auto flight_info, + sql_client->GetTables(options, catalog, schema_filter_pattern, table_filter_pattern, + include_schema, table_types)); + + ASSERT_OK_AND_ASSIGN(auto stream, + sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + + std::shared_ptr
table; + ASSERT_OK(stream->ReadAll(&table)); + + ASSERT_OK_AND_ASSIGN(auto catalog_name, MakeArrayOfNull(utf8(), 3)) + ASSERT_OK_AND_ASSIGN(auto schema_name, MakeArrayOfNull(utf8(), 3)) + + const auto table_name = + ArrayFromJSON(utf8(), R"(["foreignTable", "intTable", "sqlite_sequence"])"); + const auto table_type = ArrayFromJSON(utf8(), R"(["table", "table", "table"])"); + + const std::shared_ptr
& expected_table = Table::Make( + SqlSchema::GetTablesSchema(), {catalog_name, schema_name, table_name, table_type}); + + AssertTablesEqual(*expected_table, *table); +} + +TEST_F(TestFlightSqlServer, TestCommandGetTablesWithTableFilter) { + FlightCallOptions options = {}; + std::string* catalog = nullptr; + std::string* schema_filter_pattern = nullptr; + std::string table_filter_pattern = "int%"; + bool include_schema = false; + std::vector* table_types = nullptr; + + ASSERT_OK_AND_ASSIGN( + auto flight_info, + sql_client->GetTables(options, catalog, schema_filter_pattern, + &table_filter_pattern, include_schema, table_types)); + + ASSERT_OK_AND_ASSIGN(auto stream, + sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + + std::shared_ptr
table; + ASSERT_OK(stream->ReadAll(&table)); + + const auto catalog_name = ArrayFromJSON(utf8(), R"([null])"); + const auto schema_name = ArrayFromJSON(utf8(), R"([null])"); + const auto table_name = ArrayFromJSON(utf8(), R"(["intTable"])"); + const auto table_type = ArrayFromJSON(utf8(), R"(["table"])"); + + const std::shared_ptr
& expected_table = Table::Make( + SqlSchema::GetTablesSchema(), {catalog_name, schema_name, table_name, table_type}); + + AssertTablesEqual(*expected_table, *table); +} + +TEST_F(TestFlightSqlServer, TestCommandGetTablesWithTableTypesFilter) { + FlightCallOptions options = {}; + std::string* catalog = nullptr; + std::string* schema_filter_pattern = nullptr; + std::string* table_filter_pattern = nullptr; + bool include_schema = false; + std::vector table_types{"index"}; + + ASSERT_OK_AND_ASSIGN( + auto flight_info, + sql_client->GetTables(options, catalog, schema_filter_pattern, table_filter_pattern, + include_schema, &table_types)); + + ASSERT_OK_AND_ASSIGN(auto stream, + sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + + std::shared_ptr
table; + ASSERT_OK(stream->ReadAll(&table)); + + AssertSchemaEqual(SqlSchema::GetTablesSchema(), table->schema()); + + ASSERT_EQ(table->num_rows(), 0); +} + +TEST_F(TestFlightSqlServer, TestCommandGetTablesWithUnexistenceTableTypeFilter) { + FlightCallOptions options = {}; + std::string* catalog = nullptr; + std::string* schema_filter_pattern = nullptr; + std::string* table_filter_pattern = nullptr; + bool include_schema = false; + std::vector table_types{"table"}; + + ASSERT_OK_AND_ASSIGN( + auto flight_info, + sql_client->GetTables(options, catalog, schema_filter_pattern, table_filter_pattern, + include_schema, &table_types)); + + ASSERT_OK_AND_ASSIGN(auto stream, + sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + + std::shared_ptr
table; + ASSERT_OK(stream->ReadAll(&table)); + + const auto catalog_name = ArrayFromJSON(utf8(), R"([null, null, null])"); + const auto schema_name = ArrayFromJSON(utf8(), R"([null, null, null])"); + const auto table_name = + ArrayFromJSON(utf8(), R"(["foreignTable", "intTable", "sqlite_sequence"])"); + const auto table_type = ArrayFromJSON(utf8(), R"(["table", "table", "table"])"); + + const std::shared_ptr
& expected_table = Table::Make( + SqlSchema::GetTablesSchema(), {catalog_name, schema_name, table_name, table_type}); + + AssertTablesEqual(*expected_table, *table); +} + +TEST_F(TestFlightSqlServer, TestCommandGetTablesWithIncludedSchemas) { + FlightCallOptions options = {}; + std::string* catalog = nullptr; + std::string* schema_filter_pattern = nullptr; + std::string table_filter_pattern = "int%"; + bool include_schema = true; + std::vector* table_types = nullptr; + + ASSERT_OK_AND_ASSIGN( + auto flight_info, + sql_client->GetTables(options, catalog, schema_filter_pattern, + &table_filter_pattern, include_schema, table_types)); + + ASSERT_OK_AND_ASSIGN(auto stream, + sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + + std::shared_ptr
table; + ASSERT_OK(stream->ReadAll(&table)); + + const auto catalog_name = ArrayFromJSON(utf8(), R"([null])"); + const auto schema_name = ArrayFromJSON(utf8(), R"([null])"); + const auto table_name = ArrayFromJSON(utf8(), R"(["intTable"])"); + const auto table_type = ArrayFromJSON(utf8(), R"(["table"])"); + + const std::shared_ptr schema_table = arrow::schema( + {arrow::field("id", int64(), true), arrow::field("keyName", utf8(), true), + arrow::field("value", int64(), true), arrow::field("foreignId", int64(), true)}); + + ASSERT_OK_AND_ASSIGN(auto schema_buffer, ipc::SerializeSchema(*schema_table)); + + std::shared_ptr table_schema; + ArrayFromVector({schema_buffer->ToString()}, &table_schema); + + const std::shared_ptr
& expected_table = + Table::Make(SqlSchema::GetTablesSchemaWithIncludedSchema(), + {catalog_name, schema_name, table_name, table_type, table_schema}); + + AssertTablesEqual(*expected_table, *table); +} + +TEST_F(TestFlightSqlServer, TestCommandGetCatalogs) { + ASSERT_OK_AND_ASSIGN(auto flight_info, sql_client->GetCatalogs({})); + + ASSERT_OK_AND_ASSIGN(auto stream, + sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + + std::shared_ptr
table; + ASSERT_OK(stream->ReadAll(&table)); + + const std::shared_ptr& expected_schema = SqlSchema::GetCatalogsSchema(); + + AssertSchemaEqual(expected_schema, table->schema()); + ASSERT_EQ(0, table->num_rows()); +} + +TEST_F(TestFlightSqlServer, TestCommandGetDbSchemas) { + FlightCallOptions options = {}; + std::string* catalog = nullptr; + std::string* schema_filter_pattern = nullptr; + ASSERT_OK_AND_ASSIGN(auto flight_info, + sql_client->GetDbSchemas(options, catalog, schema_filter_pattern)); + + ASSERT_OK_AND_ASSIGN(auto stream, + sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + + std::shared_ptr
table; + ASSERT_OK(stream->ReadAll(&table)); + + const std::shared_ptr& expected_schema = SqlSchema::GetDbSchemasSchema(); + + AssertSchemaEqual(expected_schema, table->schema()); + ASSERT_EQ(0, table->num_rows()); +} + +TEST_F(TestFlightSqlServer, TestCommandGetTableTypes) { + ASSERT_OK_AND_ASSIGN(auto flight_info, sql_client->GetTableTypes({})); + + ASSERT_OK_AND_ASSIGN(auto stream, + sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + + std::shared_ptr
table; + ASSERT_OK(stream->ReadAll(&table)); + + const auto table_type = ArrayFromJSON(utf8(), R"(["table"])"); + + const std::shared_ptr
& expected_table = + Table::Make(SqlSchema::GetTableTypesSchema(), {table_type}); + AssertTablesEqual(*expected_table, *table); +} + +TEST_F(TestFlightSqlServer, TestCommandStatementUpdate) { + int64_t result; + ASSERT_OK_AND_ASSIGN(result, + sql_client->ExecuteUpdate( + {}, + "INSERT INTO intTable (keyName, value) VALUES " + "('KEYNAME1', 1001), ('KEYNAME2', 1002), ('KEYNAME3', 1003)")); + ASSERT_EQ(3, result); + + ASSERT_OK_AND_ASSIGN(result, sql_client->ExecuteUpdate( + {}, + "UPDATE intTable SET keyName = 'KEYNAME1' " + "WHERE keyName = 'KEYNAME2' OR keyName = 'KEYNAME3'")); + ASSERT_EQ(2, result); + + ASSERT_OK_AND_ASSIGN( + result, + sql_client->ExecuteUpdate({}, "DELETE FROM intTable WHERE keyName = 'KEYNAME1'")); + ASSERT_EQ(3, result); +} + +TEST_F(TestFlightSqlServer, TestCommandPreparedStatementQuery) { + ASSERT_OK_AND_ASSIGN(auto prepared_statement, + sql_client->Prepare({}, "SELECT * FROM intTable")); + + ASSERT_OK_AND_ASSIGN(auto flight_info, prepared_statement->Execute()); + + ASSERT_OK_AND_ASSIGN(auto stream, + sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + + std::shared_ptr
table; + ASSERT_OK(stream->ReadAll(&table)); + + const std::shared_ptr& expected_schema = + arrow::schema({arrow::field("id", int64()), arrow::field("keyName", utf8()), + arrow::field("value", int64()), arrow::field("foreignId", int64())}); + + const auto id_array = ArrayFromJSON(int64(), R"([1, 2, 3, 4])"); + const auto keyname_array = + ArrayFromJSON(utf8(), R"(["one", "zero", "negative one", null])"); + const auto value_array = ArrayFromJSON(int64(), R"([1, 0, -1, null])"); + const auto foreignId_array = ArrayFromJSON(int64(), R"([1, 1, 1, null])"); + + const std::shared_ptr
& expected_table = Table::Make( + expected_schema, {id_array, keyname_array, value_array, foreignId_array}); + + AssertTablesEqual(*expected_table, *table); +} + +TEST_F(TestFlightSqlServer, TestCommandPreparedStatementQueryWithParameterBinding) { + ASSERT_OK_AND_ASSIGN( + auto prepared_statement, + sql_client->Prepare({}, "SELECT * FROM intTable WHERE keyName LIKE ?")); + + auto parameter_schema = prepared_statement->parameter_schema(); + + const std::shared_ptr& expected_parameter_schema = + arrow::schema({arrow::field("parameter_1", example::GetUnknownColumnDataType())}); + + AssertSchemaEqual(expected_parameter_schema, parameter_schema); + + std::shared_ptr type_ids = ArrayFromJSON(int8(), R"([0])"); + std::shared_ptr offsets = ArrayFromJSON(int32(), R"([0])"); + std::shared_ptr string_array = ArrayFromJSON(utf8(), R"(["%one"])"); + std::shared_ptr bytes_array = ArrayFromJSON(binary(), R"([])"); + std::shared_ptr bigint_array = ArrayFromJSON(int64(), R"([])"); + std::shared_ptr double_array = ArrayFromJSON(float64(), R"([])"); + + ASSERT_OK_AND_ASSIGN( + auto parameter_1_array, + DenseUnionArray::Make(*type_ids, *offsets, + {string_array, bytes_array, bigint_array, double_array}, + {"string", "bytes", "bigint", "double"}, {0, 1, 2, 3})); + + const std::shared_ptr& record_batch = + RecordBatch::Make(parameter_schema, 1, {parameter_1_array}); + + ASSERT_OK(prepared_statement->SetParameters(record_batch)); + + ASSERT_OK_AND_ASSIGN(auto flight_info, prepared_statement->Execute()); + + ASSERT_OK_AND_ASSIGN(auto stream, + sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + + std::shared_ptr
table; + ASSERT_OK(stream->ReadAll(&table)); + + const std::shared_ptr& expected_schema = + arrow::schema({arrow::field("id", int64()), arrow::field("keyName", utf8()), + arrow::field("value", int64()), arrow::field("foreignId", int64())}); + + const auto id_array = ArrayFromJSON(int64(), R"([1, 3])"); + const auto keyname_array = ArrayFromJSON(utf8(), R"(["one", "negative one"])"); + const auto value_array = ArrayFromJSON(int64(), R"([1, -1])"); + const auto foreignId_array = ArrayFromJSON(int64(), R"([1, 1])"); + + const std::shared_ptr
& expected_table = Table::Make( + expected_schema, {id_array, keyname_array, value_array, foreignId_array}); + + AssertTablesEqual(*expected_table, *table); +} + +TEST_F(TestFlightSqlServer, TestCommandPreparedStatementUpdateWithParameterBinding) { + ASSERT_OK_AND_ASSIGN( + auto prepared_statement, + sql_client->Prepare( + {}, "INSERT INTO INTTABLE (keyName, value) VALUES ('new_value', ?)")); + + auto parameter_schema = prepared_statement->parameter_schema(); + + const std::shared_ptr& expected_parameter_schema = + arrow::schema({arrow::field("parameter_1", example::GetUnknownColumnDataType())}); + + AssertSchemaEqual(expected_parameter_schema, parameter_schema); + + std::shared_ptr type_ids = ArrayFromJSON(int8(), R"([2])"); + std::shared_ptr offsets = ArrayFromJSON(int32(), R"([0])"); + std::shared_ptr string_array = ArrayFromJSON(utf8(), R"([])"); + std::shared_ptr bytes_array = ArrayFromJSON(binary(), R"([])"); + std::shared_ptr bigint_array = ArrayFromJSON(int64(), R"([999])"); + std::shared_ptr double_array = ArrayFromJSON(float64(), R"([])"); + + ASSERT_OK_AND_ASSIGN( + auto parameter_1_array, + DenseUnionArray::Make(*type_ids, *offsets, + {string_array, bytes_array, bigint_array, double_array}, + {"string", "bytes", "bigint", "double"}, {0, 1, 2, 3})); + + const std::shared_ptr& record_batch = + RecordBatch::Make(parameter_schema, 1, {parameter_1_array}); + + ASSERT_OK(prepared_statement->SetParameters(record_batch)); + + ASSERT_OK_AND_EQ(4, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); + + ASSERT_OK_AND_EQ(1, prepared_statement->ExecuteUpdate()); + + ASSERT_OK_AND_EQ(5, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); + + ASSERT_OK_AND_EQ(1, sql_client->ExecuteUpdate( + {}, "DELETE FROM intTable WHERE keyName = 'new_value'")); + + ASSERT_OK_AND_EQ(4, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); +} + +TEST_F(TestFlightSqlServer, TestCommandPreparedStatementUpdate) { + ASSERT_OK_AND_ASSIGN( + auto prepared_statement, + sql_client->Prepare( + {}, "INSERT INTO INTTABLE (keyName, value) VALUES ('new_value', 999)")); + + ASSERT_OK_AND_EQ(4, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); + + ASSERT_OK_AND_EQ(1, prepared_statement->ExecuteUpdate()); + + ASSERT_OK_AND_EQ(5, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); + + ASSERT_OK_AND_EQ(1, sql_client->ExecuteUpdate( + {}, "DELETE FROM intTable WHERE keyName = 'new_value'")); + + ASSERT_OK_AND_EQ(4, ExecuteCountQuery("SELECT COUNT(*) FROM intTable")); +} + +TEST_F(TestFlightSqlServer, TestCommandGetPrimaryKeys) { + FlightCallOptions options = {}; + TableRef table_ref = {util::nullopt, util::nullopt, "int%"}; + ASSERT_OK_AND_ASSIGN(auto flight_info, sql_client->GetPrimaryKeys(options, table_ref)); + + ASSERT_OK_AND_ASSIGN(auto stream, + sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + + std::shared_ptr
table; + ASSERT_OK(stream->ReadAll(&table)); + + const auto catalog_name = ArrayFromJSON(utf8(), R"([null])"); + const auto schema_name = ArrayFromJSON(utf8(), R"([null])"); + const auto key_name = ArrayFromJSON(utf8(), R"([null])"); + const auto table_name = ArrayFromJSON(utf8(), R"(["intTable"])"); + const auto column_name = ArrayFromJSON(utf8(), R"(["id"])"); + const auto key_sequence = ArrayFromJSON(int32(), R"([1])"); + + const std::shared_ptr
& expected_table = Table::Make( + SqlSchema::GetPrimaryKeysSchema(), + {catalog_name, schema_name, table_name, column_name, key_sequence, key_name}); + + AssertTablesEqual(*expected_table, *table); +} + +TEST_F(TestFlightSqlServer, TestCommandGetImportedKeys) { + FlightCallOptions options = {}; + TableRef table_ref = {util::nullopt, util::nullopt, "intTable"}; + ASSERT_OK_AND_ASSIGN(auto flight_info, sql_client->GetImportedKeys(options, table_ref)); + + ASSERT_OK_AND_ASSIGN(auto stream, + sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + + std::shared_ptr
table; + ASSERT_OK(stream->ReadAll(&table)); + + const auto pk_catalog_name = ArrayFromJSON(utf8(), R"([null])"); + const auto pk_schema_name = ArrayFromJSON(utf8(), R"([null])"); + const auto pk_table_name = ArrayFromJSON(utf8(), R"(["foreignTable"])"); + const auto pk_column_name = ArrayFromJSON(utf8(), R"(["id"])"); + const auto fk_catalog_name = ArrayFromJSON(utf8(), R"([null])"); + const auto fk_schema_name = ArrayFromJSON(utf8(), R"([null])"); + const auto fk_table_name = ArrayFromJSON(utf8(), R"(["intTable"])"); + const auto fk_column_name = ArrayFromJSON(utf8(), R"(["foreignId"])"); + const auto key_sequence = ArrayFromJSON(int32(), R"([0])"); + const auto fk_key_name = ArrayFromJSON(utf8(), R"([null])"); + const auto pk_key_name = ArrayFromJSON(utf8(), R"([null])"); + const auto update_rule = ArrayFromJSON(uint8(), R"([3])"); + const auto delete_rule = ArrayFromJSON(uint8(), R"([3])"); + + const std::shared_ptr
& expected_table = + Table::Make(SqlSchema::GetImportedKeysSchema(), + {pk_catalog_name, pk_schema_name, pk_table_name, pk_column_name, + fk_catalog_name, fk_schema_name, fk_table_name, fk_column_name, + key_sequence, fk_key_name, pk_key_name, update_rule, delete_rule}); + AssertTablesEqual(*expected_table, *table); +} + +TEST_F(TestFlightSqlServer, TestCommandGetExportedKeys) { + FlightCallOptions options = {}; + TableRef table_ref = {util::nullopt, util::nullopt, "foreignTable"}; + ASSERT_OK_AND_ASSIGN(auto flight_info, sql_client->GetExportedKeys(options, table_ref)); + + ASSERT_OK_AND_ASSIGN(auto stream, + sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + + std::shared_ptr
table; + ASSERT_OK(stream->ReadAll(&table)); + + const auto pk_catalog_name = ArrayFromJSON(utf8(), R"([null])"); + const auto pk_schema_name = ArrayFromJSON(utf8(), R"([null])"); + const auto pk_table_name = ArrayFromJSON(utf8(), R"(["foreignTable"])"); + const auto pk_column_name = ArrayFromJSON(utf8(), R"(["id"])"); + const auto fk_catalog_name = ArrayFromJSON(utf8(), R"([null])"); + const auto fk_schema_name = ArrayFromJSON(utf8(), R"([null])"); + const auto fk_table_name = ArrayFromJSON(utf8(), R"(["intTable"])"); + const auto fk_column_name = ArrayFromJSON(utf8(), R"(["foreignId"])"); + const auto key_sequence = ArrayFromJSON(int32(), R"([0])"); + const auto fk_key_name = ArrayFromJSON(utf8(), R"([null])"); + const auto pk_key_name = ArrayFromJSON(utf8(), R"([null])"); + const auto update_rule = ArrayFromJSON(uint8(), R"([3])"); + const auto delete_rule = ArrayFromJSON(uint8(), R"([3])"); + + const std::shared_ptr
& expected_table = + Table::Make(SqlSchema::GetExportedKeysSchema(), + {pk_catalog_name, pk_schema_name, pk_table_name, pk_column_name, + fk_catalog_name, fk_schema_name, fk_table_name, fk_column_name, + key_sequence, fk_key_name, pk_key_name, update_rule, delete_rule}); + AssertTablesEqual(*expected_table, *table); +} + +TEST_F(TestFlightSqlServer, TestCommandGetCrossReference) { + FlightCallOptions options = {}; + TableRef pk_table_ref = {util::nullopt, util::nullopt, "foreignTable"}; + TableRef fk_table_ref = {util::nullopt, util::nullopt, "intTable"}; + ASSERT_OK_AND_ASSIGN(auto flight_info, sql_client->GetCrossReference( + options, pk_table_ref, fk_table_ref)); + + ASSERT_OK_AND_ASSIGN(auto stream, + sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + + std::shared_ptr
table; + ASSERT_OK(stream->ReadAll(&table)); + + const auto pk_catalog_name = ArrayFromJSON(utf8(), R"([null])"); + const auto pk_schema_name = ArrayFromJSON(utf8(), R"([null])"); + const auto pk_table_name = ArrayFromJSON(utf8(), R"(["foreignTable"])"); + const auto pk_column_name = ArrayFromJSON(utf8(), R"(["id"])"); + const auto fk_catalog_name = ArrayFromJSON(utf8(), R"([null])"); + const auto fk_schema_name = ArrayFromJSON(utf8(), R"([null])"); + const auto fk_table_name = ArrayFromJSON(utf8(), R"(["intTable"])"); + const auto fk_column_name = ArrayFromJSON(utf8(), R"(["foreignId"])"); + const auto key_sequence = ArrayFromJSON(int32(), R"([0])"); + const auto fk_key_name = ArrayFromJSON(utf8(), R"([null])"); + const auto pk_key_name = ArrayFromJSON(utf8(), R"([null])"); + const auto update_rule = ArrayFromJSON(uint8(), R"([3])"); + const auto delete_rule = ArrayFromJSON(uint8(), R"([3])"); + + const std::shared_ptr
& expected_table = + Table::Make(SqlSchema::GetCrossReferenceSchema(), + {pk_catalog_name, pk_schema_name, pk_table_name, pk_column_name, + fk_catalog_name, fk_schema_name, fk_table_name, fk_column_name, + key_sequence, fk_key_name, pk_key_name, update_rule, delete_rule}); + AssertTablesEqual(*expected_table, *table); +} + +TEST_F(TestFlightSqlServer, TestCommandGetSqlInfo) { + const auto& sql_info_expected_results = sql::example::GetSqlInfoResultMap(); + std::vector sql_info_ids; + sql_info_ids.reserve(sql_info_expected_results.size()); + for (const auto& sql_info_expected_result : sql_info_expected_results) { + sql_info_ids.push_back(sql_info_expected_result.first); + } + + FlightCallOptions call_options; + ASSERT_OK_AND_ASSIGN(auto flight_info, + sql_client->GetSqlInfo(call_options, sql_info_ids)); + ASSERT_OK_AND_ASSIGN( + auto reader, sql_client->DoGet(call_options, flight_info->endpoints()[0].ticket)); + std::shared_ptr
results; + ASSERT_OK(reader->ReadAll(&results)); + ASSERT_EQ(2, results->num_columns()); + ASSERT_EQ(sql_info_ids.size(), results->num_rows()); + const auto& col_name = results->column(0); + const auto& col_value = results->column(1); + for (int32_t i = 0; i < col_name->num_chunks(); i++) { + const auto* col_name_chunk_data = + col_name->chunk(i)->data()->GetValuesSafe(1); + const auto& col_value_chunk = col_value->chunk(i); + for (int64_t row = 0; row < col_value->length(); row++) { + ASSERT_OK_AND_ASSIGN(const auto& scalar, col_value_chunk->GetScalar(row)); + const SqlInfoDenseUnionValidator validator( + reinterpret_cast(*scalar)); + const auto& expected_result = + sql_info_expected_results.at(col_name_chunk_data[row]); + arrow::util::visit(validator, expected_result); + } + } +} + +TEST_F(TestFlightSqlServer, TestCommandGetSqlInfoNoInfo) { + FlightCallOptions call_options; + ASSERT_OK_AND_ASSIGN(auto flight_info, sql_client->GetSqlInfo(call_options, {999999})); + + EXPECT_RAISES_WITH_MESSAGE_THAT( + KeyError, ::testing::HasSubstr("No information for SQL info number 999999"), + sql_client->DoGet(call_options, flight_info->endpoints()[0].ticket)); +} + +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/sql_info_internal.cc b/cpp/src/arrow/flight/sql/sql_info_internal.cc new file mode 100644 index 00000000000..74718fb7cb5 --- /dev/null +++ b/cpp/src/arrow/flight/sql/sql_info_internal.cc @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/sql/sql_info_internal.h" + +#include "arrow/buffer.h" +#include "arrow/builder.h" + +namespace arrow { +namespace flight { +namespace sql { +namespace internal { + +Status SqlInfoResultAppender::operator()(const std::string& value) { + ARROW_RETURN_NOT_OK(value_builder_->Append(kStringValueIndex)); + ARROW_RETURN_NOT_OK(string_value_builder_->Append(value)); + return Status::OK(); +} + +Status SqlInfoResultAppender::operator()(const bool value) { + ARROW_RETURN_NOT_OK(value_builder_->Append(kBoolValueIndex)); + ARROW_RETURN_NOT_OK(bool_value_builder_->Append(value)); + return Status::OK(); +} + +Status SqlInfoResultAppender::operator()(const int64_t value) { + ARROW_RETURN_NOT_OK(value_builder_->Append(kBigIntValueIndex)); + ARROW_RETURN_NOT_OK(bigint_value_builder_->Append(value)); + return Status::OK(); +} + +Status SqlInfoResultAppender::operator()(const int32_t value) { + ARROW_RETURN_NOT_OK(value_builder_->Append(kInt32BitMaskIndex)); + ARROW_RETURN_NOT_OK(int32_bitmask_builder_->Append(value)); + return Status::OK(); +} + +Status SqlInfoResultAppender::operator()(const std::vector& value) { + ARROW_RETURN_NOT_OK(value_builder_->Append(kStringListIndex)); + ARROW_RETURN_NOT_OK(string_list_builder_->Append()); + auto* string_list_child = + reinterpret_cast(string_list_builder_->value_builder()); + for (const auto& string : value) { + ARROW_RETURN_NOT_OK(string_list_child->Append(string)); + } + return Status::OK(); +} + +Status SqlInfoResultAppender::operator()( + const std::unordered_map>& value) { + ARROW_RETURN_NOT_OK(value_builder_->Append(kInt32ToInt32ListIndex)); + ARROW_RETURN_NOT_OK(int32_to_int32_list_builder_->Append()); + for (const auto& pair : value) { + ARROW_RETURN_NOT_OK( + reinterpret_cast(int32_to_int32_list_builder_->key_builder()) + ->Append(pair.first)); + auto* int32_list_builder = + reinterpret_cast(int32_to_int32_list_builder_->item_builder()); + ARROW_RETURN_NOT_OK(int32_list_builder->Append()); + auto* int32_list_child = + reinterpret_cast(int32_list_builder->value_builder()); + for (const auto& int32 : pair.second) { + ARROW_RETURN_NOT_OK(int32_list_child->Append(int32)); + } + } + return Status::OK(); +} + +SqlInfoResultAppender::SqlInfoResultAppender(DenseUnionBuilder* value_builder) + : value_builder_(value_builder), + string_value_builder_( + reinterpret_cast(value_builder_->child(kStringValueIndex))), + bool_value_builder_( + reinterpret_cast(value_builder_->child(kBoolValueIndex))), + bigint_value_builder_( + reinterpret_cast(value_builder_->child(kBigIntValueIndex))), + int32_bitmask_builder_( + reinterpret_cast(value_builder_->child(kInt32BitMaskIndex))), + string_list_builder_( + reinterpret_cast(value_builder_->child(kStringListIndex))), + int32_to_int32_list_builder_( + reinterpret_cast(value_builder_->child(kInt32ToInt32ListIndex))) {} + +} // namespace internal +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/sql_info_internal.h b/cpp/src/arrow/flight/sql/sql_info_internal.h new file mode 100644 index 00000000000..b18789c2549 --- /dev/null +++ b/cpp/src/arrow/flight/sql/sql_info_internal.h @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/flight/sql/types.h" + +namespace arrow { +namespace flight { +namespace sql { +namespace internal { + +/// \brief Auxiliary class used to populate GetSqlInfo's DenseUnionArray with different +/// data types. +class SqlInfoResultAppender { + public: + /// \brief Append a string to the DenseUnionBuilder. + /// \param[in] value Value to be appended. + Status operator()(const std::string& value); + + /// \brief Append a bool to the DenseUnionBuilder. + /// \param[in] value Value to be appended. + Status operator()(bool value); + + /// \brief Append a int64_t to the DenseUnionBuilder. + /// \param[in] value Value to be appended. + Status operator()(int64_t value); + + /// \brief Append a int32_t to the DenseUnionBuilder. + /// \param[in] value Value to be appended. + Status operator()(int32_t value); + + /// \brief Append a string list to the DenseUnionBuilder. + /// \param[in] value Value to be appended. + Status operator()(const std::vector& value); + + /// \brief Append a int32 to int32 list map to the DenseUnionBuilder. + /// \param[in] value Value to be appended. + Status operator()(const std::unordered_map>& value); + + /// \brief Create a Variant visitor that appends data to given + /// DenseUnionBuilder. \param[in] value_builder DenseUnionBuilder to append data to. + explicit SqlInfoResultAppender(DenseUnionBuilder* value_builder); + + SqlInfoResultAppender(const SqlInfoResultAppender&) = delete; + SqlInfoResultAppender(SqlInfoResultAppender&&) = delete; + SqlInfoResultAppender& operator=(const SqlInfoResultAppender&) = delete; + + private: + DenseUnionBuilder* value_builder_; + + // Builders for each child on dense union + StringBuilder* string_value_builder_; + BooleanBuilder* bool_value_builder_; + Int64Builder* bigint_value_builder_; + Int32Builder* int32_bitmask_builder_; + ListBuilder* string_list_builder_; + MapBuilder* int32_to_int32_list_builder_; + + enum : int8_t { + kStringValueIndex = 0, + kBoolValueIndex = 1, + kBigIntValueIndex = 2, + kInt32BitMaskIndex = 3, + kStringListIndex = 4, + kInt32ToInt32ListIndex = 5 + }; +}; + +} // namespace internal +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/test_app_cli.cc b/cpp/src/arrow/flight/sql/test_app_cli.cc new file mode 100644 index 00000000000..43c37bee2fe --- /dev/null +++ b/cpp/src/arrow/flight/sql/test_app_cli.cc @@ -0,0 +1,197 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include +#include +#include + +#include "arrow/array/builder_binary.h" +#include "arrow/array/builder_primitive.h" +#include "arrow/flight/api.h" +#include "arrow/flight/sql/api.h" +#include "arrow/io/memory.h" +#include "arrow/pretty_print.h" +#include "arrow/status.h" +#include "arrow/table.h" +#include "arrow/util/optional.h" + +using arrow::Result; +using arrow::Schema; +using arrow::Status; +using arrow::flight::ClientAuthHandler; +using arrow::flight::FlightCallOptions; +using arrow::flight::FlightClient; +using arrow::flight::FlightDescriptor; +using arrow::flight::FlightEndpoint; +using arrow::flight::FlightInfo; +using arrow::flight::FlightStreamChunk; +using arrow::flight::FlightStreamReader; +using arrow::flight::Location; +using arrow::flight::Ticket; +using arrow::flight::sql::FlightSqlClient; +using arrow::flight::sql::TableRef; + +DEFINE_string(host, "localhost", "Host to connect to"); +DEFINE_int32(port, 32010, "Port to connect to"); +DEFINE_string(username, "", "Username"); +DEFINE_string(password, "", "Password"); + +DEFINE_string(command, "", "Method to run"); +DEFINE_string(query, "", "Query"); +DEFINE_string(catalog, "", "Catalog"); +DEFINE_string(schema, "", "Schema"); +DEFINE_string(table, "", "Table"); + +Status PrintResultsForEndpoint(FlightSqlClient& client, + const FlightCallOptions& call_options, + const FlightEndpoint& endpoint) { + ARROW_ASSIGN_OR_RAISE(auto stream, client.DoGet(call_options, endpoint.ticket)); + + const arrow::Result>& schema = stream->GetSchema(); + ARROW_RETURN_NOT_OK(schema); + + std::cout << "Schema:" << std::endl; + std::cout << schema->get()->ToString() << std::endl << std::endl; + + std::cout << "Results:" << std::endl; + + FlightStreamChunk chunk; + int64_t num_rows = 0; + + while (true) { + ARROW_RETURN_NOT_OK(stream->Next(&chunk)); + if (chunk.data == nullptr) { + break; + } + std::cout << chunk.data->ToString() << std::endl; + num_rows += chunk.data->num_rows(); + } + + std::cout << "Total: " << num_rows << std::endl; + + return Status::OK(); +} + +Status PrintResults(FlightSqlClient& client, const FlightCallOptions& call_options, + const std::unique_ptr& info) { + const std::vector& endpoints = info->endpoints(); + + for (size_t i = 0; i < endpoints.size(); i++) { + std::cout << "Results from endpoint " << i + 1 << " of " << endpoints.size() + << std::endl; + ARROW_RETURN_NOT_OK(PrintResultsForEndpoint(client, call_options, endpoints[i])); + } + + return Status::OK(); +} + +Status RunMain() { + std::unique_ptr client; + Location location; + ARROW_RETURN_NOT_OK(Location::ForGrpcTcp(FLAGS_host, FLAGS_port, &location)); + ARROW_RETURN_NOT_OK(FlightClient::Connect(location, &client)); + + FlightCallOptions call_options; + + if (!FLAGS_username.empty() || !FLAGS_password.empty()) { + Result> bearer_result = + client->AuthenticateBasicToken({}, FLAGS_username, FLAGS_password); + ARROW_RETURN_NOT_OK(bearer_result); + + call_options.headers.push_back(bearer_result.ValueOrDie()); + } + + FlightSqlClient sql_client(std::move(client)); + + if (FLAGS_command == "ExecuteUpdate") { + ARROW_ASSIGN_OR_RAISE(auto rows, sql_client.ExecuteUpdate(call_options, FLAGS_query)); + + std::cout << "Result: " << rows << std::endl; + + return Status::OK(); + } + + std::unique_ptr info; + + if (FLAGS_command == "Execute") { + ARROW_ASSIGN_OR_RAISE(info, sql_client.Execute(call_options, FLAGS_query)); + } else if (FLAGS_command == "GetCatalogs") { + ARROW_ASSIGN_OR_RAISE(info, sql_client.GetCatalogs(call_options)); + } else if (FLAGS_command == "PreparedStatementExecute") { + ARROW_ASSIGN_OR_RAISE(auto prepared_statement, + sql_client.Prepare(call_options, FLAGS_query)); + ARROW_ASSIGN_OR_RAISE(info, prepared_statement->Execute()); + } else if (FLAGS_command == "PreparedStatementExecuteParameterBinding") { + ARROW_ASSIGN_OR_RAISE(auto prepared_statement, sql_client.Prepare({}, FLAGS_query)); + auto parameter_schema = prepared_statement->parameter_schema(); + auto result_set_schema = prepared_statement->dataset_schema(); + + std::cout << result_set_schema->ToString(false) << std::endl; + arrow::Int64Builder int_builder; + ARROW_RETURN_NOT_OK(int_builder.Append(1)); + std::shared_ptr int_array; + ARROW_RETURN_NOT_OK(int_builder.Finish(&int_array)); + std::shared_ptr result; + result = arrow::RecordBatch::Make(parameter_schema, 1, {int_array}); + + ARROW_RETURN_NOT_OK(prepared_statement->SetParameters(result)); + ARROW_ASSIGN_OR_RAISE(info, prepared_statement->Execute()); + } else if (FLAGS_command == "GetDbSchemas") { + ARROW_ASSIGN_OR_RAISE( + info, sql_client.GetDbSchemas(call_options, &FLAGS_catalog, &FLAGS_schema)); + } else if (FLAGS_command == "GetTableTypes") { + ARROW_ASSIGN_OR_RAISE(info, sql_client.GetTableTypes(call_options)); + } else if (FLAGS_command == "GetTables") { + ARROW_ASSIGN_OR_RAISE( + info, sql_client.GetTables(call_options, &FLAGS_catalog, &FLAGS_schema, + &FLAGS_table, false, nullptr)); + } else if (FLAGS_command == "GetExportedKeys") { + TableRef table_ref = {arrow::util::make_optional(FLAGS_catalog), + arrow::util::make_optional(FLAGS_schema), FLAGS_table}; + ARROW_ASSIGN_OR_RAISE(info, sql_client.GetExportedKeys(call_options, table_ref)); + } else if (FLAGS_command == "GetImportedKeys") { + TableRef table_ref = {arrow::util::make_optional(FLAGS_catalog), + arrow::util::make_optional(FLAGS_schema), FLAGS_table}; + ARROW_ASSIGN_OR_RAISE(info, sql_client.GetImportedKeys(call_options, table_ref)); + } else if (FLAGS_command == "GetPrimaryKeys") { + TableRef table_ref = {arrow::util::make_optional(FLAGS_catalog), + arrow::util::make_optional(FLAGS_schema), FLAGS_table}; + ARROW_ASSIGN_OR_RAISE(info, sql_client.GetPrimaryKeys(call_options, table_ref)); + } else if (FLAGS_command == "GetSqlInfo") { + ARROW_ASSIGN_OR_RAISE(info, sql_client.GetSqlInfo(call_options, {})); + } + + if (info != NULLPTR && + !boost::istarts_with(FLAGS_command, "PreparedStatementExecute")) { + return PrintResults(sql_client, call_options, info); + } + + return Status::OK(); +} + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + Status st = RunMain(); + if (!st.ok()) { + std::cerr << st << std::endl; + return 1; + } + return 0; +} diff --git a/cpp/src/arrow/flight/sql/test_server_cli.cc b/cpp/src/arrow/flight/sql/test_server_cli.cc new file mode 100644 index 00000000000..e0ba5340e8d --- /dev/null +++ b/cpp/src/arrow/flight/sql/test_server_cli.cc @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include +#include +#include +#include + +#include "arrow/flight/server.h" +#include "arrow/flight/sql/example/sqlite_server.h" +#include "arrow/io/test_common.h" +#include "arrow/testing/json_integration.h" +#include "arrow/util/logging.h" + +DEFINE_int32(port, 31337, "Server port to listen on"); + +arrow::Status RunMain() { + arrow::flight::Location location; + ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location)); + arrow::flight::FlightServerOptions options(location); + + std::shared_ptr server; + ARROW_ASSIGN_OR_RAISE(server, + arrow::flight::sql::example::SQLiteFlightSqlServer::Create()) + + ARROW_CHECK_OK(server->Init(options)); + // Exit with a clean error code (0) on SIGTERM + ARROW_CHECK_OK(server->SetShutdownOnSignals({SIGTERM})); + + std::cout << "Server listening on localhost:" << server->port() << std::endl; + ARROW_CHECK_OK(server->Serve()); + + return arrow::Status::OK(); +} + +int main(int argc, char** argv) { + gflags::SetUsageMessage("Integration testing server for Flight SQL."); + gflags::ParseCommandLineFlags(&argc, &argv, true); + + arrow::Status st = RunMain(); + if (!st.ok()) { + std::cerr << st << std::endl; + return 1; + } + return 0; +} diff --git a/cpp/src/arrow/flight/sql/types.h b/cpp/src/arrow/flight/sql/types.h new file mode 100644 index 00000000000..44b8bca4718 --- /dev/null +++ b/cpp/src/arrow/flight/sql/types.h @@ -0,0 +1,890 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/type_fwd.h" +#include "arrow/util/optional.h" +#include "arrow/util/variant.h" + +namespace arrow { +namespace flight { +namespace sql { + +/// \brief Variant supporting all possible types on SQL info. +using SqlInfoResult = + arrow::util::Variant, + std::unordered_map>>; + +/// \brief Map SQL info identifier to its value. +using SqlInfoResultMap = std::unordered_map; + +/// \brief Options to be set in the SqlInfo. +struct SqlInfoOptions { + enum SqlInfo { + // Server Information [0-500): Provides basic information about the Flight SQL Server. + + // Retrieves a UTF-8 string with the name of the Flight SQL Server. + FLIGHT_SQL_SERVER_NAME = 0, + + // Retrieves a UTF-8 string with the native version of the Flight SQL Server. + FLIGHT_SQL_SERVER_VERSION = 1, + + // Retrieves a UTF-8 string with the Arrow format version of the Flight SQL Server. + FLIGHT_SQL_SERVER_ARROW_VERSION = 2, + + /* + * Retrieves a boolean value indicating whether the Flight SQL Server is read only. + * + * Returns: + * - false: if read-write + * - true: if read only + */ + FLIGHT_SQL_SERVER_READ_ONLY = 3, + + // SQL Syntax Information [500-1000): provides information about SQL syntax supported + // by the Flight SQL Server. + + /* + * Retrieves a boolean value indicating whether the Flight SQL Server supports CREATE + * and DROP of catalogs. + * + * Returns: + * - false: if it doesn't support CREATE and DROP of catalogs. + * - true: if it supports CREATE and DROP of catalogs. + */ + SQL_DDL_CATALOG = 500, + + /* + * Retrieves a boolean value indicating whether the Flight SQL Server supports CREATE + * and DROP of schemas. + * + * Returns: + * - false: if it doesn't support CREATE and DROP of schemas. + * - true: if it supports CREATE and DROP of schemas. + */ + SQL_DDL_SCHEMA = 501, + + /* + * Indicates whether the Flight SQL Server supports CREATE and DROP of tables. + * + * Returns: + * - false: if it doesn't support CREATE and DROP of tables. + * - true: if it supports CREATE and DROP of tables. + */ + SQL_DDL_TABLE = 502, + + /* + * Retrieves a uint32 value representing the enu uint32 ordinal for the case + * sensitivity of catalog, table and schema names. + * + * The possible values are listed in + * `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`. + */ + SQL_IDENTIFIER_CASE = 503, + + // Retrieves a UTF-8 string with the supported character(s) used to surround a + // delimited identifier. + SQL_IDENTIFIER_QUOTE_CHAR = 504, + + /* + * Retrieves a uint32 value representing the enu uint32 ordinal for the case + * sensitivity of quoted identifiers. + * + * The possible values are listed in + * `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`. + */ + SQL_QUOTED_IDENTIFIER_CASE = 505, + + /* + * Retrieves a boolean value indicating whether all tables are selectable. + * + * Returns: + * - false: if not all tables are selectable or if none are; + * - true: if all tables are selectable. + */ + SQL_ALL_TABLES_ARE_SELECTABLE = 506, + + /* + * Retrieves the null ordering. + * + * Returns a uint32 ordinal for the null ordering being used, as described in + * `arrow.flight.protocol.sql.SqlNullOrdering`. + */ + SQL_NULL_ORDERING = 507, + + // Retrieves a UTF-8 string list with values of the supported keywords. + SQL_KEYWORDS = 508, + + // Retrieves a UTF-8 string list with values of the supported numeric functions. + SQL_NUMERIC_FUNCTIONS = 509, + + // Retrieves a UTF-8 string list with values of the supported string functions. + SQL_STRING_FUNCTIONS = 510, + + // Retrieves a UTF-8 string list with values of the supported system functions. + SQL_SYSTEM_FUNCTIONS = 511, + + // Retrieves a UTF-8 string list with values of the supported datetime functions. + SQL_DATETIME_FUNCTIONS = 512, + + /* + * Retrieves the UTF-8 string that can be used to escape wildcard characters. + * This is the string that can be used to escape '_' or '%' in the catalog search + * parameters that are a pattern (and therefore use one of the wildcard characters). + * The '_' character represents any single character; the '%' character represents any + * sequence of zero or more characters. + */ + SQL_SEARCH_STRING_ESCAPE = 513, + + /* + * Retrieves a UTF-8 string with all the "extra" characters that can be used in + * unquoted identifier names (those beyond a-z, A-Z, 0-9 and _). + */ + SQL_EXTRA_NAME_CHARACTERS = 514, + + /* + * Retrieves a boolean value indicating whether column aliasing is supported. + * If so, the SQL AS clause can be used to provide names for computed columns or to + * provide alias names for columns as required. + * + * Returns: + * - false: if column aliasing is unsupported; + * - true: if column aliasing is supported. + */ + SQL_SUPPORTS_COLUMN_ALIASING = 515, + + /* + * Retrieves a boolean value indicating whether concatenations between null and + * non-null values being null are supported. + * + * - Returns: + * - false: if concatenations between null and non-null values being null are + * unsupported; + * - true: if concatenations between null and non-null values being null are + * supported. + */ + SQL_NULL_PLUS_NULL_IS_NULL = 516, + + /* + * Retrieves a map where the key is the type to convert from and the value is a list + * with the types to convert to, indicating the supported conversions. Each key and + * each item on the list value is a value to a predefined type on SqlSupportsConvert + * enum. The returned map will be: map> + */ + SQL_SUPPORTS_CONVERT = 517, + + /* + * Retrieves a boolean value indicating whether, when table correlation names are + * supported, they are restricted to being different from the names of the tables. + * + * Returns: + * - false: if table correlation names are unsupported; + * - true: if table correlation names are supported. + */ + SQL_SUPPORTS_TABLE_CORRELATION_NAMES = 518, + + /* + * Retrieves a boolean value indicating whether, when table correlation names are + * supported, they are restricted to being different from the names of the tables. + * + * Returns: + * - false: if different table correlation names are unsupported; + * - true: if different table correlation names are supported + */ + SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES = 519, + + /* + * Retrieves a boolean value indicating whether expressions in ORDER BY lists are + * supported. + * + * Returns: + * - false: if expressions in ORDER BY are unsupported; + * - true: if expressions in ORDER BY are supported; + */ + SQL_SUPPORTS_EXPRESSIONS_IN_ORDER_BY = 520, + + /* + * Retrieves a boolean value indicating whether using a column that is not in the + * SELECT statement in a GROUP BY clause is supported. + * + * Returns: + * - false: if using a column that is not in the SELECT statement in a GROUP BY clause + * is unsupported; + * - true: if using a column that is not in the SELECT statement in a GROUP BY clause + * is supported. + */ + SQL_SUPPORTS_ORDER_BY_UNRELATED = 521, + + /* + * Retrieves the supported GROUP BY commands; + * + * Returns an int32 bitmask value representing the supported commands. + * The returned bitmask should be parsed in order to retrieve the supported commands. + * + * For instance: + * - return 0 (\b0) => [] (GROUP BY is unsupported); + * - return 1 (\b1) => [SQL_GROUP_BY_UNRELATED]; + * - return 2 (\b10) => [SQL_GROUP_BY_BEYOND_SELECT]; + * - return 3 (\b11) => [SQL_GROUP_BY_UNRELATED, SQL_GROUP_BY_BEYOND_SELECT]. + * Valid GROUP BY types are described under + * `arrow.flight.protocol.sql.SqlSupportedGroupBy`. + */ + SQL_SUPPORTED_GROUP_BY = 522, + + /* + * Retrieves a boolean value indicating whether specifying a LIKE escape clause is + * supported. + * + * Returns: + * - false: if specifying a LIKE escape clause is unsupported; + * - true: if specifying a LIKE escape clause is supported. + */ + SQL_SUPPORTS_LIKE_ESCAPE_CLAUSE = 523, + + /* + * Retrieves a boolean value indicating whether columns may be defined as + * non-nullable. + * + * Returns: + * - false: if columns cannot be defined as non-nullable; + * - true: if columns may be defined as non-nullable. + */ + SQL_SUPPORTS_NON_NULLABLE_COLUMNS = 524, + + /* + * Retrieves the supported SQL grammar level as per the ODBC specification. + * + * Returns an int32 bitmask value representing the supported SQL grammar level. + * The returned bitmask should be parsed in order to retrieve the supported grammar + * levels. + * + * For instance: + * - return 0 (\b0) => [] (SQL grammar is unsupported); + * - return 1 (\b1) => [SQL_MINIMUM_GRAMMAR]; + * - return 2 (\b10) => [SQL_CORE_GRAMMAR]; + * - return 3 (\b11) => [SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR]; + * - return 4 (\b100) => [SQL_EXTENDED_GRAMMAR]; + * - return 5 (\b101) => [SQL_MINIMUM_GRAMMAR, SQL_EXTENDED_GRAMMAR]; + * - return 6 (\b110) => [SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR]; + * - return 7 (\b111) => [SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR, + * SQL_EXTENDED_GRAMMAR]. Valid SQL grammar levels are described under + * `arrow.flight.protocol.sql.SupportedSqlGrammar`. + */ + SQL_SUPPORTED_GRAMMAR = 525, + + /* + * Retrieves the supported ANSI92 SQL grammar level. + * + * Returns an int32 bitmask value representing the supported ANSI92 SQL grammar level. + * The returned bitmask should be parsed in order to retrieve the supported commands. + * + * For instance: + * - return 0 (\b0) => [] (ANSI92 SQL grammar is unsupported); + * - return 1 (\b1) => [ANSI92_ENTRY_SQL]; + * - return 2 (\b10) => [ANSI92_INTERMEDIATE_SQL]; + * - return 3 (\b11) => [ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL]; + * - return 4 (\b100) => [ANSI92_FULL_SQL]; + * - return 5 (\b101) => [ANSI92_ENTRY_SQL, ANSI92_FULL_SQL]; + * - return 6 (\b110) => [ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL]; + * - return 7 (\b111) => [ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL]. + * Valid ANSI92 SQL grammar levels are described under + * `arrow.flight.protocol.sql.SupportedAnsi92SqlGrammarLevel`. + */ + SQL_ANSI92_SUPPORTED_LEVEL = 526, + + /* + * Retrieves a boolean value indicating whether the SQL Integrity Enhancement Facility + * is supported. + * + * Returns: + * - false: if the SQL Integrity Enhancement Facility is supported; + * - true: if the SQL Integrity Enhancement Facility is supported. + */ + SQL_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY = 527, + + /* + * Retrieves the support level for SQL OUTER JOINs. + * + * Returns a uint3 uint32 ordinal for the SQL ordering being used, as described in + * `arrow.flight.protocol.sql.SqlOuterJoinsSupportLevel`. + */ + SQL_OUTER_JOINS_SUPPORT_LEVEL = 528, + + // Retrieves a UTF-8 string with the preferred term for "schema". + SQL_SCHEMA_TERM = 529, + + // Retrieves a UTF-8 string with the preferred term for "procedure". + SQL_PROCEDURE_TERM = 530, + + // Retrieves a UTF-8 string with the preferred term for "catalog". + SQL_CATALOG_TERM = 531, + + /* + * Retrieves a boolean value indicating whether a catalog appears at the start of a + * fully qualified table name. + * + * - false: if a catalog does not appear at the start of a fully qualified table name; + * - true: if a catalog appears at the start of a fully qualified table name. + */ + SQL_CATALOG_AT_START = 532, + + /* + * Retrieves the supported actions for a SQL schema. + * + * Returns an int32 bitmask value representing the supported actions for a SQL schema. + * The returned bitmask should be parsed in order to retrieve the supported actions + * for a SQL schema. + * + * For instance: + * - return 0 (\b0) => [] (no supported actions for SQL schema); + * - return 1 (\b1) => [SQL_ELEMENT_IN_PROCEDURE_CALLS]; + * - return 2 (\b10) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS]; + * - return 3 (\b11) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, + * SQL_ELEMENT_IN_INDEX_DEFINITIONS]; + * - return 4 (\b100) => [SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + * - return 5 (\b101) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, + * SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + * - return 6 (\b110) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS, + * SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + * - return 7 (\b111) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, + * SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]. Valid + * actions for a SQL schema described under + * `arrow.flight.protocol.sql.SqlSupportedElementActions`. + */ + SQL_SCHEMAS_SUPPORTED_ACTIONS = 533, + + /* + * Retrieves the supported actions for a SQL schema. + * + * Returns an int32 bitmask value representing the supported actions for a SQL + * catalog. The returned bitmask should be parsed in order to retrieve the supported + * actions for a SQL catalog. + * + * For instance: + * - return 0 (\b0) => [] (no supported actions for SQL catalog); + * - return 1 (\b1) => [SQL_ELEMENT_IN_PROCEDURE_CALLS]; + * - return 2 (\b10) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS]; + * - return 3 (\b11) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, + * SQL_ELEMENT_IN_INDEX_DEFINITIONS]; + * - return 4 (\b100) => [SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + * - return 5 (\b101) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, + * SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + * - return 6 (\b110) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS, + * SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + * - return 7 (\b111) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, + * SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]. Valid + * actions for a SQL catalog are described under + * `arrow.flight.protocol.sql.SqlSupportedElementActions`. + */ + SQL_CATALOGS_SUPPORTED_ACTIONS = 534, + + /* + * Retrieves the supported SQL positioned commands. + * + * Returns an int32 bitmask value representing the supported SQL positioned commands. + * The returned bitmask should be parsed in order to retrieve the supported SQL + * positioned commands. + * + * For instance: + * - return 0 (\b0) => [] (no supported SQL positioned commands); + * - return 1 (\b1) => [SQL_POSITIONED_DELETE]; + * - return 2 (\b10) => [SQL_POSITIONED_UPDATE]; + * - return 3 (\b11) => [SQL_POSITIONED_DELETE, SQL_POSITIONED_UPDATE]. + * Valid SQL positioned commands are described under + * `arrow.flight.protocol.sql.SqlSupportedPositionedCommands`. + */ + SQL_SUPPORTED_POSITIONED_COMMANDS = 535, + + /* + * Retrieves a boolean value indicating whether SELECT FOR UPDATE statements are + * supported. + * + * Returns: + * - false: if SELECT FOR UPDATE statements are unsupported; + * - true: if SELECT FOR UPDATE statements are supported. + */ + SQL_SELECT_FOR_UPDATE_SUPPORTED = 536, + + /* + * Retrieves a boolean value indicating whether stored procedure calls that use the + * stored procedure escape syntax are supported. + * + * Returns: + * - false: if stored procedure calls that use the stored procedure escape syntax are + * unsupported; + * - true: if stored procedure calls that use the stored procedure escape syntax are + * supported. + */ + SQL_STORED_PROCEDURES_SUPPORTED = 537, + + /* + * Retrieves the supported SQL subqueries. + * + * Returns an int32 bitmask value representing the supported SQL subqueries. + * The returned bitmask should be parsed in order to retrieve the supported SQL + * subqueries. + * + * For instance: + * - return 0 (\b0) => [] (no supported SQL subqueries); + * - return 1 (\b1) => [SQL_SUBQUERIES_IN_COMPARISONS]; + * - return 2 (\b10) => [SQL_SUBQUERIES_IN_EXISTS]; + * - return 3 (\b11) => [SQL_SUBQUERIES_IN_COMPARISONS, + * SQL_SUBQUERIES_IN_EXISTS]; + * - return 4 (\b100) => [SQL_SUBQUERIES_IN_INS]; + * - return 5 (\b101) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS]; + * - return 6 (\b110) => [SQL_SUBQUERIES_IN_COMPARISONS, + * SQL_SUBQUERIES_IN_EXISTS]; + * - return 7 (\b111) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, + * SQL_SUBQUERIES_IN_INS]; + * - return 8 (\b1000) => [SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - return 9 (\b1001) => [SQL_SUBQUERIES_IN_COMPARISONS, + * SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - return 10 (\b1010) => [SQL_SUBQUERIES_IN_EXISTS, + * SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - return 11 (\b1011) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, + * SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - return 12 (\b1100) => [SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - return 13 (\b1101) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS, + * SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - return 14 (\b1110) => [SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS, + * SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - return 15 (\b1111) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, + * SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - ... + * Valid SQL subqueries are described under + * `arrow.flight.protocol.sql.SqlSupportedSubqueries`. + */ + SQL_SUPPORTED_SUBQUERIES = 538, + + /* + * Retrieves a boolean value indicating whether correlated subqueries are supported. + * + * Returns: + * - false: if correlated subqueries are unsupported; + * - true: if correlated subqueries are supported. + */ + SQL_CORRELATED_SUBQUERIES_SUPPORTED = 539, + + /* + * Retrieves the supported SQL UNIONs. + * + * Returns an int32 bitmask value representing the supported SQL UNIONs. + * The returned bitmask should be parsed in order to retrieve the supported SQL + * UNIONs. + * + * For instance: + * - return 0 (\b0) => [] (no supported SQL positioned commands); + * - return 1 (\b1) => [SQL_UNION]; + * - return 2 (\b10) => [SQL_UNION_ALL]; + * - return 3 (\b11) => [SQL_UNION, SQL_UNION_ALL]. + * Valid SQL positioned commands are described under + * `arrow.flight.protocol.sql.SqlSupportedUnions`. + */ + SQL_SUPPORTED_UNIONS = 540, + + // Retrieves a uint32 value representing the maximum number of hex characters allowed + // in an inline binary literal. + SQL_MAX_BINARY_LITERAL_LENGTH = 541, + + // Retrieves a uint32 value representing the maximum number of characters allowed for + // a character literal. + SQL_MAX_CHAR_LITERAL_LENGTH = 542, + + // Retrieves a uint32 value representing the maximum number of characters allowed for + // a column name. + SQL_MAX_COLUMN_NAME_LENGTH = 543, + + // Retrieves a uint32 value representing the the maximum number of columns allowed in + // a GROUP BY clause. + SQL_MAX_COLUMNS_IN_GROUP_BY = 544, + + // Retrieves a uint32 value representing the maximum number of columns allowed in an + // index. + SQL_MAX_COLUMNS_IN_INDEX = 545, + + // Retrieves a uint32 value representing the maximum number of columns allowed in an + // ORDER BY clause. + SQL_MAX_COLUMNS_IN_ORDER_BY = 546, + + // Retrieves a uint32 value representing the maximum number of columns allowed in a + // SELECT list. + SQL_MAX_COLUMNS_IN_SELECT = 547, + + // Retrieves a uint32 value representing the maximum number of columns allowed in a + // table. + SQL_MAX_COLUMNS_IN_TABLE = 548, + + // Retrieves a uint32 value representing the maximum number of concurrent connections + // possible. + SQL_MAX_CONNECTIONS = 549, + + // Retrieves a uint32 value the maximum number of characters allowed in a cursor name. + SQL_MAX_CURSOR_NAME_LENGTH = 550, + + /* + * Retrieves a uint32 value representing the maximum number of bytes allowed for an + * index, including all of the parts of the index. + */ + SQL_MAX_INDEX_LENGTH = 551, + + // Retrieves a uint32 value representing the maximum number of characters allowed in a + // procedure name. + SQL_SCHEMA_NAME_LENGTH = 552, + + // Retrieves a uint32 value representing the maximum number of bytes allowed in a + // single row. + SQL_MAX_PROCEDURE_NAME_LENGTH = 553, + + // Retrieves a uint32 value representing the maximum number of characters allowed in a + // catalog name. + SQL_MAX_CATALOG_NAME_LENGTH = 554, + + // Retrieves a uint32 value representing the maximum number of bytes allowed in a + // single row. + SQL_MAX_ROW_SIZE = 555, + + /* + * Retrieves a boolean indicating whether the return value for the JDBC method + * getMaxRowSize includes the SQL data types LONGVARCHAR and LONGVARBINARY. + * + * Returns: + * - false: if return value for the JDBC method getMaxRowSize does + * not include the SQL data types LONGVARCHAR and LONGVARBINARY; + * - true: if return value for the JDBC method getMaxRowSize includes + * the SQL data types LONGVARCHAR and LONGVARBINARY. + */ + SQL_MAX_ROW_SIZE_INCLUDES_BLOBS = 556, + + /* + * Retrieves a uint32 value representing the maximum number of characters allowed for + * an SQL statement; a result of 0 (zero) means that there is no limit or the limit is + * not known. + */ + SQL_MAX_STATEMENT_LENGTH = 557, + + // Retrieves a uint32 value representing the maximum number of active statements that + // can be open at the same time. + SQL_MAX_STATEMENTS = 558, + + // Retrieves a uint32 value representing the maximum number of characters allowed in a + // table name. + SQL_MAX_TABLE_NAME_LENGTH = 559, + + // Retrieves a uint32 value representing the maximum number of tables allowed in a + // SELECT statement. + SQL_MAX_TABLES_IN_SELECT = 560, + + // Retrieves a uint32 value representing the maximum number of characters allowed in a + // user name. + SQL_MAX_USERNAME_LENGTH = 561, + + /* + * Retrieves this database's default transaction isolation level as described in + * `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`. + * + * Returns a uint32 ordinal for the SQL transaction isolation level. + */ + SQL_DEFAULT_TRANSACTION_ISOLATION = 562, + + /* + * Retrieves a boolean value indicating whether transactions are supported. If not, + * invoking the method commit is a noop, and the isolation level is + * `arrow.flight.protocol.sql.SqlTransactionIsolationLevel.TRANSACTION_NONE`. + * + * Returns: + * - false: if transactions are unsupported; + * - true: if transactions are supported. + */ + SQL_TRANSACTIONS_SUPPORTED = 563, + + /* + * Retrieves the supported transactions isolation levels. + * + * Returns an int32 bitmask value representing the supported transactions isolation + * levels. The returned bitmask should be parsed in order to retrieve the supported + * transactions isolation levels. + * + * For instance: + * - return 0 (\b0) => [] (no supported SQL transactions isolation levels); + * - return 1 (\b1) => [SQL_TRANSACTION_NONE]; + * - return 2 (\b10) => [SQL_TRANSACTION_READ_UNCOMMITTED]; + * - return 3 (\b11) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED]; + * - return 4 (\b100) => [SQL_TRANSACTION_REPEATABLE_READ]; + * - return 5 (\b101) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 6 (\b110) => [SQL_TRANSACTION_READ_UNCOMMITTED, + * SQL_TRANSACTION_REPEATABLE_READ]; + * - return 7 (\b111) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, + * SQL_TRANSACTION_REPEATABLE_READ]; + * - return 8 (\b1000) => [SQL_TRANSACTION_REPEATABLE_READ]; + * - return 9 (\b1001) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 10 (\b1010) => [SQL_TRANSACTION_READ_UNCOMMITTED, + * SQL_TRANSACTION_REPEATABLE_READ]; + * - return 11 (\b1011) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, + * SQL_TRANSACTION_REPEATABLE_READ]; + * - return 12 (\b1100) => [SQL_TRANSACTION_REPEATABLE_READ, + * SQL_TRANSACTION_REPEATABLE_READ]; + * - return 13 (\b1101) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ, + * SQL_TRANSACTION_REPEATABLE_READ]; + * - return 14 (\b1110) => [SQL_TRANSACTION_READ_UNCOMMITTED, + * SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 15 (\b1111) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, + * SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 16 (\b10000) => [SQL_TRANSACTION_SERIALIZABLE]; + * - ... + * Valid SQL positioned commands are described under + * `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`. + */ + SQL_SUPPORTED_TRANSACTIONS_ISOLATION_LEVELS = 564, + + /* + * Retrieves a boolean value indicating whether a data definition statement within a + * transaction forces the transaction to commit. + * + * Returns: + * - false: if a data definition statement within a transaction does not force the + * transaction to commit; + * - true: if a data definition statement within a transaction forces the transaction + * to commit. + */ + SQL_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT = 565, + + /* + * Retrieves a boolean value indicating whether a data definition statement within a + * transaction is ignored. + * + * Returns: + * - false: if a data definition statement within a transaction is taken into account; + * - true: a data definition statement within a transaction is ignored. + */ + SQL_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED = 566, + + /* + * Retrieves an int32 bitmask value representing the supported result set types. + * The returned bitmask should be parsed in order to retrieve the supported result set + * types. + * + * For instance: + * - return 0 (\b0) => [] (no supported result set types); + * - return 1 (\b1) => [SQL_RESULT_SET_TYPE_UNSPECIFIED]; + * - return 2 (\b10) => [SQL_RESULT_SET_TYPE_FORWARD_ONLY]; + * - return 3 (\b11) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, + * SQL_RESULT_SET_TYPE_FORWARD_ONLY]; + * - return 4 (\b100) => [SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; + * - return 5 (\b101) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, + * SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; + * - return 6 (\b110) => [SQL_RESULT_SET_TYPE_FORWARD_ONLY, + * SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; + * - return 7 (\b111) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, + * SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; + * - return 8 (\b1000) => [SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE]; + * - ... + * Valid result set types are described under + * `arrow.flight.protocol.sql.SqlSupportedResultSetType`. + */ + SQL_SUPPORTED_RESULT_SET_TYPES = 567, + + /* + * Returns an int32 bitmask value concurrency types supported for + * `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_UNSPECIFIED`. + * + * For instance: + * - return 0 (\b0) => [] (no supported concurrency types for this result set type) + * - return 1 (\b1) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED] + * - return 2 (\b10) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, + * SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 4 (\b100) => [SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, + * SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, + * SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, + * SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] Valid + * result set types are described under + * `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + */ + SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_UNSPECIFIED = 568, + + /* + * Returns an int32 bitmask value concurrency types supported for + * `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_FORWARD_ONLY`. + * + * For instance: + * - return 0 (\b0) => [] (no supported concurrency types for this result set type) + * - return 1 (\b1) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED] + * - return 2 (\b10) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, + * SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 4 (\b100) => [SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, + * SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, + * SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, + * SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] Valid + * result set types are described under + * `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + */ + SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_FORWARD_ONLY = 569, + + /* + * Returns an int32 bitmask value concurrency types supported for + * `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE`. + * + * For instance: + * - return 0 (\b0) => [] (no supported concurrency types for this result set type) + * - return 1 (\b1) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED] + * - return 2 (\b10) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, + * SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 4 (\b100) => [SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, + * SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, + * SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, + * SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] Valid + * result set types are described under + * `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + */ + SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_SENSITIVE = 570, + + /* + * Returns an int32 bitmask value concurrency types supported for + * `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE`. + * + * For instance: + * - return 0 (\b0) => [] (no supported concurrency types for this result set type) + * - return 1 (\b1) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED] + * - return 2 (\b10) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, + * SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 4 (\b100) => [SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, + * SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, + * SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, + * SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] Valid + * result set types are described under + * `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + */ + SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_INSENSITIVE = 571, + + /* + * Retrieves a boolean value indicating whether this database supports batch updates. + * + * - false: if this database does not support batch updates; + * - true: if this database supports batch updates. + */ + SQL_BATCH_UPDATES_SUPPORTED = 572, + + /* + * Retrieves a boolean value indicating whether this database supports savepoints. + * + * Returns: + * - false: if this database does not support savepoints; + * - true: if this database supports savepoints. + */ + SQL_SAVEPOINTS_SUPPORTED = 573, + + /* + * Retrieves a boolean value indicating whether named parameters are supported in + * callable statements. + * + * Returns: + * - false: if named parameters in callable statements are unsupported; + * - true: if named parameters in callable statements are supported. + */ + SQL_NAMED_PARAMETERS_SUPPORTED = 574, + + /* + * Retrieves a boolean value indicating whether updates made to a LOB are made on a + * copy or directly to the LOB. + * + * Returns: + * - false: if updates made to a LOB are made directly to the LOB; + * - true: if updates made to a LOB are made on a copy. + */ + SQL_LOCATORS_UPDATE_COPY = 575, + + /* + * Retrieves a boolean value indicating whether invoking user-defined or vendor + * functions using the stored procedure escape syntax is supported. + * + * Returns: + * - false: if invoking user-defined or vendor functions using the stored procedure + * escape syntax is unsupported; + * - true: if invoking user-defined or vendor functions using the stored procedure + * escape syntax is supported. + */ + SQL_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED = 576, + }; + + enum SqlSupportedCaseSensitivity { + SQL_CASE_SENSITIVITY_UNKNOWN = 0, + SQL_CASE_SENSITIVITY_CASE_INSENSITIVE = 1, + SQL_CASE_SENSITIVITY_UPPERCASE = 2, + }; + + enum SqlNullOrdering { + SQL_NULLS_SORTED_HIGH = 0, + SQL_NULLS_SORTED_LOW = 1, + SQL_NULLS_SORTED_AT_START = 2, + SQL_NULLS_SORTED_AT_END = 3, + }; + + enum SqlSupportsConvert { + SQL_CONVERT_BIGINT = 0, + SQL_CONVERT_BINARY = 1, + SQL_CONVERT_BIT = 2, + SQL_CONVERT_CHAR = 3, + SQL_CONVERT_DATE = 4, + SQL_CONVERT_DECIMAL = 5, + SQL_CONVERT_FLOAT = 6, + SQL_CONVERT_INTEGER = 7, + SQL_CONVERT_INTERVAL_DAY_TIME = 8, + SQL_CONVERT_INTERVAL_YEAR_MONTH = 9, + SQL_CONVERT_LONGVARBINARY = 10, + SQL_CONVERT_LONGVARCHAR = 11, + SQL_CONVERT_NUMERIC = 12, + SQL_CONVERT_REAL = 13, + SQL_CONVERT_SMALLINT = 14, + SQL_CONVERT_TIME = 15, + SQL_CONVERT_TIMESTAMP = 16, + SQL_CONVERT_TINYINT = 17, + SQL_CONVERT_VARBINARY = 18, + SQL_CONVERT_VARCHAR = 19, + }; +}; + +/// \brief Table reference, optionally containing table's catalog and db_schema. +struct TableRef { + util::optional catalog; + util::optional db_schema; + std::string table; +}; + +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/test_integration.cc b/cpp/src/arrow/flight/test_integration.cc deleted file mode 100644 index 29ce5601f37..00000000000 --- a/cpp/src/arrow/flight/test_integration.cc +++ /dev/null @@ -1,270 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/flight/test_integration.h" -#include "arrow/flight/client_middleware.h" -#include "arrow/flight/server_middleware.h" -#include "arrow/flight/test_util.h" -#include "arrow/flight/types.h" -#include "arrow/ipc/dictionary.h" - -#include -#include -#include -#include -#include - -namespace arrow { -namespace flight { - -/// \brief The server for the basic auth integration test. -class AuthBasicProtoServer : public FlightServerBase { - Status DoAction(const ServerCallContext& context, const Action& action, - std::unique_ptr* result) override { - // Respond with the authenticated username. - auto buf = Buffer::FromString(context.peer_identity()); - *result = std::unique_ptr(new SimpleResultStream({Result{buf}})); - return Status::OK(); - } -}; - -/// Validate the result of a DoAction. -Status CheckActionResults(FlightClient* client, const Action& action, - std::vector results) { - std::unique_ptr stream; - RETURN_NOT_OK(client->DoAction(action, &stream)); - std::unique_ptr result; - for (const std::string& expected : results) { - RETURN_NOT_OK(stream->Next(&result)); - if (!result) { - return Status::Invalid("Action result stream ended early"); - } - const auto actual = result->body->ToString(); - if (expected != actual) { - return Status::Invalid("Got wrong result; expected", expected, "but got", actual); - } - } - RETURN_NOT_OK(stream->Next(&result)); - if (result) { - return Status::Invalid("Action result stream had too many entries"); - } - return Status::OK(); -} - -// The expected username for the basic auth integration test. -constexpr auto kAuthUsername = "arrow"; -// The expected password for the basic auth integration test. -constexpr auto kAuthPassword = "flight"; - -/// \brief A scenario testing the basic auth protobuf. -class AuthBasicProtoScenario : public Scenario { - Status MakeServer(std::unique_ptr* server, - FlightServerOptions* options) override { - server->reset(new AuthBasicProtoServer()); - options->auth_handler = - std::make_shared(kAuthUsername, kAuthPassword); - return Status::OK(); - } - - Status MakeClient(FlightClientOptions* options) override { return Status::OK(); } - - Status RunClient(std::unique_ptr client) override { - Action action; - std::unique_ptr stream; - std::shared_ptr detail; - const auto& status = client->DoAction(action, &stream); - detail = FlightStatusDetail::UnwrapStatus(status); - // This client is unauthenticated and should fail. - if (detail == nullptr) { - return Status::Invalid("Expected UNAUTHENTICATED but got ", status.ToString()); - } - if (detail->code() != FlightStatusCode::Unauthenticated) { - return Status::Invalid("Expected UNAUTHENTICATED but got ", detail->ToString()); - } - - auto client_handler = std::unique_ptr( - new TestClientBasicAuthHandler(kAuthUsername, kAuthPassword)); - RETURN_NOT_OK(client->Authenticate({}, std::move(client_handler))); - return CheckActionResults(client.get(), action, {kAuthUsername}); - } -}; - -/// \brief Test middleware that echoes back the value of a particular -/// incoming header. -/// -/// In Java, gRPC may consolidate this header with HTTP/2 trailers if -/// the call fails, but C++ generally doesn't do this. The integration -/// test confirms the presence of this header to ensure we can read it -/// regardless of what gRPC does. -class TestServerMiddleware : public ServerMiddleware { - public: - explicit TestServerMiddleware(std::string received) : received_(received) {} - void SendingHeaders(AddCallHeaders* outgoing_headers) override { - outgoing_headers->AddHeader("x-middleware", received_); - } - void CallCompleted(const Status& status) override {} - - std::string name() const override { return "GrpcTrailersMiddleware"; } - - private: - std::string received_; -}; - -class TestServerMiddlewareFactory : public ServerMiddlewareFactory { - public: - Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers, - std::shared_ptr* middleware) override { - const std::pair& iter_pair = - incoming_headers.equal_range("x-middleware"); - std::string received = ""; - if (iter_pair.first != iter_pair.second) { - const util::string_view& value = (*iter_pair.first).second; - received = std::string(value); - } - *middleware = std::make_shared(received); - return Status::OK(); - } -}; - -/// \brief Test middleware that adds a header on every outgoing call, -/// and gets the value of the expected header sent by the server. -class TestClientMiddleware : public ClientMiddleware { - public: - explicit TestClientMiddleware(std::string* received_header) - : received_header_(received_header) {} - - void SendingHeaders(AddCallHeaders* outgoing_headers) { - outgoing_headers->AddHeader("x-middleware", "expected value"); - } - - void ReceivedHeaders(const CallHeaders& incoming_headers) { - // We expect the server to always send this header. gRPC/Java may - // send it in trailers instead of headers, so we expect Flight to - // account for this. - const std::pair& iter_pair = - incoming_headers.equal_range("x-middleware"); - if (iter_pair.first != iter_pair.second) { - const util::string_view& value = (*iter_pair.first).second; - *received_header_ = std::string(value); - } - } - - void CallCompleted(const Status& status) {} - - private: - std::string* received_header_; -}; - -class TestClientMiddlewareFactory : public ClientMiddlewareFactory { - public: - void StartCall(const CallInfo& info, std::unique_ptr* middleware) { - *middleware = - std::unique_ptr(new TestClientMiddleware(&received_header_)); - } - - std::string received_header_; -}; - -/// \brief The server used for testing middleware. Implements only one -/// endpoint, GetFlightInfo, in such a way that it either succeeds or -/// returns an error based on the input, in order to test both paths. -class MiddlewareServer : public FlightServerBase { - Status GetFlightInfo(const ServerCallContext& context, - const FlightDescriptor& descriptor, - std::unique_ptr* result) override { - if (descriptor.type == FlightDescriptor::DescriptorType::CMD && - descriptor.cmd == "success") { - // Don't fail - std::shared_ptr schema = arrow::schema({}); - Location location; - // Return a fake location - the test doesn't read it - RETURN_NOT_OK(Location::ForGrpcTcp("localhost", 10010, &location)); - std::vector endpoints{FlightEndpoint{{"foo"}, {location}}}; - ARROW_ASSIGN_OR_RAISE(auto info, - FlightInfo::Make(*schema, descriptor, endpoints, -1, -1)); - *result = std::unique_ptr(new FlightInfo(info)); - return Status::OK(); - } - // Fail the call immediately. In some gRPC implementations, this - // means that gRPC sends only HTTP/2 trailers and not headers. We want - // Flight middleware to be agnostic to this difference. - return Status::UnknownError("Unknown"); - } -}; - -/// \brief The middleware scenario. -/// -/// This tests that the server and client get expected header values. -class MiddlewareScenario : public Scenario { - Status MakeServer(std::unique_ptr* server, - FlightServerOptions* options) override { - options->middleware.push_back( - {"grpc_trailers", std::make_shared()}); - server->reset(new MiddlewareServer()); - return Status::OK(); - } - - Status MakeClient(FlightClientOptions* options) override { - client_middleware_ = std::make_shared(); - options->middleware.push_back(client_middleware_); - return Status::OK(); - } - - Status RunClient(std::unique_ptr client) override { - std::unique_ptr info; - // This call is expected to fail. In gRPC/Java, this causes the - // server to combine headers and HTTP/2 trailers, so to read the - // expected header, Flight must check for both headers and - // trailers. - if (client->GetFlightInfo(FlightDescriptor::Command(""), &info).ok()) { - return Status::Invalid("Expected call to fail"); - } - if (client_middleware_->received_header_ != "expected value") { - return Status::Invalid( - "Expected to receive header 'x-middleware: expected value', but instead got: '", - client_middleware_->received_header_, "'"); - } - std::cerr << "Headers received successfully on failing call." << std::endl; - - // This call should succeed - client_middleware_->received_header_ = ""; - RETURN_NOT_OK(client->GetFlightInfo(FlightDescriptor::Command("success"), &info)); - if (client_middleware_->received_header_ != "expected value") { - return Status::Invalid( - "Expected to receive header 'x-middleware: expected value', but instead got '", - client_middleware_->received_header_, "'"); - } - std::cerr << "Headers received successfully on passing call." << std::endl; - return Status::OK(); - } - - std::shared_ptr client_middleware_; -}; - -Status GetScenario(const std::string& scenario_name, std::shared_ptr* out) { - if (scenario_name == "auth:basic_proto") { - *out = std::make_shared(); - return Status::OK(); - } else if (scenario_name == "middleware") { - *out = std::make_shared(); - return Status::OK(); - } - return Status::KeyError("Scenario not found: ", scenario_name); -} - -} // namespace flight -} // namespace arrow diff --git a/cpp/vcpkg.json b/cpp/vcpkg.json index 64ece20926a..556643841a9 100644 --- a/cpp/vcpkg.json +++ b/cpp/vcpkg.json @@ -44,6 +44,7 @@ "rapidjson", "re2", "snappy", + "sqlite3", "thrift", "utf8proc", "zlib", diff --git a/dev/archery/archery/integration/runner.py b/dev/archery/archery/integration/runner.py index 96ebd48912b..74bbed1fc4f 100644 --- a/dev/archery/archery/integration/runner.py +++ b/dev/archery/archery/integration/runner.py @@ -382,6 +382,11 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True, description="Ensure headers are propagated via middleware.", skip={"Rust"} # TODO(ARROW-10961): tonic upgrade needed ), + Scenario( + "flight_sql", + description="Ensure Flight SQL protocol is working as expected.", + skip={"Rust", "Go"} + ), ] runner = IntegrationRunner(json_files, flight_scenarios, testers, **kwargs) diff --git a/dev/archery/archery/integration/tester_java.py b/dev/archery/archery/integration/tester_java.py index 5104a0cc755..69c6e54e056 100644 --- a/dev/archery/archery/integration/tester_java.py +++ b/dev/archery/archery/integration/tester_java.py @@ -49,11 +49,12 @@ class JavaTester(Tester): ARROW_FLIGHT_JAR = os.environ.get( 'ARROW_FLIGHT_JAVA_INTEGRATION_JAR', os.path.join(ARROW_ROOT_DEFAULT, - 'java/flight/flight-core/target/flight-core-{}-' - 'jar-with-dependencies.jar'.format(_arrow_version))) - ARROW_FLIGHT_SERVER = ('org.apache.arrow.flight.example.integration.' + 'java/flight/flight-integration-tests/target/' + 'flight-integration-tests-{}-jar-with-dependencies.jar' + .format(_arrow_version))) + ARROW_FLIGHT_SERVER = ('org.apache.arrow.flight.integration.tests.' 'IntegrationTestServer') - ARROW_FLIGHT_CLIENT = ('org.apache.arrow.flight.example.integration.' + ARROW_FLIGHT_CLIENT = ('org.apache.arrow.flight.integration.tests.' 'IntegrationTestClient') name = 'Java' diff --git a/docker-compose.yml b/docker-compose.yml index e681eb33bcd..63c7c911918 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -268,6 +268,7 @@ services: ARROW_CXXFLAGS: "-Og" # Shrink test runtime by enabling minimal optimizations ARROW_ENABLE_TIMING_TESTS: # inherit ARROW_FLIGHT: "OFF" + ARROW_FLIGHT_SQL: "OFF" ARROW_GANDIVA: "OFF" ARROW_JEMALLOC: "OFF" ARROW_RUNTIME_SIMD_LEVEL: "AVX2" # AVX512 not supported by Valgrind (ARROW-9851) @@ -1021,6 +1022,7 @@ services: environment: <<: *ccache ARROW_FLIGHT: "OFF" + ARROW_FLIGHT_SQL: "OFF" ARROW_GANDIVA: "OFF" volumes: *conda-volumes command: @@ -1610,6 +1612,7 @@ services: environment: <<: *ccache ARROW_FLIGHT: "OFF" + ARROW_FLIGHT_SQL: "OFF" ARROW_GANDIVA: "OFF" ARROW_PLASMA: "OFF" ARROW_HIVESERVER2: "ON" diff --git a/format/FlightSql.proto b/format/FlightSql.proto new file mode 100644 index 00000000000..23ada5c6e48 --- /dev/null +++ b/format/FlightSql.proto @@ -0,0 +1,1336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; +import "google/protobuf/descriptor.proto"; + +option java_package = "org.apache.arrow.flight.sql.impl"; +package arrow.flight.protocol.sql; + +/* + * Represents a metadata request. Used in the command member of FlightDescriptor + * for the following RPC calls: + * - GetSchema: return the Arrow schema of the query. + * - GetFlightInfo: execute the metadata request. + * + * The returned Arrow schema will be: + * < + * info_name: uint32 not null, + * value: dense_union< + * string_value: utf8, + * bool_value: bool, + * bigint_value: int64, + * int32_bitmask: int32, + * string_list: list + * int32_to_int32_list_map: map> + * > + * where there is one row per requested piece of metadata information. + */ +message CommandGetSqlInfo { + option (experimental) = true; + + /* + * Values are modelled after ODBC's SQLGetInfo() function. This information is intended to provide + * Flight SQL clients with basic, SQL syntax and SQL functions related information. + * More information types can be added in future releases. + * E.g. more SQL syntax support types, scalar functions support, type conversion support etc. + * + * Note that the set of metadata may expand. + * + * Initially, Flight SQL will support the following information types: + * - Server Information - Range [0-500) + * - Syntax Information - Range [500-1000) + * Range [0-10,000) is reserved for defaults (see SqlInfo enum for default options). + * Custom options should start at 10,000. + * + * If omitted, then all metadata will be retrieved. + * Flight SQL Servers may choose to include additional metadata above and beyond the specified set, however they must + * at least return the specified set. IDs ranging from 0 to 10,000 (exclusive) are reserved for future use. + * If additional metadata is included, the metadata IDs should start from 10,000. + */ + repeated uint32 info = 1; +} + +// Options for CommandGetSqlInfo. +enum SqlInfo { + + // Server Information [0-500): Provides basic information about the Flight SQL Server. + + // Retrieves a UTF-8 string with the name of the Flight SQL Server. + FLIGHT_SQL_SERVER_NAME = 0; + + // Retrieves a UTF-8 string with the native version of the Flight SQL Server. + FLIGHT_SQL_SERVER_VERSION = 1; + + // Retrieves a UTF-8 string with the Arrow format version of the Flight SQL Server. + FLIGHT_SQL_SERVER_ARROW_VERSION = 2; + + /* + * Retrieves a boolean value indicating whether the Flight SQL Server is read only. + * + * Returns: + * - false: if read-write + * - true: if read only + */ + FLIGHT_SQL_SERVER_READ_ONLY = 3; + + + // SQL Syntax Information [500-1000): provides information about SQL syntax supported by the Flight SQL Server. + + /* + * Retrieves a boolean value indicating whether the Flight SQL Server supports CREATE and DROP of catalogs. + * + * Returns: + * - false: if it doesn't support CREATE and DROP of catalogs. + * - true: if it supports CREATE and DROP of catalogs. + */ + SQL_DDL_CATALOG = 500; + + /* + * Retrieves a boolean value indicating whether the Flight SQL Server supports CREATE and DROP of schemas. + * + * Returns: + * - false: if it doesn't support CREATE and DROP of schemas. + * - true: if it supports CREATE and DROP of schemas. + */ + SQL_DDL_SCHEMA = 501; + + /* + * Indicates whether the Flight SQL Server supports CREATE and DROP of tables. + * + * Returns: + * - false: if it doesn't support CREATE and DROP of tables. + * - true: if it supports CREATE and DROP of tables. + */ + SQL_DDL_TABLE = 502; + + /* + * Retrieves a uint32 value representing the enu uint32 ordinal for the case sensitivity of catalog, table, schema and table names. + * + * The possible values are listed in `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`. + */ + SQL_IDENTIFIER_CASE = 503; + + // Retrieves a UTF-8 string with the supported character(s) used to surround a delimited identifier. + SQL_IDENTIFIER_QUOTE_CHAR = 504; + + /* + * Retrieves a uint32 value representing the enu uint32 ordinal for the case sensitivity of quoted identifiers. + * + * The possible values are listed in `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`. + */ + SQL_QUOTED_IDENTIFIER_CASE = 505; + + /* + * Retrieves a boolean value indicating whether all tables are selectable. + * + * Returns: + * - false: if not all tables are selectable or if none are; + * - true: if all tables are selectable. + */ + SQL_ALL_TABLES_ARE_SELECTABLE = 506; + + /* + * Retrieves the null ordering. + * + * Returns a uint32 ordinal for the null ordering being used, as described in + * `arrow.flight.protocol.sql.SqlNullOrdering`. + */ + SQL_NULL_ORDERING = 507; + + // Retrieves a UTF-8 string list with values of the supported keywords. + SQL_KEYWORDS = 508; + + // Retrieves a UTF-8 string list with values of the supported numeric functions. + SQL_NUMERIC_FUNCTIONS = 509; + + // Retrieves a UTF-8 string list with values of the supported string functions. + SQL_STRING_FUNCTIONS = 510; + + // Retrieves a UTF-8 string list with values of the supported system functions. + SQL_SYSTEM_FUNCTIONS = 511; + + // Retrieves a UTF-8 string list with values of the supported datetime functions. + SQL_DATETIME_FUNCTIONS = 512; + + /* + * Retrieves the UTF-8 string that can be used to escape wildcard characters. + * This is the string that can be used to escape '_' or '%' in the catalog search parameters that are a pattern + * (and therefore use one of the wildcard characters). + * The '_' character represents any single character; the '%' character represents any sequence of zero or more + * characters. + */ + SQL_SEARCH_STRING_ESCAPE = 513; + + /* + * Retrieves a UTF-8 string with all the "extra" characters that can be used in unquoted identifier names + * (those beyond a-z, A-Z, 0-9 and _). + */ + SQL_EXTRA_NAME_CHARACTERS = 514; + + /* + * Retrieves a boolean value indicating whether column aliasing is supported. + * If so, the SQL AS clause can be used to provide names for computed columns or to provide alias names for columns + * as required. + * + * Returns: + * - false: if column aliasing is unsupported; + * - true: if column aliasing is supported. + */ + SQL_SUPPORTS_COLUMN_ALIASING = 515; + + /* + * Retrieves a boolean value indicating whether concatenations between null and non-null values being + * null are supported. + * + * - Returns: + * - false: if concatenations between null and non-null values being null are unsupported; + * - true: if concatenations between null and non-null values being null are supported. + */ + SQL_NULL_PLUS_NULL_IS_NULL = 516; + + /* + * Retrieves a map where the key is the type to convert from and the value is a list with the types to convert to, + * indicating the supported conversions. Each key and each item on the list value is a value to a predefined type on + * SqlSupportsConvert enum. + * The returned map will be: map> + */ + SQL_SUPPORTS_CONVERT = 517; + + /* + * Retrieves a boolean value indicating whether, when table correlation names are supported, + * they are restricted to being different from the names of the tables. + * + * Returns: + * - false: if table correlation names are unsupported; + * - true: if table correlation names are supported. + */ + SQL_SUPPORTS_TABLE_CORRELATION_NAMES = 518; + + /* + * Retrieves a boolean value indicating whether, when table correlation names are supported, + * they are restricted to being different from the names of the tables. + * + * Returns: + * - false: if different table correlation names are unsupported; + * - true: if different table correlation names are supported + */ + SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES = 519; + + /* + * Retrieves a boolean value indicating whether expressions in ORDER BY lists are supported. + * + * Returns: + * - false: if expressions in ORDER BY are unsupported; + * - true: if expressions in ORDER BY are supported; + */ + SQL_SUPPORTS_EXPRESSIONS_IN_ORDER_BY = 520; + + /* + * Retrieves a boolean value indicating whether using a column that is not in the SELECT statement in a GROUP BY + * clause is supported. + * + * Returns: + * - false: if using a column that is not in the SELECT statement in a GROUP BY clause is unsupported; + * - true: if using a column that is not in the SELECT statement in a GROUP BY clause is supported. + */ + SQL_SUPPORTS_ORDER_BY_UNRELATED = 521; + + /* + * Retrieves the supported GROUP BY commands; + * + * Returns an int32 bitmask value representing the supported commands. + * The returned bitmask should be parsed in order to retrieve the supported commands. + * + * For instance: + * - return 0 (\b0) => [] (GROUP BY is unsupported); + * - return 1 (\b1) => [SQL_GROUP_BY_UNRELATED]; + * - return 2 (\b10) => [SQL_GROUP_BY_BEYOND_SELECT]; + * - return 3 (\b11) => [SQL_GROUP_BY_UNRELATED, SQL_GROUP_BY_BEYOND_SELECT]. + * Valid GROUP BY types are described under `arrow.flight.protocol.sql.SqlSupportedGroupBy`. + */ + SQL_SUPPORTED_GROUP_BY = 522; + + /* + * Retrieves a boolean value indicating whether specifying a LIKE escape clause is supported. + * + * Returns: + * - false: if specifying a LIKE escape clause is unsupported; + * - true: if specifying a LIKE escape clause is supported. + */ + SQL_SUPPORTS_LIKE_ESCAPE_CLAUSE = 523; + + /* + * Retrieves a boolean value indicating whether columns may be defined as non-nullable. + * + * Returns: + * - false: if columns cannot be defined as non-nullable; + * - true: if columns may be defined as non-nullable. + */ + SQL_SUPPORTS_NON_NULLABLE_COLUMNS = 524; + + /* + * Retrieves the supported SQL grammar level as per the ODBC specification. + * + * Returns an int32 bitmask value representing the supported SQL grammar level. + * The returned bitmask should be parsed in order to retrieve the supported grammar levels. + * + * For instance: + * - return 0 (\b0) => [] (SQL grammar is unsupported); + * - return 1 (\b1) => [SQL_MINIMUM_GRAMMAR]; + * - return 2 (\b10) => [SQL_CORE_GRAMMAR]; + * - return 3 (\b11) => [SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR]; + * - return 4 (\b100) => [SQL_EXTENDED_GRAMMAR]; + * - return 5 (\b101) => [SQL_MINIMUM_GRAMMAR, SQL_EXTENDED_GRAMMAR]; + * - return 6 (\b110) => [SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR]; + * - return 7 (\b111) => [SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR]. + * Valid SQL grammar levels are described under `arrow.flight.protocol.sql.SupportedSqlGrammar`. + */ + SQL_SUPPORTED_GRAMMAR = 525; + + /* + * Retrieves the supported ANSI92 SQL grammar level. + * + * Returns an int32 bitmask value representing the supported ANSI92 SQL grammar level. + * The returned bitmask should be parsed in order to retrieve the supported commands. + * + * For instance: + * - return 0 (\b0) => [] (ANSI92 SQL grammar is unsupported); + * - return 1 (\b1) => [ANSI92_ENTRY_SQL]; + * - return 2 (\b10) => [ANSI92_INTERMEDIATE_SQL]; + * - return 3 (\b11) => [ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL]; + * - return 4 (\b100) => [ANSI92_FULL_SQL]; + * - return 5 (\b101) => [ANSI92_ENTRY_SQL, ANSI92_FULL_SQL]; + * - return 6 (\b110) => [ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL]; + * - return 7 (\b111) => [ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL]. + * Valid ANSI92 SQL grammar levels are described under `arrow.flight.protocol.sql.SupportedAnsi92SqlGrammarLevel`. + */ + SQL_ANSI92_SUPPORTED_LEVEL = 526; + + /* + * Retrieves a boolean value indicating whether the SQL Integrity Enhancement Facility is supported. + * + * Returns: + * - false: if the SQL Integrity Enhancement Facility is supported; + * - true: if the SQL Integrity Enhancement Facility is supported. + */ + SQL_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY = 527; + + /* + * Retrieves the support level for SQL OUTER JOINs. + * + * Returns a uint3 uint32 ordinal for the SQL ordering being used, as described in + * `arrow.flight.protocol.sql.SqlOuterJoinsSupportLevel`. + */ + SQL_OUTER_JOINS_SUPPORT_LEVEL = 528; + + // Retrieves a UTF-8 string with the preferred term for "schema". + SQL_SCHEMA_TERM = 529; + + // Retrieves a UTF-8 string with the preferred term for "procedure". + SQL_PROCEDURE_TERM = 530; + + // Retrieves a UTF-8 string with the preferred term for "catalog". + SQL_CATALOG_TERM = 531; + + /* + * Retrieves a boolean value indicating whether a catalog appears at the start of a fully qualified table name. + * + * - false: if a catalog does not appear at the start of a fully qualified table name; + * - true: if a catalog appears at the start of a fully qualified table name. + */ + SQL_CATALOG_AT_START = 532; + + /* + * Retrieves the supported actions for a SQL schema. + * + * Returns an int32 bitmask value representing the supported actions for a SQL schema. + * The returned bitmask should be parsed in order to retrieve the supported actions for a SQL schema. + * + * For instance: + * - return 0 (\b0) => [] (no supported actions for SQL schema); + * - return 1 (\b1) => [SQL_ELEMENT_IN_PROCEDURE_CALLS]; + * - return 2 (\b10) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS]; + * - return 3 (\b11) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS]; + * - return 4 (\b100) => [SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + * - return 5 (\b101) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + * - return 6 (\b110) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + * - return 7 (\b111) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]. + * Valid actions for a SQL schema described under `arrow.flight.protocol.sql.SqlSupportedElementActions`. + */ + SQL_SCHEMAS_SUPPORTED_ACTIONS = 533; + + /* + * Retrieves the supported actions for a SQL schema. + * + * Returns an int32 bitmask value representing the supported actions for a SQL catalog. + * The returned bitmask should be parsed in order to retrieve the supported actions for a SQL catalog. + * + * For instance: + * - return 0 (\b0) => [] (no supported actions for SQL catalog); + * - return 1 (\b1) => [SQL_ELEMENT_IN_PROCEDURE_CALLS]; + * - return 2 (\b10) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS]; + * - return 3 (\b11) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS]; + * - return 4 (\b100) => [SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + * - return 5 (\b101) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + * - return 6 (\b110) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + * - return 7 (\b111) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]. + * Valid actions for a SQL catalog are described under `arrow.flight.protocol.sql.SqlSupportedElementActions`. + */ + SQL_CATALOGS_SUPPORTED_ACTIONS = 534; + + /* + * Retrieves the supported SQL positioned commands. + * + * Returns an int32 bitmask value representing the supported SQL positioned commands. + * The returned bitmask should be parsed in order to retrieve the supported SQL positioned commands. + * + * For instance: + * - return 0 (\b0) => [] (no supported SQL positioned commands); + * - return 1 (\b1) => [SQL_POSITIONED_DELETE]; + * - return 2 (\b10) => [SQL_POSITIONED_UPDATE]; + * - return 3 (\b11) => [SQL_POSITIONED_DELETE, SQL_POSITIONED_UPDATE]. + * Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlSupportedPositionedCommands`. + */ + SQL_SUPPORTED_POSITIONED_COMMANDS = 535; + + /* + * Retrieves a boolean value indicating whether SELECT FOR UPDATE statements are supported. + * + * Returns: + * - false: if SELECT FOR UPDATE statements are unsupported; + * - true: if SELECT FOR UPDATE statements are supported. + */ + SQL_SELECT_FOR_UPDATE_SUPPORTED = 536; + + /* + * Retrieves a boolean value indicating whether stored procedure calls that use the stored procedure escape syntax + * are supported. + * + * Returns: + * - false: if stored procedure calls that use the stored procedure escape syntax are unsupported; + * - true: if stored procedure calls that use the stored procedure escape syntax are supported. + */ + SQL_STORED_PROCEDURES_SUPPORTED = 537; + + /* + * Retrieves the supported SQL subqueries. + * + * Returns an int32 bitmask value representing the supported SQL subqueries. + * The returned bitmask should be parsed in order to retrieve the supported SQL subqueries. + * + * For instance: + * - return 0 (\b0) => [] (no supported SQL subqueries); + * - return 1 (\b1) => [SQL_SUBQUERIES_IN_COMPARISONS]; + * - return 2 (\b10) => [SQL_SUBQUERIES_IN_EXISTS]; + * - return 3 (\b11) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS]; + * - return 4 (\b100) => [SQL_SUBQUERIES_IN_INS]; + * - return 5 (\b101) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS]; + * - return 6 (\b110) => [SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_EXISTS]; + * - return 7 (\b111) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS]; + * - return 8 (\b1000) => [SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - return 9 (\b1001) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - return 10 (\b1010) => [SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - return 11 (\b1011) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - return 12 (\b1100) => [SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - return 13 (\b1101) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - return 14 (\b1110) => [SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - return 15 (\b1111) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - ... + * Valid SQL subqueries are described under `arrow.flight.protocol.sql.SqlSupportedSubqueries`. + */ + SQL_SUPPORTED_SUBQUERIES = 538; + + /* + * Retrieves a boolean value indicating whether correlated subqueries are supported. + * + * Returns: + * - false: if correlated subqueries are unsupported; + * - true: if correlated subqueries are supported. + */ + SQL_CORRELATED_SUBQUERIES_SUPPORTED = 539; + + /* + * Retrieves the supported SQL UNIONs. + * + * Returns an int32 bitmask value representing the supported SQL UNIONs. + * The returned bitmask should be parsed in order to retrieve the supported SQL UNIONs. + * + * For instance: + * - return 0 (\b0) => [] (no supported SQL positioned commands); + * - return 1 (\b1) => [SQL_UNION]; + * - return 2 (\b10) => [SQL_UNION_ALL]; + * - return 3 (\b11) => [SQL_UNION, SQL_UNION_ALL]. + * Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlSupportedUnions`. + */ + SQL_SUPPORTED_UNIONS = 540; + + // Retrieves a uint32 value representing the maximum number of hex characters allowed in an inline binary literal. + SQL_MAX_BINARY_LITERAL_LENGTH = 541; + + // Retrieves a uint32 value representing the maximum number of characters allowed for a character literal. + SQL_MAX_CHAR_LITERAL_LENGTH = 542; + + // Retrieves a uint32 value representing the maximum number of characters allowed for a column name. + SQL_MAX_COLUMN_NAME_LENGTH = 543; + + // Retrieves a uint32 value representing the the maximum number of columns allowed in a GROUP BY clause. + SQL_MAX_COLUMNS_IN_GROUP_BY = 544; + + // Retrieves a uint32 value representing the maximum number of columns allowed in an index. + SQL_MAX_COLUMNS_IN_INDEX = 545; + + // Retrieves a uint32 value representing the maximum number of columns allowed in an ORDER BY clause. + SQL_MAX_COLUMNS_IN_ORDER_BY = 546; + + // Retrieves a uint32 value representing the maximum number of columns allowed in a SELECT list. + SQL_MAX_COLUMNS_IN_SELECT = 547; + + // Retrieves a uint32 value representing the maximum number of columns allowed in a table. + SQL_MAX_COLUMNS_IN_TABLE = 548; + + // Retrieves a uint32 value representing the maximum number of concurrent connections possible. + SQL_MAX_CONNECTIONS = 549; + + // Retrieves a uint32 value the maximum number of characters allowed in a cursor name. + SQL_MAX_CURSOR_NAME_LENGTH = 550; + + /* + * Retrieves a uint32 value representing the maximum number of bytes allowed for an index, + * including all of the parts of the index. + */ + SQL_MAX_INDEX_LENGTH = 551; + + // Retrieves a uint32 value representing the maximum number of characters allowed in a schema name. + SQL_DB_SCHEMA_NAME_LENGTH = 552; + + // Retrieves a uint32 value representing the maximum number of characters allowed in a procedure name. + SQL_MAX_PROCEDURE_NAME_LENGTH = 553; + + // Retrieves a uint32 value representing the maximum number of characters allowed in a catalog name. + SQL_MAX_CATALOG_NAME_LENGTH = 554; + + // Retrieves a uint32 value representing the maximum number of bytes allowed in a single row. + SQL_MAX_ROW_SIZE = 555; + + /* + * Retrieves a boolean indicating whether the return value for the JDBC method getMaxRowSize includes the SQL + * data types LONGVARCHAR and LONGVARBINARY. + * + * Returns: + * - false: if return value for the JDBC method getMaxRowSize does + * not include the SQL data types LONGVARCHAR and LONGVARBINARY; + * - true: if return value for the JDBC method getMaxRowSize includes + * the SQL data types LONGVARCHAR and LONGVARBINARY. + */ + SQL_MAX_ROW_SIZE_INCLUDES_BLOBS = 556; + + /* + * Retrieves a uint32 value representing the maximum number of characters allowed for an SQL statement; + * a result of 0 (zero) means that there is no limit or the limit is not known. + */ + SQL_MAX_STATEMENT_LENGTH = 557; + + // Retrieves a uint32 value representing the maximum number of active statements that can be open at the same time. + SQL_MAX_STATEMENTS = 558; + + // Retrieves a uint32 value representing the maximum number of characters allowed in a table name. + SQL_MAX_TABLE_NAME_LENGTH = 559; + + // Retrieves a uint32 value representing the maximum number of tables allowed in a SELECT statement. + SQL_MAX_TABLES_IN_SELECT = 560; + + // Retrieves a uint32 value representing the maximum number of characters allowed in a user name. + SQL_MAX_USERNAME_LENGTH = 561; + + /* + * Retrieves this database's default transaction isolation level as described in + * `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`. + * + * Returns a uint32 ordinal for the SQL transaction isolation level. + */ + SQL_DEFAULT_TRANSACTION_ISOLATION = 562; + + /* + * Retrieves a boolean value indicating whether transactions are supported. If not, invoking the method commit is a + * noop, and the isolation level is `arrow.flight.protocol.sql.SqlTransactionIsolationLevel.TRANSACTION_NONE`. + * + * Returns: + * - false: if transactions are unsupported; + * - true: if transactions are supported. + */ + SQL_TRANSACTIONS_SUPPORTED = 563; + + /* + * Retrieves the supported transactions isolation levels. + * + * Returns an int32 bitmask value representing the supported transactions isolation levels. + * The returned bitmask should be parsed in order to retrieve the supported transactions isolation levels. + * + * For instance: + * - return 0 (\b0) => [] (no supported SQL transactions isolation levels); + * - return 1 (\b1) => [SQL_TRANSACTION_NONE]; + * - return 2 (\b10) => [SQL_TRANSACTION_READ_UNCOMMITTED]; + * - return 3 (\b11) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED]; + * - return 4 (\b100) => [SQL_TRANSACTION_REPEATABLE_READ]; + * - return 5 (\b101) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 6 (\b110) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 7 (\b111) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 8 (\b1000) => [SQL_TRANSACTION_REPEATABLE_READ]; + * - return 9 (\b1001) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 10 (\b1010) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 11 (\b1011) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 12 (\b1100) => [SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 13 (\b1101) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 14 (\b1110) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 15 (\b1111) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 16 (\b10000) => [SQL_TRANSACTION_SERIALIZABLE]; + * - ... + * Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`. + */ + SQL_SUPPORTED_TRANSACTIONS_ISOLATION_LEVELS = 564; + + /* + * Retrieves a boolean value indicating whether a data definition statement within a transaction forces + * the transaction to commit. + * + * Returns: + * - false: if a data definition statement within a transaction does not force the transaction to commit; + * - true: if a data definition statement within a transaction forces the transaction to commit. + */ + SQL_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT = 565; + + /* + * Retrieves a boolean value indicating whether a data definition statement within a transaction is ignored. + * + * Returns: + * - false: if a data definition statement within a transaction is taken into account; + * - true: a data definition statement within a transaction is ignored. + */ + SQL_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED = 566; + + /* + * Retrieves an int32 bitmask value representing the supported result set types. + * The returned bitmask should be parsed in order to retrieve the supported result set types. + * + * For instance: + * - return 0 (\b0) => [] (no supported result set types); + * - return 1 (\b1) => [SQL_RESULT_SET_TYPE_UNSPECIFIED]; + * - return 2 (\b10) => [SQL_RESULT_SET_TYPE_FORWARD_ONLY]; + * - return 3 (\b11) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_FORWARD_ONLY]; + * - return 4 (\b100) => [SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; + * - return 5 (\b101) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; + * - return 6 (\b110) => [SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; + * - return 7 (\b111) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; + * - return 8 (\b1000) => [SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE]; + * - ... + * Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetType`. + */ + SQL_SUPPORTED_RESULT_SET_TYPES = 567; + + /* + * Returns an int32 bitmask value concurrency types supported for + * `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_UNSPECIFIED`. + * + * For instance: + * - return 0 (\b0) => [] (no supported concurrency types for this result set type) + * - return 1 (\b1) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED] + * - return 2 (\b10) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 4 (\b100) => [SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + */ + SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_UNSPECIFIED = 568; + + /* + * Returns an int32 bitmask value concurrency types supported for + * `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_FORWARD_ONLY`. + * + * For instance: + * - return 0 (\b0) => [] (no supported concurrency types for this result set type) + * - return 1 (\b1) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED] + * - return 2 (\b10) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 4 (\b100) => [SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + */ + SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_FORWARD_ONLY = 569; + + /* + * Returns an int32 bitmask value concurrency types supported for + * `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE`. + * + * For instance: + * - return 0 (\b0) => [] (no supported concurrency types for this result set type) + * - return 1 (\b1) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED] + * - return 2 (\b10) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 4 (\b100) => [SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + */ + SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_SENSITIVE = 570; + + /* + * Returns an int32 bitmask value concurrency types supported for + * `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE`. + * + * For instance: + * - return 0 (\b0) => [] (no supported concurrency types for this result set type) + * - return 1 (\b1) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED] + * - return 2 (\b10) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 4 (\b100) => [SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + */ + SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_INSENSITIVE = 571; + + /* + * Retrieves a boolean value indicating whether this database supports batch updates. + * + * - false: if this database does not support batch updates; + * - true: if this database supports batch updates. + */ + SQL_BATCH_UPDATES_SUPPORTED = 572; + + /* + * Retrieves a boolean value indicating whether this database supports savepoints. + * + * Returns: + * - false: if this database does not support savepoints; + * - true: if this database supports savepoints. + */ + SQL_SAVEPOINTS_SUPPORTED = 573; + + /* + * Retrieves a boolean value indicating whether named parameters are supported in callable statements. + * + * Returns: + * - false: if named parameters in callable statements are unsupported; + * - true: if named parameters in callable statements are supported. + */ + SQL_NAMED_PARAMETERS_SUPPORTED = 574; + + /* + * Retrieves a boolean value indicating whether updates made to a LOB are made on a copy or directly to the LOB. + * + * Returns: + * - false: if updates made to a LOB are made directly to the LOB; + * - true: if updates made to a LOB are made on a copy. + */ + SQL_LOCATORS_UPDATE_COPY = 575; + + /* + * Retrieves a boolean value indicating whether invoking user-defined or vendor functions + * using the stored procedure escape syntax is supported. + * + * Returns: + * - false: if invoking user-defined or vendor functions using the stored procedure escape syntax is unsupported; + * - true: if invoking user-defined or vendor functions using the stored procedure escape syntax is supported. + */ + SQL_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED = 576; +} + +enum SqlSupportedCaseSensitivity { + SQL_CASE_SENSITIVITY_UNKNOWN = 0; + SQL_CASE_SENSITIVITY_CASE_INSENSITIVE = 1; + SQL_CASE_SENSITIVITY_UPPERCASE = 2; + SQL_CASE_SENSITIVITY_LOWERCASE = 3; +} + +enum SqlNullOrdering { + SQL_NULLS_SORTED_HIGH = 0; + SQL_NULLS_SORTED_LOW = 1; + SQL_NULLS_SORTED_AT_START = 2; + SQL_NULLS_SORTED_AT_END = 3; +} + +enum SupportedSqlGrammar { + SQL_MINIMUM_GRAMMAR = 0; + SQL_CORE_GRAMMAR = 1; + SQL_EXTENDED_GRAMMAR = 2; +} + +enum SupportedAnsi92SqlGrammarLevel { + ANSI92_ENTRY_SQL = 0; + ANSI92_INTERMEDIATE_SQL = 1; + ANSI92_FULL_SQL = 2; +} + +enum SqlOuterJoinsSupportLevel { + SQL_JOINS_UNSUPPORTED = 0; + SQL_LIMITED_OUTER_JOINS = 1; + SQL_FULL_OUTER_JOINS = 2; +} + +enum SqlSupportedGroupBy { + SQL_GROUP_BY_UNRELATED = 0; + SQL_GROUP_BY_BEYOND_SELECT = 1; +} + +enum SqlSupportedElementActions { + SQL_ELEMENT_IN_PROCEDURE_CALLS = 0; + SQL_ELEMENT_IN_INDEX_DEFINITIONS = 1; + SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS = 2; +} + +enum SqlSupportedPositionedCommands { + SQL_POSITIONED_DELETE = 0; + SQL_POSITIONED_UPDATE = 1; +} + +enum SqlSupportedSubqueries { + SQL_SUBQUERIES_IN_COMPARISONS = 0; + SQL_SUBQUERIES_IN_EXISTS = 1; + SQL_SUBQUERIES_IN_INS = 2; + SQL_SUBQUERIES_IN_QUANTIFIEDS = 3; +} + +enum SqlSupportedUnions { + SQL_UNION = 0; + SQL_UNION_ALL = 1; +} + +enum SqlTransactionIsolationLevel { + SQL_TRANSACTION_NONE = 0; + SQL_TRANSACTION_READ_UNCOMMITTED = 1; + SQL_TRANSACTION_READ_COMMITTED = 2; + SQL_TRANSACTION_REPEATABLE_READ = 3; + SQL_TRANSACTION_SERIALIZABLE = 4; +} + +enum SqlSupportedTransactions { + SQL_TRANSACTION_UNSPECIFIED = 0; + SQL_DATA_DEFINITION_TRANSACTIONS = 1; + SQL_DATA_MANIPULATION_TRANSACTIONS = 2; +} + +enum SqlSupportedResultSetType { + SQL_RESULT_SET_TYPE_UNSPECIFIED = 0; + SQL_RESULT_SET_TYPE_FORWARD_ONLY = 1; + SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE = 2; + SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE = 3; +} + +enum SqlSupportedResultSetConcurrency { + SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED = 0; + SQL_RESULT_SET_CONCURRENCY_READ_ONLY = 1; + SQL_RESULT_SET_CONCURRENCY_UPDATABLE = 2; +} + +enum SqlSupportsConvert { + SQL_CONVERT_BIGINT = 0; + SQL_CONVERT_BINARY = 1; + SQL_CONVERT_BIT = 2; + SQL_CONVERT_CHAR = 3; + SQL_CONVERT_DATE = 4; + SQL_CONVERT_DECIMAL = 5; + SQL_CONVERT_FLOAT = 6; + SQL_CONVERT_INTEGER = 7; + SQL_CONVERT_INTERVAL_DAY_TIME = 8; + SQL_CONVERT_INTERVAL_YEAR_MONTH = 9; + SQL_CONVERT_LONGVARBINARY = 10; + SQL_CONVERT_LONGVARCHAR = 11; + SQL_CONVERT_NUMERIC = 12; + SQL_CONVERT_REAL = 13; + SQL_CONVERT_SMALLINT = 14; + SQL_CONVERT_TIME = 15; + SQL_CONVERT_TIMESTAMP = 16; + SQL_CONVERT_TINYINT = 17; + SQL_CONVERT_VARBINARY = 18; + SQL_CONVERT_VARCHAR = 19; +} + +/* + * Represents a request to retrieve the list of catalogs on a Flight SQL enabled backend. + * The definition of a catalog depends on vendor/implementation. It is usually the database itself + * Used in the command member of FlightDescriptor for the following RPC calls: + * - GetSchema: return the Arrow schema of the query. + * - GetFlightInfo: execute the catalog metadata request. + * + * The returned Arrow schema will be: + * < + * catalog_name: utf8 not null + * > + * The returned data should be ordered by catalog_name. + */ +message CommandGetCatalogs { + option (experimental) = true; +} + +/* + * Represents a request to retrieve the list of database schemas on a Flight SQL enabled backend. + * The definition of a database schema depends on vendor/implementation. It is usually a collection of tables. + * Used in the command member of FlightDescriptor for the following RPC calls: + * - GetSchema: return the Arrow schema of the query. + * - GetFlightInfo: execute the catalog metadata request. + * + * The returned Arrow schema will be: + * < + * catalog_name: utf8, + * db_schema_name: utf8 not null + * > + * The returned data should be ordered by catalog_name, then db_schema_name. + */ +message CommandGetDbSchemas { + option (experimental) = true; + + /* + * Specifies the Catalog to search for the tables. + * An empty string retrieves those without a catalog. + * If omitted the catalog name should not be used to narrow the search. + */ + optional string catalog = 1; + + /* + * Specifies a filter pattern for schemas to search for. + * When no db_schema_filter_pattern is provided, the pattern will not be used to narrow the search. + * In the pattern string, two special characters can be used to denote matching rules: + * - "%" means to match any substring with 0 or more characters. + * - "_" means to match any one character. + */ + optional string db_schema_filter_pattern = 2; +} + +/* + * Represents a request to retrieve the list of tables, and optionally their schemas, on a Flight SQL enabled backend. + * Used in the command member of FlightDescriptor for the following RPC calls: + * - GetSchema: return the Arrow schema of the query. + * - GetFlightInfo: execute the catalog metadata request. + * + * The returned Arrow schema will be: + * < + * catalog_name: utf8, + * db_schema_name: utf8, + * table_name: utf8 not null, + * table_type: utf8 not null, + * [optional] table_schema: bytes not null (schema of the table as described in Schema.fbs::Schema, + * it is serialized as an IPC message.) + * > + * The returned data should be ordered by catalog_name, db_schema_name, table_name, then table_type, followed by table_schema if requested. + */ +message CommandGetTables { + option (experimental) = true; + + /* + * Specifies the Catalog to search for the tables. + * An empty string retrieves those without a catalog. + * If omitted the catalog name should not be used to narrow the search. + */ + optional string catalog = 1; + + /* + * Specifies a filter pattern for schemas to search for. + * When no db_schema_filter_pattern is provided, all schemas matching other filters are searched. + * In the pattern string, two special characters can be used to denote matching rules: + * - "%" means to match any substring with 0 or more characters. + * - "_" means to match any one character. + */ + optional string db_schema_filter_pattern = 2; + + /* + * Specifies a filter pattern for tables to search for. + * When no table_name_filter_pattern is provided, all tables matching other filters are searched. + * In the pattern string, two special characters can be used to denote matching rules: + * - "%" means to match any substring with 0 or more characters. + * - "_" means to match any one character. + */ + optional string table_name_filter_pattern = 3; + + /* + * Specifies a filter of table types which must match. + * The table types depend on vendor/implementation. It is usually used to separate tables from views or system tables. + * TABLE, VIEW, and SYSTEM TABLE are commonly supported. + */ + repeated string table_types = 4; + + // Specifies if the Arrow schema should be returned for found tables. + bool include_schema = 5; +} + +/* + * Represents a request to retrieve the list of table types on a Flight SQL enabled backend. + * The table types depend on vendor/implementation. It is usually used to separate tables from views or system tables. + * TABLE, VIEW, and SYSTEM TABLE are commonly supported. + * Used in the command member of FlightDescriptor for the following RPC calls: + * - GetSchema: return the Arrow schema of the query. + * - GetFlightInfo: execute the catalog metadata request. + * + * The returned Arrow schema will be: + * < + * table_type: utf8 not null + * > + * The returned data should be ordered by table_type. + */ +message CommandGetTableTypes { + option (experimental) = true; +} + +/* + * Represents a request to retrieve the primary keys of a table on a Flight SQL enabled backend. + * Used in the command member of FlightDescriptor for the following RPC calls: + * - GetSchema: return the Arrow schema of the query. + * - GetFlightInfo: execute the catalog metadata request. + * + * The returned Arrow schema will be: + * < + * catalog_name: utf8, + * db_schema_name: utf8, + * table_name: utf8 not null, + * column_name: utf8 not null, + * key_name: utf8, + * key_sequence: int not null + * > + * The returned data should be ordered by catalog_name, db_schema_name, table_name, key_name, then key_sequence. + */ +message CommandGetPrimaryKeys { + option (experimental) = true; + + /* + * Specifies the catalog to search for the table. + * An empty string retrieves those without a catalog. + * If omitted the catalog name should not be used to narrow the search. + */ + optional string catalog = 1; + + /* + * Specifies the schema to search for the table. + * An empty string retrieves those without a schema. + * If omitted the schema name should not be used to narrow the search. + */ + optional string db_schema = 2; + + // Specifies the table to get the primary keys for. + string table = 3; +} + +enum UpdateDeleteRules { + CASCADE = 0; + RESTRICT = 1; + SET_NULL = 2; + NO_ACTION = 3; + SET_DEFAULT = 4; +} + +/* + * Represents a request to retrieve a description of the foreign key columns that reference the given table's + * primary key columns (the foreign keys exported by a table) of a table on a Flight SQL enabled backend. + * Used in the command member of FlightDescriptor for the following RPC calls: + * - GetSchema: return the Arrow schema of the query. + * - GetFlightInfo: execute the catalog metadata request. + * + * The returned Arrow schema will be: + * < + * pk_catalog_name: utf8, + * pk_db_schema_name: utf8, + * pk_table_name: utf8 not null, + * pk_column_name: utf8 not null, + * fk_catalog_name: utf8, + * fk_db_schema_name: utf8, + * fk_table_name: utf8 not null, + * fk_column_name: utf8 not null, + * key_sequence: int not null, + * fk_key_name: utf8, + * pk_key_name: utf8, + * update_rule: uint1 not null, + * delete_rule: uint1 not null + * > + * The returned data should be ordered by fk_catalog_name, fk_db_schema_name, fk_table_name, fk_key_name, then key_sequence. + * update_rule and delete_rule returns a byte that is equivalent to actions declared on UpdateDeleteRules enum. + */ +message CommandGetExportedKeys { + option (experimental) = true; + + /* + * Specifies the catalog to search for the foreign key table. + * An empty string retrieves those without a catalog. + * If omitted the catalog name should not be used to narrow the search. + */ + optional string catalog = 1; + + /* + * Specifies the schema to search for the foreign key table. + * An empty string retrieves those without a schema. + * If omitted the schema name should not be used to narrow the search. + */ + optional string db_schema = 2; + + // Specifies the foreign key table to get the foreign keys for. + string table = 3; +} + +/* + * Represents a request to retrieve the foreign keys of a table on a Flight SQL enabled backend. + * Used in the command member of FlightDescriptor for the following RPC calls: + * - GetSchema: return the Arrow schema of the query. + * - GetFlightInfo: execute the catalog metadata request. + * + * The returned Arrow schema will be: + * < + * pk_catalog_name: utf8, + * pk_db_schema_name: utf8, + * pk_table_name: utf8 not null, + * pk_column_name: utf8 not null, + * fk_catalog_name: utf8, + * fk_db_schema_name: utf8, + * fk_table_name: utf8 not null, + * fk_column_name: utf8 not null, + * key_sequence: int not null, + * fk_key_name: utf8, + * pk_key_name: utf8, + * update_rule: uint1 not null, + * delete_rule: uint1 not null + * > + * The returned data should be ordered by pk_catalog_name, pk_db_schema_name, pk_table_name, pk_key_name, then key_sequence. + * update_rule and delete_rule returns a byte that is equivalent to actions: + * - 0 = CASCADE + * - 1 = RESTRICT + * - 2 = SET NULL + * - 3 = NO ACTION + * - 4 = SET DEFAULT + */ +message CommandGetImportedKeys { + option (experimental) = true; + + /* + * Specifies the catalog to search for the primary key table. + * An empty string retrieves those without a catalog. + * If omitted the catalog name should not be used to narrow the search. + */ + optional string catalog = 1; + + /* + * Specifies the schema to search for the primary key table. + * An empty string retrieves those without a schema. + * If omitted the schema name should not be used to narrow the search. + */ + optional string db_schema = 2; + + // Specifies the primary key table to get the foreign keys for. + string table = 3; +} + +/* + * Represents a request to retrieve a description of the foreign key columns in the given foreign key table that + * reference the primary key or the columns representing a unique constraint of the parent table (could be the same + * or a different table) on a Flight SQL enabled backend. + * Used in the command member of FlightDescriptor for the following RPC calls: + * - GetSchema: return the Arrow schema of the query. + * - GetFlightInfo: execute the catalog metadata request. + * + * The returned Arrow schema will be: + * < + * pk_catalog_name: utf8, + * pk_db_schema_name: utf8, + * pk_table_name: utf8 not null, + * pk_column_name: utf8 not null, + * fk_catalog_name: utf8, + * fk_db_schema_name: utf8, + * fk_table_name: utf8 not null, + * fk_column_name: utf8 not null, + * key_sequence: int not null, + * fk_key_name: utf8, + * pk_key_name: utf8, + * update_rule: uint1 not null, + * delete_rule: uint1 not null + * > + * The returned data should be ordered by pk_catalog_name, pk_db_schema_name, pk_table_name, pk_key_name, then key_sequence. + * update_rule and delete_rule returns a byte that is equivalent to actions: + * - 0 = CASCADE + * - 1 = RESTRICT + * - 2 = SET NULL + * - 3 = NO ACTION + * - 4 = SET DEFAULT + */ +message CommandGetCrossReference { + option (experimental) = true; + + /** + * The catalog name where the parent table is. + * An empty string retrieves those without a catalog. + * If omitted the catalog name should not be used to narrow the search. + */ + optional string pk_catalog = 1; + + /** + * The Schema name where the parent table is. + * An empty string retrieves those without a schema. + * If omitted the schema name should not be used to narrow the search. + */ + optional string pk_db_schema = 2; + + /** + * The parent table name. It cannot be null. + */ + string pk_table = 3; + + /** + * The catalog name where the foreign table is. + * An empty string retrieves those without a catalog. + * If omitted the catalog name should not be used to narrow the search. + */ + optional string fk_catalog = 4; + + /** + * The schema name where the foreign table is. + * An empty string retrieves those without a schema. + * If omitted the schema name should not be used to narrow the search. + */ + optional string fk_db_schema = 5; + + /** + * The foreign table name. It cannot be null. + */ + string fk_table = 6; +} + +// SQL Execution Action Messages + +/* + * Request message for the "CreatePreparedStatement" action on a Flight SQL enabled backend. + */ +message ActionCreatePreparedStatementRequest { + option (experimental) = true; + + // The valid SQL string to create a prepared statement for. + string query = 1; +} + +/* + * Wrap the result of a "GetPreparedStatement" action. + * + * The resultant PreparedStatement can be closed either: + * - Manually, through the "ClosePreparedStatement" action; + * - Automatically, by a server timeout. + */ +message ActionCreatePreparedStatementResult { + option (experimental) = true; + + // Opaque handle for the prepared statement on the server. + bytes prepared_statement_handle = 1; + + // If a result set generating query was provided, dataset_schema contains the + // schema of the dataset as described in Schema.fbs::Schema, it is serialized as an IPC message. + bytes dataset_schema = 2; + + // If the query provided contained parameters, parameter_schema contains the + // schema of the expected parameters as described in Schema.fbs::Schema, it is serialized as an IPC message. + bytes parameter_schema = 3; +} + +/* + * Request message for the "ClosePreparedStatement" action on a Flight SQL enabled backend. + * Closes server resources associated with the prepared statement handle. + */ +message ActionClosePreparedStatementRequest { + option (experimental) = true; + + // Opaque handle for the prepared statement on the server. + bytes prepared_statement_handle = 1; +} + + +// SQL Execution Messages. + +/* + * Represents a SQL query. Used in the command member of FlightDescriptor + * for the following RPC calls: + * - GetSchema: return the Arrow schema of the query. + * - GetFlightInfo: execute the query. + */ +message CommandStatementQuery { + option (experimental) = true; + + // The SQL syntax. + string query = 1; +} + +/** + * Represents a ticket resulting from GetFlightInfo with a CommandStatementQuery. + * This should be used only once and treated as an opaque value, that is, clients should not attempt to parse this. + */ +message TicketStatementQuery { + option (experimental) = true; + + // Unique identifier for the instance of the statement to execute. + bytes statement_handle = 1; +} + +/* + * Represents an instance of executing a prepared statement. Used in the command member of FlightDescriptor for + * the following RPC calls: + * - DoPut: bind parameter values. All of the bound parameter sets will be executed as a single atomic execution. + * - GetFlightInfo: execute the prepared statement instance. + */ +message CommandPreparedStatementQuery { + option (experimental) = true; + + // Opaque handle for the prepared statement on the server. + bytes prepared_statement_handle = 1; +} + +/* + * Represents a SQL update query. Used in the command member of FlightDescriptor + * for the the RPC call DoPut to cause the server to execute the included SQL update. + */ +message CommandStatementUpdate { + option (experimental) = true; + + // The SQL syntax. + string query = 1; +} + +/* + * Represents a SQL update query. Used in the command member of FlightDescriptor + * for the the RPC call DoPut to cause the server to execute the included + * prepared statement handle as an update. + */ +message CommandPreparedStatementUpdate { + option (experimental) = true; + + // Opaque handle for the prepared statement on the server. + bytes prepared_statement_handle = 1; +} + +/* + * Returned from the RPC call DoPut when a CommandStatementUpdate + * CommandPreparedStatementUpdate was in the request, containing + * results from the update. + */ +message DoPutUpdateResult { + option (experimental) = true; + + // The number of records updated. A return value of -1 represents + // an unknown updated record count. + int64 record_count = 1; +} + +extend google.protobuf.MessageOptions { + bool experimental = 1000; +} diff --git a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowConfig.java b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowConfig.java index 250b0edd2d3..a1bb8b667f4 100644 --- a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowConfig.java +++ b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowConfig.java @@ -17,18 +17,12 @@ package org.apache.arrow.adapter.jdbc; -import static org.apache.arrow.vector.types.FloatingPointPrecision.DOUBLE; -import static org.apache.arrow.vector.types.FloatingPointPrecision.SINGLE; - -import java.sql.Types; import java.util.Calendar; import java.util.Map; import java.util.function.Function; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.Preconditions; -import org.apache.arrow.vector.types.DateUnit; -import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.pojo.ArrowType; /** @@ -55,16 +49,14 @@ */ public final class JdbcToArrowConfig { + public static final int DEFAULT_TARGET_BATCH_SIZE = 1024; + public static final int NO_LIMIT_BATCH_SIZE = -1; private final Calendar calendar; private final BufferAllocator allocator; private final boolean includeMetadata; private final boolean reuseVectorSchemaRoot; private final Map arraySubTypesByColumnIndex; private final Map arraySubTypesByColumnName; - - public static final int DEFAULT_TARGET_BATCH_SIZE = 1024; - public static final int NO_LIMIT_BATCH_SIZE = -1; - /** * The maximum rowCount to read each time when partially convert data. * Default value is 1024 and -1 means disable partial read. @@ -82,7 +74,7 @@ public final class JdbcToArrowConfig { /** * Constructs a new configuration from the provided allocator and calendar. The allocator * is used when constructing the Arrow vectors from the ResultSet, and the calendar is used to define - * Arrow Timestamp fields, and to read time-based fields from the JDBC ResultSet. + * Arrow Timestamp fields, and to read time-based fields from the JDBC ResultSet. * * @param allocator The memory allocator to construct the Arrow vectors with. * @param calendar The calendar to use when constructing Timestamp fields and reading time-based results. @@ -99,7 +91,7 @@ public final class JdbcToArrowConfig { /** * Constructs a new configuration from the provided allocator and calendar. The allocator * is used when constructing the Arrow vectors from the ResultSet, and the calendar is used to define - * Arrow Timestamp fields, and to read time-based fields from the JDBC ResultSet. + * Arrow Timestamp fields, and to read time-based fields from the JDBC ResultSet. * * @param allocator The memory allocator to construct the Arrow vectors with. * @param calendar The calendar to use when constructing Timestamp fields and reading time-based results. @@ -134,6 +126,8 @@ public final class JdbcToArrowConfig { *

  • TIMESTAMP --> ArrowType.Timestamp(TimeUnit.MILLISECOND, calendar timezone)
  • *
  • CLOB --> ArrowType.Utf8
  • *
  • BLOB --> ArrowType.Binary
  • + *
  • ARRAY --> ArrowType.List
  • + *
  • STRUCT --> ArrowType.Struct
  • *
  • NULL --> ArrowType.Null
  • * */ @@ -157,64 +151,7 @@ public final class JdbcToArrowConfig { // set up type converter this.jdbcToArrowTypeConverter = jdbcToArrowTypeConverter != null ? jdbcToArrowTypeConverter : - fieldInfo -> { - final String timezone; - if (calendar != null) { - timezone = calendar.getTimeZone().getID(); - } else { - timezone = null; - } - - switch (fieldInfo.getJdbcType()) { - case Types.BOOLEAN: - case Types.BIT: - return new ArrowType.Bool(); - case Types.TINYINT: - return new ArrowType.Int(8, true); - case Types.SMALLINT: - return new ArrowType.Int(16, true); - case Types.INTEGER: - return new ArrowType.Int(32, true); - case Types.BIGINT: - return new ArrowType.Int(64, true); - case Types.NUMERIC: - case Types.DECIMAL: - int precision = fieldInfo.getPrecision(); - int scale = fieldInfo.getScale(); - return new ArrowType.Decimal(precision, scale, 128); - case Types.REAL: - case Types.FLOAT: - return new ArrowType.FloatingPoint(SINGLE); - case Types.DOUBLE: - return new ArrowType.FloatingPoint(DOUBLE); - case Types.CHAR: - case Types.NCHAR: - case Types.VARCHAR: - case Types.NVARCHAR: - case Types.LONGVARCHAR: - case Types.LONGNVARCHAR: - case Types.CLOB: - return new ArrowType.Utf8(); - case Types.DATE: - return new ArrowType.Date(DateUnit.DAY); - case Types.TIME: - return new ArrowType.Time(TimeUnit.MILLISECOND, 32); - case Types.TIMESTAMP: - return new ArrowType.Timestamp(TimeUnit.MILLISECOND, timezone); - case Types.BINARY: - case Types.VARBINARY: - case Types.LONGVARBINARY: - case Types.BLOB: - return new ArrowType.Binary(); - case Types.ARRAY: - return new ArrowType.List(); - case Types.NULL: - return new ArrowType.Null(); - default: - // no-op, shouldn't get here - return null; - } - }; + jdbcFieldInfo -> JdbcToArrowUtils.getArrowTypeFromJdbcType(jdbcFieldInfo, calendar); } /** @@ -230,6 +167,7 @@ public Calendar getCalendar() { /** * The Arrow memory allocator. + * * @return the allocator. */ public BufferAllocator getAllocator() { diff --git a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowUtils.java b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowUtils.java index e05f21d48cf..db528af4486 100644 --- a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowUtils.java +++ b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowUtils.java @@ -17,13 +17,18 @@ package org.apache.arrow.adapter.jdbc; +import static org.apache.arrow.vector.types.FloatingPointPrecision.DOUBLE; +import static org.apache.arrow.vector.types.FloatingPointPrecision.SINGLE; + import java.io.IOException; import java.sql.Date; +import java.sql.ParameterMetaData; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Time; import java.sql.Timestamp; +import java.sql.Types; import java.util.ArrayList; import java.util.Calendar; import java.util.HashMap; @@ -70,6 +75,8 @@ import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; @@ -106,6 +113,101 @@ public static Schema jdbcToArrowSchema(ResultSetMetaData rsmd, Calendar calendar return jdbcToArrowSchema(rsmd, new JdbcToArrowConfig(new RootAllocator(0), calendar)); } + /** + * Create Arrow {@link Schema} object for the given JDBC {@link ResultSetMetaData}. + * + * @param parameterMetaData The ResultSetMetaData containing the results, to read the JDBC metadata from. + * @param calendar The calendar to use the time zone field of, to construct Timestamp fields from. + * @return {@link Schema} + * @throws SQLException on error + */ + public static Schema jdbcToArrowSchema(final ParameterMetaData parameterMetaData, final Calendar calendar) + throws SQLException { + Preconditions.checkNotNull(calendar, "Calendar object can't be null"); + Preconditions.checkNotNull(parameterMetaData); + final List parameterFields = new ArrayList<>(parameterMetaData.getParameterCount()); + for (int parameterCounter = 1; parameterCounter <= parameterMetaData.getParameterCount(); + parameterCounter++) { + final int jdbcDataType = parameterMetaData.getParameterType(parameterCounter); + final int jdbcIsNullable = parameterMetaData.isNullable(parameterCounter); + final boolean arrowIsNullable = jdbcIsNullable != ParameterMetaData.parameterNoNulls; + final int precision = parameterMetaData.getPrecision(parameterCounter); + final int scale = parameterMetaData.getScale(parameterCounter); + final ArrowType arrowType = getArrowTypeFromJdbcType(new JdbcFieldInfo(jdbcDataType, precision, scale), calendar); + final FieldType fieldType = new FieldType(arrowIsNullable, arrowType, /*dictionary=*/null); + parameterFields.add(new Field(null, fieldType, null)); + } + + return new Schema(parameterFields); + } + + /** + * Converts the provided JDBC type to its respective {@link ArrowType} counterpart. + * + * @param fieldInfo the {@link JdbcFieldInfo} with information about the original JDBC type. + * @param calendar the {@link Calendar} to use for datetime data types. + * @return a new {@link ArrowType}. + */ + public static ArrowType getArrowTypeFromJdbcType(final JdbcFieldInfo fieldInfo, final Calendar calendar) { + switch (fieldInfo.getJdbcType()) { + case Types.BOOLEAN: + case Types.BIT: + return new ArrowType.Bool(); + case Types.TINYINT: + return new ArrowType.Int(8, true); + case Types.SMALLINT: + return new ArrowType.Int(16, true); + case Types.INTEGER: + return new ArrowType.Int(32, true); + case Types.BIGINT: + return new ArrowType.Int(64, true); + case Types.NUMERIC: + case Types.DECIMAL: + int precision = fieldInfo.getPrecision(); + int scale = fieldInfo.getScale(); + return new ArrowType.Decimal(precision, scale, 128); + case Types.REAL: + case Types.FLOAT: + return new ArrowType.FloatingPoint(SINGLE); + case Types.DOUBLE: + return new ArrowType.FloatingPoint(DOUBLE); + case Types.CHAR: + case Types.NCHAR: + case Types.VARCHAR: + case Types.NVARCHAR: + case Types.LONGVARCHAR: + case Types.LONGNVARCHAR: + case Types.CLOB: + return new ArrowType.Utf8(); + case Types.DATE: + return new ArrowType.Date(DateUnit.DAY); + case Types.TIME: + return new ArrowType.Time(TimeUnit.MILLISECOND, 32); + case Types.TIMESTAMP: + final String timezone; + if (calendar != null) { + timezone = calendar.getTimeZone().getID(); + } else { + timezone = null; + } + return new ArrowType.Timestamp(TimeUnit.MILLISECOND, timezone); + case Types.BINARY: + case Types.VARBINARY: + case Types.LONGVARBINARY: + case Types.BLOB: + return new ArrowType.Binary(); + case Types.ARRAY: + return new ArrowType.List(); + case Types.NULL: + return new ArrowType.Null(); + case Types.STRUCT: + return new ArrowType.Struct(); + default: + // no-op, shouldn't get here + return null; + } + } + /** * Create Arrow {@link Schema} object for the given JDBC {@link java.sql.ResultSetMetaData}. * diff --git a/java/flight/flight-core/pom.xml b/java/flight/flight-core/pom.xml index b1f00eb83f9..d870faf9c50 100644 --- a/java/flight/flight-core/pom.xml +++ b/java/flight/flight-core/pom.xml @@ -12,10 +12,10 @@ 4.0.0 + arrow-flight org.apache.arrow - arrow-java-root 7.0.0-SNAPSHOT - ../../pom.xml + ../pom.xml flight-core @@ -24,8 +24,6 @@ jar - 1.41.0 - 3.7.1 1 @@ -95,11 +93,6 @@ com.google.guava guava - - commons-cli - commons-cli - 1.4 - io.grpc grpc-stub diff --git a/java/flight/flight-grpc/pom.xml b/java/flight/flight-grpc/pom.xml index c567b7cada5..a12e4e26652 100644 --- a/java/flight/flight-grpc/pom.xml +++ b/java/flight/flight-grpc/pom.xml @@ -11,10 +11,10 @@ language governing permissions and limitations under the License. --> - arrow-java-root + arrow-flight org.apache.arrow 7.0.0-SNAPSHOT - ../../pom.xml + ../pom.xml 4.0.0 @@ -24,8 +24,6 @@ jar - 1.41.0 - 3.7.1 1 diff --git a/java/flight/flight-integration-tests/pom.xml b/java/flight/flight-integration-tests/pom.xml new file mode 100644 index 00000000000..1958c3bd504 --- /dev/null +++ b/java/flight/flight-integration-tests/pom.xml @@ -0,0 +1,86 @@ + + + + 4.0.0 + + arrow-flight + org.apache.arrow + 7.0.0-SNAPSHOT + ../pom.xml + + + flight-integration-tests + Arrow Flight Integration Tests + Integration tests for Flight RPC. + jar + + + + org.apache.arrow + arrow-vector + ${project.version} + + + org.apache.arrow + arrow-memory-core + ${project.version} + + + org.apache.arrow + flight-core + ${project.version} + + + org.apache.arrow + flight-sql + ${project.version} + + + com.google.protobuf + protobuf-java + ${dep.protobuf.version} + + + commons-cli + commons-cli + 1.4 + + + org.slf4j + slf4j-api + + + + + + + maven-assembly-plugin + 3.0.0 + + + jar-with-dependencies + + + + + make-assembly + package + + single + + + + + + + diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/AuthBasicProtoScenario.java similarity index 98% rename from java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java rename to java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/AuthBasicProtoScenario.java index 3955d7d21bf..1c95d4d5593 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/AuthBasicProtoScenario.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.flight.example.integration; +package org.apache.arrow.flight.integration.tests; import java.nio.charset.StandardCharsets; import java.util.Arrays; diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java new file mode 100644 index 00000000000..374e634e8a3 --- /dev/null +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.integration.tests; + +import java.util.Arrays; + +import org.apache.arrow.flight.CallOption; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.FlightSqlClient; +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.flight.sql.util.TableRef; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Schema; + +/** + * Integration test scenario for validating Flight SQL specs across multiple implementations. + * This should ensure that RPC objects are being built and parsed correctly for multiple languages + * and that the Arrow schemas are returned as expected. + */ +public class FlightSqlScenario implements Scenario { + + public static final long UPDATE_STATEMENT_EXPECTED_ROWS = 10000L; + public static final long UPDATE_PREPARED_STATEMENT_EXPECTED_ROWS = 20000L; + + @Override + public FlightProducer producer(BufferAllocator allocator, Location location) throws Exception { + return new FlightSqlScenarioProducer(allocator); + } + + @Override + public void buildServer(FlightServer.Builder builder) throws Exception { + + } + + @Override + public void client(BufferAllocator allocator, Location location, FlightClient client) + throws Exception { + final FlightSqlClient sqlClient = new FlightSqlClient(client); + + validateMetadataRetrieval(sqlClient); + + validateStatementExecution(sqlClient); + + validatePreparedStatementExecution(sqlClient, allocator); + } + + private void validateMetadataRetrieval(FlightSqlClient sqlClient) throws Exception { + final CallOption[] options = new CallOption[0]; + + validate(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA, sqlClient.getCatalogs(options), + sqlClient); + validate(FlightSqlProducer.Schemas.GET_SCHEMAS_SCHEMA, + sqlClient.getSchemas("catalog", "db_schema_filter_pattern", options), + sqlClient); + validate(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA, + sqlClient.getTables("catalog", "db_schema_filter_pattern", "table_filter_pattern", + Arrays.asList("table", "view"), true, options), sqlClient); + validate(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA, sqlClient.getTableTypes(options), + sqlClient); + validate(FlightSqlProducer.Schemas.GET_PRIMARY_KEYS_SCHEMA, + sqlClient.getPrimaryKeys(TableRef.of("catalog", "db_schema", "table"), options), + sqlClient); + validate(FlightSqlProducer.Schemas.GET_EXPORTED_KEYS_SCHEMA, + sqlClient.getExportedKeys(TableRef.of("catalog", "db_schema", "table"), options), + sqlClient); + validate(FlightSqlProducer.Schemas.GET_IMPORTED_KEYS_SCHEMA, + sqlClient.getImportedKeys(TableRef.of("catalog", "db_schema", "table"), options), + sqlClient); + validate(FlightSqlProducer.Schemas.GET_CROSS_REFERENCE_SCHEMA, + sqlClient.getCrossReference(TableRef.of("pk_catalog", "pk_db_schema", "pk_table"), + TableRef.of("fk_catalog", "fk_db_schema", "fk_table"), options), + sqlClient); + validate(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA, + sqlClient.getSqlInfo(new FlightSql.SqlInfo[] {FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME, + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY}, options), sqlClient); + } + + private void validateStatementExecution(FlightSqlClient sqlClient) throws Exception { + final CallOption[] options = new CallOption[0]; + + validate(FlightSqlScenarioProducer.getQuerySchema(), + sqlClient.execute("SELECT STATEMENT", options), sqlClient); + + IntegrationAssertions.assertEquals(sqlClient.executeUpdate("UPDATE STATEMENT", options), + UPDATE_STATEMENT_EXPECTED_ROWS); + } + + private void validatePreparedStatementExecution(FlightSqlClient sqlClient, + BufferAllocator allocator) throws Exception { + final CallOption[] options = new CallOption[0]; + try (FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare( + "SELECT PREPARED STATEMENT"); + VectorSchemaRoot parameters = VectorSchemaRoot.create( + FlightSqlScenarioProducer.getQuerySchema(), allocator)) { + parameters.setRowCount(1); + preparedStatement.setParameters(parameters); + + validate(FlightSqlScenarioProducer.getQuerySchema(), preparedStatement.execute(options), + sqlClient); + } + + try (FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare( + "UPDATE PREPARED STATEMENT")) { + IntegrationAssertions.assertEquals(preparedStatement.executeUpdate(options), + UPDATE_PREPARED_STATEMENT_EXPECTED_ROWS); + } + } + + private void validate(Schema expectedSchema, FlightInfo flightInfo, + FlightSqlClient sqlClient) throws Exception { + Ticket ticket = flightInfo.getEndpoints().get(0).getTicket(); + try (FlightStream stream = sqlClient.getStream(ticket)) { + Schema actualSchema = stream.getSchema(); + IntegrationAssertions.assertEquals(expectedSchema, actualSchema); + } + } +} diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java new file mode 100644 index 00000000000..f3554e1d3d8 --- /dev/null +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.integration.tests; + +import static com.google.protobuf.Any.pack; +import static java.util.Collections.singletonList; + +import java.util.List; + +import org.apache.arrow.flight.Criteria; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.SchemaResult; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; + +import com.google.protobuf.ByteString; +import com.google.protobuf.Message; + +/** + * Hardcoded Flight SQL producer used for cross-language integration tests. + */ +public class FlightSqlScenarioProducer implements FlightSqlProducer { + private final BufferAllocator allocator; + + public FlightSqlScenarioProducer(BufferAllocator allocator) { + this.allocator = allocator; + } + + /** + * Schema to be returned for mocking the statement/prepared statement results. + * Must be the same across all languages. + */ + static Schema getQuerySchema() { + return new Schema( + singletonList( + new Field("id", FieldType.nullable(new ArrowType.Int(64, true)), null) + ) + ); + } + + @Override + public void createPreparedStatement(FlightSql.ActionCreatePreparedStatementRequest request, + CallContext context, StreamListener listener) { + IntegrationAssertions.assertTrue("Expect to be one of the two queries used on tests", + request.getQuery().equals("SELECT PREPARED STATEMENT") || + request.getQuery().equals("UPDATE PREPARED STATEMENT")); + + final FlightSql.ActionCreatePreparedStatementResult + result = FlightSql.ActionCreatePreparedStatementResult.newBuilder() + .setPreparedStatementHandle(ByteString.copyFromUtf8(request.getQuery() + " HANDLE")) + .build(); + listener.onNext(new Result(pack(result).toByteArray())); + listener.onCompleted(); + } + + @Override + public void closePreparedStatement(FlightSql.ActionClosePreparedStatementRequest request, + CallContext context, StreamListener listener) { + IntegrationAssertions.assertTrue("Expect to be one of the two queries used on tests", + request.getPreparedStatementHandle().toStringUtf8().equals("SELECT PREPARED STATEMENT HANDLE") || + request.getPreparedStatementHandle().toStringUtf8().equals("UPDATE PREPARED STATEMENT HANDLE")); + + listener.onCompleted(); + } + + @Override + public FlightInfo getFlightInfoStatement(FlightSql.CommandStatementQuery command, + CallContext context, FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(command.getQuery(), "SELECT STATEMENT"); + + ByteString handle = ByteString.copyFromUtf8("SELECT STATEMENT HANDLE"); + + FlightSql.TicketStatementQuery ticket = FlightSql.TicketStatementQuery.newBuilder() + .setStatementHandle(handle) + .build(); + return getFlightInfoForSchema(ticket, descriptor, getQuerySchema()); + } + + @Override + public FlightInfo getFlightInfoPreparedStatement(FlightSql.CommandPreparedStatementQuery command, + CallContext context, + FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(command.getPreparedStatementHandle().toStringUtf8(), + "SELECT PREPARED STATEMENT HANDLE"); + + return getFlightInfoForSchema(command, descriptor, getQuerySchema()); + } + + @Override + public SchemaResult getSchemaStatement(FlightSql.CommandStatementQuery command, + CallContext context, FlightDescriptor descriptor) { + return new SchemaResult(getQuerySchema()); + } + + @Override + public void getStreamStatement(FlightSql.TicketStatementQuery ticket, CallContext context, + ServerStreamListener listener) { + putEmptyBatchToStreamListener(listener, getQuerySchema()); + } + + @Override + public void getStreamPreparedStatement(FlightSql.CommandPreparedStatementQuery command, + CallContext context, ServerStreamListener listener) { + putEmptyBatchToStreamListener(listener, getQuerySchema()); + } + + private Runnable acceptPutReturnConstant(StreamListener ackStream, long value) { + return () -> { + final FlightSql.DoPutUpdateResult build = + FlightSql.DoPutUpdateResult.newBuilder().setRecordCount(value).build(); + + try (final ArrowBuf buffer = allocator.buffer(build.getSerializedSize())) { + buffer.writeBytes(build.toByteArray()); + ackStream.onNext(PutResult.metadata(buffer)); + ackStream.onCompleted(); + } + }; + } + + @Override + public Runnable acceptPutStatement(FlightSql.CommandStatementUpdate command, CallContext context, + FlightStream flightStream, + StreamListener ackStream) { + IntegrationAssertions.assertEquals(command.getQuery(), "UPDATE STATEMENT"); + + return acceptPutReturnConstant(ackStream, FlightSqlScenario.UPDATE_STATEMENT_EXPECTED_ROWS); + } + + @Override + public Runnable acceptPutPreparedStatementUpdate(FlightSql.CommandPreparedStatementUpdate command, + CallContext context, FlightStream flightStream, + StreamListener ackStream) { + IntegrationAssertions.assertEquals(command.getPreparedStatementHandle().toStringUtf8(), + "UPDATE PREPARED STATEMENT HANDLE"); + + return acceptPutReturnConstant(ackStream, FlightSqlScenario.UPDATE_PREPARED_STATEMENT_EXPECTED_ROWS); + } + + @Override + public Runnable acceptPutPreparedStatementQuery(FlightSql.CommandPreparedStatementQuery command, + CallContext context, FlightStream flightStream, + StreamListener ackStream) { + IntegrationAssertions.assertEquals(command.getPreparedStatementHandle().toStringUtf8(), + "SELECT PREPARED STATEMENT HANDLE"); + + IntegrationAssertions.assertEquals(getQuerySchema(), flightStream.getSchema()); + + return ackStream::onCompleted; + } + + @Override + public FlightInfo getFlightInfoSqlInfo(FlightSql.CommandGetSqlInfo request, CallContext context, + FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(request.getInfoCount(), 2); + IntegrationAssertions.assertEquals(request.getInfo(0), + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME_VALUE); + IntegrationAssertions.assertEquals(request.getInfo(1), + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY_VALUE); + + return getFlightInfoForSchema(request, descriptor, Schemas.GET_SQL_INFO_SCHEMA); + } + + @Override + public void getStreamSqlInfo(FlightSql.CommandGetSqlInfo command, CallContext context, + ServerStreamListener listener) { + putEmptyBatchToStreamListener(listener, Schemas.GET_SQL_INFO_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoCatalogs(FlightSql.CommandGetCatalogs request, CallContext context, + FlightDescriptor descriptor) { + return getFlightInfoForSchema(request, descriptor, Schemas.GET_CATALOGS_SCHEMA); + } + + private void putEmptyBatchToStreamListener(ServerStreamListener stream, Schema schema) { + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + stream.start(root); + stream.putNext(); + stream.completed(); + } + } + + @Override + public void getStreamCatalogs(CallContext context, ServerStreamListener listener) { + putEmptyBatchToStreamListener(listener, Schemas.GET_CATALOGS_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoSchemas(FlightSql.CommandGetDbSchemas request, CallContext context, + FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(request.getCatalog(), "catalog"); + IntegrationAssertions.assertEquals(request.getDbSchemaFilterPattern(), + "db_schema_filter_pattern"); + + return getFlightInfoForSchema(request, descriptor, Schemas.GET_SCHEMAS_SCHEMA); + } + + @Override + public void getStreamSchemas(FlightSql.CommandGetDbSchemas command, CallContext context, + ServerStreamListener listener) { + putEmptyBatchToStreamListener(listener, Schemas.GET_SCHEMAS_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoTables(FlightSql.CommandGetTables request, CallContext context, + FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(request.getCatalog(), "catalog"); + IntegrationAssertions.assertEquals(request.getDbSchemaFilterPattern(), + "db_schema_filter_pattern"); + IntegrationAssertions.assertEquals(request.getTableNameFilterPattern(), "table_filter_pattern"); + IntegrationAssertions.assertEquals(request.getTableTypesCount(), 2); + IntegrationAssertions.assertEquals(request.getTableTypes(0), "table"); + IntegrationAssertions.assertEquals(request.getTableTypes(1), "view"); + + return getFlightInfoForSchema(request, descriptor, Schemas.GET_TABLES_SCHEMA); + } + + @Override + public void getStreamTables(FlightSql.CommandGetTables command, CallContext context, + ServerStreamListener listener) { + putEmptyBatchToStreamListener(listener, Schemas.GET_TABLES_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoTableTypes(FlightSql.CommandGetTableTypes request, + CallContext context, FlightDescriptor descriptor) { + return getFlightInfoForSchema(request, descriptor, Schemas.GET_TABLE_TYPES_SCHEMA); + } + + @Override + public void getStreamTableTypes(CallContext context, ServerStreamListener listener) { + putEmptyBatchToStreamListener(listener, Schemas.GET_TABLE_TYPES_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoPrimaryKeys(FlightSql.CommandGetPrimaryKeys request, + CallContext context, FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(request.getCatalog(), "catalog"); + IntegrationAssertions.assertEquals(request.getDbSchema(), "db_schema"); + IntegrationAssertions.assertEquals(request.getTable(), "table"); + + return getFlightInfoForSchema(request, descriptor, Schemas.GET_PRIMARY_KEYS_SCHEMA); + } + + @Override + public void getStreamPrimaryKeys(FlightSql.CommandGetPrimaryKeys command, CallContext context, + ServerStreamListener listener) { + putEmptyBatchToStreamListener(listener, Schemas.GET_PRIMARY_KEYS_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoExportedKeys(FlightSql.CommandGetExportedKeys request, + CallContext context, FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(request.getCatalog(), "catalog"); + IntegrationAssertions.assertEquals(request.getDbSchema(), "db_schema"); + IntegrationAssertions.assertEquals(request.getTable(), "table"); + + return getFlightInfoForSchema(request, descriptor, Schemas.GET_EXPORTED_KEYS_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoImportedKeys(FlightSql.CommandGetImportedKeys request, + CallContext context, FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(request.getCatalog(), "catalog"); + IntegrationAssertions.assertEquals(request.getDbSchema(), "db_schema"); + IntegrationAssertions.assertEquals(request.getTable(), "table"); + + return getFlightInfoForSchema(request, descriptor, Schemas.GET_IMPORTED_KEYS_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoCrossReference(FlightSql.CommandGetCrossReference request, + CallContext context, FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(request.getPkCatalog(), "pk_catalog"); + IntegrationAssertions.assertEquals(request.getPkDbSchema(), "pk_db_schema"); + IntegrationAssertions.assertEquals(request.getPkTable(), "pk_table"); + IntegrationAssertions.assertEquals(request.getFkCatalog(), "fk_catalog"); + IntegrationAssertions.assertEquals(request.getFkDbSchema(), "fk_db_schema"); + IntegrationAssertions.assertEquals(request.getFkTable(), "fk_table"); + + return getFlightInfoForSchema(request, descriptor, Schemas.GET_CROSS_REFERENCE_SCHEMA); + } + + @Override + public void getStreamExportedKeys(FlightSql.CommandGetExportedKeys command, CallContext context, + ServerStreamListener listener) { + putEmptyBatchToStreamListener(listener, Schemas.GET_EXPORTED_KEYS_SCHEMA); + } + + @Override + public void getStreamImportedKeys(FlightSql.CommandGetImportedKeys command, CallContext context, + ServerStreamListener listener) { + putEmptyBatchToStreamListener(listener, Schemas.GET_IMPORTED_KEYS_SCHEMA); + } + + @Override + public void getStreamCrossReference(FlightSql.CommandGetCrossReference command, + CallContext context, ServerStreamListener listener) { + putEmptyBatchToStreamListener(listener, Schemas.GET_CROSS_REFERENCE_SCHEMA); + } + + @Override + public void close() throws Exception { + + } + + @Override + public void listFlights(CallContext context, Criteria criteria, + StreamListener listener) { + + } + + private FlightInfo getFlightInfoForSchema(final T request, + final FlightDescriptor descriptor, + final Schema schema) { + final Ticket ticket = new Ticket(pack(request).toByteArray()); + final List endpoints = singletonList(new FlightEndpoint(ticket)); + + return new FlightInfo(schema, descriptor, endpoints, -1, -1); + } +} diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationAssertions.java similarity index 88% rename from java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java rename to java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationAssertions.java index 576d1887f39..e124ed0ea74 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationAssertions.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.flight.example.integration; +package org.apache.arrow.flight.integration.tests; import java.util.Objects; @@ -63,6 +63,15 @@ static void assertFalse(String message, boolean value) { } } + /** + * Assert that the value is true, using the given message as an error otherwise. + */ + static void assertTrue(String message, boolean value) { + if (!value) { + throw new AssertionError("Expected true: " + message); + } + } + /** * An interface used with {@link #assertThrows(Class, AssertThrows)}. */ diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestClient.java similarity index 93% rename from java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java rename to java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestClient.java index 27a545f84fd..2a36747b618 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestClient.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.flight.example.integration; +package org.apache.arrow.flight.integration.tests; import static org.apache.arrow.memory.util.LargeMemoryUtil.checkedCastToInt; @@ -91,7 +91,7 @@ private void run(String[] args) throws Exception { final Location defaultLocation = Location.forGrpcInsecure(host, port); try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); - final FlightClient client = FlightClient.builder(allocator, defaultLocation).build()) { + final FlightClient client = FlightClient.builder(allocator, defaultLocation).build()) { if (cmd.hasOption("scenario")) { Scenarios.getScenario(cmd.getOptionValue("scenario")).client(allocator, defaultLocation, client); @@ -109,7 +109,7 @@ private static void testStream(BufferAllocator allocator, Location server, Fligh // 1. Read data from JSON and upload to server. FlightDescriptor descriptor = FlightDescriptor.path(inputPath); try (JsonFileReader reader = new JsonFileReader(new File(inputPath), allocator); - VectorSchemaRoot root = VectorSchemaRoot.create(reader.start(), allocator)) { + VectorSchemaRoot root = VectorSchemaRoot.create(reader.start(), allocator)) { FlightClient.ClientStreamListener stream = client.startPut(descriptor, root, reader, new AsyncPutListener() { int counter = 0; @@ -157,10 +157,10 @@ public void onNext(PutResult val) { for (Location location : locations) { System.out.println("Verifying location " + location.getUri()); try (FlightClient readClient = FlightClient.builder(allocator, location).build(); - FlightStream stream = readClient.getStream(endpoint.getTicket()); - VectorSchemaRoot root = stream.getRoot(); - VectorSchemaRoot downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator); - JsonFileReader reader = new JsonFileReader(new File(inputPath), allocator)) { + FlightStream stream = readClient.getStream(endpoint.getTicket()); + VectorSchemaRoot root = stream.getRoot(); + VectorSchemaRoot downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator); + JsonFileReader reader = new JsonFileReader(new File(inputPath), allocator)) { VectorLoader loader = new VectorLoader(downloadedRoot); VectorUnloader unloader = new VectorUnloader(root); diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestServer.java similarity index 98% rename from java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java rename to java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestServer.java index da336c5024a..7f5e15fe376 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestServer.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.flight.example.integration; +package org.apache.arrow.flight.integration.tests; import org.apache.arrow.flight.FlightServer; import org.apache.arrow.flight.Location; diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/MiddlewareScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/MiddlewareScenario.java similarity index 99% rename from java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/MiddlewareScenario.java rename to java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/MiddlewareScenario.java index c710ce98b56..c284a577c08 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/MiddlewareScenario.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/MiddlewareScenario.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.flight.example.integration; +package org.apache.arrow.flight.integration.tests; import java.nio.charset.StandardCharsets; import java.util.Arrays; diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenario.java similarity index 96% rename from java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java rename to java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenario.java index b3b962d2e73..bcc657b765c 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenario.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.flight.example.integration; +package org.apache.arrow.flight.integration.tests; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightProducer; diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java similarity index 96% rename from java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java rename to java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java index cd9859b4f36..16cc856daf5 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.flight.example.integration; +package org.apache.arrow.flight.integration.tests; import java.util.Map; import java.util.TreeMap; @@ -41,6 +41,7 @@ private Scenarios() { scenarios = new TreeMap<>(); scenarios.put("auth:basic_proto", AuthBasicProtoScenario::new); scenarios.put("middleware", MiddlewareScenario::new); + scenarios.put("flight_sql", FlightSqlScenario::new); } private static Scenarios getInstance() { diff --git a/java/flight/flight-sql/pom.xml b/java/flight/flight-sql/pom.xml new file mode 100644 index 00000000000..b17ab9b7c48 --- /dev/null +++ b/java/flight/flight-sql/pom.xml @@ -0,0 +1,151 @@ + + + + 4.0.0 + + arrow-flight + org.apache.arrow + 7.0.0-SNAPSHOT + ../pom.xml + + + flight-sql + Arrow Flight SQL + (Experimental)Contains utility classes to expose Flight SQL semantics for clients and servers over Arrow Flight + jar + + + 1 + + + + + org.apache.arrow + flight-core + ${project.version} + + + io.netty + netty-transport-native-unix-common + + + io.netty + netty-transport-native-kqueue + + + io.netty + netty-transport-native-epoll + + + + + org.apache.arrow + arrow-memory-core + ${project.version} + + + org.apache.arrow + arrow-jdbc + ${project.version} + + + io.grpc + grpc-protobuf + ${dep.grpc.version} + + + com.google.guava + guava + + + io.grpc + grpc-stub + ${dep.grpc.version} + + + com.google.protobuf + protobuf-java + ${dep.protobuf.version} + + + io.grpc + grpc-api + ${dep.grpc.version} + + + org.apache.arrow + arrow-vector + ${project.version} + ${arrow.vector.classifier} + + + org.slf4j + slf4j-api + + + org.apache.derby + derby + 10.14.2.0 + test + + + org.apache.commons + commons-dbcp2 + 2.9.0 + test + + + commons-logging + commons-logging + + + + + org.apache.commons + commons-pool2 + 2.11.1 + test + + + org.hamcrest + hamcrest + + + commons-cli + commons-cli + 1.4 + + + + + + + org.xolstice.maven.plugins + protobuf-maven-plugin + + + proto-compile + generate-sources + + ${basedir}/../../../format/ + + + compile + compile-custom + + + + + + + + diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java new file mode 100644 index 00000000000..069d59edd48 --- /dev/null +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java @@ -0,0 +1,632 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql; + +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionClosePreparedStatementRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCatalogs; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCrossReference; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetDbSchemas; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetExportedKeys; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetImportedKeys; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetPrimaryKeys; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetSqlInfo; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTableTypes; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTables; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementUpdate; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementQuery; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementUpdate; +import static org.apache.arrow.flight.sql.impl.FlightSql.DoPutUpdateResult; +import static org.apache.arrow.flight.sql.impl.FlightSql.SqlInfo; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.nio.channels.Channels; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.ExecutionException; +import java.util.stream.Collectors; + +import org.apache.arrow.flight.Action; +import org.apache.arrow.flight.CallOption; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.SchemaResult; +import org.apache.arrow.flight.SyncPutListener; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementResult; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementQuery; +import org.apache.arrow.flight.sql.util.TableRef; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ReadChannel; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.pojo.Schema; + +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; + +/** + * Flight client with Flight SQL semantics. + */ +public class FlightSqlClient implements AutoCloseable { + private final FlightClient client; + + public FlightSqlClient(final FlightClient client) { + this.client = Objects.requireNonNull(client, "Client cannot be null!"); + } + + /** + * Execute a query on the server. + * + * @param query The query to execute. + * @param options RPC-layer hints for this call. + * @return a FlightInfo object representing the stream(s) to fetch. + */ + public FlightInfo execute(final String query, final CallOption... options) { + final CommandStatementQuery.Builder builder = CommandStatementQuery.newBuilder(); + builder.setQuery(query); + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); + return client.getInfo(descriptor, options); + } + + /** + * Execute an update query on the server. + * + * @param query The query to execute. + * @param options RPC-layer hints for this call. + * @return the number of rows affected. + */ + public long executeUpdate(final String query, final CallOption... options) { + final CommandStatementUpdate.Builder builder = CommandStatementUpdate.newBuilder(); + builder.setQuery(query); + + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); + final SyncPutListener putListener = new SyncPutListener(); + client.startPut(descriptor, VectorSchemaRoot.of(), putListener, options); + + try { + final PutResult read = putListener.read(); + try (final ArrowBuf metadata = read.getApplicationMetadata()) { + final DoPutUpdateResult doPutUpdateResult = DoPutUpdateResult.parseFrom(metadata.nioBuffer()); + return doPutUpdateResult.getRecordCount(); + } + } catch (final InterruptedException | ExecutionException e) { + throw CallStatus.CANCELLED.withCause(e).toRuntimeException(); + } catch (final InvalidProtocolBufferException e) { + throw CallStatus.INTERNAL.withCause(e).toRuntimeException(); + } + } + + /** + * Request a list of catalogs. + * + * @param options RPC-layer hints for this call. + * @return a FlightInfo object representing the stream(s) to fetch. + */ + public FlightInfo getCatalogs(final CallOption... options) { + final CommandGetCatalogs.Builder builder = CommandGetCatalogs.newBuilder(); + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); + return client.getInfo(descriptor, options); + } + + /** + * Request a list of schemas. + * + * @param catalog The catalog. + * @param dbSchemaFilterPattern The schema filter pattern. + * @param options RPC-layer hints for this call. + * @return a FlightInfo object representing the stream(s) to fetch. + */ + public FlightInfo getSchemas(final String catalog, final String dbSchemaFilterPattern, final CallOption... options) { + final CommandGetDbSchemas.Builder builder = CommandGetDbSchemas.newBuilder(); + + if (catalog != null) { + builder.setCatalog(catalog); + } + + if (dbSchemaFilterPattern != null) { + builder.setDbSchemaFilterPattern(dbSchemaFilterPattern); + } + + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); + return client.getInfo(descriptor, options); + } + + /** + * Get schema for a stream. + * + * @param descriptor The descriptor for the stream. + * @param options RPC-layer hints for this call. + */ + public SchemaResult getSchema(FlightDescriptor descriptor, CallOption... options) { + return client.getSchema(descriptor, options); + } + + /** + * Retrieve a stream from the server. + * + * @param ticket The ticket granting access to the data stream. + * @param options RPC-layer hints for this call. + */ + public FlightStream getStream(Ticket ticket, CallOption... options) { + return client.getStream(ticket, options); + } + + /** + * Request a set of Flight SQL metadata. + * + * @param info The set of metadata to retrieve. None to retrieve all metadata. + * @return a FlightInfo object representing the stream(s) to fetch. + */ + public FlightInfo getSqlInfo(final SqlInfo... info) { + return getSqlInfo(info, new CallOption[0]); + } + + /** + * Request a set of Flight SQL metadata. + * + * @param info The set of metadata to retrieve. None to retrieve all metadata. + * @param options RPC-layer hints for this call. + * @return a FlightInfo object representing the stream(s) to fetch. + */ + public FlightInfo getSqlInfo(final SqlInfo[] info, final CallOption... options) { + final int[] infoNumbers = Arrays.stream(info).mapToInt(SqlInfo::getNumber).toArray(); + return getSqlInfo(infoNumbers, options); + } + + /** + * Request a set of Flight SQL metadata. + * Use this method if you would like to retrieve custom metadata, where the custom metadata key values start + * from 10_000. + * + * @param info The set of metadata to retrieve. None to retrieve all metadata. + * @param options RPC-layer hints for this call. + * @return a FlightInfo object representing the stream(s) to fetch. + */ + public FlightInfo getSqlInfo(final int[] info, final CallOption... options) { + return getSqlInfo(Arrays.stream(info).boxed().collect(Collectors.toList()), options); + } + + /** + * Request a set of Flight SQL metadata. + * Use this method if you would like to retrieve custom metadata, where the custom metadata key values start + * from 10_000. + * + * @param info The set of metadata to retrieve. None to retrieve all metadata. + * @param options RPC-layer hints for this call. + * @return a FlightInfo object representing the stream(s) to fetch. + */ + public FlightInfo getSqlInfo(final Iterable info, final CallOption... options) { + final CommandGetSqlInfo.Builder builder = CommandGetSqlInfo.newBuilder(); + builder.addAllInfo(info); + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); + return client.getInfo(descriptor, options); + } + + /** + * Request a list of tables. + * + * @param catalog The catalog. + * @param dbSchemaFilterPattern The schema filter pattern. + * @param tableFilterPattern The table filter pattern. + * @param tableTypes The table types to include. + * @param includeSchema True to include the schema upon return, false to not include the schema. + * @param options RPC-layer hints for this call. + * @return a FlightInfo object representing the stream(s) to fetch. + */ + public FlightInfo getTables(final String catalog, final String dbSchemaFilterPattern, + final String tableFilterPattern, final List tableTypes, + final boolean includeSchema, final CallOption... options) { + final CommandGetTables.Builder builder = CommandGetTables.newBuilder(); + + if (catalog != null) { + builder.setCatalog(catalog); + } + + if (dbSchemaFilterPattern != null) { + builder.setDbSchemaFilterPattern(dbSchemaFilterPattern); + } + + if (tableFilterPattern != null) { + builder.setTableNameFilterPattern(tableFilterPattern); + } + + if (tableTypes != null) { + builder.addAllTableTypes(tableTypes); + } + builder.setIncludeSchema(includeSchema); + + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); + return client.getInfo(descriptor, options); + } + + /** + * Request the primary keys for a table. + * + * @param tableRef An object which hold info about catalog, dbSchema and table. + * @param options RPC-layer hints for this call. + * @return a FlightInfo object representing the stream(s) to fetch. + */ + public FlightInfo getPrimaryKeys(final TableRef tableRef, final CallOption... options) { + final CommandGetPrimaryKeys.Builder builder = CommandGetPrimaryKeys.newBuilder(); + + if (tableRef.getCatalog() != null) { + builder.setCatalog(tableRef.getCatalog()); + } + + if (tableRef.getDbSchema() != null) { + builder.setDbSchema(tableRef.getDbSchema()); + } + + Objects.requireNonNull(tableRef.getTable()); + builder.setTable(tableRef.getTable()).build(); + + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); + return client.getInfo(descriptor, options); + } + + /** + * Retrieves a description about the foreign key columns that reference the primary key columns of the given table. + * + * @param tableRef An object which hold info about catalog, dbSchema and table. + * @param options RPC-layer hints for this call. + * @return a FlightInfo object representing the stream(s) to fetch. + */ + public FlightInfo getExportedKeys(final TableRef tableRef, final CallOption... options) { + Objects.requireNonNull(tableRef.getTable(), "Table cannot be null."); + + final CommandGetExportedKeys.Builder builder = CommandGetExportedKeys.newBuilder(); + + if (tableRef.getCatalog() != null) { + builder.setCatalog(tableRef.getCatalog()); + } + + if (tableRef.getDbSchema() != null) { + builder.setDbSchema(tableRef.getDbSchema()); + } + + Objects.requireNonNull(tableRef.getTable()); + builder.setTable(tableRef.getTable()).build(); + + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); + return client.getInfo(descriptor, options); + } + + /** + * Retrieves the foreign key columns for the given table. + * + * @param tableRef An object which hold info about catalog, dbSchema and table. + * @param options RPC-layer hints for this call. + * @return a FlightInfo object representing the stream(s) to fetch. + */ + public FlightInfo getImportedKeys(final TableRef tableRef, + final CallOption... options) { + Objects.requireNonNull(tableRef.getTable(), "Table cannot be null."); + + final CommandGetImportedKeys.Builder builder = CommandGetImportedKeys.newBuilder(); + + if (tableRef.getCatalog() != null) { + builder.setCatalog(tableRef.getCatalog()); + } + + if (tableRef.getDbSchema() != null) { + builder.setDbSchema(tableRef.getDbSchema()); + } + + Objects.requireNonNull(tableRef.getTable()); + builder.setTable(tableRef.getTable()).build(); + + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); + return client.getInfo(descriptor, options); + } + + /** + * Retrieves a description of the foreign key columns that reference the given table's + * primary key columns (the foreign keys exported by a table). + * + * @param pkTableRef An object which hold info about catalog, dbSchema and table from a primary table. + * @param fkTableRef An object which hold info about catalog, dbSchema and table from a foreign table. + * @param options RPC-layer hints for this call. + * @return a FlightInfo object representing the stream(s) to fetch. + */ + public FlightInfo getCrossReference(final TableRef pkTableRef, + final TableRef fkTableRef, final CallOption... options) { + Objects.requireNonNull(pkTableRef.getTable(), "Parent Table cannot be null."); + Objects.requireNonNull(fkTableRef.getTable(), "Foreign Table cannot be null."); + + final CommandGetCrossReference.Builder builder = CommandGetCrossReference.newBuilder(); + + if (pkTableRef.getCatalog() != null) { + builder.setPkCatalog(pkTableRef.getCatalog()); + } + + if (pkTableRef.getDbSchema() != null) { + builder.setPkDbSchema(pkTableRef.getDbSchema()); + } + + if (fkTableRef.getCatalog() != null) { + builder.setFkCatalog(fkTableRef.getCatalog()); + } + + if (fkTableRef.getDbSchema() != null) { + builder.setFkDbSchema(fkTableRef.getDbSchema()); + } + + builder.setPkTable(pkTableRef.getTable()); + builder.setFkTable(fkTableRef.getTable()); + + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); + return client.getInfo(descriptor, options); + } + + /** + * Request a list of table types. + * + * @param options RPC-layer hints for this call. + * @return a FlightInfo object representing the stream(s) to fetch. + */ + public FlightInfo getTableTypes(final CallOption... options) { + final CommandGetTableTypes.Builder builder = CommandGetTableTypes.newBuilder(); + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); + return client.getInfo(descriptor, options); + } + + /** + * Create a prepared statement on the server. + * + * @param query The query to prepare. + * @param options RPC-layer hints for this call. + * @return The representation of the prepared statement which exists on the server. + */ + public PreparedStatement prepare(final String query, final CallOption... options) { + return new PreparedStatement(client, query, options); + } + + @Override + public void close() throws SQLException { + try { + AutoCloseables.close(client); + } catch (final Exception e) { + throw new SQLException(e); + } + } + + /** + * Helper class to encapsulate Flight SQL prepared statement logic. + */ + public static class PreparedStatement implements AutoCloseable { + private final FlightClient client; + private final ActionCreatePreparedStatementResult preparedStatementResult; + private VectorSchemaRoot parameterBindingRoot; + private boolean isClosed; + private Schema resultSetSchema; + private Schema parameterSchema; + + /** + * Constructor. + * + * @param client The client. PreparedStatement does not maintain this resource. + * @param sql The query. + * @param options RPC-layer hints for this call. + */ + public PreparedStatement(final FlightClient client, final String sql, final CallOption... options) { + this.client = client; + final Action action = new Action( + FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), + Any.pack(ActionCreatePreparedStatementRequest + .newBuilder() + .setQuery(sql) + .build()) + .toByteArray()); + final Iterator preparedStatementResults = client.doAction(action, options); + + preparedStatementResult = FlightSqlUtils.unpackAndParseOrThrow( + preparedStatementResults.next().getBody(), + ActionCreatePreparedStatementResult.class); + + isClosed = false; + } + + /** + * Set the {@link #parameterBindingRoot} containing the parameter binding from a {@link PreparedStatement} + * operation. + * + * @param parameterBindingRoot a {@code VectorSchemaRoot} object containing the values to be used in the + * {@code PreparedStatement} setters. + */ + public void setParameters(final VectorSchemaRoot parameterBindingRoot) { + if (this.parameterBindingRoot != null) { + if (this.parameterBindingRoot.equals(parameterBindingRoot)) { + return; + } + this.parameterBindingRoot.close(); + } + this.parameterBindingRoot = parameterBindingRoot; + } + + /** + * Closes the {@link #parameterBindingRoot}, which contains the parameter binding from + * a {@link PreparedStatement} operation, releasing its resources. + */ + public void clearParameters() { + if (parameterBindingRoot != null) { + parameterBindingRoot.close(); + } + } + + /** + * Returns the Schema of the resultset. + * + * @return the Schema of the resultset. + */ + public Schema getResultSetSchema() { + if (resultSetSchema == null) { + final ByteString bytes = preparedStatementResult.getDatasetSchema(); + resultSetSchema = deserializeSchema(bytes); + } + return resultSetSchema; + } + + /** + * Returns the Schema of the parameters. + * + * @return the Schema of the parameters. + */ + public Schema getParameterSchema() { + if (parameterSchema == null) { + final ByteString bytes = preparedStatementResult.getParameterSchema(); + parameterSchema = deserializeSchema(bytes); + } + return parameterSchema; + } + + private Schema deserializeSchema(final ByteString bytes) { + try { + return bytes.isEmpty() ? + new Schema(Collections.emptyList()) : + MessageSerializer.deserializeSchema( + new ReadChannel(Channels.newChannel( + new ByteArrayInputStream(bytes.toByteArray())))); + } catch (final IOException e) { + throw new RuntimeException("Failed to deserialize schema", e); + } + } + + /** + * Executes the prepared statement query on the server. + * + * @param options RPC-layer hints for this call. + * @return a FlightInfo object representing the stream(s) to fetch. + */ + public FlightInfo execute(final CallOption... options) throws SQLException { + checkOpen(); + + final FlightDescriptor descriptor = FlightDescriptor + .command(Any.pack(CommandPreparedStatementQuery.newBuilder() + .setPreparedStatementHandle(preparedStatementResult.getPreparedStatementHandle()) + .build()) + .toByteArray()); + + if (parameterBindingRoot != null && parameterBindingRoot.getRowCount() > 0) { + final SyncPutListener putListener = new SyncPutListener(); + + FlightClient.ClientStreamListener listener = + client.startPut(descriptor, parameterBindingRoot, putListener, options); + + listener.putNext(); + listener.completed(); + listener.getResult(); + } + + return client.getInfo(descriptor, options); + } + + /** + * Checks whether this client is open. + * + * @throws IllegalStateException if client is closed. + */ + protected final void checkOpen() { + Preconditions.checkState(!isClosed, "Statement closed"); + } + + /** + * Executes the prepared statement update on the server. + * + * @param options RPC-layer hints for this call. + * @return the count of updated records + */ + public long executeUpdate(final CallOption... options) { + checkOpen(); + final FlightDescriptor descriptor = FlightDescriptor + .command(Any.pack(CommandPreparedStatementUpdate.newBuilder() + .setPreparedStatementHandle(preparedStatementResult.getPreparedStatementHandle()) + .build()) + .toByteArray()); + setParameters(parameterBindingRoot == null ? VectorSchemaRoot.of() : parameterBindingRoot); + final SyncPutListener putListener = new SyncPutListener(); + final FlightClient.ClientStreamListener listener = + client.startPut(descriptor, parameterBindingRoot, putListener, options); + listener.putNext(); + listener.completed(); + try { + final PutResult read = putListener.read(); + try (final ArrowBuf metadata = read.getApplicationMetadata()) { + final DoPutUpdateResult doPutUpdateResult = + DoPutUpdateResult.parseFrom(metadata.nioBuffer()); + return doPutUpdateResult.getRecordCount(); + } + } catch (final InterruptedException | ExecutionException e) { + throw CallStatus.CANCELLED.withCause(e).toRuntimeException(); + } catch (final InvalidProtocolBufferException e) { + throw CallStatus.INVALID_ARGUMENT.withCause(e).toRuntimeException(); + } + } + + /** + * Closes the client. + * + * @param options RPC-layer hints for this call. + */ + public void close(final CallOption... options) { + if (isClosed) { + return; + } + isClosed = true; + final Action action = new Action( + FlightSqlUtils.FLIGHT_SQL_CLOSE_PREPARED_STATEMENT.getType(), + Any.pack(ActionClosePreparedStatementRequest.newBuilder() + .setPreparedStatementHandle(preparedStatementResult.getPreparedStatementHandle()) + .build()) + .toByteArray()); + final Iterator closePreparedStatementResults = client.doAction(action, options); + closePreparedStatementResults.forEachRemaining(result -> { + }); + if (parameterBindingRoot != null) { + parameterBindingRoot.close(); + } + } + + @Override + public void close() { + close(new CallOption[0]); + } + + /** + * Returns if the prepared statement is already closed. + * + * @return true if the prepared statement is already closed. + */ + public boolean isClosed() { + return isClosed; + } + } +} diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java new file mode 100644 index 00000000000..f1eaf2f8988 --- /dev/null +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java @@ -0,0 +1,668 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql; + +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static java.util.stream.IntStream.range; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementResult; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCrossReference; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetDbSchemas; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetExportedKeys; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetImportedKeys; +import static org.apache.arrow.vector.complex.MapVector.DATA_VECTOR_NAME; +import static org.apache.arrow.vector.complex.MapVector.KEY_NAME; +import static org.apache.arrow.vector.complex.MapVector.VALUE_NAME; +import static org.apache.arrow.vector.types.Types.MinorType.BIGINT; +import static org.apache.arrow.vector.types.Types.MinorType.BIT; +import static org.apache.arrow.vector.types.Types.MinorType.INT; +import static org.apache.arrow.vector.types.Types.MinorType.LIST; +import static org.apache.arrow.vector.types.Types.MinorType.STRUCT; +import static org.apache.arrow.vector.types.Types.MinorType.UINT4; +import static org.apache.arrow.vector.types.Types.MinorType.VARCHAR; + +import java.util.List; + +import org.apache.arrow.flight.Action; +import org.apache.arrow.flight.ActionType; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.SchemaResult; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.impl.FlightSql.ActionClosePreparedStatementRequest; +import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementRequest; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCatalogs; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetPrimaryKeys; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetSqlInfo; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTableTypes; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTables; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementQuery; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementUpdate; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementQuery; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementUpdate; +import org.apache.arrow.flight.sql.impl.FlightSql.DoPutUpdateResult; +import org.apache.arrow.flight.sql.impl.FlightSql.TicketStatementQuery; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.UnionMode; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.ArrowType.Union; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; + +import com.google.common.collect.ImmutableList; +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; + +/** + * API to Implement an Arrow Flight SQL producer. + */ +public interface FlightSqlProducer extends FlightProducer, AutoCloseable { + /** + * Depending on the provided command, method either: + * 1. Return information about a SQL query, or + * 2. Return information about a prepared statement. In this case, parameters binding is allowed. + * + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return information about the given SQL query, or the given prepared statement. + */ + @Override + default FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { + final Any command = FlightSqlUtils.parseOrThrow(descriptor.getCommand()); + + if (command.is(CommandStatementQuery.class)) { + return getFlightInfoStatement( + FlightSqlUtils.unpackOrThrow(command, CommandStatementQuery.class), context, descriptor); + } else if (command.is(CommandPreparedStatementQuery.class)) { + return getFlightInfoPreparedStatement( + FlightSqlUtils.unpackOrThrow(command, CommandPreparedStatementQuery.class), context, descriptor); + } else if (command.is(CommandGetCatalogs.class)) { + return getFlightInfoCatalogs( + FlightSqlUtils.unpackOrThrow(command, CommandGetCatalogs.class), context, descriptor); + } else if (command.is(CommandGetDbSchemas.class)) { + return getFlightInfoSchemas( + FlightSqlUtils.unpackOrThrow(command, CommandGetDbSchemas.class), context, descriptor); + } else if (command.is(CommandGetTables.class)) { + return getFlightInfoTables( + FlightSqlUtils.unpackOrThrow(command, CommandGetTables.class), context, descriptor); + } else if (command.is(CommandGetTableTypes.class)) { + return getFlightInfoTableTypes( + FlightSqlUtils.unpackOrThrow(command, CommandGetTableTypes.class), context, descriptor); + } else if (command.is(CommandGetSqlInfo.class)) { + return getFlightInfoSqlInfo( + FlightSqlUtils.unpackOrThrow(command, CommandGetSqlInfo.class), context, descriptor); + } else if (command.is(CommandGetPrimaryKeys.class)) { + return getFlightInfoPrimaryKeys( + FlightSqlUtils.unpackOrThrow(command, CommandGetPrimaryKeys.class), context, descriptor); + } else if (command.is(CommandGetExportedKeys.class)) { + return getFlightInfoExportedKeys( + FlightSqlUtils.unpackOrThrow(command, CommandGetExportedKeys.class), context, descriptor); + } else if (command.is(CommandGetImportedKeys.class)) { + return getFlightInfoImportedKeys( + FlightSqlUtils.unpackOrThrow(command, CommandGetImportedKeys.class), context, descriptor); + } else if (command.is(CommandGetCrossReference.class)) { + return getFlightInfoCrossReference( + FlightSqlUtils.unpackOrThrow(command, CommandGetCrossReference.class), context, descriptor); + } + + throw CallStatus.INVALID_ARGUMENT.withDescription("The defined request is invalid.").toRuntimeException(); + } + + /** + * Returns the schema of the result produced by the SQL query. + * + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return the result set schema. + */ + @Override + default SchemaResult getSchema(CallContext context, FlightDescriptor descriptor) { + final Any command = FlightSqlUtils.parseOrThrow(descriptor.getCommand()); + + if (command.is(CommandStatementQuery.class)) { + return getSchemaStatement( + FlightSqlUtils.unpackOrThrow(command, CommandStatementQuery.class), context, descriptor); + } else if (command.is(CommandGetCatalogs.class)) { + return new SchemaResult(Schemas.GET_CATALOGS_SCHEMA); + } else if (command.is(CommandGetDbSchemas.class)) { + return new SchemaResult(Schemas.GET_SCHEMAS_SCHEMA); + } else if (command.is(CommandGetTables.class)) { + return new SchemaResult(Schemas.GET_TABLES_SCHEMA); + } else if (command.is(CommandGetTableTypes.class)) { + return new SchemaResult(Schemas.GET_TABLE_TYPES_SCHEMA); + } else if (command.is(CommandGetSqlInfo.class)) { + return new SchemaResult(Schemas.GET_SQL_INFO_SCHEMA); + } else if (command.is(CommandGetPrimaryKeys.class)) { + return new SchemaResult(Schemas.GET_PRIMARY_KEYS_SCHEMA); + } else if (command.is(CommandGetImportedKeys.class)) { + return new SchemaResult(Schemas.GET_IMPORTED_KEYS_SCHEMA); + } else if (command.is(CommandGetExportedKeys.class)) { + return new SchemaResult(Schemas.GET_EXPORTED_KEYS_SCHEMA); + } else if (command.is(CommandGetCrossReference.class)) { + return new SchemaResult(Schemas.GET_CROSS_REFERENCE_SCHEMA); + } + + throw CallStatus.INVALID_ARGUMENT.withDescription("Invalid command provided.").toRuntimeException(); + } + + /** + * Depending on the provided command, method either: + * 1. Return data for a stream produced by executing the provided SQL query, or + * 2. Return data for a prepared statement. In this case, parameters binding is allowed. + * + * @param context Per-call context. + * @param ticket The application-defined ticket identifying this stream. + * @param listener An interface for sending data back to the client. + */ + @Override + default void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { + final Any command; + + try { + command = Any.parseFrom(ticket.getBytes()); + } catch (InvalidProtocolBufferException e) { + listener.error(e); + return; + } + + if (command.is(TicketStatementQuery.class)) { + getStreamStatement( + FlightSqlUtils.unpackOrThrow(command, TicketStatementQuery.class), context, listener); + } else if (command.is(CommandPreparedStatementQuery.class)) { + getStreamPreparedStatement( + FlightSqlUtils.unpackOrThrow(command, CommandPreparedStatementQuery.class), context, listener); + } else if (command.is(CommandGetCatalogs.class)) { + getStreamCatalogs(context, listener); + } else if (command.is(CommandGetDbSchemas.class)) { + getStreamSchemas(FlightSqlUtils.unpackOrThrow(command, CommandGetDbSchemas.class), context, listener); + } else if (command.is(CommandGetTables.class)) { + getStreamTables(FlightSqlUtils.unpackOrThrow(command, CommandGetTables.class), context, listener); + } else if (command.is(CommandGetTableTypes.class)) { + getStreamTableTypes(context, listener); + } else if (command.is(CommandGetSqlInfo.class)) { + getStreamSqlInfo(FlightSqlUtils.unpackOrThrow(command, CommandGetSqlInfo.class), context, listener); + } else if (command.is(CommandGetPrimaryKeys.class)) { + getStreamPrimaryKeys(FlightSqlUtils.unpackOrThrow(command, CommandGetPrimaryKeys.class), context, listener); + } else if (command.is(CommandGetExportedKeys.class)) { + getStreamExportedKeys(FlightSqlUtils.unpackOrThrow(command, CommandGetExportedKeys.class), context, listener); + } else if (command.is(CommandGetImportedKeys.class)) { + getStreamImportedKeys(FlightSqlUtils.unpackOrThrow(command, CommandGetImportedKeys.class), context, listener); + } else if (command.is(CommandGetCrossReference.class)) { + getStreamCrossReference(FlightSqlUtils.unpackOrThrow(command, CommandGetCrossReference.class), context, listener); + } else { + throw CallStatus.INVALID_ARGUMENT.withDescription("The defined request is invalid.").toRuntimeException(); + } + } + + /** + * Depending on the provided command, method either: + * 1. Execute provided SQL query as an update statement, or + * 2. Execute provided update SQL query prepared statement. In this case, parameters binding + * is allowed, or + * 3. Binds parameters to the provided prepared statement. + * + * @param context Per-call context. + * @param flightStream The data stream being uploaded. + * @param ackStream The data stream listener for update result acknowledgement. + * @return a Runnable to process the stream. + */ + @Override + default Runnable acceptPut(CallContext context, FlightStream flightStream, StreamListener ackStream) { + final Any command = FlightSqlUtils.parseOrThrow(flightStream.getDescriptor().getCommand()); + + if (command.is(CommandStatementUpdate.class)) { + return acceptPutStatement( + FlightSqlUtils.unpackOrThrow(command, CommandStatementUpdate.class), + context, flightStream, ackStream); + } else if (command.is(CommandPreparedStatementUpdate.class)) { + return acceptPutPreparedStatementUpdate( + FlightSqlUtils.unpackOrThrow(command, CommandPreparedStatementUpdate.class), + context, flightStream, ackStream); + } else if (command.is(CommandPreparedStatementQuery.class)) { + return acceptPutPreparedStatementQuery( + FlightSqlUtils.unpackOrThrow(command, CommandPreparedStatementQuery.class), + context, flightStream, ackStream); + } + + throw CallStatus.INVALID_ARGUMENT.withDescription("The defined request is invalid.").toRuntimeException(); + } + + /** + * Lists all available Flight SQL actions. + * + * @param context Per-call context. + * @param listener An interface for sending data back to the client. + */ + @Override + default void listActions(CallContext context, StreamListener listener) { + FlightSqlUtils.FLIGHT_SQL_ACTIONS.forEach(listener::onNext); + listener.onCompleted(); + } + + /** + * Performs the requested Flight SQL action. + * + * @param context Per-call context. + * @param action Client-supplied parameters. + * @param listener A stream of responses. + */ + @Override + default void doAction(CallContext context, Action action, StreamListener listener) { + final String actionType = action.getType(); + if (actionType.equals(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType())) { + final ActionCreatePreparedStatementRequest request = FlightSqlUtils.unpackAndParseOrThrow(action.getBody(), + ActionCreatePreparedStatementRequest.class); + createPreparedStatement(request, context, listener); + } else if (actionType.equals(FlightSqlUtils.FLIGHT_SQL_CLOSE_PREPARED_STATEMENT.getType())) { + final ActionClosePreparedStatementRequest request = FlightSqlUtils.unpackAndParseOrThrow(action.getBody(), + ActionClosePreparedStatementRequest.class); + closePreparedStatement(request, context, listener); + } + + throw CallStatus.INVALID_ARGUMENT.withDescription("Invalid action provided.").toRuntimeException(); + } + + /** + * Creates a prepared statement on the server and returns a handle and metadata for in a + * {@link ActionCreatePreparedStatementResult} object in a {@link Result} + * object. + * + * @param request The sql command to generate the prepared statement. + * @param context Per-call context. + * @param listener A stream of responses. + */ + void createPreparedStatement(ActionCreatePreparedStatementRequest request, CallContext context, + StreamListener listener); + + /** + * Closes a prepared statement on the server. No result is expected. + * + * @param request The sql command to generate the prepared statement. + * @param context Per-call context. + * @param listener A stream of responses. + */ + void closePreparedStatement(ActionClosePreparedStatementRequest request, CallContext context, + StreamListener listener); + + /** + * Gets information about a particular SQL query based data stream. + * + * @param command The sql command to generate the data stream. + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return Metadata about the stream. + */ + FlightInfo getFlightInfoStatement(CommandStatementQuery command, CallContext context, + FlightDescriptor descriptor); + + /** + * Gets information about a particular prepared statement data stream. + * + * @param command The prepared statement to generate the data stream. + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return Metadata about the stream. + */ + FlightInfo getFlightInfoPreparedStatement(CommandPreparedStatementQuery command, + CallContext context, FlightDescriptor descriptor); + + /** + * Gets schema about a particular SQL query based data stream. + * + * @param command The sql command to generate the data stream. + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return Schema for the stream. + */ + SchemaResult getSchemaStatement(CommandStatementQuery command, CallContext context, + FlightDescriptor descriptor); + + /** + * Returns data for a SQL query based data stream. + * @param ticket Ticket message containing the statement handle. + * @param context Per-call context. + * @param listener An interface for sending data back to the client. + */ + void getStreamStatement(TicketStatementQuery ticket, CallContext context, + ServerStreamListener listener); + + /** + * Returns data for a particular prepared statement query instance. + * + * @param command The prepared statement to generate the data stream. + * @param context Per-call context. + * @param listener An interface for sending data back to the client. + */ + void getStreamPreparedStatement(CommandPreparedStatementQuery command, CallContext context, + ServerStreamListener listener); + + /** + * Accepts uploaded data for a particular SQL query based data stream. + *

    `PutResult`s must be in the form of a {@link DoPutUpdateResult}. + * + * @param command The sql command to generate the data stream. + * @param context Per-call context. + * @param flightStream The data stream being uploaded. + * @param ackStream The result data stream. + * @return A runnable to process the stream. + */ + Runnable acceptPutStatement(CommandStatementUpdate command, CallContext context, + FlightStream flightStream, StreamListener ackStream); + + /** + * Accepts uploaded data for a particular prepared statement data stream. + *

    `PutResult`s must be in the form of a {@link DoPutUpdateResult}. + * + * @param command The prepared statement to generate the data stream. + * @param context Per-call context. + * @param flightStream The data stream being uploaded. + * @param ackStream The result data stream. + * @return A runnable to process the stream. + */ + Runnable acceptPutPreparedStatementUpdate(CommandPreparedStatementUpdate command, + CallContext context, FlightStream flightStream, + StreamListener ackStream); + + /** + * Accepts uploaded parameter values for a particular prepared statement query. + * + * @param command The prepared statement the parameter values will bind to. + * @param context Per-call context. + * @param flightStream The data stream being uploaded. + * @param ackStream The result data stream. + * @return A runnable to process the stream. + */ + Runnable acceptPutPreparedStatementQuery(CommandPreparedStatementQuery command, + CallContext context, FlightStream flightStream, + StreamListener ackStream); + + /** + * Returns the SQL Info of the server by returning a + * {@link CommandGetSqlInfo} in a {@link Result}. + * + * @param request request filter parameters. + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return Metadata about the stream. + */ + FlightInfo getFlightInfoSqlInfo(CommandGetSqlInfo request, CallContext context, + FlightDescriptor descriptor); + + /** + * Returns data for SQL info based data stream. + * + * @param command The command to generate the data stream. + * @param context Per-call context. + * @param listener An interface for sending data back to the client. + */ + void getStreamSqlInfo(CommandGetSqlInfo command, CallContext context, ServerStreamListener listener); + + /** + * Returns the available catalogs by returning a stream of + * {@link CommandGetCatalogs} objects in {@link Result} objects. + * + * @param request request filter parameters. + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return Metadata about the stream. + */ + FlightInfo getFlightInfoCatalogs(CommandGetCatalogs request, CallContext context, + FlightDescriptor descriptor); + + /** + * Returns data for catalogs based data stream. + * + * @param context Per-call context. + * @param listener An interface for sending data back to the client. + */ + void getStreamCatalogs(CallContext context, ServerStreamListener listener); + + /** + * Returns the available schemas by returning a stream of + * {@link CommandGetDbSchemas} objects in {@link Result} objects. + * + * @param request request filter parameters. + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return Metadata about the stream. + */ + FlightInfo getFlightInfoSchemas(CommandGetDbSchemas request, CallContext context, + FlightDescriptor descriptor); + + /** + * Returns data for schemas based data stream. + * + * @param command The command to generate the data stream. + * @param context Per-call context. + * @param listener An interface for sending data back to the client. + */ + void getStreamSchemas(CommandGetDbSchemas command, CallContext context, ServerStreamListener listener); + + /** + * Returns the available tables by returning a stream of + * {@link CommandGetTables} objects in {@link Result} objects. + * + * @param request request filter parameters. + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return Metadata about the stream. + */ + FlightInfo getFlightInfoTables(CommandGetTables request, CallContext context, + FlightDescriptor descriptor); + + /** + * Returns data for tables based data stream. + * + * @param command The command to generate the data stream. + * @param context Per-call context. + * @param listener An interface for sending data back to the client. + */ + void getStreamTables(CommandGetTables command, CallContext context, ServerStreamListener listener); + + /** + * Returns the available table types by returning a stream of + * {@link CommandGetTableTypes} objects in {@link Result} objects. + * + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return Metadata about the stream. + */ + FlightInfo getFlightInfoTableTypes(CommandGetTableTypes request, CallContext context, + FlightDescriptor descriptor); + + /** + * Returns data for table types based data stream. + * + * @param context Per-call context. + * @param listener An interface for sending data back to the client. + */ + void getStreamTableTypes(CallContext context, ServerStreamListener listener); + + /** + * Returns the available primary keys by returning a stream of + * {@link CommandGetPrimaryKeys} objects in {@link Result} objects. + * + * @param request request filter parameters. + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return Metadata about the stream. + */ + FlightInfo getFlightInfoPrimaryKeys(CommandGetPrimaryKeys request, CallContext context, + FlightDescriptor descriptor); + + /** + * Returns data for primary keys based data stream. + * + * @param command The command to generate the data stream. + * @param context Per-call context. + * @param listener An interface for sending data back to the client. + */ + void getStreamPrimaryKeys(CommandGetPrimaryKeys command, CallContext context, + ServerStreamListener listener); + + /** + * Retrieves a description of the foreign key columns that reference the given table's primary key columns + * {@link CommandGetExportedKeys} objects in {@link Result} objects. + * + * @param request request filter parameters. + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return Metadata about the stream. + */ + FlightInfo getFlightInfoExportedKeys(CommandGetExportedKeys request, CallContext context, + FlightDescriptor descriptor); + + /** + * Retrieves a description of the primary key columns that are referenced by given table's foreign key columns + * {@link CommandGetImportedKeys} objects in {@link Result} objects. + * + * @param request request filter parameters. + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return Metadata about the stream. + */ + FlightInfo getFlightInfoImportedKeys(CommandGetImportedKeys request, CallContext context, + FlightDescriptor descriptor); + + /** + * Retrieve a description of the foreign key columns that reference the given table's primary key columns + * {@link CommandGetCrossReference} objects in {@link Result} objects. + * + * @param request request filter parameters. + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return Metadata about the stream. + */ + FlightInfo getFlightInfoCrossReference(CommandGetCrossReference request, CallContext context, + FlightDescriptor descriptor); + + /** + * Returns data for foreign keys based data stream. + * + * @param command The command to generate the data stream. + * @param context Per-call context. + * @param listener An interface for sending data back to the client. + */ + void getStreamExportedKeys(CommandGetExportedKeys command, CallContext context, + ServerStreamListener listener); + + /** + * Returns data for foreign keys based data stream. + * + * @param command The command to generate the data stream. + * @param context Per-call context. + * @param listener An interface for sending data back to the client. + */ + void getStreamImportedKeys(CommandGetImportedKeys command, CallContext context, + ServerStreamListener listener); + + /** + * Returns data for cross reference based data stream. + * + * @param command The command to generate the data stream. + * @param context Per-call context. + * @param listener An interface for sending data back to the client. + */ + void getStreamCrossReference(CommandGetCrossReference command, CallContext context, + ServerStreamListener listener); + + + /** + * Default schema templates for the {@link FlightSqlProducer}. + */ + final class Schemas { + public static final Schema GET_TABLES_SCHEMA = new Schema(asList( + Field.nullable("catalog_name", VARCHAR.getType()), + Field.nullable("db_schema_name", VARCHAR.getType()), + Field.notNullable("table_name", VARCHAR.getType()), + Field.notNullable("table_type", VARCHAR.getType()), + Field.notNullable("table_schema", MinorType.VARBINARY.getType()))); + public static final Schema GET_TABLES_SCHEMA_NO_SCHEMA = new Schema(asList( + Field.nullable("catalog_name", VARCHAR.getType()), + Field.nullable("db_schema_name", VARCHAR.getType()), + Field.notNullable("table_name", VARCHAR.getType()), + Field.notNullable("table_type", VARCHAR.getType()))); + public static final Schema GET_CATALOGS_SCHEMA = new Schema( + singletonList(Field.notNullable("catalog_name", VARCHAR.getType()))); + public static final Schema GET_TABLE_TYPES_SCHEMA = + new Schema(singletonList(Field.notNullable("table_type", VARCHAR.getType()))); + public static final Schema GET_SCHEMAS_SCHEMA = + new Schema(asList( + Field.nullable("catalog_name", VARCHAR.getType()), + Field.notNullable("db_schema_name", VARCHAR.getType()))); + private static final Schema GET_IMPORTED_EXPORTED_AND_CROSS_REFERENCE_KEYS_SCHEMA = + new Schema(asList( + Field.nullable("pk_catalog_name", VARCHAR.getType()), + Field.nullable("pk_db_schema_name", VARCHAR.getType()), + Field.notNullable("pk_table_name", VARCHAR.getType()), + Field.notNullable("pk_column_name", VARCHAR.getType()), + Field.nullable("fk_catalog_name", VARCHAR.getType()), + Field.nullable("fk_db_schema_name", VARCHAR.getType()), + Field.notNullable("fk_table_name", VARCHAR.getType()), + Field.notNullable("fk_column_name", VARCHAR.getType()), + Field.notNullable("key_sequence", INT.getType()), + Field.nullable("fk_key_name", VARCHAR.getType()), + Field.nullable("pk_key_name", VARCHAR.getType()), + Field.notNullable("update_rule", MinorType.UINT1.getType()), + Field.notNullable("delete_rule", MinorType.UINT1.getType()))); + public static final Schema GET_IMPORTED_KEYS_SCHEMA = GET_IMPORTED_EXPORTED_AND_CROSS_REFERENCE_KEYS_SCHEMA; + public static final Schema GET_EXPORTED_KEYS_SCHEMA = GET_IMPORTED_EXPORTED_AND_CROSS_REFERENCE_KEYS_SCHEMA; + public static final Schema GET_CROSS_REFERENCE_SCHEMA = GET_IMPORTED_EXPORTED_AND_CROSS_REFERENCE_KEYS_SCHEMA; + private static final List GET_SQL_INFO_DENSE_UNION_SCHEMA_FIELDS = asList( + Field.notNullable("string_value", VARCHAR.getType()), + Field.notNullable("bool_value", BIT.getType()), + Field.notNullable("bigint_value", BIGINT.getType()), + Field.notNullable("int32_bitmask", INT.getType()), + new Field( + "string_list", FieldType.notNullable(LIST.getType()), + singletonList(Field.nullable("item", VARCHAR.getType()))), + new Field( + "int32_to_int32_list_map", FieldType.notNullable(new ArrowType.Map(false)), + singletonList(new Field(DATA_VECTOR_NAME, new FieldType(false, STRUCT.getType(), null), + ImmutableList.of( + Field.notNullable(KEY_NAME, INT.getType()), + new Field( + VALUE_NAME, FieldType.nullable(LIST.getType()), + singletonList(Field.nullable("item", INT.getType())))))))); + public static final Schema GET_SQL_INFO_SCHEMA = + new Schema(asList( + Field.notNullable("info_name", UINT4.getType()), + new Field("value", + FieldType.notNullable( + new Union(UnionMode.Dense, range(0, GET_SQL_INFO_DENSE_UNION_SCHEMA_FIELDS.size()).toArray())), + GET_SQL_INFO_DENSE_UNION_SCHEMA_FIELDS))); + public static final Schema GET_PRIMARY_KEYS_SCHEMA = + new Schema(asList( + Field.nullable("catalog_name", VARCHAR.getType()), + Field.nullable("db_schema_name", VARCHAR.getType()), + Field.notNullable("table_name", VARCHAR.getType()), + Field.notNullable("column_name", VARCHAR.getType()), + Field.notNullable("key_sequence", INT.getType()), + Field.nullable("key_name", VARCHAR.getType()))); + + private Schemas() { + // Prevent instantiation. + } + } +} diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java new file mode 100644 index 00000000000..25affa8f08a --- /dev/null +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql; + +import java.util.List; + +import org.apache.arrow.flight.ActionType; +import org.apache.arrow.flight.CallStatus; + +import com.google.common.collect.ImmutableList; +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Message; + +/** + * Utilities to work with Flight SQL semantics. + */ +public final class FlightSqlUtils { + public static final ActionType FLIGHT_SQL_CREATE_PREPARED_STATEMENT = new ActionType("CreatePreparedStatement", + "Creates a reusable prepared statement resource on the server. \n" + + "Request Message: ActionCreatePreparedStatementRequest\n" + + "Response Message: ActionCreatePreparedStatementResult"); + + public static final ActionType FLIGHT_SQL_CLOSE_PREPARED_STATEMENT = new ActionType("ClosePreparedStatement", + "Closes a reusable prepared statement resource on the server. \n" + + "Request Message: ActionClosePreparedStatementRequest\n" + + "Response Message: N/A"); + + public static final List FLIGHT_SQL_ACTIONS = ImmutableList.of( + FLIGHT_SQL_CREATE_PREPARED_STATEMENT, + FLIGHT_SQL_CLOSE_PREPARED_STATEMENT + ); + + /** + * Helper to parse {@link com.google.protobuf.Any} objects to the specific protobuf object. + * + * @param source the raw bytes source value. + * @return the materialized protobuf object. + */ + public static Any parseOrThrow(byte[] source) { + try { + return Any.parseFrom(source); + } catch (final InvalidProtocolBufferException e) { + throw CallStatus.INVALID_ARGUMENT + .withDescription("Received invalid message from remote.") + .withCause(e) + .toRuntimeException(); + } + } + + /** + * Helper to unpack {@link com.google.protobuf.Any} objects to the specific protobuf object. + * + * @param source the parsed Source value. + * @param as the class to unpack as. + * @param the class to unpack as. + * @return the materialized protobuf object. + */ + public static T unpackOrThrow(Any source, Class as) { + try { + return source.unpack(as); + } catch (final InvalidProtocolBufferException e) { + throw CallStatus.INVALID_ARGUMENT + .withDescription("Provided message cannot be unpacked as desired type.") + .withCause(e) + .toRuntimeException(); + } + } + + /** + * Helper to parse and unpack {@link com.google.protobuf.Any} objects to the specific protobuf object. + * + * @param source the raw bytes source value. + * @param as the class to unpack as. + * @param the class to unpack as. + * @return the materialized protobuf object. + */ + public static T unpackAndParseOrThrow(byte[] source, Class as) { + return unpackOrThrow(parseOrThrow(source), as); + } +} diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/SqlInfoBuilder.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/SqlInfoBuilder.java new file mode 100644 index 00000000000..3866cb89b1f --- /dev/null +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/SqlInfoBuilder.java @@ -0,0 +1,1024 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.stream.IntStream.range; +import static org.apache.arrow.flight.FlightProducer.ServerStreamListener; +import static org.apache.arrow.flight.sql.util.SqlInfoOptionsUtils.createBitmaskFromEnums; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.ObjIntConsumer; + +import org.apache.arrow.flight.sql.impl.FlightSql.SqlInfo; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlOuterJoinsSupportLevel; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedCaseSensitivity; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedElementActions; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedGroupBy; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedPositionedCommands; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedResultSetType; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedSubqueries; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedUnions; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlTransactionIsolationLevel; +import org.apache.arrow.flight.sql.impl.FlightSql.SupportedAnsi92SqlGrammarLevel; +import org.apache.arrow.flight.sql.impl.FlightSql.SupportedSqlGrammar; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.DenseUnionVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.apache.arrow.vector.complex.impl.UnionMapWriter; +import org.apache.arrow.vector.complex.writer.BaseWriter; +import org.apache.arrow.vector.holders.NullableBigIntHolder; +import org.apache.arrow.vector.holders.NullableBitHolder; +import org.apache.arrow.vector.holders.NullableIntHolder; +import org.apache.arrow.vector.holders.NullableVarCharHolder; + +import com.google.protobuf.ProtocolMessageEnum; + +/** + * Auxiliary class meant to facilitate the implementation of {@link FlightSqlProducer#getStreamSqlInfo}. + *

    + * Usage requires the user to add the required SqlInfo values using the {@code with*} methods + * like {@link SqlInfoBuilder#withFlightSqlServerName(String)}, and request it back + * through the {@link SqlInfoBuilder#send(List, ServerStreamListener)} method. + */ +@SuppressWarnings({"unused"}) +public class SqlInfoBuilder { + private final Map> providers = new HashMap<>(); + + /** + * Gets a {@link NullableVarCharHolder} from the provided {@code string} using the provided {@code buf}. + * + * @param string the {@link StandardCharsets#UTF_8}-encoded text input to store onto the holder. + * @param buf the {@link ArrowBuf} from which to create the new holder. + * @return a new {@link NullableVarCharHolder} with the provided input data {@code string}. + */ + public static NullableVarCharHolder getHolderForUtf8(final String string, final ArrowBuf buf) { + final byte[] bytes = string.getBytes(UTF_8); + buf.setBytes(0, bytes); + final NullableVarCharHolder holder = new NullableVarCharHolder(); + holder.buffer = buf; + holder.end = bytes.length; + holder.isSet = 1; + return holder; + } + + /** + * Sets a value for {@link SqlInfo#FLIGHT_SQL_SERVER_NAME} in the builder. + * + * @param value the value for {@link SqlInfo#FLIGHT_SQL_SERVER_NAME} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withFlightSqlServerName(final String value) { + return withStringProvider(SqlInfo.FLIGHT_SQL_SERVER_NAME_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#FLIGHT_SQL_SERVER_VERSION} in the builder. + * + * @param value the value for {@link SqlInfo#FLIGHT_SQL_SERVER_VERSION} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withFlightSqlServerVersion(final String value) { + return withStringProvider(SqlInfo.FLIGHT_SQL_SERVER_VERSION_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#FLIGHT_SQL_SERVER_ARROW_VERSION} in the builder. + * + * @param value the value for {@link SqlInfo#FLIGHT_SQL_SERVER_ARROW_VERSION} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withFlightSqlServerArrowVersion(final String value) { + return withStringProvider(SqlInfo.FLIGHT_SQL_SERVER_ARROW_VERSION_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_IDENTIFIER_QUOTE_CHAR} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_IDENTIFIER_QUOTE_CHAR} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlIdentifierQuoteChar(final String value) { + return withStringProvider(SqlInfo.SQL_IDENTIFIER_QUOTE_CHAR_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_SEARCH_STRING_ESCAPE} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_SEARCH_STRING_ESCAPE} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlSearchStringEscape(final String value) { + return withStringProvider(SqlInfo.SQL_SEARCH_STRING_ESCAPE_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_EXTRA_NAME_CHARACTERS} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_EXTRA_NAME_CHARACTERS} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlExtraNameCharacters(final String value) { + return withStringProvider(SqlInfo.SQL_EXTRA_NAME_CHARACTERS_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_SCHEMA_TERM} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_SCHEMA_TERM} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlSchemaTerm(final String value) { + return withStringProvider(SqlInfo.SQL_SCHEMA_TERM_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_CATALOG_TERM} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_CATALOG_TERM} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlCatalogTerm(final String value) { + return withStringProvider(SqlInfo.SQL_CATALOG_TERM_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_PROCEDURE_TERM} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_PROCEDURE_TERM} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlProcedureTerm(final String value) { + return withStringProvider(SqlInfo.SQL_PROCEDURE_TERM_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_DDL_CATALOG} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_DDL_CATALOG} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlDdlCatalog(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_DDL_CATALOG_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_DDL_SCHEMA} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_DDL_SCHEMA} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlDdlSchema(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_DDL_SCHEMA_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_DDL_TABLE} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_DDL_TABLE} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlDdlTable(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_DDL_TABLE_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#FLIGHT_SQL_SERVER_READ_ONLY} in the builder. + * + * @param value the value for {@link SqlInfo#FLIGHT_SQL_SERVER_READ_ONLY} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withFlightSqlServerReadOnly(final boolean value) { + return withBooleanProvider(SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_SUPPORTS_COLUMN_ALIASING} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_SUPPORTS_COLUMN_ALIASING} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlSupportsColumnAliasing(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_SUPPORTS_COLUMN_ALIASING_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_NULL_PLUS_NULL_IS_NULL} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_NULL_PLUS_NULL_IS_NULL} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlNullPlusNullIsNull(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_NULL_PLUS_NULL_IS_NULL_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_SUPPORTS_TABLE_CORRELATION_NAMES} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_SUPPORTS_TABLE_CORRELATION_NAMES} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlSupportsTableCorrelationNames(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_SUPPORTS_TABLE_CORRELATION_NAMES_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlSupportsDifferentTableCorrelationNames(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_SUPPORTS_EXPRESSIONS_IN_ORDER_BY} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_SUPPORTS_EXPRESSIONS_IN_ORDER_BY} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlSupportsExpressionsInOrderBy(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_SUPPORTS_EXPRESSIONS_IN_ORDER_BY_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_SUPPORTS_ORDER_BY_UNRELATED} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_SUPPORTS_ORDER_BY_UNRELATED} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlSupportsOrderByUnrelated(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_SUPPORTS_ORDER_BY_UNRELATED_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_SUPPORTS_LIKE_ESCAPE_CLAUSE} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_SUPPORTS_LIKE_ESCAPE_CLAUSE} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlSupportsLikeEscapeClause(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_SUPPORTS_LIKE_ESCAPE_CLAUSE_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_SUPPORTS_NON_NULLABLE_COLUMNS} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_SUPPORTS_NON_NULLABLE_COLUMNS} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlSupportsNonNullableColumns(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_SUPPORTS_NON_NULLABLE_COLUMNS_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlSupportsIntegrityEnhancementFacility(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_CATALOG_AT_START} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_CATALOG_AT_START} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlCatalogAtStart(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_CATALOG_AT_START_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_SELECT_FOR_UPDATE_SUPPORTED} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_SELECT_FOR_UPDATE_SUPPORTED} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlSelectForUpdateSupported(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_SELECT_FOR_UPDATE_SUPPORTED_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_STORED_PROCEDURES_SUPPORTED} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_STORED_PROCEDURES_SUPPORTED} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlStoredProceduresSupported(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_STORED_PROCEDURES_SUPPORTED_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_CORRELATED_SUBQUERIES_SUPPORTED} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_CORRELATED_SUBQUERIES_SUPPORTED} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlCorrelatedSubqueriesSupported(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_CORRELATED_SUBQUERIES_SUPPORTED_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_MAX_ROW_SIZE_INCLUDES_BLOBS} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_MAX_ROW_SIZE_INCLUDES_BLOBS} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlMaxRowSizeIncludesBlobs(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_MAX_ROW_SIZE_INCLUDES_BLOBS_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_TRANSACTIONS_SUPPORTED} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_TRANSACTIONS_SUPPORTED} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlTransactionsSupported(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_TRANSACTIONS_SUPPORTED_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlDataDefinitionCausesTransactionCommit(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT_VALUE, + value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlDataDefinitionsInTransactionsIgnored(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED_VALUE, + value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_BATCH_UPDATES_SUPPORTED} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_BATCH_UPDATES_SUPPORTED} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlBatchUpdatesSupported(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_BATCH_UPDATES_SUPPORTED_VALUE, value); + } + + /** + * Sets a value for { @link SqlInfo#SQL_SAVEPOINTS_SUPPORTED} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_SAVEPOINTS_SUPPORTED} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlSavepointsSupported(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_SAVEPOINTS_SUPPORTED_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_NAMED_PARAMETERS_SUPPORTED} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_NAMED_PARAMETERS_SUPPORTED} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlNamedParametersSupported(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_NAMED_PARAMETERS_SUPPORTED_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_LOCATORS_UPDATE_COPY} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_LOCATORS_UPDATE_COPY} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlLocatorsUpdateCopy(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_LOCATORS_UPDATE_COPY_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlStoredFunctionsUsingCallSyntaxSupported(final boolean value) { + return withBooleanProvider(SqlInfo.SQL_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED_VALUE, + value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_IDENTIFIER_CASE} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_IDENTIFIER_CASE} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlIdentifierCase(final SqlSupportedCaseSensitivity value) { + return withBitIntProvider(SqlInfo.SQL_IDENTIFIER_CASE_VALUE, value.getNumber()); + } + + /** + * Sets a value for {@link SqlInfo#SQL_QUOTED_IDENTIFIER_CASE} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_QUOTED_IDENTIFIER_CASE} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlQuotedIdentifierCase(final SqlSupportedCaseSensitivity value) { + return withBitIntProvider(SqlInfo.SQL_QUOTED_IDENTIFIER_CASE_VALUE, value.getNumber()); + } + + /** + * Sets a value SqlInf @link SqlInfo#SQL_MAX_BINARY_LITERAL_LENGTH} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_MAX_BINARY_LITERAL_LENGTH} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlMaxBinaryLiteralLength(final long value) { + return withBitIntProvider(SqlInfo.SQL_MAX_BINARY_LITERAL_LENGTH_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_MAX_CHAR_LITERAL_LENGTH} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_MAX_CHAR_LITERAL_LENGTH} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlMaxCharLiteralLength(final long value) { + return withBitIntProvider(SqlInfo.SQL_MAX_CHAR_LITERAL_LENGTH_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_MAX_COLUMN_NAME_LENGTH} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_MAX_COLUMN_NAME_LENGTH} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlMaxColumnNameLength(final long value) { + return withBitIntProvider(SqlInfo.SQL_MAX_COLUMN_NAME_LENGTH_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_MAX_COLUMNS_IN_GROUP_BY} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_MAX_COLUMNS_IN_GROUP_BY} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlMaxColumnsInGroupBy(final long value) { + return withBitIntProvider(SqlInfo.SQL_MAX_COLUMNS_IN_GROUP_BY_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_MAX_COLUMNS_IN_INDEX} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_MAX_COLUMNS_IN_INDEX} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlMaxColumnsInIndex(final long value) { + return withBitIntProvider(SqlInfo.SQL_MAX_COLUMNS_IN_INDEX_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_MAX_COLUMNS_IN_ORDER_BY} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_MAX_COLUMNS_IN_ORDER_BY} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlMaxColumnsInOrderBy(final long value) { + return withBitIntProvider(SqlInfo.SQL_MAX_COLUMNS_IN_ORDER_BY_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_MAX_COLUMNS_IN_SELECT} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_MAX_COLUMNS_IN_SELECT} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlMaxColumnsInSelect(final long value) { + return withBitIntProvider(SqlInfo.SQL_MAX_COLUMNS_IN_SELECT_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_MAX_CONNECTIONS} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_MAX_CONNECTIONS} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlMaxConnections(final long value) { + return withBitIntProvider(SqlInfo.SQL_MAX_CONNECTIONS_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_MAX_CURSOR_NAME_LENGTH} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_MAX_CURSOR_NAME_LENGTH} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlMaxCursorNameLength(final long value) { + return withBitIntProvider(SqlInfo.SQL_MAX_CURSOR_NAME_LENGTH_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_MAX_INDEX_LENGTH} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_MAX_INDEX_LENGTH} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlMaxIndexLength(final long value) { + return withBitIntProvider(SqlInfo.SQL_MAX_INDEX_LENGTH_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_DB_SCHEMA_NAME_LENGTH} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_DB_SCHEMA_NAME_LENGTH} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlDbSchemaNameLength(final long value) { + return withBitIntProvider(SqlInfo.SQL_DB_SCHEMA_NAME_LENGTH_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_MAX_PROCEDURE_NAME_LENGTH} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_MAX_PROCEDURE_NAME_LENGTH} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlMaxProcedureNameLength(final long value) { + return withBitIntProvider(SqlInfo.SQL_MAX_PROCEDURE_NAME_LENGTH_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_MAX_CATALOG_NAME_LENGTH} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_MAX_CATALOG_NAME_LENGTH} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlMaxCatalogNameLength(final long value) { + return withBitIntProvider(SqlInfo.SQL_MAX_CATALOG_NAME_LENGTH_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_MAX_ROW_SIZE} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_MAX_ROW_SIZE} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlMaxRowSize(final long value) { + return withBitIntProvider(SqlInfo.SQL_MAX_ROW_SIZE_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_MAX_STATEMENT_LENGTH} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_MAX_STATEMENT_LENGTH} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlMaxStatementLength(final long value) { + return withBitIntProvider(SqlInfo.SQL_MAX_STATEMENT_LENGTH_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_MAX_STATEMENTS} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_MAX_STATEMENTS} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlMaxStatements(final long value) { + return withBitIntProvider(SqlInfo.SQL_MAX_STATEMENTS_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_MAX_TABLE_NAME_LENGTH} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_MAX_TABLE_NAME_LENGTH} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlMaxTableNameLength(final long value) { + return withBitIntProvider(SqlInfo.SQL_MAX_TABLE_NAME_LENGTH_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_MAX_TABLES_IN_SELECT} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_MAX_TABLES_IN_SELECT} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlMaxTablesInSelect(final long value) { + return withBitIntProvider(SqlInfo.SQL_MAX_TABLES_IN_SELECT_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_MAX_USERNAME_LENGTH} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_MAX_USERNAME_LENGTH} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlMaxUsernameLength(final long value) { + return withBitIntProvider(SqlInfo.SQL_MAX_USERNAME_LENGTH_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_DEFAULT_TRANSACTION_ISOLATION} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_DEFAULT_TRANSACTION_ISOLATION} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlDefaultTransactionIsolation(final long value) { + return withBitIntProvider(SqlInfo.SQL_DEFAULT_TRANSACTION_ISOLATION_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_SUPPORTED_GROUP_BY} in the builder. + * + * @param values the value for {@link SqlInfo#SQL_SUPPORTED_GROUP_BY} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlSupportedGroupBy(final SqlSupportedGroupBy... values) { + return withEnumProvider(SqlInfo.SQL_SUPPORTED_GROUP_BY_VALUE, values); + } + + /** + * Sets a value for {@link SqlInfo#SQL_SUPPORTED_GRAMMAR} in the builder. + * + * @param values the value for {@link SqlInfo#SQL_SUPPORTED_GRAMMAR} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlSupportedGrammar(final SupportedSqlGrammar... values) { + return withEnumProvider(SqlInfo.SQL_SUPPORTED_GRAMMAR_VALUE, values); + } + + /** + * Sets a value for {@link SqlInfo#SQL_ANSI92_SUPPORTED_LEVEL} in the builder. + * + * @param values the value for {@link SqlInfo#SQL_ANSI92_SUPPORTED_LEVEL} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlAnsi92SupportedLevel(final SupportedAnsi92SqlGrammarLevel... values) { + return withEnumProvider(SqlInfo.SQL_ANSI92_SUPPORTED_LEVEL_VALUE, values); + } + + /** + * Sets a value for {@link SqlInfo#SQL_SCHEMAS_SUPPORTED_ACTIONS} in the builder. + * + * @param values the value for {@link SqlInfo#SQL_SCHEMAS_SUPPORTED_ACTIONS} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlSchemasSupportedActions(final SqlSupportedElementActions... values) { + return withEnumProvider(SqlInfo.SQL_SCHEMAS_SUPPORTED_ACTIONS_VALUE, values); + } + + /** + * Sets a value for {@link SqlInfo#SQL_CATALOGS_SUPPORTED_ACTIONS} in the builder. + * + * @param values the value for {@link SqlInfo#SQL_CATALOGS_SUPPORTED_ACTIONS} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlCatalogsSupportedActions(final SqlSupportedElementActions... values) { + return withEnumProvider(SqlInfo.SQL_CATALOGS_SUPPORTED_ACTIONS_VALUE, values); + } + + /** + * Sets a value for {@link SqlInfo#SQL_SUPPORTED_POSITIONED_COMMANDS} in the builder. + * + * @param values the value for {@link SqlInfo#SQL_SUPPORTED_POSITIONED_COMMANDS} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlSupportedPositionedCommands(final SqlSupportedPositionedCommands... values) { + return withEnumProvider(SqlInfo.SQL_SUPPORTED_POSITIONED_COMMANDS_VALUE, values); + } + + /** + * Sets a value for {@link SqlInfo#SQL_SUPPORTED_SUBQUERIES} in the builder. + * + * @param values the value for {@link SqlInfo#SQL_SUPPORTED_SUBQUERIES} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlSubQueriesSupported(final SqlSupportedSubqueries... values) { + return withEnumProvider(SqlInfo.SQL_SUPPORTED_SUBQUERIES_VALUE, values); + } + + /** + * Sets a value for {@link SqlInfo#SQL_SUPPORTED_UNIONS} in the builder. + * + * @param values the values for {@link SqlInfo#SQL_SUPPORTED_UNIONS} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlSupportedUnions(final SqlSupportedUnions... values) { + return withEnumProvider(SqlInfo.SQL_SUPPORTED_UNIONS_VALUE, values); + } + + /** + * Sets a value for {@link SqlInfo#SQL_OUTER_JOINS_SUPPORT_LEVEL} in the builder. + * + * @param value the value for {@link SqlInfo#SQL_OUTER_JOINS_SUPPORT_LEVEL} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlOuterJoinSupportLevel(final SqlOuterJoinsSupportLevel... value) { + return withEnumProvider(SqlInfo.SQL_OUTER_JOINS_SUPPORT_LEVEL_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_SUPPORTED_TRANSACTIONS_ISOLATION_LEVELS} in the builder. + * + * @param values the values for {@link SqlInfo#SQL_SUPPORTED_TRANSACTIONS_ISOLATION_LEVELS} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlSupportedTransactionsIsolationLevels(final SqlTransactionIsolationLevel... values) { + return withEnumProvider(SqlInfo.SQL_SUPPORTED_TRANSACTIONS_ISOLATION_LEVELS_VALUE, values); + } + + /** + * Sets a value for {@link SqlInfo#SQL_SUPPORTED_RESULT_SET_TYPES} in the builder. + * + * @param values the values for {@link SqlInfo#SQL_SUPPORTED_RESULT_SET_TYPES} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlSupportedResultSetTypes(final SqlSupportedResultSetType... values) { + return withEnumProvider(SqlInfo.SQL_SUPPORTED_RESULT_SET_TYPES_VALUE, values + ); + } + + /** + * Sets a value for {@link SqlInfo#SQL_KEYWORDS} in the builder. + * + * @param value the values for {@link SqlInfo#SQL_KEYWORDS} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlKeywords(final String[] value) { + return withStringArrayProvider(SqlInfo.SQL_KEYWORDS_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_NUMERIC_FUNCTIONS} in the builder. + * + * @param value the values for {@link SqlInfo#SQL_NUMERIC_FUNCTIONS} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlNumericFunctions(final String[] value) { + return withStringArrayProvider(SqlInfo.SQL_NUMERIC_FUNCTIONS_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_STRING_FUNCTIONS} in the builder. + * + * @param value the values for {@link SqlInfo#SQL_STRING_FUNCTIONS} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlStringFunctions(final String[] value) { + return withStringArrayProvider(SqlInfo.SQL_STRING_FUNCTIONS_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_SYSTEM_FUNCTIONS} in the builder. + * + * @param value the values for {@link SqlInfo#SQL_SYSTEM_FUNCTIONS} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlSystemFunctions(final String[] value) { + return withStringArrayProvider(SqlInfo.SQL_SYSTEM_FUNCTIONS_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_DATETIME_FUNCTIONS} in the builder. + * + * @param value the values for {@link SqlInfo#SQL_DATETIME_FUNCTIONS} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlDatetimeFunctions(final String[] value) { + return withStringArrayProvider(SqlInfo.SQL_DATETIME_FUNCTIONS_VALUE, value); + } + + /** + * Sets a value for {@link SqlInfo#SQL_SUPPORTS_CONVERT} in the builder. + * + * @param value the values for {@link SqlInfo#SQL_SUPPORTS_CONVERT} to be set. + * @return the SqlInfoBuilder itself. + */ + public SqlInfoBuilder withSqlSupportsConvert(final Map> value) { + return withIntToIntListMapProvider(SqlInfo.SQL_SUPPORTS_CONVERT_VALUE, value); + } + + private void addProvider(final int sqlInfo, final ObjIntConsumer provider) { + providers.put(sqlInfo, provider); + } + + private SqlInfoBuilder withEnumProvider(final int sqlInfo, final ProtocolMessageEnum[] values) { + return withIntProvider(sqlInfo, (int) createBitmaskFromEnums(values)); + } + + private SqlInfoBuilder withIntProvider(final int sqlInfo, final int value) { + addProvider(sqlInfo, (root, index) -> setDataForIntField(root, index, sqlInfo, value)); + return this; + } + + private SqlInfoBuilder withBitIntProvider(final int sqlInfo, final long value) { + addProvider(sqlInfo, (root, index) -> setDataForBigIntField(root, index, sqlInfo, value)); + return this; + } + + private SqlInfoBuilder withBooleanProvider(final int sqlInfo, + final boolean value) { + addProvider(sqlInfo, (root, index) -> setDataForBooleanField(root, index, sqlInfo, value)); + return this; + } + + private SqlInfoBuilder withStringProvider(final int sqlInfo, final String value) { + addProvider(sqlInfo, (root, index) -> setDataForUtf8Field(root, index, sqlInfo, value)); + return this; + } + + private SqlInfoBuilder withStringArrayProvider(final int sqlInfo, + final String[] value) { + addProvider(sqlInfo, (root, index) -> setDataVarCharListField(root, index, sqlInfo, value)); + return this; + } + + private SqlInfoBuilder withIntToIntListMapProvider(final int sqlInfo, + final Map> value) { + addProvider(sqlInfo, (root, index) -> setIntToIntListMapField(root, index, sqlInfo, value)); + return this; + } + + /** + * Send the requested information to given ServerStreamListener. + * + * @param infos List of SqlInfo to be sent. + * @param listener ServerStreamListener to send data to. + */ + public void send(List infos, final ServerStreamListener listener) { + if (infos == null || infos.isEmpty()) { + infos = new ArrayList<>(providers.keySet()); + } + try (final BufferAllocator allocator = new RootAllocator(); + final VectorSchemaRoot root = VectorSchemaRoot.create( + FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA, + allocator)) { + final int rows = infos.size(); + for (int i = 0; i < rows; i++) { + providers.get(infos.get(i)).accept(root, i); + } + root.setRowCount(rows); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + } + + private void setInfoName(final VectorSchemaRoot root, final int index, final int info) { + final UInt4Vector infoName = (UInt4Vector) root.getVector("info_name"); + infoName.setSafe(index, info); + } + + private void setValues(final VectorSchemaRoot root, final int index, final byte typeId, + final Consumer dataSetter) { + final DenseUnionVector values = (DenseUnionVector) root.getVector("value"); + values.setTypeId(index, typeId); + dataSetter.accept(values); + } + + /** + * Executes the given action on an ad-hoc, newly created instance of {@link ArrowBuf}. + * + * @param executor the action to take. + */ + private void onCreateArrowBuf(final Consumer executor) { + try (final BufferAllocator allocator = new RootAllocator(); + final ArrowBuf buf = allocator.buffer(1024)) { + executor.accept(buf); + } + } + + private void setDataForUtf8Field(final VectorSchemaRoot root, final int index, + final int sqlInfo, final String value) { + setInfoName(root, index, sqlInfo); + onCreateArrowBuf(buf -> { + final Consumer producer = + values -> values.setSafe(index, getHolderForUtf8(value, buf)); + setValues(root, index, (byte) 0, producer); + }); + } + + private void setDataForIntField(final VectorSchemaRoot root, final int index, + final int sqlInfo, final int value) { + setInfoName(root, index, sqlInfo); + final NullableIntHolder dataHolder = new NullableIntHolder(); + dataHolder.isSet = 1; + dataHolder.value = value; + setValues(root, index, (byte) 3, values -> values.setSafe(index, dataHolder)); + } + + private void setDataForBigIntField(final VectorSchemaRoot root, final int index, + final int sqlInfo, final long value) { + setInfoName(root, index, sqlInfo); + final NullableBigIntHolder dataHolder = new NullableBigIntHolder(); + dataHolder.isSet = 1; + dataHolder.value = value; + setValues(root, index, (byte) 2, values -> values.setSafe(index, dataHolder)); + } + + private void setDataForBooleanField(final VectorSchemaRoot root, final int index, + final int sqlInfo, final boolean value) { + setInfoName(root, index, sqlInfo); + final NullableBitHolder dataHolder = new NullableBitHolder(); + dataHolder.isSet = 1; + dataHolder.value = value ? 1 : 0; + setValues(root, index, (byte) 1, values -> values.setSafe(index, dataHolder)); + } + + private void setDataVarCharListField(final VectorSchemaRoot root, final int index, + final int sqlInfo, + final String[] values) { + final DenseUnionVector denseUnion = (DenseUnionVector) root.getVector("value"); + final ListVector listVector = denseUnion.getList((byte) 4); + final int listIndex = listVector.getValueCount(); + final int denseUnionValueCount = index + 1; + final int listVectorValueCount = listIndex + 1; + denseUnion.setValueCount(denseUnionValueCount); + listVector.setValueCount(listVectorValueCount); + + final UnionListWriter writer = listVector.getWriter(); + writer.setPosition(listIndex); + writer.startList(); + final int length = values.length; + range(0, length) + .forEach(i -> onCreateArrowBuf(buf -> { + final byte[] bytes = values[i].getBytes(UTF_8); + buf.setBytes(0, bytes); + writer.writeVarChar(0, bytes.length, buf); + })); + writer.endList(); + writer.setValueCount(listVectorValueCount); + + denseUnion.setTypeId(index, (byte) 4); + denseUnion.getOffsetBuffer().setInt(index * 4L, listIndex); + setInfoName(root, index, sqlInfo); + } + + private void setIntToIntListMapField(final VectorSchemaRoot root, final int index, + final int sqlInfo, + final Map> values) { + final DenseUnionVector denseUnion = (DenseUnionVector) root.getVector("value"); + final MapVector mapVector = denseUnion.getMap((byte) 5); + final int mapIndex = mapVector.getValueCount(); + denseUnion.setValueCount(index + 1); + mapVector.setValueCount(mapIndex + 1); + + final UnionMapWriter mapWriter = mapVector.getWriter(); + mapWriter.setPosition(mapIndex); + mapWriter.startMap(); + values.forEach((key, value) -> { + mapWriter.startEntry(); + mapWriter.key().integer().writeInt(key); + final BaseWriter.ListWriter listWriter = mapWriter.value().list(); + listWriter.startList(); + for (final int v : value) { + listWriter.integer().writeInt(v); + } + listWriter.endList(); + mapWriter.endEntry(); + }); + mapWriter.endMap(); + mapWriter.setValueCount(mapIndex + 1); + + denseUnion.setTypeId(index, (byte) 5); + denseUnion.getOffsetBuffer().setInt(index * 4L, mapIndex); + setInfoName(root, index, sqlInfo); + } +} diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/example/FlightSqlClientDemoApp.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/example/FlightSqlClientDemoApp.java new file mode 100644 index 00000000000..f3774a8a500 --- /dev/null +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/example/FlightSqlClientDemoApp.java @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql.example; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.arrow.flight.CallOption; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.sql.FlightSqlClient; +import org.apache.arrow.flight.sql.util.TableRef; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.CommandLineParser; +import org.apache.commons.cli.DefaultParser; +import org.apache.commons.cli.HelpFormatter; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.ParseException; + +/** + * Flight SQL Client Demo CLI Application. + */ +public class FlightSqlClientDemoApp implements AutoCloseable { + public final List callOptions = new ArrayList<>(); + public final BufferAllocator allocator; + public FlightSqlClient flightSqlClient; + + public FlightSqlClientDemoApp(final BufferAllocator bufferAllocator) { + allocator = bufferAllocator; + } + + public static void main(final String[] args) throws Exception { + final Options options = new Options(); + + options.addRequiredOption("host", "host", true, "Host to connect to"); + options.addRequiredOption("port", "port", true, "Port to connect to"); + options.addRequiredOption("command", "command", true, "Method to run"); + + options.addOption("query", "query", false, "Query"); + options.addOption("catalog", "catalog", false, "Catalog"); + options.addOption("schema", "schema", false, "Schema"); + options.addOption("table", "table", false, "Table"); + + CommandLineParser parser = new DefaultParser(); + HelpFormatter formatter = new HelpFormatter(); + CommandLine cmd; + + try { + cmd = parser.parse(options, args); + try (final FlightSqlClientDemoApp thisApp = new FlightSqlClientDemoApp(new RootAllocator(Integer.MAX_VALUE))) { + thisApp.executeApp(cmd); + } + + } catch (final ParseException e) { + System.out.println(e.getMessage()); + formatter.printHelp("FlightSqlClientDemoApp -host localhost -port 32010 ...", options); + throw e; + } + } + + /** + * Gets the current {@link CallOption} as an array; usually used as an + * argument in {@link FlightSqlClient} methods. + * + * @return current {@link CallOption} array. + */ + public CallOption[] getCallOptions() { + return callOptions.toArray(new CallOption[0]); + } + + /** + * Calls {@link FlightSqlClientDemoApp#createFlightSqlClient(String, int)} + * in order to create a {@link FlightSqlClient} to be used in future calls, + * and then calls {@link FlightSqlClientDemoApp#executeCommand(CommandLine)} + * to execute the command parsed at execution. + * + * @param cmd parsed {@link CommandLine}; often the result of {@link DefaultParser#parse(Options, String[])}. + */ + public void executeApp(final CommandLine cmd) throws Exception { + final String host = cmd.getOptionValue("host").trim(); + final int port = Integer.parseInt(cmd.getOptionValue("port").trim()); + + createFlightSqlClient(host, port); + executeCommand(cmd); + } + + /** + * Parses the "{@code command}" CLI argument and redirects to the appropriate method. + * + * @param cmd parsed {@link CommandLine}; often the result of + * {@link DefaultParser#parse(Options, String[])}. + */ + public void executeCommand(CommandLine cmd) throws Exception { + switch (cmd.getOptionValue("command").trim()) { + case "Execute": + exampleExecute( + cmd.getOptionValue("query") + ); + break; + case "ExecuteUpdate": + exampleExecuteUpdate( + cmd.getOptionValue("query") + ); + break; + case "GetCatalogs": + exampleGetCatalogs(); + break; + case "GetSchemas": + exampleGetSchemas( + cmd.getOptionValue("catalog"), + cmd.getOptionValue("schema") + ); + break; + case "GetTableTypes": + exampleGetTableTypes(); + break; + case "GetTables": + exampleGetTables( + cmd.getOptionValue("catalog"), + cmd.getOptionValue("schema"), + cmd.getOptionValue("table") + ); + break; + case "GetExportedKeys": + exampleGetExportedKeys( + cmd.getOptionValue("catalog"), + cmd.getOptionValue("schema"), + cmd.getOptionValue("table") + ); + break; + case "GetImportedKeys": + exampleGetImportedKeys( + cmd.getOptionValue("catalog"), + cmd.getOptionValue("schema"), + cmd.getOptionValue("table") + ); + break; + case "GetPrimaryKeys": + exampleGetPrimaryKeys( + cmd.getOptionValue("catalog"), + cmd.getOptionValue("schema"), + cmd.getOptionValue("table") + ); + break; + default: + System.out.println("Command used is not valid! Please use one of: \n" + + "[\"ExecuteUpdate\",\n" + + "\"Execute\",\n" + + "\"GetCatalogs\",\n" + + "\"GetSchemas\",\n" + + "\"GetTableTypes\",\n" + + "\"GetTables\",\n" + + "\"GetExportedKeys\",\n" + + "\"GetImportedKeys\",\n" + + "\"GetPrimaryKeys\"]"); + } + } + + /** + * Creates a {@link FlightSqlClient} to be used with the example methods. + * + * @param host client's hostname. + * @param port client's port. + */ + public void createFlightSqlClient(final String host, final int port) { + final Location clientLocation = Location.forGrpcInsecure(host, port); + flightSqlClient = new FlightSqlClient(FlightClient.builder(allocator, clientLocation).build()); + } + + private void exampleExecute(final String query) throws Exception { + printFlightInfoResults(flightSqlClient.execute(query, getCallOptions())); + } + + private void exampleExecuteUpdate(final String query) { + System.out.println("Updated: " + flightSqlClient.executeUpdate(query, getCallOptions()) + "rows."); + } + + private void exampleGetCatalogs() throws Exception { + printFlightInfoResults(flightSqlClient.getCatalogs(getCallOptions())); + } + + private void exampleGetSchemas(final String catalog, final String schema) throws Exception { + printFlightInfoResults(flightSqlClient.getSchemas(catalog, schema, getCallOptions())); + } + + private void exampleGetTableTypes() throws Exception { + printFlightInfoResults(flightSqlClient.getTableTypes(getCallOptions())); + } + + private void exampleGetTables(final String catalog, final String schema, final String table) throws Exception { + // For now, this won't filter by table types. + printFlightInfoResults(flightSqlClient.getTables( + catalog, schema, table, null, false, getCallOptions())); + } + + private void exampleGetExportedKeys(final String catalog, final String schema, final String table) throws Exception { + printFlightInfoResults(flightSqlClient.getExportedKeys(TableRef.of(catalog, schema, table), getCallOptions())); + } + + private void exampleGetImportedKeys(final String catalog, final String schema, final String table) throws Exception { + printFlightInfoResults(flightSqlClient.getImportedKeys(TableRef.of(catalog, schema, table), getCallOptions())); + } + + private void exampleGetPrimaryKeys(final String catalog, final String schema, final String table) throws Exception { + printFlightInfoResults(flightSqlClient.getPrimaryKeys(TableRef.of(catalog, schema, table), getCallOptions())); + } + + private void printFlightInfoResults(final FlightInfo flightInfo) throws Exception { + final FlightStream stream = + flightSqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket(), getCallOptions()); + while (stream.next()) { + try (final VectorSchemaRoot root = stream.getRoot()) { + System.out.println(root.contentToTSVString()); + } + } + stream.close(); + } + + @Override + public void close() throws Exception { + flightSqlClient.close(); + allocator.close(); + } +} diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/util/SqlInfoOptionsUtils.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/util/SqlInfoOptionsUtils.java new file mode 100644 index 00000000000..c43c48eb8e0 --- /dev/null +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/util/SqlInfoOptionsUtils.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql.util; + +import java.util.Arrays; +import java.util.Collection; + +import org.apache.arrow.flight.sql.FlightSqlClient; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlInfo; + +import com.google.protobuf.ProtocolMessageEnum; + +/** + * Utility class for {@link SqlInfo} and {@link FlightSqlClient#getSqlInfo} option parsing. + */ +public final class SqlInfoOptionsUtils { + private SqlInfoOptionsUtils() { + // Prevent instantiation. + } + + /** + * Returns whether the provided {@code bitmask} points to the provided {@link ProtocolMessageEnum} by comparing + * {@link ProtocolMessageEnum#getNumber} with the respective bit index of the {@code bitmask}. + * + * @param enumInstance the protobuf message enum to use. + * @param bitmask the bitmask response from {@link FlightSqlClient#getSqlInfo}. + * @return whether the provided {@code bitmask} points to the specified {@code enumInstance}. + */ + public static boolean doesBitmaskTranslateToEnum(final ProtocolMessageEnum enumInstance, final long bitmask) { + return ((bitmask >> enumInstance.getNumber()) & 1) == 1; + } + + /** + * Creates a bitmask that translates to the specified {@code enums}. + * + * @param enums the {@link ProtocolMessageEnum} instances to represent as bitmask. + * @return the bitmask. + */ + public static long createBitmaskFromEnums(final ProtocolMessageEnum... enums) { + return createBitmaskFromEnums(Arrays.asList(enums)); + } + + /** + * Creates a bitmask that translates to the specified {@code enums}. + * + * @param enums the {@link ProtocolMessageEnum} instances to represent as bitmask. + * @return the bitmask. + */ + public static long createBitmaskFromEnums(final Collection enums) { + return enums.stream() + .mapToInt(ProtocolMessageEnum::getNumber) + .map(bitIndexToSet -> 1 << bitIndexToSet) + .reduce((firstBitmask, secondBitmask) -> firstBitmask | secondBitmask) + .orElse(0); + } +} diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/util/TableRef.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/util/TableRef.java new file mode 100644 index 00000000000..315f17ee911 --- /dev/null +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/util/TableRef.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql.util; + +/** + * A helper class to reference a table to be passed to the flight + * sql client. + */ +public class TableRef { + private final String catalog; + private final String dbSchema; + private final String table; + + /** + * The complete constructor for the TableRef class. + * @param catalog the catalog from a table. + * @param dbSchema the database schema from a table. + * @param table the table name from a table. + */ + public TableRef(String catalog, String dbSchema, String table) { + this.catalog = catalog; + this.dbSchema = dbSchema; + this.table = table; + } + + /** + * A static initializer of the TableRef with all the arguments. + * @param catalog the catalog from a table. + * @param dbSchema the database schema from a table. + * @param table the table name from a table. + * @return A TableRef object. + */ + public static TableRef of(String catalog, String dbSchema, String table) { + return new TableRef(catalog, dbSchema, table); + } + + /** + * Retrieve the catalog from the object. + * @return the catalog. + */ + public String getCatalog() { + return catalog; + } + + /** + * Retrieves the db schema from the object. + * @return the dbSchema + */ + public String getDbSchema() { + return dbSchema; + } + + /** + * Retreives the table from the object. + * @return the table. + */ + public String getTable() { + return table; + } +} + diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java new file mode 100644 index 00000000000..159ef72401f --- /dev/null +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java @@ -0,0 +1,706 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static java.util.Collections.singletonList; +import static java.util.Objects.isNull; +import static org.apache.arrow.util.AutoCloseables.close; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.CoreMatchers.nullValue; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.nio.channels.Channels; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.IntStream; + +import org.apache.arrow.flight.sql.FlightSqlClient; +import org.apache.arrow.flight.sql.FlightSqlClient.PreparedStatement; +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.flight.sql.example.FlightSqlExample; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedCaseSensitivity; +import org.apache.arrow.flight.sql.util.TableRef; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.DenseUnionVector; +import org.apache.arrow.vector.ipc.ReadChannel; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; +import org.hamcrest.Matcher; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +import com.google.common.collect.ImmutableList; + +/** + * Test direct usage of Flight SQL workflows. + */ +public class TestFlightSql { + + protected static final Schema SCHEMA_INT_TABLE = new Schema(asList( + new Field("ID", new FieldType(false, MinorType.INT.getType(), null), null), + Field.nullable("KEYNAME", MinorType.VARCHAR.getType()), + Field.nullable("VALUE", MinorType.INT.getType()), + Field.nullable("FOREIGNID", MinorType.INT.getType()))); + private static final List> EXPECTED_RESULTS_FOR_STAR_SELECT_QUERY = ImmutableList.of( + asList("1", "one", "1", "1"), asList("2", "zero", "0", "1"), asList("3", "negative one", "-1", "1")); + private static final List> EXPECTED_RESULTS_FOR_PARAMETER_BINDING = ImmutableList.of( + asList("1", "one", "1", "1")); + private static final Map GET_SQL_INFO_EXPECTED_RESULTS_MAP = new LinkedHashMap<>(); + private static final String LOCALHOST = "localhost"; + private static BufferAllocator allocator; + private static FlightServer server; + private static FlightSqlClient sqlClient; + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + @BeforeClass + public static void setUp() throws Exception { + allocator = new RootAllocator(Integer.MAX_VALUE); + + final Location serverLocation = Location.forGrpcInsecure(LOCALHOST, 0); + server = FlightServer.builder(allocator, serverLocation, new FlightSqlExample(serverLocation)) + .build() + .start(); + + final Location clientLocation = Location.forGrpcInsecure(LOCALHOST, server.getPort()); + sqlClient = new FlightSqlClient(FlightClient.builder(allocator, clientLocation).build()); + + GET_SQL_INFO_EXPECTED_RESULTS_MAP + .put(Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME_VALUE), "Apache Derby"); + GET_SQL_INFO_EXPECTED_RESULTS_MAP + .put(Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_VERSION_VALUE), "10.14.2.0 - (1828579)"); + GET_SQL_INFO_EXPECTED_RESULTS_MAP + .put(Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_ARROW_VERSION_VALUE), "10.14.2.0 - (1828579)"); + GET_SQL_INFO_EXPECTED_RESULTS_MAP + .put(Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY_VALUE), "false"); + GET_SQL_INFO_EXPECTED_RESULTS_MAP + .put(Integer.toString(FlightSql.SqlInfo.SQL_DDL_CATALOG_VALUE), "false"); + GET_SQL_INFO_EXPECTED_RESULTS_MAP + .put(Integer.toString(FlightSql.SqlInfo.SQL_DDL_SCHEMA_VALUE), "true"); + GET_SQL_INFO_EXPECTED_RESULTS_MAP + .put(Integer.toString(FlightSql.SqlInfo.SQL_DDL_TABLE_VALUE), "true"); + GET_SQL_INFO_EXPECTED_RESULTS_MAP + .put( + Integer.toString(FlightSql.SqlInfo.SQL_IDENTIFIER_CASE_VALUE), + Integer.toString(SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_UPPERCASE_VALUE)); + GET_SQL_INFO_EXPECTED_RESULTS_MAP + .put(Integer.toString(FlightSql.SqlInfo.SQL_IDENTIFIER_QUOTE_CHAR_VALUE), "\""); + GET_SQL_INFO_EXPECTED_RESULTS_MAP + .put( + Integer.toString(FlightSql.SqlInfo.SQL_QUOTED_IDENTIFIER_CASE_VALUE), + Integer.toString(SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_CASE_INSENSITIVE_VALUE)); + } + + @AfterClass + public static void tearDown() throws Exception { + close(sqlClient, server, allocator); + } + + private static List> getNonConformingResultsForGetSqlInfo(final List> results) { + return getNonConformingResultsForGetSqlInfo(results, + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME, + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_VERSION, + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_ARROW_VERSION, + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY, + FlightSql.SqlInfo.SQL_DDL_CATALOG, + FlightSql.SqlInfo.SQL_DDL_SCHEMA, + FlightSql.SqlInfo.SQL_DDL_TABLE, + FlightSql.SqlInfo.SQL_IDENTIFIER_CASE, + FlightSql.SqlInfo.SQL_IDENTIFIER_QUOTE_CHAR, + FlightSql.SqlInfo.SQL_QUOTED_IDENTIFIER_CASE); + } + + private static List> getNonConformingResultsForGetSqlInfo( + final List> results, + final FlightSql.SqlInfo... args) { + final List> nonConformingResults = new ArrayList<>(); + if (results.size() == args.length) { + for (int index = 0; index < results.size(); index++) { + final List result = results.get(index); + final String providedName = result.get(0); + final String expectedName = Integer.toString(args[index].getNumber()); + if (!(GET_SQL_INFO_EXPECTED_RESULTS_MAP.get(providedName).equals(result.get(1)) && + providedName.equals(expectedName))) { + nonConformingResults.add(result); + break; + } + } + } + return nonConformingResults; + } + + @Test + public void testGetTablesSchema() { + final FlightInfo info = sqlClient.getTables(null, null, null, null, true); + collector.checkThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA)); + } + + @Test + public void testGetTablesResultNoSchema() throws Exception { + try (final FlightStream stream = + sqlClient.getStream( + sqlClient.getTables(null, null, null, null, false) + .getEndpoints().get(0).getTicket())) { + collector.checkThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA_NO_SCHEMA)); + final List> results = getResults(stream); + final List> expectedResults = ImmutableList.of( + // catalog_name | schema_name | table_name | table_type | table_schema + asList(null /* TODO No catalog yet */, "SYS", "SYSALIASES", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSCHECKS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSCOLPERMS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSCOLUMNS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSCONGLOMERATES", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSCONSTRAINTS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSDEPENDS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSFILES", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSFOREIGNKEYS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSKEYS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSPERMS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSROLES", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSROUTINEPERMS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSSCHEMAS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSSEQUENCES", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSSTATEMENTS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSSTATISTICS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSTABLEPERMS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSTABLES", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSTRIGGERS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSUSERS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSVIEWS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYSIBM", "SYSDUMMY1", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "APP", "FOREIGNTABLE", "TABLE"), + asList(null /* TODO No catalog yet */, "APP", "INTTABLE", "TABLE")); + collector.checkThat(results, is(expectedResults)); + } + } + + @Test + public void testGetTablesResultFilteredNoSchema() throws Exception { + try (final FlightStream stream = + sqlClient.getStream( + sqlClient.getTables(null, null, null, singletonList("TABLE"), false) + .getEndpoints().get(0).getTicket())) { + collector.checkThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA_NO_SCHEMA)); + final List> results = getResults(stream); + final List> expectedResults = ImmutableList.of( + // catalog_name | schema_name | table_name | table_type | table_schema + asList(null /* TODO No catalog yet */, "APP", "FOREIGNTABLE", "TABLE"), + asList(null /* TODO No catalog yet */, "APP", "INTTABLE", "TABLE")); + collector.checkThat(results, is(expectedResults)); + } + } + + @Test + public void testGetTablesResultFilteredWithSchema() throws Exception { + try (final FlightStream stream = + sqlClient.getStream( + sqlClient.getTables(null, null, null, singletonList("TABLE"), true) + .getEndpoints().get(0).getTicket())) { + collector.checkThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA)); + final List> results = getResults(stream); + final List> expectedResults = ImmutableList.of( + // catalog_name | schema_name | table_name | table_type | table_schema + asList( + null /* TODO No catalog yet */, + "APP", + "FOREIGNTABLE", + "TABLE", + new Schema(asList( + new Field("ID", new FieldType(false, MinorType.INT.getType(), null), null), + Field.nullable("FOREIGNNAME", MinorType.VARCHAR.getType()), + Field.nullable("VALUE", MinorType.INT.getType()))).toJson()), + asList( + null /* TODO No catalog yet */, + "APP", + "INTTABLE", + "TABLE", + new Schema(asList( + new Field("ID", new FieldType(false, MinorType.INT.getType(), null), null), + Field.nullable("KEYNAME", MinorType.VARCHAR.getType()), + Field.nullable("VALUE", MinorType.INT.getType()), + Field.nullable("FOREIGNID", MinorType.INT.getType()))).toJson())); + collector.checkThat(results, is(expectedResults)); + } + } + + @Test + public void testSimplePreparedStatementSchema() throws Exception { + try (final PreparedStatement preparedStatement = sqlClient.prepare("SELECT * FROM intTable")) { + final Schema actualSchema = preparedStatement.getResultSetSchema(); + collector.checkThat(actualSchema, is(SCHEMA_INT_TABLE)); + + final FlightInfo info = preparedStatement.execute(); + collector.checkThat(info.getSchema(), is(SCHEMA_INT_TABLE)); + } + } + + @Test + public void testSimplePreparedStatementResults() throws Exception { + try (final PreparedStatement preparedStatement = sqlClient.prepare("SELECT * FROM intTable"); + final FlightStream stream = sqlClient.getStream( + preparedStatement.execute().getEndpoints().get(0).getTicket())) { + collector.checkThat(stream.getSchema(), is(SCHEMA_INT_TABLE)); + collector.checkThat(getResults(stream), is(EXPECTED_RESULTS_FOR_STAR_SELECT_QUERY)); + } + } + + @Test + public void testSimplePreparedStatementResultsWithParameterBinding() throws Exception { + try (PreparedStatement prepare = sqlClient.prepare("SELECT * FROM intTable WHERE id = ?")) { + final Schema parameterSchema = prepare.getParameterSchema(); + try (final VectorSchemaRoot insertRoot = VectorSchemaRoot.create(parameterSchema, allocator)) { + insertRoot.allocateNew(); + + final IntVector valueVector = (IntVector) insertRoot.getVector(0); + valueVector.setSafe(0, 1); + insertRoot.setRowCount(1); + + prepare.setParameters(insertRoot); + FlightInfo flightInfo = prepare.execute(); + + FlightStream stream = sqlClient.getStream(flightInfo + .getEndpoints() + .get(0).getTicket()); + + collector.checkThat(stream.getSchema(), is(SCHEMA_INT_TABLE)); + collector.checkThat(getResults(stream), is(EXPECTED_RESULTS_FOR_PARAMETER_BINDING)); + } + } + } + + @Test + public void testSimplePreparedStatementUpdateResults() throws SQLException { + try (PreparedStatement prepare = sqlClient.prepare("INSERT INTO INTTABLE (keyName, value ) VALUES (?, ?)"); + PreparedStatement deletePrepare = sqlClient.prepare("DELETE FROM INTTABLE WHERE keyName = ?")) { + final Schema parameterSchema = prepare.getParameterSchema(); + try (final VectorSchemaRoot insertRoot = VectorSchemaRoot.create(parameterSchema, allocator)) { + final VarCharVector varCharVector = (VarCharVector) insertRoot.getVector(0); + final IntVector valueVector = (IntVector) insertRoot.getVector(1); + final int counter = 10; + insertRoot.allocateNew(); + + final IntStream range = IntStream.range(0, counter); + + range.forEach(i -> { + valueVector.setSafe(i, i * counter); + varCharVector.setSafe(i, new Text("value" + i)); + }); + + insertRoot.setRowCount(counter); + + prepare.setParameters(insertRoot); + final long updatedRows = prepare.executeUpdate(); + + final long deletedRows; + try (final VectorSchemaRoot deleteRoot = VectorSchemaRoot.of(varCharVector)) { + deletePrepare.setParameters(deleteRoot); + deletedRows = deletePrepare.executeUpdate(); + } + + collector.checkThat(updatedRows, is(10L)); + collector.checkThat(deletedRows, is(10L)); + } + } + } + + @Test + public void testSimplePreparedStatementUpdateResultsWithoutParameters() throws SQLException { + try (PreparedStatement prepare = sqlClient + .prepare("INSERT INTO INTTABLE (keyName, value ) VALUES ('test', 1000)"); + PreparedStatement deletePrepare = sqlClient.prepare("DELETE FROM INTTABLE WHERE keyName = 'test'")) { + final long updatedRows = prepare.executeUpdate(); + + final long deletedRows = deletePrepare.executeUpdate(); + + collector.checkThat(updatedRows, is(1L)); + collector.checkThat(deletedRows, is(1L)); + } + } + + @Test + public void testSimplePreparedStatementClosesProperly() { + final PreparedStatement preparedStatement = sqlClient.prepare("SELECT * FROM intTable"); + collector.checkThat(preparedStatement.isClosed(), is(false)); + preparedStatement.close(); + collector.checkThat(preparedStatement.isClosed(), is(true)); + } + + @Test + public void testGetCatalogsSchema() { + final FlightInfo info = sqlClient.getCatalogs(); + collector.checkThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA)); + } + + @Test + public void testGetCatalogsResults() throws Exception { + try (final FlightStream stream = + sqlClient.getStream(sqlClient.getCatalogs().getEndpoints().get(0).getTicket())) { + collector.checkThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA)); + List> catalogs = getResults(stream); + collector.checkThat(catalogs, is(emptyList())); + } + } + + @Test + public void testGetTableTypesSchema() { + final FlightInfo info = sqlClient.getTableTypes(); + collector.checkThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA)); + } + + @Test + public void testGetTableTypesResult() throws Exception { + try (final FlightStream stream = + sqlClient.getStream(sqlClient.getTableTypes().getEndpoints().get(0).getTicket())) { + collector.checkThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA)); + final List> tableTypes = getResults(stream); + final List> expectedTableTypes = ImmutableList.of( + // table_type + singletonList("SYNONYM"), + singletonList("SYSTEM TABLE"), + singletonList("TABLE"), + singletonList("VIEW") + ); + collector.checkThat(tableTypes, is(expectedTableTypes)); + } + } + + @Test + public void testGetSchemasSchema() { + final FlightInfo info = sqlClient.getSchemas(null, null); + collector.checkThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_SCHEMAS_SCHEMA)); + } + + @Test + public void testGetSchemasResult() throws Exception { + try (final FlightStream stream = + sqlClient.getStream(sqlClient.getSchemas(null, null).getEndpoints().get(0).getTicket())) { + collector.checkThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_SCHEMAS_SCHEMA)); + final List> schemas = getResults(stream); + final List> expectedSchemas = ImmutableList.of( + // catalog_name | schema_name + asList(null /* TODO Add catalog. */, "APP"), + asList(null /* TODO Add catalog. */, "NULLID"), + asList(null /* TODO Add catalog. */, "SQLJ"), + asList(null /* TODO Add catalog. */, "SYS"), + asList(null /* TODO Add catalog. */, "SYSCAT"), + asList(null /* TODO Add catalog. */, "SYSCS_DIAG"), + asList(null /* TODO Add catalog. */, "SYSCS_UTIL"), + asList(null /* TODO Add catalog. */, "SYSFUN"), + asList(null /* TODO Add catalog. */, "SYSIBM"), + asList(null /* TODO Add catalog. */, "SYSPROC"), + asList(null /* TODO Add catalog. */, "SYSSTAT")); + collector.checkThat(schemas, is(expectedSchemas)); + } + } + + @Test + public void testGetPrimaryKey() { + final FlightInfo flightInfo = sqlClient.getPrimaryKeys(TableRef.of(null, null, "INTTABLE")); + final FlightStream stream = sqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket()); + + final List> results = getResults(stream); + collector.checkThat(results.size(), is(1)); + + final List result = results.get(0); + + collector.checkThat(result.get(0), is("")); + collector.checkThat(result.get(1), is("APP")); + collector.checkThat(result.get(2), is("INTTABLE")); + collector.checkThat(result.get(3), is("ID")); + collector.checkThat(result.get(4), is("1")); + collector.checkThat(result.get(5), notNullValue()); + } + + @Test + public void testGetSqlInfoSchema() { + final FlightInfo info = sqlClient.getSqlInfo(); + collector.checkThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA)); + } + + @Test + public void testGetSqlInfoResults() throws Exception { + final FlightInfo info = sqlClient.getSqlInfo(); + try (final FlightStream stream = sqlClient.getStream(info.getEndpoints().get(0).getTicket())) { + collector.checkThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA)); + collector.checkThat(getNonConformingResultsForGetSqlInfo(getResults(stream)), is(emptyList())); + } + } + + @Test + public void testGetSqlInfoResultsWithSingleArg() throws Exception { + final FlightSql.SqlInfo arg = FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME; + final FlightInfo info = sqlClient.getSqlInfo(arg); + try (final FlightStream stream = sqlClient.getStream(info.getEndpoints().get(0).getTicket())) { + collector.checkThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA)); + collector.checkThat(getNonConformingResultsForGetSqlInfo(getResults(stream), arg), is(emptyList())); + } + } + + @Test + public void testGetSqlInfoResultsWithTwoArgs() throws Exception { + final FlightSql.SqlInfo[] args = { + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME, + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_VERSION}; + final FlightInfo info = sqlClient.getSqlInfo(args); + try (final FlightStream stream = sqlClient.getStream(info.getEndpoints().get(0).getTicket())) { + collector.checkThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA)); + collector.checkThat(getNonConformingResultsForGetSqlInfo(getResults(stream), args), is(emptyList())); + } + } + + @Test + public void testGetSqlInfoResultsWithThreeArgs() throws Exception { + final FlightSql.SqlInfo[] args = { + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME, + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_VERSION, + FlightSql.SqlInfo.SQL_IDENTIFIER_QUOTE_CHAR}; + final FlightInfo info = sqlClient.getSqlInfo(args); + try (final FlightStream stream = sqlClient.getStream(info.getEndpoints().get(0).getTicket())) { + collector.checkThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA)); + collector.checkThat(getNonConformingResultsForGetSqlInfo(getResults(stream), args), is(emptyList())); + } + } + + @Test + public void testGetCommandExportedKeys() { + final FlightStream stream = + sqlClient.getStream( + sqlClient.getExportedKeys(TableRef.of(null, null, "FOREIGNTABLE")) + .getEndpoints().get(0).getTicket()); + + final List> results = getResults(stream); + + final List> matchers = asList( + nullValue(String.class), // pk_catalog_name + is("APP"), // pk_schema_name + is("FOREIGNTABLE"), // pk_table_name + is("ID"), // pk_column_name + nullValue(String.class), // fk_catalog_name + is("APP"), // fk_schema_name + is("INTTABLE"), // fk_table_name + is("FOREIGNID"), // fk_column_name + is("1"), // key_sequence + containsString("SQL"), // fk_key_name + containsString("SQL"), // pk_key_name + is("3"), // update_rule + is("3")); // delete_rule + + Assert.assertEquals(1, results.size()); + for (int i = 0; i < matchers.size(); i++) { + collector.checkThat(results.get(0).get(i), matchers.get(i)); + } + } + + @Test + public void testGetCommandImportedKeys() { + final FlightStream stream = + sqlClient.getStream( + sqlClient.getImportedKeys(TableRef.of(null, null, "INTTABLE")) + .getEndpoints().get(0).getTicket()); + + final List> results = getResults(stream); + + final List> matchers = asList( + nullValue(String.class), // pk_catalog_name + is("APP"), // pk_schema_name + is("FOREIGNTABLE"), // pk_table_name + is("ID"), // pk_column_name + nullValue(String.class), // fk_catalog_name + is("APP"), // fk_schema_name + is("INTTABLE"), // fk_table_name + is("FOREIGNID"), // fk_column_name + is("1"), // key_sequence + containsString("SQL"), // fk_key_name + containsString("SQL"), // pk_key_name + is("3"), // update_rule + is("3")); // delete_rule + + Assert.assertEquals(1, results.size()); + for (int i = 0; i < matchers.size(); i++) { + collector.checkThat(results.get(0).get(i), matchers.get(i)); + } + } + + @Test + public void testGetCommandCrossReference() { + final FlightInfo flightInfo = sqlClient.getCrossReference(TableRef.of(null, null, + "FOREIGNTABLE"), TableRef.of(null, null, "INTTABLE")); + final FlightStream stream = sqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket()); + + final List> results = getResults(stream); + + final List> matchers = asList( + nullValue(String.class), // pk_catalog_name + is("APP"), // pk_schema_name + is("FOREIGNTABLE"), // pk_table_name + is("ID"), // pk_column_name + nullValue(String.class), // fk_catalog_name + is("APP"), // fk_schema_name + is("INTTABLE"), // fk_table_name + is("FOREIGNID"), // fk_column_name + is("1"), // key_sequence + containsString("SQL"), // fk_key_name + containsString("SQL"), // pk_key_name + is("3"), // update_rule + is("3")); // delete_rule + + Assert.assertEquals(1, results.size()); + for (int i = 0; i < matchers.size(); i++) { + collector.checkThat(results.get(0).get(i), matchers.get(i)); + } + } + + @Test + public void testCreateStatementSchema() throws Exception { + final FlightInfo info = sqlClient.execute("SELECT * FROM intTable"); + collector.checkThat(info.getSchema(), is(SCHEMA_INT_TABLE)); + + // Consume statement to close connection before cache eviction + try (FlightStream stream = sqlClient.getStream(info.getEndpoints().get(0).getTicket())) { + while (stream.next()) { + // Do nothing + } + } + } + + @Test + public void testCreateStatementResults() throws Exception { + try (final FlightStream stream = sqlClient + .getStream(sqlClient.execute("SELECT * FROM intTable").getEndpoints().get(0).getTicket())) { + collector.checkThat(stream.getSchema(), is(SCHEMA_INT_TABLE)); + collector.checkThat(getResults(stream), is(EXPECTED_RESULTS_FOR_STAR_SELECT_QUERY)); + } + } + + List> getResults(FlightStream stream) { + final List> results = new ArrayList<>(); + while (stream.next()) { + try (final VectorSchemaRoot root = stream.getRoot()) { + final long rowCount = root.getRowCount(); + for (int i = 0; i < rowCount; ++i) { + results.add(new ArrayList<>()); + } + + root.getSchema().getFields().forEach(field -> { + try (final FieldVector fieldVector = root.getVector(field.getName())) { + if (fieldVector instanceof VarCharVector) { + final VarCharVector varcharVector = (VarCharVector) fieldVector; + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { + final Text data = varcharVector.getObject(rowIndex); + results.get(rowIndex).add(isNull(data) ? null : data.toString()); + } + } else if (fieldVector instanceof IntVector) { + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { + results.get(rowIndex).add(String.valueOf(((IntVector) fieldVector).get(rowIndex))); + } + } else if (fieldVector instanceof VarBinaryVector) { + final VarBinaryVector varbinaryVector = (VarBinaryVector) fieldVector; + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { + final byte[] data = varbinaryVector.getObject(rowIndex); + final String output; + try { + output = isNull(data) ? + null : + MessageSerializer.deserializeSchema( + new ReadChannel(Channels.newChannel(new ByteArrayInputStream(data)))).toJson(); + } catch (final IOException e) { + throw new RuntimeException("Failed to deserialize schema", e); + } + results.get(rowIndex).add(output); + } + } else if (fieldVector instanceof DenseUnionVector) { + final DenseUnionVector denseUnionVector = (DenseUnionVector) fieldVector; + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { + final Object data = denseUnionVector.getObject(rowIndex); + results.get(rowIndex).add(isNull(data) ? null : Objects.toString(data)); + } + } else if (fieldVector instanceof UInt4Vector) { + final UInt4Vector uInt4Vector = (UInt4Vector) fieldVector; + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { + final Object data = uInt4Vector.getObject(rowIndex); + results.get(rowIndex).add(isNull(data) ? null : Objects.toString(data)); + } + } else if (fieldVector instanceof UInt1Vector) { + final UInt1Vector uInt1Vector = (UInt1Vector) fieldVector; + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { + final Object data = uInt1Vector.getObject(rowIndex); + results.get(rowIndex).add(isNull(data) ? null : Objects.toString(data)); + } + } else { + throw new UnsupportedOperationException("Not yet implemented"); + } + } + }); + } + } + + return results; + } + + @Test + public void testExecuteUpdate() { + long insertedCount = sqlClient.executeUpdate("INSERT INTO INTTABLE (keyName, value) VALUES " + + "('KEYNAME1', 1001), ('KEYNAME2', 1002), ('KEYNAME3', 1003)"); + collector.checkThat(insertedCount, is(3L)); + + long updatedCount = sqlClient.executeUpdate("UPDATE INTTABLE SET keyName = 'KEYNAME1' " + + "WHERE keyName = 'KEYNAME2' OR keyName = 'KEYNAME3'"); + collector.checkThat(updatedCount, is(2L)); + + long deletedCount = sqlClient.executeUpdate("DELETE FROM INTTABLE WHERE keyName = 'KEYNAME1'"); + collector.checkThat(deletedCount, is(3L)); + } + + @Test + public void testQueryWithNoResultsShouldNotHang() throws Exception { + try (final PreparedStatement preparedStatement = sqlClient.prepare("SELECT * FROM intTable WHERE 1 = 0"); + final FlightStream stream = sqlClient + .getStream(preparedStatement.execute().getEndpoints().get(0).getTicket())) { + collector.checkThat(stream.getSchema(), is(SCHEMA_INT_TABLE)); + + final List> result = getResults(stream); + collector.checkThat(result, is(emptyList())); + } + } +} diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java new file mode 100644 index 00000000000..90a2aaf1004 --- /dev/null +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java @@ -0,0 +1,1625 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql.example; + +import static com.google.common.base.Strings.emptyToNull; +import static com.google.protobuf.Any.pack; +import static com.google.protobuf.ByteString.copyFrom; +import static java.lang.String.format; +import static java.util.Collections.singletonList; +import static java.util.Objects.isNull; +import static java.util.UUID.randomUUID; +import static org.apache.arrow.adapter.jdbc.JdbcToArrow.sqlToArrowVectorIterator; +import static org.apache.arrow.adapter.jdbc.JdbcToArrowUtils.jdbcToArrowSchema; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCrossReference; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetDbSchemas; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetExportedKeys; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetImportedKeys; +import static org.apache.arrow.flight.sql.impl.FlightSql.DoPutUpdateResult; +import static org.apache.arrow.flight.sql.impl.FlightSql.TicketStatementQuery; +import static org.apache.arrow.util.Preconditions.checkState; +import static org.slf4j.LoggerFactory.getLogger; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.NoSuchFileException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.Date; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Statement; +import java.sql.Time; +import java.sql.Timestamp; +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Calendar; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Properties; +import java.util.Set; +import java.util.TimeZone; +import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.stream.Stream; + +import org.apache.arrow.adapter.jdbc.ArrowVectorIterator; +import org.apache.arrow.adapter.jdbc.JdbcFieldInfo; +import org.apache.arrow.adapter.jdbc.JdbcToArrowUtils; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.Criteria; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.SchemaResult; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.flight.sql.SqlInfoBuilder; +import org.apache.arrow.flight.sql.impl.FlightSql.ActionClosePreparedStatementRequest; +import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementRequest; +import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementResult; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCatalogs; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetPrimaryKeys; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetSqlInfo; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTableTypes; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTables; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementQuery; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementUpdate; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementQuery; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementUpdate; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedCaseSensitivity; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.Decimal256Vector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.LargeVarCharVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeNanoVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.TimeStampMilliTZVector; +import org.apache.arrow.vector.TimeStampNanoTZVector; +import org.apache.arrow.vector.TimeStampSecTZVector; +import org.apache.arrow.vector.TimeStampVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.UInt2Vector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.UInt8Vector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; +import org.apache.commons.dbcp2.ConnectionFactory; +import org.apache.commons.dbcp2.DriverManagerConnectionFactory; +import org.apache.commons.dbcp2.PoolableConnection; +import org.apache.commons.dbcp2.PoolableConnectionFactory; +import org.apache.commons.dbcp2.PoolingDataSource; +import org.apache.commons.pool2.ObjectPool; +import org.apache.commons.pool2.impl.GenericObjectPool; +import org.slf4j.Logger; + +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.RemovalListener; +import com.google.common.cache.RemovalNotification; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.ByteString; +import com.google.protobuf.Message; +import com.google.protobuf.ProtocolStringList; + +/** + * Proof of concept {@link FlightSqlProducer} implementation showing an Apache Derby backed Flight SQL server capable + * of the following workflows: + * + * - returning a list of tables from the action `GetTables`. + * - creation of a prepared statement from the action `CreatePreparedStatement`. + * - execution of a prepared statement by using a {@link CommandPreparedStatementQuery} + * with {@link #getFlightInfo} and {@link #getStream}. + */ +public class FlightSqlExample implements FlightSqlProducer, AutoCloseable { + private static final String DATABASE_URI = "jdbc:derby:target/derbyDB"; + private static final Logger LOGGER = getLogger(FlightSqlExample.class); + private static final Calendar DEFAULT_CALENDAR = JdbcToArrowUtils.getUtcCalendar(); + private final Location location; + private final PoolingDataSource dataSource; + private final BufferAllocator rootAllocator = new RootAllocator(); + private final Cache> preparedStatementLoadingCache; + private final Cache> statementLoadingCache; + private final SqlInfoBuilder sqlInfoBuilder; + + public FlightSqlExample(final Location location) { + // TODO Constructor should not be doing work. + checkState( + removeDerbyDatabaseIfExists() && populateDerbyDatabase(), + "Failed to reset Derby database!"); + final ConnectionFactory connectionFactory = + new DriverManagerConnectionFactory(DATABASE_URI, new Properties()); + final PoolableConnectionFactory poolableConnectionFactory = + new PoolableConnectionFactory(connectionFactory, null); + final ObjectPool connectionPool = new GenericObjectPool<>(poolableConnectionFactory); + + poolableConnectionFactory.setPool(connectionPool); + // PoolingDataSource takes ownership of `connectionPool` + dataSource = new PoolingDataSource<>(connectionPool); + + preparedStatementLoadingCache = + CacheBuilder.newBuilder() + .maximumSize(100) + .expireAfterWrite(10, TimeUnit.MINUTES) + .removalListener(new StatementRemovalListener()) + .build(); + + statementLoadingCache = + CacheBuilder.newBuilder() + .maximumSize(100) + .expireAfterWrite(10, TimeUnit.MINUTES) + .removalListener(new StatementRemovalListener<>()) + .build(); + + this.location = location; + + sqlInfoBuilder = new SqlInfoBuilder(); + try (final Connection connection = dataSource.getConnection()) { + final DatabaseMetaData metaData = connection.getMetaData(); + + sqlInfoBuilder.withFlightSqlServerName(metaData.getDatabaseProductName()) + .withFlightSqlServerVersion(metaData.getDatabaseProductVersion()) + .withFlightSqlServerArrowVersion(metaData.getDriverVersion()) + .withFlightSqlServerReadOnly(metaData.isReadOnly()) + .withSqlIdentifierQuoteChar(metaData.getIdentifierQuoteString()) + .withSqlDdlCatalog(metaData.supportsCatalogsInDataManipulation()) + .withSqlDdlSchema( metaData.supportsSchemasInDataManipulation()) + .withSqlDdlTable( metaData.allTablesAreSelectable()) + .withSqlIdentifierCase(metaData.storesMixedCaseIdentifiers() ? + SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_CASE_INSENSITIVE : + metaData.storesUpperCaseIdentifiers() ? + SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_UPPERCASE : + metaData.storesLowerCaseIdentifiers() ? + SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_LOWERCASE : + SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_UNKNOWN) + .withSqlQuotedIdentifierCase(metaData.storesMixedCaseQuotedIdentifiers() ? + SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_CASE_INSENSITIVE : + metaData.storesUpperCaseQuotedIdentifiers() ? + SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_UPPERCASE : + metaData.storesLowerCaseQuotedIdentifiers() ? + SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_LOWERCASE : + SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_UNKNOWN); + } catch (SQLException e) { + throw new RuntimeException(e); + } + + } + + private static boolean removeDerbyDatabaseIfExists() { + boolean wasSuccess; + final Path path = Paths.get("target" + File.separator + "derbyDB"); + + try (final Stream walk = Files.walk(path)) { + /* + * Iterate over all paths to delete, mapping each path to the outcome of its own + * deletion as a boolean representing whether or not each individual operation was + * successful; then reduce all booleans into a single answer, and store that into + * `wasSuccess`, which will later be returned by this method. + * If for whatever reason the resulting `Stream` is empty, throw an `IOException`; + * this not expected. + */ + wasSuccess = walk.sorted(Comparator.reverseOrder()).map(Path::toFile).map(File::delete) + .reduce(Boolean::logicalAnd).orElseThrow(IOException::new); + } catch (IOException e) { + /* + * The only acceptable scenario for an `IOException` to be thrown here is if + * an attempt to delete an non-existing file takes place -- which should be + * alright, since they would be deleted anyway. + */ + if (!(wasSuccess = e instanceof NoSuchFileException)) { + LOGGER.error(format("Failed attempt to clear DerbyDB: <%s>", e.getMessage()), e); + } + } + + return wasSuccess; + } + + private static boolean populateDerbyDatabase() { + try (final Connection connection = DriverManager.getConnection("jdbc:derby:target/derbyDB;create=true"); + Statement statement = connection.createStatement()) { + statement.execute("CREATE TABLE foreignTable (" + + "id INT not null primary key GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1), " + + "foreignName varchar(100), " + + "value int)"); + statement.execute("CREATE TABLE intTable (" + + "id INT not null primary key GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1), " + + "keyName varchar(100), " + + "value int, " + + "foreignId int references foreignTable(id))"); + statement.execute("INSERT INTO foreignTable (foreignName, value) VALUES ('keyOne', 1)"); + statement.execute("INSERT INTO foreignTable (foreignName, value) VALUES ('keyTwo', 0)"); + statement.execute("INSERT INTO foreignTable (foreignName, value) VALUES ('keyThree', -1)"); + statement.execute("INSERT INTO intTable (keyName, value, foreignId) VALUES ('one', 1, 1)"); + statement.execute("INSERT INTO intTable (keyName, value, foreignId) VALUES ('zero', 0, 1)"); + statement.execute("INSERT INTO intTable (keyName, value, foreignId) VALUES ('negative one', -1, 1)"); + } catch (final SQLException e) { + LOGGER.error(format("Failed attempt to populate DerbyDB: <%s>", e.getMessage()), e); + return false; + } + return true; + } + + private static ArrowType getArrowTypeFromJdbcType(final int jdbcDataType, final int precision, final int scale) { + final ArrowType type = + JdbcToArrowUtils.getArrowTypeFromJdbcType(new JdbcFieldInfo(jdbcDataType, precision, scale), DEFAULT_CALENDAR); + return isNull(type) ? ArrowType.Utf8.INSTANCE : type; + } + + private static void saveToVector(final Byte data, final UInt1Vector vector, final int index) { + vectorConsumer( + data, + vector, + fieldVector -> fieldVector.setNull(index), + (theData, fieldVector) -> fieldVector.setSafe(index, theData)); + } + + private static void saveToVector(final String data, final VarCharVector vector, final int index) { + preconditionCheckSaveToVector(vector, index); + vectorConsumer(data, vector, fieldVector -> fieldVector.setNull(index), + (theData, fieldVector) -> fieldVector.setSafe(index, new Text(theData))); + } + + private static void saveToVector(final Integer data, final IntVector vector, final int index) { + preconditionCheckSaveToVector(vector, index); + vectorConsumer(data, vector, fieldVector -> fieldVector.setNull(index), + (theData, fieldVector) -> fieldVector.setSafe(index, theData)); + } + + private static void saveToVector(final byte[] data, final VarBinaryVector vector, final int index) { + preconditionCheckSaveToVector(vector, index); + vectorConsumer(data, vector, fieldVector -> fieldVector.setNull(index), + (theData, fieldVector) -> fieldVector.setSafe(index, theData)); + } + + private static void preconditionCheckSaveToVector(final FieldVector vector, final int index) { + Objects.requireNonNull(vector, "vector cannot be null."); + checkState(index >= 0, "Index must be a positive number!"); + } + + private static void vectorConsumer(final T data, final V vector, + final Consumer consumerIfNullable, + final BiConsumer defaultConsumer) { + if (isNull(data)) { + consumerIfNullable.accept(vector); + return; + } + defaultConsumer.accept(data, vector); + } + + private static VectorSchemaRoot getSchemasRoot(final ResultSet data, final BufferAllocator allocator) + throws SQLException { + final VarCharVector catalogs = new VarCharVector("catalog_name", allocator); + final VarCharVector schemas = + new VarCharVector("db_schema_name", FieldType.notNullable(MinorType.VARCHAR.getType()), allocator); + final List vectors = ImmutableList.of(catalogs, schemas); + vectors.forEach(FieldVector::allocateNew); + final Map vectorToColumnName = ImmutableMap.of( + catalogs, "TABLE_CATALOG", + schemas, "TABLE_SCHEM"); + saveToVectors(vectorToColumnName, data); + final int rows = vectors.stream().map(FieldVector::getValueCount).findAny().orElseThrow(IllegalStateException::new); + vectors.forEach(vector -> vector.setValueCount(rows)); + return new VectorSchemaRoot(vectors); + } + + private static int saveToVectors(final Map vectorToColumnName, + final ResultSet data, boolean emptyToNull) + throws SQLException { + Objects.requireNonNull(vectorToColumnName, "vectorToColumnName cannot be null."); + Objects.requireNonNull(data, "data cannot be null."); + final Set> entrySet = vectorToColumnName.entrySet(); + int rows = 0; + for (; data.next(); rows++) { + for (final Entry vectorToColumn : entrySet) { + final T vector = vectorToColumn.getKey(); + final String columnName = vectorToColumn.getValue(); + if (vector instanceof VarCharVector) { + String thisData = data.getString(columnName); + saveToVector(emptyToNull ? emptyToNull(thisData) : thisData, (VarCharVector) vector, rows); + continue; + } else if (vector instanceof IntVector) { + final int intValue = data.getInt(columnName); + saveToVector(data.wasNull() ? null : intValue, (IntVector) vector, rows); + continue; + } else if (vector instanceof UInt1Vector) { + final byte byteValue = data.getByte(columnName); + saveToVector(data.wasNull() ? null : byteValue, (UInt1Vector) vector, rows); + continue; + } + throw CallStatus.INVALID_ARGUMENT.withDescription("Provided vector not supported").toRuntimeException(); + } + } + for (final Entry vectorToColumn : entrySet) { + vectorToColumn.getKey().setValueCount(rows); + } + + return rows; + } + + private static void saveToVectors(final Map vectorToColumnName, + final ResultSet data) + throws SQLException { + saveToVectors(vectorToColumnName, data, false); + } + + private static VectorSchemaRoot getTableTypesRoot(final ResultSet data, final BufferAllocator allocator) + throws SQLException { + return getRoot(data, allocator, "table_type", "TABLE_TYPE"); + } + + private static VectorSchemaRoot getCatalogsRoot(final ResultSet data, final BufferAllocator allocator) + throws SQLException { + return getRoot(data, allocator, "catalog_name", "TABLE_CATALOG"); + } + + private static VectorSchemaRoot getRoot(final ResultSet data, final BufferAllocator allocator, + final String fieldVectorName, final String columnName) + throws SQLException { + final VarCharVector dataVector = + new VarCharVector(fieldVectorName, FieldType.notNullable(MinorType.VARCHAR.getType()), allocator); + saveToVectors(ImmutableMap.of(dataVector, columnName), data); + final int rows = dataVector.getValueCount(); + dataVector.setValueCount(rows); + return new VectorSchemaRoot(singletonList(dataVector)); + } + + private static VectorSchemaRoot getTablesRoot(final DatabaseMetaData databaseMetaData, + final BufferAllocator allocator, + final boolean includeSchema, + final String catalog, + final String schemaFilterPattern, + final String tableFilterPattern, + final String... tableTypes) + throws SQLException, IOException { + /* + * TODO Fix DerbyDB inconsistency if possible. + * During the early development of this prototype, an inconsistency has been found in the database + * used for this demonstration; as DerbyDB does not operate with the concept of catalogs, fetching + * the catalog name for a given table from `DatabaseMetadata#getColumns` and `DatabaseMetadata#getSchemas` + * returns null, as expected. However, the inconsistency lies in the fact that accessing the same + * information -- that is, the catalog name for a given table -- from `DatabaseMetadata#getSchemas` + * returns an empty String.The temporary workaround for this was making sure we convert the empty Strings + * to null using `com.google.common.base.Strings#emptyToNull`. + */ + Objects.requireNonNull(allocator, "BufferAllocator cannot be null."); + final VarCharVector catalogNameVector = new VarCharVector("catalog_name", allocator); + final VarCharVector schemaNameVector = new VarCharVector("db_schema_name", allocator); + final VarCharVector tableNameVector = + new VarCharVector("table_name", FieldType.notNullable(MinorType.VARCHAR.getType()), allocator); + final VarCharVector tableTypeVector = + new VarCharVector("table_type", FieldType.notNullable(MinorType.VARCHAR.getType()), allocator); + + final List vectors = new ArrayList<>(4); + vectors.add(catalogNameVector); + vectors.add(schemaNameVector); + vectors.add(tableNameVector); + vectors.add(tableTypeVector); + + vectors.forEach(FieldVector::allocateNew); + + final Map vectorToColumnName = ImmutableMap.of( + catalogNameVector, "TABLE_CAT", + schemaNameVector, "TABLE_SCHEM", + tableNameVector, "TABLE_NAME", + tableTypeVector, "TABLE_TYPE"); + + try (final ResultSet data = + Objects.requireNonNull( + databaseMetaData, + format("%s cannot be null.", databaseMetaData.getClass().getName())) + .getTables(catalog, schemaFilterPattern, tableFilterPattern, tableTypes)) { + + saveToVectors(vectorToColumnName, data, true); + final int rows = + vectors.stream().map(FieldVector::getValueCount).findAny().orElseThrow(IllegalStateException::new); + vectors.forEach(vector -> vector.setValueCount(rows)); + + if (includeSchema) { + final VarBinaryVector tableSchemaVector = + new VarBinaryVector("table_schema", FieldType.notNullable(MinorType.VARBINARY.getType()), allocator); + tableSchemaVector.allocateNew(rows); + + try (final ResultSet columnsData = + databaseMetaData.getColumns(catalog, schemaFilterPattern, tableFilterPattern, null)) { + final Map> tableToFields = new HashMap<>(); + + while (columnsData.next()) { + final String tableName = columnsData.getString("TABLE_NAME"); + final String fieldName = columnsData.getString("COLUMN_NAME"); + final int dataType = columnsData.getInt("DATA_TYPE"); + final boolean isNullable = columnsData.getInt("NULLABLE") != DatabaseMetaData.columnNoNulls; + final int precision = columnsData.getInt("NUM_PREC_RADIX"); + final int scale = columnsData.getInt("DECIMAL_DIGITS"); + final List fields = tableToFields.computeIfAbsent(tableName, tableName_ -> new ArrayList<>()); + final Field field = + new Field( + fieldName, + new FieldType( + isNullable, + getArrowTypeFromJdbcType(dataType, precision, scale), + null), + null); + fields.add(field); + } + + for (int index = 0; index < rows; index++) { + final String tableName = tableNameVector.getObject(index).toString(); + final Schema schema = new Schema(tableToFields.get(tableName)); + saveToVector( + copyFrom(serializeMetadata(schema)).toByteArray(), + tableSchemaVector, index); + } + } + + tableSchemaVector.setValueCount(rows); + vectors.add(tableSchemaVector); + } + } + + return new VectorSchemaRoot(vectors); + } + + private static ByteBuffer serializeMetadata(final Schema schema) { + final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + try { + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(outputStream)), schema); + + return ByteBuffer.wrap(outputStream.toByteArray()); + } catch (final IOException e) { + throw new RuntimeException("Failed to serialize schema", e); + } + } + + @Override + public void getStreamPreparedStatement(final CommandPreparedStatementQuery command, final CallContext context, + final ServerStreamListener listener) { + final ByteString handle = command.getPreparedStatementHandle(); + StatementContext statementContext = preparedStatementLoadingCache.getIfPresent(handle); + Objects.requireNonNull(statementContext); + final PreparedStatement statement = statementContext.getStatement(); + try (final ResultSet resultSet = statement.executeQuery()) { + final Schema schema = jdbcToArrowSchema(resultSet.getMetaData(), DEFAULT_CALENDAR); + try (final VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(schema, rootAllocator)) { + final VectorLoader loader = new VectorLoader(vectorSchemaRoot); + listener.start(vectorSchemaRoot); + + final ArrowVectorIterator iterator = sqlToArrowVectorIterator(resultSet, rootAllocator); + while (iterator.hasNext()) { + final VectorSchemaRoot batch = iterator.next(); + if (batch.getRowCount() == 0) { + break; + } + final VectorUnloader unloader = new VectorUnloader(batch); + loader.load(unloader.getRecordBatch()); + listener.putNext(); + vectorSchemaRoot.clear(); + } + + listener.putNext(); + } + } catch (final SQLException | IOException e) { + LOGGER.error(format("Failed to getStreamPreparedStatement: <%s>.", e.getMessage()), e); + listener.error(e); + } finally { + listener.completed(); + } + } + + @Override + public void closePreparedStatement(final ActionClosePreparedStatementRequest request, final CallContext context, + final StreamListener listener) { + try { + preparedStatementLoadingCache.invalidate(request.getPreparedStatementHandle()); + } catch (final Exception e) { + listener.onError(e); + return; + } + listener.onCompleted(); + } + + @Override + public FlightInfo getFlightInfoStatement(final CommandStatementQuery request, final CallContext context, + final FlightDescriptor descriptor) { + ByteString handle = copyFrom(randomUUID().toString().getBytes(StandardCharsets.UTF_8)); + + try { + // Ownership of the connection will be passed to the context. Do NOT close! + final Connection connection = dataSource.getConnection(); + final Statement statement = connection.createStatement( + ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY); + final String query = request.getQuery(); + final StatementContext statementContext = new StatementContext<>(statement, query); + + statementLoadingCache.put(handle, statementContext); + final ResultSet resultSet = statement.executeQuery(query); + + TicketStatementQuery ticket = TicketStatementQuery.newBuilder() + .setStatementHandle(handle) + .build(); + return getFlightInfoForSchema(ticket, descriptor, + jdbcToArrowSchema(resultSet.getMetaData(), DEFAULT_CALENDAR)); + } catch (final SQLException e) { + LOGGER.error( + format("There was a problem executing the prepared statement: <%s>.", e.getMessage()), + e); + throw CallStatus.INTERNAL.withCause(e).toRuntimeException(); + } + } + + @Override + public FlightInfo getFlightInfoPreparedStatement(final CommandPreparedStatementQuery command, + final CallContext context, + final FlightDescriptor descriptor) { + final ByteString preparedStatementHandle = command.getPreparedStatementHandle(); + StatementContext statementContext = + preparedStatementLoadingCache.getIfPresent(preparedStatementHandle); + try { + assert statementContext != null; + PreparedStatement statement = statementContext.getStatement(); + + ResultSetMetaData metaData = statement.getMetaData(); + return getFlightInfoForSchema(command, descriptor, + jdbcToArrowSchema(metaData, DEFAULT_CALENDAR)); + } catch (final SQLException e) { + LOGGER.error( + format("There was a problem executing the prepared statement: <%s>.", e.getMessage()), + e); + throw CallStatus.INTERNAL.withCause(e).toRuntimeException(); + } + } + + @Override + public SchemaResult getSchemaStatement(final CommandStatementQuery command, final CallContext context, + final FlightDescriptor descriptor) { + throw CallStatus.UNIMPLEMENTED.toRuntimeException(); + } + + @Override + public void close() throws Exception { + try { + preparedStatementLoadingCache.cleanUp(); + } catch (Throwable t) { + LOGGER.error(format("Failed to close resources: <%s>", t.getMessage()), t); + } + + AutoCloseables.close(dataSource, rootAllocator); + } + + @Override + public void listFlights(CallContext context, Criteria criteria, StreamListener listener) { + // TODO - build example implementation + throw CallStatus.UNIMPLEMENTED.toRuntimeException(); + } + + @Override + public void createPreparedStatement(final ActionCreatePreparedStatementRequest request, final CallContext context, + final StreamListener listener) { + try { + final ByteString preparedStatementHandle = copyFrom(randomUUID().toString().getBytes(StandardCharsets.UTF_8)); + // Ownership of the connection will be passed to the context. Do NOT close! + final Connection connection = dataSource.getConnection(); + final PreparedStatement preparedStatement = connection.prepareStatement(request.getQuery(), + ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY); + final StatementContext preparedStatementContext = + new StatementContext<>(preparedStatement, request.getQuery()); + + preparedStatementLoadingCache.put(preparedStatementHandle, preparedStatementContext); + + final Schema parameterSchema = + jdbcToArrowSchema(preparedStatement.getParameterMetaData(), DEFAULT_CALENDAR); + + final ResultSetMetaData metaData = preparedStatement.getMetaData(); + final ByteString bytes = isNull(metaData) ? + ByteString.EMPTY : + ByteString.copyFrom( + serializeMetadata(jdbcToArrowSchema(metaData, DEFAULT_CALENDAR))); + final ActionCreatePreparedStatementResult result = ActionCreatePreparedStatementResult.newBuilder() + .setDatasetSchema(bytes) + .setParameterSchema(copyFrom(serializeMetadata(parameterSchema))) + .setPreparedStatementHandle(preparedStatementHandle) + .build(); + listener.onNext(new Result(pack(result).toByteArray())); + } catch (final Throwable t) { + listener.onError(t); + } finally { + listener.onCompleted(); + } + } + + @Override + public void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) { + // TODO - build example implementation + throw CallStatus.UNIMPLEMENTED.toRuntimeException(); + } + + @Override + public Runnable acceptPutStatement(CommandStatementUpdate command, + CallContext context, FlightStream flightStream, + StreamListener ackStream) { + final String query = command.getQuery(); + + return () -> { + try (final Connection connection = dataSource.getConnection(); + final Statement statement = connection.createStatement()) { + final int result = statement.executeUpdate(query); + + final DoPutUpdateResult build = + DoPutUpdateResult.newBuilder().setRecordCount(result).build(); + + try (final ArrowBuf buffer = rootAllocator.buffer(build.getSerializedSize())) { + buffer.writeBytes(build.toByteArray()); + ackStream.onNext(PutResult.metadata(buffer)); + ackStream.onCompleted(); + } + } catch (SQLException e) { + ackStream.onError(e); + } + }; + } + + @Override + public Runnable acceptPutPreparedStatementUpdate(CommandPreparedStatementUpdate command, CallContext context, + FlightStream flightStream, StreamListener ackStream) { + final StatementContext statement = + preparedStatementLoadingCache.getIfPresent(command.getPreparedStatementHandle()); + + return () -> { + assert statement != null; + try { + final PreparedStatement preparedStatement = statement.getStatement(); + + while (flightStream.next()) { + final VectorSchemaRoot root = flightStream.getRoot(); + + final int rowCount = root.getRowCount(); + final int recordCount; + + if (rowCount == 0) { + preparedStatement.execute(); + recordCount = preparedStatement.getUpdateCount(); + } else { + setDataPreparedStatement(preparedStatement, root, true); + int[] recordCount1 = preparedStatement.executeBatch(); + recordCount = Arrays.stream(recordCount1).sum(); + } + + final DoPutUpdateResult build = + DoPutUpdateResult.newBuilder().setRecordCount(recordCount).build(); + + try (final ArrowBuf buffer = rootAllocator.buffer(build.getSerializedSize())) { + buffer.writeBytes(build.toByteArray()); + ackStream.onNext(PutResult.metadata(buffer)); + } + } + } catch (SQLException e) { + ackStream.onError(e); + return; + } + ackStream.onCompleted(); + }; + } + + /** + * Method responsible to set the parameters, to the preparedStatement object, sent via doPut request. + * + * @param preparedStatement the preparedStatement object for the operation. + * @param root a {@link VectorSchemaRoot} object contain the values to be used in the + * PreparedStatement setters. + * @param isUpdate a flag to indicate if is an update or query operation. + * @throws SQLException in case of error. + */ + private void setDataPreparedStatement(PreparedStatement preparedStatement, VectorSchemaRoot root, + boolean isUpdate) + throws SQLException { + for (int i = 0; i < root.getRowCount(); i++) { + for (FieldVector vector : root.getFieldVectors()) { + final int vectorPosition = root.getFieldVectors().indexOf(vector); + final int position = vectorPosition + 1; + + if (vector instanceof UInt1Vector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (UInt1Vector) vector); + } else if (vector instanceof TimeStampNanoTZVector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (TimeStampNanoTZVector) vector); + } else if (vector instanceof TimeStampMicroTZVector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (TimeStampMicroTZVector) vector); + } else if (vector instanceof TimeStampMilliTZVector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (TimeStampMilliTZVector) vector); + } else if (vector instanceof TimeStampSecTZVector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (TimeStampSecTZVector) vector); + } else if (vector instanceof UInt2Vector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (UInt2Vector) vector); + } else if (vector instanceof UInt4Vector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (UInt4Vector) vector); + } else if (vector instanceof UInt8Vector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (UInt8Vector) vector); + } else if (vector instanceof TinyIntVector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (TinyIntVector) vector); + } else if (vector instanceof SmallIntVector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (SmallIntVector) vector); + } else if (vector instanceof IntVector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (IntVector) vector); + } else if (vector instanceof BigIntVector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (BigIntVector) vector); + } else if (vector instanceof Float4Vector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (Float4Vector) vector); + } else if (vector instanceof Float8Vector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (Float8Vector) vector); + } else if (vector instanceof BitVector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (BitVector) vector); + } else if (vector instanceof DecimalVector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (DecimalVector) vector); + } else if (vector instanceof Decimal256Vector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (Decimal256Vector) vector); + } else if (vector instanceof TimeStampVector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (TimeStampVector) vector); + } else if (vector instanceof TimeNanoVector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (TimeNanoVector) vector); + } else if (vector instanceof TimeMicroVector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (TimeMicroVector) vector); + } else if (vector instanceof TimeMilliVector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (TimeMilliVector) vector); + } else if (vector instanceof TimeSecVector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (TimeSecVector) vector); + } else if (vector instanceof DateDayVector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (DateDayVector) vector); + } else if (vector instanceof DateMilliVector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (DateMilliVector) vector); + } else if (vector instanceof VarCharVector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (VarCharVector) vector); + } else if (vector instanceof LargeVarCharVector) { + setOnPreparedStatement(preparedStatement, position, vectorPosition, (LargeVarCharVector) vector); + } + } + if (isUpdate) { + preparedStatement.addBatch(); + } + } + } + + protected TimeZone getTimeZoneForVector(TimeStampVector vector) { + ArrowType.Timestamp arrowType = (ArrowType.Timestamp) vector.getField().getFieldType().getType(); + + String timezoneName = arrowType.getTimezone(); + if (timezoneName == null) { + return TimeZone.getDefault(); + } + + return TimeZone.getTimeZone(timezoneName); + } + + /** + * Set a string parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, VarCharVector vector) + throws SQLException { + final Text object = vector.getObject(vectorIndex); + statement.setObject(column, object.toString()); + } + + /** + * Set a string parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, + LargeVarCharVector vector) + throws SQLException { + final Text object = vector.getObject(vectorIndex); + statement.setObject(column, object); + } + + /** + * Set a byte parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, TinyIntVector vector) + throws SQLException { + final Byte object = vector.getObject(vectorIndex); + statement.setObject(column, object); + } + + /** + * Set a short parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, SmallIntVector vector) + throws SQLException { + final Short object = vector.getObject(vectorIndex); + statement.setObject(column, object); + } + + /** + * Set an integer parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, IntVector vector) + throws SQLException { + final Integer object = vector.getObject(vectorIndex); + statement.setObject(column, object); + } + + /** + * Set a long parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, BigIntVector vector) + throws SQLException { + final Long object = vector.getObject(vectorIndex); + statement.setObject(column, object); + } + + /** + * Set a float parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, Float4Vector vector) + throws SQLException { + final Float object = vector.getObject(vectorIndex); + statement.setObject(column, object); + } + + /** + * Set a double parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, Float8Vector vector) + throws SQLException { + final Double object = vector.getObject(vectorIndex); + statement.setObject(column, object); + } + + /** + * Set a BigDecimal parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, DecimalVector vector) + throws SQLException { + final BigDecimal object = vector.getObject(vectorIndex); + statement.setObject(column, object); + } + + /** + * Set a BigDecimal parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, Decimal256Vector vector) + throws SQLException { + final BigDecimal object = vector.getObject(vectorIndex); + statement.setObject(column, object); + } + + /** + * Set a timestamp parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, TimeStampVector vector) + throws SQLException { + final Object object = vector.getObject(vectorIndex); + statement.setObject(column, object); + } + + /** + * Set a time parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, TimeNanoVector vector) + throws SQLException { + final Long object = vector.getObject(vectorIndex); + statement.setTime(column, new Time(object * 1000L)); + } + + /** + * Set a time parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, TimeMicroVector vector) + throws SQLException { + final Long object = vector.getObject(vectorIndex); + statement.setTime(column, new Time(object / 1000L)); + } + + /** + * Set a time parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, TimeMilliVector vector) + throws SQLException { + final LocalDateTime object = vector.getObject(vectorIndex); + statement.setTime(column, Time.valueOf(object.toLocalTime())); + } + + /** + * Set a time parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, TimeSecVector vector) + throws SQLException { + final Integer object = vector.getObject(vectorIndex); + statement.setTime(column, new Time(object)); + } + + /** + * Set a date parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, DateDayVector vector) + throws SQLException { + final Integer object = vector.getObject(vectorIndex); + statement.setDate(column, new Date(TimeUnit.DAYS.toMillis(object))); + } + + /** + * Set a date parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, DateMilliVector vector) + throws SQLException { + final LocalDateTime object = vector.getObject(vectorIndex); + statement.setDate(column, Date.valueOf(object.toLocalDate())); + + } + + /** + * Set an unsigned 1 byte number parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, UInt1Vector vector) + throws SQLException { + final Byte object = vector.getObject(vectorIndex); + statement.setObject(column, object); + } + + /** + * Set an unsigned 2 bytes number parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, UInt2Vector vector) + throws SQLException { + final Character object = vector.getObject(vectorIndex); + statement.setObject(column, object); + } + + /** + * Set an unsigned 4 bytes number parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, UInt4Vector vector) + throws SQLException { + final Integer object = vector.getObject(vectorIndex); + statement.setObject(column, object); + } + + /** + * Set an unsigned 8 bytes number parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, UInt8Vector vector) + throws SQLException { + final Long object = vector.getObject(vectorIndex); + statement.setObject(column, object); + } + + /** + * Set a boolean parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, BitVector vector) + throws SQLException { + final Boolean object = vector.getObject(vectorIndex); + statement.setObject(column, object); + } + + /** + * Set a timestamp parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, + TimeStampNanoTZVector vector) + throws SQLException { + final Long object = vector.getObject(vectorIndex); + statement.setTimestamp(column, new Timestamp(object / 1000000L), + Calendar.getInstance(getTimeZoneForVector(vector))); + } + + /** + * Set a timestamp parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, + TimeStampMicroTZVector vector) + throws SQLException { + final Long object = vector.getObject(vectorIndex); + statement.setTimestamp(column, new Timestamp(object / 1000L), + Calendar.getInstance(getTimeZoneForVector(vector))); + } + + /** + * Set a timestamp parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, + TimeStampMilliTZVector vector) + throws SQLException { + final Long object = vector.getObject(vectorIndex); + statement.setTimestamp(column, new Timestamp(object), + Calendar.getInstance(getTimeZoneForVector(vector))); + } + + /** + * Set a timestamp parameter to the preparedStatement object. + * + * @param statement an instance of the {@link PreparedStatement} class. + * @param column the index of the column in the {@link PreparedStatement}. + * @param vectorIndex the index from the vector which contain the value. + * @param vector an instance of the vector the will be accessed. + * @throws SQLException in case of error. + */ + public void setOnPreparedStatement(PreparedStatement statement, int column, int vectorIndex, + TimeStampSecTZVector vector) + throws SQLException { + final Long object = vector.getObject(vectorIndex); + statement.setTimestamp(column, new Timestamp(object * 1000L), + Calendar.getInstance(getTimeZoneForVector(vector))); + } + + @Override + public Runnable acceptPutPreparedStatementQuery(CommandPreparedStatementQuery command, CallContext context, + FlightStream flightStream, StreamListener ackStream) { + final StatementContext statementContext = + preparedStatementLoadingCache.getIfPresent(command.getPreparedStatementHandle()); + + return () -> { + assert statementContext != null; + PreparedStatement preparedStatement = statementContext.getStatement(); + + try { + while (flightStream.next()) { + final VectorSchemaRoot root = flightStream.getRoot(); + setDataPreparedStatement(preparedStatement, root, false); + } + + } catch (SQLException e) { + ackStream.onError(CallStatus.INTERNAL + .withDescription("Failed to bind parameters: " + e.getMessage()) + .withCause(e) + .toRuntimeException()); + return; + } + ackStream.onCompleted(); + }; + } + + @Override + public FlightInfo getFlightInfoSqlInfo(final CommandGetSqlInfo request, final CallContext context, + final FlightDescriptor descriptor) { + return getFlightInfoForSchema(request, descriptor, Schemas.GET_SQL_INFO_SCHEMA); + } + + @Override + public void getStreamSqlInfo(final CommandGetSqlInfo command, final CallContext context, + final ServerStreamListener listener) { + this.sqlInfoBuilder.send(command.getInfoList(), listener); + } + + @Override + public FlightInfo getFlightInfoCatalogs(final CommandGetCatalogs request, final CallContext context, + final FlightDescriptor descriptor) { + return getFlightInfoForSchema(request, descriptor, Schemas.GET_CATALOGS_SCHEMA); + } + + @Override + public void getStreamCatalogs(final CallContext context, final ServerStreamListener listener) { + try (final Connection connection = dataSource.getConnection(); + final ResultSet catalogs = connection.getMetaData().getCatalogs(); + final VectorSchemaRoot vectorSchemaRoot = getCatalogsRoot(catalogs, rootAllocator)) { + listener.start(vectorSchemaRoot); + listener.putNext(); + } catch (SQLException e) { + LOGGER.error(format("Failed to getStreamCatalogs: <%s>.", e.getMessage()), e); + listener.error(e); + } finally { + listener.completed(); + } + } + + @Override + public FlightInfo getFlightInfoSchemas(final CommandGetDbSchemas request, final CallContext context, + final FlightDescriptor descriptor) { + return getFlightInfoForSchema(request, descriptor, Schemas.GET_SCHEMAS_SCHEMA); + } + + @Override + public void getStreamSchemas(final CommandGetDbSchemas command, final CallContext context, + final ServerStreamListener listener) { + final String catalog = command.hasCatalog() ? command.getCatalog() : null; + final String schemaFilterPattern = command.hasDbSchemaFilterPattern() ? command.getDbSchemaFilterPattern() : null; + try (final Connection connection = dataSource.getConnection(); + final ResultSet schemas = connection.getMetaData().getSchemas(catalog, schemaFilterPattern); + final VectorSchemaRoot vectorSchemaRoot = getSchemasRoot(schemas, rootAllocator)) { + listener.start(vectorSchemaRoot); + listener.putNext(); + } catch (SQLException e) { + LOGGER.error(format("Failed to getStreamSchemas: <%s>.", e.getMessage()), e); + listener.error(e); + } finally { + listener.completed(); + } + } + + @Override + public FlightInfo getFlightInfoTables(final CommandGetTables request, final CallContext context, + final FlightDescriptor descriptor) { + return getFlightInfoForSchema(request, descriptor, Schemas.GET_TABLES_SCHEMA); + } + + @Override + public void getStreamTables(final CommandGetTables command, final CallContext context, + final ServerStreamListener listener) { + final String catalog = command.hasCatalog() ? command.getCatalog() : null; + final String schemaFilterPattern = + command.hasDbSchemaFilterPattern() ? command.getDbSchemaFilterPattern() : null; + final String tableFilterPattern = + command.hasTableNameFilterPattern() ? command.getTableNameFilterPattern() : null; + + final ProtocolStringList protocolStringList = command.getTableTypesList(); + final int protocolSize = protocolStringList.size(); + final String[] tableTypes = + protocolSize == 0 ? null : protocolStringList.toArray(new String[protocolSize]); + + try (final Connection connection = DriverManager.getConnection(DATABASE_URI); + final VectorSchemaRoot vectorSchemaRoot = getTablesRoot( + connection.getMetaData(), + rootAllocator, + command.getIncludeSchema(), + catalog, schemaFilterPattern, tableFilterPattern, tableTypes)) { + listener.start(vectorSchemaRoot); + listener.putNext(); + } catch (SQLException | IOException e) { + LOGGER.error(format("Failed to getStreamTables: <%s>.", e.getMessage()), e); + listener.error(e); + } finally { + listener.completed(); + } + } + + @Override + public FlightInfo getFlightInfoTableTypes(final CommandGetTableTypes request, final CallContext context, + final FlightDescriptor descriptor) { + return getFlightInfoForSchema(request, descriptor, Schemas.GET_TABLE_TYPES_SCHEMA); + } + + @Override + public void getStreamTableTypes(final CallContext context, final ServerStreamListener listener) { + try (final Connection connection = dataSource.getConnection(); + final ResultSet tableTypes = connection.getMetaData().getTableTypes(); + final VectorSchemaRoot vectorSchemaRoot = getTableTypesRoot(tableTypes, rootAllocator)) { + listener.start(vectorSchemaRoot); + listener.putNext(); + } catch (SQLException e) { + LOGGER.error(format("Failed to getStreamTableTypes: <%s>.", e.getMessage()), e); + listener.error(e); + } finally { + listener.completed(); + } + } + + @Override + public FlightInfo getFlightInfoPrimaryKeys(final CommandGetPrimaryKeys request, final CallContext context, + final FlightDescriptor descriptor) { + return getFlightInfoForSchema(request, descriptor, Schemas.GET_PRIMARY_KEYS_SCHEMA); + } + + @Override + public void getStreamPrimaryKeys(final CommandGetPrimaryKeys command, final CallContext context, + final ServerStreamListener listener) { + + final String catalog = command.hasCatalog() ? command.getCatalog() : null; + final String schema = command.hasDbSchema() ? command.getDbSchema() : null; + final String table = command.getTable(); + + try (Connection connection = DriverManager.getConnection(DATABASE_URI)) { + final ResultSet primaryKeys = connection.getMetaData().getPrimaryKeys(catalog, schema, table); + + final VarCharVector catalogNameVector = new VarCharVector("catalog_name", rootAllocator); + final VarCharVector schemaNameVector = new VarCharVector("db_schema_name", rootAllocator); + final VarCharVector tableNameVector = new VarCharVector("table_name", rootAllocator); + final VarCharVector columnNameVector = new VarCharVector("column_name", rootAllocator); + final IntVector keySequenceVector = new IntVector("key_sequence", rootAllocator); + final VarCharVector keyNameVector = new VarCharVector("key_name", rootAllocator); + + final List vectors = + new ArrayList<>( + ImmutableList.of( + catalogNameVector, schemaNameVector, tableNameVector, columnNameVector, keySequenceVector, + keyNameVector)); + vectors.forEach(FieldVector::allocateNew); + + int rows = 0; + for (; primaryKeys.next(); rows++) { + saveToVector(primaryKeys.getString("TABLE_CAT"), catalogNameVector, rows); + saveToVector(primaryKeys.getString("TABLE_SCHEM"), schemaNameVector, rows); + saveToVector(primaryKeys.getString("TABLE_NAME"), tableNameVector, rows); + saveToVector(primaryKeys.getString("COLUMN_NAME"), columnNameVector, rows); + final int key_seq = primaryKeys.getInt("KEY_SEQ"); + saveToVector(primaryKeys.wasNull() ? null : key_seq, keySequenceVector, rows); + saveToVector(primaryKeys.getString("PK_NAME"), keyNameVector, rows); + } + + try (final VectorSchemaRoot vectorSchemaRoot = new VectorSchemaRoot(vectors)) { + vectorSchemaRoot.setRowCount(rows); + + listener.start(vectorSchemaRoot); + listener.putNext(); + } + } catch (SQLException e) { + listener.error(e); + } finally { + listener.completed(); + } + } + + @Override + public FlightInfo getFlightInfoExportedKeys(final CommandGetExportedKeys request, final CallContext context, + final FlightDescriptor descriptor) { + return getFlightInfoForSchema(request, descriptor, Schemas.GET_EXPORTED_KEYS_SCHEMA); + } + + @Override + public void getStreamExportedKeys(final CommandGetExportedKeys command, final CallContext context, + final ServerStreamListener listener) { + String catalog = command.hasCatalog() ? command.getCatalog() : null; + String schema = command.hasDbSchema() ? command.getDbSchema() : null; + String table = command.getTable(); + + try (Connection connection = DriverManager.getConnection(DATABASE_URI); + ResultSet keys = connection.getMetaData().getExportedKeys(catalog, schema, table); + VectorSchemaRoot vectorSchemaRoot = createVectors(keys)) { + listener.start(vectorSchemaRoot); + listener.putNext(); + } catch (SQLException e) { + listener.error(e); + } finally { + listener.completed(); + } + } + + @Override + public FlightInfo getFlightInfoImportedKeys(final CommandGetImportedKeys request, final CallContext context, + final FlightDescriptor descriptor) { + return getFlightInfoForSchema(request, descriptor, Schemas.GET_IMPORTED_KEYS_SCHEMA); + } + + @Override + public void getStreamImportedKeys(final CommandGetImportedKeys command, final CallContext context, + final ServerStreamListener listener) { + String catalog = command.hasCatalog() ? command.getCatalog() : null; + String schema = command.hasDbSchema() ? command.getDbSchema() : null; + String table = command.getTable(); + + try (Connection connection = DriverManager.getConnection(DATABASE_URI); + ResultSet keys = connection.getMetaData().getImportedKeys(catalog, schema, table); + VectorSchemaRoot vectorSchemaRoot = createVectors(keys)) { + listener.start(vectorSchemaRoot); + listener.putNext(); + } catch (final SQLException e) { + listener.error(e); + } finally { + listener.completed(); + } + } + + @Override + public FlightInfo getFlightInfoCrossReference(CommandGetCrossReference request, CallContext context, + FlightDescriptor descriptor) { + return getFlightInfoForSchema(request, descriptor, Schemas.GET_CROSS_REFERENCE_SCHEMA); + } + + @Override + public void getStreamCrossReference(CommandGetCrossReference command, CallContext context, + ServerStreamListener listener) { + final String pkCatalog = command.hasPkCatalog() ? command.getPkCatalog() : null; + final String pkSchema = command.hasPkDbSchema() ? command.getPkDbSchema() : null; + final String fkCatalog = command.hasFkCatalog() ? command.getFkCatalog() : null; + final String fkSchema = command.hasFkDbSchema() ? command.getFkDbSchema() : null; + final String pkTable = command.getPkTable(); + final String fkTable = command.getFkTable(); + + try (Connection connection = DriverManager.getConnection(DATABASE_URI); + ResultSet keys = connection.getMetaData() + .getCrossReference(pkCatalog, pkSchema, pkTable, fkCatalog, fkSchema, fkTable); + VectorSchemaRoot vectorSchemaRoot = createVectors(keys)) { + listener.start(vectorSchemaRoot); + listener.putNext(); + } catch (final SQLException e) { + listener.error(e); + } finally { + listener.completed(); + } + } + + private VectorSchemaRoot createVectors(ResultSet keys) throws SQLException { + final VarCharVector pkCatalogNameVector = new VarCharVector("pk_catalog_name", rootAllocator); + final VarCharVector pkSchemaNameVector = new VarCharVector("pk_db_schema_name", rootAllocator); + final VarCharVector pkTableNameVector = new VarCharVector("pk_table_name", rootAllocator); + final VarCharVector pkColumnNameVector = new VarCharVector("pk_column_name", rootAllocator); + final VarCharVector fkCatalogNameVector = new VarCharVector("fk_catalog_name", rootAllocator); + final VarCharVector fkSchemaNameVector = new VarCharVector("fk_db_schema_name", rootAllocator); + final VarCharVector fkTableNameVector = new VarCharVector("fk_table_name", rootAllocator); + final VarCharVector fkColumnNameVector = new VarCharVector("fk_column_name", rootAllocator); + final IntVector keySequenceVector = new IntVector("key_sequence", rootAllocator); + final VarCharVector fkKeyNameVector = new VarCharVector("fk_key_name", rootAllocator); + final VarCharVector pkKeyNameVector = new VarCharVector("pk_key_name", rootAllocator); + final UInt1Vector updateRuleVector = new UInt1Vector("update_rule", rootAllocator); + final UInt1Vector deleteRuleVector = new UInt1Vector("delete_rule", rootAllocator); + + Map vectorToColumnName = new HashMap<>(); + vectorToColumnName.put(pkCatalogNameVector, "PKTABLE_CAT"); + vectorToColumnName.put(pkSchemaNameVector, "PKTABLE_SCHEM"); + vectorToColumnName.put(pkTableNameVector, "PKTABLE_NAME"); + vectorToColumnName.put(pkColumnNameVector, "PKCOLUMN_NAME"); + vectorToColumnName.put(fkCatalogNameVector, "FKTABLE_CAT"); + vectorToColumnName.put(fkSchemaNameVector, "FKTABLE_SCHEM"); + vectorToColumnName.put(fkTableNameVector, "FKTABLE_NAME"); + vectorToColumnName.put(fkColumnNameVector, "FKCOLUMN_NAME"); + vectorToColumnName.put(keySequenceVector, "KEY_SEQ"); + vectorToColumnName.put(updateRuleVector, "UPDATE_RULE"); + vectorToColumnName.put(deleteRuleVector, "DELETE_RULE"); + vectorToColumnName.put(fkKeyNameVector, "FK_NAME"); + vectorToColumnName.put(pkKeyNameVector, "PK_NAME"); + + final VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.of( + pkCatalogNameVector, pkSchemaNameVector, pkTableNameVector, pkColumnNameVector, fkCatalogNameVector, + fkSchemaNameVector, fkTableNameVector, fkColumnNameVector, keySequenceVector, fkKeyNameVector, + pkKeyNameVector, updateRuleVector, deleteRuleVector); + + vectorSchemaRoot.allocateNew(); + final int rowCount = saveToVectors(vectorToColumnName, keys, true); + + vectorSchemaRoot.setRowCount(rowCount); + + return vectorSchemaRoot; + } + + @Override + public void getStreamStatement(final TicketStatementQuery ticketStatementQuery, final CallContext context, + final ServerStreamListener listener) { + final ByteString handle = ticketStatementQuery.getStatementHandle(); + final StatementContext statementContext = + Objects.requireNonNull(statementLoadingCache.getIfPresent(handle)); + try (final ResultSet resultSet = statementContext.getStatement().getResultSet()) { + final Schema schema = jdbcToArrowSchema(resultSet.getMetaData(), DEFAULT_CALENDAR); + try (VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(schema, rootAllocator)) { + final VectorLoader loader = new VectorLoader(vectorSchemaRoot); + listener.start(vectorSchemaRoot); + + final ArrowVectorIterator iterator = sqlToArrowVectorIterator(resultSet, rootAllocator); + while (iterator.hasNext()) { + final VectorUnloader unloader = new VectorUnloader(iterator.next()); + loader.load(unloader.getRecordBatch()); + listener.putNext(); + vectorSchemaRoot.clear(); + } + + listener.putNext(); + } + } catch (SQLException | IOException e) { + LOGGER.error(format("Failed to getStreamPreparedStatement: <%s>.", e.getMessage()), e); + listener.error(e); + } finally { + listener.completed(); + statementLoadingCache.invalidate(handle); + } + } + + private FlightInfo getFlightInfoForSchema(final T request, final FlightDescriptor descriptor, + final Schema schema) { + final Ticket ticket = new Ticket(pack(request).toByteArray()); + // TODO Support multiple endpoints. + final List endpoints = singletonList(new FlightEndpoint(ticket, location)); + + return new FlightInfo(schema, descriptor, endpoints, -1, -1); + } + + private static class StatementRemovalListener + implements RemovalListener> { + @Override + public void onRemoval(final RemovalNotification> notification) { + try { + AutoCloseables.close(notification.getValue()); + } catch (final Exception e) { + // swallow + } + } + } +} diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/StatementContext.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/StatementContext.java new file mode 100644 index 00000000000..764ef3f54aa --- /dev/null +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/StatementContext.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql.example; + +import java.sql.Connection; +import java.sql.Statement; +import java.util.Objects; + +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.util.AutoCloseables; + +/** + * Context for {@link T} to be persisted in memory in between {@link FlightSqlProducer} calls. + * + * @param the {@link Statement} to be persisted. + */ +public final class StatementContext implements AutoCloseable { + + private final T statement; + private final String query; + + public StatementContext(final T statement, final String query) { + this.statement = Objects.requireNonNull(statement, "statement cannot be null."); + this.query = query; + } + + /** + * Gets the statement wrapped by this {@link StatementContext}. + * + * @return the inner statement. + */ + public T getStatement() { + return statement; + } + + /** + * Gets the optional SQL query wrapped by this {@link StatementContext}. + * + * @return the SQL query if present; empty otherwise. + */ + public String getQuery() { + return query; + } + + @Override + public void close() throws Exception { + Connection connection = statement.getConnection(); + AutoCloseables.close(statement, connection); + } + + @Override + public boolean equals(final Object other) { + if (this == other) { + return true; + } + if (!(other instanceof StatementContext)) { + return false; + } + final StatementContext that = (StatementContext) other; + return statement.equals(that.statement); + } + + @Override + public int hashCode() { + return Objects.hash(statement); + } +} diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/util/AdhocTestOption.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/util/AdhocTestOption.java new file mode 100644 index 00000000000..6988a86049d --- /dev/null +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/util/AdhocTestOption.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql.util; + +import com.google.protobuf.Descriptors.EnumDescriptor; +import com.google.protobuf.Descriptors.EnumValueDescriptor; +import com.google.protobuf.ProtocolMessageEnum; + +enum AdhocTestOption implements ProtocolMessageEnum { + OPTION_A, OPTION_B, OPTION_C; + + @Override + public int getNumber() { + return ordinal(); + } + + @Override + public EnumValueDescriptor getValueDescriptor() { + throw getUnsupportedException(); + } + + @Override + public EnumDescriptor getDescriptorForType() { + throw getUnsupportedException(); + } + + private UnsupportedOperationException getUnsupportedException() { + return new UnsupportedOperationException("Unimplemented method is irrelevant for the scope of this test."); + } +} diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/util/SqlInfoOptionsUtilsBitmaskCreationTest.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/util/SqlInfoOptionsUtilsBitmaskCreationTest.java new file mode 100644 index 00000000000..6f2b66646bb --- /dev/null +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/util/SqlInfoOptionsUtilsBitmaskCreationTest.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql.util; + +import static java.util.Arrays.asList; +import static org.apache.arrow.flight.sql.util.AdhocTestOption.OPTION_A; +import static org.apache.arrow.flight.sql.util.AdhocTestOption.OPTION_B; +import static org.apache.arrow.flight.sql.util.AdhocTestOption.OPTION_C; +import static org.apache.arrow.flight.sql.util.SqlInfoOptionsUtils.createBitmaskFromEnums; +import static org.hamcrest.CoreMatchers.is; + +import java.util.List; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) +public final class SqlInfoOptionsUtilsBitmaskCreationTest { + + @Parameter + public AdhocTestOption[] adhocTestOptions; + @Parameter(value = 1) + public long expectedBitmask; + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + @Parameters + public static List provideParameters() { + return asList( + new Object[][]{ + {new AdhocTestOption[0], 0L}, + {new AdhocTestOption[]{OPTION_A}, 1L}, + {new AdhocTestOption[]{OPTION_B}, 0b10L}, + {new AdhocTestOption[]{OPTION_A, OPTION_B}, 0b11L}, + {new AdhocTestOption[]{OPTION_C}, 0b100L}, + {new AdhocTestOption[]{OPTION_A, OPTION_C}, 0b101L}, + {new AdhocTestOption[]{OPTION_B, OPTION_C}, 0b110L}, + {AdhocTestOption.values(), 0b111L}, + }); + } + + @Test + public void testShouldBuildBitmaskFromEnums() { + collector.checkThat(createBitmaskFromEnums(adhocTestOptions), is(expectedBitmask)); + } +} diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/util/SqlInfoOptionsUtilsBitmaskParsingTest.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/util/SqlInfoOptionsUtilsBitmaskParsingTest.java new file mode 100644 index 00000000000..decee38ee0a --- /dev/null +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/util/SqlInfoOptionsUtilsBitmaskParsingTest.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql.util; + +import static java.util.Arrays.asList; +import static java.util.Arrays.stream; +import static java.util.stream.Collectors.toCollection; +import static org.apache.arrow.flight.sql.util.AdhocTestOption.OPTION_A; +import static org.apache.arrow.flight.sql.util.AdhocTestOption.OPTION_B; +import static org.apache.arrow.flight.sql.util.AdhocTestOption.OPTION_C; +import static org.apache.arrow.flight.sql.util.SqlInfoOptionsUtils.doesBitmaskTranslateToEnum; +import static org.hamcrest.CoreMatchers.is; + +import java.util.EnumSet; +import java.util.List; +import java.util.Set; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) +public final class SqlInfoOptionsUtilsBitmaskParsingTest { + + @Parameter + public long bitmask; + @Parameter(value = 1) + public Set expectedOptions; + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + @Parameters + public static List provideParameters() { + return asList( + new Object[][]{ + {0L, EnumSet.noneOf(AdhocTestOption.class)}, + {1L, EnumSet.of(OPTION_A)}, + {0b10L, EnumSet.of(OPTION_B)}, + {0b11L, EnumSet.of(OPTION_A, OPTION_B)}, + {0b100L, EnumSet.of(OPTION_C)}, + {0b101L, EnumSet.of(OPTION_A, OPTION_C)}, + {0b110L, EnumSet.of(OPTION_B, OPTION_C)}, + {0b111L, EnumSet.allOf(AdhocTestOption.class)}, + }); + } + + @Test + public void testShouldFilterOutEnumsBasedOnBitmask() { + final Set actualOptions = + stream(AdhocTestOption.values()) + .filter(enumInstance -> doesBitmaskTranslateToEnum(enumInstance, bitmask)) + .collect(toCollection(() -> EnumSet.noneOf(AdhocTestOption.class))); + collector.checkThat(actualOptions, is(expectedOptions)); + } +} diff --git a/java/flight/pom.xml b/java/flight/pom.xml new file mode 100644 index 00000000000..7cb0e1d7171 --- /dev/null +++ b/java/flight/pom.xml @@ -0,0 +1,58 @@ + + + + + arrow-java-root + org.apache.arrow + 7.0.0-SNAPSHOT + + 4.0.0 + + Arrow Flight + arrow-flight + + pom + + + 1.41.0 + 3.17.3 + + + + flight-core + flight-grpc + flight-sql + flight-integration-tests + + + + + + + org.xolstice.maven.plugins + protobuf-maven-plugin + 0.6.1 + + + com.google.protobuf:protoc:${dep.protobuf.version}:exe:${os.detected.classifier} + + grpc-java + io.grpc:protoc-gen-grpc-java:${dep.grpc.version}:exe:${os.detected.classifier} + + + + + + + diff --git a/java/pom.xml b/java/pom.xml index 007f4533ad3..7059f0027f4 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -84,6 +84,14 @@ + + + + kr.motd.maven + os-maven-plugin + 1.5.0.Final + + @@ -565,6 +573,11 @@ 2.8.2 provided + + org.hamcrest + hamcrest + 2.2 + @@ -676,8 +689,7 @@ tools adapter/jdbc plasma - flight/flight-core - flight/flight-grpc + flight performance algorithm adapter/avro diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Field.java b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Field.java index 3a5ef11537a..54c609d4a10 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Field.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Field.java @@ -64,6 +64,10 @@ public static Field nullable(String name, ArrowType type) { return new Field(name, FieldType.nullable(type), null); } + public static Field notNullable(String name, ArrowType type) { + return new Field(name, FieldType.notNullable(type), null); + } + private final String name; private final FieldType fieldType; private final List children; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java index bb3250ef102..d5c0d85671f 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java @@ -41,6 +41,10 @@ public static FieldType nullable(ArrowType type) { return new FieldType(true, type, null, null); } + public static FieldType notNullable(ArrowType type) { + return new FieldType(false, type, null, null); + } + private final boolean nullable; private final ArrowType type; private final DictionaryEncoding dictionary;