From 05ba8279073d229cb825ba00740cbbabda7b1203 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 9 Sep 2022 11:48:34 -0400 Subject: [PATCH 1/5] ARROW-17661: [C++][Python] Add Flight SQL ADBC driver Co-authored-by: Sutou Kouhei --- ci/scripts/python_build.sh | 1 + ci/scripts/python_test.sh | 4 + cpp/cmake_modules/ThirdpartyToolchain.cmake | 120 + cpp/src/arrow/c/adbc_internal.h | 1207 ++++++++++ cpp/src/arrow/flight/sql/CMakeLists.txt | 13 +- cpp/src/arrow/flight/sql/adbc_driver.cc | 2062 +++++++++++++++++ .../arrow/flight/sql/adbc_driver_internal.cc | 191 ++ .../arrow/flight/sql/adbc_driver_internal.h | 198 ++ cpp/src/arrow/flight/sql/adbc_driver_test.cc | 1057 +++++++++ cpp/src/arrow/flight/sql/client.cc | 1 + cpp/src/arrow/flight/sql/client.h | 14 + cpp/src/arrow/symbols.map | 3 + cpp/thirdparty/versions.txt | 6 + python/CMakeLists.txt | 29 + python/pyarrow/_flight_sql.pyx | 29 + python/pyarrow/conftest.py | 8 + python/pyarrow/flight_sql.py | 153 ++ .../pyarrow/includes/libarrow_flight_sql.pxd | 26 + python/pyarrow/tests/test_flight_sql.py | 67 + python/setup.py | 6 + 20 files changed, 5193 insertions(+), 2 deletions(-) create mode 100644 cpp/src/arrow/c/adbc_internal.h create mode 100644 cpp/src/arrow/flight/sql/adbc_driver.cc create mode 100644 cpp/src/arrow/flight/sql/adbc_driver_internal.cc create mode 100644 cpp/src/arrow/flight/sql/adbc_driver_internal.h create mode 100644 cpp/src/arrow/flight/sql/adbc_driver_test.cc create mode 100644 python/pyarrow/_flight_sql.pyx create mode 100644 python/pyarrow/flight_sql.py create mode 100644 python/pyarrow/includes/libarrow_flight_sql.pxd create mode 100644 python/pyarrow/tests/test_flight_sql.py diff --git a/ci/scripts/python_build.sh b/ci/scripts/python_build.sh index cfac68bd6ec..1275e6405f3 100755 --- a/ci/scripts/python_build.sh +++ b/ci/scripts/python_build.sh @@ -57,6 +57,7 @@ export PYARROW_BUILD_TYPE=${CMAKE_BUILD_TYPE:-debug} export PYARROW_WITH_CUDA=${ARROW_CUDA:-OFF} export PYARROW_WITH_DATASET=${ARROW_DATASET:-ON} export PYARROW_WITH_FLIGHT=${ARROW_FLIGHT:-OFF} +export PYARROW_WITH_FLIGHT_SQL=${ARROW_FLIGHT_SQL:-OFF} export PYARROW_WITH_GANDIVA=${ARROW_GANDIVA:-OFF} export PYARROW_WITH_GCS=${ARROW_GCS:-OFF} export PYARROW_WITH_HDFS=${ARROW_HDFS:-ON} diff --git a/ci/scripts/python_test.sh b/ci/scripts/python_test.sh index 2d5bd5dd9ff..d2e1cb206d1 100755 --- a/ci/scripts/python_test.sh +++ b/ci/scripts/python_test.sh @@ -39,6 +39,9 @@ export ARROW_DEBUG_MEMORY_POOL=trap : ${PYARROW_TEST_CUDA:=${ARROW_CUDA:-ON}} : ${PYARROW_TEST_DATASET:=${ARROW_DATASET:-ON}} : ${PYARROW_TEST_FLIGHT:=${ARROW_FLIGHT:-ON}} +# Flight SQL must be enabled explicitly due to optional dependency on +# ADBC driver manager +: ${PYARROW_TEST_FLIGHT_SQL:=OFF} : ${PYARROW_TEST_GANDIVA:=${ARROW_GANDIVA:-ON}} : ${PYARROW_TEST_GCS:=${ARROW_GCS:-ON}} : ${PYARROW_TEST_HDFS:=${ARROW_HDFS:-ON}} @@ -49,6 +52,7 @@ export ARROW_DEBUG_MEMORY_POOL=trap export PYARROW_TEST_CUDA export PYARROW_TEST_DATASET export PYARROW_TEST_FLIGHT +export PYARROW_TEST_FLIGHT_SQL export PYARROW_TEST_GANDIVA export PYARROW_TEST_GCS export PYARROW_TEST_HDFS diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 3eda538fb2e..365aa936585 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -45,6 +45,7 @@ set(ARROW_RE2_LINKAGE set(ARROW_THIRDPARTY_DEPENDENCIES absl + AdbcValidation AWSSDK benchmark Boost @@ -59,6 +60,7 @@ set(ARROW_THIRDPARTY_DEPENDENCIES jemalloc LLVM lz4 + nanoarrow nlohmann_json opentelemetry-cpp ORC @@ -149,6 +151,8 @@ endforeach() macro(build_dependency DEPENDENCY_NAME) if("${DEPENDENCY_NAME}" STREQUAL "absl") build_absl() + elseif("${DEPENDENCY_NAME}" STREQUAL "AdbcValidation") + build_adbc_validation() elseif("${DEPENDENCY_NAME}" STREQUAL "AWSSDK") build_awssdk() elseif("${DEPENDENCY_NAME}" STREQUAL "benchmark") @@ -175,6 +179,8 @@ macro(build_dependency DEPENDENCY_NAME) build_jemalloc() elseif("${DEPENDENCY_NAME}" STREQUAL "lz4") build_lz4() + elseif("${DEPENDENCY_NAME}" STREQUAL "nanoarrow") + build_nanoarrow() elseif("${DEPENDENCY_NAME}" STREQUAL "nlohmann_json") build_nlohmann_json() elseif("${DEPENDENCY_NAME}" STREQUAL "opentelemetry-cpp") @@ -345,6 +351,11 @@ if(ARROW_FLIGHT) set(ARROW_WITH_GRPC ON) endif() +if(ARROW_FLIGHT_SQL AND ARROW_BUILD_TESTS) + set(ARROW_WITH_ADBC_VALIDATION ON) + set(ARROW_WITH_NANOARROW ON) +endif() + if(ARROW_WITH_GRPC) set(ARROW_WITH_RE2 ON) set(ARROW_WITH_ZLIB ON) @@ -431,6 +442,14 @@ else() ) endif() +if(DEFINED ENV{ARROW_ADBC_URL}) + set(ADBC_SOURCE_URL "$ENV{ARROW_ADBC_URL}") +else() + set_urls(ADBC_SOURCE_URL + "https://github.com/apache/arrow-adbc/archive/${ARROW_ADBC_BUILD_VERSION}.tar.gz" + ) +endif() + if(DEFINED ENV{ARROW_AWS_C_COMMON_URL}) set(AWS_C_COMMON_SOURCE_URL "$ENV{ARROW_AWS_C_COMMON_URL}") else() @@ -577,6 +596,14 @@ else() "${THIRDPARTY_MIRROR_URL}/mimalloc-${ARROW_MIMALLOC_BUILD_VERSION}.tar.gz") endif() +if(DEFINED ENV{ARROW_NANOARROW_URL}) + set(NANOARROW_SOURCE_URL "$ENV{ARROW_NANOARROW_URL}") +else() + set_urls(NANOARROW_SOURCE_URL + "https://github.com/apache/arrow-nanoarrow/archive/${ARROW_NANOARROW_BUILD_VERSION}.tar.gz" + ) +endif() + if(DEFINED ENV{ARROW_NLOHMANN_JSON_URL}) set(NLOHMANN_JSON_SOURCE_URL "$ENV{ARROW_NLOHMANN_JSON_URL}") else() @@ -4584,6 +4611,99 @@ if(ARROW_WITH_OPENTELEMETRY) message(STATUS "Found OpenTelemetry headers: ${OPENTELEMETRY_INCLUDE_DIR}") endif() +# ---------------------------------------------------------------------- +# nanoarrow (only used for tests) + +macro(build_nanoarrow) + message(STATUS "Building nanoarrow from source") + set(NANOARROW_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/nanoarrow_ep-install") + set(NANOARROW_INCLUDE_DIR "${NANOARROW_PREFIX}/include") + set(NANOARROW_LIB_DIR "lib") + + set(NANOARROW_COMMON_CMAKE_ARGS + ${EP_COMMON_CMAKE_ARGS} + "-DCMAKE_INSTALL_LIBDIR=${NANOARROW_LIB_DIR}" + "-DCMAKE_INSTALL_PREFIX=${NANOARROW_PREFIX}" + "-DCMAKE_PREFIX_PATH=${NANOARROW_PREFIX}" + "-DCMAKE_UNITY_BUILD=ON") + set(NANOARROW_STATIC_LIBRARY + "${NANOARROW_PREFIX}/${NANOARROW_LIB_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}nanoarrow${CMAKE_STATIC_LIBRARY_SUFFIX}" + ) + + file(MAKE_DIRECTORY ${NANOARROW_INCLUDE_DIR}) + + add_library(nanoarrow STATIC IMPORTED) + set_target_properties(nanoarrow + PROPERTIES IMPORTED_LOCATION "${NANOARROW_STATIC_LIBRARY}" + INTERFACE_INCLUDE_DIRECTORIES + "${NANOARROW_INCLUDE_DIR}") + externalproject_add(nanoarrow_ep + ${EP_LOG_OPTIONS} + URL ${NANOARROW_SOURCE_URL} + URL_HASH "SHA256=${ARROW_NANOARROW_BUILD_SHA256_CHECKSUM}" + CMAKE_ARGS ${NANOARROW_COMMON_CMAKE_ARGS} + BUILD_BYPRODUCTS "${NANOARROW_STATIC_LIBRARY}") + set(NANOARROW_LINK_LIBRARIES nanoarrow) + add_dependencies(nanoarrow nanoarrow_ep) +endmacro() + +if(ARROW_WITH_NANOARROW) + set(nanoarrow_SOURCE "AUTO") + resolve_dependency(nanoarrow HAVE_ALT FALSE) + + message(STATUS "Found nanoarrow headers: ${NANOARROW_INCLUDE_DIR}") + message(STATUS "Found nanoarrow libraries: ${NANOARROW_LINK_LIBRARIES}") +endif() + +# ---------------------------------------------------------------------- +# ADBC validation suite (only used for tests) + +macro(build_adbc_validation) + message(STATUS "Building ADBC validation suite from source") + set(ADBC_VALIDATION_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/adbcvalidation_ep-install") + set(ADBC_VALIDATION_INCLUDE_DIR "${ADBC_VALIDATION_PREFIX}/include") + set(ADBC_VALIDATION_LIB_DIR "lib") + + set(ADBC_VALIDATION_COMMON_CMAKE_ARGS + ${EP_COMMON_CMAKE_ARGS} + "-DCMAKE_INSTALL_LIBDIR=${ADBC_VALIDATION_LIB_DIR}" + "-DCMAKE_INSTALL_PREFIX=${ADBC_VALIDATION_PREFIX}" + "-DCMAKE_PREFIX_PATH=${ADBC_VALIDATION_PREFIX}" + "-DCMAKE_UNITY_BUILD=ON") + set(ADBC_VALIDATION_STATIC_LIBRARY + "${ADBC_VALIDATION_PREFIX}/${ADBC_VALIDATION_LIB_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}adbc_validation${CMAKE_STATIC_LIBRARY_SUFFIX}" + ) + + file(MAKE_DIRECTORY ${ADBC_VALIDATION_INCLUDE_DIR}) + + add_library(AdbcValidation::adbc_validation STATIC IMPORTED) + set_target_properties(AdbcValidation::adbc_validation + PROPERTIES IMPORTED_LOCATION "${ADBC_VALIDATION_STATIC_LIBRARY}" + INTERFACE_INCLUDE_DIRECTORIES + "${ADBC_VALIDATION_INCLUDE_DIR}" + INTERFACE_LINK_LIBRARIES nanoarrow + GTest::gtest GTest::gmock) + externalproject_add(adbcvalidation_ep + ${EP_LOG_OPTIONS} + URL ${ADBC_SOURCE_URL} + URL_HASH "SHA256=${ARROW_ADBC_BUILD_SHA256_CHECKSUM}" + CMAKE_ARGS ${ADBC_VALIDATION_COMMON_CMAKE_ARGS} + BUILD_BYPRODUCTS "${ADBC_VALIDATION_STATIC_LIBRARY}" + SOURCE_SUBDIR c/validation) + add_dependencies(AdbcValidation::adbc_validation nanoarrow_ep) + set(ADBCVALIDATION_LINK_LIBRARIES AdbcValidation::adbc_validation) + add_dependencies(AdbcValidation::adbc_validation adbcvalidation_ep) +endmacro() + +if(ARROW_WITH_ADBC_VALIDATION) + set(AdbcValidation_SOURCE "AUTO") + resolve_dependency(AdbcValidation HAVE_ALT FALSE) + + message(STATUS "Found ADBC validation suite headers: ${ADBC_VALIDATION_INCLUDE_DIR}") + message(STATUS "Found ADBC validation suite libraries: ${ADBC_VALIDATION_LINK_LIBRARIES}" + ) +endif() + # ---------------------------------------------------------------------- # AWS SDK for C++ diff --git a/cpp/src/arrow/c/adbc_internal.h b/cpp/src/arrow/c/adbc_internal.h new file mode 100644 index 00000000000..a1ff53441db --- /dev/null +++ b/cpp/src/arrow/c/adbc_internal.h @@ -0,0 +1,1207 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// \file adbc.h ADBC: Arrow Database connectivity +/// +/// An Arrow-based interface between applications and database +/// drivers. ADBC aims to provide a vendor-independent API for SQL +/// and Substrait-based database access that is targeted at +/// analytics/OLAP use cases. +/// +/// This API is intended to be implemented directly by drivers and +/// used directly by client applications. To assist portability +/// between different vendors, a "driver manager" library is also +/// provided, which implements this same API, but dynamically loads +/// drivers internally and forwards calls appropriately. +/// +/// ADBC uses structs with free functions that operate on those +/// structs to model objects. +/// +/// In general, objects allow serialized access from multiple threads, +/// but not concurrent access. Specific implementations may permit +/// multiple threads. +/// +/// \version 1.0.0 + +#pragma once + +#include +#include + +/// \defgroup Arrow C Data Interface +/// Definitions for the C Data Interface/C Stream Interface. +/// +/// See https://arrow.apache.org/docs/format/CDataInterface.html +/// +/// @{ + +//! @cond Doxygen_Suppress + +#ifdef __cplusplus +extern "C" { +#endif + +// Extra guard for versions of Arrow without the canonical guard +#ifndef ARROW_FLAG_DICTIONARY_ORDERED + +#ifndef ARROW_C_DATA_INTERFACE +#define ARROW_C_DATA_INTERFACE + +#define ARROW_FLAG_DICTIONARY_ORDERED 1 +#define ARROW_FLAG_NULLABLE 2 +#define ARROW_FLAG_MAP_KEYS_SORTED 4 + +struct ArrowSchema { + // Array type description + const char* format; + const char* name; + const char* metadata; + int64_t flags; + int64_t n_children; + struct ArrowSchema** children; + struct ArrowSchema* dictionary; + + // Release callback + void (*release)(struct ArrowSchema*); + // Opaque producer-specific data + void* private_data; +}; + +struct ArrowArray { + // Array data description + int64_t length; + int64_t null_count; + int64_t offset; + int64_t n_buffers; + int64_t n_children; + const void** buffers; + struct ArrowArray** children; + struct ArrowArray* dictionary; + + // Release callback + void (*release)(struct ArrowArray*); + // Opaque producer-specific data + void* private_data; +}; + +#endif // ARROW_C_DATA_INTERFACE + +#ifndef ARROW_C_STREAM_INTERFACE +#define ARROW_C_STREAM_INTERFACE + +struct ArrowArrayStream { + // Callback to get the stream type + // (will be the same for all arrays in the stream). + // + // Return value: 0 if successful, an `errno`-compatible error code otherwise. + // + // If successful, the ArrowSchema must be released independently from the stream. + int (*get_schema)(struct ArrowArrayStream*, struct ArrowSchema* out); + + // Callback to get the next array + // (if no error and the array is released, the stream has ended) + // + // Return value: 0 if successful, an `errno`-compatible error code otherwise. + // + // If successful, the ArrowArray must be released independently from the stream. + int (*get_next)(struct ArrowArrayStream*, struct ArrowArray* out); + + // Callback to get optional detailed error information. + // This must only be called if the last stream operation failed + // with a non-0 return code. + // + // Return value: pointer to a null-terminated character array describing + // the last error, or NULL if no description is available. + // + // The returned pointer is only valid until the next operation on this stream + // (including release). + const char* (*get_last_error)(struct ArrowArrayStream*); + + // Release callback: release the stream's own resources. + // Note that arrays returned by `get_next` must be individually released. + void (*release)(struct ArrowArrayStream*); + + // Opaque producer-specific data + void* private_data; +}; + +#endif // ARROW_C_STREAM_INTERFACE +#endif // ARROW_FLAG_DICTIONARY_ORDERED + +//! @endcond + +/// @} + +#ifndef ADBC +#define ADBC + +// Storage class macros for Windows +// Allow overriding/aliasing with application-defined macros +#if !defined(ADBC_EXPORT) +#if defined(_WIN32) +#if defined(ADBC_EXPORTING) +#define ADBC_EXPORT __declspec(dllexport) +#else +#define ADBC_EXPORT __declspec(dllimport) +#endif // defined(ADBC_EXPORTING) +#else +#define ADBC_EXPORT +#endif // defined(_WIN32) +#endif // !defined(ADBC_EXPORT) + +/// \defgroup adbc-error-handling Error Handling +/// ADBC uses integer error codes to signal errors. To provide more +/// detail about errors, functions may also return an AdbcError via an +/// optional out parameter, which can be inspected. If provided, it is +/// the responsibility of the caller to zero-initialize the AdbcError +/// value. +/// +/// @{ + +/// \brief Error codes for operations that may fail. +typedef uint8_t AdbcStatusCode; + +/// \brief No error. +#define ADBC_STATUS_OK 0 +/// \brief An unknown error occurred. +/// +/// May indicate a driver-side or database-side error. +#define ADBC_STATUS_UNKNOWN 1 +/// \brief The operation is not implemented or supported. +/// +/// May indicate a driver-side or database-side error. +#define ADBC_STATUS_NOT_IMPLEMENTED 2 +/// \brief A requested resource was not found. +/// +/// May indicate a driver-side or database-side error. +#define ADBC_STATUS_NOT_FOUND 3 +/// \brief A requested resource already exists. +/// +/// May indicate a driver-side or database-side error. +#define ADBC_STATUS_ALREADY_EXISTS 4 +/// \brief The arguments are invalid, likely a programming error. +/// +/// For instance, they may be of the wrong format, or out of range. +/// +/// May indicate a driver-side or database-side error. +#define ADBC_STATUS_INVALID_ARGUMENT 5 +/// \brief The preconditions for the operation are not met, likely a +/// programming error. +/// +/// For instance, the object may be uninitialized, or may have not +/// been fully configured. +/// +/// May indicate a driver-side or database-side error. +#define ADBC_STATUS_INVALID_STATE 6 +/// \brief Invalid data was processed (not a programming error). +/// +/// For instance, a division by zero may have occurred during query +/// execution. +/// +/// May indicate a database-side error only. +#define ADBC_STATUS_INVALID_DATA 7 +/// \brief The database's integrity was affected. +/// +/// For instance, a foreign key check may have failed, or a uniqueness +/// constraint may have been violated. +/// +/// May indicate a database-side error only. +#define ADBC_STATUS_INTEGRITY 8 +/// \brief An error internal to the driver or database occurred. +/// +/// May indicate a driver-side or database-side error. +#define ADBC_STATUS_INTERNAL 9 +/// \brief An I/O error occurred. +/// +/// For instance, a remote service may be unavailable. +/// +/// May indicate a driver-side or database-side error. +#define ADBC_STATUS_IO 10 +/// \brief The operation was cancelled, not due to a timeout. +/// +/// May indicate a driver-side or database-side error. +#define ADBC_STATUS_CANCELLED 11 +/// \brief The operation was cancelled due to a timeout. +/// +/// May indicate a driver-side or database-side error. +#define ADBC_STATUS_TIMEOUT 12 +/// \brief Authentication failed. +/// +/// May indicate a database-side error only. +#define ADBC_STATUS_UNAUTHENTICATED 13 +/// \brief The client is not authorized to perform the given operation. +/// +/// May indicate a database-side error only. +#define ADBC_STATUS_UNAUTHORIZED 14 + +/// \brief A detailed error message for an operation. +struct ADBC_EXPORT AdbcError { + /// \brief The error message. + char* message; + + /// \brief A vendor-specific error code, if applicable. + int32_t vendor_code; + + /// \brief A SQLSTATE error code, if provided, as defined by the + /// SQL:2003 standard. If not set, it should be set to + /// "\0\0\0\0\0". + char sqlstate[5]; + + /// \brief Release the contained error. + /// + /// Unlike other structures, this is an embedded callback to make it + /// easier for the driver manager and driver to cooperate. + void (*release)(struct AdbcError* error); +}; + +/// @} + +/// \defgroup adbc-constants Constants +/// @{ + +/// \brief ADBC revision 1.0.0. +/// +/// When passed to an AdbcDriverInitFunc(), the driver parameter must +/// point to an AdbcDriver. +#define ADBC_VERSION_1_0_0 1000000 + +/// \brief Canonical option value for enabling an option. +/// +/// For use as the value in SetOption calls. +#define ADBC_OPTION_VALUE_ENABLED "true" +/// \brief Canonical option value for disabling an option. +/// +/// For use as the value in SetOption calls. +#define ADBC_OPTION_VALUE_DISABLED "false" + +/// \brief The database vendor/product name (e.g. the server name). +/// (type: utf8). +/// +/// \see AdbcConnectionGetInfo +#define ADBC_INFO_VENDOR_NAME 0 +/// \brief The database vendor/product version (type: utf8). +/// +/// \see AdbcConnectionGetInfo +#define ADBC_INFO_VENDOR_VERSION 1 +/// \brief The database vendor/product Arrow library version (type: +/// utf8). +/// +/// \see AdbcConnectionGetInfo +#define ADBC_INFO_VENDOR_ARROW_VERSION 2 + +/// \brief The driver name (type: utf8). +/// +/// \see AdbcConnectionGetInfo +#define ADBC_INFO_DRIVER_NAME 100 +/// \brief The driver version (type: utf8). +/// +/// \see AdbcConnectionGetInfo +#define ADBC_INFO_DRIVER_VERSION 101 +/// \brief The driver Arrow library version (type: utf8). +/// +/// \see AdbcConnectionGetInfo +#define ADBC_INFO_DRIVER_ARROW_VERSION 102 + +/// \brief Return metadata on catalogs, schemas, tables, and columns. +/// +/// \see AdbcConnectionGetObjects +#define ADBC_OBJECT_DEPTH_ALL 0 +/// \brief Return metadata on catalogs only. +/// +/// \see AdbcConnectionGetObjects +#define ADBC_OBJECT_DEPTH_CATALOGS 1 +/// \brief Return metadata on catalogs and schemas. +/// +/// \see AdbcConnectionGetObjects +#define ADBC_OBJECT_DEPTH_DB_SCHEMAS 2 +/// \brief Return metadata on catalogs, schemas, and tables. +/// +/// \see AdbcConnectionGetObjects +#define ADBC_OBJECT_DEPTH_TABLES 3 +/// \brief Return metadata on catalogs, schemas, tables, and columns. +/// +/// \see AdbcConnectionGetObjects +#define ADBC_OBJECT_DEPTH_COLUMNS ADBC_OBJECT_DEPTH_ALL + +/// \brief The name of the canonical option for whether autocommit is +/// enabled. +/// +/// \see AdbcConnectionSetOption +#define ADBC_CONNECTION_OPTION_AUTOCOMMIT "adbc.connection.autocommit" + +/// \brief The name of the canonical option for whether the current +/// connection should be restricted to being read-only. +/// +/// \see AdbcConnectionSetOption +#define ADBC_CONNECTION_OPTION_READ_ONLY "adbc.connection.readonly" + +/// \brief The name of the canonical option for setting the isolation +/// level of a transaction. +/// +/// Should only be used in conjunction with autocommit disabled and +/// AdbcConnectionCommit / AdbcConnectionRollback. If the desired +/// isolation level is not supported by a driver, it should return an +/// appropriate error. +/// +/// \see AdbcConnectionSetOption +#define ADBC_CONNECTION_OPTION_ISOLATION_LEVEL \ + "adbc.connection.transaction.isolation_level" + +/// \brief Use database or driver default isolation level +/// +/// \see AdbcConnectionSetOption +#define ADBC_OPTION_ISOLATION_LEVEL_DEFAULT \ + "adbc.connection.transaction.isolation.default" + +/// \brief The lowest isolation level. Dirty reads are allowed, so one +/// transaction may see not-yet-committed changes made by others. +/// +/// \see AdbcConnectionSetOption +#define ADBC_OPTION_ISOLATION_LEVEL_READ_UNCOMMITTED \ + "adbc.connection.transaction.isolation.read_uncommitted" + +/// \brief Lock-based concurrency control keeps write locks until the +/// end of the transaction, but read locks are released as soon as a +/// SELECT is performed. Non-repeatable reads can occur in this +/// isolation level. +/// +/// More simply put, Read Committed is an isolation level that guarantees +/// that any data read is committed at the moment it is read. It simply +/// restricts the reader from seeing any intermediate, uncommitted, +/// 'dirty' reads. It makes no promise whatsoever that if the transaction +/// re-issues the read, it will find the same data; data is free to change +/// after it is read. +/// +/// \see AdbcConnectionSetOption +#define ADBC_OPTION_ISOLATION_LEVEL_READ_COMMITTED \ + "adbc.connection.transaction.isolation.read_committed" + +/// \brief Lock-based concurrency control keeps read AND write locks +/// (acquired on selection data) until the end of the transaction. +/// +/// However, range-locks are not managed, so phantom reads can occur. +/// Write skew is possible at this isolation level in some systems. +/// +/// \see AdbcConnectionSetOption +#define ADBC_OPTION_ISOLATION_LEVEL_REPEATABLE_READ \ + "adbc.connection.transaction.isolation.repeatable_read" + +/// \brief This isolation guarantees that all reads in the transaction +/// will see a consistent snapshot of the database and the transaction +/// should only successfully commit if no updates conflict with any +/// concurrent updates made since that snapshot. +/// +/// \see AdbcConnectionSetOption +#define ADBC_OPTION_ISOLATION_LEVEL_SNAPSHOT \ + "adbc.connection.transaction.isolation.snapshot" + +/// \brief Serializability requires read and write locks to be released +/// only at the end of the transaction. This includes acquiring range- +/// locks when a select query uses a ranged WHERE clause to avoid +/// phantom reads. +/// +/// \see AdbcConnectionSetOption +#define ADBC_OPTION_ISOLATION_LEVEL_SERIALIZABLE \ + "adbc.connection.transaction.isolation.serializable" + +/// \brief The central distinction between serializability and linearizability +/// is that serializability is a global property; a property of an entire +/// history of operations and transactions. Linearizability is a local +/// property; a property of a single operation/transaction. +/// +/// Linearizability can be viewed as a special case of strict serializability +/// where transactions are restricted to consist of a single operation applied +/// to a single object. +/// +/// \see AdbcConnectionSetOption +#define ADBC_OPTION_ISOLATION_LEVEL_LINEARIZABLE \ + "adbc.connection.transaction.isolation.linearizable" + +/// \defgroup adbc-statement-ingestion Bulk Data Ingestion +/// While it is possible to insert data via prepared statements, it can +/// be more efficient to explicitly perform a bulk insert. For +/// compatible drivers, this can be accomplished by setting up and +/// executing a statement. Instead of setting a SQL query or Substrait +/// plan, bind the source data via AdbcStatementBind, and set the name +/// of the table to be created via AdbcStatementSetOption and the +/// options below. Then, call AdbcStatementExecute with +/// ADBC_OUTPUT_TYPE_UPDATE. +/// +/// @{ + +/// \brief The name of the target table for a bulk insert. +/// +/// The driver should attempt to create the table if it does not +/// exist. If the table exists but has a different schema, +/// ADBC_STATUS_ALREADY_EXISTS should be raised. Else, data should be +/// appended to the target table. +#define ADBC_INGEST_OPTION_TARGET_TABLE "adbc.ingest.target_table" +/// \brief Whether to create (the default) or append. +#define ADBC_INGEST_OPTION_MODE "adbc.ingest.mode" +/// \brief Create the table and insert data; error if the table exists. +#define ADBC_INGEST_OPTION_MODE_CREATE "adbc.ingest.mode.create" +/// \brief Do not create the table, and insert data; error if the +/// table does not exist (ADBC_STATUS_NOT_FOUND) or does not match +/// the schema of the data to append (ADBC_STATUS_ALREADY_EXISTS). +#define ADBC_INGEST_OPTION_MODE_APPEND "adbc.ingest.mode.append" + +/// @} + +/// @} + +/// \defgroup adbc-database Database Initialization +/// Clients first initialize a database, then create a connection +/// (below). This gives the implementation a place to initialize and +/// own any common connection state. For example, in-memory databases +/// can place ownership of the actual database in this object. +/// @{ + +/// \brief An instance of a database. +/// +/// Must be kept alive as long as any connections exist. +struct ADBC_EXPORT AdbcDatabase { + /// \brief Opaque implementation-defined state. + /// This field is NULLPTR iff the connection is unintialized/freed. + void* private_data; + /// \brief The associated driver (used by the driver manager to help + /// track state). + struct AdbcDriver* private_driver; +}; + +/// @} + +/// \defgroup adbc-connection Connection Establishment +/// Functions for creating, using, and releasing database connections. +/// @{ + +/// \brief An active database connection. +/// +/// Provides methods for query execution, managing prepared +/// statements, using transactions, and so on. +/// +/// Connections are not required to be thread-safe, but they can be +/// used from multiple threads so long as clients take care to +/// serialize accesses to a connection. +struct ADBC_EXPORT AdbcConnection { + /// \brief Opaque implementation-defined state. + /// This field is NULLPTR iff the connection is unintialized/freed. + void* private_data; + /// \brief The associated driver (used by the driver manager to help + /// track state). + struct AdbcDriver* private_driver; +}; + +/// @} + +/// \defgroup adbc-statement Managing Statements +/// Applications should first initialize a statement with +/// AdbcStatementNew. Then, the statement should be configured with +/// functions like AdbcStatementSetSqlQuery and +/// AdbcStatementSetOption. Finally, the statement can be executed +/// with AdbcStatementExecuteQuery (or call AdbcStatementPrepare first +/// to turn it into a prepared statement instead). +/// @{ + +/// \brief A container for all state needed to execute a database +/// query, such as the query itself, parameters for prepared +/// statements, driver parameters, etc. +/// +/// Statements may represent queries or prepared statements. +/// +/// Statements may be used multiple times and can be reconfigured +/// (e.g. they can be reused to execute multiple different queries). +/// However, executing a statement (and changing certain other state) +/// will invalidate result sets obtained prior to that execution. +/// +/// Multiple statements may be created from a single connection. +/// However, the driver may block or error if they are used +/// concurrently (whether from a single thread or multiple threads). +/// +/// Statements are not required to be thread-safe, but they can be +/// used from multiple threads so long as clients take care to +/// serialize accesses to a statement. +struct ADBC_EXPORT AdbcStatement { + /// \brief Opaque implementation-defined state. + /// This field is NULLPTR iff the connection is unintialized/freed. + void* private_data; + + /// \brief The associated driver (used by the driver manager to help + /// track state). + struct AdbcDriver* private_driver; +}; + +/// \defgroup adbc-statement-partition Partitioned Results +/// Some backends may internally partition the results. These +/// partitions are exposed to clients who may wish to integrate them +/// with a threaded or distributed execution model, where partitions +/// can be divided among threads or machines and fetched in parallel. +/// +/// To use partitioning, execute the statement with +/// AdbcStatementExecutePartitions to get the partition descriptors. +/// Call AdbcConnectionReadPartition to turn the individual +/// descriptors into ArrowArrayStream instances. This may be done on +/// a different connection than the one the partition was created +/// with, or even in a different process on another machine. +/// +/// Drivers are not required to support partitioning. +/// +/// @{ + +/// \brief The partitions of a distributed/partitioned result set. +struct AdbcPartitions { + /// \brief The number of partitions. + size_t num_partitions; + + /// \brief The partitions of the result set, where each entry (up to + /// num_partitions entries) is an opaque identifier that can be + /// passed to AdbcConnectionReadPartition. + const uint8_t** partitions; + + /// \brief The length of each corresponding entry in partitions. + const size_t* partition_lengths; + + /// \brief Opaque implementation-defined state. + /// This field is NULLPTR iff the connection is unintialized/freed. + void* private_data; + + /// \brief Release the contained partitions. + /// + /// Unlike other structures, this is an embedded callback to make it + /// easier for the driver manager and driver to cooperate. + void (*release)(struct AdbcPartitions* partitions); +}; + +/// @} + +/// @} + +/// \defgroup adbc-driver Driver Initialization +/// +/// These functions are intended to help support integration between a +/// driver and the driver manager. +/// @{ + +/// \brief An instance of an initialized database driver. +/// +/// This provides a common interface for vendor-specific driver +/// initialization routines. Drivers should populate this struct, and +/// applications can call ADBC functions through this struct, without +/// worrying about multiple definitions of the same symbol. +struct ADBC_EXPORT AdbcDriver { + /// \brief Opaque driver-defined state. + /// This field is NULL if the driver is unintialized/freed (but + /// it need not have a value even if the driver is initialized). + void* private_data; + /// \brief Opaque driver manager-defined state. + /// This field is NULL if the driver is unintialized/freed (but + /// it need not have a value even if the driver is initialized). + void* private_manager; + + /// \brief Release the driver and perform any cleanup. + /// + /// This is an embedded callback to make it easier for the driver + /// manager and driver to cooperate. + AdbcStatusCode (*release)(struct AdbcDriver* driver, struct AdbcError* error); + + AdbcStatusCode (*DatabaseInit)(struct AdbcDatabase*, struct AdbcError*); + AdbcStatusCode (*DatabaseNew)(struct AdbcDatabase*, struct AdbcError*); + AdbcStatusCode (*DatabaseSetOption)(struct AdbcDatabase*, const char*, const char*, + struct AdbcError*); + AdbcStatusCode (*DatabaseRelease)(struct AdbcDatabase*, struct AdbcError*); + + AdbcStatusCode (*ConnectionCommit)(struct AdbcConnection*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetInfo)(struct AdbcConnection*, uint32_t*, size_t, + struct ArrowArrayStream*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetObjects)(struct AdbcConnection*, int, const char*, + const char*, const char*, const char**, + const char*, struct ArrowArrayStream*, + struct AdbcError*); + AdbcStatusCode (*ConnectionGetTableSchema)(struct AdbcConnection*, const char*, + const char*, const char*, + struct ArrowSchema*, struct AdbcError*); + AdbcStatusCode (*ConnectionGetTableTypes)(struct AdbcConnection*, + struct ArrowArrayStream*, struct AdbcError*); + AdbcStatusCode (*ConnectionInit)(struct AdbcConnection*, struct AdbcDatabase*, + struct AdbcError*); + AdbcStatusCode (*ConnectionNew)(struct AdbcConnection*, struct AdbcError*); + AdbcStatusCode (*ConnectionSetOption)(struct AdbcConnection*, const char*, const char*, + struct AdbcError*); + AdbcStatusCode (*ConnectionReadPartition)(struct AdbcConnection*, const uint8_t*, + size_t, struct ArrowArrayStream*, + struct AdbcError*); + AdbcStatusCode (*ConnectionRelease)(struct AdbcConnection*, struct AdbcError*); + AdbcStatusCode (*ConnectionRollback)(struct AdbcConnection*, struct AdbcError*); + + AdbcStatusCode (*StatementBind)(struct AdbcStatement*, struct ArrowArray*, + struct ArrowSchema*, struct AdbcError*); + AdbcStatusCode (*StatementBindStream)(struct AdbcStatement*, struct ArrowArrayStream*, + struct AdbcError*); + AdbcStatusCode (*StatementExecuteQuery)(struct AdbcStatement*, struct ArrowArrayStream*, + int64_t*, struct AdbcError*); + AdbcStatusCode (*StatementExecutePartitions)(struct AdbcStatement*, struct ArrowSchema*, + struct AdbcPartitions*, int64_t*, + struct AdbcError*); + AdbcStatusCode (*StatementGetParameterSchema)(struct AdbcStatement*, + struct ArrowSchema*, struct AdbcError*); + AdbcStatusCode (*StatementNew)(struct AdbcConnection*, struct AdbcStatement*, + struct AdbcError*); + AdbcStatusCode (*StatementPrepare)(struct AdbcStatement*, struct AdbcError*); + AdbcStatusCode (*StatementRelease)(struct AdbcStatement*, struct AdbcError*); + AdbcStatusCode (*StatementSetOption)(struct AdbcStatement*, const char*, const char*, + struct AdbcError*); + AdbcStatusCode (*StatementSetSqlQuery)(struct AdbcStatement*, const char*, + struct AdbcError*); + AdbcStatusCode (*StatementSetSubstraitPlan)(struct AdbcStatement*, const uint8_t*, + size_t, struct AdbcError*); +}; + +/// @} + +/// \addtogroup adbc-database +/// @{ + +/// \brief Allocate a new (but uninitialized) database. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error); + +/// \brief Set a char* option. +/// +/// Options may be set before AdbcDatabaseInit. Some drivers may +/// support setting options after initialization as well. +/// +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, + const char* value, struct AdbcError* error); + +/// \brief Finish setting options and initialize the database. +/// +/// Some drivers may support setting options after initialization +/// as well. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error); + +/// \brief Destroy this database. No connections may exist. +/// \param[in] database The database to release. +/// \param[out] error An optional location to return an error +/// message if necessary. +ADBC_EXPORT +AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, + struct AdbcError* error); + +/// @} + +/// \addtogroup adbc-connection +/// @{ + +/// \brief Allocate a new (but uninitialized) connection. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, + struct AdbcError* error); + +/// \brief Set a char* option. +/// +/// Options may be set before AdbcConnectionInit. Some drivers may +/// support setting options after initialization as well. +/// +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the option is not recognized +ADBC_EXPORT +AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, + const char* value, struct AdbcError* error); + +/// \brief Finish setting options and initialize the connection. +/// +/// Some drivers may support setting options after initialization +/// as well. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, + struct AdbcDatabase* database, struct AdbcError* error); + +/// \brief Destroy this connection. +/// +/// \param[in] connection The connection to release. +/// \param[out] error An optional location to return an error +/// message if necessary. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, + struct AdbcError* error); + +/// \defgroup adbc-connection-metadata Metadata +/// Functions for retrieving metadata about the database. +/// +/// Generally, these functions return an ArrowArrayStream that can be +/// consumed to get the metadata as Arrow data. The returned metadata +/// has an expected schema given in the function docstring. Schema +/// fields are nullable unless otherwise marked. While no +/// AdbcStatement is used in these functions, the result set may count +/// as an active statement to the driver for the purposes of +/// concurrency management (e.g. if the driver has a limit on +/// concurrent active statements and it must execute a SQL query +/// internally in order to implement the metadata function). +/// +/// Some functions accept "search pattern" arguments, which are +/// strings that can contain the special character "%" to match zero +/// or more characters, or "_" to match exactly one character. (See +/// the documentation of DatabaseMetaData in JDBC or "Pattern Value +/// Arguments" in the ODBC documentation.) Escaping is not currently +/// supported. +/// +/// @{ + +/// \brief Get metadata about the database/driver. +/// +/// The result is an Arrow dataset with the following schema: +/// +/// Field Name | Field Type +/// ----------------------------|------------------------ +/// info_name | uint32 not null +/// info_value | INFO_SCHEMA +/// +/// INFO_SCHEMA is a dense union with members: +/// +/// Field Name (Type Code) | Field Type +/// ----------------------------|------------------------ +/// string_value (0) | utf8 +/// bool_value (1) | bool +/// int64_value (2) | int64 +/// int32_bitmask (3) | int32 +/// string_list (4) | list +/// int32_to_int32_list_map (5) | map> +/// +/// Each metadatum is identified by an integer code. The recognized +/// codes are defined as constants. Codes [0, 10_000) are reserved +/// for ADBC usage. Drivers/vendors will ignore requests for +/// unrecognized codes (the row will be omitted from the result). +/// +/// \param[in] connection The connection to query. +/// \param[in] info_codes A list of metadata codes to fetch, or NULL +/// to fetch all. +/// \param[in] info_codes_length The length of the info_codes +/// parameter. Ignored if info_codes is NULL. +/// \param[out] out The result set. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, + uint32_t* info_codes, size_t info_codes_length, + struct ArrowArrayStream* out, + struct AdbcError* error); + +/// \brief Get a hierarchical view of all catalogs, database schemas, +/// tables, and columns. +/// +/// The result is an Arrow dataset with the following schema: +/// +/// | Field Name | Field Type | +/// |--------------------------|-------------------------| +/// | catalog_name | utf8 | +/// | catalog_db_schemas | list | +/// +/// DB_SCHEMA_SCHEMA is a Struct with fields: +/// +/// | Field Name | Field Type | +/// |--------------------------|-------------------------| +/// | db_schema_name | utf8 | +/// | db_schema_tables | list | +/// +/// TABLE_SCHEMA is a Struct with fields: +/// +/// | Field Name | Field Type | +/// |--------------------------|-------------------------| +/// | table_name | utf8 not null | +/// | table_type | utf8 not null | +/// | table_columns | list | +/// | table_constraints | list | +/// +/// COLUMN_SCHEMA is a Struct with fields: +/// +/// | Field Name | Field Type | Comments | +/// |--------------------------|-------------------------|----------| +/// | column_name | utf8 not null | | +/// | ordinal_position | int32 | (1) | +/// | remarks | utf8 | (2) | +/// | xdbc_data_type | int16 | (3) | +/// | xdbc_type_name | utf8 | (3) | +/// | xdbc_column_size | int32 | (3) | +/// | xdbc_decimal_digits | int16 | (3) | +/// | xdbc_num_prec_radix | int16 | (3) | +/// | xdbc_nullable | int16 | (3) | +/// | xdbc_column_def | utf8 | (3) | +/// | xdbc_sql_data_type | int16 | (3) | +/// | xdbc_datetime_sub | int16 | (3) | +/// | xdbc_char_octet_length | int32 | (3) | +/// | xdbc_is_nullable | utf8 | (3) | +/// | xdbc_scope_catalog | utf8 | (3) | +/// | xdbc_scope_schema | utf8 | (3) | +/// | xdbc_scope_table | utf8 | (3) | +/// | xdbc_is_autoincrement | bool | (3) | +/// | xdbc_is_generatedcolumn | bool | (3) | +/// +/// 1. The column's ordinal position in the table (starting from 1). +/// 2. Database-specific description of the column. +/// 3. Optional value. Should be null if not supported by the driver. +/// xdbc_ values are meant to provide JDBC/ODBC-compatible metadata +/// in an agnostic manner. +/// +/// CONSTRAINT_SCHEMA is a Struct with fields: +/// +/// | Field Name | Field Type | Comments | +/// |--------------------------|-------------------------|----------| +/// | constraint_name | utf8 | | +/// | constraint_type | utf8 not null | (1) | +/// | constraint_column_names | list not null | (2) | +/// | constraint_column_usage | list | (3) | +/// +/// 1. One of 'CHECK', 'FOREIGN KEY', 'PRIMARY KEY', or 'UNIQUE'. +/// 2. The columns on the current table that are constrained, in +/// order. +/// 3. For FOREIGN KEY only, the referenced table and columns. +/// +/// USAGE_SCHEMA is a Struct with fields: +/// +/// | Field Name | Field Type | +/// |--------------------------|-------------------------| +/// | fk_catalog | utf8 | +/// | fk_db_schema | utf8 | +/// | fk_table | utf8 not null | +/// | fk_column_name | utf8 not null | +/// +/// \param[in] connection The database connection. +/// \param[in] depth The level of nesting to display. If 0, display +/// all levels. If 1, display only catalogs (i.e. catalog_schemas +/// will be null). If 2, display only catalogs and schemas +/// (i.e. db_schema_tables will be null), and so on. +/// \param[in] catalog Only show tables in the given catalog. If NULL, +/// do not filter by catalog. If an empty string, only show tables +/// without a catalog. May be a search pattern (see section +/// documentation). +/// \param[in] db_schema Only show tables in the given database schema. If +/// NULL, do not filter by database schema. If an empty string, only show +/// tables without a database schema. May be a search pattern (see section +/// documentation). +/// \param[in] table_name Only show tables with the given name. If NULL, do not +/// filter by name. May be a search pattern (see section documentation). +/// \param[in] table_type Only show tables matching one of the given table +/// types. If NULL, show tables of any type. Valid table types can be fetched +/// from GetTableTypes. Terminate the list with a NULL entry. +/// \param[in] column_name Only show columns with the given name. If +/// NULL, do not filter by name. May be a search pattern (see +/// section documentation). +/// \param[out] out The result set. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth, + const char* catalog, const char* db_schema, + const char* table_name, const char** table_type, + const char* column_name, + struct ArrowArrayStream* out, + struct AdbcError* error); + +/// \brief Get the Arrow schema of a table. +/// +/// \param[in] connection The database connection. +/// \param[in] catalog The catalog (or nullptr if not applicable). +/// \param[in] db_schema The database schema (or nullptr if not applicable). +/// \param[in] table_name The table name. +/// \param[out] schema The table schema. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, + struct ArrowSchema* schema, + struct AdbcError* error); + +/// \brief Get a list of table types in the database. +/// +/// The result is an Arrow dataset with the following schema: +/// +/// Field Name | Field Type +/// ---------------|-------------- +/// table_type | utf8 not null +/// +/// \param[in] connection The database connection. +/// \param[out] out The result set. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error); + +/// @} + +/// \defgroup adbc-connection-partition Partitioned Results +/// Some databases may internally partition the results. These +/// partitions are exposed to clients who may wish to integrate them +/// with a threaded or distributed execution model, where partitions +/// can be divided among threads or machines for processing. +/// +/// Drivers are not required to support partitioning. +/// +/// Partitions are not ordered. If the result set is sorted, +/// implementations should return a single partition. +/// +/// @{ + +/// \brief Construct a statement for a partition of a query. The +/// results can then be read independently. +/// +/// A partition can be retrieved from AdbcPartitions. +/// +/// \param[in] connection The connection to use. This does not have +/// to be the same connection that the partition was created on. +/// \param[in] serialized_partition The partition descriptor. +/// \param[in] serialized_length The partition descriptor length. +/// \param[out] out The result set. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, + const uint8_t* serialized_partition, + size_t serialized_length, + struct ArrowArrayStream* out, + struct AdbcError* error); + +/// @} + +/// \defgroup adbc-connection-transaction Transaction Semantics +/// +/// Connections start out in auto-commit mode by default (if +/// applicable for the given vendor). Use AdbcConnectionSetOption and +/// ADBC_CONNECTION_OPTION_AUTO_COMMIT to change this. +/// +/// @{ + +/// \brief Commit any pending transactions. Only used if autocommit is +/// disabled. +/// +/// Behavior is undefined if this is mixed with SQL transaction +/// statements. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, + struct AdbcError* error); + +/// \brief Roll back any pending transactions. Only used if autocommit +/// is disabled. +/// +/// Behavior is undefined if this is mixed with SQL transaction +/// statements. +ADBC_EXPORT +AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, + struct AdbcError* error); + +/// @} + +/// @} + +/// \addtogroup adbc-statement +/// @{ + +/// \brief Create a new statement for a given connection. +/// +/// Set options on the statement, then call AdbcStatementExecuteQuery +/// or AdbcStatementPrepare. +ADBC_EXPORT +AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, + struct AdbcStatement* statement, struct AdbcError* error); + +/// \brief Destroy a statement. +/// \param[in] statement The statement to release. +/// \param[out] error An optional location to return an error +/// message if necessary. +ADBC_EXPORT +AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, + struct AdbcError* error); + +/// \brief Execute a statement and get the results. +/// +/// This invalidates any prior result sets. +/// +/// \param[in] statement The statement to execute. +/// \param[out] out The results. Pass NULL if the client does not +/// expect a result set. +/// \param[out] rows_affected The number of rows affected if known, +/// else -1. Pass NULL if the client does not want this information. +/// \param[out] error An optional location to return an error +/// message if necessary. +ADBC_EXPORT +AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, + struct ArrowArrayStream* out, + int64_t* rows_affected, struct AdbcError* error); + +/// \brief Turn this statement into a prepared statement to be +/// executed multiple times. +/// +/// This invalidates any prior result sets. +ADBC_EXPORT +AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, + struct AdbcError* error); + +/// \defgroup adbc-statement-sql SQL Semantics +/// Functions for executing SQL queries, or querying SQL-related +/// metadata. Drivers are not required to support both SQL and +/// Substrait semantics. If they do, it may be via converting +/// between representations internally. +/// @{ + +/// \brief Set the SQL query to execute. +/// +/// The query can then be executed with AdbcStatementExecute. For +/// queries expected to be executed repeatedly, AdbcStatementPrepare +/// the statement first. +/// +/// \param[in] statement The statement. +/// \param[in] query The query to execute. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, + const char* query, struct AdbcError* error); + +/// @} + +/// \defgroup adbc-statement-substrait Substrait Semantics +/// Functions for executing Substrait plans, or querying +/// Substrait-related metadata. Drivers are not required to support +/// both SQL and Substrait semantics. If they do, it may be via +/// converting between representations internally. +/// @{ + +/// \brief Set the Substrait plan to execute. +/// +/// The query can then be executed with AdbcStatementExecute. For +/// queries expected to be executed repeatedly, AdbcStatementPrepare +/// the statement first. +/// +/// \param[in] statement The statement. +/// \param[in] plan The serialized substrait.Plan to execute. +/// \param[in] length The length of the serialized plan. +/// \param[out] error Error details, if an error occurs. +ADBC_EXPORT +AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, + const uint8_t* plan, size_t length, + struct AdbcError* error); + +/// @} + +/// \brief Bind Arrow data. This can be used for bulk inserts or +/// prepared statements. +/// +/// \param[in] statement The statement to bind to. +/// \param[in] values The values to bind. The driver will call the +/// release callback itself, although it may not do this until the +/// statement is released. +/// \param[in] schema The schema of the values to bind. +/// \param[out] error An optional location to return an error message +/// if necessary. +ADBC_EXPORT +AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement, + struct ArrowArray* values, struct ArrowSchema* schema, + struct AdbcError* error); + +/// \brief Bind Arrow data. This can be used for bulk inserts or +/// prepared statements. +/// \param[in] statement The statement to bind to. +/// \param[in] stream The values to bind. The driver will call the +/// release callback itself, although it may not do this until the +/// statement is released. +/// \param[out] error An optional location to return an error message +/// if necessary. +ADBC_EXPORT +AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, + struct ArrowArrayStream* stream, + struct AdbcError* error); + +/// \brief Get the schema for bound parameters. +/// +/// This retrieves an Arrow schema describing the number, names, and +/// types of the parameters in a parameterized statement. The fields +/// of the schema should be in order of the ordinal position of the +/// parameters; named parameters should appear only once. +/// +/// If the parameter does not have a name, or the name cannot be +/// determined, the name of the corresponding field in the schema will +/// be an empty string. If the type cannot be determined, the type of +/// the corresponding field will be NA (NullType). +/// +/// This should be called after AdbcStatementPrepare. +/// +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the schema cannot be determined. +ADBC_EXPORT +AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error); + +/// \brief Set a string option on a statement. +ADBC_EXPORT +AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const char* key, + const char* value, struct AdbcError* error); + +/// \addtogroup adbc-statement-partition +/// @{ + +/// \brief Execute a statement and get the results as a partitioned +/// result set. +/// +/// \param[in] statement The statement to execute. +/// \param[out] schema The schema of the result set. +/// \param[out] partitions The result partitions. +/// \param[out] rows_affected The number of rows affected if known, +/// else -1. Pass NULL if the client does not want this information. +/// \param[out] error An optional location to return an error +/// message if necessary. +/// \return ADBC_STATUS_NOT_IMPLEMENTED if the driver does not support +/// partitioned results +ADBC_EXPORT +AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcPartitions* partitions, + int64_t* rows_affected, + struct AdbcError* error); + +/// @} + +/// @} + +/// \addtogroup adbc-driver +/// @{ + +/// \brief Common entry point for drivers via the driver manager +/// (which uses dlopen(3)/LoadLibrary). The driver manager is told +/// to load a library and call a function of this type to load the +/// driver. +/// +/// Although drivers may choose any name for this function, the +/// recommended name is "AdbcDriverInit". +/// +/// \param[in] version The ADBC revision to attempt to initialize (see +/// ADBC_VERSION_1_0_0). +/// \param[out] driver The table of function pointers to +/// initialize. Should be a pointer to the appropriate struct for +/// the given version (see the documentation for the version). +/// \param[out] error An optional location to return an error message +/// if necessary. +/// \return ADBC_STATUS_OK if the driver was initialized, or +/// ADBC_STATUS_NOT_IMPLEMENTED if the version is not supported. In +/// that case, clients may retry with a different version. +typedef AdbcStatusCode (*AdbcDriverInitFunc)(int version, void* driver, + struct AdbcError* error); + +/// @} + +#endif // ADBC + +#ifdef __cplusplus +} +#endif diff --git a/cpp/src/arrow/flight/sql/CMakeLists.txt b/cpp/src/arrow/flight/sql/CMakeLists.txt index 628b02b9d28..bedbe440996 100644 --- a/cpp/src/arrow/flight/sql/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/CMakeLists.txt @@ -37,6 +37,8 @@ set_source_files_properties(${FLIGHT_SQL_GENERATED_PROTO_FILES} PROPERTIES GENER add_custom_target(flight_sql_protobuf_gen ALL DEPENDS ${FLIGHT_SQL_GENERATED_PROTO_FILES}) set(ARROW_FLIGHT_SQL_SRCS + adbc_driver.cc + adbc_driver_internal.cc server.cc sql_info_internal.cc column_metadata.cc @@ -94,7 +96,7 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES) example/sqlite_server.cc example/sqlite_tables_schema_batch_reader.cc) - set(ARROW_FLIGHT_SQL_TEST_SRCS server_test.cc) + set(ARROW_FLIGHT_SQL_TEST_SRCS adbc_driver_test.cc server_test.cc) set(ARROW_FLIGHT_SQL_TEST_LIBS ${SQLite3_LIBRARIES}) set(ARROW_FLIGHT_SQL_ACERO_SRCS example/acero_server.cc) @@ -124,8 +126,15 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES) STATIC_LINK_LIBS ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS} ${ARROW_FLIGHT_SQL_TEST_LIBS} + nanoarrow + AdbcValidation::adbc_validation + # Needs to come twice since the validation library + # also uses ADBC symbols, which get provided by + # libarrow_flight_sql here. + ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS} EXTRA_INCLUDES - "${CMAKE_CURRENT_BINARY_DIR}/../" + # adbc_validation.h needs adbc.h + "${CMAKE_SOURCE_DIR}/../format/" LABELS "arrow_flight_sql") diff --git a/cpp/src/arrow/flight/sql/adbc_driver.cc b/cpp/src/arrow/flight/sql/adbc_driver.cc new file mode 100644 index 00000000000..96b93f31827 --- /dev/null +++ b/cpp/src/arrow/flight/sql/adbc_driver.cc @@ -0,0 +1,2062 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/array/array_binary.h" +#include "arrow/array/array_nested.h" +#include "arrow/array/builder_base.h" +#include "arrow/array/builder_binary.h" +#include "arrow/array/builder_nested.h" +#include "arrow/array/builder_primitive.h" +#include "arrow/array/builder_union.h" +#include "arrow/c/bridge.h" +#include "arrow/config.h" +#include "arrow/flight/client.h" +#include "arrow/flight/sql/adbc_driver_internal.h" +#include "arrow/flight/sql/client.h" +#include "arrow/flight/sql/server.h" +#include "arrow/flight/sql/types.h" +#include "arrow/flight/sql/visibility.h" +#include "arrow/io/memory.h" +#include "arrow/io/type_fwd.h" +#include "arrow/ipc/dictionary.h" +#include "arrow/ipc/reader.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/table.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/config.h" +#include "arrow/util/logging.h" +#include "arrow/util/value_parsing.h" + +#ifdef ARROW_COMPUTE +#include "arrow/compute/api_scalar.h" +#include "arrow/compute/exec.h" +#endif + +namespace arrow::flight::sql { + +using arrow::internal::checked_cast; + +namespace { +static const std::vector> kDefaultTypeMapping = { + {Type::BINARY, "BLOB"}, + {Type::BOOL, "BOOLEAN"}, + {Type::DATE32, "DATE"}, + {Type::DATE64, "DATE"}, + {Type::DECIMAL128, "NUMERIC"}, + {Type::DECIMAL256, "NUMERIC"}, + {Type::DOUBLE, "DOUBLE PRECISION"}, + {Type::FLOAT, "REAL"}, + {Type::INT16, "SMALLINT"}, + {Type::INT32, "INT"}, + {Type::INT64, "BIGINT"}, + {Type::LARGE_BINARY, "BLOB"}, + {Type::LARGE_STRING, "TEXT"}, + {Type::STRING, "TEXT"}, + {Type::TIME32, "TIME"}, + {Type::TIME64, "TIME"}, + {Type::TIMESTAMP, "TIMESTAMP"}, +}; + +/// \brief Client-side configuration to help paper over SQL dialect differences +/// +/// This is needed when we have to generate SQL to implement ADBC +/// features that have no direct Flight SQL equivalent. In +/// particular, it's needed for bulk ingestion to a target table. +struct FlightSqlQuirks { + /// A mapping from Arrow type to SQL type string + std::unordered_map ingest_type_mapping; + + FlightSqlQuirks() + : ingest_type_mapping(kDefaultTypeMapping.begin(), kDefaultTypeMapping.end()) {} + + bool UpdateTypeMapping(std::string_view type_name, const char* value) { + if (type_name == "binary") { + ingest_type_mapping[Type::BINARY] = value; + } else if (type_name == "bool") { + ingest_type_mapping[Type::BOOL] = value; + } else if (type_name == "date32") { + ingest_type_mapping[Type::DATE32] = value; + } else if (type_name == "date64") { + ingest_type_mapping[Type::DATE64] = value; + } else if (type_name == "decimal128") { + ingest_type_mapping[Type::DECIMAL128] = value; + } else if (type_name == "decimal256") { + ingest_type_mapping[Type::DECIMAL256] = value; + } else if (type_name == "double") { + ingest_type_mapping[Type::DOUBLE] = value; + } else if (type_name == "float") { + ingest_type_mapping[Type::FLOAT] = value; + } else if (type_name == "int16") { + ingest_type_mapping[Type::INT16] = value; + } else if (type_name == "int32") { + ingest_type_mapping[Type::INT32] = value; + } else if (type_name == "int64") { + ingest_type_mapping[Type::INT64] = value; + } else if (type_name == "large_binary") { + ingest_type_mapping[Type::LARGE_BINARY] = value; + } else if (type_name == "large_string") { + ingest_type_mapping[Type::LARGE_STRING] = value; + } else if (type_name == "string") { + ingest_type_mapping[Type::STRING] = value; + } else if (type_name == "time32") { + ingest_type_mapping[Type::TIME32] = value; + } else if (type_name == "time64") { + ingest_type_mapping[Type::TIME64] = value; + } else if (type_name == "timestamp") { + ingest_type_mapping[Type::TIMESTAMP] = value; + } else { + return false; + } + return true; + } +}; + +/// Config options that map to FlightClientOptions +constexpr std::string_view kClientOptionTlsRootCerts = + "arrow.flight.sql.client_option.tls_root_certs"; +constexpr std::string_view kClientOptionOverrideHostname = + "arrow.flight.sql.client_option.override_hostname"; +constexpr std::string_view kClientOptionCertChain = + "arrow.flight.sql.client_option.cert_chain"; +constexpr std::string_view kClientOptionPrivateKey = + "arrow.flight.sql.client_option.private_key"; +constexpr std::string_view kClientOptionGenericIntOption = + "arrow.flight.sql.client_option.generic_int_option."; +constexpr std::string_view kClientOptionGenericStringOption = + "arrow.flight.sql.client_option.generic_string_option."; +constexpr std::string_view kClientOptionDisableServerVerification = + "arrow.flight.sql.client_option.disable_server_verification"; +/// Config option to enable JDBC driver-like authorization +constexpr std::string_view kAuthorizationHeaderKey = + "arrow.flight.sql.authorization_header"; +/// Config options used to override the type mapping in FlightSqlQuirks +constexpr std::string_view kIngestTypePrefix = "arrow.flight.sql.quirks.ingest_type."; +/// Explicitly specify the Substrait version for Flight SQL (although +/// Substrait will eventually embed this into the plan itself) +constexpr std::string_view kStatementSubstraitVersionKey = + "arrow.flight.sql.substrait.version"; +/// Attach arbitrary key-value headers via Flight +constexpr std::string_view kCallHeaderPrefix = "arrow.flight.sql.rpc.call_header."; +/// A timeout for any DoGet requests +constexpr std::string_view kConnectionTimeoutFetchKey = + "arrow.flight.sql.rpc.timeout_seconds.fetch"; +/// A timeout for any GetFlightInfo requests +constexpr std::string_view kConnectionTimeoutQueryKey = + "arrow.flight.sql.rpc.timeout_seconds.query"; +/// A timeout for any DoPut requests, or miscellaneous DoAction requests +constexpr std::string_view kConnectionTimeoutUpdateKey = + "arrow.flight.sql.rpc.timeout_seconds.update"; +constexpr std::string_view kConnectionOptionAutocommit = + ADBC_CONNECTION_OPTION_AUTOCOMMIT; +constexpr std::string_view kIngestOptionMode = ADBC_INGEST_OPTION_MODE; +constexpr std::string_view kIngestOptionModeAppend = ADBC_INGEST_OPTION_MODE_APPEND; +constexpr std::string_view kIngestOptionModeCreate = ADBC_INGEST_OPTION_MODE_CREATE; +constexpr std::string_view kIngestOptionTargetTable = ADBC_INGEST_OPTION_TARGET_TABLE; +constexpr std::string_view kOptionValueEnabled = ADBC_OPTION_VALUE_ENABLED; +constexpr std::string_view kOptionValueDisabled = ADBC_OPTION_VALUE_DISABLED; + +enum class CallContext { + kFetch, + kQuery, + kUpdate, +}; + +static const char kAuthorizationHeader[] = "authorization"; + +/// Implement auth in the same way as the Flight SQL JDBC driver +/// 1. Client ---[authorization: Basic XXX ]--> Server +/// 2. Client <--[authorization: Bearer YYY]--- Server +/// 3. Client ---[authorization: Bearer YYY]--> Server +class ClientAuthorizationMiddlewareFactory : public ClientMiddlewareFactory { + public: + explicit ClientAuthorizationMiddlewareFactory(std::string initial_header) + : token_(std::move(initial_header)) {} + + void StartCall(const CallInfo& info, + std::unique_ptr* middleware) override { + // TODO: maybe we want a shared_ptr and an atomic reference? + std::lock_guard guard(token_mutex_); + // Middleware instances outlive the factory + *middleware = std::make_unique(this, token_); + } + + void UpdateToken(std::string_view token) { + std::lock_guard guard(token_mutex_); + token_ = std::string(token); + } + + private: + std::mutex token_mutex_; + std::string token_; + + class Impl : public ClientMiddleware { + public: + explicit Impl(ClientAuthorizationMiddlewareFactory* factory, std::string token) + : factory_(factory), token_(std::move(token)) {} + void SendingHeaders(AddCallHeaders* outgoing_headers) override { + outgoing_headers->AddHeader(kAuthorizationHeader, token_); + } + void ReceivedHeaders(const CallHeaders& incoming_headers) override { + // Assume server isn't going to send multiple authorization headers + auto it = incoming_headers.find(std::string_view(kAuthorizationHeader)); + if (it == incoming_headers.end()) return; + if (it->second == token_) return; + factory_->UpdateToken(it->second); + } + void CallCompleted(const Status&) override {} + + private: + ClientAuthorizationMiddlewareFactory* factory_; + std::string token_; + }; +}; + +/// \brief AdbcDatabase implementation +class FlightSqlDatabaseImpl { + public: + AdbcStatusCode Connect(std::unique_ptr* client, + struct AdbcError* error) { + std::lock_guard guard(mutex_); + if (!location_.has_value()) { + SetError(error, "Cannot create connection from uninitialized database"); + return ADBC_STATUS_INVALID_STATE; + } + + std::unique_ptr flight_client; + ADBC_ARROW_RETURN_NOT_OK( + error, FlightClient::Connect(*location_, client_options_).Value(&flight_client)); + *client = std::make_unique(std::move(flight_client)); + ++connection_count_; + return ADBC_STATUS_OK; + } + + const std::shared_ptr& quirks() const { return quirks_; } + FlightCallOptions MakeCallOptions(CallContext context) const { + FlightCallOptions options; + for (const auto& header : call_headers_) { + options.headers.emplace_back(header.first, header.second); + } + return options; + } + + AdbcStatusCode Init(struct AdbcError* error) { + std::lock_guard guard(mutex_); + + if (location_.has_value()) { + SetError(error, "Database already initialized"); + return ADBC_STATUS_INVALID_STATE; + } + + auto it = options_.find("uri"); + if (it == options_.end()) { + SetError(error, "Must provide 'uri' option"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + Location location; + ADBC_ARROW_RETURN_NOT_OK(error, Location::Parse(it->second).Value(&location)); + location_ = location; + + options_.clear(); + return ADBC_STATUS_OK; + } + + AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error) { + if (key == nullptr) { + SetError(error, "Key must not be null"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + std::string_view key_view(key); + std::string_view val_view = value ? value : ""; + if (key_view.rfind(kIngestTypePrefix, 0) == 0) { + // Changing the mapping from Arrow type <-> SQL type name, for when we + // do a bulk ingest and have to generate a CREATE TABLE statement + const std::string_view type_name = key_view.substr(kIngestTypePrefix.size()); + if (!quirks_->UpdateTypeMapping(type_name, value)) { + SetError(error, "Unknown option value ", key_view, "=", val_view, ": type name ", + type_name, " is not recognized"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + return ADBC_STATUS_OK; + } else if (key_view.rfind(kCallHeaderPrefix, 0) == 0) { + // Add a custom header to all outgoing calls (can also be set at + // connection, statement level) + std::string header(key_view.substr(kCallHeaderPrefix.size())); + if (value == nullptr) { + call_headers_.erase(header); + } else { + call_headers_.insert({std::move(header), std::string(val_view)}); + } + return ADBC_STATUS_OK; + } else if (key_view == kClientOptionTlsRootCerts) { + client_options_.tls_root_certs = val_view; + return ADBC_STATUS_OK; + } else if (key_view == kClientOptionOverrideHostname) { + client_options_.override_hostname = val_view; + return ADBC_STATUS_OK; + } else if (key_view == kClientOptionCertChain) { + client_options_.cert_chain = val_view; + return ADBC_STATUS_OK; + } else if (key_view == kClientOptionPrivateKey) { + client_options_.private_key = val_view; + return ADBC_STATUS_OK; + } else if (key_view == kClientOptionDisableServerVerification) { + if (val_view == kOptionValueEnabled) { + client_options_.disable_server_verification = true; + return ADBC_STATUS_OK; + } else if (val_view == kOptionValueDisabled) { + client_options_.disable_server_verification = false; + return ADBC_STATUS_OK; + } + SetError(error, + "Invalid boolean value for client option disable_server_verification: ", + val_view); + return ADBC_STATUS_INVALID_ARGUMENT; + } else if (key_view.rfind(kClientOptionGenericIntOption, 0) == 0) { + const std::string_view option_name = + key_view.substr(kClientOptionGenericIntOption.size()); + int32_t option_value = 0; + if (!arrow::internal::StringConverter().Convert( + Int32Type(), val_view.data(), val_view.size(), &option_value)) { + SetError(error, + "Invalid integer value for client option generic_options: ", val_view); + return ADBC_STATUS_INVALID_ARGUMENT; + } + client_options_.generic_options.emplace_back(option_name, option_value); + return ADBC_STATUS_OK; + } else if (key_view.rfind(kClientOptionGenericStringOption, 0) == 0) { + const std::string_view option_name = + key_view.substr(kClientOptionGenericStringOption.size()); + client_options_.generic_options.emplace_back(option_name, std::string(val_view)); + return ADBC_STATUS_OK; + } else if (key_view == kAuthorizationHeaderKey) { + client_options_.middleware.push_back( + std::make_shared(std::string(val_view))); + return ADBC_STATUS_OK; + } + + if (location_.has_value()) { + SetError(error, "Database already initialized"); + return ADBC_STATUS_INVALID_STATE; + } + options_[std::string(key_view)] = std::string(val_view); + return ADBC_STATUS_OK; + } + + AdbcStatusCode Disconnect(struct AdbcError* error) { + std::lock_guard guard(mutex_); + if (--connection_count_ < 0) { + SetError(error, "Connection count underflow"); + return ADBC_STATUS_INTERNAL; + } + return ADBC_STATUS_OK; + } + + AdbcStatusCode Release(struct AdbcError* error) { + std::lock_guard guard(mutex_); + + if (connection_count_ > 0) { + SetError(error, "Cannot release database with ", connection_count_, + " open connections"); + return ADBC_STATUS_INTERNAL; + } + + return ADBC_STATUS_OK; + } + + private: + std::optional location_; + FlightClientOptions client_options_ = DefaultClientOptions(); + std::shared_ptr quirks_ = std::make_shared(); + std::unordered_map call_headers_; + std::unordered_map options_; + std::mutex mutex_; + int connection_count_ = 0; +}; + +/// \brief A RecordBatchReader that reads the endpoints of a FlightInfo +class FlightInfoReader : public RecordBatchReader { + public: + explicit FlightInfoReader(FlightSqlClient* client, FlightCallOptions call_options, + std::unique_ptr info) + : client_(client), + call_options_(std::move(call_options)), + info_(std::move(info)), + next_endpoint_(0) {} + + std::shared_ptr schema() const override { return schema_; } + + Status ReadNext(std::shared_ptr* batch) override { + FlightStreamChunk chunk; + while (current_stream_ && !chunk.data) { + ARROW_ASSIGN_OR_RAISE(chunk, current_stream_->Next()); + if (chunk.data) { + *batch = chunk.data; + break; + } + if (!chunk.data && !chunk.app_metadata) { + RETURN_NOT_OK(NextStream()); + } + } + if (!current_stream_) *batch = nullptr; + return Status::OK(); + } + + Status Close() override { + if (current_stream_) { + current_stream_->Cancel(); + current_stream_.reset(); + } + return Status::OK(); + } + + AdbcStatusCode Init(struct AdbcError* error) { + ADBC_ARROW_RETURN_NOT_OK(error, NextStream()); + if (!schema_) { + // Empty result set - fall back on schema in FlightInfo + ipc::DictionaryMemo memo; + ADBC_ARROW_RETURN_NOT_OK(error, info_->GetSchema(&memo).Value(&schema_)); + } + return ADBC_STATUS_OK; + } + + /// \brief Export to an ArrowArrayStream + static AdbcStatusCode Export(FlightSqlClient* client, FlightCallOptions call_options, + std::unique_ptr info, + struct ArrowArrayStream* stream, struct AdbcError* error) { + auto reader = std::make_shared(client, std::move(call_options), + std::move(info)); + ADBC_RETURN_NOT_OK(reader->Init(error)); + ADBC_ARROW_RETURN_NOT_OK(error, ExportRecordBatchReader(std::move(reader), stream)); + return ADBC_STATUS_OK; + } + + private: + Status NextStream() { + if (next_endpoint_ >= info_->endpoints().size()) { + current_stream_ = nullptr; + return Status::OK(); + } + const FlightEndpoint& endpoint = info_->endpoints()[next_endpoint_]; + + if (endpoint.locations.empty()) { + ARROW_ASSIGN_OR_RAISE(current_stream_, + client_->DoGet(call_options_, endpoint.ticket)); + } else { + // TODO(lidavidm): this should come from a connection pool + std::string failures; + current_stream_ = nullptr; + for (const Location& location : endpoint.locations) { + auto status = + FlightClient::Connect(location, DefaultClientOptions()).Value(&data_client_); + if (status.ok()) { + status = + data_client_->DoGet(call_options_, endpoint.ticket).Value(¤t_stream_); + } + if (!status.ok()) { + if (!failures.empty()) { + failures += "; "; + } + failures += location.ToString(); + failures += ": "; + failures += status.ToString(); + data_client_.reset(); + continue; + } + break; + } + + if (!current_stream_) { + return Status::IOError("Failed to connect to all endpoints: ", failures); + } + } + next_endpoint_++; + if (!schema_) { + ARROW_ASSIGN_OR_RAISE(schema_, current_stream_->GetSchema()); + } + return Status::OK(); + } + + FlightSqlClient* client_; + FlightCallOptions call_options_; + std::unique_ptr info_; + size_t next_endpoint_; + std::shared_ptr schema_; + std::unique_ptr current_stream_; + // TODO(lidavidm): use a common pool of cached clients with expiration + std::unique_ptr data_client_; +}; + +class FlightSqlConnectionImpl { + public: + //---------------------------------------------------------- + // Common Functions + //---------------------------------------------------------- + + FlightSqlClient* client() const { return client_.get(); } + const FlightSqlQuirks& quirks() const { return *quirks_; } + const Transaction& transaction() const { return transaction_; } + FlightCallOptions MakeCallOptions(CallContext context) const { + FlightCallOptions options = database_->MakeCallOptions(context); + auto it = timeout_seconds_.find(context); + if (it != timeout_seconds_.end()) { + options.timeout = it->second; + } + for (const auto& header : call_headers_) { + options.headers.emplace_back(header.first, header.second); + } + return options; + } + + AdbcStatusCode Init(struct AdbcDatabase* database, struct AdbcError* error) { + if (!database->private_data) { + SetError(error, "Cannot create connection from uninitialized database"); + return ADBC_STATUS_INVALID_STATE; + } else if (client_) { + SetError(error, "Already initialized connection"); + return ADBC_STATUS_INVALID_STATE; + } + + database_ = *reinterpret_cast*>( + database->private_data); + ADBC_RETURN_NOT_OK(database_->Connect(&client_, error)); + quirks_ = database_->quirks(); + + // Ignore this if it fails - we'll just proceed + ARROW_UNUSED(QuerySqlInfo()); + return ADBC_STATUS_OK; + } + + AdbcStatusCode Close(struct AdbcError* error) { + AdbcStatusCode return_status = ADBC_STATUS_OK; + if (client_) { + auto status = client_->Close(); + client_.reset(); + if (!status.ok()) { + SetError(error, status); + return_status = ADBC_STATUS_IO; + } + + if (database_) { + ADBC_RETURN_NOT_OK(database_->Disconnect(error)); + } + } + + return return_status; + } + + AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error) { + if (key == nullptr) { + SetError(error, "Key must not be null"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + auto set_timeout_option = [=](CallContext context) -> AdbcStatusCode { + double timeout = 0.0; + const size_t len = std::strlen(value); + if (!arrow::internal::StringToFloat(value, len, /*decimal_point=*/'.', &timeout)) { + SetError(error, "Invalid timeout option value ", key, '=', value, + ": invalid floating point value"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (std::isnan(timeout) || std::isinf(timeout) || timeout < 0) { + SetError(error, "Invalid timeout option value ", key, '=', value, + ": timeout must be positive and finite"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + if (timeout == 0) { + timeout_seconds_.erase(context); + } else { + timeout_seconds_[context] = std::chrono::duration(timeout); + } + return ADBC_STATUS_OK; + }; + + std::string_view key_view(key); + std::string_view val_view = value ? value : ""; + if (key == kConnectionOptionAutocommit) { + if (val_view == kOptionValueEnabled && !transaction_.is_valid()) { + // No-op - don't error even if the server didn't support transactions + return ADBC_STATUS_OK; + } + ADBC_RETURN_NOT_OK(CheckTransactionSupport(error)); + FlightCallOptions call_options = MakeCallOptions(CallContext::kUpdate); + if (val_view == kOptionValueEnabled) { + if (transaction_.is_valid()) { + ADBC_ARROW_RETURN_NOT_OK(error, client_->Commit(call_options, transaction_)); + transaction_ = no_transaction(); + } + return ADBC_STATUS_OK; + } else if (val_view == kOptionValueDisabled) { + if (transaction_.is_valid()) { + ADBC_ARROW_RETURN_NOT_OK(error, client_->Commit(call_options, transaction_)); + } + ADBC_ARROW_RETURN_NOT_OK( + error, client_->BeginTransaction(call_options).Value(&transaction_)); + return ADBC_STATUS_OK; + } + SetError(error, "Invalid connection option value ", key_view, '=', val_view); + return ADBC_STATUS_INVALID_ARGUMENT; + } else if (key == kConnectionTimeoutFetchKey) { + return set_timeout_option(CallContext::kFetch); + } else if (key == kConnectionTimeoutQueryKey) { + return set_timeout_option(CallContext::kQuery); + } else if (key == kConnectionTimeoutUpdateKey) { + return set_timeout_option(CallContext::kUpdate); + } else if (key_view.rfind(kCallHeaderPrefix, 0) == 0) { + std::string header(key_view.substr(kCallHeaderPrefix.size())); + if (value == nullptr) { + call_headers_.erase(header); + } else { + call_headers_.insert({std::move(header), value}); + } + return ADBC_STATUS_OK; + } + SetError(error, "Unknown connection option ", key_view, '=', val_view); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + //---------------------------------------------------------- + // Metadata + //---------------------------------------------------------- + + AdbcStatusCode GetInfo(uint32_t* info_codes, size_t info_codes_length, + struct ArrowArrayStream* stream, struct AdbcError* error) { + static std::shared_ptr kInfoSchema = arrow::schema({ + arrow::field("info_name", arrow::uint32(), /*nullable=*/false), + arrow::field( + "info_value", + arrow::dense_union({ + arrow::field("string_value", arrow::utf8()), + arrow::field("bool_value", arrow::boolean()), + arrow::field("int64_value", arrow::int64()), + arrow::field("int32_bitmask", arrow::int32()), + arrow::field("string_list", arrow::list(arrow::utf8())), + arrow::field("int32_to_int32_list_map", + arrow::map(arrow::int32(), arrow::list(arrow::int32()))), + })), + }); + + // XXX(ARROW-17558): type should be uint32_t not int + std::vector flight_sql_codes; + std::vector codes; + if (info_codes && info_codes_length > 0) { + for (size_t i = 0; i < info_codes_length; i++) { + const uint32_t info_code = info_codes[i]; + switch (info_code) { + case ADBC_INFO_VENDOR_NAME: + case ADBC_INFO_VENDOR_VERSION: + case ADBC_INFO_VENDOR_ARROW_VERSION: + // These codes are equivalent between the two + flight_sql_codes.push_back(info_code); + break; + case ADBC_INFO_DRIVER_NAME: + case ADBC_INFO_DRIVER_VERSION: + case ADBC_INFO_DRIVER_ARROW_VERSION: + codes.push_back(info_code); + break; + default: + SetError(error, "Unknown info code: ", info_code); + return ADBC_STATUS_INVALID_ARGUMENT; + } + } + } else { + flight_sql_codes = { + SqlInfoOptions::FLIGHT_SQL_SERVER_NAME, + SqlInfoOptions::FLIGHT_SQL_SERVER_VERSION, + SqlInfoOptions::FLIGHT_SQL_SERVER_ARROW_VERSION, + }; + codes = { + ADBC_INFO_DRIVER_NAME, + ADBC_INFO_DRIVER_VERSION, + ADBC_INFO_DRIVER_ARROW_VERSION, + }; + } + + RecordBatchVector result; + + UInt32Builder names; + std::unique_ptr values; + ADBC_ARROW_RETURN_NOT_OK(error, + MakeBuilder(kInfoSchema->field(1)->type()).Value(&values)); + auto* info_value = static_cast(values.get()); + auto* info_string = static_cast(info_value->child_builder(0).get()); + int64_t num_values = 0; + + constexpr int8_t kStringCode = 0; + + if (!flight_sql_codes.empty()) { + FlightCallOptions call_options = MakeCallOptions(CallContext::kQuery); + std::unique_ptr info; + ADBC_ARROW_RETURN_NOT_OK( + error, client_->GetSqlInfo(call_options, flight_sql_codes).Value(&info)); + FlightInfoReader reader(client_.get(), MakeCallOptions(CallContext::kFetch), + std::move(info)); + ADBC_RETURN_NOT_OK(reader.Init(error)); + + if (!reader.schema()->Equals(*SqlSchema::GetSqlInfoSchema())) { + SetError(error, "Server returned wrong schema, got: ", *reader.schema()); + return ADBC_STATUS_INTERNAL; + } + + while (true) { + std::shared_ptr batch; + ADBC_ARROW_RETURN_NOT_OK(error, reader.Next().Value(&batch)); + if (!batch) break; + + const auto& sql_codes = checked_cast(*batch->column(0)); + const auto& sql_value = checked_cast(*batch->column(1)); + const auto& sql_string = + checked_cast(*sql_value.field(kStringCode)); + for (int64_t i = 0; i < batch->num_rows(); i++) { + // Shouldn't happen but oh well + if (!sql_codes.IsValid(i)) continue; + + switch (sql_codes.Value(i)) { + case SqlInfoOptions::FLIGHT_SQL_SERVER_NAME: + case SqlInfoOptions::FLIGHT_SQL_SERVER_VERSION: + case SqlInfoOptions::FLIGHT_SQL_SERVER_ARROW_VERSION: { + // These should all be string values where the codes are + // equivalent between ADBC/Flight SQL + ADBC_ARROW_RETURN_NOT_OK(error, names.Append(sql_codes.Value(i))); + if (sql_value.type_code(i) != kStringCode) { + SetError(error, "Server returned wrong type for info value ", + sql_codes.Value(i)); + return ADBC_STATUS_INTERNAL; + } + ADBC_ARROW_RETURN_NOT_OK( + error, + info_string->Append(sql_string.GetString(sql_value.value_offset(i)))); + ADBC_ARROW_RETURN_NOT_OK(error, info_value->Append(kStringCode)); + num_values++; + break; + } + default: + // Ignore if the server returns something unknown + continue; + } + } + } + } + + if (!codes.empty()) { + for (const uint32_t code : codes) { + switch (code) { + case ADBC_INFO_DRIVER_NAME: + ADBC_ARROW_RETURN_NOT_OK( + error, info_string->Append("Apache Arrow Flight SQL ADBC Driver")); + ADBC_ARROW_RETURN_NOT_OK(error, info_value->Append(kStringCode)); + break; + case ADBC_INFO_DRIVER_VERSION: + case ADBC_INFO_DRIVER_ARROW_VERSION: + ADBC_ARROW_RETURN_NOT_OK(error, + info_string->Append(GetBuildInfo().version_string)); + ADBC_ARROW_RETURN_NOT_OK(error, info_value->Append(kStringCode)); + break; + default: + SetError(error, "Unknown info code: ", code); + return ADBC_STATUS_INVALID_ARGUMENT; + } + ADBC_ARROW_RETURN_NOT_OK(error, names.Append(code)); + num_values++; + } + } + + ArrayVector cols(2); + ADBC_ARROW_RETURN_NOT_OK(error, names.Finish(&cols[0])); + ADBC_ARROW_RETURN_NOT_OK(error, values->Finish(&cols[1])); + result.push_back(RecordBatch::Make(kInfoSchema, num_values, std::move(cols))); + + ADBC_ARROW_RETURN_NOT_OK(error, + ExportRecordBatches(kInfoSchema, std::move(result), stream)); + return ADBC_STATUS_OK; + } + + AdbcStatusCode GetObjects(int depth, const char* catalog, const char* db_schema, + const char* table_name, const char** table_type, + const char* column_name, struct ArrowArrayStream* stream, + struct AdbcError* error) { + static std::shared_ptr kColumnSchema = arrow::struct_({ + arrow::field("column_name", arrow::utf8(), /*nullable=*/false), + arrow::field("ordinal_position", arrow::int32()), + arrow::field("remarks", arrow::utf8()), + arrow::field("xdbc_data_type", arrow::int16()), + arrow::field("xdbc_type_name", arrow::utf8()), + arrow::field("xdbc_column_size", arrow::int32()), + arrow::field("xdbc_decimal_digits", arrow::int16()), + arrow::field("xdbc_num_prec_radix", arrow::int16()), + arrow::field("xdbc_nullable", arrow::int16()), + arrow::field("xdbc_column_def", arrow::utf8()), + arrow::field("xdbc_sql_data_type", arrow::int16()), + arrow::field("xdbc_datetime_sub", arrow::int16()), + arrow::field("xdbc_char_octet_length", arrow::int32()), + arrow::field("xdbc_is_nullable", arrow::utf8()), + arrow::field("xdbc_scope_catalog", arrow::utf8()), + arrow::field("xdbc_scope_schema", arrow::utf8()), + arrow::field("xdbc_scope_table", arrow::utf8()), + arrow::field("xdbc_is_autoincrement", arrow::boolean()), + arrow::field("xdbc_is_generatedcolumn", arrow::boolean()), + }); + static std::shared_ptr kUsageSchema = arrow::struct_({ + arrow::field("fk_catalog", arrow::utf8()), + arrow::field("fk_db_schema", arrow::utf8()), + arrow::field("fk_table", arrow::utf8(), /*nullable=*/false), + arrow::field("fk_column_name", arrow::utf8(), /*nullable=*/false), + }); + static std::shared_ptr kConstraintSchema = arrow::struct_({ + arrow::field("constraint_name", arrow::utf8()), + arrow::field("constraint_type", arrow::utf8(), /*nullable=*/false), + arrow::field("constraint_column_names", arrow::list(arrow::utf8()), + /*nullable=*/false), + arrow::field("constraint_column_usage", arrow::list(kUsageSchema)), + }); + static std::shared_ptr kTableSchema = arrow::struct_({ + arrow::field("table_name", arrow::utf8(), /*nullable=*/false), + arrow::field("table_type", arrow::utf8(), /*nullable=*/false), + arrow::field("table_columns", arrow::list(kColumnSchema)), + arrow::field("table_constraints", arrow::list(kConstraintSchema)), + }); + static std::shared_ptr kDbSchemaSchema = arrow::struct_({ + arrow::field("db_schema_name", arrow::utf8()), + arrow::field("db_schema_tables", arrow::list(kTableSchema)), + }); + static std::shared_ptr kCatalogSchema = arrow::schema({ + arrow::field("catalog_name", arrow::utf8()), + arrow::field("catalog_db_schemas", arrow::list(kDbSchemaSchema)), + }); + FlightCallOptions call_options = MakeCallOptions(CallContext::kQuery); + + // To avoid an N+1 query problem, we assume result sets here will + // fit in memory and build up a single response. + + std::string db_schema_filter = db_schema ? db_schema : ""; + std::string table_name_filter = table_name ? table_name : ""; + std::vector table_type_filter; + if (table_type) { + while (*table_type) { + table_type_filter.emplace_back(*table_type); + table_type++; + } + } + + TableIndex table_index; + std::shared_ptr table_data; + + if (depth == ADBC_OBJECT_DEPTH_ALL || depth >= ADBC_OBJECT_DEPTH_DB_SCHEMAS) { + std::unique_ptr info; + ADBC_ARROW_RETURN_NOT_OK(error, + client_ + ->GetDbSchemas(call_options, /*catalog=*/nullptr, + db_schema ? &db_schema_filter : nullptr) + .Value(&info)); + FlightInfoReader reader(client_.get(), MakeCallOptions(CallContext::kFetch), + std::move(info)); + ADBC_RETURN_NOT_OK(reader.Init(error)); + ADBC_ARROW_RETURN_NOT_OK(error, IndexDbSchemas(&reader, &table_index)); + } + if (depth == ADBC_OBJECT_DEPTH_ALL || depth >= ADBC_OBJECT_DEPTH_TABLES) { + std::unique_ptr info; + ADBC_ARROW_RETURN_NOT_OK(error, + client_ + ->GetTables(call_options, /*catalog=*/nullptr, + db_schema ? &db_schema_filter : nullptr, + table_name ? &table_name_filter : nullptr, + /*include_schemas=*/true, + table_type ? &table_type_filter : nullptr) + .Value(&info)); + FlightInfoReader reader(client_.get(), MakeCallOptions(CallContext::kFetch), + std::move(info)); + ADBC_RETURN_NOT_OK(reader.Init(error)); + ADBC_ARROW_RETURN_NOT_OK(error, IndexTables(&reader, &table_index, &table_data)); + } + + std::unique_ptr raw_builder; + ADBC_ARROW_RETURN_NOT_OK(error, MakeBuilder(kDbSchemaSchema).Value(&raw_builder)); + StringBuilder catalog_name_builder; + ListBuilder catalog_db_schemas_builder(default_memory_pool(), std::move(raw_builder)); + auto* catalog_db_schemas_items_builder = + checked_cast(catalog_db_schemas_builder.value_builder()); + auto* db_schema_name_builder = + checked_cast(catalog_db_schemas_items_builder->field_builder(0)); + auto* db_schema_tables_builder = + checked_cast(catalog_db_schemas_items_builder->field_builder(1)); + auto* db_schema_tables_items_builder = + checked_cast(db_schema_tables_builder->value_builder()); + auto* table_name_builder = + checked_cast(db_schema_tables_items_builder->field_builder(0)); + auto* table_type_builder = + checked_cast(db_schema_tables_items_builder->field_builder(1)); + auto* table_columns_builder = + checked_cast(db_schema_tables_items_builder->field_builder(2)); + auto* table_constraints_builder = + checked_cast(db_schema_tables_items_builder->field_builder(3)); + auto* table_columns_items_builder = + checked_cast(table_columns_builder->value_builder()); + auto* column_name_builder = + checked_cast(table_columns_items_builder->field_builder(0)); + auto* ordinal_position_builder = + checked_cast(table_columns_items_builder->field_builder(1)); + + std::unique_ptr info; + ADBC_ARROW_RETURN_NOT_OK(error, client_->GetCatalogs(call_options).Value(&info)); + FlightInfoReader catalog_reader(client_.get(), MakeCallOptions(CallContext::kFetch), + std::move(info)); + ADBC_RETURN_NOT_OK(catalog_reader.Init(error)); + + ReaderIterator catalogs(*SqlSchema::GetCatalogsSchema(), &catalog_reader); + ADBC_ARROW_RETURN_NOT_OK(error, catalogs.Init()); + while (true) { + bool have_data = false; + ADBC_ARROW_RETURN_NOT_OK(error, catalogs.Next().Value(&have_data)); + if (!have_data) break; + + std::optional cur_catalog_name(catalogs.GetNullable(0)); + // TODO(lidavidm): catalog is a filter string (evaluate with compute fn if + // available) + if (catalog && cur_catalog_name != catalog) continue; + + if (cur_catalog_name) { + ADBC_ARROW_RETURN_NOT_OK(error, catalog_name_builder.Append(*cur_catalog_name)); + } else { + ADBC_ARROW_RETURN_NOT_OK(error, catalog_name_builder.AppendNull()); + } + + if (depth == ADBC_OBJECT_DEPTH_ALL || depth >= ADBC_OBJECT_DEPTH_DB_SCHEMAS) { + ADBC_ARROW_RETURN_NOT_OK(error, catalog_db_schemas_builder.Append()); + } else { + ADBC_ARROW_RETURN_NOT_OK(error, catalog_db_schemas_builder.AppendNull()); + continue; + } + + auto it = table_index.find(cur_catalog_name); + if (it == table_index.end()) continue; + + for (const auto& schema_item : it->second) { + ADBC_ARROW_RETURN_NOT_OK(error, catalog_db_schemas_items_builder->Append()); + const std::optional& cur_schema_name = schema_item.first; + + if (cur_schema_name) { + ADBC_ARROW_RETURN_NOT_OK(error, + db_schema_name_builder->Append(*cur_schema_name)); + } else { + ADBC_ARROW_RETURN_NOT_OK(error, db_schema_name_builder->AppendNull()); + } + + if (depth == ADBC_OBJECT_DEPTH_ALL || depth >= ADBC_OBJECT_DEPTH_TABLES) { + ADBC_ARROW_RETURN_NOT_OK(error, db_schema_tables_builder->Append()); + } else { + ADBC_ARROW_RETURN_NOT_OK(error, db_schema_tables_builder->AppendNull()); + continue; + } + + std::pair schema_tables_range = schema_item.second; + for (int64_t i = schema_tables_range.first; i < schema_tables_range.second; i++) { + const std::string_view table_name = GetNotNull(*table_data, 2, i); + const std::string_view table_type = GetNotNull(*table_data, 3, i); + const std::string_view table_schema = GetNotNull(*table_data, 4, i); + + ADBC_ARROW_RETURN_NOT_OK(error, db_schema_tables_items_builder->Append()); + ADBC_ARROW_RETURN_NOT_OK(error, table_name_builder->Append(table_name)); + ADBC_ARROW_RETURN_NOT_OK(error, table_type_builder->Append(table_type)); + + if (depth == ADBC_OBJECT_DEPTH_COLUMNS) { + ADBC_ARROW_RETURN_NOT_OK(error, table_columns_builder->Append()); + + std::shared_ptr schema; + io::BufferReader reader((Buffer(table_schema))); + ipc::DictionaryMemo memo; + ADBC_ARROW_RETURN_NOT_OK(error, + ipc::ReadSchema(&reader, &memo).Value(&schema)); + + int32_t ordinal_position = 1; + for (const std::shared_ptr& field : schema->fields()) { +#ifdef ARROW_COMPUTE + if (column_name && std::strlen(column_name) > 0) { + Datum result; + arrow::compute::MatchSubstringOptions options(column_name); + ADBC_ARROW_RETURN_NOT_OK( + error, arrow::compute::CallFunction( + "match_like", {Datum(MakeScalar(field->name()))}, &options) + .Value(&result)); + DCHECK_EQ(result.kind(), Datum::Kind::SCALAR); + if (!result.scalar_as().value) continue; + } +#else + if (column_name && std::strlen(column_name) > 0) { + SetError(error, "Cannot filter on column name without ARROW_COMPUTE"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } +#endif + ADBC_ARROW_RETURN_NOT_OK(error, table_columns_items_builder->Append()); + ADBC_ARROW_RETURN_NOT_OK(error, column_name_builder->Append(field->name())); + ADBC_ARROW_RETURN_NOT_OK( + error, ordinal_position_builder->Append(ordinal_position++)); + for (int i = 2; i < table_columns_items_builder->num_fields(); i++) { + ADBC_ARROW_RETURN_NOT_OK( + error, table_columns_items_builder->field_builder(i)->AppendNull()); + } + } + + // TODO(lidavidm): unimplemented for now + ADBC_ARROW_RETURN_NOT_OK(error, table_constraints_builder->Append()); + } else { + ADBC_ARROW_RETURN_NOT_OK(error, table_columns_builder->AppendNull()); + ADBC_ARROW_RETURN_NOT_OK(error, table_constraints_builder->AppendNull()); + } + } + } + } + + ArrayVector arrays(2); + ADBC_ARROW_RETURN_NOT_OK(error, catalog_name_builder.Finish(&arrays[0])); + ADBC_ARROW_RETURN_NOT_OK(error, catalog_db_schemas_builder.Finish(&arrays[1])); + const int64_t num_rows = arrays[0]->length(); + auto batch = RecordBatch::Make(kCatalogSchema, num_rows, std::move(arrays)); + + ADBC_ARROW_RETURN_NOT_OK( + error, ExportRecordBatches(kCatalogSchema, {std::move(batch)}, stream)); + return ADBC_STATUS_OK; + } + + AdbcStatusCode GetTableSchema(const char* catalog, const char* db_schema, + const char* table_name, struct ArrowSchema* schema, + struct AdbcError* error) { + FlightCallOptions call_options = MakeCallOptions(CallContext::kQuery); + std::string catalog_str, db_schema_str, table_name_str; + + if (catalog) catalog_str = catalog; + if (db_schema) db_schema_str = db_schema; + if (table_name) table_name_str = table_name; + + std::unique_ptr info; + ADBC_ARROW_RETURN_NOT_OK( + error, client_ + ->GetTables(call_options, catalog ? &catalog_str : nullptr, + db_schema ? &db_schema_str : nullptr, + table_name ? &table_name_str : nullptr, + /*include_schema=*/true, /*table_types=*/nullptr) + .Value(&info)); + + FlightInfoReader reader(client_.get(), MakeCallOptions(CallContext::kFetch), + std::move(info)); + ADBC_RETURN_NOT_OK(reader.Init(error)); + + if (!reader.schema()->Equals(*SqlSchema::GetTablesSchemaWithIncludedSchema())) { + SetError(error, "Server returned wrong schema, got:\n", *reader.schema(), + "\nExpected:\n", *SqlSchema::GetTablesSchemaWithIncludedSchema()); + return ADBC_STATUS_INTERNAL; + } + + while (true) { + std::shared_ptr batch; + ADBC_ARROW_RETURN_NOT_OK(error, reader.Next().Value(&batch)); + if (!batch) break; + if (batch->num_rows() == 0) continue; + + const auto& schema_col = batch->column(4); + if (schema_col->type()->id() != Type::BINARY) { + SetError(error, "Server returned invalid schema; expected binary, found ", + *schema_col->type()); + return ADBC_STATUS_INTERNAL; + } + const auto& binary = checked_cast(*schema_col); + if (!binary.IsValid(0)) { + SetError(error, "Schema was null though field is non-null"); + return ADBC_STATUS_INTERNAL; + } + + ipc::DictionaryMemo memo; + io::BufferReader stream(Buffer::FromString(binary.GetString(0))); + std::shared_ptr arrow_schema; + ADBC_ARROW_RETURN_NOT_OK(error, + ipc::ReadSchema(&stream, &memo).Value(&arrow_schema)); + ADBC_ARROW_RETURN_NOT_OK(error, arrow::ExportSchema(*arrow_schema, schema)); + return ADBC_STATUS_OK; + } + + SetError(error, "No table found meeting criteria"); + return ADBC_STATUS_NOT_FOUND; + } + + AdbcStatusCode GetTableTypes(struct ArrowArrayStream* stream, struct AdbcError* error) { + FlightCallOptions call_options = MakeCallOptions(CallContext::kQuery); + std::unique_ptr flight_info; + auto status = client_->GetTableTypes(call_options).Value(&flight_info); + if (!status.ok()) { + SetError(error, status); + return ADBC_STATUS_IO; + } + return FlightInfoReader::Export(client_.get(), MakeCallOptions(CallContext::kFetch), + std::move(flight_info), stream, error); + } + + //---------------------------------------------------------- + // Partitioned Results + //---------------------------------------------------------- + + AdbcStatusCode ReadPartition(const uint8_t* serialized_partition, + size_t serialized_length, struct ArrowArrayStream* out, + struct AdbcError* error) { + std::unique_ptr info; + ADBC_ARROW_RETURN_NOT_OK( + error, FlightInfo::Deserialize( + std::string_view(reinterpret_cast(serialized_partition), + serialized_length)) + .Value(&info)); + return FlightInfoReader::Export(client_.get(), MakeCallOptions(CallContext::kFetch), + std::move(info), out, error); + } + + //---------------------------------------------------------- + // Transactions + //---------------------------------------------------------- + + AdbcStatusCode Commit(struct AdbcError* error) { + if (transaction_.is_valid()) { + ADBC_RETURN_NOT_OK(CheckTransactionSupport(error)); + FlightCallOptions call_options = MakeCallOptions(CallContext::kUpdate); + ADBC_ARROW_RETURN_NOT_OK(error, client_->Commit(call_options, transaction_)); + ADBC_ARROW_RETURN_NOT_OK( + error, client_->BeginTransaction(call_options).Value(&transaction_)); + return ADBC_STATUS_OK; + } + SetError(error, "Cannot commit when autocommit is enabled"); + return ADBC_STATUS_INVALID_STATE; + } + + AdbcStatusCode Rollback(struct AdbcError* error) { + if (transaction_.is_valid()) { + ADBC_RETURN_NOT_OK(CheckTransactionSupport(error)); + FlightCallOptions call_options = MakeCallOptions(CallContext::kUpdate); + ADBC_ARROW_RETURN_NOT_OK(error, client_->Rollback(call_options, transaction_)); + ADBC_ARROW_RETURN_NOT_OK( + error, client_->BeginTransaction(call_options).Value(&transaction_)); + return ADBC_STATUS_OK; + } + SetError(error, "Cannot rollback when autocommit is enabled"); + return ADBC_STATUS_INVALID_STATE; + } + + private: + // Check SqlInfo to determine what the server does/doesn't support. + Status QuerySqlInfo() { + FlightCallOptions call_options = MakeCallOptions(CallContext::kQuery); + ARROW_ASSIGN_OR_RAISE( + auto info, client_->GetSqlInfo(call_options, + { + SqlInfoOptions::FLIGHT_SQL_SERVER_TRANSACTION, + })); + + FlightInfoReader reader(client_.get(), MakeCallOptions(CallContext::kFetch), + std::move(info)); + struct AdbcError error = {}; + if (reader.Init(&error) != ADBC_STATUS_OK) { + std::string message = "Could not initialize reader: "; + if (error.message) { + message += error.message; + } + if (error.release) { + error.release(&error); + } + return Status::IOError(std::move(message)); + } + + constexpr int8_t kInt32Code = 3; + while (true) { + ARROW_ASSIGN_OR_RAISE(auto batch, reader.Next()); + if (!batch) break; + + const auto& sql_codes = checked_cast(*batch->column(0)); + const auto& sql_value = checked_cast(*batch->column(1)); + const auto& sql_int32 = + checked_cast(*sql_value.field(kInt32Code)); + for (int64_t i = 0; i < batch->num_rows(); i++) { + if (!sql_codes.IsValid(i)) continue; + + switch (sql_codes.Value(i)) { + case SqlInfoOptions::FLIGHT_SQL_SERVER_TRANSACTION: { + if (sql_value.type_code(i) != kInt32Code) { + continue; + } + const int32_t idx = sql_value.value_offset(i); + if (!sql_int32.IsValid(idx)) { + continue; + } + const int32_t value = sql_int32.IsValid(idx) && sql_int32.Value(idx); + support_.transactions = + (value == SqlInfoOptions::SQL_SUPPORTED_TRANSACTION_TRANSACTION || + value == SqlInfoOptions::SQL_SUPPORTED_TRANSACTION_SAVEPOINT); + break; + } + default: + continue; + } + } + } + + return reader.Close(); + } + + AdbcStatusCode CheckTransactionSupport(struct AdbcError* error) { + if (!support_.transactions) { + SetError(error, "Server does not report transaction support"); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + return ADBC_STATUS_OK; + } + + struct { + bool transactions = false; + } support_; + std::shared_ptr database_; + std::unique_ptr client_; + Transaction transaction_ = no_transaction(); + std::shared_ptr quirks_; + std::unordered_map> timeout_seconds_; + std::unordered_map call_headers_; +}; + +class FlightSqlPartitionsImpl { + public: + explicit FlightSqlPartitionsImpl(std::vector partitions) + : partitions_(std::move(partitions)) { + pointers_.resize(partitions_.size()); + lengths_.reserve(partitions_.size()); + + for (size_t i = 0; i < partitions_.size(); i++) { + pointers_[i] = reinterpret_cast(partitions_[i].data()); + lengths_[i] = partitions_[i].size(); + } + } + + static AdbcStatusCode Export(const FlightInfo& info, struct AdbcPartitions* out, + struct AdbcError* error) { + std::vector partitions; + partitions.reserve(info.endpoints().size()); + for (const FlightEndpoint& endpoint : info.endpoints()) { + FlightInfo partition_info(FlightInfo::Data{ + info.serialized_schema(), + info.descriptor(), + {endpoint}, + /*total_records=*/-1, + /*total_bytes=*/-1, + }); + std::string serialized; + ADBC_ARROW_RETURN_NOT_OK(error, + partition_info.SerializeToString().Value(&serialized)); + partitions.push_back(std::move(serialized)); + } + + auto* impl = new FlightSqlPartitionsImpl(std::move(partitions)); + out->num_partitions = impl->partitions_.size(); + out->partitions = impl->pointers_.data(); + out->partition_lengths = impl->lengths_.data(); + out->private_data = impl; + out->release = &Release; + return ADBC_STATUS_OK; + } + + static void Release(struct AdbcPartitions* partitions) { + auto* impl = static_cast(partitions->private_data); + delete impl; + partitions->num_partitions = 0; + partitions->partitions = nullptr; + partitions->partition_lengths = nullptr; + partitions->private_data = nullptr; + partitions->release = nullptr; + } + + private: + std::vector partitions_; + std::vector pointers_; + std::vector lengths_; +}; + +class FlightSqlStatementImpl { + public: + explicit FlightSqlStatementImpl(std::shared_ptr connection) + : connection_(std::move(connection)) {} + + AdbcStatusCode Bind(struct ArrowArray* array, struct ArrowSchema* schema, + struct AdbcError* error) { + std::shared_ptr batch; + ADBC_ARROW_RETURN_NOT_OK(error, ImportRecordBatch(array, schema).Value(&batch)); + ADBC_ARROW_RETURN_NOT_OK( + error, RecordBatchReader::Make({std::move(batch)}).Value(&bind_parameters_)); + return ADBC_STATUS_OK; + } + + AdbcStatusCode Bind(struct ArrowArrayStream* stream, struct AdbcError* error) { + ADBC_ARROW_RETURN_NOT_OK(error, + ImportRecordBatchReader(stream).Value(&bind_parameters_)); + return ADBC_STATUS_OK; + } + + AdbcStatusCode Close(struct AdbcError* error) { + ADBC_RETURN_NOT_OK(ClosePreparedStatement(error)); + return ADBC_STATUS_OK; + } + + AdbcStatusCode ExecutePartitions(struct ArrowSchema* schema, + struct AdbcPartitions* partitions, + int64_t* rows_affected, struct AdbcError* error) { + std::unique_ptr info; + ADBC_ARROW_RETURN_NOT_OK(error, ExecuteFlightInfo().Value(&info)); + if (rows_affected) *rows_affected = info->total_records(); + + ipc::DictionaryMemo memo; + std::shared_ptr arrow_schema; + ADBC_ARROW_RETURN_NOT_OK(error, info->GetSchema(&memo).Value(&arrow_schema)); + ADBC_ARROW_RETURN_NOT_OK(error, ExportSchema(*arrow_schema, schema)); + ADBC_RETURN_NOT_OK(FlightSqlPartitionsImpl::Export(*info, partitions, error)); + return ADBC_STATUS_OK; + } + + AdbcStatusCode ExecuteQuery(struct ArrowArrayStream* out, int64_t* rows_affected, + struct AdbcError* error) { + if (!ingest_.target_table.empty()) { + if (out) { + SetError(error, "Must not provide out for bulk ingest"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + return ExecuteIngest(rows_affected, error); + } else if (plan_.plan.empty() && query_.empty() && !prepared_statement_) { + SetError(error, "Must provide query"); + return ADBC_STATUS_INVALID_STATE; + } + + if (!out) { + int64_t out_rows = 0; + ADBC_ARROW_RETURN_NOT_OK(error, ExecuteUpdate().Value(&out_rows)); + if (rows_affected) *rows_affected = out_rows; + return ADBC_STATUS_OK; + } + + std::unique_ptr info; + ADBC_ARROW_RETURN_NOT_OK(error, ExecuteFlightInfo().Value(&info)); + if (rows_affected) *rows_affected = info->total_records(); + return FlightInfoReader::Export(connection_->client(), + MakeCallOptions(CallContext::kFetch), std::move(info), + out, error); + } + + AdbcStatusCode GetParameterSchema(struct ArrowSchema* schema, struct AdbcError* error) { + if (!prepared_statement_) { + SetError(error, "Must Prepare() before GetParameterSchema()"); + return ADBC_STATUS_INVALID_STATE; + } + ADBC_ARROW_RETURN_NOT_OK( + error, arrow::ExportSchema(*prepared_statement_->parameter_schema(), schema)); + return ADBC_STATUS_OK; + } + + AdbcStatusCode Prepare(struct AdbcError* error) { + if (plan_.plan.empty() && query_.empty()) { + SetError(error, "Must provide query"); + return ADBC_STATUS_INVALID_STATE; + } + FlightCallOptions call_options = MakeCallOptions(CallContext::kUpdate); + ADBC_RETURN_NOT_OK(ClosePreparedStatement(error)); + if (!plan_.plan.empty()) { + DCHECK(query_.empty()); + ADBC_ARROW_RETURN_NOT_OK( + error, connection_->client() + ->PrepareSubstrait(call_options, plan_, connection_->transaction()) + .Value(&prepared_statement_)); + } else { + ADBC_ARROW_RETURN_NOT_OK( + error, connection_->client() + ->Prepare(call_options, query_, connection_->transaction()) + .Value(&prepared_statement_)); + } + return ADBC_STATUS_OK; + } + + AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error) { + if (key == nullptr) { + SetError(error, "Key must not be null"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + std::string_view key_view(key); + std::string_view val_view = value ? value : ""; + if (key_view == kIngestOptionTargetTable) { + ADBC_RETURN_NOT_OK(ClearQueryParams(error)); + ingest_.target_table = val_view; + } else if (key_view == kIngestOptionMode) { + if (value == kIngestOptionModeCreate) { + ingest_.append = false; + } else if (value == kIngestOptionModeAppend) { + ingest_.append = true; + } else { + SetError(error, "Invalid statement option value ", key_view, "=", val_view); + return ADBC_STATUS_INVALID_ARGUMENT; + } + } else if (key_view == kStatementSubstraitVersionKey) { + plan_.version = value; + } else if (key_view.rfind(kCallHeaderPrefix, 0) == 0) { + std::string header(key_view.substr(kCallHeaderPrefix.size())); + if (value == nullptr) { + call_headers_.erase(header); + } else { + call_headers_.insert({std::move(header), std::string(val_view)}); + } + return ADBC_STATUS_OK; + } else { + SetError(error, "Unknown statement option ", key_view, "=", val_view); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + return ADBC_STATUS_OK; + } + + AdbcStatusCode SetSqlQuery(const char* query, struct AdbcError* error) { + ADBC_RETURN_NOT_OK(ClearQueryParams(error)); + query_ = query; + return ADBC_STATUS_OK; + } + + AdbcStatusCode SetSubstraitPlan(const uint8_t* plan, size_t length, + struct AdbcError* error) { + ADBC_RETURN_NOT_OK(ClearQueryParams(error)); + plan_.plan = std::string(reinterpret_cast(plan), length); + return ADBC_STATUS_OK; + } + + private: + AdbcStatusCode ClosePreparedStatement(struct AdbcError* error) { + if (prepared_statement_) { + FlightCallOptions call_options = MakeCallOptions(CallContext::kUpdate); + ADBC_ARROW_RETURN_NOT_OK(error, prepared_statement_->Close(call_options)); + prepared_statement_ = nullptr; + } + return ADBC_STATUS_OK; + } + + AdbcStatusCode ClearQueryParams(struct AdbcError* error) { + ADBC_RETURN_NOT_OK(ClosePreparedStatement(error)); + ingest_.target_table.clear(); + plan_.plan.clear(); + query_.clear(); + if (bind_parameters_) { + ADBC_ARROW_RETURN_NOT_OK(error, bind_parameters_->Close()); + bind_parameters_.reset(); + } + return ADBC_STATUS_OK; + } + + AdbcStatusCode ExecuteIngest(int64_t* rows_affected, struct AdbcError* error) { + if (!bind_parameters_) { + SetError(error, "Must Bind() before bulk ingestion"); + return ADBC_STATUS_INVALID_STATE; + } + if (!IsApproximatelyValidIdentifier(ingest_.target_table)) { + SetError(error, "Invalid target table ", ingest_.target_table, + ": must be alphanumeric with underscores"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + FlightCallOptions call_options = MakeCallOptions(CallContext::kUpdate); + if (!ingest_.append) { + std::string create = "CREATE TABLE "; + + create += ingest_.target_table; + create += " ("; + + bool first = true; + for (const std::shared_ptr& field : bind_parameters_->schema()->fields()) { + if (!first) { + create += ", "; + } + first = false; + + if (!IsApproximatelyValidIdentifier(field->name())) { + SetError(error, "Invalid column name ", field->name(), + ": must be alphanumeric with underscores"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + create += field->name(); + const std::unordered_map& mapping = + connection_->quirks().ingest_type_mapping; + const auto it = mapping.find(field->type()->id()); + if (it == mapping.end()) { + SetError(error, "Data type not supported for bulk ingest: ", field->ToString()); + return ADBC_STATUS_NOT_IMPLEMENTED; + } else { + create += " "; + create += it->second; + } + } + create += ")"; + SetError(error, "Creating table via: ", create); + + ADBC_ARROW_RETURN_NOT_OK( + error, connection_->client() + ->ExecuteUpdate(call_options, create, connection_->transaction()) + .status()); + } + + std::string append = "INSERT INTO "; + { + append += ingest_.target_table; + append += " VALUES ("; + for (int i = 0; i < bind_parameters_->schema()->num_fields(); i++) { + if (i > 0) append += ", "; + // TODO(lidavidm): add config option for parameter symbol (or + // query Flight SQL metadata?) + append += "?"; + } + append += ")"; + SetError(error, "Updating table via: ", append); + } + + std::shared_ptr stmt; + ADBC_ARROW_RETURN_NOT_OK( + error, connection_->client() + ->Prepare(call_options, append, connection_->transaction()) + .Value(&stmt)); + + int64_t total_rows = 0; + ADBC_ARROW_RETURN_NOT_OK(error, stmt->SetParameters(std::move(bind_parameters_))); + ADBC_ARROW_RETURN_NOT_OK(error, stmt->ExecuteUpdate().Value(&total_rows)); + ADBC_ARROW_RETURN_NOT_OK(error, stmt->Close()); + if (rows_affected) *rows_affected = total_rows; + return ADBC_STATUS_OK; + } + + arrow::Result ExecuteUpdate() { + FlightCallOptions call_options = MakeCallOptions(CallContext::kUpdate); + if (prepared_statement_) { + if (bind_parameters_) { + ARROW_RETURN_NOT_OK( + prepared_statement_->SetParameters(std::move(bind_parameters_))); + } + return prepared_statement_->ExecuteUpdate(call_options); + } + + if (!plan_.plan.empty()) { + return connection_->client()->ExecuteSubstraitUpdate(call_options, plan_, + connection_->transaction()); + } + return connection_->client()->ExecuteUpdate(call_options, query_, + connection_->transaction()); + } + + arrow::Result> ExecuteFlightInfo() { + FlightCallOptions call_options = MakeCallOptions(CallContext::kQuery); + if (prepared_statement_) { + ARROW_RETURN_NOT_OK( + prepared_statement_->SetParameters(std::move(bind_parameters_))); + return prepared_statement_->Execute(call_options); + } + if (!plan_.plan.empty()) { + return connection_->client()->ExecuteSubstrait(call_options, plan_, + connection_->transaction()); + } + return connection_->client()->Execute(call_options, query_, + connection_->transaction()); + } + + FlightCallOptions MakeCallOptions(CallContext context) const { + FlightCallOptions options = connection_->MakeCallOptions(context); + for (const auto& header : call_headers_) { + options.headers.emplace_back(header.first, header.second); + } + return options; + } + + std::shared_ptr connection_; + std::shared_ptr prepared_statement_; + SubstraitPlan plan_; + std::string query_; + std::shared_ptr bind_parameters_; + std::unordered_map call_headers_; + // Bulk ingest state + struct { + std::string target_table; + bool append = false; + } ingest_; +}; + +AdbcStatusCode FlightSqlDatabaseNew(struct AdbcDatabase* database, + struct AdbcError* error) { + auto impl = std::make_shared(); + database->private_data = new std::shared_ptr(impl); + return ADBC_STATUS_OK; +} + +AdbcStatusCode FlightSqlDatabaseSetOption(struct AdbcDatabase* database, const char* key, + const char* value, struct AdbcError* error) { + if (!database || !database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(database->private_data); + return (*ptr)->SetOption(key, value, error); +} + +AdbcStatusCode FlightSqlDatabaseInit(struct AdbcDatabase* database, + struct AdbcError* error) { + if (!database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(database->private_data); + return (*ptr)->Init(error); +} + +AdbcStatusCode FlightSqlDatabaseRelease(struct AdbcDatabase* database, + struct AdbcError* error) { + if (!database->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = + reinterpret_cast*>(database->private_data); + AdbcStatusCode status = (*ptr)->Release(error); + delete ptr; + database->private_data = nullptr; + return status; +} + +AdbcStatusCode FlightSqlConnectionCommit(struct AdbcConnection* connection, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = reinterpret_cast*>( + connection->private_data); + return (*ptr)->Commit(error); +} + +AdbcStatusCode FlightSqlConnectionGetInfo(struct AdbcConnection* connection, + uint32_t* info_codes, size_t info_codes_length, + struct ArrowArrayStream* stream, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = reinterpret_cast*>( + connection->private_data); + return (*ptr)->GetInfo(info_codes, info_codes_length, stream, error); +} + +AdbcStatusCode FlightSqlConnectionGetObjects( + struct AdbcConnection* connection, int depth, const char* catalog, + const char* db_schema, const char* table_name, const char** table_type, + const char* column_name, struct ArrowArrayStream* stream, struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = reinterpret_cast*>( + connection->private_data); + return (*ptr)->GetObjects(depth, catalog, db_schema, table_name, table_type, + column_name, stream, error); +} + +AdbcStatusCode FlightSqlConnectionGetTableSchema( + struct AdbcConnection* connection, const char* catalog, const char* db_schema, + const char* table_name, struct ArrowSchema* schema, struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = reinterpret_cast*>( + connection->private_data); + return (*ptr)->GetTableSchema(catalog, db_schema, table_name, schema, error); +} + +AdbcStatusCode FlightSqlConnectionGetTableTypes(struct AdbcConnection* connection, + struct ArrowArrayStream* stream, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = reinterpret_cast*>( + connection->private_data); + return (*ptr)->GetTableTypes(stream, error); +} + +AdbcStatusCode FlightSqlConnectionInit(struct AdbcConnection* connection, + struct AdbcDatabase* database, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto ptr = reinterpret_cast*>( + connection->private_data); + return (*ptr)->Init(database, error); +} + +AdbcStatusCode FlightSqlConnectionNew(struct AdbcConnection* connection, + struct AdbcError* error) { + auto impl = std::make_shared(); + connection->private_data = new std::shared_ptr(impl); + return ADBC_STATUS_OK; +} + +AdbcStatusCode FlightSqlConnectionReadPartition(struct AdbcConnection* connection, + const uint8_t* serialized_partition, + size_t serialized_length, + struct ArrowArrayStream* out, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = reinterpret_cast*>( + connection->private_data); + return (*ptr)->ReadPartition(serialized_partition, serialized_length, out, error); +} + +AdbcStatusCode FlightSqlConnectionRelease(struct AdbcConnection* connection, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = reinterpret_cast*>( + connection->private_data); + auto status = (*ptr)->Close(error); + delete ptr; + connection->private_data = nullptr; + return status; +} + +AdbcStatusCode FlightSqlConnectionRollback(struct AdbcConnection* connection, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = reinterpret_cast*>( + connection->private_data); + return (*ptr)->Rollback(error); +} + +AdbcStatusCode FlightSqlConnectionSetOption(struct AdbcConnection* connection, + const char* key, const char* value, + struct AdbcError* error) { + if (!connection->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = reinterpret_cast*>( + connection->private_data); + return (*ptr)->SetOption(key, value, error); +} + +AdbcStatusCode FlightSqlStatementBind(struct AdbcStatement* statement, + struct ArrowArray* values, + struct ArrowSchema* schema, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->Bind(values, schema, error); +} + +AdbcStatusCode FlightSqlStatementBindStream(struct AdbcStatement* statement, + struct ArrowArrayStream* stream, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->Bind(stream, error); +} + +AdbcStatusCode FlightSqlStatementExecutePartitions(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcPartitions* partitions, + int64_t* rows_affected, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->ExecutePartitions(schema, partitions, rows_affected, error); +} + +AdbcStatusCode FlightSqlStatementExecuteQuery(struct AdbcStatement* statement, + struct ArrowArrayStream* out, + int64_t* rows_affected, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->ExecuteQuery(out, rows_affected, error); +} + +AdbcStatusCode FlightSqlStatementGetParameterSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->GetParameterSchema(schema, error); +} + +AdbcStatusCode FlightSqlStatementNew(struct AdbcConnection* connection, + struct AdbcStatement* statement, + struct AdbcError* error) { + auto* ptr = reinterpret_cast*>( + connection->private_data); + auto impl = std::make_shared(*ptr); + statement->private_data = new std::shared_ptr(impl); + return ADBC_STATUS_OK; +} + +AdbcStatusCode FlightSqlStatementPrepare(struct AdbcStatement* statement, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->Prepare(error); +} + +AdbcStatusCode FlightSqlStatementRelease(struct AdbcStatement* statement, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = + reinterpret_cast*>(statement->private_data); + auto status = (*ptr)->Close(error); + delete ptr; + statement->private_data = nullptr; + return status; +} + +AdbcStatusCode FlightSqlStatementSetOption(struct AdbcStatement* statement, + const char* key, const char* value, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->SetOption(key, value, error); +} + +AdbcStatusCode FlightSqlStatementSetSqlQuery(struct AdbcStatement* statement, + const char* query, struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->SetSqlQuery(query, error); +} + +AdbcStatusCode FlightSqlStatementSetSubstraitPlan(struct AdbcStatement* statement, + const uint8_t* plan, size_t length, + struct AdbcError* error) { + if (!statement->private_data) return ADBC_STATUS_INVALID_STATE; + auto* ptr = + reinterpret_cast*>(statement->private_data); + return (*ptr)->SetSubstraitPlan(plan, length, error); +} +} // namespace +} // namespace arrow::flight::sql + +AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) { + return arrow::flight::sql::FlightSqlDatabaseInit(database, error); +} + +AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { + return arrow::flight::sql::FlightSqlDatabaseNew(database, error); +} + +AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, + const char* value, struct AdbcError* error) { + return arrow::flight::sql::FlightSqlDatabaseSetOption(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, + struct AdbcError* error) { + return arrow::flight::sql::FlightSqlDatabaseRelease(database, error); +} + +AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, + struct AdbcError* error) { + return arrow::flight::sql::FlightSqlConnectionCommit(connection, error); +} + +AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, + uint32_t* info_codes, size_t info_codes_length, + struct ArrowArrayStream* stream, + struct AdbcError* error) { + return arrow::flight::sql::FlightSqlConnectionGetInfo(connection, info_codes, + info_codes_length, stream, error); +} + +AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth, + const char* catalog, const char* db_schema, + const char* table_name, const char** table_type, + const char* column_name, + struct ArrowArrayStream* stream, + struct AdbcError* error) { + return arrow::flight::sql::FlightSqlConnectionGetObjects( + connection, depth, catalog, db_schema, table_name, table_type, column_name, stream, + error); +} + +AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, + struct ArrowSchema* schema, + struct AdbcError* error) { + return arrow::flight::sql::FlightSqlConnectionGetTableSchema( + connection, catalog, db_schema, table_name, schema, error); +} + +AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, + struct ArrowArrayStream* stream, + struct AdbcError* error) { + return arrow::flight::sql::FlightSqlConnectionGetTableTypes(connection, stream, error); +} + +AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, + struct AdbcDatabase* database, + struct AdbcError* error) { + return arrow::flight::sql::FlightSqlConnectionInit(connection, database, error); +} + +AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, + struct AdbcError* error) { + return arrow::flight::sql::FlightSqlConnectionNew(connection, error); +} + +AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, + const uint8_t* serialized_partition, + size_t serialized_length, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return arrow::flight::sql::FlightSqlConnectionReadPartition( + connection, serialized_partition, serialized_length, out, error); +} + +AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, + struct AdbcError* error) { + return arrow::flight::sql::FlightSqlConnectionRelease(connection, error); +} + +AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, + struct AdbcError* error) { + return arrow::flight::sql::FlightSqlConnectionRollback(connection, error); +} + +AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, + const char* value, struct AdbcError* error) { + return arrow::flight::sql::FlightSqlConnectionSetOption(connection, key, value, error); +} + +AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement, + struct ArrowArray* values, struct ArrowSchema* schema, + struct AdbcError* error) { + return arrow::flight::sql::FlightSqlStatementBind(statement, values, schema, error); +} + +AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, + struct ArrowArrayStream* stream, + struct AdbcError* error) { + return arrow::flight::sql::FlightSqlStatementBindStream(statement, stream, error); +} + +// XXX: cpplint gets confused if declared as struct ArrowSchema* +AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, + ArrowSchema* schema, + struct AdbcPartitions* partitions, + int64_t* rows_affected, + struct AdbcError* error) { + return arrow::flight::sql::FlightSqlStatementExecutePartitions( + statement, schema, partitions, rows_affected, error); +} + +AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, + struct ArrowArrayStream* out, + int64_t* rows_affected, + struct AdbcError* error) { + return arrow::flight::sql::FlightSqlStatementExecuteQuery(statement, out, rows_affected, + error); +} + +AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + return arrow::flight::sql::FlightSqlStatementGetParameterSchema(statement, schema, + error); +} + +AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, + struct AdbcStatement* statement, + struct AdbcError* error) { + return arrow::flight::sql::FlightSqlStatementNew(connection, statement, error); +} + +AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, + struct AdbcError* error) { + return arrow::flight::sql::FlightSqlStatementPrepare(statement, error); +} + +AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, + struct AdbcError* error) { + return arrow::flight::sql::FlightSqlStatementRelease(statement, error); +} + +AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const char* key, + const char* value, struct AdbcError* error) { + return arrow::flight::sql::FlightSqlStatementSetOption(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, + const char* query, struct AdbcError* error) { + return arrow::flight::sql::FlightSqlStatementSetSqlQuery(statement, query, error); +} + +AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, + const uint8_t* plan, size_t length, + struct AdbcError* error) { + return arrow::flight::sql::FlightSqlStatementSetSubstraitPlan(statement, plan, length, + error); +} + +extern "C" { +ARROW_FLIGHT_SQL_EXPORT +AdbcStatusCode AdbcDriverInit(int version, void* raw_driver, struct AdbcError* error) { + if (version != ADBC_VERSION_1_0_0) return ADBC_STATUS_NOT_IMPLEMENTED; + + auto* driver = reinterpret_cast(raw_driver); + std::memset(driver, 0, sizeof(*driver)); + + driver->DatabaseInit = arrow::flight::sql::FlightSqlDatabaseInit; + driver->DatabaseNew = arrow::flight::sql::FlightSqlDatabaseNew; + driver->DatabaseRelease = arrow::flight::sql::FlightSqlDatabaseRelease; + driver->DatabaseSetOption = arrow::flight::sql::FlightSqlDatabaseSetOption; + + driver->ConnectionCommit = arrow::flight::sql::FlightSqlConnectionCommit; + driver->ConnectionGetInfo = arrow::flight::sql::FlightSqlConnectionGetInfo; + driver->ConnectionGetObjects = arrow::flight::sql::FlightSqlConnectionGetObjects; + driver->ConnectionGetTableSchema = + arrow::flight::sql::FlightSqlConnectionGetTableSchema; + driver->ConnectionGetTableTypes = arrow::flight::sql::FlightSqlConnectionGetTableTypes; + driver->ConnectionInit = arrow::flight::sql::FlightSqlConnectionInit; + driver->ConnectionNew = arrow::flight::sql::FlightSqlConnectionNew; + driver->ConnectionReadPartition = arrow::flight::sql::FlightSqlConnectionReadPartition; + driver->ConnectionRelease = arrow::flight::sql::FlightSqlConnectionRelease; + driver->ConnectionRollback = arrow::flight::sql::FlightSqlConnectionRollback; + driver->ConnectionSetOption = arrow::flight::sql::FlightSqlConnectionSetOption; + + driver->StatementBind = arrow::flight::sql::FlightSqlStatementBind; + driver->StatementBindStream = arrow::flight::sql::FlightSqlStatementBindStream; + driver->StatementExecutePartitions = + arrow::flight::sql::FlightSqlStatementExecutePartitions; + driver->StatementExecuteQuery = arrow::flight::sql::FlightSqlStatementExecuteQuery; + driver->StatementGetParameterSchema = + arrow::flight::sql::FlightSqlStatementGetParameterSchema; + driver->StatementNew = arrow::flight::sql::FlightSqlStatementNew; + driver->StatementPrepare = arrow::flight::sql::FlightSqlStatementPrepare; + driver->StatementRelease = arrow::flight::sql::FlightSqlStatementRelease; + driver->StatementSetOption = arrow::flight::sql::FlightSqlStatementSetOption; + driver->StatementSetSqlQuery = arrow::flight::sql::FlightSqlStatementSetSqlQuery; + driver->StatementSetSubstraitPlan = + arrow::flight::sql::FlightSqlStatementSetSubstraitPlan; + + return ADBC_STATUS_OK; +} +} + +namespace arrow::flight::sql { + +AdbcStatusCode AdbcDriverInit(int version, void* raw_driver, struct AdbcError* error) { + return ::AdbcDriverInit(version, raw_driver, error); +} + +} // namespace arrow::flight::sql diff --git a/cpp/src/arrow/flight/sql/adbc_driver_internal.cc b/cpp/src/arrow/flight/sql/adbc_driver_internal.cc new file mode 100644 index 00000000000..1b0c893531b --- /dev/null +++ b/cpp/src/arrow/flight/sql/adbc_driver_internal.cc @@ -0,0 +1,191 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/sql/adbc_driver_internal.h" + +#include "arrow/c/bridge.h" + +namespace arrow::flight::sql { + +AdbcStatusCode ArrowToAdbcStatusCode(const Status& status) { + if (auto detail = FlightStatusDetail::UnwrapStatus(status)) { + switch (detail->code()) { + case FlightStatusCode::Internal: + return ADBC_STATUS_INTERNAL; + case FlightStatusCode::TimedOut: + return ADBC_STATUS_TIMEOUT; + case FlightStatusCode::Cancelled: + return ADBC_STATUS_CANCELLED; + case FlightStatusCode::Unauthenticated: + return ADBC_STATUS_UNAUTHENTICATED; + case FlightStatusCode::Unauthorized: + return ADBC_STATUS_UNAUTHORIZED; + case FlightStatusCode::Unavailable: + return ADBC_STATUS_IO; + case FlightStatusCode::Failed: + return ADBC_STATUS_INTERNAL; + default: + break; + } + } + switch (status.code()) { + case StatusCode::OK: + return ADBC_STATUS_OK; + case StatusCode::OutOfMemory: + return ADBC_STATUS_INTERNAL; + case StatusCode::KeyError: + return ADBC_STATUS_NOT_FOUND; + case StatusCode::TypeError: + return ADBC_STATUS_INVALID_ARGUMENT; + case StatusCode::Invalid: + return ADBC_STATUS_INVALID_ARGUMENT; + case StatusCode::IOError: + return ADBC_STATUS_IO; + case StatusCode::CapacityError: + return ADBC_STATUS_INTERNAL; + case StatusCode::IndexError: + return ADBC_STATUS_INVALID_ARGUMENT; + case StatusCode::Cancelled: + return ADBC_STATUS_CANCELLED; + case StatusCode::UnknownError: + return ADBC_STATUS_UNKNOWN; + case StatusCode::NotImplemented: + return ADBC_STATUS_NOT_IMPLEMENTED; + case StatusCode::SerializationError: + case StatusCode::RError: + case StatusCode::CodeGenError: + case StatusCode::ExpressionValidationError: + case StatusCode::ExecutionError: + return ADBC_STATUS_INTERNAL; + case StatusCode::AlreadyExists: + return ADBC_STATUS_INVALID_STATE; + default: + break; + } + return ADBC_STATUS_UNKNOWN; +} + +void ReleaseError(struct AdbcError* error) { + if (error->message) { + delete[] error->message; + error->message = nullptr; + } +} + +FlightClientOptions DefaultClientOptions() { + FlightClientOptions client_options = FlightClientOptions::Defaults(); + client_options.middleware.push_back(MakeTracingClientMiddlewareFactory()); + client_options.generic_options.emplace_back("grpc.primary_user_agent", + "ADBC Flight SQL " ARROW_VERSION_STRING); + return client_options; +} + +Status IndexDbSchemas(RecordBatchReader* reader, TableIndex* table_index) { + // TODO: unit test + ReaderIterator db_schemas(*SqlSchema::GetDbSchemasSchema(), reader); + ARROW_RETURN_NOT_OK(db_schemas.Init()); + while (true) { + ARROW_ASSIGN_OR_RAISE(bool have_data, db_schemas.Next()); + if (!have_data) break; + std::optional catalog(db_schemas.GetNullable(0)); + std::optional schema(db_schemas.GetNullable(1)); + + auto it = table_index->insert({catalog, {}}); + it.first->second.insert({schema, {0, 0}}); + } + return Status::OK(); +} + +Status IndexTables(RecordBatchReader* reader, TableIndex* table_index, + std::shared_ptr* table_data) { + ReaderIterator tables(*SqlSchema::GetTablesSchemaWithIncludedSchema(), reader); + ARROW_RETURN_NOT_OK(tables.Init()); + + int64_t start_index = 0; + int64_t index = 0; + RecordBatchVector batches; + + std::optional catalog_name; + std::optional db_schema_name; + bool first = true; + + while (true) { + ARROW_ASSIGN_OR_RAISE(bool have_data, tables.Next()); + if (!have_data) break; + + if (std::shared_ptr batch = tables.NewBatch(); batch != nullptr) { + batches.push_back(std::move(batch)); + } + + std::optional catalog(tables.GetNullable(0)); + std::optional schema(tables.GetNullable(1)); + + if (first || catalog != catalog_name || schema != db_schema_name) { + if (!first) { + auto it = table_index->find(catalog_name); + if (it != table_index->end()) { + it->second.insert_or_assign(db_schema_name, std::make_pair(start_index, index)); + } + } + catalog_name = std::move(catalog); + db_schema_name = std::move(schema); + start_index = index; + first = false; + } + + index++; + } + if (!first) { + auto it = table_index->find(catalog_name); + if (it != table_index->end()) { + it->second.insert_or_assign(db_schema_name, std::make_pair(start_index, index)); + } + } + + ARROW_ASSIGN_OR_RAISE(auto all_data, Table::FromRecordBatches( + SqlSchema::GetTablesSchemaWithIncludedSchema(), + std::move(batches))); + ARROW_ASSIGN_OR_RAISE(*table_data, all_data->CombineChunksToBatch()); + return Status::OK(); +} + +Status ExportRecordBatches(std::shared_ptr schema, RecordBatchVector batches, + struct ArrowArrayStream* stream) { + std::shared_ptr reader; + ARROW_ASSIGN_OR_RAISE(reader, + RecordBatchReader::Make(std::move(batches), std::move(schema))); + return arrow::ExportRecordBatchReader(std::move(reader), stream); +} + +bool IsApproximatelyValidIdentifier(std::string_view name) { + if (name.empty()) return false; + // First character must be a letter or underscore + if (!(name[0] == '_' || (name[0] >= 'a' && name[0] <= 'z') || + (name[0] >= 'A' && name[0] <= 'Z'))) { + return false; + } + // Subsequent characters must be underscores, letters, or digits + for (size_t i = 1; i < name.size(); i++) { + if (!(name[i] == '_' || (name[i] >= 'a' && name[i] <= 'z') || + (name[i] >= 'A' && name[i] <= 'Z') || (name[i] >= '0' && name[i] <= '9'))) { + return false; + } + } + return true; +} + +} // namespace arrow::flight::sql diff --git a/cpp/src/arrow/flight/sql/adbc_driver_internal.h b/cpp/src/arrow/flight/sql/adbc_driver_internal.h new file mode 100644 index 00000000000..4eec8389525 --- /dev/null +++ b/cpp/src/arrow/flight/sql/adbc_driver_internal.h @@ -0,0 +1,198 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "arrow/array/array_binary.h" +#include "arrow/flight/client.h" +#include "arrow/flight/client_tracing_middleware.h" +#include "arrow/flight/sql/server.h" +#include "arrow/flight/sql/visibility.h" +#include "arrow/record_batch.h" +#include "arrow/status.h" +#include "arrow/table.h" +#include "arrow/type.h" +#include "arrow/type_traits.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/macros.h" +#include "arrow/util/string_builder.h" + +// Override the built-in visibility defines with Arrow's +#define ADBC_EXPORT ARROW_FLIGHT_SQL_EXPORT +#include "arrow/c/adbc_internal.h" // IWYU pragma: export + +namespace arrow::flight::sql { + +// Internal utilities for the Flight SQL driver. + +#define ADBC_RETURN_NOT_OK(EXPR) \ + do { \ + auto _s = (EXPR); \ + if (_s != ADBC_STATUS_OK) return _s; \ + } while (false) + +#ifdef ARROW_EXTRA_ERROR_CONTEXT +#define ADBC_ARROW_RETURN_NOT_OK(ERROR, EXPR) \ + do { \ + if (Status _s = (EXPR); !_s.ok()) { \ + _s.AddContextLine(__FILE__, __LINE__, ARROW_STRINGIFY(EXPR)); \ + SetError(error, _s); \ + return ::arrow::flight::sql::ArrowToAdbcStatusCode(_s); \ + } \ + } while (false) +#else +#define ADBC_ARROW_RETURN_NOT_OK(ERROR, EXPR) \ + do { \ + if (Status _s = (EXPR); !_s.ok()) { \ + SetError(error, _s); \ + return ::arrow::flight::sql::ArrowToAdbcStatusCode(_s); \ + } \ + } while (false) +#endif + +ARROW_FLIGHT_SQL_EXPORT +AdbcStatusCode ArrowToAdbcStatusCode(const Status& status); +ARROW_FLIGHT_SQL_EXPORT +void ReleaseError(struct AdbcError* error); +ARROW_FLIGHT_SQL_EXPORT +FlightClientOptions DefaultClientOptions(); + +/// Helper to populate an AdbcError +template +void SetError(struct AdbcError* error, Args&&... args) { + if (!error) return; + std::string message = util::StringBuilder("[Flight SQL] ", std::forward(args)...); + if (error->message) { + message.reserve(message.size() + 1 + std::strlen(error->message)); + message.append(1, '\n'); + message.append(error->message); + delete[] error->message; + } + error->message = new char[message.size() + 1]; + message.copy(error->message, message.size()); + error->message[message.size()] = '\0'; + error->release = ReleaseError; +} + +template ::ArrayType> +auto GetNotNull(const RecordBatch& batch, int col_index, int64_t row_index) { + return arrow::internal::checked_cast(*batch.column(col_index)) + .GetView(row_index); +} + +template ::ArrayType> +std::optional> +GetNullable(const RecordBatch& batch, int col_index, int64_t row_index) { + // compiler unfortunately can't infer the return type here + const auto& arr = + arrow::internal::checked_cast(*batch.column(col_index)); + if (arr.IsNull(row_index)) return std::nullopt; + return arr.GetView(row_index); +} + +/// \brief Helper to iterate over a RecordBatchReader in a rowwise +/// manner +class ARROW_FLIGHT_SQL_EXPORT ReaderIterator { + public: + explicit ReaderIterator(const Schema& schema, RecordBatchReader* reader) + : schema_(schema), reader_(reader) {} + + Status Init() { + if (!schema_.Equals(*reader_->schema())) { + return Status::Invalid("Server sent the wrong schema.\nExpected:", schema_, + "\nActual:", *reader_->schema()); + } + return reader_->Next().Value(¤t_); + } + + std::shared_ptr NewBatch() const { + return current_row_ == 0 ? current_ : nullptr; + } + + arrow::Result Next() { + if (done_) return false; + + current_row_++; + while (current_ && current_row_ >= current_->num_rows()) { + ARROW_ASSIGN_OR_RAISE(current_, reader_->Next()); + if (!current_) { + done_ = true; + return false; + } + current_row_ = 0; + } + return current_ != nullptr; + } + + template ::ArrayType> + std::invoke_result_t GetNotNull( + int col_index) const { + return sql::GetNotNull(*current_, col_index, current_row_); + } + + template ::ArrayType> + std::optional> + GetNullable(int col_index) const { + return sql::GetNullable(*current_, col_index, current_row_); + } + + private: + const Schema& schema_; + RecordBatchReader* reader_; + std::shared_ptr current_; + int64_t current_row_ = -1; + bool done_ = false; +}; + +// {[catalog name]: {[db schema name]: (inclusive_lower_bound, exclusive_upper_bound)}} +using TableIndex = std::unordered_map< + std::optional, + std::unordered_map, std::pair>>; + +/// Build up an index of database metadata in-memory to help implement GetObjects. +ARROW_FLIGHT_SQL_EXPORT +Status IndexDbSchemas(RecordBatchReader* reader, TableIndex* table_index); + +/// Build up an index of database metadata in-memory to help implement GetObjects. +ARROW_FLIGHT_SQL_EXPORT +Status IndexTables(RecordBatchReader* reader, TableIndex* table_index, + std::shared_ptr* table_data); + +/// Export batches of data as an ArrayStream. +ARROW_FLIGHT_SQL_EXPORT +Status ExportRecordBatches(std::shared_ptr schema, RecordBatchVector batches, + struct ArrowArrayStream* stream); + +/// Check if a name is (probably) a valid SQL identifier. We don't +/// know the exact SQL syntax (though we could try to take advantage +/// of the info in GetSqlInfo), so this function is conservative about +/// allowed identifiers. +ARROW_FLIGHT_SQL_EXPORT +bool IsApproximatelyValidIdentifier(std::string_view name); + +} // namespace arrow::flight::sql diff --git a/cpp/src/arrow/flight/sql/adbc_driver_test.cc b/cpp/src/arrow/flight/sql/adbc_driver_test.cc new file mode 100644 index 00000000000..1b060ebc2cb --- /dev/null +++ b/cpp/src/arrow/flight/sql/adbc_driver_test.cc @@ -0,0 +1,1057 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include +#include +#include + +#include "arrow/c/bridge.h" +#include "arrow/flight/server.h" +#include "arrow/flight/server_middleware.h" +#include "arrow/flight/sql/adbc_driver_internal.h" +#include "arrow/flight/sql/example/sqlite_server.h" +#include "arrow/flight/sql/server.h" +#include "arrow/flight/test_util.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/util/config.h" + +#include "adbc_validation/adbc_validation.h" +#include "adbc_validation/adbc_validation_util.h" + +using arrow::internal::checked_cast; + +namespace arrow::flight::sql { + +using adbc_validation::IsOkStatus; + +#define ARROW_ADBC_RETURN_NOT_OK(STATUS, ERROR, EXPR) \ + do { \ + if (AdbcStatusCode _s = (EXPR); _s != ADBC_STATUS_OK) { \ + return Status::STATUS(adbc_validation::StatusCodeToString(_s), ": ", \ + (ERROR)->message ? (ERROR)->message : "(no message)"); \ + } \ + } while (false) + +// ------------------------------------------------------------ +// ADBC Test Suite + +class SqliteFlightSqlQuirks : public adbc_validation::DriverQuirks { + public: + AdbcStatusCode SetupDatabase(struct AdbcDatabase* database, + struct AdbcError* error) const override { + std::string location = location_.ToString(); + if (auto status = AdbcDatabaseSetOption( + database, "arrow.flight.sql.quirks.ingest_type.int64", "INT", error); + status != ADBC_STATUS_OK) { + return status; + } + return AdbcDatabaseSetOption(database, "uri", location.c_str(), error); + } + + [[nodiscard]] std::string BindParameter(int index) const override { return "?"; } + [[nodiscard]] bool supports_concurrent_statements() const override { return true; } + [[nodiscard]] bool supports_partitioned_data() const override { return true; } + + void StartServer() { + ASSERT_OK_AND_ASSIGN(auto bind_location, Location::ForGrpcTcp("0.0.0.0", 0)); + arrow::flight::FlightServerOptions options(bind_location); + ASSERT_OK_AND_ASSIGN(server_, example::SQLiteFlightSqlServer::Create()); + ASSERT_OK(server_->Init(options)); + + ASSERT_OK_AND_ASSIGN(location_, Location::ForGrpcTcp("localhost", server_->port())); + } + + void StopServer() { ASSERT_OK(server_->Shutdown()); } + + private: + std::shared_ptr server_; + Location location_; +}; + +class AdbcSqliteDatabaseTest : public ::testing::Test, + public adbc_validation::DatabaseTest { + public: + [[nodiscard]] const adbc_validation::DriverQuirks* quirks() const override { + return &quirks_; + } + void SetUp() override { + ASSERT_NO_FATAL_FAILURE(quirks_.StartServer()); + ASSERT_NO_FATAL_FAILURE(SetUpTest()); + } + void TearDown() override { + ASSERT_NO_FATAL_FAILURE(TearDownTest()); + ASSERT_NO_FATAL_FAILURE(quirks_.StopServer()); + } + + protected: + SqliteFlightSqlQuirks quirks_; +}; +ADBCV_TEST_DATABASE(AdbcSqliteDatabaseTest) + +class AdbcSqliteConnectionTest : public ::testing::Test, + public adbc_validation::ConnectionTest { + public: + [[nodiscard]] const adbc_validation::DriverQuirks* quirks() const override { + return &quirks_; + } + void SetUp() override { + ASSERT_NO_FATAL_FAILURE(quirks_.StartServer()); + ASSERT_NO_FATAL_FAILURE(SetUpTest()); + } + void TearDown() override { + ASSERT_NO_FATAL_FAILURE(TearDownTest()); + ASSERT_NO_FATAL_FAILURE(quirks_.StopServer()); + } + +#if !defined(ARROW_COMPUTE) + void TestMetadataGetObjectsColumns() { + GTEST_SKIP() << "Test fails without ARROW_COMPUTE"; + } +#endif + + protected: + SqliteFlightSqlQuirks quirks_; +}; +ADBCV_TEST_CONNECTION(AdbcSqliteConnectionTest) + +class AdbcSqliteStatementTest : public ::testing::Test, + public adbc_validation::StatementTest { + public: + [[nodiscard]] const adbc_validation::DriverQuirks* quirks() const override { + return &quirks_; + } + void SetUp() override { + ASSERT_NO_FATAL_FAILURE(quirks_.StartServer()); + ASSERT_NO_FATAL_FAILURE(SetUpTest()); + } + void TearDown() override { + ASSERT_NO_FATAL_FAILURE(TearDownTest()); + ASSERT_NO_FATAL_FAILURE(quirks_.StopServer()); + } + + protected: + SqliteFlightSqlQuirks quirks_; +}; +ADBCV_TEST_STATEMENT(AdbcSqliteStatementTest) + +// Ensure partitions can be introspected by clients who know they're +// using Flight SQL +TEST_F(AdbcSqliteStatementTest, IntrospectPartitions) { + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error), + IsOkStatus(&error)); + + struct ArrowSchema c_schema; + struct AdbcPartitions partitions; + ASSERT_THAT(AdbcStatementExecutePartitions(&statement, &c_schema, &partitions, + /*rows_affected=*/nullptr, &error), + IsOkStatus(&error)); + + ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema)); + + ASSERT_GT(partitions.num_partitions, 0); + for (size_t i = 0; i < partitions.num_partitions; i++) { + ASSERT_OK_AND_ASSIGN(auto info, + FlightInfo::Deserialize(std::string_view( + reinterpret_cast(partitions.partitions[i]), + partitions.partition_lengths[i]))); + ipc::DictionaryMemo memo; + ASSERT_OK_AND_ASSIGN(auto info_schema, info->GetSchema(&memo)); + ASSERT_NO_FATAL_FAILURE(AssertSchemaEqual(*schema, *info_schema, /*verbose=*/true)); + ASSERT_EQ(1, info->endpoints().size()); + } + + partitions.release(&partitions); +} + +constexpr std::initializer_list kInvalidNames = { + "???", "", "名前", "[quoted]", "9abc", "foo-bar"}; + +TEST_F(AdbcSqliteStatementTest, InvalidTableName) { + for (const auto& name : kInvalidNames) { + ARROW_SCOPED_TRACE(name); + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT( + AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, name, &error), + IsOkStatus(&error)); + auto schema = arrow::schema({field("ints", int64())}); + auto batch = RecordBatchFromJSON(schema, R"([[1]])"); + struct ArrowSchema c_schema = {}; + struct ArrowArray c_batch = {}; + ASSERT_OK(ExportSchema(*schema, &c_schema)); + ASSERT_OK(ExportRecordBatch(*batch, &c_batch)); + ASSERT_THAT(AdbcStatementBind(&statement, &c_batch, &c_schema, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, /*out=*/nullptr, + /*rows_affected=*/nullptr, &error), + ::testing::Not(IsOkStatus(&error))); + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); + } +} + +TEST_F(AdbcSqliteStatementTest, InvalidColumnName) { + for (const auto& name : kInvalidNames) { + ARROW_SCOPED_TRACE(name); + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, + "foobar", &error), + IsOkStatus(&error)); + auto schema = arrow::schema({field(name, int64())}); + auto batch = RecordBatchFromJSON(schema, R"([[1]])"); + struct ArrowSchema c_schema = {}; + struct ArrowArray c_batch = {}; + ASSERT_OK(ExportSchema(*schema, &c_schema)); + ASSERT_OK(ExportRecordBatch(*batch, &c_batch)); + ASSERT_THAT(AdbcStatementBind(&statement, &c_batch, &c_schema, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, /*out=*/nullptr, + /*rows_affected=*/nullptr, &error), + ::testing::Not(IsOkStatus(&error))); + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); + } +} + +TEST(FlightSqlInternals, NameSanitizer) { + for (const auto& name : kInvalidNames) { + ASSERT_FALSE(IsApproximatelyValidIdentifier(name)) << name; + } + for (const auto& name : {"_foo", "a9", "foo_bar", "ABCD"}) { + ASSERT_TRUE(IsApproximatelyValidIdentifier(name)) << name; + } +} + +class AdbcTimeoutTestServer : public FlightSqlServerBase { + arrow::Result BeginTransaction( + const ServerCallContext& context, + const ActionBeginTransactionRequest& request) override { + while (!context.is_cancelled()) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + return Status::NotImplemented("NYI"); + } + + arrow::Result> DoGetStatement( + const ServerCallContext& context, const StatementQueryTicket& command) override { + while (!context.is_cancelled()) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + return Status::NotImplemented("NYI"); + } + + arrow::Result DoPutCommandStatementUpdate( + const ServerCallContext& context, const StatementUpdate& command) override { + if (command.query == "timeout") { + while (!context.is_cancelled()) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + } + return Status::NotImplemented("NYI"); + } + + arrow::Result> GetFlightInfoStatement( + const ServerCallContext& context, const StatementQuery& command, + const FlightDescriptor& descriptor) override { + if (command.query == "timeout") { + while (!context.is_cancelled()) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + } else if (command.query == "fetch") { + ARROW_ASSIGN_OR_RAISE(std::string ticket, CreateStatementQueryTicket("fetch")); + ARROW_ASSIGN_OR_RAISE( + auto info, FlightInfo::Make(Schema({}), descriptor, + { + FlightEndpoint{Ticket{std::move(ticket)}, {}}, + }, + /*total_records=*/-1, /*total_bytes=*/-1)); + return std::make_unique(std::move(info)); + } + return Status::NotImplemented("NYI"); + } +}; + +template +class AdbcServerTest : public ::testing::Test { + protected: + virtual arrow::Result GetLocation(const std::string& host, int port) { + return Location::ForGrpcTcp(host, port); + } + virtual void Configure(FlightServerOptions* options) {} + void SetUp() override { + ASSERT_OK_AND_ASSIGN(auto bind_location, GetLocation("0.0.0.0", 0)); + FlightServerOptions options(bind_location); + Configure(&options); + server_ = std::make_unique(); + ASSERT_OK(server_->Init(options)); + + ASSERT_OK_AND_ASSIGN(auto connect_location, + GetLocation("localhost", server_->port())); + std::string uri = connect_location.ToString(); + ASSERT_THAT(AdbcDatabaseNew(&database_, /*error=*/nullptr), IsOkStatus(&error_)); + ASSERT_THAT(AdbcDatabaseSetOption(&database_, "uri", uri.c_str(), + /*error=*/nullptr), + IsOkStatus(&error_)); + ASSERT_THAT(AdbcDatabaseInit(&database_, /*error=*/nullptr), IsOkStatus(&error_)); + + ASSERT_THAT(AdbcConnectionNew(&connection_, /*error=*/nullptr), IsOkStatus(&error_)); + ASSERT_THAT(AdbcConnectionInit(&connection_, &database_, /*error=*/nullptr), + IsOkStatus(&error_)); + error_ = {}; + } + void TearDown() override { + if (error_.release) error_.release(&error_); + ASSERT_THAT(AdbcConnectionRelease(&connection_, /*error=*/nullptr), + IsOkStatus(&error_)); + ASSERT_THAT(AdbcDatabaseRelease(&database_, /*error=*/nullptr), IsOkStatus(&error_)); + ASSERT_OK(server_->Shutdown()); + } + + std::unique_ptr server_; + AdbcDatabase database_; + AdbcConnection connection_; + AdbcError error_; +}; + +class AdbcTimeoutTest : public AdbcServerTest {}; + +TEST_F(AdbcTimeoutTest, InvalidValues) { + for (const auto& key : { + "arrow.flight.sql.rpc.timeout_seconds.fetch", + "arrow.flight.sql.rpc.timeout_seconds.query", + "arrow.flight.sql.rpc.timeout_seconds.update", + }) { + for (const auto& value : {"1.1f", "asdf", "inf", "NaN", "-1"}) { + ARROW_SCOPED_TRACE("key=", key, " value=", value); + ASSERT_EQ(ADBC_STATUS_INVALID_ARGUMENT, + AdbcConnectionSetOption(&connection_, key, value, &error_)); + ASSERT_THAT(error_.message, ::testing::HasSubstr("Invalid timeout option value")); + } + } +} + +TEST_F(AdbcTimeoutTest, RemoveTimeout) { + for (const auto& key : { + "arrow.flight.sql.rpc.timeout_seconds.fetch", + "arrow.flight.sql.rpc.timeout_seconds.query", + "arrow.flight.sql.rpc.timeout_seconds.update", + }) { + ARROW_SCOPED_TRACE("key=", key); + ASSERT_THAT(AdbcConnectionSetOption(&connection_, key, "1.0", &error_), + IsOkStatus(&error_)); + ASSERT_THAT(AdbcConnectionSetOption(&connection_, key, "0", &error_), + IsOkStatus(&error_)); + } +} + +TEST_F(AdbcTimeoutTest, DoActionTimeout) { + AdbcStatusCode status; + + status = AdbcConnectionSetOption( + &connection_, "arrow.flight.sql.rpc.timeout_seconds.update", "0.1", &error_); + ASSERT_THAT(status, IsOkStatus(&error_)); + + ASSERT_THAT(AdbcConnectionSetOption(&connection_, ADBC_CONNECTION_OPTION_AUTOCOMMIT, + ADBC_OPTION_VALUE_DISABLED, &error_), + ::testing::Not(IsOkStatus(&error_))); + ASSERT_THAT(error_.message, ::testing::Not(::testing::HasSubstr("NYI"))); +} + +TEST_F(AdbcTimeoutTest, DoGetTimeout) { + struct AdbcStatement stmt = {}; + struct ArrowArrayStream out = {}; + + ASSERT_THAT( + AdbcConnectionSetOption(&connection_, "arrow.flight.sql.rpc.timeout_seconds.fetch", + "0.1", &error_), + IsOkStatus(&error_)); + + ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&stmt, "fetch", &error_), IsOkStatus(&error_)); + Status st = ([&]() -> Status { + auto status = + AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_); + if (status != ADBC_STATUS_OK) return Status::Invalid(error_.message); + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr reader, + arrow::ImportRecordBatchReader(&out)); + while (true) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr rb, reader->Next()); + if (!rb) break; + } + return Status::OK(); + })(); + ASSERT_NOT_OK(st); + ASSERT_THAT(st.ToString(), ::testing::Not(::testing::HasSubstr("NYI"))); + + ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_)); +} + +TEST_F(AdbcTimeoutTest, DoPutTimeout) { + struct AdbcStatement stmt = {}; + + ASSERT_THAT( + AdbcConnectionSetOption(&connection_, "arrow.flight.sql.rpc.timeout_seconds.update", + "0.1", &error_), + IsOkStatus(&error_)); + + ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&stmt, "timeout", &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, /*out=*/nullptr, /*rows_affected=*/nullptr, + &error_), + ::testing::Not(IsOkStatus(&error_))); + ASSERT_THAT(error_.message, ::testing::Not(::testing::HasSubstr("NYI"))); + + ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_)); +} + +TEST_F(AdbcTimeoutTest, GetFlightInfoTimeout) { + struct AdbcStatement stmt = {}; + struct ArrowArrayStream out = {}; + + ASSERT_THAT( + AdbcConnectionSetOption(&connection_, "arrow.flight.sql.rpc.timeout_seconds.query", + "0.1", &error_), + IsOkStatus(&error_)) + << error_.message; + + ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&stmt, "timeout", &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_), + ::testing::Not(IsOkStatus(&error_))); + ASSERT_THAT(error_.message, ::testing::Not(::testing::HasSubstr("NYI"))); + + if (out.release) out.release(&out); + ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_)); +} + +/// Ensure options to add custom headers make it through. +class HeaderServerMiddlewareFactory : public ServerMiddlewareFactory { + public: + Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers, + std::shared_ptr* middleware) override { + for (const auto& header : incoming_headers) { + recorded_headers_.emplace_back(header.first, header.second); + } + return Status::OK(); + } + + std::vector> recorded_headers_; +}; + +class AdbcHeaderTest : public AdbcServerTest { + protected: + void Configure(FlightServerOptions* options) override { + headers_ = std::make_shared(); + options->middleware.emplace_back("headers", headers_); + } + std::shared_ptr headers_; +}; + +TEST_F(AdbcHeaderTest, Connection) { + struct AdbcStatement stmt = {}; + struct ArrowArrayStream out = {}; + + ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_)); + + ASSERT_THAT( + AdbcConnectionSetOption(&connection_, "arrow.flight.sql.rpc.call_header.x-span-id", + "my span id", &error_), + IsOkStatus(&error_)); + ASSERT_THAT(AdbcConnectionSetOption(&connection_, + "arrow.flight.sql.rpc.call_header.x-user-agent", + "Flight SQL ADBC", &error_), + IsOkStatus(&error_)) + << error_.message; + + ASSERT_THAT(AdbcStatementSetSqlQuery(&stmt, "timeout", &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_), + ::testing::Not(IsOkStatus(&error_))); + + ASSERT_THAT(AdbcConnectionSetOption(&connection_, ADBC_CONNECTION_OPTION_AUTOCOMMIT, + ADBC_OPTION_VALUE_DISABLED, &error_), + ::testing::Not(IsOkStatus(&error_))); + ASSERT_THAT(headers_->recorded_headers_, + ::testing::Contains(std::make_pair("x-span-id", "my span id"))); + ASSERT_THAT(headers_->recorded_headers_, + ::testing::Contains(std::make_pair("x-user-agent", "Flight SQL ADBC"))); + headers_->recorded_headers_.clear(); + + if (out.release) out.release(&out); + ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_)); +} + +TEST_F(AdbcHeaderTest, Database) { + struct AdbcStatement stmt = {}; + struct ArrowArrayStream out = {}; + + ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_)); + + ASSERT_THAT( + AdbcDatabaseSetOption(&database_, "arrow.flight.sql.rpc.call_header.x-span-id", + "my span id", &error_), + IsOkStatus(&error_)); + ASSERT_THAT(AdbcConnectionSetOption(&connection_, + "arrow.flight.sql.rpc.call_header.x-user-agent", + "Flight SQL ADBC", &error_), + IsOkStatus(&error_)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&stmt, "timeout", &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_), + ::testing::Not(IsOkStatus(&error_))); + + ASSERT_THAT(headers_->recorded_headers_, + ::testing::Contains(std::make_pair("x-span-id", "my span id"))); + ASSERT_THAT(headers_->recorded_headers_, + ::testing::Contains(std::make_pair("x-user-agent", "Flight SQL ADBC"))); + headers_->recorded_headers_.clear(); + + if (out.release) out.release(&out); + ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_)); +} + +TEST_F(AdbcHeaderTest, Statement) { + struct AdbcStatement stmt = {}; + struct ArrowArrayStream out = {}; + + ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementSetOption(&stmt, "arrow.flight.sql.rpc.call_header.x-span-id", + "my span id", &error_), + IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&stmt, "timeout", &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_), + ::testing::Not(IsOkStatus(&error_))); + ASSERT_THAT(headers_->recorded_headers_, + ::testing::Contains(std::make_pair("x-span-id", "my span id"))); + headers_->recorded_headers_.clear(); + + // Set header to NULL to erase it + ASSERT_THAT(AdbcStatementSetOption(&stmt, "arrow.flight.sql.rpc.call_header.x-span-id", + nullptr, &error_), + IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_), + ::testing::Not(IsOkStatus(&error_))); + ASSERT_THAT( + headers_->recorded_headers_, + ::testing::Not(::testing::Contains(std::make_pair("x-span-id", "my span id")))); + headers_->recorded_headers_.clear(); + + // Connection headers are inherited + ASSERT_THAT( + AdbcConnectionSetOption(&connection_, "arrow.flight.sql.rpc.call_header.x-span-id", + "my span id", &error_), + IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_), + ::testing::Not(IsOkStatus(&error_))); + ASSERT_THAT(headers_->recorded_headers_, + ::testing::Contains(std::make_pair("x-span-id", "my span id"))); + headers_->recorded_headers_.clear(); + + if (out.release) out.release(&out); + ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_)); +} + +TEST_F(AdbcHeaderTest, Combined) { + struct AdbcStatement stmt = {}; + struct ArrowArrayStream out = {}; + + ASSERT_THAT( + AdbcDatabaseSetOption(&database_, "arrow.flight.sql.rpc.call_header.x-header-one", + "value 1", &error_), + IsOkStatus(&error_)); + + ASSERT_THAT(AdbcConnectionSetOption(&connection_, + "arrow.flight.sql.rpc.call_header.x-header-two", + "value 2", &error_), + IsOkStatus(&error_)); + + ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_)); + ASSERT_THAT( + AdbcStatementSetOption(&stmt, "arrow.flight.sql.rpc.call_header.x-header-three", + "value 3", &error_), + IsOkStatus(&error_)); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&stmt, "timeout", &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_), + ::testing::Not(IsOkStatus(&error_))); + + ASSERT_THAT(headers_->recorded_headers_, + ::testing::Contains(std::make_pair("x-header-one", "value 1"))); + ASSERT_THAT(headers_->recorded_headers_, + ::testing::Contains(std::make_pair("x-header-two", "value 2"))); + ASSERT_THAT(headers_->recorded_headers_, + ::testing::Contains(std::make_pair("x-header-three", "value 3"))); + + if (out.release) out.release(&out); + ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_)); +} + +class SubstraitTestServer : public FlightSqlServerBase { + arrow::Result> DoGetStatement( + const ServerCallContext& context, const StatementQueryTicket& command) override { + if (command.statement_handle != "expected plan") { + return Status::Invalid("invalid plan"); + } + ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({}, arrow::schema({}))); + return std::make_unique(reader); + } + + arrow::Result DoPutCommandSubstraitPlan( + const ServerCallContext& context, const StatementSubstraitPlan& command) override { + if (command.plan.plan != "expected plan") return Status::Invalid("invalid plan"); + return 42; + } + + arrow::Result> GetFlightInfoSubstraitPlan( + const ServerCallContext& context, const StatementSubstraitPlan& command, + const FlightDescriptor& descriptor) override { + if (command.plan.plan != "expected plan") { + return Status::Invalid("invalid plan"); + } + ARROW_ASSIGN_OR_RAISE(std::string ticket, + CreateStatementQueryTicket(command.plan.plan)); + std::vector endpoints = { + FlightEndpoint{Ticket{std::move(ticket)}, {}}}; + ARROW_ASSIGN_OR_RAISE(auto info, + FlightInfo::Make(Schema({}), descriptor, std::move(endpoints), + /*total_records=*/-1, /*total_bytes=*/-1)); + return std::make_unique(std::move(info)); + } +}; + +class AdbcSubstraitTest : public AdbcServerTest {}; + +TEST_F(AdbcSubstraitTest, DoGet) { + struct AdbcStatement stmt = {}; + struct ArrowArrayStream out = {}; + + std::string plan = "expected plan"; + + ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_)); + ASSERT_THAT( + AdbcStatementSetSubstraitPlan(&stmt, reinterpret_cast(plan.data()), + plan.size(), &error_), + IsOkStatus(&error_)); + int64_t rows_affected = 0; + ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, &out, &rows_affected, &error_), + IsOkStatus(&error_)); + ASSERT_NE(nullptr, out.release); + + out.release(&out); + ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_)); +} + +TEST_F(AdbcSubstraitTest, DoPut) { + struct AdbcStatement stmt = {}; + + std::string plan = "expected plan"; + + ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_)); + ASSERT_THAT( + AdbcStatementSetSubstraitPlan(&stmt, reinterpret_cast(plan.data()), + plan.size(), &error_), + IsOkStatus(&error_)); + int64_t rows_affected = 0; + ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, /*out=*/nullptr, &rows_affected, &error_), + IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_)); + + ASSERT_EQ(42, rows_affected); +} + +class FallbackTestServer : public FlightSqlServerBase { + public: + arrow::Result> DoGetStatement( + const ServerCallContext& context, const StatementQueryTicket& command) override { + if (command.statement_handle == "FailedGet" && counter_++ == 0) { + // Fail the first time around + return Status::IOError("First time fails"); + } + ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({}, arrow::schema({}))); + return std::make_unique(reader); + } + + arrow::Result> GetFlightInfoStatement( + const ServerCallContext& context, const StatementQuery& command, + const FlightDescriptor& descriptor) override { + ARROW_ASSIGN_OR_RAISE(std::string ticket, CreateStatementQueryTicket(command.query)); + std::vector endpoints; + if (command.query == "NoLocation") { + endpoints = {FlightEndpoint{Ticket{std::move(ticket)}, {}}}; + } else if (command.query == "FailedConnection") { + ARROW_ASSIGN_OR_RAISE(auto location1, + Location::ForGrpcTcp("unreachable", this->port())); + ARROW_ASSIGN_OR_RAISE(auto location2, + Location::ForGrpcTcp("localhost", this->port())); + endpoints = { + FlightEndpoint{Ticket{std::move(ticket)}, + { + location1, + location2, + }}, + }; + } else if (command.query == "FailedGet") { + ARROW_ASSIGN_OR_RAISE(auto location, + Location::ForGrpcTcp("localhost", this->port())); + endpoints = { + FlightEndpoint{Ticket{std::move(ticket)}, + { + location, + location, + }}, + }; + } else if (command.query == "Failure") { + ARROW_ASSIGN_OR_RAISE(auto location, + Location::ForGrpcTcp("unreachable", this->port())); + endpoints = { + FlightEndpoint{Ticket{std::move(ticket)}, + { + location, + location, + }}, + }; + } else { + return Status::Invalid(command.query); + } + ARROW_ASSIGN_OR_RAISE(auto info, + FlightInfo::Make(Schema({}), descriptor, std::move(endpoints), + /*total_records=*/-1, /*total_bytes=*/-1)); + return std::make_unique(std::move(info)); + } + + private: + int32_t counter_ = 0; +}; + +class AdbcFallbackTest : public AdbcServerTest { + protected: + Status DoQuery(const std::string& query) { + struct AdbcStatement stmt = {}; + ARROW_ADBC_RETURN_NOT_OK(UnknownError, &error_, + AdbcStatementNew(&connection_, &stmt, &error_)); + ARROW_ADBC_RETURN_NOT_OK(UnknownError, &error_, + AdbcStatementSetSqlQuery(&stmt, query.data(), &error_)); + + struct ArrowArrayStream out; + AdbcStatusCode status = + AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_); + + ARROW_ADBC_RETURN_NOT_OK(UnknownError, &error_, AdbcStatementRelease(&stmt, &error_)); + ARROW_ADBC_RETURN_NOT_OK(UnknownError, &error_, status); + if (out.release) out.release(&out); + return Status::OK(); + } + + arrow::Result>> GetPartitions( + const std::string& query) { + struct AdbcStatement stmt = {}; + ARROW_ADBC_RETURN_NOT_OK(UnknownError, &error_, + AdbcStatementNew(&connection_, &stmt, &error_)); + ARROW_ADBC_RETURN_NOT_OK(UnknownError, &error_, + AdbcStatementSetSqlQuery(&stmt, query.data(), &error_)); + + struct ArrowSchema c_schema; + struct AdbcPartitions partitions; + ARROW_ADBC_RETURN_NOT_OK( + UnknownError, &error_, + AdbcStatementExecutePartitions(&stmt, &c_schema, &partitions, + /*rows_affected=*/nullptr, &error_)); + EXPECT_GT(partitions.num_partitions, 0); + c_schema.release(&c_schema); + + std::vector> result; + for (size_t i = 0; i < partitions.num_partitions; i++) { + EXPECT_OK_AND_ASSIGN(auto info, + FlightInfo::Deserialize(std::string_view( + reinterpret_cast(partitions.partitions[i]), + partitions.partition_lengths[i]))); + result.emplace_back(std::move(info)); + } + + if (c_schema.release) c_schema.release(&c_schema); + partitions.release(&partitions); + + ARROW_ADBC_RETURN_NOT_OK(UnknownError, &error_, AdbcStatementRelease(&stmt, &error_)); + return result; + } +}; + +TEST_F(AdbcFallbackTest, NoLocation) { + ASSERT_OK(DoQuery("NoLocation")); + ASSERT_OK_AND_ASSIGN(auto partitions, GetPartitions("NoLocation")); + ASSERT_EQ(1, partitions.size()); + ASSERT_EQ(1, partitions[0]->endpoints().size()); + ASSERT_TRUE(partitions[0]->endpoints()[0].locations.empty()); +} + +TEST_F(AdbcFallbackTest, FailedConnection) { + ASSERT_OK(DoQuery("FailedConnection")); + ASSERT_OK_AND_ASSIGN(auto partitions, GetPartitions("FailedConnection")); + ASSERT_EQ(1, partitions.size()); + ASSERT_EQ(1, partitions[0]->endpoints().size()); + ASSERT_EQ(2, partitions[0]->endpoints()[0].locations.size()); +} + +TEST_F(AdbcFallbackTest, FailedGet) { + ASSERT_OK(DoQuery("FailedGet")); + ASSERT_OK_AND_ASSIGN(auto partitions, GetPartitions("FailedGet")); + ASSERT_EQ(1, partitions.size()); + ASSERT_EQ(1, partitions[0]->endpoints().size()); + ASSERT_EQ(2, partitions[0]->endpoints()[0].locations.size()); +} + +TEST_F(AdbcFallbackTest, Failure) { + ASSERT_NOT_OK(DoQuery("Failure")); + ASSERT_OK_AND_ASSIGN(auto partitions, GetPartitions("Failure")); + ASSERT_EQ(1, partitions.size()); + ASSERT_EQ(1, partitions[0]->endpoints().size()); + ASSERT_EQ(2, partitions[0]->endpoints()[0].locations.size()); +} + +class NoOpTestServer : public FlightSqlServerBase { + public: + arrow::Result> DoGetStatement( + const ServerCallContext& context, const StatementQueryTicket& command) override { + ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({}, arrow::schema({}))); + return std::make_unique(reader); + } + + arrow::Result> GetFlightInfoStatement( + const ServerCallContext& context, const StatementQuery& command, + const FlightDescriptor& descriptor) override { + ARROW_ASSIGN_OR_RAISE(std::string ticket, CreateStatementQueryTicket(command.query)); + ARROW_ASSIGN_OR_RAISE(auto info, + FlightInfo::Make(Schema({}), descriptor, {}, + /*total_records=*/-1, /*total_bytes=*/-1)); + return std::make_unique(std::move(info)); + } +}; + +class AdbcTlsTest : public AdbcServerTest { + protected: + arrow::Result GetLocation(const std::string& host, int port) override { + return Location::ForGrpcTls(host, port); + } + + void Configure(FlightServerOptions* options) override { + ASSERT_OK(ExampleTlsCertificates(&options->tls_certificates)); + headers_ = std::make_shared(); + options->middleware.emplace_back("headers", headers_); + } + std::shared_ptr headers_; +}; + +TEST_F(AdbcTlsTest, DisableVerification) { + { + // Default connection tries to verify certs, fails + struct AdbcStatement stmt = {}; + struct ArrowArrayStream out = {}; + ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&stmt, "query", &error_), IsOkStatus(&error_)); + ASSERT_THAT( + AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_), + ::testing::Not(IsOkStatus(&error_))); + ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_)); + } + { + // Disabling verification works + ASSERT_THAT(AdbcConnectionRelease(&connection_, /*error=*/nullptr), + IsOkStatus(&error_)); + ASSERT_THAT( + AdbcDatabaseSetOption( + &database_, "arrow.flight.sql.client_option.disable_server_verification", + ADBC_OPTION_VALUE_ENABLED, &error_), + IsOkStatus(&error_)); + + ASSERT_THAT(AdbcConnectionNew(&connection_, /*error=*/nullptr), IsOkStatus(&error_)); + ASSERT_THAT(AdbcConnectionInit(&connection_, &database_, /*error=*/nullptr), + IsOkStatus(&error_)); + + struct AdbcStatement stmt = {}; + struct ArrowArrayStream out = {}; + ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&stmt, "query", &error_), IsOkStatus(&error_)); + ASSERT_THAT( + AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_), + IsOkStatus(&error_)); + out.release(&out); + ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_)); + } +} + +TEST_F(AdbcTlsTest, GenericIntOption) { + ASSERT_THAT(AdbcDatabaseSetOption(&database_, + "arrow.flight.sql.client_option.generic_int_option." + "grpc.max_send_message_length", + "invalid", &error_), + ::testing::Not(IsOkStatus(&error_))); + ASSERT_THAT(AdbcDatabaseSetOption(&database_, + "arrow.flight.sql.client_option.generic_int_option." + "grpc.max_send_message_length", + "0", &error_), + IsOkStatus(&error_)); +} + +TEST_F(AdbcTlsTest, GenericStringOption) { + ASSERT_THAT(AdbcConnectionRelease(&connection_, /*error=*/nullptr), + IsOkStatus(&error_)); + ASSERT_THAT( + AdbcDatabaseSetOption(&database_, + "arrow.flight.sql.client_option.disable_server_verification", + ADBC_OPTION_VALUE_ENABLED, &error_), + IsOkStatus(&error_)); + ASSERT_THAT( + AdbcDatabaseSetOption( + &database_, + "arrow.flight.sql.client_option.generic_string_option.grpc.primary_user_agent", + "custom user agent", &error_), + IsOkStatus(&error_)); + ASSERT_THAT(AdbcConnectionNew(&connection_, /*error=*/nullptr), IsOkStatus(&error_)); + ASSERT_THAT(AdbcConnectionInit(&connection_, &database_, /*error=*/nullptr), + IsOkStatus(&error_)); + + struct AdbcStatement stmt = {}; + struct ArrowArrayStream out = {}; + ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&stmt, "query", &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_), + IsOkStatus(&error_)); + out.release(&out); + ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_)); + std::string found_header; + for (const auto& pair : headers_->recorded_headers_) { + if (pair.first == "user-agent") { + found_header = pair.second; + } + } + ASSERT_THAT(found_header, ::testing::HasSubstr("custom user agent")); +} + +class AdbcJdbcStyleAuthTest : public AdbcServerTest { + protected: + void Configure(FlightServerOptions* options) override { + headers_ = std::make_shared(); + options->middleware.emplace_back("headers", headers_); + options->middleware.emplace_back("auth", + std::make_shared()); + } + + void ReconnectWithToken(const std::string& token) { + ASSERT_THAT(AdbcConnectionRelease(&connection_, /*error=*/nullptr), + IsOkStatus(&error_)); + ASSERT_THAT(AdbcDatabaseSetOption(&database_, "arrow.flight.sql.authorization_header", + token.c_str(), &error_), + IsOkStatus(&error_)); + ASSERT_THAT(AdbcConnectionNew(&connection_, /*error=*/nullptr), IsOkStatus(&error_)); + ASSERT_THAT(AdbcConnectionInit(&connection_, &database_, /*error=*/nullptr), + IsOkStatus(&error_)); + } + + class TestAuthServerMiddleware : public ServerMiddleware { + public: + std::string name() const override { return "auth"; } + void SendingHeaders(AddCallHeaders* outgoing_headers) override { + outgoing_headers->AddHeader("authorization", "response token"); + } + void CallCompleted(const Status&) override {} + }; + + class TestAuthServerMiddlewareFactory : public ServerMiddlewareFactory { + public: + Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers, + std::shared_ptr* middleware) { + if (info.method == FlightMethod::Handshake) return Status::OK(); + + auto it = incoming_headers.find("authorization"); + if (it == incoming_headers.end()) { + return MakeFlightError(FlightStatusCode::Unauthenticated, "Missing auth header"); + } + if (it->second == "initial token" || it->second == "response token") { + *middleware = std::make_shared(); + return Status::OK(); + } + return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid auth header"); + } + }; + + std::shared_ptr headers_; +}; + +TEST_F(AdbcJdbcStyleAuthTest, SuccessfulAuth) { + ASSERT_NO_FATAL_FAILURE(ReconnectWithToken("initial token")); + + struct AdbcStatement stmt = {}; + struct ArrowArrayStream out = {}; + + ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&stmt, "query", &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_), + IsOkStatus(&error_)); + out.release(&out); + ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_)); + + ASSERT_THAT(headers_->recorded_headers_, + ::testing::Contains(std::make_pair("authorization", "initial token"))); + headers_->recorded_headers_.clear(); + + ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&stmt, "query", &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_), + IsOkStatus(&error_)); + out.release(&out); + ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_)); + + ASSERT_THAT(headers_->recorded_headers_, + ::testing::Contains(std::make_pair("authorization", "response token"))); +} + +TEST_F(AdbcJdbcStyleAuthTest, FailedAuth) { + ASSERT_NO_FATAL_FAILURE(ReconnectWithToken("wrong token")); + + struct AdbcStatement stmt = {}; + struct ArrowArrayStream out = {}; + ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&stmt, "query", &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_), + ::testing::Not(IsOkStatus(&error_))); + ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_)); + ASSERT_THAT(error_.message, ::testing::HasSubstr("Invalid auth header")); +} + +TEST_F(AdbcJdbcStyleAuthTest, NoAuth) { + struct AdbcStatement stmt = {}; + struct ArrowArrayStream out = {}; + ASSERT_THAT(AdbcStatementNew(&connection_, &stmt, &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&stmt, "query", &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementExecuteQuery(&stmt, &out, /*rows_affected=*/nullptr, &error_), + ::testing::Not(IsOkStatus(&error_))); + ASSERT_THAT(AdbcStatementRelease(&stmt, &error_), IsOkStatus(&error_)); + ASSERT_THAT(error_.message, ::testing::HasSubstr("Missing auth header")); +} + +class AdbcSqlInfoTest : public AdbcServerTest {}; + +TEST_F(AdbcSqlInfoTest, NoTransactionSupport) { + ASSERT_EQ(ADBC_STATUS_NOT_IMPLEMENTED, + AdbcConnectionSetOption(&connection_, ADBC_CONNECTION_OPTION_AUTOCOMMIT, + ADBC_OPTION_VALUE_DISABLED, &error_)); + ASSERT_THAT(error_.message, + ::testing::HasSubstr("Server does not report transaction support")); +} + +} // namespace arrow::flight::sql diff --git a/cpp/src/arrow/flight/sql/client.cc b/cpp/src/arrow/flight/sql/client.cc index 25bf8e384ef..7095d49941e 100644 --- a/cpp/src/arrow/flight/sql/client.cc +++ b/cpp/src/arrow/flight/sql/client.cc @@ -532,6 +532,7 @@ arrow::Result> PreparedStatement::ParseRespon ReadSchema(¶meter_schema_reader, &in_memo)); } auto handle = prepared_statement_result.prepared_statement_handle(); + ARROW_RETURN_NOT_OK(DrainResultStream(results.get())); return std::make_shared(client, handle, dataset_schema, parameter_schema); diff --git a/cpp/src/arrow/flight/sql/client.h b/cpp/src/arrow/flight/sql/client.h index 648f71563e9..d151b41ce0b 100644 --- a/cpp/src/arrow/flight/sql/client.h +++ b/cpp/src/arrow/flight/sql/client.h @@ -28,6 +28,11 @@ #include "arrow/result.h" #include "arrow/status.h" +// Forward declaration +extern "C" { +struct AdbcError; +} + namespace arrow { namespace flight { namespace sql { @@ -450,6 +455,15 @@ class ARROW_FLIGHT_SQL_EXPORT Transaction { std::string transaction_id_; }; +/// \brief ADBC entry point. +/// +/// A pointer to this function can be passed to the ADBC driver +/// manager to initialize an ADBC connection using the Flight SQL +/// driver. Applications should generally not call this function +/// directly. +ARROW_FLIGHT_SQL_EXPORT +uint8_t AdbcDriverInit(int version, void* raw_driver, struct AdbcError* error); + } // namespace sql } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/symbols.map b/cpp/src/arrow/symbols.map index 9ef0e404bc0..56c9da35043 100644 --- a/cpp/src/arrow/symbols.map +++ b/cpp/src/arrow/symbols.map @@ -32,6 +32,9 @@ }; # Also export C-level helpers arrow_*; + pyarrow_*; + # Also export ADBC driver symbols + Adbc*; # ARROW-14771: export Protobuf symbol table descriptor_table_Flight_2eproto; descriptor_table_FlightSql_2eproto; diff --git a/cpp/thirdparty/versions.txt b/cpp/thirdparty/versions.txt index 2611944cf26..8a02c2a3688 100644 --- a/cpp/thirdparty/versions.txt +++ b/cpp/thirdparty/versions.txt @@ -25,6 +25,8 @@ ARROW_ABSL_BUILD_VERSION=20211102.0 ARROW_ABSL_BUILD_SHA256_CHECKSUM=dcf71b9cba8dc0ca9940c4b316a0c796be8fab42b070bb6b7cab62b48f0e66c4 +ARROW_ADBC_BUILD_VERSION=apache-arrow-adbc-0.1.0 +ARROW_ADBC_BUILD_SHA256_CHECKSUM=cdb151c1ad355bf3ff67a22eb606ec18b34527dab491419331648e9153adbde3 ARROW_AWSSDK_BUILD_VERSION=1.8.133 ARROW_AWSSDK_BUILD_SHA256_CHECKSUM=d6c495bc06be5e21dac716571305d77437e7cfd62a2226b8fe48d9ab5785a8d6 ARROW_AWS_CHECKSUMS_BUILD_VERSION=v0.1.12 @@ -61,6 +63,8 @@ ARROW_LZ4_BUILD_VERSION=v1.9.4 ARROW_LZ4_BUILD_SHA256_CHECKSUM=0b0e3aa07c8c063ddf40b082bdf7e37a1562bda40a0ff5272957f3e987e0e54b ARROW_MIMALLOC_BUILD_VERSION=v2.0.6 ARROW_MIMALLOC_BUILD_SHA256_CHECKSUM=9f05c94cc2b017ed13698834ac2a3567b6339a8bde27640df5a1581d49d05ce5 +ARROW_NANOARROW_BUILD_VERSION=bcc7dfb80a09a897124da3f10b40124b150659b5 +ARROW_NANOARROW_BUILD_SHA256_CHECKSUM=2d5655c877c26a7a96ae4abe2f15b211ccd38eacc85422b23c493fb0298d227a ARROW_NLOHMANN_JSON_BUILD_VERSION=v3.10.5 ARROW_NLOHMANN_JSON_BUILD_SHA256_CHECKSUM=5daca6ca216495edf89d167f808d1d03c4a4d929cef7da5e10f135ae1540c7e4 ARROW_OPENTELEMETRY_BUILD_VERSION=v1.4.1 @@ -103,6 +107,7 @@ ARROW_ZSTD_BUILD_SHA256_CHECKSUM=f7de13462f7a82c29ab865820149e778cbfe01087b3a55b # given version. DEPENDENCIES=( "ARROW_ABSL_URL absl-${ARROW_ABSL_BUILD_VERSION}.tar.gz https://github.com/abseil/abseil-cpp/archive/${ARROW_ABSL_BUILD_VERSION}.tar.gz" + "ARROW_ADBC_URL adbc-${ARROW_ADBC_BUILD_VERSION}.tar.gz https://github.com/apache/arrow-adbc/archive/${ARROW_ADBC_BUILD_VERSION}.tar.gz" "ARROW_AWSSDK_URL aws-sdk-cpp-${ARROW_AWSSDK_BUILD_VERSION}.tar.gz https://github.com/aws/aws-sdk-cpp/archive/${ARROW_AWSSDK_BUILD_VERSION}.tar.gz" "ARROW_AWS_CHECKSUMS_URL aws-checksums-${ARROW_AWS_CHECKSUMS_BUILD_VERSION}.tar.gz https://github.com/awslabs/aws-checksums/archive/${ARROW_AWS_CHECKSUMS_BUILD_VERSION}.tar.gz" "ARROW_AWS_C_COMMON_URL aws-c-common-${ARROW_AWS_C_COMMON_BUILD_VERSION}.tar.gz https://github.com/awslabs/aws-c-common/archive/${ARROW_AWS_C_COMMON_BUILD_VERSION}.tar.gz" @@ -121,6 +126,7 @@ DEPENDENCIES=( "ARROW_JEMALLOC_URL jemalloc-${ARROW_JEMALLOC_BUILD_VERSION}.tar.bz2 https://github.com/jemalloc/jemalloc/releases/download/${ARROW_JEMALLOC_BUILD_VERSION}/jemalloc-${ARROW_JEMALLOC_BUILD_VERSION}.tar.bz2" "ARROW_LZ4_URL lz4-${ARROW_LZ4_BUILD_VERSION}.tar.gz https://github.com/lz4/lz4/archive/${ARROW_LZ4_BUILD_VERSION}.tar.gz" "ARROW_MIMALLOC_URL mimalloc-${ARROW_MIMALLOC_BUILD_VERSION}.tar.gz https://github.com/microsoft/mimalloc/archive/${ARROW_MIMALLOC_BUILD_VERSION}.tar.gz" + "ARROW_NANOARROW_URL nanoarrow-${ARROW_NANOARROW_BUILD_VERSION}.tar.gz https://github.com/apache/arrow-nanoarrow/archive/${ARROW_NANOARROW_BUILD_VERSION}.tar.gz" "ARROW_NLOHMANN_JSON_URL nlohmann-json-${ARROW_NLOHMANN_JSON_BUILD_VERSION}.tar.gz https://github.com/nlohmann/json/archive/refs/tags/${ARROW_NLOHMANN_JSON_BUILD_VERSION}.tar.gz" "ARROW_OPENTELEMETRY_URL opentelemetry-cpp-${ARROW_OPENTELEMETRY_BUILD_VERSION}.tar.gz https://github.com/open-telemetry/opentelemetry-cpp/archive/refs/tags/${ARROW_OPENTELEMETRY_BUILD_VERSION}.tar.gz" "ARROW_OPENTELEMETRY_PROTO_URL opentelemetry-proto-${ARROW_OPENTELEMETRY_PROTO_BUILD_VERSION}.tar.gz https://github.com/open-telemetry/opentelemetry-proto/archive/refs/tags/${ARROW_OPENTELEMETRY_PROTO_BUILD_VERSION}.tar.gz" diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 51b8eaae400..521721c7853 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -112,6 +112,7 @@ endif() if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}") option(PYARROW_BUILD_CUDA "Build the PyArrow CUDA support" OFF) option(PYARROW_BUILD_FLIGHT "Build the PyArrow Flight integration" OFF) + option(PYARROW_BUILD_FLIGHT_SQL "Build the PyArrow Flight SQL ADBC driver" OFF) option(PYARROW_BUILD_SUBSTRAIT "Build the PyArrow Substrait integration" OFF) option(PYARROW_BUILD_DATASET "Build the PyArrow Dataset integration" OFF) option(PYARROW_BUILD_GANDIVA "Build the PyArrow Gandiva integration" OFF) @@ -377,6 +378,12 @@ install(TARGETS arrow_python RUNTIME DESTINATION .) set(PYARROW_CPP_FLIGHT_SRCS ${PYARROW_CPP_SOURCE_DIR}/flight.cc) + +if(PYARROW_BUILD_FLIGHT_SQL) + set(ARROW_FLIGHT_SQL TRUE) + set(PYARROW_BUILD_FLIGHT TRUE) +endif() + if(PYARROW_BUILD_FLIGHT) if(NOT ARROW_FLIGHT) message(FATAL_ERROR "You must build Arrow C++ with ARROW_FLIGHT=ON") @@ -657,6 +664,24 @@ else() set(FLIGHT_LINK_LIBS "") endif() +# Flight SQL +if(PYARROW_BUILD_FLIGHT_SQL) + # Arrow Flight SQL + find_package(ArrowFlightSql REQUIRED) + + if(PYARROW_BUNDLE_ARROW_CPP) + bundle_arrow_lib(${ARROW_FLIGHT_SQL_SHARED_LIB} SO_VERSION ${ARROW_SO_VERSION}) + if(MSVC) + bundle_arrow_import_lib(${ARROW_FLIGHT_SQL_IMPORT_LIB}) + endif() + endif() + + set(FLIGHT_SQL_LINK_LIBS ArrowFlightSql::arrow_flight_sql_shared) + set(CYTHON_EXTENSIONS ${CYTHON_EXTENSIONS} _flight_sql) +else() + set(FLIGHT_SQL_LINK_LIBS "") +endif() + # Substrait if(PYARROW_BUILD_SUBSTRAIT) if(NOT ARROW_SUBSTRAIT) @@ -778,6 +803,10 @@ if(PYARROW_BUILD_FLIGHT) target_link_libraries(_flight PRIVATE ${FLIGHT_LINK_LIBS}) endif() +if(PYARROW_BUILD_FLIGHT_SQL) + target_link_libraries(_flight_sql PRIVATE ${FLIGHT_SQL_LINK_LIBS}) +endif() + if(PYARROW_BUILD_SUBSTRAIT) target_link_libraries(_substrait PRIVATE ${SUBSTRAIT_LINK_LIBS}) endif() diff --git a/python/pyarrow/_flight_sql.pyx b/python/pyarrow/_flight_sql.pyx new file mode 100644 index 00000000000..4a6892e5e61 --- /dev/null +++ b/python/pyarrow/_flight_sql.pyx @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +from libc.stdint cimport uintptr_t + +from pyarrow.includes.libarrow_flight_sql cimport * + + +def connect_raw(uri: str, **kwargs): + """Create a low level ADBC connection via Flight SQL.""" + import adbc_driver_manager + return adbc_driver_manager.AdbcDatabase(init_func=int( &AdbcDriverInit), + uri=uri, **kwargs) diff --git a/python/pyarrow/conftest.py b/python/pyarrow/conftest.py index bea735bd3ac..7fbe45cc022 100644 --- a/python/pyarrow/conftest.py +++ b/python/pyarrow/conftest.py @@ -45,6 +45,7 @@ 'substrait', 'tensorflow', 'flight', + 'flight_sql', 'slow', 'requires_testing_data', 'zstd', @@ -57,6 +58,7 @@ 'dataset': False, 'fastparquet': False, 'flight': False, + 'flight_sql': False, 'gandiva': False, 'gcs': False, 'gdb': True, @@ -148,6 +150,12 @@ except ImportError: pass +try: + from pyarrow.flight_sql import connect # noqa + defaults['flight_sql'] = True +except (ImportError, RuntimeError): + pass + try: from pyarrow.fs import GcsFileSystem # noqa defaults['gcs'] = True diff --git a/python/pyarrow/flight_sql.py b/python/pyarrow/flight_sql.py new file mode 100644 index 00000000000..0d87e5168e8 --- /dev/null +++ b/python/pyarrow/flight_sql.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +A Flight SQL client with a DBAPI 2.0/PEP 249-compatible API. + +Depends on the ADBC driver manager. +""" + +from typing import Dict, Optional + +_ADBC_FOUND = False + +try: + import adbc_driver_manager + import adbc_driver_manager.dbapi +except ImportError: + pass +else: + _ADBC_FOUND = True + + +_dbapi_names = [ + "BINARY", + "DATETIME", + "NUMBER", + "ROWID", + "STRING", + "Connection", + "Cursor", + "DataError", + "DatabaseError", + "Date", + "DateFromTicks", + "Error", + "IntegrityError", + "InterfaceError", + "InternalError", + "NotSupportedError", + "OperationalError", + "ProgrammingError", + "Time", + "TimeFromTicks", + "Timestamp", + "TimestampFromTicks", + "Warning", + "apilevel", + "connect", + "paramstyle", + "threadsafety", +] + + +if _ADBC_FOUND: + import pyarrow._flight_sql + + __all__ = _dbapi_names + + # ---------------------------------------------------------- + # Globals + + apilevel = adbc_driver_manager.dbapi.apilevel + threadsafety = adbc_driver_manager.dbapi.threadsafety + # XXX: the param style can't be determined up front + paramstyle = "qmark" + + Warning = adbc_driver_manager.dbapi.Warning + Error = adbc_driver_manager.dbapi.Error + InterfaceError = adbc_driver_manager.dbapi.InterfaceError + DatabaseError = adbc_driver_manager.dbapi.DatabaseError + DataError = adbc_driver_manager.dbapi.DataError + OperationalError = adbc_driver_manager.dbapi.OperationalError + IntegrityError = adbc_driver_manager.dbapi.IntegrityError + InternalError = adbc_driver_manager.dbapi.InternalError + ProgrammingError = adbc_driver_manager.dbapi.ProgrammingError + NotSupportedError = adbc_driver_manager.dbapi.NotSupportedError + + # ---------------------------------------------------------- + # Types + + Date = adbc_driver_manager.dbapi.Date + Time = adbc_driver_manager.dbapi.Time + Timestamp = adbc_driver_manager.dbapi.Timestamp + DateFromTicks = adbc_driver_manager.dbapi.DateFromTicks + TimeFromTicks = adbc_driver_manager.dbapi.TimeFromTicks + TimestampFromTicks = adbc_driver_manager.dbapi.TimestampFromTicks + STRING = adbc_driver_manager.dbapi.STRING + BINARY = adbc_driver_manager.dbapi.BINARY + NUMBER = adbc_driver_manager.dbapi.NUMBER + DATETIME = adbc_driver_manager.dbapi.DATETIME + ROWID = adbc_driver_manager.dbapi.ROWID + + # ---------------------------------------------------------- + # Functions + + def connect(uri: str, *, db_kwargs: Optional[Dict[str, str]] = None, + conn_kwargs: Optional[Dict[str, str]] = None) -> "Connection": + """ + Connect to a Flight SQL server via ADBC. + + Parameters + ---------- + uri : str + The Flight URI to connect to. + db_kwargs : dict, optional + Additional arguments to pass when creating the AdbcDatabase. + conn_kwargs : dict, optional + Additional arguments to pass when creating the AdbcConnection. + """ + db = None + conn = None + db_kwargs = db_kwargs or {} + conn_kwargs = conn_kwargs or {} + + try: + db = pyarrow._flight_sql.connect_raw(uri, **db_kwargs) + conn = adbc_driver_manager.AdbcConnection(db, **conn_kwargs) + return adbc_driver_manager.dbapi.Connection(db, conn) + except Exception: + if conn: + conn.close() + if db: + db.close() + raise + + # ---------------------------------------------------------- + # Classes + + Connection = adbc_driver_manager.dbapi.Connection + Cursor = adbc_driver_manager.dbapi.Cursor +else: + + def __getattr__(name): + if name in _dbapi_names: + raise RuntimeError( + f"{__name__}.{name} requires adbc_driver_manager") + else: + raise AttributeError( + f"module '{__name__}' has no attribute '{name}'") diff --git a/python/pyarrow/includes/libarrow_flight_sql.pxd b/python/pyarrow/includes/libarrow_flight_sql.pxd new file mode 100644 index 00000000000..df7d3e6b8a5 --- /dev/null +++ b/python/pyarrow/includes/libarrow_flight_sql.pxd @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# distutils: language = c++ + +from libc.stdint cimport uint8_t + +cdef extern from "arrow/flight/sql/api.h" namespace "arrow" nogil: + cdef struct CAdbcError "AdbcError": + pass + + uint8_t AdbcDriverInit"arrow::flight::sql::AdbcDriverInit"(int version, void* raw_driver, CAdbcError* error) diff --git a/python/pyarrow/tests/test_flight_sql.py b/python/pyarrow/tests/test_flight_sql.py new file mode 100644 index 00000000000..2871938e484 --- /dev/null +++ b/python/pyarrow/tests/test_flight_sql.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +try: + from pyarrow import flight_sql + from pyarrow.flight_sql import connect +except (ImportError, RuntimeError): + flight_sql = None + connect = None + +# Marks all of the tests in this module +# Ignore these with pytest ... -m 'not flight_sql' +pytestmark = pytest.mark.flight_sql + + +# Without a Flight SQL server implementation in Python, the best we +# can do is just basic smoke tests + +def test_import(): + assert flight_sql.connect + + +def test_invalid_uri(): + match = ".*No client transport implementation for invalid.*" + with pytest.raises(flight_sql.Error, match=match): + connect("invalid://foo") + + +def test_valid_uri(): + # Connection still fails, but we should get something different + match = ".*ADBC_STATUS_IO.*" + with pytest.raises(flight_sql.OperationalError, match=match): + with connect("grpc+tcp://localhost:1234") as conn: + conn.adbc_get_info() + + +def test_db_kwargs(): + match = ".*type name invalid_type is not recognized.*" + db_kwargs = {"arrow.flight.sql.quirks.ingest_type.invalid_type": "integer"} + with pytest.raises(flight_sql.ProgrammingError, match=match): + with connect("grpc+tcp://localhost:1234", db_kwargs=db_kwargs) as conn: + conn.adbc_get_info() + + +def test_conn_kwargs(): + match = ".*Invalid timeout option value.*" + conn_kwargs = {"arrow.flight.sql.rpc.timeout_seconds.fetch": "invalid"} + with pytest.raises(flight_sql.ProgrammingError, match=match): + with connect("grpc+tcp://localhost:1234", + conn_kwargs=conn_kwargs) as conn: + conn.adbc_get_info() diff --git a/python/setup.py b/python/setup.py index 2e184e6411a..243a8ae457a 100755 --- a/python/setup.py +++ b/python/setup.py @@ -167,6 +167,8 @@ def initialize_options(self): os.environ.get('PYARROW_WITH_SUBSTRAIT', '0')) self.with_flight = strtobool( os.environ.get('PYARROW_WITH_FLIGHT', '0')) + self.with_flight_sql = self.with_flight and strtobool( + os.environ.get('PYARROW_WITH_FLIGHT_SQL', '0')) self.with_dataset = strtobool( os.environ.get('PYARROW_WITH_DATASET', '0')) self.with_parquet = strtobool( @@ -209,6 +211,7 @@ def initialize_options(self): '_compute', '_cuda', '_flight', + '_flight_sql', '_dataset', '_dataset_orc', '_dataset_parquet', @@ -284,6 +287,7 @@ def append_cmake_bool(value, varname): append_cmake_bool(self.with_cuda, 'PYARROW_BUILD_CUDA') append_cmake_bool(self.with_substrait, 'PYARROW_BUILD_SUBSTRAIT') append_cmake_bool(self.with_flight, 'PYARROW_BUILD_FLIGHT') + append_cmake_bool(self.with_flight_sql, 'PYARROW_BUILD_FLIGHT_SQL') append_cmake_bool(self.with_gandiva, 'PYARROW_BUILD_GANDIVA') append_cmake_bool(self.with_dataset, 'PYARROW_BUILD_DATASET') append_cmake_bool(self.with_orc, 'PYARROW_BUILD_ORC') @@ -370,6 +374,8 @@ def _failure_permitted(self, name): return True if name == '_flight' and not self.with_flight: return True + if name == '_flight_sql' and not self.with_flight_sql: + return True if name == '_substrait' and not self.with_substrait: return True if name == '_gcsfs' and not self.with_gcs: From d48308adfcd428bebab408da7a1628cadaf38d21 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 10 Jan 2023 14:20:13 -0500 Subject: [PATCH 2/5] Enable in wheel builds --- ci/scripts/python_wheel_macos_build.sh | 3 +++ ci/scripts/python_wheel_manylinux_build.sh | 3 +++ ci/scripts/python_wheel_windows_build.bat | 3 +++ 3 files changed, 9 insertions(+) diff --git a/ci/scripts/python_wheel_macos_build.sh b/ci/scripts/python_wheel_macos_build.sh index 7c7ef7745c0..e0cafb30c5e 100755 --- a/ci/scripts/python_wheel_macos_build.sh +++ b/ci/scripts/python_wheel_macos_build.sh @@ -61,6 +61,7 @@ pip install "delocate>=0.10.3" echo "=== (${PYTHON_VERSION}) Building Arrow C++ libraries ===" : ${ARROW_DATASET:=ON} : ${ARROW_FLIGHT:=ON} +: ${ARROW_FLIGHT_SQL:=${ARROW_FLIGHT}} : ${ARROW_GANDIVA:=OFF} : ${ARROW_GCS:=ON} : ${ARROW_HDFS:=ON} @@ -101,6 +102,7 @@ cmake \ -DARROW_DEPENDENCY_USE_SHARED=OFF \ -DARROW_FILESYSTEM=ON \ -DARROW_FLIGHT=${ARROW_FLIGHT} \ + -DARROW_FLIGHT_SQL=${ARROW_FLIGHT_SQL} \ -DARROW_GANDIVA=${ARROW_GANDIVA} \ -DARROW_GCS=${ARROW_GCS} \ -DARROW_HDFS=${ARROW_HDFS} \ @@ -146,6 +148,7 @@ export PYARROW_CMAKE_GENERATOR=${CMAKE_GENERATOR} export PYARROW_INSTALL_TESTS=1 export PYARROW_WITH_DATASET=${ARROW_DATASET} export PYARROW_WITH_FLIGHT=${ARROW_FLIGHT} +export PYARROW_WITH_FLIGHT_SQL=${ARROW_FLIGHT_SQL} export PYARROW_WITH_GANDIVA=${ARROW_GANDIVA} export PYARROW_WITH_GCS=${ARROW_GCS} export PYARROW_WITH_HDFS=${ARROW_HDFS} diff --git a/ci/scripts/python_wheel_manylinux_build.sh b/ci/scripts/python_wheel_manylinux_build.sh index 2aea55ed70f..3118cb1d852 100755 --- a/ci/scripts/python_wheel_manylinux_build.sh +++ b/ci/scripts/python_wheel_manylinux_build.sh @@ -50,6 +50,7 @@ rm -rf /arrow/python/pyarrow/*.so.* echo "=== (${PYTHON_VERSION}) Building Arrow C++ libraries ===" : ${ARROW_DATASET:=ON} : ${ARROW_FLIGHT:=ON} +: ${ARROW_FLIGHT_SQL:=${ARROW_FLIGHT}} : ${ARROW_GANDIVA:=OFF} : ${ARROW_GCS:=ON} : ${ARROW_HDFS:=ON} @@ -99,6 +100,7 @@ cmake \ -DARROW_DEPENDENCY_USE_SHARED=OFF \ -DARROW_FILESYSTEM=ON \ -DARROW_FLIGHT=${ARROW_FLIGHT} \ + -DARROW_FLIGHT_SQL=${ARROW_FLIGHT_SQL} \ -DARROW_GANDIVA=${ARROW_GANDIVA} \ -DARROW_GCS=${ARROW_GCS} \ -DARROW_HDFS=${ARROW_HDFS} \ @@ -146,6 +148,7 @@ export PYARROW_CMAKE_GENERATOR=${CMAKE_GENERATOR} export PYARROW_INSTALL_TESTS=1 export PYARROW_WITH_DATASET=${ARROW_DATASET} export PYARROW_WITH_FLIGHT=${ARROW_FLIGHT} +export PYARROW_WITH_FLIGHT_SQL=${ARROW_FLIGHT_SQL} export PYARROW_WITH_GANDIVA=${ARROW_GANDIVA} export PYARROW_WITH_GCS=${ARROW_GCS} export PYARROW_WITH_HDFS=${ARROW_HDFS} diff --git a/ci/scripts/python_wheel_windows_build.bat b/ci/scripts/python_wheel_windows_build.bat index d137cd8a985..5900953fe63 100644 --- a/ci/scripts/python_wheel_windows_build.bat +++ b/ci/scripts/python_wheel_windows_build.bat @@ -32,6 +32,7 @@ del /s /q C:\arrow\python\pyarrow\*.so.* echo "=== (%PYTHON_VERSION%) Building Arrow C++ libraries ===" set ARROW_DATASET=ON set ARROW_FLIGHT=ON +set ARROW_FLIGHT_SQL=ON set ARROW_GANDIVA=OFF set ARROW_HDFS=ON set ARROW_ORC=OFF @@ -70,6 +71,7 @@ cmake ^ -DARROW_DEPENDENCY_USE_SHARED=OFF ^ -DARROW_FILESYSTEM=ON ^ -DARROW_FLIGHT=%ARROW_FLIGHT% ^ + -DARROW_FLIGHT_SQL=%ARROW_FLIGHT_SQL% ^ -DARROW_GANDIVA=%ARROW_GANDIVA% ^ -DARROW_HDFS=%ARROW_HDFS% ^ -DARROW_JSON=ON ^ @@ -108,6 +110,7 @@ set PYARROW_CMAKE_GENERATOR=%CMAKE_GENERATOR% set PYARROW_INSTALL_TESTS=ON set PYARROW_WITH_DATASET=%ARROW_DATASET% set PYARROW_WITH_FLIGHT=%ARROW_FLIGHT% +set PYARROW_WITH_FLIGHT_SQL=%ARROW_FLIGHT_SQL% set PYARROW_WITH_GANDIVA=%ARROW_GANDIVA% set PYARROW_WITH_HDFS=%ARROW_HDFS% set PYARROW_WITH_ORC=%ARROW_ORC% From 3c51ea6b64fce023f41c0413a176395a130fc5d1 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 11 Jan 2023 08:13:06 -0500 Subject: [PATCH 3/5] Enable Flight SQL in tests --- python/requirements-test.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/python/requirements-test.txt b/python/requirements-test.txt index 9fc37a5f157..2716434c0e0 100644 --- a/python/requirements-test.txt +++ b/python/requirements-test.txt @@ -1,3 +1,4 @@ +adbc-driver-manager; python_version >= "3.9" cffi hypothesis pandas From b18f4d26180f0bd523553e742c61ac5d37240e59 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 11 Jan 2023 08:58:56 -0500 Subject: [PATCH 4/5] Actually enable Flight SQL in tests --- python/requirements-wheel-test.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/python/requirements-wheel-test.txt b/python/requirements-wheel-test.txt index dd07f0358d7..9f98a4b760b 100644 --- a/python/requirements-wheel-test.txt +++ b/python/requirements-wheel-test.txt @@ -1,3 +1,4 @@ +adbc-driver-manager; python_version >= "3.9" cffi cython hypothesis From 3e342475d6cd1764e9b3b4575e1ee6fc62c5d36a Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 16 Jan 2023 08:32:44 -0500 Subject: [PATCH 5/5] Fix CMake config with bundled gtest --- cpp/cmake_modules/ThirdpartyToolchain.cmake | 13 ++++++++----- cpp/src/arrow/flight/sql/CMakeLists.txt | 6 +----- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 365aa936585..feff24d0aac 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -4665,11 +4665,11 @@ macro(build_adbc_validation) set(ADBC_VALIDATION_LIB_DIR "lib") set(ADBC_VALIDATION_COMMON_CMAKE_ARGS - ${EP_COMMON_CMAKE_ARGS} - "-DCMAKE_INSTALL_LIBDIR=${ADBC_VALIDATION_LIB_DIR}" - "-DCMAKE_INSTALL_PREFIX=${ADBC_VALIDATION_PREFIX}" - "-DCMAKE_PREFIX_PATH=${ADBC_VALIDATION_PREFIX}" - "-DCMAKE_UNITY_BUILD=ON") + ${EP_COMMON_CMAKE_ARGS} "-DCMAKE_INSTALL_LIBDIR=${ADBC_VALIDATION_LIB_DIR}" + "-DCMAKE_INSTALL_PREFIX=${ADBC_VALIDATION_PREFIX}" "-DCMAKE_UNITY_BUILD=ON") + if(GTEST_VENDORED) + list(APPEND ADBC_VALIDATION_COMMON_CMAKE_ARGS "-DCMAKE_PREFIX_PATH=${GTEST_PREFIX}") + endif() set(ADBC_VALIDATION_STATIC_LIBRARY "${ADBC_VALIDATION_PREFIX}/${ADBC_VALIDATION_LIB_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}adbc_validation${CMAKE_STATIC_LIBRARY_SUFFIX}" ) @@ -4693,6 +4693,9 @@ macro(build_adbc_validation) add_dependencies(AdbcValidation::adbc_validation nanoarrow_ep) set(ADBCVALIDATION_LINK_LIBRARIES AdbcValidation::adbc_validation) add_dependencies(AdbcValidation::adbc_validation adbcvalidation_ep) + if(GTEST_VENDORED) + add_dependencies(adbcvalidation_ep googletest_ep) + endif() endmacro() if(ARROW_WITH_ADBC_VALIDATION) diff --git a/cpp/src/arrow/flight/sql/CMakeLists.txt b/cpp/src/arrow/flight/sql/CMakeLists.txt index bedbe440996..4940094ee28 100644 --- a/cpp/src/arrow/flight/sql/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/CMakeLists.txt @@ -124,14 +124,10 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES) ${ARROW_FLIGHT_SQL_TEST_SRCS} ${ARROW_FLIGHT_SQL_TEST_SERVER_SRCS} STATIC_LINK_LIBS - ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS} - ${ARROW_FLIGHT_SQL_TEST_LIBS} nanoarrow AdbcValidation::adbc_validation - # Needs to come twice since the validation library - # also uses ADBC symbols, which get provided by - # libarrow_flight_sql here. ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS} + ${ARROW_FLIGHT_SQL_TEST_LIBS} EXTRA_INCLUDES # adbc_validation.h needs adbc.h "${CMAKE_SOURCE_DIR}/../format/"